Skip to main content

Batch inference with Ray Data and vLLM

Beta

This feature is in Beta. Workspace admins can control access to this feature from the Previews page. See Manage Databricks previews.

This example runs offline LLM batch inference with Ray Data and vLLM across 8 H100 GPUs on a single node. A bootstrap script starts a Ray cluster on the node, then the driver uses Ray Data's LLM API (ray.data.llm) to launch one vLLM replica per GPU and stream a dataset of prompts through them, writing the generated text to a Unity Catalog volume as Parquet.

It uses a public model (Qwen2.5-7B-Instruct), so it runs as-is without a Hugging Face token.

The workload does the following:

  • Uploads the local project with code_source: snapshot.
  • Starts a Ray head with all 8 GPUs, then runs the batch inference driver.
  • Uses ray.data.llm to run one vLLM replica per GPU and process prompts in parallel.
  • Writes the prompts and generated outputs to a Unity Catalog volume as Parquet.

Prerequisites

  • The air CLI installed and authenticated. See Install the AI Runtime CLI.
  • A Unity Catalog volume you can write to. You set its path in the workload YAML below.

Project layout

Create a directory with the following files.

Text
ray_batch_inference/
├── train.yaml # air workload config (inline dependencies + Ray bootstrap)
└── batch_inference.py # Ray Data + vLLM batch inference driver

Step 1: Write the workload YAML

train.yaml requests a single GPU_8xH100 node. Dependencies are declared inline under environment (with the client image version), and the command starts a Ray cluster on the node then runs the driver, so the workload doesn't need a separate dependency file or launcher script.

vLLM isn't in the base image, so it's installed inline along with three pins the GPU nodes need: hf_transfer (the base image enables fast Hugging Face downloads and expects this package), a newer fsspec (the base image ships an old one that breaks downloads), and a pinned opencv-python-headless (vLLM pulls in OpenCV, whose default wheel crashes the OpenSSL FIPS self-test on the GPU nodes).

Set OUTPUT_PATH to a Unity Catalog volume you can write to.

YAML
experiment_name: air-ray-batch-inference

environment:
version: '4'
dependencies:
- ray[data]>=2.44
- vllm
- datasets>=3.0
- huggingface_hub>=0.34
# The base image sets HF_HUB_ENABLE_HF_TRANSFER=1; install the package it expects
# so model and dataset downloads don't error out.
- hf_transfer
# The base image ships fsspec 2023.5.0, which is too old for modern
# huggingface_hub and breaks dataset/model downloads. Pin a newer fsspec.
- fsspec>=2024.6.1
# vLLM pulls in opencv; its default wheel crashes the OpenSSL FIPS self-test
# on the GPU nodes. This pinned headless build avoids the crash.
- opencv-python-headless==4.12.0.88

# 8 H100 on a single node. Ray Data runs one vLLM replica per GPU.
compute:
num_accelerators: 8
accelerator_type: GPU_8xH100

code_source:
type: snapshot
snapshot:
root_path: .

command: |
cd $CODE_SOURCE_PATH
RAY_HEAD_PORT=6379
GPUS_PER_NODE=${LOCAL_WORLD_SIZE:-8}
if [ "${NODE_RANK:-0}" = "0" ]; then
echo "NODE_RANK=0: starting Ray head with $GPUS_PER_NODE GPU(s)..."
ray start --head --port=$RAY_HEAD_PORT --num-gpus="$GPUS_PER_NODE" --dashboard-host=0.0.0.0
python batch_inference.py
ray stop
else
echo "NODE_RANK=$NODE_RANK: connecting to Ray head at $MASTER_ADDR:$RAY_HEAD_PORT..."
for i in $(seq 1 12); do
if ray start --address="$MASTER_ADDR:$RAY_HEAD_PORT" --num-gpus="$GPUS_PER_NODE" --block 2>/dev/null; then
break
fi
echo "Attempt $i failed, retrying in 5s..."
sleep 5
done
fi

max_retries: 0
timeout_minutes: 60
env_variables:
NCCL_SOCKET_IFNAME: eth0
# Unity Catalog volume where results land as Parquet. Replace with your volume.
OUTPUT_PATH: /Volumes/main/default/air_examples/ray_batch_inference

The inline command starts a Ray head with all GPUs on the node, runs the driver with python batch_inference.py, then stops the cluster. It also includes a worker branch that joins the head, so the same command keeps working if you scale the job to multiple nodes.

Step 2: Define the batch inference driver

batch_inference.py builds a Ray Dataset of prompts, configures a vLLM processor with ray.data.llm, and writes the results. concurrency is the number of vLLM replicas Ray Data runs in parallel. Setting it to the cluster's GPU count gives one replica per GPU, so the prompts are processed across every GPU at once and the example scales as you add nodes:

Python
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig

# Read the GPU count from the live Ray cluster so concurrency scales with the cluster.
total_gpus = int(ray.cluster_resources().get("GPU", 0))

config = vLLMEngineProcessorConfig(
model_source="Qwen/Qwen2.5-7B-Instruct",
engine_kwargs={"max_model_len": 4096, "tensor_parallel_size": 1},
concurrency=total_gpus, # one vLLM replica per GPU in the cluster
batch_size=64,
)

processor = build_llm_processor(
config,
preprocess=lambda row: dict(
messages=[{"role": "user", "content": row["instruction"]}],
sampling_params=dict(max_tokens=256, temperature=0.7),
),
postprocess=lambda row: dict(instruction=row["instruction"], output=row["generated_text"]),
)

out = processor(ds) # ds is a Ray Dataset with an "instruction" column
out.write_parquet(OUTPUT_PATH)

preprocess turns each input row into a chat request, and postprocess keeps the columns to persist. Ray Data adds a generated_text column with the model's output. The complete script is in Full driver script at the end of this page.

For larger models, set tensor_parallel_size to shard one replica across several GPUs, and divide total_gpus by that value so the replicas still fill the cluster, for example concurrency=total_gpus // 2 with tensor_parallel_size=2.

Step 3: Submit the run

Bash
air run -f train.yaml --dry-run
air run -f train.yaml --watch

Step 4: Inspect the run

Bash
air get run <run-id>
air logs <run-id>

The logs show the vLLM engine's prompt and generation throughput as the batch runs, then a Wrote <n> rows line when the output is written.

Where results land

The driver writes one Parquet dataset to the OUTPUT_PATH volume, with an instruction column and an output column. Read it back with Spark or pandas, for example spark.read.parquet(OUTPUT_PATH).

Full driver script

The complete batch_inference.py for copy-paste:

Python
#!/usr/bin/env python3
"""Offline batch inference with Ray Data + vLLM on a single 8x H100 node.

The workload `command` starts a Ray head with 8 GPUs and runs this script. Ray Data's
LLM API (`ray.data.llm`) launches one vLLM replica per GPU and streams a dataset of
prompts through them, then writes the generated text to a Unity Catalog volume as
Parquet.

Uses a public model (no Hugging Face token required) so the example runs as-is.
"""

import os

import ray
from datasets import load_dataset
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig

MODEL_SOURCE = "Qwen/Qwen2.5-7B-Instruct"
NUM_PROMPTS = 1000
# Unity Catalog volume path where results land as Parquet. Set this in train.yaml.
OUTPUT_PATH = os.environ.get("OUTPUT_PATH", "/Volumes/main/default/air_examples/ray_batch_inference")


def build_prompts():
"""Build a Ray Dataset of prompts from a public instruction dataset."""
raw = load_dataset("tatsu-lab/alpaca", split=f"train[:{NUM_PROMPTS}]")
items = []
for row in raw:
instruction = row["instruction"]
if row.get("input"):
instruction = f"{instruction}\n\n{row['input']}"
items.append({"instruction": instruction})
return ray.data.from_items(items)


def main():
ray.init(address="auto")
# Derive replicas from the live cluster so the example scales when nodes are added.
total_gpus = int(ray.cluster_resources().get("GPU", 0))
print(f"Ray cluster ready: {total_gpus} GPU(s)", flush=True)

ds = build_prompts()

# vLLM engine config. concurrency = number of replicas Ray Data runs in parallel;
# one per GPU in the cluster here. engine_kwargs are passed through to the vLLM engine.
config = vLLMEngineProcessorConfig(
model_source=MODEL_SOURCE,
engine_kwargs={
"max_model_len": 4096,
"tensor_parallel_size": 1,
"enable_chunked_prefill": True,
},
concurrency=total_gpus,
batch_size=64,
)

# preprocess maps each input row to a chat request; postprocess keeps the columns
# we want to persist. ray.data.llm adds a `generated_text` column.
processor = build_llm_processor(
config,
preprocess=lambda row: dict(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": row["instruction"]},
],
sampling_params=dict(max_tokens=256, temperature=0.7),
),
postprocess=lambda row: dict(
instruction=row["instruction"],
output=row["generated_text"],
),
)

# materialize once so the write and the sample print don't re-run inference.
out = processor(ds).materialize()
out.write_parquet(OUTPUT_PATH)
print(f"Wrote {out.count()} rows to {OUTPUT_PATH}", flush=True)

for row in out.take(2):
print("INSTRUCTION:", row["instruction"][:120], flush=True)
print("OUTPUT:", row["output"][:200], flush=True)

ray.shutdown()


if __name__ == "__main__":
main()

Next steps