Pular para o conteúdo principal

Inferência LLM distribuída com Ray Data e SGLang em GPU serverless

Este notebook demonstra como executar inferência de modelos de linguagem de grande porte (LLM) em escala usando Ray Data e SGLang em uma GPU serverless Databricks . Ele utiliza a APIde GPU distribuída sem servidor para provisionar e gerenciar automaticamente GPUs A10 em vários nós para inferência distribuída.

Conteúdo deste caderno:

  • Configure o Ray e o SGLang para inferência LLM distribuída.
  • Utilize o Ray Data para agrupar e processar solicitações de forma eficiente em várias GPUs.
  • Defina funções de prompt SGLang com geração estruturada.
  • Salvar resultados de inferência em Parquet nos volumes Unity Catalog
  • Converter tabelas Parquet em tabelas Delta para governança e consultas eficientes.

Caso de uso: inferência em lotes com milhares de solicitações, utilizando GPUs de forma eficiente, o ambiente de execução otimizado do SGLang e a integração com Delta Lake .

Requisitos

compute GPU sem servidor com acelerador A10

Conecte o notebook à compute GPU serverless :

  1. Clique no dropdown "Conectar" na parte superior.
  2. Selecione GPU sem servidor .
  3. Abra o painel lateral Ambiente , localizado no lado direito do Notebook.
  4. Defina o acelerador para A10 .
  5. Clique em Aplicar e Confirmar .

Nota: A função distribuída inicia GPUs A10 remotas para inferência em vários nós. Execução do próprio Notebook em um único A10 para orquestração.

Instalar dependências

A célula a seguir instala todos os pacotes necessários para inferência distribuída de Ray e SGLang:

  • Flash Attention : Atenção otimizada para inferência mais rápida (compatível com CUDA 12, PyTorch 2.6 e A10)
  • SGLang : Framework de inferência e disponibilização de LLM de alto desempenho
  • Ray Data : Processamento de dados distribuídos para inferência em lotes
  • hf_transfer : downloadsrápidos do modelo Hugging Face
Python
# Pre-compiled Flash Attention for A10s (Essential for speed/compilation)
%pip install --no-cache-dir "torch==2.9.1+cu128" --index-url https://download.pytorch.org/whl/cu128
%pip install -U --no-cache-dir wheel ninja packaging
%pip install --force-reinstall --no-cache-dir --no-build-isolation flash-attn
%pip install hf_transfer
%pip install "ray[data]>=2.47.1"

# SGLang with all dependencies (handles vLLM/Torch automatically)
%pip install "sglang[all]>=0.4.7"

%restart_python

Verificar versões do pacote

A célula seguinte confirma que todos os pacotes necessários estão instalados em versões compatíveis.

Python
from packaging.version import Version

import torch
import flash_attn
import sglang
import ray

print(f"PyTorch: {torch.__version__}")
print(f"Flash Attention: {flash_attn.__version__}")
print(f"SGLang: {sglang.__version__}")
print(f"Ray: {ray.__version__}")

assert Version(ray.__version__) >= Version("2.47.1"), "Ray version must be at least 2.47.1"
print("\n✓ All version checks passed!")

Configuração

Utilize widgets para configurar parâmetros de inferência e autenticação opcional do Hugging Face.

Nota de segurança: Armazene os tokens da Hugging Face em Segredos do Databricks para uso em produção. Consulte a documentação do Databricks Secrets.

Python
# Widget configuration
dbutils.widgets.text("hf_secret_scope", "")
dbutils.widgets.text("hf_secret_key", "")
dbutils.widgets.text("model_name", "Qwen/Qwen3-4B-Instruct-2507")
dbutils.widgets.text("num_gpus", "5")
dbutils.widgets.text("num_prompts", "1000")

# Unity Catalog configuration for output storage
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_volume", "ray_data")
dbutils.widgets.text("uc_table", "sglang_inference_results")

# Retrieve widget values
HF_SECRET_SCOPE = dbutils.widgets.get("hf_secret_scope")
HF_SECRET_KEY = dbutils.widgets.get("hf_secret_key")
MODEL_NAME = dbutils.widgets.get("model_name")
NUM_GPUS = int(dbutils.widgets.get("num_gpus"))
NUM_PROMPTS = int(dbutils.widgets.get("num_prompts"))

# Unity Catalog paths
UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_VOLUME = dbutils.widgets.get("uc_volume")
UC_TABLE = dbutils.widgets.get("uc_table")

# Construct paths
UC_VOLUME_PATH = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}"
UC_TABLE_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{UC_TABLE}"
PARQUET_OUTPUT_PATH = f"{UC_VOLUME_PATH}/sglang_inference_output"

print(f"Model: {MODEL_NAME}")
print(f"Number of GPUs: {NUM_GPUS}")
print(f"Number of prompts: {NUM_PROMPTS}")
print(f"\nUnity Catalog Configuration:")
print(f" Volume Path: {UC_VOLUME_PATH}")
print(f" Table Name: {UC_TABLE_NAME}")
print(f" Parquet Output: {PARQUET_OUTPUT_PATH}")

Autentique com Hugging Face (opcional)

Se estiver usando modelos com acesso restrito (como o Llama), autentique-se com o Hugging Face.

Opção 1: Usar Segredos do Databricks (recomendado para produção)

Python
hf_token = dbutils.secrets.get(scope=HF_SECRET_SCOPE, key=HF_SECRET_KEY)
login(token=hf_token)

Opção 2: Login interativo (para desenvolvimento)

Python
from huggingface_hub import login

# Uncomment ONE of the following options:

# Option 1: Use Databricks Secrets (recommended)
# if HF_SECRET_SCOPE and HF_SECRET_KEY:
# hf_token = dbutils.secrets.get(scope=HF_SECRET_SCOPE, key=HF_SECRET_KEY)
# login(token=hf_token)
# print("✓ Logged in using Databricks Secrets")

# Option 2: Interactive login
login()
print("✓ Hugging Face authentication complete")

Configurar recurso Unity Catalog

A célula a seguir cria o recurso necessário Unity Catalog (catálogo, esquema e volume) para armazenar os resultados da inferência. Esses recursos fornecem governança, acompanhamento de linhagem e armazenamento centralizado para os resultados gerados.

Python
# Unity Catalog Setup and Dataset Download
# ⚠️ IMPORTANT: Run this cell BEFORE the dataset processing cell
# to set up Unity Catalog resources and download the raw datasets.

import os
from databricks.sdk import WorkspaceClient
from databricks.sdk.service import catalog
import requests

# Create Unity Catalog resources
w = WorkspaceClient()

# Create catalog if it doesn't exist
try:
created_catalog = w.catalogs.create(name=UC_CATALOG)
print(f"Catalog '{created_catalog.name}' created successfully")
except Exception as e:
print(f"Catalog '{UC_CATALOG}' already exists or error: {e}")

# Create schema if it doesn't exist
try:
created_schema = w.schemas.create(name=UC_SCHEMA, catalog_name=UC_CATALOG)
print(f"Schema '{created_schema.name}' created successfully")
except Exception as e:
print(f"Schema '{UC_SCHEMA}' already exists in catalog '{UC_CATALOG}' or error: {e}")

# Create volume if it doesn't exist
volume_path = f'/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}'
if not os.path.exists(volume_path):
try:
created_volume = w.volumes.create(
catalog_name=UC_CATALOG,
schema_name=UC_SCHEMA,
name=UC_VOLUME,
volume_type=catalog.VolumeType.MANAGED
)
print(f"Volume '{created_volume.name}' created successfully")
except Exception as e:
print(f"Volume '{UC_VOLUME}' already exists or error: {e}")
else:
print(f"Volume {volume_path} already exists")

Ray cluster recurso monitoramento

A célula a seguir define uma função utilitária para inspecionar os recursos cluster Ray e verificar a alocação de GPUs entre os nós.

Python
import json
import ray

def print_ray_resources():
"""Print Ray cluster resources and GPU allocation per node."""
try:
cluster_resources = ray.cluster_resources()
print("Ray Cluster Resources:")
print(json.dumps(cluster_resources, indent=2))

nodes = ray.nodes()
print(f"\nDetected {len(nodes)} Ray node(s):")

for node in nodes:
node_id = node.get("NodeID", "N/A")[:8] # Truncate for readability
ip_address = node.get("NodeManagerAddress", "N/A")
resources = node.get("Resources", {})
num_gpus = int(resources.get("GPU", 0))

print(f" • Node {node_id}... | IP: {ip_address} | GPUs: {num_gpus}")

# Show specific GPU IDs if available
gpu_ids = [k for k in resources.keys() if k.startswith("GPU_ID_")]
if gpu_ids:
print(f" GPU IDs: {', '.join(gpu_ids)}")

except Exception as e:
print(f"Error querying Ray cluster: {e}")

# Uncomment to display resources after cluster initialization
# print_ray_resources()

Defina a tarefa de inferência distribuída.

Este notebook utiliza o SGLang Runtime com Ray Data para inferência LLM distribuída. O SGLang oferece um ambiente de execução otimizado com recursos como o RadixAttention para um cache de prefixos eficiente.

Visão geral da arquitetura

┌─────────────────┐    ┌─────────────────┐    ┌─────────────────┐
│ Ray Data │───▶│ SGLang Runtime │───▶│ Generated │
│ (Prompts) │ │ (map_batches) │ │ Outputs │
└─────────────────┘ └─────────────────┘ └─────────────────┘
│ │
▼ ▼
Distributed GPU Workers (A10)
across nodes with SGLang engines

Principais benefícios do SGLang

  • RadixAttention : Reutilização eficiente do cache KV entre requisições
  • Geração estruturada : Decodificação guiada por restrições com regex/gramática
  • Tempo de execução otimizado : Alta taxa de transferência com lotes automáticos
  • Suporte a múltiplas interações : Manipulação eficiente de instruções conversacionais.
Python
from serverless_gpu.ray import ray_launch
import os

# Set Ray temp directory
os.environ['RAY_TEMP_DIR'] = '/tmp/ray'

# Set the UC Volumes temp directory for write_databricks_table
os.environ['_RAY_UC_VOLUMES_FUSE_TEMP_DIR'] = f"{UC_VOLUME_PATH}/ray_temp"

@ray_launch(gpus=NUM_GPUS, gpu_type='a10', remote=True)
def run_distributed_inference():
"""Run distributed LLM inference using Ray Data and SGLang Runtime."""
import ray
import numpy as np
from typing import Dict
import sglang as sgl
from datetime import datetime

# Sample prompts for inference
base_prompts = [
"Hello, my name is",
"The president of the United States is",
"The future of AI is",
]

# Scale up prompts for distributed processing
prompts = [{"prompt": p} for p in base_prompts * (NUM_PROMPTS // len(base_prompts))]
ds = ray.data.from_items(prompts)

print(f"✓ Created Ray dataset with {ds.count()} prompts")

# Define SGLang Predictor class for Ray Data map_batches
class SGLangPredictor:
"""SGLang-based predictor for batch inference with Ray Data."""

def __init__(self):
# Initialize SGLang Runtime inside the actor process
self.runtime = sgl.Runtime(
model_path=MODEL_NAME,
dtype="bfloat16",
trust_remote_code=True,
mem_fraction_static=0.85,
tp_size=1, # Tensor parallelism (1 GPU per worker)
)
# Set as default backend for the current process
sgl.set_default_backend(self.runtime)
print(f"✓ SGLang runtime initialized with model: {MODEL_NAME}")

def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
"""Process a batch of prompts using SGLang."""
input_prompts = batch["prompt"].tolist()

# Define SGLang prompt function
@sgl.function
def chat_completion(s, prompt):
s += sgl.system("You are a helpful AI assistant.")
s += sgl.user(prompt)
s += sgl.assistant(sgl.gen("response", max_tokens=100, temperature=0.8))

# Run batch inference
results = chat_completion.run_batch(
[{"prompt": p} for p in input_prompts],
progress_bar=False
)

generated_text = [r["response"] for r in results]

return {
"prompt": input_prompts,
"generated_text": generated_text,
"model": [MODEL_NAME] * len(input_prompts),
"timestamp": [datetime.now().isoformat()] * len(input_prompts),
}

def __del__(self):
"""Clean up runtime when actor dies."""
try:
self.runtime.shutdown()
except:
pass

print(f"✓ SGLang predictor configured with model: {MODEL_NAME}")

# Apply map_batches with SGLang predictor
ds = ds.map_batches(
SGLangPredictor,
concurrency=NUM_GPUS, # Number of parallel SGLang instances
batch_size=32,
num_gpus=1,
num_cpus=12,
)

# =========================================================================
# Write results to Parquet (stored in Unity Catalog Volume)
# =========================================================================
print(f"\n📦 Writing results to Parquet: {PARQUET_OUTPUT_PATH}")
ds.write_parquet(PARQUET_OUTPUT_PATH, mode="overwrite")
print(f"✓ Parquet files written successfully")

# Collect sample outputs for display
sample_outputs = ray.data.read_parquet(PARQUET_OUTPUT_PATH).take(limit=10)

print("\n" + "="*60)
print("SAMPLE INFERENCE RESULTS")
print("="*60 + "\n")

for i, output in enumerate(sample_outputs):
prompt = output.get("prompt", "N/A")
generated_text = output.get("generated_text", "")
display_text = generated_text[:100] if generated_text else "N/A"
print(f"[{i+1}] Prompt: {prompt!r}")
print(f" Generated: {display_text!r}...\n")

return PARQUET_OUTPUT_PATH

inferência distribuída de execução

A célula seguinte inicia a tarefa de inferência distribuída em várias GPUs A10. Isto irá:

  1. provisionamento trabalhador GPU A10 remoto
  2. Inicialize os mecanismos SGLang em cada worker.
  3. Distribua lembretes entre os trabalhadores usando dados Ray.
  4. Salvar os resultados gerados em formato Parquet.

Observação: A startup pode levar alguns minutos, pois os nós da GPU estão sendo provisionados e os modelos carregados.

Python
result = run_distributed_inference.distributed()
parquet_path = result[0] if NUM_GPUS > 1 else result
print(f"\n✓ Inference complete! Results saved to: {parquet_path}")

Carregar Parquet e converter para tabela Delta

A célula a seguir carrega a saída Parquet dos Volumes Unity Catalog usando Spark e a salva como uma tabela Delta para consultas e governança eficientes.

Python
# Load Parquet data using Spark
print(f"📖 Loading Parquet from: {PARQUET_OUTPUT_PATH}")
df_spark = spark.read.parquet(PARQUET_OUTPUT_PATH)

# Show schema and row count
print(f"\n✓ Loaded {df_spark.count()} rows")
print("\nSchema:")
df_spark.printSchema()

# Display sample rows
print("\nSample Results:")
display(df_spark.limit(10))

Salvar como tabela Delta no Unity Catalog

A célula a seguir grava os resultados da inferência em uma tabela Delta Unity Catalog para:

  • Governança : Rastrear a linhagem de dados e os controles de acesso.
  • Desempenho : Consultas otimizadas com Delta Lake
  • Versionamento : viagem do tempo e auditoria história
Python
# Write to Unity Catalog Delta table
print(f"💾 Writing to Delta table: {UC_TABLE_NAME}")

# Write the DataFrame as a Delta table (overwrite mode)
df_spark.write \
.format("delta") \
.mode("overwrite") \
.option("overwriteSchema", "true") \
.saveAsTable(UC_TABLE_NAME)

print(f"✓ Delta table created successfully: {UC_TABLE_NAME}")

Consultar a tabela Delta

A célula seguinte verifica se a tabela Delta foi criada e a consulta usando SQL.

Python
# Query the Delta table using SQL
print(f"📊 Querying Delta table: {UC_TABLE_NAME}\n")

# Get table info
display(spark.sql(f"DESCRIBE TABLE {UC_TABLE_NAME}"))

# Query sample results
print("\nSample Results from Delta Table:")
display(spark.sql(f"""
SELECT
prompt,
generated_text,
model,
timestamp
FROM {UC_TABLE_NAME}
LIMIT 10
"""))

# Get row count
row_count = spark.sql(f"SELECT COUNT(*) as count FROM {UC_TABLE_NAME}").collect()[0]["count"]
print(f"\n✓ Total rows in Delta table: {row_count}")

# Assert expected row count (NUM_PROMPTS should result in 999 rows: 1000 // 3 * 3 = 999)
expected_rows = (NUM_PROMPTS // 3) * 3 # Rounds down to nearest multiple of 3 base prompts
assert row_count == expected_rows, f"Expected {expected_rows} rows, but got {row_count}"

Próximos passos

Este notebook demonstrou com sucesso a inferência LLM distribuída usando Ray Data e SGLang em uma GPU serverless Databricks , com os resultados salvos em uma tabela Delta .

O que foi realizado

  • Executou inferência LLM distribuída em várias GPUs A10.
  • Utilização do SGLang Runtime para processamento otimizado de lotes.
  • Resultados salvos em Parquet nos volumes Unity Catalog
  • Converteu o formato Parquet em uma tabela Delta controlada.

Opções de personalização

  • Alterar o modelo : Atualizar o widget model_name para usar modelos diferentes do Hugging Face.
  • escala up : Aumente num_gpus para Taxa de transferência mais alta
  • Ajustar o tamanho dos lotes : Modificar batch_size em map_batches com base nas restrições de memória.
  • Geração de sintonia : Ajuste max_tokens, temperature na função SGLang.
  • Saída estruturada : Utilize as restrições de regex/gramática do SGLang para saída JSON.
  • Chat com múltiplas interações : Encadeie várias chamadas user/assistant na função de prompt.
  • Modo de anexação : Altere a gravação Delta para mode("append") para atualizações incrementais.

Recurso avançado SGLang

Python
# Structured JSON output with regex constraint
@sgl.function
def json_output(s, prompt):
s += sgl.user(prompt)
s += sgl.assistant(sgl.gen("response", regex=r'\{"name": "\w+", "age": \d+\}'))

# Multi-turn conversation
@sgl.function
def multi_turn(s, question1, question2):
s += sgl.user(question1)
s += sgl.assistant(sgl.gen("answer1"))
s += sgl.user(question2)
s += sgl.assistant(sgl.gen("answer2"))

Alternativa: write_databricks_table

Para espaços de trabalho Unity Catalog ativado, use ray.data.Dataset.write_databricks_table() para escrever diretamente em uma tabela Unity Catalog :

Python
# Set the temp directory environment variable
os.environ["_RAY_UC_VOLUMES_FUSE_TEMP_DIR"] = "/Volumes/catalog/schema/volume/ray_temp"

# Write directly to Unity Catalog table
ds.write_databricks_table(table_name="catalog.schema.table_name")

recurso

Limpar

Os recursos da GPU são limpos automaticamente quando o notebook é desconectado. Para desconectar manualmente:

  1. Clique em Conectado no dropdown compute
  2. Passe o cursor sobre "sem servidor"
  3. Selecione "Encerrar" no menu dropdown .

Exemplo de caderno

Inferência LLM distribuída com Ray Data e SGLang em GPU serverless

Abrir notebook em uma nova aba