メインコンテンツまでスキップ

サーバーレス GPU 上で Ray Data と SGLang を使用した分散LLM推論

このノートブックでは、 Databricksサーバレス GPU 上で Ray DataSGLang を使用して大規模言語モデル ( LLM ) 推論を大規模に実行する方法を示します。 分散サーバレス GPU APIを利用して、分散推論用のマルチノード A10 GPU を自動的にプロビジョニングおよび管理します。

このノートブックの内容:

  • 分散LLM推論用にRayとSGLangを設定する
  • Ray Data を使用して、複数の GPU 間でプロンプトを効率的にバッチ処理します。
  • 構造化された生成によるSGLangプロンプト関数の定義
  • Unity CatalogボリュームのParquetに推論結果を保存する
  • ガバナンスと効率的なクエリのために Parquet テーブルを Delta テーブルに変換する

ユースケース: 効率的な GPU 使用率、SGLang の最適化されたランタイム、Delta Lake 統合により、数千のプロンプトをバッチ推論します。

要件

A10アクセラレータによるサーバレスGPUコンピュート

ノートブックをサーバレス GPU コンピュートに接続します。

  1. 上部の 「接続」 ドロップダウンをクリックします。
  2. サーバレス GPU を選択します。
  3. ノートブックの右側にある 環境 サイドパネルを開きます。
  4. アクセラレータを A10 に設定します。
  5. [適用][確認] をクリックします。

注: 分散関数は、マルチノード推論のためにリモート A10 GPU を起動します。ノートブック自体は、オーケストレーション用の単一の A10 上で実行されます。

依存関係をインストールする

次のセルは、分散 Ray および SGLang 推論に必要なすべてのパッケージをインストールします。

  • Flash Attention : 推論を高速化するための最適化されたアテンション (CUDA 12、PyTorch 2.6、A10 互換)
  • SGLang : 高性能 LLM 推論およびサービス フレームワーク
  • Ray Data : バッチ推論のための分散データ処理
  • hf_transfer : 高速Hugging Faceモデルのダウンロード
Python
# Pre-compiled Flash Attention for A10s (Essential for speed/compilation)
%pip install --no-cache-dir "torch==2.9.1+cu128" --index-url https://download.pytorch.org/whl/cu128
%pip install -U --no-cache-dir wheel ninja packaging
%pip install --force-reinstall --no-cache-dir --no-build-isolation flash-attn
%pip install hf_transfer
%pip install "ray[data]>=2.47.1"

# SGLang with all dependencies (handles vLLM/Torch automatically)
%pip install "sglang[all]>=0.4.7"

%restart_python

パッケージのバージョンを確認する

次のセルは、必要なすべてのパッケージが互換性のあるバージョンでインストールされていることを確認します。

Python
from packaging.version import Version

import torch
import flash_attn
import sglang
import ray

print(f"PyTorch: {torch.__version__}")
print(f"Flash Attention: {flash_attn.__version__}")
print(f"SGLang: {sglang.__version__}")
print(f"Ray: {ray.__version__}")

assert Version(ray.__version__) >= Version("2.47.1"), "Ray version must be at least 2.47.1"
print("\n✓ All version checks passed!")

構成

ウィジェットを使用して推論とオプションのHugging Face認証を設定します。

セキュリティに関する注意: 本番運用で使用するために、 Hugging FaceをDatabricks Secrets に保存します。 Databricks Secrets のドキュメントを参照してください。

Python
# Widget configuration
dbutils.widgets.text("hf_secret_scope", "")
dbutils.widgets.text("hf_secret_key", "")
dbutils.widgets.text("model_name", "Qwen/Qwen3-4B-Instruct-2507")
dbutils.widgets.text("num_gpus", "5")
dbutils.widgets.text("num_prompts", "1000")

# Unity Catalog configuration for output storage
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_volume", "ray_data")
dbutils.widgets.text("uc_table", "sglang_inference_results")

# Retrieve widget values
HF_SECRET_SCOPE = dbutils.widgets.get("hf_secret_scope")
HF_SECRET_KEY = dbutils.widgets.get("hf_secret_key")
MODEL_NAME = dbutils.widgets.get("model_name")
NUM_GPUS = int(dbutils.widgets.get("num_gpus"))
NUM_PROMPTS = int(dbutils.widgets.get("num_prompts"))

# Unity Catalog paths
UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_VOLUME = dbutils.widgets.get("uc_volume")
UC_TABLE = dbutils.widgets.get("uc_table")

# Construct paths
UC_VOLUME_PATH = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}"
UC_TABLE_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{UC_TABLE}"
PARQUET_OUTPUT_PATH = f"{UC_VOLUME_PATH}/sglang_inference_output"

print(f"Model: {MODEL_NAME}")
print(f"Number of GPUs: {NUM_GPUS}")
print(f"Number of prompts: {NUM_PROMPTS}")
print(f"\nUnity Catalog Configuration:")
print(f" Volume Path: {UC_VOLUME_PATH}")
print(f" Table Name: {UC_TABLE_NAME}")
print(f" Parquet Output: {PARQUET_OUTPUT_PATH}")

Hugging Faceで認証する(オプション)

ゲート付きモデル (Llama など) を使用する場合は、Hugging Face で認証します。

オプション 1: Databricks Secret を使用する (本番運用に推奨)

Python
hf_token = dbutils.secrets.get(scope=HF_SECRET_SCOPE, key=HF_SECRET_KEY)
login(token=hf_token)

オプション2: 対話型ログイン(開発用)

Python
from huggingface_hub import login

# Uncomment ONE of the following options:

# Option 1: Use Databricks Secrets (recommended)
# if HF_SECRET_SCOPE and HF_SECRET_KEY:
# hf_token = dbutils.secrets.get(scope=HF_SECRET_SCOPE, key=HF_SECRET_KEY)
# login(token=hf_token)
# print("✓ Logged in using Databricks Secrets")

# Option 2: Interactive login
login()
print("✓ Hugging Face authentication complete")

Unity Catalogリソースをセットアップする

次のセルは、推論結果を保存するために必要なUnity Catalogリソース (カタログ、スキーマ、ボリューム) を作成します。 これらのリソースは、ガバナンス、リネージ追跡、生成された出力の集中ストレージを提供します。

Python
# Unity Catalog Setup and Dataset Download
# ⚠️ IMPORTANT: Run this cell BEFORE the dataset processing cell
# to set up Unity Catalog resources and download the raw datasets.

import os
from databricks.sdk import WorkspaceClient
from databricks.sdk.service import catalog
import requests

# Create Unity Catalog resources
w = WorkspaceClient()

# Create catalog if it doesn't exist
try:
created_catalog = w.catalogs.create(name=UC_CATALOG)
print(f"Catalog '{created_catalog.name}' created successfully")
except Exception as e:
print(f"Catalog '{UC_CATALOG}' already exists or error: {e}")

# Create schema if it doesn't exist
try:
created_schema = w.schemas.create(name=UC_SCHEMA, catalog_name=UC_CATALOG)
print(f"Schema '{created_schema.name}' created successfully")
except Exception as e:
print(f"Schema '{UC_SCHEMA}' already exists in catalog '{UC_CATALOG}' or error: {e}")

# Create volume if it doesn't exist
volume_path = f'/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}'
if not os.path.exists(volume_path):
try:
created_volume = w.volumes.create(
catalog_name=UC_CATALOG,
schema_name=UC_SCHEMA,
name=UC_VOLUME,
volume_type=catalog.VolumeType.MANAGED
)
print(f"Volume '{created_volume.name}' created successfully")
except Exception as e:
print(f"Volume '{UC_VOLUME}' already exists or error: {e}")
else:
print(f"Volume {volume_path} already exists")

レイクラスターリソースモニタリング

次のセルは、Ray クラスターのリソースを検査し、ノード間の GPU 割り当てを確認するユーティリティ関数を定義します。

Python
import json
import ray

def print_ray_resources():
"""Print Ray cluster resources and GPU allocation per node."""
try:
cluster_resources = ray.cluster_resources()
print("Ray Cluster Resources:")
print(json.dumps(cluster_resources, indent=2))

nodes = ray.nodes()
print(f"\nDetected {len(nodes)} Ray node(s):")

for node in nodes:
node_id = node.get("NodeID", "N/A")[:8] # Truncate for readability
ip_address = node.get("NodeManagerAddress", "N/A")
resources = node.get("Resources", {})
num_gpus = int(resources.get("GPU", 0))

print(f" • Node {node_id}... | IP: {ip_address} | GPUs: {num_gpus}")

# Show specific GPU IDs if available
gpu_ids = [k for k in resources.keys() if k.startswith("GPU_ID_")]
if gpu_ids:
print(f" GPU IDs: {', '.join(gpu_ids)}")

except Exception as e:
print(f"Error querying Ray cluster: {e}")

# Uncomment to display resources after cluster initialization
# print_ray_resources()

分散推論タスクを定義する

このノートブックでは、分散LLM推論に SGLang RuntimeRay Data を使用しています。 SGLang は、効率的なプレフィックス キャッシュを実現する RadixAttention などの機能を備えた最適化されたランタイムを提供します。

アーキテクチャの概要

┌─────────────────┐    ┌─────────────────┐    ┌─────────────────┐
│ Ray Data │───▶│ SGLang Runtime │───▶│ Generated │
│ (Prompts) │ │ (map_batches) │ │ Outputs │
└─────────────────┘ └─────────────────┘ └─────────────────┘
│ │
▼ ▼
Distributed GPU Workers (A10)
across nodes with SGLang engines

SGLangの主な利点

  • RadixAttention : リクエスト間での効率的なKVキャッシュの再利用
  • 構造化生成 : 正規表現/文法による制約誘導デコード
  • 最適化されたランタイム : 自動バッチ処理による高スループット
  • マルチターンサポート :会話プロンプトの効率的な処理
Python
from serverless_gpu.ray import ray_launch
import os

# Set Ray temp directory
os.environ['RAY_TEMP_DIR'] = '/tmp/ray'

# Set the UC Volumes temp directory for write_databricks_table
os.environ['_RAY_UC_VOLUMES_FUSE_TEMP_DIR'] = f"{UC_VOLUME_PATH}/ray_temp"

@ray_launch(gpus=NUM_GPUS, gpu_type='a10', remote=True)
def run_distributed_inference():
"""Run distributed LLM inference using Ray Data and SGLang Runtime."""
import ray
import numpy as np
from typing import Dict
import sglang as sgl
from datetime import datetime

# Sample prompts for inference
base_prompts = [
"Hello, my name is",
"The president of the United States is",
"The future of AI is",
]

# Scale up prompts for distributed processing
prompts = [{"prompt": p} for p in base_prompts * (NUM_PROMPTS // len(base_prompts))]
ds = ray.data.from_items(prompts)

print(f"✓ Created Ray dataset with {ds.count()} prompts")

# Define SGLang Predictor class for Ray Data map_batches
class SGLangPredictor:
"""SGLang-based predictor for batch inference with Ray Data."""

def __init__(self):
# Initialize SGLang Runtime inside the actor process
self.runtime = sgl.Runtime(
model_path=MODEL_NAME,
dtype="bfloat16",
trust_remote_code=True,
mem_fraction_static=0.85,
tp_size=1, # Tensor parallelism (1 GPU per worker)
)
# Set as default backend for the current process
sgl.set_default_backend(self.runtime)
print(f"✓ SGLang runtime initialized with model: {MODEL_NAME}")

def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
"""Process a batch of prompts using SGLang."""
input_prompts = batch["prompt"].tolist()

# Define SGLang prompt function
@sgl.function
def chat_completion(s, prompt):
s += sgl.system("You are a helpful AI assistant.")
s += sgl.user(prompt)
s += sgl.assistant(sgl.gen("response", max_tokens=100, temperature=0.8))

# Run batch inference
results = chat_completion.run_batch(
[{"prompt": p} for p in input_prompts],
progress_bar=False
)

generated_text = [r["response"] for r in results]

return {
"prompt": input_prompts,
"generated_text": generated_text,
"model": [MODEL_NAME] * len(input_prompts),
"timestamp": [datetime.now().isoformat()] * len(input_prompts),
}

def __del__(self):
"""Clean up runtime when actor dies."""
try:
self.runtime.shutdown()
except:
pass

print(f"✓ SGLang predictor configured with model: {MODEL_NAME}")

# Apply map_batches with SGLang predictor
ds = ds.map_batches(
SGLangPredictor,
concurrency=NUM_GPUS, # Number of parallel SGLang instances
batch_size=32,
num_gpus=1,
num_cpus=12,
)

# =========================================================================
# Write results to Parquet (stored in Unity Catalog Volume)
# =========================================================================
print(f"\n📦 Writing results to Parquet: {PARQUET_OUTPUT_PATH}")
ds.write_parquet(PARQUET_OUTPUT_PATH, mode="overwrite")
print(f"✓ Parquet files written successfully")

# Collect sample outputs for display
sample_outputs = ray.data.read_parquet(PARQUET_OUTPUT_PATH).take(limit=10)

print("\n" + "="*60)
print("SAMPLE INFERENCE RESULTS")
print("="*60 + "\n")

for i, output in enumerate(sample_outputs):
prompt = output.get("prompt", "N/A")
generated_text = output.get("generated_text", "")
display_text = generated_text[:100] if generated_text else "N/A"
print(f"[{i+1}] Prompt: {prompt!r}")
print(f" Generated: {display_text!r}...\n")

return PARQUET_OUTPUT_PATH

分散推論を実行する

次のセルは、複数の A10 GPU にわたって分散推論タスクを起動します。これにより、次のようになります。

  1. リモート A10 GPU ワーカーのプロビジョニング
  2. 各ワーカーのSGLangエンジンを初期化する
  3. Ray Data を使用してワーカー全体にプロンプトを配布する
  4. 生成された出力をParquetに保存する

注: GPU ノードがプロビジョニングされ、モデルがロードされるため、最初の起動には数分かかる場合があります。

Python
result = run_distributed_inference.distributed()
parquet_path = result[0] if NUM_GPUS > 1 else result
print(f"\n✓ Inference complete! Results saved to: {parquet_path}")

Parquet をロードして Delta テーブルに変換する

次のセルは、 Sparkを使用してUnity CatalogボリュームからParquet出力を読み込み、効率的なクエリとガバナンスのためにDeltaテーブルとして保存します。

Python
# Load Parquet data using Spark
print(f"📖 Loading Parquet from: {PARQUET_OUTPUT_PATH}")
df_spark = spark.read.parquet(PARQUET_OUTPUT_PATH)

# Show schema and row count
print(f"\n✓ Loaded {df_spark.count()} rows")
print("\nSchema:")
df_spark.printSchema()

# Display sample rows
print("\nSample Results:")
display(df_spark.limit(10))

Unity CatalogにDeltaテーブルとして保存

次のセルは、次の推論結果を Unity Catalog Delta テーブルに書き込みます。

  • ガバナンス : データ改修とアクセス制御の追跡
  • パフォーマンス : Delta Lakeによるクエリの最適化
  • バージョン管理 : タイムトラベルと監査履歴
Python
# Write to Unity Catalog Delta table
print(f"💾 Writing to Delta table: {UC_TABLE_NAME}")

# Write the DataFrame as a Delta table (overwrite mode)
df_spark.write \
.format("delta") \
.mode("overwrite") \
.option("overwriteSchema", "true") \
.saveAsTable(UC_TABLE_NAME)

print(f"✓ Delta table created successfully: {UC_TABLE_NAME}")

Deltaテーブルをクエリする

次のセルは、Delta テーブルが作成されたことを確認し、SQL を使用してクエリを実行します。

Python
# Query the Delta table using SQL
print(f"📊 Querying Delta table: {UC_TABLE_NAME}\n")

# Get table info
display(spark.sql(f"DESCRIBE TABLE {UC_TABLE_NAME}"))

# Query sample results
print("\nSample Results from Delta Table:")
display(spark.sql(f"""
SELECT
prompt,
generated_text,
model,
timestamp
FROM {UC_TABLE_NAME}
LIMIT 10
"""))

# Get row count
row_count = spark.sql(f"SELECT COUNT(*) as count FROM {UC_TABLE_NAME}").collect()[0]["count"]
print(f"\n✓ Total rows in Delta table: {row_count}")

# Assert expected row count (NUM_PROMPTS should result in 999 rows: 1000 // 3 * 3 = 999)
expected_rows = (NUM_PROMPTS // 3) * 3 # Rounds down to nearest multiple of 3 base prompts
assert row_count == expected_rows, f"Expected {expected_rows} rows, but got {row_count}"

次のステップ

このノートブックは、 Databricksサーバレス GPU 上で Ray Data と SGLang を使用した分散LLM推論のデモンストレーションに成功し、結果はDeltaテーブルに保存されました。

何が達成されたか

  • 複数のA10 GPUにわたって分散LLM推論を実行
  • 最適化されたバッチ処理のために SGLang Runtimeを使用
  • Unity CatalogボリュームのParquetに保存された結果
  • ParquetをガバナンスされたDeltaテーブルに変換しました

カスタマイズオプション

  • モデルの変更 : 別の Hugging Face モデルを使用するようにmodel_nameウィジェットを更新します
  • スケールアップ : スループットを向上させるにはnum_gpusを増やします
  • バッチサイズを調整 : メモリ制約に基づいてmap_batchesbatch_sizeを変更します
  • 生成の調整 : SGLang関数のmax_tokenstemperatureを調整します
  • 構造化出力 : JSON出力にSGLangの正規表現/文法制約を使用する
  • マルチターンチャット : プロンプト関数で複数のuser / assistant呼び出しを連鎖します
  • 追加モード : 増分更新のDelta書き込みをmode("append")に変更します

SGLangの高度な機能

Python
# Structured JSON output with regex constraint
@sgl.function
def json_output(s, prompt):
s += sgl.user(prompt)
s += sgl.assistant(sgl.gen("response", regex=r'\{"name": "\w+", "age": \d+\}'))

# Multi-turn conversation
@sgl.function
def multi_turn(s, question1, question2):
s += sgl.user(question1)
s += sgl.assistant(sgl.gen("answer1"))
s += sgl.user(question2)
s += sgl.assistant(sgl.gen("answer2"))

代替: write_databricks_table

Unity Catalog 対応ワークスペースの場合は、 ray.data.Dataset.write_databricks_table()を使用して Unity Catalog テーブルに直接書き込みます。

Python
# Set the temp directory environment variable
os.environ["_RAY_UC_VOLUMES_FUSE_TEMP_DIR"] = "/Volumes/catalog/schema/volume/ray_temp"

# Write directly to Unity Catalog table
ds.write_databricks_table(table_name="catalog.schema.table_name")

リソース

掃除

ノートブックが切断されると、GPU リソースは自動的にクリーンアップされます。手動で切断するには:

  1. 「コンピュート」ドロップダウンで 「接続済み」 をクリックします。
  2. サーバレス の上にマウスを移動します
  3. ドロップダウンメニューから 「終了」 を選択します

サンプルノートブック

サーバーレス GPU 上で Ray Data と SGLang を使用した分散LLM推論

ノートブックを新しいタブで開く