メインコンテンツまでスキップ

Ray Data と vLLM を使用したバッチ推論

備考

ベータ版

この機能はベータ版です。ワークスペース管理者は、 プレビュー ページからこの機能へのアクセスを制御できます。Databricksのプレビューを管理するを参照してください。

この例では、Ray Data および vLLM を使用して、単一ノード上の 8 個の H100 GPU でオフライン LLM バッチ推論が実行されます。ブートストラップ スクリプトがノード上で Ray クラスターを起動し、その後、ドライバーは Ray Data の LLM API (ray.data.llm) を使用して、GPU ごとに 1 つの vLLM レプリカを起動し、プロンプトのデータセットをそれらを介してストリームし、生成されたテキストを Parquet として Unity Catalog ボリュームに書き込みます。

パブリック モデル (Qwen2.5-7B-Instruct) を使用します。そのため、Hugging Face トークンなしでそのまま実行できます。

ワークロードでは、次の処理を行います:

  • ローカルプロジェクトを code_source: snapshot でアップロードします。
  • 8 個の GPU すべてを使用して Ray ヘッドを起動し、その後バッチ推論ドライバーを実行します。
  • ray.data.llm を使用して、GPU ごとに 1 つの vLLM レプリカを実行し、プロンプトを並列処理します。
  • プロンプトと生成された出力を Parquet として Unity Catalog ボリュームに書き込みます。

前提条件

  • air CLI がインストールされ、認証されています。「AI Runtime CLI のインストール」を参照してください。
  • 書き込み可能な Unity Catalog ボリューム。以下でワークロード YAML 内のそのパスを設定します。

プロジェクト レイアウト

次のファイルを含むディレクトリを作成します。

Text
ray_batch_inference/
├── train.yaml # air workload config (inline dependencies + Ray bootstrap)
└── batch_inference.py # Ray Data + vLLM batch inference driver

ステップ 1: ワークロード YAML を記述する

train.yaml 単一のGPU_8xH100ノードを必要とします。environment の下にインラインで依存関係が宣言され(クライアントイメージ version を使用)、command がノード上で Ray クラスターを起動してドライバーを実行するため、ワークロードには個別の依存関係ファイルやランチャースクリプトは不要です。

vLLM はベース イメージに含まれていないため、GPU ノードに必要な 3 つのピン留め対象と一緒にインラインでインストールされます:hf_transfer(ベース イメージは Hugging Face の高速ダウンロードを有効にし、このパッケージを想定しています)、新しい fsspec(ベース イメージには古いバージョンが含まれており、ダウンロードが失敗します)、およびピン留めされた opencv-python-headless(vLLM により OpenCV が導入されますが、そのデフォルトの wheel は GPU ノード上で OpenSSL FIPS 自己テストをクラッシュさせます)。

OUTPUT_PATH を書き込み可能な Unity Catalog ボリュームに設定します。

YAML
experiment_name: air-ray-batch-inference

environment:
version: '4'
dependencies:
- ray[data]>=2.44
- vllm
- datasets>=3.0
- huggingface_hub>=0.34
# The base image sets HF_HUB_ENABLE_HF_TRANSFER=1; install the package it expects
# so model and dataset downloads don't error out.
- hf_transfer
# The base image ships fsspec 2023.5.0, which is too old for modern
# huggingface_hub and breaks dataset/model downloads. Pin a newer fsspec.
- fsspec>=2024.6.1
# vLLM pulls in opencv; its default wheel crashes the OpenSSL FIPS self-test
# on the GPU nodes. This pinned headless build avoids the crash.
- opencv-python-headless==4.12.0.88

# 8 H100 on a single node. Ray Data runs one vLLM replica per GPU.
compute:
num_accelerators: 8
accelerator_type: GPU_8xH100

code_source:
type: snapshot
snapshot:
root_path: .

command: |
cd $CODE_SOURCE_PATH
RAY_HEAD_PORT=6379
GPUS_PER_NODE=${LOCAL_WORLD_SIZE:-8}
if [ "${NODE_RANK:-0}" = "0" ]; then
echo "NODE_RANK=0: starting Ray head with $GPUS_PER_NODE GPU(s)..."
ray start --head --port=$RAY_HEAD_PORT --num-gpus="$GPUS_PER_NODE" --dashboard-host=0.0.0.0
python batch_inference.py
ray stop
else
echo "NODE_RANK=$NODE_RANK: connecting to Ray head at $MASTER_ADDR:$RAY_HEAD_PORT..."
for i in $(seq 1 12); do
if ray start --address="$MASTER_ADDR:$RAY_HEAD_PORT" --num-gpus="$GPUS_PER_NODE" --block 2>/dev/null; then
break
fi
echo "Attempt $i failed, retrying in 5s..."
sleep 5
done
fi

max_retries: 0
timeout_minutes: 60
env_variables:
NCCL_SOCKET_IFNAME: eth0
# Unity Catalog volume where results land as Parquet. Replace with your volume.
OUTPUT_PATH: /Volumes/main/default/air_examples/ray_batch_inference

インラインcommandは、ノード上のすべてのGPUでRayヘッドを起動し、python batch_inference.pyでドライバーを実行した後、クラスターを停止します。ヘッドを結合するワーカーブランチも含まれているため、ジョブを複数のノードに拡張した場合でも、同じコマンドが動作し続けます。

ステップ2: バッチ推論ドライバーを定義する

batch_inference.py プロンプトのRayデータセットを構築し、ray.data.llmでvLLMプロセッサを設定し、結果を書き込みます。concurrencyは、Ray Dataが並行して実行するvLLMレプリカの数です。これをクラスターのGPU数に設定すると、GPUごとに1つのレプリカが提供されます。これにより、プロンプトはすべてのGPUで一度に処理され、ノードを追加するにつれて例はスケーリングされます。

Python
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig

# Read the GPU count from the live Ray cluster so concurrency scales with the cluster.
total_gpus = int(ray.cluster_resources().get("GPU", 0))

config = vLLMEngineProcessorConfig(
model_source="Qwen/Qwen2.5-7B-Instruct",
engine_kwargs={"max_model_len": 4096, "tensor_parallel_size": 1},
concurrency=total_gpus, # one vLLM replica per GPU in the cluster
batch_size=64,
)

processor = build_llm_processor(
config,
preprocess=lambda row: dict(
messages=[{"role": "user", "content": row["instruction"]}],
sampling_params=dict(max_tokens=256, temperature=0.7),
),
postprocess=lambda row: dict(instruction=row["instruction"], output=row["generated_text"]),
)

out = processor(ds) # ds is a Ray Dataset with an "instruction" column
out.write_parquet(OUTPUT_PATH)

preprocess 各入力行をチャットリクエストに変換し、postprocess が保持する列を維持します。Ray Data はモデルの出力を含む generated_text 列を追加します。完全なスクリプトは、このページの最後にある「Full driver script」にあります。

大規模なモデルの場合、tensor_parallel_size を設定して1つのレプリカを複数のGPUにシャーディングし、レプリカが引き続きクラスターを埋めるように、その値で total_gpus を除算します(たとえば、tensor_parallel_size=2 を使用して concurrency=total_gpus // 2)。

ステップ3:実行を送信する

Bash
air run -f train.yaml --dry-run
air run -f train.yaml --watch

ステップ 4: 実行を確認する

Bash
air get run <run-id>
air logs <run-id>

ログには、バッチの実行中に vLLM エンジンのプロンプトと生成スループットが表示され、その後出力が書き込まれると Wrote <n> rows 行が表示されます。

結果の保存場所

ドライバーは、OUTPUT_PATH ボリュームに 1 つの Parquet データセットを書き込み、instruction 列と output 列を含めます。たとえば、Spark または pandas で読み戻します。 spark.read.parquet(OUTPUT_PATH)

完全なドライバー スクリプト

コピー/貼り付けの完全なbatch_inference.py

Python
#!/usr/bin/env python3
"""Offline batch inference with Ray Data + vLLM on a single 8x H100 node.

The workload `command` starts a Ray head with 8 GPUs and runs this script. Ray Data's
LLM API (`ray.data.llm`) launches one vLLM replica per GPU and streams a dataset of
prompts through them, then writes the generated text to a Unity Catalog volume as
Parquet.

Uses a public model (no Hugging Face token required) so the example runs as-is.
"""

import os

import ray
from datasets import load_dataset
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig

MODEL_SOURCE = "Qwen/Qwen2.5-7B-Instruct"
NUM_PROMPTS = 1000
# Unity Catalog volume path where results land as Parquet. Set this in train.yaml.
OUTPUT_PATH = os.environ.get("OUTPUT_PATH", "/Volumes/main/default/air_examples/ray_batch_inference")


def build_prompts():
"""Build a Ray Dataset of prompts from a public instruction dataset."""
raw = load_dataset("tatsu-lab/alpaca", split=f"train[:{NUM_PROMPTS}]")
items = []
for row in raw:
instruction = row["instruction"]
if row.get("input"):
instruction = f"{instruction}\n\n{row['input']}"
items.append({"instruction": instruction})
return ray.data.from_items(items)


def main():
ray.init(address="auto")
# Derive replicas from the live cluster so the example scales when nodes are added.
total_gpus = int(ray.cluster_resources().get("GPU", 0))
print(f"Ray cluster ready: {total_gpus} GPU(s)", flush=True)

ds = build_prompts()

# vLLM engine config. concurrency = number of replicas Ray Data runs in parallel;
# one per GPU in the cluster here. engine_kwargs are passed through to the vLLM engine.
config = vLLMEngineProcessorConfig(
model_source=MODEL_SOURCE,
engine_kwargs={
&quot;max_model_len&quot;: 4096,
&quot;tensor_parallel_size&quot;: 1,
&quot;enable_chunked_prefill&quot;: True,
},
concurrency=total_gpus,
batch_size=64,
)

# preprocess maps each input row to a chat request; postprocess keeps the columns
# we want to persist. ray.data.llm adds a `generated_text` column.
processor = build_llm_processor(
config,
preprocess=lambda row: dict(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": row["instruction"]},
],
sampling_params=dict(max_tokens=256, temperature=0.7),
),
postprocess=lambda row: dict(
instruction=row["instruction"],
output=row["generated_text"],
),
)

# materialize once so the write and the sample print don't re-run inference.
out = processor(ds).materialize()
out.write_parquet(OUTPUT_PATH)
print(f"Wrote {out.count()} rows to {OUTPUT_PATH}", flush=True)

for row in out.take(2):
print("INSTRUCTION:", row["instruction"][:120], flush=True)
print("OUTPUT:", row["output"][:200], flush=True)

ray.shutdown()


if __name__ == "__main__":
main()

次のステップ