Inferência LLM distribuída com Ray Data e SGLang em GPU serverless
Este notebook demonstra como executar inferência de modelos de linguagem de grande porte (LLM) em escala usando Ray Data e SGLang em uma GPU serverless Databricks . Ele utiliza a APIde GPU distribuída sem servidor para provisionar e gerenciar automaticamente GPUs A10 em vários nós para inferência distribuída.
Conteúdo deste caderno:
- Configure o Ray e o SGLang para inferência LLM distribuída.
- Utilize o Ray Data para agrupar e processar solicitações de forma eficiente em várias GPUs.
- Defina funções de prompt SGLang com geração estruturada.
- Salvar resultados de inferência em Parquet nos volumes Unity Catalog
- Converter tabelas Parquet em tabelas Delta para governança e consultas eficientes.
Caso de uso: inferência em lotes com milhares de solicitações, utilizando GPUs de forma eficiente, o ambiente de execução otimizado do SGLang e a integração com Delta Lake .
Requisitos
compute GPU sem servidor com acelerador A10
Conecte o notebook à compute GPU serverless :
- Clique no dropdown "Conectar" na parte superior.
- Selecione GPU sem servidor .
- Abra o painel lateral Ambiente , localizado no lado direito do Notebook.
- Defina o acelerador para A10 .
- Clique em Aplicar e Confirmar .
Nota: A função distribuída inicia GPUs A10 remotas para inferência em vários nós. Execução do próprio Notebook em um único A10 para orquestração.
Instalar dependências
A célula a seguir instala todos os pacotes necessários para inferência distribuída de Ray e SGLang:
- Flash Attention : Atenção otimizada para inferência mais rápida (compatível com CUDA 12, PyTorch 2.6 e A10)
- SGLang : Framework de inferência e disponibilização de LLM de alto desempenho
- Ray Data : Processamento de dados distribuídos para inferência em lotes
- hf_transfer : downloadsrápidos do modelo Hugging Face
# Pre-compiled Flash Attention for A10s (Essential for speed/compilation)
%pip install --no-cache-dir "torch==2.9.1+cu128" --index-url https://download.pytorch.org/whl/cu128
%pip install -U --no-cache-dir wheel ninja packaging
%pip install --force-reinstall --no-cache-dir --no-build-isolation flash-attn
%pip install hf_transfer
%pip install "ray[data]>=2.47.1"
# SGLang with all dependencies (handles vLLM/Torch automatically)
%pip install "sglang[all]>=0.4.7"
%restart_python
Verificar versões do pacote
A célula seguinte confirma que todos os pacotes necessários estão instalados em versões compatíveis.
from packaging.version import Version
import torch
import flash_attn
import sglang
import ray
print(f"PyTorch: {torch.__version__}")
print(f"Flash Attention: {flash_attn.__version__}")
print(f"SGLang: {sglang.__version__}")
print(f"Ray: {ray.__version__}")
assert Version(ray.__version__) >= Version("2.47.1"), "Ray version must be at least 2.47.1"
print("\n✓ All version checks passed!")
Configuração
Utilize widgets para configurar parâmetros de inferência e autenticação opcional do Hugging Face.
Nota de segurança: Armazene os tokens da Hugging Face em Segredos do Databricks para uso em produção. Consulte a documentação do Databricks Secrets.
# Widget configuration
dbutils.widgets.text("hf_secret_scope", "")
dbutils.widgets.text("hf_secret_key", "")
dbutils.widgets.text("model_name", "Qwen/Qwen3-4B-Instruct-2507")
dbutils.widgets.text("num_gpus", "5")
dbutils.widgets.text("num_prompts", "1000")
# Unity Catalog configuration for output storage
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_volume", "ray_data")
dbutils.widgets.text("uc_table", "sglang_inference_results")
# Retrieve widget values
HF_SECRET_SCOPE = dbutils.widgets.get("hf_secret_scope")
HF_SECRET_KEY = dbutils.widgets.get("hf_secret_key")
MODEL_NAME = dbutils.widgets.get("model_name")
NUM_GPUS = int(dbutils.widgets.get("num_gpus"))
NUM_PROMPTS = int(dbutils.widgets.get("num_prompts"))
# Unity Catalog paths
UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_VOLUME = dbutils.widgets.get("uc_volume")
UC_TABLE = dbutils.widgets.get("uc_table")
# Construct paths
UC_VOLUME_PATH = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}"
UC_TABLE_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{UC_TABLE}"
PARQUET_OUTPUT_PATH = f"{UC_VOLUME_PATH}/sglang_inference_output"
print(f"Model: {MODEL_NAME}")
print(f"Number of GPUs: {NUM_GPUS}")
print(f"Number of prompts: {NUM_PROMPTS}")
print(f"\nUnity Catalog Configuration:")
print(f" Volume Path: {UC_VOLUME_PATH}")
print(f" Table Name: {UC_TABLE_NAME}")
print(f" Parquet Output: {PARQUET_OUTPUT_PATH}")
Autentique com Hugging Face (opcional)
Se estiver usando modelos com acesso restrito (como o Llama), autentique-se com o Hugging Face.
Opção 1: Usar Segredos do Databricks (recomendado para produção)
hf_token = dbutils.secrets.get(scope=HF_SECRET_SCOPE, key=HF_SECRET_KEY)
login(token=hf_token)
Opção 2: Login interativo (para desenvolvimento)
from huggingface_hub import login
# Uncomment ONE of the following options:
# Option 1: Use Databricks Secrets (recommended)
# if HF_SECRET_SCOPE and HF_SECRET_KEY:
# hf_token = dbutils.secrets.get(scope=HF_SECRET_SCOPE, key=HF_SECRET_KEY)
# login(token=hf_token)
# print("✓ Logged in using Databricks Secrets")
# Option 2: Interactive login
login()
print("✓ Hugging Face authentication complete")
Configurar recurso Unity Catalog
A célula a seguir cria o recurso necessário Unity Catalog (catálogo, esquema e volume) para armazenar os resultados da inferência. Esses recursos fornecem governança, acompanhamento de linhagem e armazenamento centralizado para os resultados gerados.
# Unity Catalog Setup and Dataset Download
# ⚠️ IMPORTANT: Run this cell BEFORE the dataset processing cell
# to set up Unity Catalog resources and download the raw datasets.
import os
from databricks.sdk import WorkspaceClient
from databricks.sdk.service import catalog
import requests
# Create Unity Catalog resources
w = WorkspaceClient()
# Create catalog if it doesn't exist
try:
created_catalog = w.catalogs.create(name=UC_CATALOG)
print(f"Catalog '{created_catalog.name}' created successfully")
except Exception as e:
print(f"Catalog '{UC_CATALOG}' already exists or error: {e}")
# Create schema if it doesn't exist
try:
created_schema = w.schemas.create(name=UC_SCHEMA, catalog_name=UC_CATALOG)
print(f"Schema '{created_schema.name}' created successfully")
except Exception as e:
print(f"Schema '{UC_SCHEMA}' already exists in catalog '{UC_CATALOG}' or error: {e}")
# Create volume if it doesn't exist
volume_path = f'/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}'
if not os.path.exists(volume_path):
try:
created_volume = w.volumes.create(
catalog_name=UC_CATALOG,
schema_name=UC_SCHEMA,
name=UC_VOLUME,
volume_type=catalog.VolumeType.MANAGED
)
print(f"Volume '{created_volume.name}' created successfully")
except Exception as e:
print(f"Volume '{UC_VOLUME}' already exists or error: {e}")
else:
print(f"Volume {volume_path} already exists")
Ray cluster recurso monitoramento
A célula a seguir define uma função utilitária para inspecionar os recursos cluster Ray e verificar a alocação de GPUs entre os nós.
import json
import ray
def print_ray_resources():
"""Print Ray cluster resources and GPU allocation per node."""
try:
cluster_resources = ray.cluster_resources()
print("Ray Cluster Resources:")
print(json.dumps(cluster_resources, indent=2))
nodes = ray.nodes()
print(f"\nDetected {len(nodes)} Ray node(s):")
for node in nodes:
node_id = node.get("NodeID", "N/A")[:8] # Truncate for readability
ip_address = node.get("NodeManagerAddress", "N/A")
resources = node.get("Resources", {})
num_gpus = int(resources.get("GPU", 0))
print(f" • Node {node_id}... | IP: {ip_address} | GPUs: {num_gpus}")
# Show specific GPU IDs if available
gpu_ids = [k for k in resources.keys() if k.startswith("GPU_ID_")]
if gpu_ids:
print(f" GPU IDs: {', '.join(gpu_ids)}")
except Exception as e:
print(f"Error querying Ray cluster: {e}")
# Uncomment to display resources after cluster initialization
# print_ray_resources()
Defina a tarefa de inferência distribuída.
Este notebook utiliza o SGLang Runtime com Ray Data para inferência LLM distribuída. O SGLang oferece um ambiente de execução otimizado com recursos como o RadixAttention para um cache de prefixos eficiente.
Visão geral da arquitetura
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ Ray Data │───▶│ SGLang Runtime │───▶│ Generated │
│ (Prompts) │ │ (map_batches) │ │ Outputs │
└─────────────────┘ └─────────────────┘ └─────────────────┘
│ │
▼ ▼
Distributed GPU Workers (A10)
across nodes with SGLang engines
Principais benefícios do SGLang
- RadixAttention : Reutilização eficiente do cache KV entre requisições
- Geração estruturada : Decodificação guiada por restrições com regex/gramática
- Tempo de execução otimizado : Alta taxa de transferência com lotes automáticos
- Suporte a múltiplas interações : Manipulação eficiente de instruções conversacionais.
from serverless_gpu.ray import ray_launch
import os
# Set Ray temp directory
os.environ['RAY_TEMP_DIR'] = '/tmp/ray'
# Set the UC Volumes temp directory for write_databricks_table
os.environ['_RAY_UC_VOLUMES_FUSE_TEMP_DIR'] = f"{UC_VOLUME_PATH}/ray_temp"
@ray_launch(gpus=NUM_GPUS, gpu_type='a10', remote=True)
def run_distributed_inference():
"""Run distributed LLM inference using Ray Data and SGLang Runtime."""
import ray
import numpy as np
from typing import Dict
import sglang as sgl
from datetime import datetime
# Sample prompts for inference
base_prompts = [
"Hello, my name is",
"The president of the United States is",
"The future of AI is",
]
# Scale up prompts for distributed processing
prompts = [{"prompt": p} for p in base_prompts * (NUM_PROMPTS // len(base_prompts))]
ds = ray.data.from_items(prompts)
print(f"✓ Created Ray dataset with {ds.count()} prompts")
# Define SGLang Predictor class for Ray Data map_batches
class SGLangPredictor:
"""SGLang-based predictor for batch inference with Ray Data."""
def __init__(self):
# Initialize SGLang Runtime inside the actor process
self.runtime = sgl.Runtime(
model_path=MODEL_NAME,
dtype="bfloat16",
trust_remote_code=True,
mem_fraction_static=0.85,
tp_size=1, # Tensor parallelism (1 GPU per worker)
)
# Set as default backend for the current process
sgl.set_default_backend(self.runtime)
print(f"✓ SGLang runtime initialized with model: {MODEL_NAME}")
def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
"""Process a batch of prompts using SGLang."""
input_prompts = batch["prompt"].tolist()
# Define SGLang prompt function
@sgl.function
def chat_completion(s, prompt):
s += sgl.system("You are a helpful AI assistant.")
s += sgl.user(prompt)
s += sgl.assistant(sgl.gen("response", max_tokens=100, temperature=0.8))
# Run batch inference
results = chat_completion.run_batch(
[{"prompt": p} for p in input_prompts],
progress_bar=False
)
generated_text = [r["response"] for r in results]
return {
"prompt": input_prompts,
"generated_text": generated_text,
"model": [MODEL_NAME] * len(input_prompts),
"timestamp": [datetime.now().isoformat()] * len(input_prompts),
}
def __del__(self):
"""Clean up runtime when actor dies."""
try:
self.runtime.shutdown()
except:
pass
print(f"✓ SGLang predictor configured with model: {MODEL_NAME}")
# Apply map_batches with SGLang predictor
ds = ds.map_batches(
SGLangPredictor,
concurrency=NUM_GPUS, # Number of parallel SGLang instances
batch_size=32,
num_gpus=1,
num_cpus=12,
)
# =========================================================================
# Write results to Parquet (stored in Unity Catalog Volume)
# =========================================================================
print(f"\n📦 Writing results to Parquet: {PARQUET_OUTPUT_PATH}")
ds.write_parquet(PARQUET_OUTPUT_PATH, mode="overwrite")
print(f"✓ Parquet files written successfully")
# Collect sample outputs for display
sample_outputs = ray.data.read_parquet(PARQUET_OUTPUT_PATH).take(limit=10)
print("\n" + "="*60)
print("SAMPLE INFERENCE RESULTS")
print("="*60 + "\n")
for i, output in enumerate(sample_outputs):
prompt = output.get("prompt", "N/A")
generated_text = output.get("generated_text", "")
display_text = generated_text[:100] if generated_text else "N/A"
print(f"[{i+1}] Prompt: {prompt!r}")
print(f" Generated: {display_text!r}...\n")
return PARQUET_OUTPUT_PATH
inferência distribuída de execução
A célula seguinte inicia a tarefa de inferência distribuída em várias GPUs A10. Isto irá:
- provisionamento trabalhador GPU A10 remoto
- Inicialize os mecanismos SGLang em cada worker.
- Distribua lembretes entre os trabalhadores usando dados Ray.
- Salvar os resultados gerados em formato Parquet.
Observação: A startup pode levar alguns minutos, pois os nós da GPU estão sendo provisionados e os modelos carregados.
result = run_distributed_inference.distributed()
parquet_path = result[0] if NUM_GPUS > 1 else result
print(f"\n✓ Inference complete! Results saved to: {parquet_path}")
Carregar Parquet e converter para tabela Delta
A célula a seguir carrega a saída Parquet dos Volumes Unity Catalog usando Spark e a salva como uma tabela Delta para consultas e governança eficientes.
# Load Parquet data using Spark
print(f"📖 Loading Parquet from: {PARQUET_OUTPUT_PATH}")
df_spark = spark.read.parquet(PARQUET_OUTPUT_PATH)
# Show schema and row count
print(f"\n✓ Loaded {df_spark.count()} rows")
print("\nSchema:")
df_spark.printSchema()
# Display sample rows
print("\nSample Results:")
display(df_spark.limit(10))
Salvar como tabela Delta no Unity Catalog
A célula a seguir grava os resultados da inferência em uma tabela Delta Unity Catalog para:
- Governança : Rastrear a linhagem de dados e os controles de acesso.
- Desempenho : Consultas otimizadas com Delta Lake
- Versionamento : viagem do tempo e auditoria história
# Write to Unity Catalog Delta table
print(f"💾 Writing to Delta table: {UC_TABLE_NAME}")
# Write the DataFrame as a Delta table (overwrite mode)
df_spark.write \
.format("delta") \
.mode("overwrite") \
.option("overwriteSchema", "true") \
.saveAsTable(UC_TABLE_NAME)
print(f"✓ Delta table created successfully: {UC_TABLE_NAME}")
Consultar a tabela Delta
A célula seguinte verifica se a tabela Delta foi criada e a consulta usando SQL.
# Query the Delta table using SQL
print(f"📊 Querying Delta table: {UC_TABLE_NAME}\n")
# Get table info
display(spark.sql(f"DESCRIBE TABLE {UC_TABLE_NAME}"))
# Query sample results
print("\nSample Results from Delta Table:")
display(spark.sql(f"""
SELECT
prompt,
generated_text,
model,
timestamp
FROM {UC_TABLE_NAME}
LIMIT 10
"""))
# Get row count
row_count = spark.sql(f"SELECT COUNT(*) as count FROM {UC_TABLE_NAME}").collect()[0]["count"]
print(f"\n✓ Total rows in Delta table: {row_count}")
# Assert expected row count (NUM_PROMPTS should result in 999 rows: 1000 // 3 * 3 = 999)
expected_rows = (NUM_PROMPTS // 3) * 3 # Rounds down to nearest multiple of 3 base prompts
assert row_count == expected_rows, f"Expected {expected_rows} rows, but got {row_count}"
Próximos passos
Este notebook demonstrou com sucesso a inferência LLM distribuída usando Ray Data e SGLang em uma GPU serverless Databricks , com os resultados salvos em uma tabela Delta .
O que foi realizado
- Executou inferência LLM distribuída em várias GPUs A10.
- Utilização do SGLang Runtime para processamento otimizado de lotes.
- Resultados salvos em Parquet nos volumes Unity Catalog
- Converteu o formato Parquet em uma tabela Delta controlada.
Opções de personalização
- Alterar o modelo : Atualizar o widget
model_namepara usar modelos diferentes do Hugging Face. - escala up : Aumente
num_gpuspara Taxa de transferência mais alta - Ajustar o tamanho dos lotes : Modificar
batch_sizeemmap_batchescom base nas restrições de memória. - Geração de sintonia : Ajuste
max_tokens,temperaturena função SGLang. - Saída estruturada : Utilize as restrições de regex/gramática do SGLang para saída JSON.
- Chat com múltiplas interações : Encadeie várias chamadas
user/assistantna função de prompt. - Modo de anexação : Altere a gravação Delta para
mode("append")para atualizações incrementais.
Recurso avançado SGLang
# Structured JSON output with regex constraint
@sgl.function
def json_output(s, prompt):
s += sgl.user(prompt)
s += sgl.assistant(sgl.gen("response", regex=r'\{"name": "\w+", "age": \d+\}'))
# Multi-turn conversation
@sgl.function
def multi_turn(s, question1, question2):
s += sgl.user(question1)
s += sgl.assistant(sgl.gen("answer1"))
s += sgl.user(question2)
s += sgl.assistant(sgl.gen("answer2"))
Alternativa: write_databricks_table
Para espaços de trabalho Unity Catalog ativado, use ray.data.Dataset.write_databricks_table() para escrever diretamente em uma tabela Unity Catalog :
# Set the temp directory environment variable
os.environ["_RAY_UC_VOLUMES_FUSE_TEMP_DIR"] = "/Volumes/catalog/schema/volume/ray_temp"
# Write directly to Unity Catalog table
ds.write_databricks_table(table_name="catalog.schema.table_name")
recurso
- Documentação API de GPU sem servidor
- Documentação SGLang
- SGLang GitHub
- Documentação de dados Ray
- Documentação Unity Catalog
Limpar
Os recursos da GPU são limpos automaticamente quando o notebook é desconectado. Para desconectar manualmente:
- Clique em Conectado no dropdown compute
- Passe o cursor sobre "sem servidor"
- Selecione "Encerrar" no menu dropdown .