Skip to main content

Distributed LLM inference with Ray Data and SGLang on serverless GPU

This notebook demonstrates how to run large language model (LLM) inference at scale using Ray Data and SGLang on Databricks serverless GPU. It leverages the Distributed Serverless GPU API to automatically provision and manage multi-node A10 GPUs for distributed inference.

What this notebook covers:

  • Set up Ray and SGLang for distributed LLM inference
  • Use Ray Data to efficiently batch and process prompts across multiple GPUs
  • Define SGLang prompt functions with structured generation
  • Save inference results to Parquet in Unity Catalog Volumes
  • Convert Parquet to Delta tables for governance and efficient querying

Use case: Batch inference on thousands of prompts with efficient GPU utilization, SGLang's optimized runtime, and Delta Lake integration.

Requirements

Serverless GPU compute with A10 accelerator

Connect the notebook to serverless GPU compute:

  1. Click the Connect dropdown at the top.
  2. Select Serverless GPU.
  3. Open the Environment side panel on the right side of the notebook.
  4. Set Accelerator to A10.
  5. Click Apply and Confirm.

Note: The distributed function launches remote A10 GPUs for multi-node inference. The notebook itself runs on a single A10 for orchestration.

Install dependencies

The following cell installs all required packages for distributed Ray and SGLang inference:

  • Flash Attention: Optimized attention for faster inference (CUDA 12, PyTorch 2.6, A10 compatible)
  • SGLang: High-performance LLM inference and serving framework
  • Ray Data: Distributed data processing for batch inference
  • hf_transfer: Fast Hugging Face model downloads
Python
# Pre-compiled Flash Attention for A10s (Essential for speed/compilation)
%pip install --no-cache-dir "torch==2.9.1+cu128" --index-url https://download.pytorch.org/whl/cu128
%pip install -U --no-cache-dir wheel ninja packaging
%pip install --force-reinstall --no-cache-dir --no-build-isolation flash-attn
%pip install hf_transfer
%pip install "ray[data]>=2.47.1"

# SGLang with all dependencies (handles vLLM/Torch automatically)
%pip install "sglang[all]>=0.4.7"

%restart_python

Verify package versions

The following cell confirms that all required packages are installed with compatible versions.

Python
from packaging.version import Version

import torch
import flash_attn
import sglang
import ray

print(f"PyTorch: {torch.__version__}")
print(f"Flash Attention: {flash_attn.__version__}")
print(f"SGLang: {sglang.__version__}")
print(f"Ray: {ray.__version__}")

assert Version(ray.__version__) >= Version("2.47.1"), "Ray version must be at least 2.47.1"
print("\n✓ All version checks passed!")

Configuration

Use widgets to configure inference parameters and optional Hugging Face authentication.

Security Note: Store Hugging Face tokens in Databricks Secrets for production use. See Databricks Secrets documentation.

Python
# Widget configuration
dbutils.widgets.text("hf_secret_scope", "")
dbutils.widgets.text("hf_secret_key", "")
dbutils.widgets.text("model_name", "Qwen/Qwen3-4B-Instruct-2507")
dbutils.widgets.text("num_gpus", "5")
dbutils.widgets.text("num_prompts", "1000")

# Unity Catalog configuration for output storage
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_volume", "ray_data")
dbutils.widgets.text("uc_table", "sglang_inference_results")

# Retrieve widget values
HF_SECRET_SCOPE = dbutils.widgets.get("hf_secret_scope")
HF_SECRET_KEY = dbutils.widgets.get("hf_secret_key")
MODEL_NAME = dbutils.widgets.get("model_name")
NUM_GPUS = int(dbutils.widgets.get("num_gpus"))
NUM_PROMPTS = int(dbutils.widgets.get("num_prompts"))

# Unity Catalog paths
UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_VOLUME = dbutils.widgets.get("uc_volume")
UC_TABLE = dbutils.widgets.get("uc_table")

# Construct paths
UC_VOLUME_PATH = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}"
UC_TABLE_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{UC_TABLE}"
PARQUET_OUTPUT_PATH = f"{UC_VOLUME_PATH}/sglang_inference_output"

print(f"Model: {MODEL_NAME}")
print(f"Number of GPUs: {NUM_GPUS}")
print(f"Number of prompts: {NUM_PROMPTS}")
print(f"\nUnity Catalog Configuration:")
print(f" Volume Path: {UC_VOLUME_PATH}")
print(f" Table Name: {UC_TABLE_NAME}")
print(f" Parquet Output: {PARQUET_OUTPUT_PATH}")

Authenticate with Hugging Face (optional)

If using gated models (like Llama), authenticate with Hugging Face.

Option 1: Use Databricks Secrets (recommended for production)

Python
hf_token = dbutils.secrets.get(scope=HF_SECRET_SCOPE, key=HF_SECRET_KEY)
login(token=hf_token)

Option 2: Interactive login (for development)

Python
from huggingface_hub import login

# Uncomment ONE of the following options:

# Option 1: Use Databricks Secrets (recommended)
# if HF_SECRET_SCOPE and HF_SECRET_KEY:
# hf_token = dbutils.secrets.get(scope=HF_SECRET_SCOPE, key=HF_SECRET_KEY)
# login(token=hf_token)
# print("✓ Logged in using Databricks Secrets")

# Option 2: Interactive login
login()
print("✓ Hugging Face authentication complete")

Set up Unity Catalog resources

The following cell creates the necessary Unity Catalog resources (catalog, schema, and volume) for storing inference results. These resources provide governance, lineage tracking, and centralized storage for the generated outputs.

Python
# Unity Catalog Setup and Dataset Download
# ⚠️ IMPORTANT: Run this cell BEFORE the dataset processing cell
# to set up Unity Catalog resources and download the raw datasets.

import os
from databricks.sdk import WorkspaceClient
from databricks.sdk.service import catalog
import requests

# Create Unity Catalog resources
w = WorkspaceClient()

# Create catalog if it doesn't exist
try:
created_catalog = w.catalogs.create(name=UC_CATALOG)
print(f"Catalog '{created_catalog.name}' created successfully")
except Exception as e:
print(f"Catalog '{UC_CATALOG}' already exists or error: {e}")

# Create schema if it doesn't exist
try:
created_schema = w.schemas.create(name=UC_SCHEMA, catalog_name=UC_CATALOG)
print(f"Schema '{created_schema.name}' created successfully")
except Exception as e:
print(f"Schema '{UC_SCHEMA}' already exists in catalog '{UC_CATALOG}' or error: {e}")

# Create volume if it doesn't exist
volume_path = f'/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}'
if not os.path.exists(volume_path):
try:
created_volume = w.volumes.create(
catalog_name=UC_CATALOG,
schema_name=UC_SCHEMA,
name=UC_VOLUME,
volume_type=catalog.VolumeType.MANAGED
)
print(f"Volume '{created_volume.name}' created successfully")
except Exception as e:
print(f"Volume '{UC_VOLUME}' already exists or error: {e}")
else:
print(f"Volume {volume_path} already exists")

Ray cluster resource monitoring

The following cell defines a utility function to inspect Ray cluster resources and verify GPU allocation across nodes.

Python
import json
import ray

def print_ray_resources():
"""Print Ray cluster resources and GPU allocation per node."""
try:
cluster_resources = ray.cluster_resources()
print("Ray Cluster Resources:")
print(json.dumps(cluster_resources, indent=2))

nodes = ray.nodes()
print(f"\nDetected {len(nodes)} Ray node(s):")

for node in nodes:
node_id = node.get("NodeID", "N/A")[:8] # Truncate for readability
ip_address = node.get("NodeManagerAddress", "N/A")
resources = node.get("Resources", {})
num_gpus = int(resources.get("GPU", 0))

print(f" • Node {node_id}... | IP: {ip_address} | GPUs: {num_gpus}")

# Show specific GPU IDs if available
gpu_ids = [k for k in resources.keys() if k.startswith("GPU_ID_")]
if gpu_ids:
print(f" GPU IDs: {', '.join(gpu_ids)}")

except Exception as e:
print(f"Error querying Ray cluster: {e}")

# Uncomment to display resources after cluster initialization
# print_ray_resources()

Define the distributed inference task

This notebook uses SGLang Runtime with Ray Data for distributed LLM inference. SGLang provides an optimized runtime with features like RadixAttention for efficient prefix caching.

Architecture overview

┌─────────────────┐    ┌─────────────────┐    ┌─────────────────┐
│ Ray Data │───▶│ SGLang Runtime │───▶│ Generated │
│ (Prompts) │ │ (map_batches) │ │ Outputs │
└─────────────────┘ └─────────────────┘ └─────────────────┘
│ │
▼ ▼
Distributed GPU Workers (A10)
across nodes with SGLang engines

Key benefits of SGLang

  • RadixAttention: Efficient KV cache reuse across requests
  • Structured generation: Constraint-guided decoding with regex/grammar
  • Optimized runtime: High throughput with automatic batching
  • Multi-turn support: Efficient handling of conversational prompts
Python
from serverless_gpu.ray import ray_launch
import os

# Set Ray temp directory
os.environ['RAY_TEMP_DIR'] = '/tmp/ray'

# Set the UC Volumes temp directory for write_databricks_table
os.environ['_RAY_UC_VOLUMES_FUSE_TEMP_DIR'] = f"{UC_VOLUME_PATH}/ray_temp"

@ray_launch(gpus=NUM_GPUS, gpu_type='a10', remote=True)
def run_distributed_inference():
"""Run distributed LLM inference using Ray Data and SGLang Runtime."""
import ray
import numpy as np
from typing import Dict
import sglang as sgl
from datetime import datetime

# Sample prompts for inference
base_prompts = [
"Hello, my name is",
"The president of the United States is",
"The future of AI is",
]

# Scale up prompts for distributed processing
prompts = [{"prompt": p} for p in base_prompts * (NUM_PROMPTS // len(base_prompts))]
ds = ray.data.from_items(prompts)

print(f"✓ Created Ray dataset with {ds.count()} prompts")

# Define SGLang Predictor class for Ray Data map_batches
class SGLangPredictor:
"""SGLang-based predictor for batch inference with Ray Data."""

def __init__(self):
# Initialize SGLang Runtime inside the actor process
self.runtime = sgl.Runtime(
model_path=MODEL_NAME,
dtype="bfloat16",
trust_remote_code=True,
mem_fraction_static=0.85,
tp_size=1, # Tensor parallelism (1 GPU per worker)
)
# Set as default backend for the current process
sgl.set_default_backend(self.runtime)
print(f"✓ SGLang runtime initialized with model: {MODEL_NAME}")

def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
"""Process a batch of prompts using SGLang."""
input_prompts = batch["prompt"].tolist()

# Define SGLang prompt function
@sgl.function
def chat_completion(s, prompt):
s += sgl.system("You are a helpful AI assistant.")
s += sgl.user(prompt)
s += sgl.assistant(sgl.gen("response", max_tokens=100, temperature=0.8))

# Run batch inference
results = chat_completion.run_batch(
[{"prompt": p} for p in input_prompts],
progress_bar=False
)

generated_text = [r["response"] for r in results]

return {
"prompt": input_prompts,
"generated_text": generated_text,
"model": [MODEL_NAME] * len(input_prompts),
"timestamp": [datetime.now().isoformat()] * len(input_prompts),
}

def __del__(self):
"""Clean up runtime when actor dies."""
try:
self.runtime.shutdown()
except:
pass

print(f"✓ SGLang predictor configured with model: {MODEL_NAME}")

# Apply map_batches with SGLang predictor
ds = ds.map_batches(
SGLangPredictor,
concurrency=NUM_GPUS, # Number of parallel SGLang instances
batch_size=32,
num_gpus=1,
num_cpus=12,
)

# =========================================================================
# Write results to Parquet (stored in Unity Catalog Volume)
# =========================================================================
print(f"\n📦 Writing results to Parquet: {PARQUET_OUTPUT_PATH}")
ds.write_parquet(PARQUET_OUTPUT_PATH, mode="overwrite")
print(f"✓ Parquet files written successfully")

# Collect sample outputs for display
sample_outputs = ray.data.read_parquet(PARQUET_OUTPUT_PATH).take(limit=10)

print("\n" + "="*60)
print("SAMPLE INFERENCE RESULTS")
print("="*60 + "\n")

for i, output in enumerate(sample_outputs):
prompt = output.get("prompt", "N/A")
generated_text = output.get("generated_text", "")
display_text = generated_text[:100] if generated_text else "N/A"
print(f"[{i+1}] Prompt: {prompt!r}")
print(f" Generated: {display_text!r}...\n")

return PARQUET_OUTPUT_PATH

Run distributed inference

The following cell launches the distributed inference task across multiple A10 GPUs. This will:

  1. Provision remote A10 GPU workers
  2. Initialize SGLang engines on each worker
  3. Distribute prompts across workers using Ray Data
  4. Save generated outputs to Parquet

Note: Initial startup may take a few minutes as GPU nodes are provisioned and models are loaded.

Python
result = run_distributed_inference.distributed()
parquet_path = result[0] if NUM_GPUS > 1 else result
print(f"\n✓ Inference complete! Results saved to: {parquet_path}")

Load Parquet and convert to Delta table

The following cell loads the Parquet output from Unity Catalog Volumes using Spark and saves it as a Delta table for efficient querying and governance.

Python
# Load Parquet data using Spark
print(f"📖 Loading Parquet from: {PARQUET_OUTPUT_PATH}")
df_spark = spark.read.parquet(PARQUET_OUTPUT_PATH)

# Show schema and row count
print(f"\n✓ Loaded {df_spark.count()} rows")
print("\nSchema:")
df_spark.printSchema()

# Display sample rows
print("\nSample Results:")
display(df_spark.limit(10))

Save as Delta table in Unity Catalog

The following cell writes the inference results to a Unity Catalog Delta table for:

  • Governance: Track data lineage and access controls
  • Performance: Optimized queries with Delta Lake
  • Versioning: Time travel and audit history
Python
# Write to Unity Catalog Delta table
print(f"💾 Writing to Delta table: {UC_TABLE_NAME}")

# Write the DataFrame as a Delta table (overwrite mode)
df_spark.write \
.format("delta") \
.mode("overwrite") \
.option("overwriteSchema", "true") \
.saveAsTable(UC_TABLE_NAME)

print(f"✓ Delta table created successfully: {UC_TABLE_NAME}")

Query the Delta table

The following cell verifies the Delta table was created and queries it using SQL.

Python
# Query the Delta table using SQL
print(f"📊 Querying Delta table: {UC_TABLE_NAME}\n")

# Get table info
display(spark.sql(f"DESCRIBE TABLE {UC_TABLE_NAME}"))

# Query sample results
print("\nSample Results from Delta Table:")
display(spark.sql(f"""
SELECT
prompt,
generated_text,
model,
timestamp
FROM {UC_TABLE_NAME}
LIMIT 10
"""))

# Get row count
row_count = spark.sql(f"SELECT COUNT(*) as count FROM {UC_TABLE_NAME}").collect()[0]["count"]
print(f"\n✓ Total rows in Delta table: {row_count}")

# Assert expected row count (NUM_PROMPTS should result in 999 rows: 1000 // 3 * 3 = 999)
expected_rows = (NUM_PROMPTS // 3) * 3 # Rounds down to nearest multiple of 3 base prompts
assert row_count == expected_rows, f"Expected {expected_rows} rows, but got {row_count}"

Next steps

This notebook successfully demonstrated distributed LLM inference using Ray Data and SGLang on Databricks serverless GPU, with results saved to a Delta table.

What was accomplished

  • Ran distributed LLM inference across multiple A10 GPUs
  • Used SGLang Runtime for optimized batch processing
  • Saved results to Parquet in Unity Catalog Volumes
  • Converted Parquet to a governed Delta table

Customization options

  • Change the model: Update model_name widget to use different Hugging Face models
  • Scale up: Increase num_gpus for higher throughput
  • Adjust batch size: Modify batch_size in map_batches based on memory constraints
  • Tune generation: Adjust max_tokens, temperature in the SGLang function
  • Structured output: Use SGLang's regex/grammar constraints for JSON output
  • Multi-turn chat: Chain multiple user/assistant calls in the prompt function
  • Append mode: Change Delta write to mode("append") for incremental updates

SGLang advanced features

Python
# Structured JSON output with regex constraint
@sgl.function
def json_output(s, prompt):
s += sgl.user(prompt)
s += sgl.assistant(sgl.gen("response", regex=r'\{"name": "\w+", "age": \d+\}'))

# Multi-turn conversation
@sgl.function
def multi_turn(s, question1, question2):
s += sgl.user(question1)
s += sgl.assistant(sgl.gen("answer1"))
s += sgl.user(question2)
s += sgl.assistant(sgl.gen("answer2"))

Alternative: write_databricks_table

For Unity Catalog enabled workspaces, use ray.data.Dataset.write_databricks_table() to write directly to a Unity Catalog table:

Python
# Set the temp directory environment variable
os.environ["_RAY_UC_VOLUMES_FUSE_TEMP_DIR"] = "/Volumes/catalog/schema/volume/ray_temp"

# Write directly to Unity Catalog table
ds.write_databricks_table(table_name="catalog.schema.table_name")

Resources

Cleanup

GPU resources are automatically cleaned up when the notebook disconnects. To manually disconnect:

  1. Click Connected in the compute dropdown
  2. Hover over Serverless
  3. Select Terminate from the dropdown menu

Example notebook

Distributed LLM inference with Ray Data and SGLang on serverless GPU

Open notebook in new tab