サーバーレス GPU 上で Ray Data と SGLang を使用した分散LLM推論
このノートブックでは、 Databricksサーバレス GPU 上で Ray Data と SGLang を使用して大規模言語モデル ( LLM ) 推論を大規模に実行する方法を示します。 分散サーバレス GPU APIを利用して、分散推論用のマルチノード A10 GPU を自動的にプロビジョニングおよび管理します。
このノートブックの内容:
- 分散LLM推論用にRayとSGLangを設定する
- Ray Data を使用して、複数の GPU 間でプロンプトを効率的にバッチ処理します。
- 構造化された生成によるSGLangプロンプト関数の定義
- Unity CatalogボリュームのParquetに推論結果を保存する
- ガバナンスと効率的なクエリのために Parquet テーブルを Delta テーブルに変換する
ユースケース: 効率的な GPU 使用率、SGLang の最適化されたランタイム、Delta Lake 統合により、数千のプロンプトをバッチ推論します。
要件
A10アクセラレータによるサーバレスGPUコンピュート
ノートブックをサーバレス GPU コンピュートに接続します。
- 上部の 「接続」 ドロップダウンをクリックします。
- サーバレス GPU を選択します。
- ノートブックの右側にある 環境 サイドパネルを開きます。
- アクセラレータを A10 に設定します。
- [適用] と [確認] をクリックします。
注: 分散関数は、マルチノード推論のためにリモート A10 GPU を起動します。ノートブック自体は、オーケストレーション用の単一の A10 上で実行されます。
依存関係をインストールする
次のセルは、分散 Ray および SGLang 推論に必要なすべてのパッケージをインストールします。
- Flash Attention : 推論を高速化するための最適化されたアテンション (CUDA 12、PyTorch 2.6、A10 互換)
- SGLang : 高性能 LLM 推論およびサービス フレームワーク
- Ray Data : バッチ推論のための分散データ処理
- hf_transfer : 高速Hugging Faceモデルのダウンロード
# Pre-compiled Flash Attention for A10s (Essential for speed/compilation)
%pip install --no-cache-dir "torch==2.9.1+cu128" --index-url https://download.pytorch.org/whl/cu128
%pip install -U --no-cache-dir wheel ninja packaging
%pip install --force-reinstall --no-cache-dir --no-build-isolation flash-attn
%pip install hf_transfer
%pip install "ray[data]>=2.47.1"
# SGLang with all dependencies (handles vLLM/Torch automatically)
%pip install "sglang[all]>=0.4.7"
%restart_python
パッケージのバージョンを確認する
次のセルは、必要なすべてのパッケージが互換性のあるバージョンでインストールされていることを確認します。
from packaging.version import Version
import torch
import flash_attn
import sglang
import ray
print(f"PyTorch: {torch.__version__}")
print(f"Flash Attention: {flash_attn.__version__}")
print(f"SGLang: {sglang.__version__}")
print(f"Ray: {ray.__version__}")
assert Version(ray.__version__) >= Version("2.47.1"), "Ray version must be at least 2.47.1"
print("\n✓ All version checks passed!")
構成
ウィジェットを使用して推論とオプションのHugging Face認証を設定します。
セキュリティに関する注意: 本番運用で使用するために、 Hugging FaceをDatabricks Secrets に保存します。 Databricks Secrets のドキュメントを参照してください。
# Widget configuration
dbutils.widgets.text("hf_secret_scope", "")
dbutils.widgets.text("hf_secret_key", "")
dbutils.widgets.text("model_name", "Qwen/Qwen3-4B-Instruct-2507")
dbutils.widgets.text("num_gpus", "5")
dbutils.widgets.text("num_prompts", "1000")
# Unity Catalog configuration for output storage
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_volume", "ray_data")
dbutils.widgets.text("uc_table", "sglang_inference_results")
# Retrieve widget values
HF_SECRET_SCOPE = dbutils.widgets.get("hf_secret_scope")
HF_SECRET_KEY = dbutils.widgets.get("hf_secret_key")
MODEL_NAME = dbutils.widgets.get("model_name")
NUM_GPUS = int(dbutils.widgets.get("num_gpus"))
NUM_PROMPTS = int(dbutils.widgets.get("num_prompts"))
# Unity Catalog paths
UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_VOLUME = dbutils.widgets.get("uc_volume")
UC_TABLE = dbutils.widgets.get("uc_table")
# Construct paths
UC_VOLUME_PATH = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}"
UC_TABLE_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{UC_TABLE}"
PARQUET_OUTPUT_PATH = f"{UC_VOLUME_PATH}/sglang_inference_output"
print(f"Model: {MODEL_NAME}")
print(f"Number of GPUs: {NUM_GPUS}")
print(f"Number of prompts: {NUM_PROMPTS}")
print(f"\nUnity Catalog Configuration:")
print(f" Volume Path: {UC_VOLUME_PATH}")
print(f" Table Name: {UC_TABLE_NAME}")
print(f" Parquet Output: {PARQUET_OUTPUT_PATH}")
Hugging Faceで認証する(オプション)
ゲート付きモデル (Llama など) を使用する場合は、Hugging Face で認証します。
オプション 1: Databricks Secret を使用する (本番運用に推奨)
hf_token = dbutils.secrets.get(scope=HF_SECRET_SCOPE, key=HF_SECRET_KEY)
login(token=hf_token)
オプション2: 対話型ログイン(開発用)
from huggingface_hub import login
# Uncomment ONE of the following options:
# Option 1: Use Databricks Secrets (recommended)
# if HF_SECRET_SCOPE and HF_SECRET_KEY:
# hf_token = dbutils.secrets.get(scope=HF_SECRET_SCOPE, key=HF_SECRET_KEY)
# login(token=hf_token)
# print("✓ Logged in using Databricks Secrets")
# Option 2: Interactive login
login()
print("✓ Hugging Face authentication complete")
Unity Catalogリソースをセットアップする
次のセルは、推論結果を保存するために必要なUnity Catalogリソース (カタログ、スキーマ、ボリューム) を作成します。 これらのリソースは、ガバナンス、リネージ追跡、生成された出力の集中ストレージを提供します。
# Unity Catalog Setup and Dataset Download
# ⚠️ IMPORTANT: Run this cell BEFORE the dataset processing cell
# to set up Unity Catalog resources and download the raw datasets.
import os
from databricks.sdk import WorkspaceClient
from databricks.sdk.service import catalog
import requests
# Create Unity Catalog resources
w = WorkspaceClient()
# Create catalog if it doesn't exist
try:
created_catalog = w.catalogs.create(name=UC_CATALOG)
print(f"Catalog '{created_catalog.name}' created successfully")
except Exception as e:
print(f"Catalog '{UC_CATALOG}' already exists or error: {e}")
# Create schema if it doesn't exist
try:
created_schema = w.schemas.create(name=UC_SCHEMA, catalog_name=UC_CATALOG)
print(f"Schema '{created_schema.name}' created successfully")
except Exception as e:
print(f"Schema '{UC_SCHEMA}' already exists in catalog '{UC_CATALOG}' or error: {e}")
# Create volume if it doesn't exist
volume_path = f'/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}'
if not os.path.exists(volume_path):
try:
created_volume = w.volumes.create(
catalog_name=UC_CATALOG,
schema_name=UC_SCHEMA,
name=UC_VOLUME,
volume_type=catalog.VolumeType.MANAGED
)
print(f"Volume '{created_volume.name}' created successfully")
except Exception as e:
print(f"Volume '{UC_VOLUME}' already exists or error: {e}")
else:
print(f"Volume {volume_path} already exists")
レイクラスターリソースモニタリング
次のセルは、Ray クラスターのリソースを検査し、ノード間の GPU 割り当てを確認するユーティリティ関数を定義します。
import json
import ray
def print_ray_resources():
"""Print Ray cluster resources and GPU allocation per node."""
try:
cluster_resources = ray.cluster_resources()
print("Ray Cluster Resources:")
print(json.dumps(cluster_resources, indent=2))
nodes = ray.nodes()
print(f"\nDetected {len(nodes)} Ray node(s):")
for node in nodes:
node_id = node.get("NodeID", "N/A")[:8] # Truncate for readability
ip_address = node.get("NodeManagerAddress", "N/A")
resources = node.get("Resources", {})
num_gpus = int(resources.get("GPU", 0))
print(f" • Node {node_id}... | IP: {ip_address} | GPUs: {num_gpus}")
# Show specific GPU IDs if available
gpu_ids = [k for k in resources.keys() if k.startswith("GPU_ID_")]
if gpu_ids:
print(f" GPU IDs: {', '.join(gpu_ids)}")
except Exception as e:
print(f"Error querying Ray cluster: {e}")
# Uncomment to display resources after cluster initialization
# print_ray_resources()
分散推論タスクを定義する
このノートブックでは、分散LLM推論に SGLang Runtime と Ray Data を使用しています。 SGLang は、効率的なプレフィックス キャッシュを実現する RadixAttention などの機能を備えた最適化されたランタイムを提供します。
アーキテクチャの概要
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ Ray Data │───▶│ SGLang Runtime │───▶│ Generated │
│ (Prompts) │ │ (map_batches) │ │ Outputs │
└─────────────────┘ └─────────────────┘ └─────────────────┘
│ │
▼ ▼
Distributed GPU Workers (A10)
across nodes with SGLang engines
SGLangの主な利点
- RadixAttention : リクエスト間での効率的なKVキャッシュの再利用
- 構造化生成 : 正規表現/文法による制約誘導デコード
- 最適化されたランタイム : 自動バッチ処理による高スループット
- マルチターンサポート :会話プロンプトの効率的な処理
from serverless_gpu.ray import ray_launch
import os
# Set Ray temp directory
os.environ['RAY_TEMP_DIR'] = '/tmp/ray'
# Set the UC Volumes temp directory for write_databricks_table
os.environ['_RAY_UC_VOLUMES_FUSE_TEMP_DIR'] = f"{UC_VOLUME_PATH}/ray_temp"
@ray_launch(gpus=NUM_GPUS, gpu_type='a10', remote=True)
def run_distributed_inference():
"""Run distributed LLM inference using Ray Data and SGLang Runtime."""
import ray
import numpy as np
from typing import Dict
import sglang as sgl
from datetime import datetime
# Sample prompts for inference
base_prompts = [
"Hello, my name is",
"The president of the United States is",
"The future of AI is",
]
# Scale up prompts for distributed processing
prompts = [{"prompt": p} for p in base_prompts * (NUM_PROMPTS // len(base_prompts))]
ds = ray.data.from_items(prompts)
print(f"✓ Created Ray dataset with {ds.count()} prompts")
# Define SGLang Predictor class for Ray Data map_batches
class SGLangPredictor:
"""SGLang-based predictor for batch inference with Ray Data."""
def __init__(self):
# Initialize SGLang Runtime inside the actor process
self.runtime = sgl.Runtime(
model_path=MODEL_NAME,
dtype="bfloat16",
trust_remote_code=True,
mem_fraction_static=0.85,
tp_size=1, # Tensor parallelism (1 GPU per worker)
)
# Set as default backend for the current process
sgl.set_default_backend(self.runtime)
print(f"✓ SGLang runtime initialized with model: {MODEL_NAME}")
def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
"""Process a batch of prompts using SGLang."""
input_prompts = batch["prompt"].tolist()
# Define SGLang prompt function
@sgl.function
def chat_completion(s, prompt):
s += sgl.system("You are a helpful AI assistant.")
s += sgl.user(prompt)
s += sgl.assistant(sgl.gen("response", max_tokens=100, temperature=0.8))
# Run batch inference
results = chat_completion.run_batch(
[{"prompt": p} for p in input_prompts],
progress_bar=False
)
generated_text = [r["response"] for r in results]
return {
"prompt": input_prompts,
"generated_text": generated_text,
"model": [MODEL_NAME] * len(input_prompts),
"timestamp": [datetime.now().isoformat()] * len(input_prompts),
}
def __del__(self):
"""Clean up runtime when actor dies."""
try:
self.runtime.shutdown()
except:
pass
print(f"✓ SGLang predictor configured with model: {MODEL_NAME}")
# Apply map_batches with SGLang predictor
ds = ds.map_batches(
SGLangPredictor,
concurrency=NUM_GPUS, # Number of parallel SGLang instances
batch_size=32,
num_gpus=1,
num_cpus=12,
)
# =========================================================================
# Write results to Parquet (stored in Unity Catalog Volume)
# =========================================================================
print(f"\n📦 Writing results to Parquet: {PARQUET_OUTPUT_PATH}")
ds.write_parquet(PARQUET_OUTPUT_PATH, mode="overwrite")
print(f"✓ Parquet files written successfully")
# Collect sample outputs for display
sample_outputs = ray.data.read_parquet(PARQUET_OUTPUT_PATH).take(limit=10)
print("\n" + "="*60)
print("SAMPLE INFERENCE RESULTS")
print("="*60 + "\n")
for i, output in enumerate(sample_outputs):
prompt = output.get("prompt", "N/A")
generated_text = output.get("generated_text", "")
display_text = generated_text[:100] if generated_text else "N/A"
print(f"[{i+1}] Prompt: {prompt!r}")
print(f" Generated: {display_text!r}...\n")
return PARQUET_OUTPUT_PATH
分散推論を実行する
次のセルは、複数の A10 GPU にわたって分散推論タスクを起動します。これにより、次のようになります。
- リモート A10 GPU ワーカーのプロビジョニング
- 各ワーカーのSGLangエンジンを初期化する
- Ray Data を使用してワーカー全体にプロンプトを配布する
- 生成された出力をParquetに保存する
注: GPU ノードがプロビジョニングされ、モデルがロードされるため、最初の起動には数分かかる場合があります。
result = run_distributed_inference.distributed()
parquet_path = result[0] if NUM_GPUS > 1 else result
print(f"\n✓ Inference complete! Results saved to: {parquet_path}")
Parquet をロードして Delta テーブルに変換する
次のセルは、 Sparkを使用してUnity CatalogボリュームからParquet出力を読み込み、効率的なクエリとガバナンスのためにDeltaテーブルとして保存します。
# Load Parquet data using Spark
print(f"📖 Loading Parquet from: {PARQUET_OUTPUT_PATH}")
df_spark = spark.read.parquet(PARQUET_OUTPUT_PATH)
# Show schema and row count
print(f"\n✓ Loaded {df_spark.count()} rows")
print("\nSchema:")
df_spark.printSchema()
# Display sample rows
print("\nSample Results:")
display(df_spark.limit(10))
Unity CatalogにDeltaテーブルとして保存
次のセルは、次の推論結果を Unity Catalog Delta テーブルに書き込みます。
- ガバナンス : データ改修とアクセス制御の追跡
- パフォーマンス : Delta Lakeによるクエリの最適化
- バージョン管理 : タイムトラベルと監査履歴
# Write to Unity Catalog Delta table
print(f"💾 Writing to Delta table: {UC_TABLE_NAME}")
# Write the DataFrame as a Delta table (overwrite mode)
df_spark.write \
.format("delta") \
.mode("overwrite") \
.option("overwriteSchema", "true") \
.saveAsTable(UC_TABLE_NAME)
print(f"✓ Delta table created successfully: {UC_TABLE_NAME}")
Deltaテーブルをクエリする
次のセルは、Delta テーブルが作成されたことを確認し、SQL を使用してクエリを実行します。
# Query the Delta table using SQL
print(f"📊 Querying Delta table: {UC_TABLE_NAME}\n")
# Get table info
display(spark.sql(f"DESCRIBE TABLE {UC_TABLE_NAME}"))
# Query sample results
print("\nSample Results from Delta Table:")
display(spark.sql(f"""
SELECT
prompt,
generated_text,
model,
timestamp
FROM {UC_TABLE_NAME}
LIMIT 10
"""))
# Get row count
row_count = spark.sql(f"SELECT COUNT(*) as count FROM {UC_TABLE_NAME}").collect()[0]["count"]
print(f"\n✓ Total rows in Delta table: {row_count}")
# Assert expected row count (NUM_PROMPTS should result in 999 rows: 1000 // 3 * 3 = 999)
expected_rows = (NUM_PROMPTS // 3) * 3 # Rounds down to nearest multiple of 3 base prompts
assert row_count == expected_rows, f"Expected {expected_rows} rows, but got {row_count}"
次のステップ
このノートブックは、 Databricksサーバレス GPU 上で Ray Data と SGLang を使用した分散LLM推論のデモンストレーションに成功し、結果はDeltaテーブルに保存されました。
何が達成されたか
- 複数のA10 GPUにわたって分散LLM推論を実行
- 最適化されたバッチ処理のために SGLang Runtimeを使用
- Unity CatalogボリュームのParquetに保存された結果
- ParquetをガバナンスされたDeltaテーブルに変換しました
カスタマイズオプション
- モデルの変更 : 別の Hugging Face モデルを使用するように
model_nameウィジェットを更新します - スケールアップ : スループットを向上させるには
num_gpusを増やします - バッチサイズを調整 : メモリ制約に基づいて
map_batchesのbatch_sizeを変更します - 生成の調整 : SGLang関数の
max_tokens、temperatureを調整します - 構造化出力 : JSON出力にSGLangの正規表現/文法制約を使用する
- マルチターンチャット : プロンプト関数で複数の
user/assistant呼び出しを連鎖します - 追加モード : 増分更新のDelta書き込みを
mode("append")に変更します
SGLangの高度な機能
# Structured JSON output with regex constraint
@sgl.function
def json_output(s, prompt):
s += sgl.user(prompt)
s += sgl.assistant(sgl.gen("response", regex=r'\{"name": "\w+", "age": \d+\}'))
# Multi-turn conversation
@sgl.function
def multi_turn(s, question1, question2):
s += sgl.user(question1)
s += sgl.assistant(sgl.gen("answer1"))
s += sgl.user(question2)
s += sgl.assistant(sgl.gen("answer2"))
代替: write_databricks_table
Unity Catalog 対応ワークスペースの場合は、 ray.data.Dataset.write_databricks_table()を使用して Unity Catalog テーブルに直接書き込みます。
# Set the temp directory environment variable
os.environ["_RAY_UC_VOLUMES_FUSE_TEMP_DIR"] = "/Volumes/catalog/schema/volume/ray_temp"
# Write directly to Unity Catalog table
ds.write_databricks_table(table_name="catalog.schema.table_name")
リソース
掃除
ノートブックが切断されると、GPU リソースは自動的にクリーンアップされます。手動で切断するには:
- 「コンピュート」ドロップダウンで 「接続済み」 をクリックします。
- サーバレス の上にマウスを移動します
- ドロップダウンメニューから 「終了」 を選択します