Pular para o conteúdo principal

Inferência em lote com Ray Data e vLLM

info

Beta

Este recurso está em Beta. Os administradores do espaço de trabalho podem controlar o acesso a esse recurso na página Pré-visualizações . Consulte Gerenciar prévias do Databricks.

Este exemplo executa inferência de lotes de LLM offline com Ray Data e vLLM em 8 GPUs H100 em um único nó. Um script de bootstrap inicia um cluster Ray no nó, então o driver usa a API LLM do Ray Data (ray.data.llm) para iniciar uma réplica vLLM por GPU e transmitir um dataset de prompts por meio deles, gravando o texto gerado em um volume do Unity Catalog como Parquet.

Ele usa um modelo público (Qwen2.5-7B-Instruct), assim, ele é executado como está sem um token do Hugging Face.

A carga de trabalho executa as seguintes ações:

  • Faz upload do projeto local com code_source: snapshot.
  • Inicia um cabeçalho Ray com todas as 8 GPUs e, em seguida, executa o driver de inferência em lote.
  • Utiliza ray.data.llm para executar uma réplica de vLLM por GPU e processar prompts em paralelo.
  • Grava os prompts e as saídas geradas em um volume do Unity Catalog como Parquet.

Pré-requisitos

  • air CLI instalada e autenticada. Consulte Instalar a CLI do Runtime de AI.
  • Um volume do Unity Catalog gravável. Você define o caminho no YAML de carga de trabalho abaixo.

Disposição do projeto

Criar um diretório com os seguintes arquivos.

Text
ray_batch_inference/
├── train.yaml # air workload config (inline dependencies + Ray bootstrap)
└── batch_inference.py # Ray Data + vLLM batch inference driver

O passo 1: Escreva a carga de trabalho YAML

train.yaml solicita um nó único de GPU_8xH100. Dependências são declaradas em linha em environment (com a imagem do cliente version), e o command inicia um cluster Ray no nó e, em seguida, executa o driver, assim, a carga de trabalho não precisa de um arquivo de dependência ou script iniciador separado.

O vLLM não está na imagem base, então é instalado em linha junto com três pins que os nós da GPU precisam: hf_transfer (a imagem base permite downloads rápidos do Hugging Face e espera este pacote), um fsspec mais recente (a imagem base vem com um antigo que impede downloads) e um opencv-python-headless fixado (o vLLM incorpora o OpenCV, cujo wheel default causa a falha do auto-teste OpenSSL FIPS nos nós da GPU).

Defina OUTPUT_PATH para um volume do Unity Catalog no qual você pode escrever.

YAML
experiment_name: air-ray-batch-inference

environment:
version: '4'
dependencies:
- ray[data]>=2.44
- vllm
- datasets>=3.0
- huggingface_hub>=0.34
# The base image sets HF_HUB_ENABLE_HF_TRANSFER=1; install the package it expects
# so model and dataset downloads don't error out.
- hf_transfer
# The base image ships fsspec 2023.5.0, which is too old for modern
# huggingface_hub and breaks dataset/model downloads. Pin a newer fsspec.
- fsspec>=2024.6.1
# vLLM pulls in opencv; its default wheel crashes the OpenSSL FIPS self-test
# on the GPU nodes. This pinned headless build avoids the crash.
- opencv-python-headless==4.12.0.88

# 8 H100 on a single node. Ray Data runs one vLLM replica per GPU.
compute:
num_accelerators: 8
accelerator_type: GPU_8xH100

code_source:
type: snapshot
snapshot:
root_path: .

command: |
cd $CODE_SOURCE_PATH
RAY_HEAD_PORT=6379
GPUS_PER_NODE=${LOCAL_WORLD_SIZE:-8}
if [ "${NODE_RANK:-0}" = "0" ]; then
echo "NODE_RANK=0: starting Ray head with $GPUS_PER_NODE GPU(s)..."
ray start --head --port=$RAY_HEAD_PORT --num-gpus="$GPUS_PER_NODE" --dashboard-host=0.0.0.0
python batch_inference.py
ray stop
else
echo "NODE_RANK=$NODE_RANK: connecting to Ray head at $MASTER_ADDR:$RAY_HEAD_PORT..."
for i in $(seq 1 12); do
if ray start --address="$MASTER_ADDR:$RAY_HEAD_PORT" --num-gpus="$GPUS_PER_NODE" --block 2>/dev/null; then
break
fi
echo "Attempt $i failed, retrying in 5s..."
sleep 5
done
fi

max_retries: 0
timeout_minutes: 60
env_variables:
NCCL_SOCKET_IFNAME: eth0
# Unity Catalog volume where results land as Parquet. Replace with your volume.
OUTPUT_PATH: /Volumes/main/default/air_examples/ray_batch_inference

O command em linha inicia um head do Ray com todas as GPUs no nó, executa o driver com python batch_inference.py e para o cluster. Também inclui um branch worker que une o cabeçalho, então o mesmo comando continua funcionando se você dimensionar o Job para vários nós.

Passo 2: Definir o driver de inferência em lote

batch_inference.py cria um Dataset Ray de prompts, configura um processador vLLM com ray.data.llm e grava os resultados. concurrency é o número de réplicas vLLM que o Ray Data executa em paralelo. Defini-lo como a contagem de GPU do cluster dá uma réplica por GPU, então os prompts são processados em todas as GPUs de uma só vez e o exemplo é escalado conforme você adiciona nós:

Python
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig

# Read the GPU count from the live Ray cluster so concurrency scales with the cluster.
total_gpus = int(ray.cluster_resources().get("GPU", 0))

config = vLLMEngineProcessorConfig(
model_source="Qwen/Qwen2.5-7B-Instruct",
engine_kwargs={"max_model_len": 4096, "tensor_parallel_size": 1},
concurrency=total_gpus, # one vLLM replica per GPU in the cluster
batch_size=64,
)

processor = build_llm_processor(
config,
preprocess=lambda row: dict(
messages=[{"role": "user", "content": row["instruction"]}],
sampling_params=dict(max_tokens=256, temperature=0.7),
),
postprocess=lambda row: dict(instruction=row["instruction"], output=row["generated_text"]),
)

out = processor(ds) # ds is a Ray Dataset with an "instruction" column
out.write_parquet(OUTPUT_PATH)

preprocess transforma cada linha de entrada em uma solicitação de chat, e postprocess mantém as colunas a serem persistidas. Ray Data adiciona uma coluna generated_text com a saída do modelo. O script completo está em Script completo do driver no final desta página.

Para modelos maiores, defina tensor_parallel_size para fragmentar uma réplica entre várias GPUs e divida total_gpus por esse valor para que as réplicas ainda preencham o cluster, por exemplo, concurrency=total_gpus // 2 com tensor_parallel_size=2.

Passo 3: Enviar a execução

Bash
air run -f train.yaml --dry-run
air run -f train.yaml --watch

o passo 4: inspeção da execução

Bash
air get run <run-id>
air logs <run-id>

Os logs mostram o prompt e a taxa de transferência de geração do motor vLLM enquanto o lote é executado, então uma linha Wrote <n> rows quando a saída é gravada.

Onde os resultados são armazenados

O driver grava um dataset Parquet no volume OUTPUT_PATH, com uma coluna instruction e uma coluna output. Leia de volta com Spark ou Pandas, por exemplo spark.read.parquet(OUTPUT_PATH).

Script de driver completo

O batch_inference.py completo para copiar e colar:

Python
#!/usr/bin/env python3
"""Offline batch inference with Ray Data + vLLM on a single 8x H100 node.

The workload `command` starts a Ray head with 8 GPUs and runs this script. Ray Data's
LLM API (`ray.data.llm`) launches one vLLM replica per GPU and streams a dataset of
prompts through them, then writes the generated text to a Unity Catalog volume as
Parquet.

Uses a public model (no Hugging Face token required) so the example runs as-is.
"""

import os

import ray
from datasets import load_dataset
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig

MODEL_SOURCE = "Qwen/Qwen2.5-7B-Instruct"
NUM_PROMPTS = 1000
# Unity Catalog volume path where results land as Parquet. Set this in train.yaml.
OUTPUT_PATH = os.environ.get("OUTPUT_PATH", "/Volumes/main/default/air_examples/ray_batch_inference")


def build_prompts():
"""Build a Ray Dataset of prompts from a public instruction dataset."""
raw = load_dataset("tatsu-lab/alpaca", split=f"train[:{NUM_PROMPTS}]")
items = []
for row in raw:
instruction = row["instruction"]
if row.get("input"):
instruction = f"{instruction}\n\n{row['input']}"
items.append({"instruction": instruction})
return ray.data.from_items(items)


def main():
ray.init(address="auto")
# Derive replicas from the live cluster so the example scales when nodes are added.
total_gpus = int(ray.cluster_resources().get("GPU", 0))
print(f"Ray cluster ready: {total_gpus} GPU(s)", flush=True)

ds = build_prompts()

# vLLM engine config. concurrency = number of replicas Ray Data runs in parallel;
# one per GPU in the cluster here. engine_kwargs are passed through to the vLLM engine.
config = vLLMEngineProcessorConfig(
model_source=MODEL_SOURCE,
engine_kwargs={
&quot;max_model_len&quot;: 4096,
&quot;tensor_parallel_size&quot;: 1,
&quot;enable_chunked_prefill&quot;: True,
},
concurrency=total_gpus,
batch_size=64,
)

# preprocess maps each input row to a chat request; postprocess keeps the columns
# we want to persist. ray.data.llm adds a `generated_text` column.
processor = build_llm_processor(
config,
preprocess=lambda row: dict(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": row["instruction"]},
],
sampling_params=dict(max_tokens=256, temperature=0.7),
),
postprocess=lambda row: dict(
instruction=row["instruction"],
output=row["generated_text"],
),
)

# materialize once so the write and the sample print don't re-run inference.
out = processor(ds).materialize()
out.write_parquet(OUTPUT_PATH)
print(f"Wrote {out.count()} rows to {OUTPUT_PATH}", flush=True)

for row in out.take(2):
print("INSTRUCTION:", row["instruction"][:120], flush=True)
print("OUTPUT:", row["output"][:200], flush=True)

ray.shutdown()


if __name__ == "__main__":
main()

Passos seguintes