Inferência em lote com Ray Data e vLLM
Beta
Este recurso está em Beta. Os administradores do espaço de trabalho podem controlar o acesso a esse recurso na página Pré-visualizações . Consulte Gerenciar prévias do Databricks.
Este exemplo executa inferência de lotes de LLM offline com Ray Data e vLLM em 8 GPUs H100 em um único nó. Um script de bootstrap inicia um cluster Ray no nó, então o driver usa a API LLM do Ray Data (ray.data.llm) para iniciar uma réplica vLLM por GPU e transmitir um dataset de prompts por meio deles, gravando o texto gerado em um volume do Unity Catalog como Parquet.
Ele usa um modelo público (Qwen2.5-7B-Instruct), assim, ele é executado como está sem um token do Hugging Face.
A carga de trabalho executa as seguintes ações:
- Faz upload do projeto local com
code_source: snapshot. - Inicia um cabeçalho Ray com todas as 8 GPUs e, em seguida, executa o driver de inferência em lote.
- Utiliza
ray.data.llmpara executar uma réplica de vLLM por GPU e processar prompts em paralelo. - Grava os prompts e as saídas geradas em um volume do Unity Catalog como Parquet.
Pré-requisitos
airCLI instalada e autenticada. Consulte Instalar a CLI do Runtime de AI.- Um volume do Unity Catalog gravável. Você define o caminho no YAML de carga de trabalho abaixo.
Disposição do projeto
Criar um diretório com os seguintes arquivos.
ray_batch_inference/
├── train.yaml # air workload config (inline dependencies + Ray bootstrap)
└── batch_inference.py # Ray Data + vLLM batch inference driver
O passo 1: Escreva a carga de trabalho YAML
train.yaml solicita um nó único de GPU_8xH100. Dependências são declaradas em linha em environment (com a imagem do cliente version), e o command inicia um cluster Ray no nó e, em seguida, executa o driver, assim, a carga de trabalho não precisa de um arquivo de dependência ou script iniciador separado.
O vLLM não está na imagem base, então é instalado em linha junto com três pins que os nós da GPU
precisam: hf_transfer (a imagem base permite downloads rápidos do Hugging Face e espera este
pacote), um fsspec mais recente (a imagem base vem com um antigo que impede downloads) e um
opencv-python-headless fixado (o vLLM incorpora o OpenCV, cujo wheel default causa a falha do
auto-teste OpenSSL FIPS nos nós da GPU).
Defina OUTPUT_PATH para um volume do Unity Catalog no qual você pode escrever.
experiment_name: air-ray-batch-inference
environment:
version: '4'
dependencies:
- ray[data]>=2.44
- vllm
- datasets>=3.0
- huggingface_hub>=0.34
# The base image sets HF_HUB_ENABLE_HF_TRANSFER=1; install the package it expects
# so model and dataset downloads don't error out.
- hf_transfer
# The base image ships fsspec 2023.5.0, which is too old for modern
# huggingface_hub and breaks dataset/model downloads. Pin a newer fsspec.
- fsspec>=2024.6.1
# vLLM pulls in opencv; its default wheel crashes the OpenSSL FIPS self-test
# on the GPU nodes. This pinned headless build avoids the crash.
- opencv-python-headless==4.12.0.88
# 8 H100 on a single node. Ray Data runs one vLLM replica per GPU.
compute:
num_accelerators: 8
accelerator_type: GPU_8xH100
code_source:
type: snapshot
snapshot:
root_path: .
command: |
cd $CODE_SOURCE_PATH
RAY_HEAD_PORT=6379
GPUS_PER_NODE=${LOCAL_WORLD_SIZE:-8}
if [ "${NODE_RANK:-0}" = "0" ]; then
echo "NODE_RANK=0: starting Ray head with $GPUS_PER_NODE GPU(s)..."
ray start --head --port=$RAY_HEAD_PORT --num-gpus="$GPUS_PER_NODE" --dashboard-host=0.0.0.0
python batch_inference.py
ray stop
else
echo "NODE_RANK=$NODE_RANK: connecting to Ray head at $MASTER_ADDR:$RAY_HEAD_PORT..."
for i in $(seq 1 12); do
if ray start --address="$MASTER_ADDR:$RAY_HEAD_PORT" --num-gpus="$GPUS_PER_NODE" --block 2>/dev/null; then
break
fi
echo "Attempt $i failed, retrying in 5s..."
sleep 5
done
fi
max_retries: 0
timeout_minutes: 60
env_variables:
NCCL_SOCKET_IFNAME: eth0
# Unity Catalog volume where results land as Parquet. Replace with your volume.
OUTPUT_PATH: /Volumes/main/default/air_examples/ray_batch_inference
O command em linha inicia um head do Ray com todas as GPUs no nó, executa o driver com python batch_inference.py e para o cluster. Também inclui um branch worker que une o cabeçalho, então o mesmo comando continua funcionando se você dimensionar o Job para vários nós.
Passo 2: Definir o driver de inferência em lote
batch_inference.py cria um Dataset Ray de prompts, configura um processador vLLM com ray.data.llm e grava os resultados. concurrency é o número de réplicas vLLM que o Ray Data executa em paralelo. Defini-lo como a contagem de GPU do cluster dá uma réplica por GPU, então os prompts são processados em todas as GPUs de uma só vez e o exemplo é escalado conforme você adiciona nós:
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig
# Read the GPU count from the live Ray cluster so concurrency scales with the cluster.
total_gpus = int(ray.cluster_resources().get("GPU", 0))
config = vLLMEngineProcessorConfig(
model_source="Qwen/Qwen2.5-7B-Instruct",
engine_kwargs={"max_model_len": 4096, "tensor_parallel_size": 1},
concurrency=total_gpus, # one vLLM replica per GPU in the cluster
batch_size=64,
)
processor = build_llm_processor(
config,
preprocess=lambda row: dict(
messages=[{"role": "user", "content": row["instruction"]}],
sampling_params=dict(max_tokens=256, temperature=0.7),
),
postprocess=lambda row: dict(instruction=row["instruction"], output=row["generated_text"]),
)
out = processor(ds) # ds is a Ray Dataset with an "instruction" column
out.write_parquet(OUTPUT_PATH)
preprocess transforma cada linha de entrada em uma solicitação de chat, e postprocess mantém as colunas a serem persistidas. Ray Data adiciona uma coluna generated_text com a saída do modelo. O script completo está em Script completo do driver no final desta página.
Para modelos maiores, defina tensor_parallel_size para fragmentar uma réplica entre várias GPUs e divida total_gpus por esse valor para que as réplicas ainda preencham o cluster, por exemplo, concurrency=total_gpus // 2 com tensor_parallel_size=2.
Passo 3: Enviar a execução
air run -f train.yaml --dry-run
air run -f train.yaml --watch
o passo 4: inspeção da execução
air get run <run-id>
air logs <run-id>
Os logs mostram o prompt e a taxa de transferência de geração do motor vLLM enquanto o lote é executado, então uma linha Wrote <n> rows quando a saída é gravada.
Onde os resultados são armazenados
O driver grava um dataset Parquet no volume OUTPUT_PATH, com uma coluna instruction
e uma coluna output. Leia de volta com Spark ou Pandas, por exemplo
spark.read.parquet(OUTPUT_PATH).
Script de driver completo
O batch_inference.py completo para copiar e colar:
#!/usr/bin/env python3
"""Offline batch inference with Ray Data + vLLM on a single 8x H100 node.
The workload `command` starts a Ray head with 8 GPUs and runs this script. Ray Data's
LLM API (`ray.data.llm`) launches one vLLM replica per GPU and streams a dataset of
prompts through them, then writes the generated text to a Unity Catalog volume as
Parquet.
Uses a public model (no Hugging Face token required) so the example runs as-is.
"""
import os
import ray
from datasets import load_dataset
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig
MODEL_SOURCE = "Qwen/Qwen2.5-7B-Instruct"
NUM_PROMPTS = 1000
# Unity Catalog volume path where results land as Parquet. Set this in train.yaml.
OUTPUT_PATH = os.environ.get("OUTPUT_PATH", "/Volumes/main/default/air_examples/ray_batch_inference")
def build_prompts():
"""Build a Ray Dataset of prompts from a public instruction dataset."""
raw = load_dataset("tatsu-lab/alpaca", split=f"train[:{NUM_PROMPTS}]")
items = []
for row in raw:
instruction = row["instruction"]
if row.get("input"):
instruction = f"{instruction}\n\n{row['input']}"
items.append({"instruction": instruction})
return ray.data.from_items(items)
def main():
ray.init(address="auto")
# Derive replicas from the live cluster so the example scales when nodes are added.
total_gpus = int(ray.cluster_resources().get("GPU", 0))
print(f"Ray cluster ready: {total_gpus} GPU(s)", flush=True)
ds = build_prompts()
# vLLM engine config. concurrency = number of replicas Ray Data runs in parallel;
# one per GPU in the cluster here. engine_kwargs are passed through to the vLLM engine.
config = vLLMEngineProcessorConfig(
model_source=MODEL_SOURCE,
engine_kwargs={
"max_model_len": 4096,
"tensor_parallel_size": 1,
"enable_chunked_prefill": True,
},
concurrency=total_gpus,
batch_size=64,
)
# preprocess maps each input row to a chat request; postprocess keeps the columns
# we want to persist. ray.data.llm adds a `generated_text` column.
processor = build_llm_processor(
config,
preprocess=lambda row: dict(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": row["instruction"]},
],
sampling_params=dict(max_tokens=256, temperature=0.7),
),
postprocess=lambda row: dict(
instruction=row["instruction"],
output=row["generated_text"],
),
)
# materialize once so the write and the sample print don't re-run inference.
out = processor(ds).materialize()
out.write_parquet(OUTPUT_PATH)
print(f"Wrote {out.count()} rows to {OUTPUT_PATH}", flush=True)
for row in out.take(2):
print("INSTRUCTION:", row["instruction"][:120], flush=True)
print("OUTPUT:", row["output"][:200], flush=True)
ray.shutdown()
if __name__ == "__main__":
main()