Distributed LLM Inference with Ray Data and vLLM on Serverless GPU
This notebook demonstrates how to run large language model (LLM) inference at scale using Ray Data and vLLM on Databricks Serverless GPU. It leverages the Distributed Serverless GPU API to automatically provision and manage multi-node A10 GPUs for distributed inference.
What you'll learn:
- Set up Ray and vLLM for distributed LLM inference using
map_batches - Use Ray Data to efficiently batch and process prompts across multiple GPUs
- Save inference results to Parquet in Unity Catalog Volumes
- Convert Parquet to Delta tables for governance and efficient querying
- Monitor Ray cluster resources and GPU allocation
Use case: Batch inference on thousands of prompts with efficient GPU utilization, persistent storage, and Delta Lake integration.
Connect to serverless GPU compute
- Click the Connect dropdown at the top.
- Select Serverless GPU.
- Open the Environment side panel on the right side of the notebook.
- Set Accelerator to A10 for this demo.
- Click Apply and Confirm to apply this environment to your notebook.
Note: The distributed function will launch remote A10 GPUs for multi-node inference. The notebook itself runs on a single A10 for orchestration.
Install dependencies
Install all required packages for distributed Ray and vLLM inference:
- Flash Attention: Optimized attention for faster inference (CUDA 12, PyTorch 2.6, A10 compatible)
- vLLM: High-throughput LLM inference engine
- Ray Data: Distributed data processing with LLM support (
ray.data.llmAPI) - Transformers: Hugging Face model loading utilities
%pip install --force-reinstall --no-cache-dir --no-deps "https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl"
%pip install "transformers<4.54.0"
%pip install "vllm==0.8.5.post1"
%pip install "ray[data]>=2.47.1" # Required for ray.data.llm API
%pip install "opentelemetry-exporter-prometheus"
%pip install "optree>=0.13.0"
%pip install hf_transfer
%pip install "numpy==1.26.4"
%restart_python
Verify package versions
Confirm that all required packages are installed with compatible versions.
from packaging.version import Version
import torch
import flash_attn
import vllm
import ray
import transformers
print(f"PyTorch: {torch.__version__}")
print(f"Flash Attention: {flash_attn.__version__}")
print(f"vLLM: {vllm.__version__}")
print(f"Ray: {ray.__version__}")
print(f"Transformers: {transformers.__version__}")
assert Version(ray.__version__) >= Version("2.47.1"), "Ray version must be at least 2.47.1"
print("\n✓ All version checks passed!")
Configuration
Use widgets to configure inference parameters and optional Hugging Face authentication.
Security Note: Store your Hugging Face token in Databricks Secrets for production use. See Databricks Secrets documentation.
# Widget configuration
dbutils.widgets.text("hf_secret_scope", "")
dbutils.widgets.text("hf_secret_key", "")
dbutils.widgets.text("model_name", "Qwen/Qwen3-4B-Instruct-2507")
dbutils.widgets.text("num_gpus", "5")
dbutils.widgets.text("num_prompts", "1000")
# Unity Catalog configuration for output storage
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_volume", "ray_data")
dbutils.widgets.text("uc_table", "llm_inference_results")
# Retrieve widget values
HF_SECRET_SCOPE = dbutils.widgets.get("hf_secret_scope")
HF_SECRET_KEY = dbutils.widgets.get("hf_secret_key")
MODEL_NAME = dbutils.widgets.get("model_name")
NUM_GPUS = int(dbutils.widgets.get("num_gpus"))
NUM_PROMPTS = int(dbutils.widgets.get("num_prompts"))
# Unity Catalog paths
UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_VOLUME = dbutils.widgets.get("uc_volume")
UC_TABLE = dbutils.widgets.get("uc_table")
# Construct paths
UC_VOLUME_PATH = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}"
UC_TABLE_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{UC_TABLE}"
PARQUET_OUTPUT_PATH = f"{UC_VOLUME_PATH}/inference_output"
print(f"Model: {MODEL_NAME}")
print(f"Number of GPUs: {NUM_GPUS}")
print(f"Number of prompts: {NUM_PROMPTS}")
print(f"\nUnity Catalog Configuration:")
print(f" Volume Path: {UC_VOLUME_PATH}")
print(f" Table Name: {UC_TABLE_NAME}")
print(f" Parquet Output: {PARQUET_OUTPUT_PATH}")
Authenticate with Hugging Face (Optional)
If using gated models (like Llama), authenticate with Hugging Face.
Option 1: Use Databricks Secrets (recommended for production)
hf_token = dbutils.secrets.get(scope=HF_SECRET_SCOPE, key=HF_SECRET_KEY)
Option 2: Interactive login (for development)
from huggingface_hub import login
# Uncomment ONE of the following options:
# Option 1: Use Databricks Secrets (recommended)
# if HF_SECRET_SCOPE and HF_SECRET_KEY:
# hf_token = dbutils.secrets.get(scope=HF_SECRET_SCOPE, key=HF_SECRET_KEY)
# login(token=hf_token)
# print("✓ Logged in using Databricks Secrets")
# Option 2: Interactive login
login()
print("✓ Hugging Face authentication complete")
Ray cluster resource monitoring
Utility function to inspect Ray cluster resources and verify GPU allocation across nodes.
import json
import ray
def print_ray_resources():
"""Print Ray cluster resources and GPU allocation per node."""
try:
cluster_resources = ray.cluster_resources()
print("Ray Cluster Resources:")
print(json.dumps(cluster_resources, indent=2))
nodes = ray.nodes()
print(f"\nDetected {len(nodes)} Ray node(s):")
for node in nodes:
node_id = node.get("NodeID", "N/A")[:8] # Truncate for readability
ip_address = node.get("NodeManagerAddress", "N/A")
resources = node.get("Resources", {})
num_gpus = int(resources.get("GPU", 0))
print(f" • Node {node_id}... | IP: {ip_address} | GPUs: {num_gpus}")
# Show specific GPU IDs if available
gpu_ids = [k for k in resources.keys() if k.startswith("GPU_ID_")]
if gpu_ids:
print(f" GPU IDs: {', '.join(gpu_ids)}")
except Exception as e:
print(f"Error querying Ray cluster: {e}")
# Display current resources
# print_ray_resources()
Define the distributed inference task
The LLMPredictor class wraps vLLM for efficient batch inference. Ray Data distributes the workload across multiple GPU workers using map_batches.
Architecture overview
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ Ray Data │───▶│ LLMPredictor │───▶│ Parquet │
│ (Prompts) │ │ (vLLM Engine) │ │ (UC Volume) │
└─────────────────┘ └─────────────────┘ └─────────────────┘
│ │ │
▼ ▼ ▼
Distributed GPU Workers (A10) Delta Table
across nodes with vLLM instances (UC Table)
from serverless_gpu.ray import ray_launch
import os
# Set Ray temp directory
os.environ['RAY_TEMP_DIR'] = '/tmp/ray'
@ray_launch(gpus=NUM_GPUS, gpu_type='a10', remote=True)
def run_distributed_inference():
"""Run distributed LLM inference using Ray Data and vLLM with map_batches."""
from typing import Dict, List
from datetime import datetime
import numpy as np
import ray
from vllm import LLM, SamplingParams
# Sample prompts for inference
base_prompts = [
"Hello, my name is",
"The president of the United States is",
"The future of AI is",
]
# Scale up prompts for distributed processing
prompts = base_prompts * (NUM_PROMPTS // len(base_prompts))
ds = ray.data.from_items(prompts)
print(f"✓ Created Ray dataset with {ds.count()} prompts")
# Sampling parameters for text generation
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=100
)
class LLMPredictor:
"""vLLM-based predictor for batch inference."""
def __init__(self):
self.llm = LLM(
model=MODEL_NAME,
tensor_parallel_size=1,
dtype="bfloat16",
trust_remote_code=True,
gpu_memory_utilization=0.90,
max_model_len=8192,
enable_prefix_caching=True,
enable_chunked_prefill=True,
max_num_batched_tokens=8192,
)
self.model_name = MODEL_NAME
print(f"✓ vLLM engine initialized with model: {MODEL_NAME}")
def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
"""Process a batch of prompts."""
outputs = self.llm.generate(batch["item"], sampling_params)
prompt_list: List[str] = []
generated_text_list: List[str] = []
model_list: List[str] = []
timestamp_list: List[str] = []
for output in outputs:
prompt_list.append(output.prompt)
generated_text_list.append(
' '.join([o.text for o in output.outputs])
)
model_list.append(self.model_name)
timestamp_list.append(datetime.now().isoformat())
return {
"prompt": prompt_list,
"generated_text": generated_text_list,
"model": model_list,
"timestamp": timestamp_list,
}
# Configure number of parallel vLLM instances
num_instances = NUM_GPUS
# Apply the predictor across the dataset using map_batches
ds = ds.map_batches(
LLMPredictor,
concurrency=num_instances,
batch_size=32,
num_gpus=1,
num_cpus=12
)
# =========================================================================
# Write results to Parquet (stored in Unity Catalog Volume)
# =========================================================================
print(f"\n📦 Writing results to Parquet: {PARQUET_OUTPUT_PATH}")
ds.write_parquet(PARQUET_OUTPUT_PATH, mode="overwrite")
print(f"✓ Parquet files written successfully")
# Collect sample outputs for display
sample_outputs = ray.data.read_parquet(PARQUET_OUTPUT_PATH).take(limit=10)
print("\n" + "="*60)
print("SAMPLE INFERENCE RESULTS")
print("="*60 + "\n")
for i, output in enumerate(sample_outputs):
prompt = output.get("prompt", "N/A")
generated_text = output.get("generated_text", "")
display_text = generated_text[:100] if generated_text else "N/A"
print(f"[{i+1}] Prompt: {prompt!r}")
print(f" Generated: {display_text!r}...\n")
return PARQUET_OUTPUT_PATH
Run distributed inference
Launch the distributed inference task across multiple A10 GPUs. This will:
- Provision remote A10 GPU workers
- Initialize vLLM engines on each worker
- Distribute prompts across workers using Ray Data
- Collect and return the generated outputs
Note: Initial startup may take a few minutes as GPU nodes are provisioned and models are loaded.
result = run_distributed_inference.distributed()
parquet_path = result[0] if NUM_GPUS > 1 else result
print(f"\n✓ Inference complete! Results saved to: {parquet_path}")
Load Parquet and preview results
Load the Parquet output from Unity Catalog Volumes using Spark and preview the inference results.
# Load Parquet data using Spark
print(f"📖 Loading Parquet from: {PARQUET_OUTPUT_PATH}")
df_spark = spark.read.parquet(PARQUET_OUTPUT_PATH)
# Show schema and row count
print(f"\n✓ Loaded {df_spark.count()} rows")
print("\nSchema:")
df_spark.printSchema()
# Display sample rows
print("\nSample Results:")
display(df_spark.limit(10))
Save as Delta table in Unity Catalog
Write the inference results to a Unity Catalog Delta table for:
- Governance: Track data lineage and access controls
- Performance: Optimized queries with Delta Lake
- Versioning: Time travel and audit history
# Write to Unity Catalog Delta table
print(f"💾 Writing to Delta table: {UC_TABLE_NAME}")
# Write the DataFrame as a Delta table (overwrite mode)
df_spark.write \
.format("delta") \
.mode("overwrite") \
.option("overwriteSchema", "true") \
.saveAsTable(UC_TABLE_NAME)
print(f"✓ Delta table created successfully: {UC_TABLE_NAME}")
Query the Delta table
Verify the Delta table was created and query it using SQL.
# Query the Delta table using SQL
print(f"📊 Querying Delta table: {UC_TABLE_NAME}\n")
# Get table info
display(spark.sql(f"DESCRIBE TABLE {UC_TABLE_NAME}"))
# Query sample results
print("\nSample Results from Delta Table:")
display(spark.sql(f"""
SELECT
prompt,
generated_text,
model,
timestamp
FROM {UC_TABLE_NAME}
LIMIT 10
"""))
# Get row count and verify correctness
row_count = spark.sql(f"SELECT COUNT(*) as count FROM {UC_TABLE_NAME}").collect()[0]["count"]
print(f"\n✓ Total rows in Delta table: {row_count}")
# Assert expected row count (NUM_PROMPTS should result in 999 rows: 1000 // 3 * 3 = 999)
expected_rows = (NUM_PROMPTS // 3) * 3 # Rounds down to nearest multiple of 3 base prompts
assert row_count == expected_rows, f"Expected {expected_rows} rows, but got {row_count}"
Next steps
You've successfully run distributed LLM inference using Ray Data and vLLM on Databricks Serverless GPU, and saved the results to a Delta table!
What you accomplished
- ✅ Ran distributed LLM inference across multiple A10 GPUs
- ✅ Used
map_batcheswith a customLLMPredictorclass for batch processing - ✅ Saved results to Parquet in Unity Catalog Volumes
- ✅ Converted Parquet to a governed Delta table
Customization options
- Change the model: Update
model_namewidget to use different Hugging Face models - Scale up: Increase
num_gpusfor higher throughput - Adjust batch size: Modify
batch_sizeinmap_batches()based on your memory constraints - Tune generation: Adjust
SamplingParamsfor different output characteristics - Append mode: Change Delta write to
mode("append")for incremental updates
Resources
- Serverless GPU API Documentation
- Ray Data Documentation
- vLLM Documentation
- Unity Catalog Documentation
Cleanup
GPU resources are automatically cleaned up when the notebook disconnects. To manually disconnect:
- Click Connected in the compute dropdown
- Hover over Serverless
- Select Terminate from the dropdown menu