Distributed LLM inference with Ray Data and SGLang on serverless GPU
This notebook demonstrates how to run large language model (LLM) inference at scale using Ray Data and SGLang on Databricks serverless GPU. It leverages the Distributed Serverless GPU API to automatically provision and manage multi-node A10 GPUs for distributed inference.
What this notebook covers:
- Set up Ray and SGLang for distributed LLM inference
- Use Ray Data to efficiently batch and process prompts across multiple GPUs
- Define SGLang prompt functions with structured generation
- Save inference results to Parquet in Unity Catalog Volumes
- Convert Parquet to Delta tables for governance and efficient querying
Use case: Batch inference on thousands of prompts with efficient GPU utilization, SGLang's optimized runtime, and Delta Lake integration.
Requirements
Serverless GPU compute with A10 accelerator
Connect the notebook to serverless GPU compute:
- Click the Connect dropdown at the top.
- Select Serverless GPU.
- Open the Environment side panel on the right side of the notebook.
- Set Accelerator to A10.
- Click Apply and Confirm.
Note: The distributed function launches remote A10 GPUs for multi-node inference. The notebook itself runs on a single A10 for orchestration.
Install dependencies
The following cell installs all required packages for distributed Ray and SGLang inference:
- Flash Attention: Optimized attention for faster inference (CUDA 12, PyTorch 2.6, A10 compatible)
- SGLang: High-performance LLM inference and serving framework
- Ray Data: Distributed data processing for batch inference
- hf_transfer: Fast Hugging Face model downloads
# Pre-compiled Flash Attention for A10s (Essential for speed/compilation)
%pip install --no-cache-dir "torch==2.9.1+cu128" --index-url https://download.pytorch.org/whl/cu128
%pip install -U --no-cache-dir wheel ninja packaging
%pip install --force-reinstall --no-cache-dir --no-build-isolation flash-attn
%pip install hf_transfer
%pip install "ray[data]>=2.47.1"
# SGLang with all dependencies (handles vLLM/Torch automatically)
%pip install "sglang[all]>=0.4.7"
%restart_python
Verify package versions
The following cell confirms that all required packages are installed with compatible versions.
from packaging.version import Version
import torch
import flash_attn
import sglang
import ray
print(f"PyTorch: {torch.__version__}")
print(f"Flash Attention: {flash_attn.__version__}")
print(f"SGLang: {sglang.__version__}")
print(f"Ray: {ray.__version__}")
assert Version(ray.__version__) >= Version("2.47.1"), "Ray version must be at least 2.47.1"
print("\n✓ All version checks passed!")
Configuration
Use widgets to configure inference parameters and optional Hugging Face authentication.
Security Note: Store Hugging Face tokens in Databricks Secrets for production use. See Databricks Secrets documentation.
# Widget configuration
dbutils.widgets.text("hf_secret_scope", "")
dbutils.widgets.text("hf_secret_key", "")
dbutils.widgets.text("model_name", "Qwen/Qwen3-4B-Instruct-2507")
dbutils.widgets.text("num_gpus", "5")
dbutils.widgets.text("num_prompts", "1000")
# Unity Catalog configuration for output storage
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_volume", "ray_data")
dbutils.widgets.text("uc_table", "sglang_inference_results")
# Retrieve widget values
HF_SECRET_SCOPE = dbutils.widgets.get("hf_secret_scope")
HF_SECRET_KEY = dbutils.widgets.get("hf_secret_key")
MODEL_NAME = dbutils.widgets.get("model_name")
NUM_GPUS = int(dbutils.widgets.get("num_gpus"))
NUM_PROMPTS = int(dbutils.widgets.get("num_prompts"))
# Unity Catalog paths
UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_VOLUME = dbutils.widgets.get("uc_volume")
UC_TABLE = dbutils.widgets.get("uc_table")
# Construct paths
UC_VOLUME_PATH = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}"
UC_TABLE_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{UC_TABLE}"
PARQUET_OUTPUT_PATH = f"{UC_VOLUME_PATH}/sglang_inference_output"
print(f"Model: {MODEL_NAME}")
print(f"Number of GPUs: {NUM_GPUS}")
print(f"Number of prompts: {NUM_PROMPTS}")
print(f"\nUnity Catalog Configuration:")
print(f" Volume Path: {UC_VOLUME_PATH}")
print(f" Table Name: {UC_TABLE_NAME}")
print(f" Parquet Output: {PARQUET_OUTPUT_PATH}")
Authenticate with Hugging Face (optional)
If using gated models (like Llama), authenticate with Hugging Face.
Option 1: Use Databricks Secrets (recommended for production)
hf_token = dbutils.secrets.get(scope=HF_SECRET_SCOPE, key=HF_SECRET_KEY)
login(token=hf_token)
Option 2: Interactive login (for development)
from huggingface_hub import login
# Uncomment ONE of the following options:
# Option 1: Use Databricks Secrets (recommended)
# if HF_SECRET_SCOPE and HF_SECRET_KEY:
# hf_token = dbutils.secrets.get(scope=HF_SECRET_SCOPE, key=HF_SECRET_KEY)
# login(token=hf_token)
# print("✓ Logged in using Databricks Secrets")
# Option 2: Interactive login
login()
print("✓ Hugging Face authentication complete")
Set up Unity Catalog resources
The following cell creates the necessary Unity Catalog resources (catalog, schema, and volume) for storing inference results. These resources provide governance, lineage tracking, and centralized storage for the generated outputs.
# Unity Catalog Setup and Dataset Download
# ⚠️ IMPORTANT: Run this cell BEFORE the dataset processing cell
# to set up Unity Catalog resources and download the raw datasets.
import os
from databricks.sdk import WorkspaceClient
from databricks.sdk.service import catalog
import requests
# Create Unity Catalog resources
w = WorkspaceClient()
# Create catalog if it doesn't exist
try:
created_catalog = w.catalogs.create(name=UC_CATALOG)
print(f"Catalog '{created_catalog.name}' created successfully")
except Exception as e:
print(f"Catalog '{UC_CATALOG}' already exists or error: {e}")
# Create schema if it doesn't exist
try:
created_schema = w.schemas.create(name=UC_SCHEMA, catalog_name=UC_CATALOG)
print(f"Schema '{created_schema.name}' created successfully")
except Exception as e:
print(f"Schema '{UC_SCHEMA}' already exists in catalog '{UC_CATALOG}' or error: {e}")
# Create volume if it doesn't exist
volume_path = f'/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}'
if not os.path.exists(volume_path):
try:
created_volume = w.volumes.create(
catalog_name=UC_CATALOG,
schema_name=UC_SCHEMA,
name=UC_VOLUME,
volume_type=catalog.VolumeType.MANAGED
)
print(f"Volume '{created_volume.name}' created successfully")
except Exception as e:
print(f"Volume '{UC_VOLUME}' already exists or error: {e}")
else:
print(f"Volume {volume_path} already exists")
Ray cluster resource monitoring
The following cell defines a utility function to inspect Ray cluster resources and verify GPU allocation across nodes.
import json
import ray
def print_ray_resources():
"""Print Ray cluster resources and GPU allocation per node."""
try:
cluster_resources = ray.cluster_resources()
print("Ray Cluster Resources:")
print(json.dumps(cluster_resources, indent=2))
nodes = ray.nodes()
print(f"\nDetected {len(nodes)} Ray node(s):")
for node in nodes:
node_id = node.get("NodeID", "N/A")[:8] # Truncate for readability
ip_address = node.get("NodeManagerAddress", "N/A")
resources = node.get("Resources", {})
num_gpus = int(resources.get("GPU", 0))
print(f" • Node {node_id}... | IP: {ip_address} | GPUs: {num_gpus}")
# Show specific GPU IDs if available
gpu_ids = [k for k in resources.keys() if k.startswith("GPU_ID_")]
if gpu_ids:
print(f" GPU IDs: {', '.join(gpu_ids)}")
except Exception as e:
print(f"Error querying Ray cluster: {e}")
# Uncomment to display resources after cluster initialization
# print_ray_resources()
Define the distributed inference task
This notebook uses SGLang Runtime with Ray Data for distributed LLM inference. SGLang provides an optimized runtime with features like RadixAttention for efficient prefix caching.
Architecture overview
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ Ray Data │───▶│ SGLang Runtime │───▶│ Generated │
│ (Prompts) │ │ (map_batches) │ │ Outputs │
└─────────────────┘ └─────────────────┘ └─────────────────┘
│ │
▼ ▼
Distributed GPU Workers (A10)
across nodes with SGLang engines
Key benefits of SGLang
- RadixAttention: Efficient KV cache reuse across requests
- Structured generation: Constraint-guided decoding with regex/grammar
- Optimized runtime: High throughput with automatic batching
- Multi-turn support: Efficient handling of conversational prompts
from serverless_gpu.ray import ray_launch
import os
# Set Ray temp directory
os.environ['RAY_TEMP_DIR'] = '/tmp/ray'
# Set the UC Volumes temp directory for write_databricks_table
os.environ['_RAY_UC_VOLUMES_FUSE_TEMP_DIR'] = f"{UC_VOLUME_PATH}/ray_temp"
@ray_launch(gpus=NUM_GPUS, gpu_type='a10', remote=True)
def run_distributed_inference():
"""Run distributed LLM inference using Ray Data and SGLang Runtime."""
import ray
import numpy as np
from typing import Dict
import sglang as sgl
from datetime import datetime
# Sample prompts for inference
base_prompts = [
"Hello, my name is",
"The president of the United States is",
"The future of AI is",
]
# Scale up prompts for distributed processing
prompts = [{"prompt": p} for p in base_prompts * (NUM_PROMPTS // len(base_prompts))]
ds = ray.data.from_items(prompts)
print(f"✓ Created Ray dataset with {ds.count()} prompts")
# Define SGLang Predictor class for Ray Data map_batches
class SGLangPredictor:
"""SGLang-based predictor for batch inference with Ray Data."""
def __init__(self):
# Initialize SGLang Runtime inside the actor process
self.runtime = sgl.Runtime(
model_path=MODEL_NAME,
dtype="bfloat16",
trust_remote_code=True,
mem_fraction_static=0.85,
tp_size=1, # Tensor parallelism (1 GPU per worker)
)
# Set as default backend for the current process
sgl.set_default_backend(self.runtime)
print(f"✓ SGLang runtime initialized with model: {MODEL_NAME}")
def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
"""Process a batch of prompts using SGLang."""
input_prompts = batch["prompt"].tolist()
# Define SGLang prompt function
@sgl.function
def chat_completion(s, prompt):
s += sgl.system("You are a helpful AI assistant.")
s += sgl.user(prompt)
s += sgl.assistant(sgl.gen("response", max_tokens=100, temperature=0.8))
# Run batch inference
results = chat_completion.run_batch(
[{"prompt": p} for p in input_prompts],
progress_bar=False
)
generated_text = [r["response"] for r in results]
return {
"prompt": input_prompts,
"generated_text": generated_text,
"model": [MODEL_NAME] * len(input_prompts),
"timestamp": [datetime.now().isoformat()] * len(input_prompts),
}
def __del__(self):
"""Clean up runtime when actor dies."""
try:
self.runtime.shutdown()
except:
pass
print(f"✓ SGLang predictor configured with model: {MODEL_NAME}")
# Apply map_batches with SGLang predictor
ds = ds.map_batches(
SGLangPredictor,
concurrency=NUM_GPUS, # Number of parallel SGLang instances
batch_size=32,
num_gpus=1,
num_cpus=12,
)
# =========================================================================
# Write results to Parquet (stored in Unity Catalog Volume)
# =========================================================================
print(f"\n📦 Writing results to Parquet: {PARQUET_OUTPUT_PATH}")
ds.write_parquet(PARQUET_OUTPUT_PATH, mode="overwrite")
print(f"✓ Parquet files written successfully")
# Collect sample outputs for display
sample_outputs = ray.data.read_parquet(PARQUET_OUTPUT_PATH).take(limit=10)
print("\n" + "="*60)
print("SAMPLE INFERENCE RESULTS")
print("="*60 + "\n")
for i, output in enumerate(sample_outputs):
prompt = output.get("prompt", "N/A")
generated_text = output.get("generated_text", "")
display_text = generated_text[:100] if generated_text else "N/A"
print(f"[{i+1}] Prompt: {prompt!r}")
print(f" Generated: {display_text!r}...\n")
return PARQUET_OUTPUT_PATH
Run distributed inference
The following cell launches the distributed inference task across multiple A10 GPUs. This will:
- Provision remote A10 GPU workers
- Initialize SGLang engines on each worker
- Distribute prompts across workers using Ray Data
- Save generated outputs to Parquet
Note: Initial startup may take a few minutes as GPU nodes are provisioned and models are loaded.
result = run_distributed_inference.distributed()
parquet_path = result[0] if NUM_GPUS > 1 else result
print(f"\n✓ Inference complete! Results saved to: {parquet_path}")
Load Parquet and convert to Delta table
The following cell loads the Parquet output from Unity Catalog Volumes using Spark and saves it as a Delta table for efficient querying and governance.
# Load Parquet data using Spark
print(f"📖 Loading Parquet from: {PARQUET_OUTPUT_PATH}")
df_spark = spark.read.parquet(PARQUET_OUTPUT_PATH)
# Show schema and row count
print(f"\n✓ Loaded {df_spark.count()} rows")
print("\nSchema:")
df_spark.printSchema()
# Display sample rows
print("\nSample Results:")
display(df_spark.limit(10))
Save as Delta table in Unity Catalog
The following cell writes the inference results to a Unity Catalog Delta table for:
- Governance: Track data lineage and access controls
- Performance: Optimized queries with Delta Lake
- Versioning: Time travel and audit history
# Write to Unity Catalog Delta table
print(f"💾 Writing to Delta table: {UC_TABLE_NAME}")
# Write the DataFrame as a Delta table (overwrite mode)
df_spark.write \
.format("delta") \
.mode("overwrite") \
.option("overwriteSchema", "true") \
.saveAsTable(UC_TABLE_NAME)
print(f"✓ Delta table created successfully: {UC_TABLE_NAME}")
Query the Delta table
The following cell verifies the Delta table was created and queries it using SQL.
# Query the Delta table using SQL
print(f"📊 Querying Delta table: {UC_TABLE_NAME}\n")
# Get table info
display(spark.sql(f"DESCRIBE TABLE {UC_TABLE_NAME}"))
# Query sample results
print("\nSample Results from Delta Table:")
display(spark.sql(f"""
SELECT
prompt,
generated_text,
model,
timestamp
FROM {UC_TABLE_NAME}
LIMIT 10
"""))
# Get row count
row_count = spark.sql(f"SELECT COUNT(*) as count FROM {UC_TABLE_NAME}").collect()[0]["count"]
print(f"\n✓ Total rows in Delta table: {row_count}")
# Assert expected row count (NUM_PROMPTS should result in 999 rows: 1000 // 3 * 3 = 999)
expected_rows = (NUM_PROMPTS // 3) * 3 # Rounds down to nearest multiple of 3 base prompts
assert row_count == expected_rows, f"Expected {expected_rows} rows, but got {row_count}"
Next steps
This notebook successfully demonstrated distributed LLM inference using Ray Data and SGLang on Databricks serverless GPU, with results saved to a Delta table.
What was accomplished
- Ran distributed LLM inference across multiple A10 GPUs
- Used SGLang Runtime for optimized batch processing
- Saved results to Parquet in Unity Catalog Volumes
- Converted Parquet to a governed Delta table
Customization options
- Change the model: Update
model_namewidget to use different Hugging Face models - Scale up: Increase
num_gpusfor higher throughput - Adjust batch size: Modify
batch_sizeinmap_batchesbased on memory constraints - Tune generation: Adjust
max_tokens,temperaturein the SGLang function - Structured output: Use SGLang's regex/grammar constraints for JSON output
- Multi-turn chat: Chain multiple
user/assistantcalls in the prompt function - Append mode: Change Delta write to
mode("append")for incremental updates
SGLang advanced features
# Structured JSON output with regex constraint
@sgl.function
def json_output(s, prompt):
s += sgl.user(prompt)
s += sgl.assistant(sgl.gen("response", regex=r'\{"name": "\w+", "age": \d+\}'))
# Multi-turn conversation
@sgl.function
def multi_turn(s, question1, question2):
s += sgl.user(question1)
s += sgl.assistant(sgl.gen("answer1"))
s += sgl.user(question2)
s += sgl.assistant(sgl.gen("answer2"))
Alternative: write_databricks_table
For Unity Catalog enabled workspaces, use ray.data.Dataset.write_databricks_table() to write directly to a Unity Catalog table:
# Set the temp directory environment variable
os.environ["_RAY_UC_VOLUMES_FUSE_TEMP_DIR"] = "/Volumes/catalog/schema/volume/ray_temp"
# Write directly to Unity Catalog table
ds.write_databricks_table(table_name="catalog.schema.table_name")
Resources
- Serverless GPU API Documentation
- SGLang Documentation
- SGLang GitHub Repository
- Ray Data Documentation
- Unity Catalog Documentation
Cleanup
GPU resources are automatically cleaned up when the notebook disconnects. To manually disconnect:
- Click Connected in the compute dropdown
- Hover over Serverless
- Select Terminate from the dropdown menu