Ray Data と vLLM を使用したバッチ推論
ベータ版
この機能はベータ版です。ワークスペース管理者は、 プレビュー ページからこの機能へのアクセスを制御できます。Databricksのプレビューを管理するを参照してください。
この例では、Ray Data および vLLM を使用して、単一ノード上の 8 個の H100 GPU でオフライン LLM バッチ推論が実行されます。ブートストラップ スクリプトがノード上で Ray クラスターを起動し、その後、ドライバーは Ray Data の LLM API (ray.data.llm) を使用して、GPU ごとに 1 つの vLLM レプリカを起動し、プロンプトのデータセットをそれらを介してストリームし、生成されたテキストを Parquet として Unity Catalog ボリュームに書き込みます。
パブリック モデル (Qwen2.5-7B-Instruct) を使用します。そのため、Hugging Face トークンなしでそのまま実行できます。
ワークロードでは、次の処理を行います:
- ローカルプロジェクトを
code_source: snapshotでアップロードします。 - 8 個の GPU すべてを使用して Ray ヘッドを起動し、その後バッチ推論ドライバーを実行します。
ray.data.llmを使用して、GPU ごとに 1 つの vLLM レプリカを実行し、プロンプトを並列処理します。- プロンプトと生成された出力を Parquet として Unity Catalog ボリュームに書き込みます。
前提条件
airCLI がインストールされ、認証されています。「AI Runtime CLI のインストール」を参照してください。- 書き込み可能な Unity Catalog ボリューム。以下でワークロード YAML 内のそのパスを設定します。
プロジェクト レイアウト
次のファイルを含むディレクトリを作成します。
ray_batch_inference/
├── train.yaml # air workload config (inline dependencies + Ray bootstrap)
└── batch_inference.py # Ray Data + vLLM batch inference driver
ステップ 1: ワークロード YAML を記述する
train.yaml 単一のGPU_8xH100ノードを必要とします。environment の下にインラインで依存関係が宣言され(クライアントイメージ version を使用)、command がノード上で Ray クラスターを起動してドライバーを実行するため、ワークロードには個別の依存関係ファイルやランチャースクリプトは不要です。
vLLM はベース イメージに含まれていないため、GPU ノードに必要な 3 つのピン留め対象と一緒にインラインでインストールされます:hf_transfer(ベース イメージは Hugging Face の高速ダウンロードを有効にし、このパッケージを想定しています)、新しい fsspec(ベース イメージには古いバージョンが含まれており、ダウンロードが失敗します)、およびピン留めされた opencv-python-headless(vLLM により OpenCV が導入されますが、そのデフォルトの wheel は GPU ノード上で OpenSSL FIPS 自己テストをクラッシュさせます)。
OUTPUT_PATH を書き込み可能な Unity Catalog ボリュームに設定します。
experiment_name: air-ray-batch-inference
environment:
version: '4'
dependencies:
- ray[data]>=2.44
- vllm
- datasets>=3.0
- huggingface_hub>=0.34
# The base image sets HF_HUB_ENABLE_HF_TRANSFER=1; install the package it expects
# so model and dataset downloads don't error out.
- hf_transfer
# The base image ships fsspec 2023.5.0, which is too old for modern
# huggingface_hub and breaks dataset/model downloads. Pin a newer fsspec.
- fsspec>=2024.6.1
# vLLM pulls in opencv; its default wheel crashes the OpenSSL FIPS self-test
# on the GPU nodes. This pinned headless build avoids the crash.
- opencv-python-headless==4.12.0.88
# 8 H100 on a single node. Ray Data runs one vLLM replica per GPU.
compute:
num_accelerators: 8
accelerator_type: GPU_8xH100
code_source:
type: snapshot
snapshot:
root_path: .
command: |
cd $CODE_SOURCE_PATH
RAY_HEAD_PORT=6379
GPUS_PER_NODE=${LOCAL_WORLD_SIZE:-8}
if [ "${NODE_RANK:-0}" = "0" ]; then
echo "NODE_RANK=0: starting Ray head with $GPUS_PER_NODE GPU(s)..."
ray start --head --port=$RAY_HEAD_PORT --num-gpus="$GPUS_PER_NODE" --dashboard-host=0.0.0.0
python batch_inference.py
ray stop
else
echo "NODE_RANK=$NODE_RANK: connecting to Ray head at $MASTER_ADDR:$RAY_HEAD_PORT..."
for i in $(seq 1 12); do
if ray start --address="$MASTER_ADDR:$RAY_HEAD_PORT" --num-gpus="$GPUS_PER_NODE" --block 2>/dev/null; then
break
fi
echo "Attempt $i failed, retrying in 5s..."
sleep 5
done
fi
max_retries: 0
timeout_minutes: 60
env_variables:
NCCL_SOCKET_IFNAME: eth0
# Unity Catalog volume where results land as Parquet. Replace with your volume.
OUTPUT_PATH: /Volumes/main/default/air_examples/ray_batch_inference
インラインcommandは、ノード上のすべてのGPUでRayヘッドを起動し、python batch_inference.pyでドライバーを実行した後、クラスターを停止します。ヘッドを結合するワーカーブランチも含まれているため、ジョブを複数のノードに拡張した場合でも、同じコマンドが動作し続けます。
ステップ2: バッチ推論ドライバーを定義する
batch_inference.py プロンプトのRayデータセットを構築し、ray.data.llmでvLLMプロセッサを設定し、結果を書き込みます。concurrencyは、Ray Dataが並行して実行するvLLMレプリカの数です。これをクラスターのGPU数に設定すると、GPUごとに1つのレプリカが提供されます。これにより、プロンプトはすべてのGPUで一度に処理され、ノードを追加するにつれて例はスケーリングされます。
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig
# Read the GPU count from the live Ray cluster so concurrency scales with the cluster.
total_gpus = int(ray.cluster_resources().get("GPU", 0))
config = vLLMEngineProcessorConfig(
model_source="Qwen/Qwen2.5-7B-Instruct",
engine_kwargs={"max_model_len": 4096, "tensor_parallel_size": 1},
concurrency=total_gpus, # one vLLM replica per GPU in the cluster
batch_size=64,
)
processor = build_llm_processor(
config,
preprocess=lambda row: dict(
messages=[{"role": "user", "content": row["instruction"]}],
sampling_params=dict(max_tokens=256, temperature=0.7),
),
postprocess=lambda row: dict(instruction=row["instruction"], output=row["generated_text"]),
)
out = processor(ds) # ds is a Ray Dataset with an "instruction" column
out.write_parquet(OUTPUT_PATH)
preprocess 各入力行をチャットリクエストに変換し、postprocess が保持する列を維持します。Ray Data はモデルの出力を含む generated_text 列を追加します。完全なスクリプトは、このページの最後にある「Full driver script」にあります。
大規模なモデルの場合、tensor_parallel_size を設定して1つのレプリカを複数のGPUにシャーディングし、レプリカが引き続きクラスターを埋めるように、その値で total_gpus を除算します(たとえば、tensor_parallel_size=2 を使用して concurrency=total_gpus // 2)。
ステップ3:実行を送信する
air run -f train.yaml --dry-run
air run -f train.yaml --watch
ステップ 4: 実行を確認する
air get run <run-id>
air logs <run-id>
ログには、バッチの実行中に vLLM エンジンのプロンプトと生成スループットが表示され、その後出力が書き込まれると Wrote <n> rows 行が表示されます。
結果の保存場所
ドライバーは、OUTPUT_PATH ボリュームに 1 つの Parquet データセットを書き込み、instruction 列と output 列を含めます。たとえば、Spark または pandas で読み戻します。
spark.read.parquet(OUTPUT_PATH)。
完全なドライバー スクリプト
コピー/貼り付けの完全なbatch_inference.py
#!/usr/bin/env python3
"""Offline batch inference with Ray Data + vLLM on a single 8x H100 node.
The workload `command` starts a Ray head with 8 GPUs and runs this script. Ray Data's
LLM API (`ray.data.llm`) launches one vLLM replica per GPU and streams a dataset of
prompts through them, then writes the generated text to a Unity Catalog volume as
Parquet.
Uses a public model (no Hugging Face token required) so the example runs as-is.
"""
import os
import ray
from datasets import load_dataset
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig
MODEL_SOURCE = "Qwen/Qwen2.5-7B-Instruct"
NUM_PROMPTS = 1000
# Unity Catalog volume path where results land as Parquet. Set this in train.yaml.
OUTPUT_PATH = os.environ.get("OUTPUT_PATH", "/Volumes/main/default/air_examples/ray_batch_inference")
def build_prompts():
"""Build a Ray Dataset of prompts from a public instruction dataset."""
raw = load_dataset("tatsu-lab/alpaca", split=f"train[:{NUM_PROMPTS}]")
items = []
for row in raw:
instruction = row["instruction"]
if row.get("input"):
instruction = f"{instruction}\n\n{row['input']}"
items.append({"instruction": instruction})
return ray.data.from_items(items)
def main():
ray.init(address="auto")
# Derive replicas from the live cluster so the example scales when nodes are added.
total_gpus = int(ray.cluster_resources().get("GPU", 0))
print(f"Ray cluster ready: {total_gpus} GPU(s)", flush=True)
ds = build_prompts()
# vLLM engine config. concurrency = number of replicas Ray Data runs in parallel;
# one per GPU in the cluster here. engine_kwargs are passed through to the vLLM engine.
config = vLLMEngineProcessorConfig(
model_source=MODEL_SOURCE,
engine_kwargs={
"max_model_len": 4096,
"tensor_parallel_size": 1,
"enable_chunked_prefill": True,
},
concurrency=total_gpus,
batch_size=64,
)
# preprocess maps each input row to a chat request; postprocess keeps the columns
# we want to persist. ray.data.llm adds a `generated_text` column.
processor = build_llm_processor(
config,
preprocess=lambda row: dict(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": row["instruction"]},
],
sampling_params=dict(max_tokens=256, temperature=0.7),
),
postprocess=lambda row: dict(
instruction=row["instruction"],
output=row["generated_text"],
),
)
# materialize once so the write and the sample print don't re-run inference.
out = processor(ds).materialize()
out.write_parquet(OUTPUT_PATH)
print(f"Wrote {out.count()} rows to {OUTPUT_PATH}", flush=True)
for row in out.take(2):
print("INSTRUCTION:", row["instruction"][:120], flush=True)
print("OUTPUT:", row["output"][:200], flush=True)
ray.shutdown()
if __name__ == "__main__":
main()