Skip to main content

Distributed LLM Inference with Ray Data and vLLM on Serverless GPU

This notebook demonstrates how to run large language model (LLM) inference at scale using Ray Data and vLLM on Databricks Serverless GPU. It leverages the Distributed Serverless GPU API to automatically provision and manage multi-node A10 GPUs for distributed inference.

What you'll learn:

  • Set up Ray and vLLM for distributed LLM inference using map_batches
  • Use Ray Data to efficiently batch and process prompts across multiple GPUs
  • Save inference results to Parquet in Unity Catalog Volumes
  • Convert Parquet to Delta tables for governance and efficient querying
  • Monitor Ray cluster resources and GPU allocation

Use case: Batch inference on thousands of prompts with efficient GPU utilization, persistent storage, and Delta Lake integration.

Connect to serverless GPU compute

  1. Click the Connect dropdown at the top.
  2. Select Serverless GPU.
  3. Open the Environment side panel on the right side of the notebook.
  4. Set Accelerator to A10 for this demo.
  5. Click Apply and Confirm to apply this environment to your notebook.

Note: The distributed function will launch remote A10 GPUs for multi-node inference. The notebook itself runs on a single A10 for orchestration.

Install dependencies

Install all required packages for distributed Ray and vLLM inference:

  • Flash Attention: Optimized attention for faster inference (CUDA 12, PyTorch 2.6, A10 compatible)
  • vLLM: High-throughput LLM inference engine
  • Ray Data: Distributed data processing with LLM support (ray.data.llm API)
  • Transformers: Hugging Face model loading utilities
Python
%pip install --force-reinstall --no-cache-dir --no-deps "https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl"
%pip install "transformers<4.54.0"
%pip install "vllm==0.8.5.post1"
%pip install "ray[data]>=2.47.1" # Required for ray.data.llm API
%pip install "opentelemetry-exporter-prometheus"
%pip install "optree>=0.13.0"
%pip install hf_transfer
%pip install "numpy==1.26.4"
%restart_python

Verify package versions

Confirm that all required packages are installed with compatible versions.

Python
from packaging.version import Version

import torch
import flash_attn
import vllm
import ray
import transformers

print(f"PyTorch: {torch.__version__}")
print(f"Flash Attention: {flash_attn.__version__}")
print(f"vLLM: {vllm.__version__}")
print(f"Ray: {ray.__version__}")
print(f"Transformers: {transformers.__version__}")

assert Version(ray.__version__) >= Version("2.47.1"), "Ray version must be at least 2.47.1"
print("\n✓ All version checks passed!")


Configuration

Use widgets to configure inference parameters and optional Hugging Face authentication.

Security Note: Store your Hugging Face token in Databricks Secrets for production use. See Databricks Secrets documentation.

Python
# Widget configuration
dbutils.widgets.text("hf_secret_scope", "")
dbutils.widgets.text("hf_secret_key", "")
dbutils.widgets.text("model_name", "Qwen/Qwen3-4B-Instruct-2507")
dbutils.widgets.text("num_gpus", "5")
dbutils.widgets.text("num_prompts", "1000")

# Unity Catalog configuration for output storage
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_volume", "ray_data")
dbutils.widgets.text("uc_table", "llm_inference_results")

# Retrieve widget values
HF_SECRET_SCOPE = dbutils.widgets.get("hf_secret_scope")
HF_SECRET_KEY = dbutils.widgets.get("hf_secret_key")
MODEL_NAME = dbutils.widgets.get("model_name")
NUM_GPUS = int(dbutils.widgets.get("num_gpus"))
NUM_PROMPTS = int(dbutils.widgets.get("num_prompts"))

# Unity Catalog paths
UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_VOLUME = dbutils.widgets.get("uc_volume")
UC_TABLE = dbutils.widgets.get("uc_table")

# Construct paths
UC_VOLUME_PATH = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}"
UC_TABLE_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{UC_TABLE}"
PARQUET_OUTPUT_PATH = f"{UC_VOLUME_PATH}/inference_output"

print(f"Model: {MODEL_NAME}")
print(f"Number of GPUs: {NUM_GPUS}")
print(f"Number of prompts: {NUM_PROMPTS}")
print(f"\nUnity Catalog Configuration:")
print(f" Volume Path: {UC_VOLUME_PATH}")
print(f" Table Name: {UC_TABLE_NAME}")
print(f" Parquet Output: {PARQUET_OUTPUT_PATH}")

Authenticate with Hugging Face (Optional)

If using gated models (like Llama), authenticate with Hugging Face.

Option 1: Use Databricks Secrets (recommended for production)

Python
hf_token = dbutils.secrets.get(scope=HF_SECRET_SCOPE, key=HF_SECRET_KEY)

Option 2: Interactive login (for development)

Python
from huggingface_hub import login

# Uncomment ONE of the following options:

# Option 1: Use Databricks Secrets (recommended)
# if HF_SECRET_SCOPE and HF_SECRET_KEY:
# hf_token = dbutils.secrets.get(scope=HF_SECRET_SCOPE, key=HF_SECRET_KEY)
# login(token=hf_token)
# print("✓ Logged in using Databricks Secrets")

# Option 2: Interactive login
login()
print("✓ Hugging Face authentication complete")


Ray cluster resource monitoring

Utility function to inspect Ray cluster resources and verify GPU allocation across nodes.

Python
import json
import ray

def print_ray_resources():
"""Print Ray cluster resources and GPU allocation per node."""
try:
cluster_resources = ray.cluster_resources()
print("Ray Cluster Resources:")
print(json.dumps(cluster_resources, indent=2))

nodes = ray.nodes()
print(f"\nDetected {len(nodes)} Ray node(s):")

for node in nodes:
node_id = node.get("NodeID", "N/A")[:8] # Truncate for readability
ip_address = node.get("NodeManagerAddress", "N/A")
resources = node.get("Resources", {})
num_gpus = int(resources.get("GPU", 0))

print(f" • Node {node_id}... | IP: {ip_address} | GPUs: {num_gpus}")

# Show specific GPU IDs if available
gpu_ids = [k for k in resources.keys() if k.startswith("GPU_ID_")]
if gpu_ids:
print(f" GPU IDs: {', '.join(gpu_ids)}")

except Exception as e:
print(f"Error querying Ray cluster: {e}")

# Display current resources
# print_ray_resources()

Define the distributed inference task

The LLMPredictor class wraps vLLM for efficient batch inference. Ray Data distributes the workload across multiple GPU workers using map_batches.

Architecture overview

┌─────────────────┐    ┌─────────────────┐    ┌─────────────────┐
│ Ray Data │───▶│ LLMPredictor │───▶│ Parquet │
│ (Prompts) │ │ (vLLM Engine) │ │ (UC Volume) │
└─────────────────┘ └─────────────────┘ └─────────────────┘
│ │ │
▼ ▼ ▼
Distributed GPU Workers (A10) Delta Table
across nodes with vLLM instances (UC Table)
Python
from serverless_gpu.ray import ray_launch
import os

# Set Ray temp directory
os.environ['RAY_TEMP_DIR'] = '/tmp/ray'

@ray_launch(gpus=NUM_GPUS, gpu_type='a10', remote=True)
def run_distributed_inference():
"""Run distributed LLM inference using Ray Data and vLLM with map_batches."""
from typing import Dict, List
from datetime import datetime
import numpy as np
import ray
from vllm import LLM, SamplingParams

# Sample prompts for inference
base_prompts = [
"Hello, my name is",
"The president of the United States is",
"The future of AI is",
]

# Scale up prompts for distributed processing
prompts = base_prompts * (NUM_PROMPTS // len(base_prompts))
ds = ray.data.from_items(prompts)

print(f"✓ Created Ray dataset with {ds.count()} prompts")

# Sampling parameters for text generation
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=100
)

class LLMPredictor:
"""vLLM-based predictor for batch inference."""

def __init__(self):
self.llm = LLM(
model=MODEL_NAME,
tensor_parallel_size=1,
dtype="bfloat16",
trust_remote_code=True,
gpu_memory_utilization=0.90,
max_model_len=8192,
enable_prefix_caching=True,
enable_chunked_prefill=True,
max_num_batched_tokens=8192,
)
self.model_name = MODEL_NAME
print(f"✓ vLLM engine initialized with model: {MODEL_NAME}")

def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
"""Process a batch of prompts."""
outputs = self.llm.generate(batch["item"], sampling_params)

prompt_list: List[str] = []
generated_text_list: List[str] = []
model_list: List[str] = []
timestamp_list: List[str] = []

for output in outputs:
prompt_list.append(output.prompt)
generated_text_list.append(
' '.join([o.text for o in output.outputs])
)
model_list.append(self.model_name)
timestamp_list.append(datetime.now().isoformat())

return {
"prompt": prompt_list,
"generated_text": generated_text_list,
"model": model_list,
"timestamp": timestamp_list,
}

# Configure number of parallel vLLM instances
num_instances = NUM_GPUS

# Apply the predictor across the dataset using map_batches
ds = ds.map_batches(
LLMPredictor,
concurrency=num_instances,
batch_size=32,
num_gpus=1,
num_cpus=12
)

# =========================================================================
# Write results to Parquet (stored in Unity Catalog Volume)
# =========================================================================
print(f"\n📦 Writing results to Parquet: {PARQUET_OUTPUT_PATH}")
ds.write_parquet(PARQUET_OUTPUT_PATH, mode="overwrite")
print(f"✓ Parquet files written successfully")

# Collect sample outputs for display
sample_outputs = ray.data.read_parquet(PARQUET_OUTPUT_PATH).take(limit=10)

print("\n" + "="*60)
print("SAMPLE INFERENCE RESULTS")
print("="*60 + "\n")

for i, output in enumerate(sample_outputs):
prompt = output.get("prompt", "N/A")
generated_text = output.get("generated_text", "")
display_text = generated_text[:100] if generated_text else "N/A"
print(f"[{i+1}] Prompt: {prompt!r}")
print(f" Generated: {display_text!r}...\n")

return PARQUET_OUTPUT_PATH


Run distributed inference

Launch the distributed inference task across multiple A10 GPUs. This will:

  1. Provision remote A10 GPU workers
  2. Initialize vLLM engines on each worker
  3. Distribute prompts across workers using Ray Data
  4. Collect and return the generated outputs

Note: Initial startup may take a few minutes as GPU nodes are provisioned and models are loaded.

Python
result = run_distributed_inference.distributed()
parquet_path = result[0] if NUM_GPUS > 1 else result
print(f"\n✓ Inference complete! Results saved to: {parquet_path}")

Load Parquet and preview results

Load the Parquet output from Unity Catalog Volumes using Spark and preview the inference results.

Python
# Load Parquet data using Spark
print(f"📖 Loading Parquet from: {PARQUET_OUTPUT_PATH}")
df_spark = spark.read.parquet(PARQUET_OUTPUT_PATH)

# Show schema and row count
print(f"\n✓ Loaded {df_spark.count()} rows")
print("\nSchema:")
df_spark.printSchema()

# Display sample rows
print("\nSample Results:")
display(df_spark.limit(10))

Save as Delta table in Unity Catalog

Write the inference results to a Unity Catalog Delta table for:

  • Governance: Track data lineage and access controls
  • Performance: Optimized queries with Delta Lake
  • Versioning: Time travel and audit history
Python
# Write to Unity Catalog Delta table
print(f"💾 Writing to Delta table: {UC_TABLE_NAME}")

# Write the DataFrame as a Delta table (overwrite mode)
df_spark.write \
.format("delta") \
.mode("overwrite") \
.option("overwriteSchema", "true") \
.saveAsTable(UC_TABLE_NAME)

print(f"✓ Delta table created successfully: {UC_TABLE_NAME}")

Query the Delta table

Verify the Delta table was created and query it using SQL.

Python
# Query the Delta table using SQL
print(f"📊 Querying Delta table: {UC_TABLE_NAME}\n")

# Get table info
display(spark.sql(f"DESCRIBE TABLE {UC_TABLE_NAME}"))

# Query sample results
print("\nSample Results from Delta Table:")
display(spark.sql(f"""
SELECT
prompt,
generated_text,
model,
timestamp
FROM {UC_TABLE_NAME}
LIMIT 10
"""))

# Get row count and verify correctness
row_count = spark.sql(f"SELECT COUNT(*) as count FROM {UC_TABLE_NAME}").collect()[0]["count"]
print(f"\n✓ Total rows in Delta table: {row_count}")

# Assert expected row count (NUM_PROMPTS should result in 999 rows: 1000 // 3 * 3 = 999)
expected_rows = (NUM_PROMPTS // 3) * 3 # Rounds down to nearest multiple of 3 base prompts
assert row_count == expected_rows, f"Expected {expected_rows} rows, but got {row_count}"

Next steps

You've successfully run distributed LLM inference using Ray Data and vLLM on Databricks Serverless GPU, and saved the results to a Delta table!

What you accomplished

  1. ✅ Ran distributed LLM inference across multiple A10 GPUs
  2. ✅ Used map_batches with a custom LLMPredictor class for batch processing
  3. ✅ Saved results to Parquet in Unity Catalog Volumes
  4. ✅ Converted Parquet to a governed Delta table

Customization options

  • Change the model: Update model_name widget to use different Hugging Face models
  • Scale up: Increase num_gpus for higher throughput
  • Adjust batch size: Modify batch_size in map_batches() based on your memory constraints
  • Tune generation: Adjust SamplingParams for different output characteristics
  • Append mode: Change Delta write to mode("append") for incremental updates

Resources

Cleanup

GPU resources are automatically cleaned up when the notebook disconnects. To manually disconnect:

  1. Click Connected in the compute dropdown
  2. Hover over Serverless
  3. Select Terminate from the dropdown menu

Example notebook

Distributed LLM Inference with Ray Data and vLLM on Serverless GPU

Open notebook in new tab