サーバーレスGPU上でRay DataとvLLMを使用した分散LLM推論
このノートブックでは、 Databricksサーバレス GPU 上で Ray Data と vLLM を使用して大規模言語モデル ( LLM ) 推論を大規模に実行する方法を示します。 分散サーバレス GPU APIを利用して、分散推論用のマルチノード A10 GPU を自動的にプロビジョニングおよび管理します。
学習内容:
- 分散LLM推論用にRayとvLLMを設定する
map_batches - Ray Data を使用して、複数の GPU 間でプロンプトを効率的にバッチ処理します。
- Unity CatalogボリュームのParquetに推論結果を保存する
- ガバナンスと効率的なクエリのために Parquet テーブルを Delta テーブルに変換する
- Ray クラスターのリソースと GPU の割り当てを監視する
ユースケース: 効率的な GPU 使用率、永続ストレージ、Delta Lake 統合により、数千のプロンプトをバッチ推論します。
サーバレスGPUコンピュートに接続する
- 上部の 「接続」 ドロップダウンをクリックします。
- サーバレス GPU を選択します。
- ノートブックの右側にある 環境 サイドパネルを開きます。
- このデモでは、 アクセラレータ を A10 に設定します。
- 「適用」 と 「確認」 をクリックして、この環境をノートブックに適用します。
注: 分散関数は、マルチノード推論のためにリモート A10 GPU を起動します。ノートブック自体は、オーケストレーション用の単一の A10 上で実行されます。
依存関係をインストールする
分散 Ray および vLLM 推論に必要なすべてのパッケージをインストールします。
- Flash Attention : 推論を高速化するための最適化されたアテンション (CUDA 12、PyTorch 2.6、A10 互換)
- vLLM : ハイスループットLLM推論エンジン
- Ray Data : LLM サポートによる分散データ処理 (
ray.data.llmAPI) - トランスフォーマー : Hugging Faceモデル読み込みユーティリティ
%pip install --force-reinstall --no-cache-dir --no-deps "https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl"
%pip install "transformers<4.54.0"
%pip install "vllm==0.8.5.post1"
%pip install "ray[data]>=2.47.1" # Required for ray.data.llm API
%pip install "opentelemetry-exporter-prometheus"
%pip install "optree>=0.13.0"
%pip install hf_transfer
%pip install "numpy==1.26.4"
%restart_python
パッケージのバージョンを確認する
必要なすべてのパッケージが互換性のあるバージョンでインストールされていることを確認します。
from packaging.version import Version
import torch
import flash_attn
import vllm
import ray
import transformers
print(f"PyTorch: {torch.__version__}")
print(f"Flash Attention: {flash_attn.__version__}")
print(f"vLLM: {vllm.__version__}")
print(f"Ray: {ray.__version__}")
print(f"Transformers: {transformers.__version__}")
assert Version(ray.__version__) >= Version("2.47.1"), "Ray version must be at least 2.47.1"
print("\n✓ All version checks passed!")
構成
ウィジェットを使用して推論とオプションのHugging Face認証を設定します。
セキュリティに関する注意: 本番運用で使用するために、 Hugging FaceをDatabricks Secrets に保存します。 Databricks Secrets のドキュメントを参照してください。
# Widget configuration
dbutils.widgets.text("hf_secret_scope", "")
dbutils.widgets.text("hf_secret_key", "")
dbutils.widgets.text("model_name", "Qwen/Qwen3-4B-Instruct-2507")
dbutils.widgets.text("num_gpus", "5")
dbutils.widgets.text("num_prompts", "1000")
# Unity Catalog configuration for output storage
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_volume", "ray_data")
dbutils.widgets.text("uc_table", "llm_inference_results")
# Retrieve widget values
HF_SECRET_SCOPE = dbutils.widgets.get("hf_secret_scope")
HF_SECRET_KEY = dbutils.widgets.get("hf_secret_key")
MODEL_NAME = dbutils.widgets.get("model_name")
NUM_GPUS = int(dbutils.widgets.get("num_gpus"))
NUM_PROMPTS = int(dbutils.widgets.get("num_prompts"))
# Unity Catalog paths
UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_VOLUME = dbutils.widgets.get("uc_volume")
UC_TABLE = dbutils.widgets.get("uc_table")
# Construct paths
UC_VOLUME_PATH = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}"
UC_TABLE_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{UC_TABLE}"
PARQUET_OUTPUT_PATH = f"{UC_VOLUME_PATH}/inference_output"
print(f"Model: {MODEL_NAME}")
print(f"Number of GPUs: {NUM_GPUS}")
print(f"Number of prompts: {NUM_PROMPTS}")
print(f"\nUnity Catalog Configuration:")
print(f" Volume Path: {UC_VOLUME_PATH}")
print(f" Table Name: {UC_TABLE_NAME}")
print(f" Parquet Output: {PARQUET_OUTPUT_PATH}")
Hugging Faceで認証する(オプション)
ゲート付きモデル (Llama など) を使用する場合は、Hugging Face で認証します。
オプション 1: Databricks Secret を使用する (本番運用に推奨)
hf_token = dbutils.secrets.get(scope=HF_SECRET_SCOPE, key=HF_SECRET_KEY)
オプション2: 対話型ログイン(開発用)
from huggingface_hub import login
# Uncomment ONE of the following options:
# Option 1: Use Databricks Secrets (recommended)
# if HF_SECRET_SCOPE and HF_SECRET_KEY:
# hf_token = dbutils.secrets.get(scope=HF_SECRET_SCOPE, key=HF_SECRET_KEY)
# login(token=hf_token)
# print("✓ Logged in using Databricks Secrets")
# Option 2: Interactive login
login()
print("✓ Hugging Face authentication complete")
レイクラスターリソースモニタリング
Ray クラスターのリソースを検査し、ノード間の GPU 割り当てを確認するユーティリティ関数。
import json
import ray
def print_ray_resources():
"""Print Ray cluster resources and GPU allocation per node."""
try:
cluster_resources = ray.cluster_resources()
print("Ray Cluster Resources:")
print(json.dumps(cluster_resources, indent=2))
nodes = ray.nodes()
print(f"\nDetected {len(nodes)} Ray node(s):")
for node in nodes:
node_id = node.get("NodeID", "N/A")[:8] # Truncate for readability
ip_address = node.get("NodeManagerAddress", "N/A")
resources = node.get("Resources", {})
num_gpus = int(resources.get("GPU", 0))
print(f" • Node {node_id}... | IP: {ip_address} | GPUs: {num_gpus}")
# Show specific GPU IDs if available
gpu_ids = [k for k in resources.keys() if k.startswith("GPU_ID_")]
if gpu_ids:
print(f" GPU IDs: {', '.join(gpu_ids)}")
except Exception as e:
print(f"Error querying Ray cluster: {e}")
# Display current resources
# print_ray_resources()
分散推論タスクを定義する
LLMPredictorクラスは、効率的なバッチ推論のために vLLM をラップします。Ray Data はmap_batchesを使用して、ワークロードを複数の GPU ワーカーに分散します。
アーキテクチャの概要
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ Ray Data │───▶│ LLMPredictor │───▶│ Parquet │
│ (Prompts) │ │ (vLLM Engine) │ │ (UC Volume) │
└─────────────────┘ └─────────────────┘ └─────────────────┘
│ │ │
▼ ▼ ▼
Distributed GPU Workers (A10) Delta Table
across nodes with vLLM instances (UC Table)
from serverless_gpu.ray import ray_launch
import os
# Set Ray temp directory
os.environ['RAY_TEMP_DIR'] = '/tmp/ray'
@ray_launch(gpus=NUM_GPUS, gpu_type='a10', remote=True)
def run_distributed_inference():
"""Run distributed LLM inference using Ray Data and vLLM with map_batches."""
from typing import Dict, List
from datetime import datetime
import numpy as np
import ray
from vllm import LLM, SamplingParams
# Sample prompts for inference
base_prompts = [
"Hello, my name is",
"The president of the United States is",
"The future of AI is",
]
# Scale up prompts for distributed processing
prompts = base_prompts * (NUM_PROMPTS // len(base_prompts))
ds = ray.data.from_items(prompts)
print(f"✓ Created Ray dataset with {ds.count()} prompts")
# Sampling parameters for text generation
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=100
)
class LLMPredictor:
"""vLLM-based predictor for batch inference."""
def __init__(self):
self.llm = LLM(
model=MODEL_NAME,
tensor_parallel_size=1,
dtype="bfloat16",
trust_remote_code=True,
gpu_memory_utilization=0.90,
max_model_len=8192,
enable_prefix_caching=True,
enable_chunked_prefill=True,
max_num_batched_tokens=8192,
)
self.model_name = MODEL_NAME
print(f"✓ vLLM engine initialized with model: {MODEL_NAME}")
def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
"""Process a batch of prompts."""
outputs = self.llm.generate(batch["item"], sampling_params)
prompt_list: List[str] = []
generated_text_list: List[str] = []
model_list: List[str] = []
timestamp_list: List[str] = []
for output in outputs:
prompt_list.append(output.prompt)
generated_text_list.append(
' '.join([o.text for o in output.outputs])
)
model_list.append(self.model_name)
timestamp_list.append(datetime.now().isoformat())
return {
"prompt": prompt_list,
"generated_text": generated_text_list,
"model": model_list,
"timestamp": timestamp_list,
}
# Configure number of parallel vLLM instances
num_instances = NUM_GPUS
# Apply the predictor across the dataset using map_batches
ds = ds.map_batches(
LLMPredictor,
concurrency=num_instances,
batch_size=32,
num_gpus=1,
num_cpus=12
)
# =========================================================================
# Write results to Parquet (stored in Unity Catalog Volume)
# =========================================================================
print(f"\n📦 Writing results to Parquet: {PARQUET_OUTPUT_PATH}")
ds.write_parquet(PARQUET_OUTPUT_PATH, mode="overwrite")
print(f"✓ Parquet files written successfully")
# Collect sample outputs for display
sample_outputs = ray.data.read_parquet(PARQUET_OUTPUT_PATH).take(limit=10)
print("\n" + "="*60)
print("SAMPLE INFERENCE RESULTS")
print("="*60 + "\n")
for i, output in enumerate(sample_outputs):
prompt = output.get("prompt", "N/A")
generated_text = output.get("generated_text", "")
display_text = generated_text[:100] if generated_text else "N/A"
print(f"[{i+1}] Prompt: {prompt!r}")
print(f" Generated: {display_text!r}...\n")
return PARQUET_OUTPUT_PATH
分散推論を実行する
複数の A10 GPU にわたって分散推論タスクを起動します。これにより、次のようになります。
- リモート A10 GPU ワーカーのプロビジョニング
- 各ワーカーのvLLMエンジンを初期化する
- Ray Data を使用してワーカー全体にプロンプトを配布する
- 生成された出力を収集して返す
注: GPU ノードがプロビジョニングされ、モデルがロードされるため、最初の起動には数分かかる場合があります。
result = run_distributed_inference.distributed()
parquet_path = result[0] if NUM_GPUS > 1 else result
print(f"\n✓ Inference complete! Results saved to: {parquet_path}")
Parquet をロードして結果をプレビューする
Sparkを使用してUnity CatalogボリュームからParquet出力を読み込み、推論結果をプレビューします。
# Load Parquet data using Spark
print(f"📖 Loading Parquet from: {PARQUET_OUTPUT_PATH}")
df_spark = spark.read.parquet(PARQUET_OUTPUT_PATH)
# Show schema and row count
print(f"\n✓ Loaded {df_spark.count()} rows")
print("\nSchema:")
df_spark.printSchema()
# Display sample rows
print("\nSample Results:")
display(df_spark.limit(10))
Unity CatalogにDeltaテーブルとして保存
次の推論結果を Unity Catalog Delta テーブルに書き込みます。
- ガバナンス : データ改修とアクセス制御の追跡
- パフォーマンス : Delta Lakeによるクエリの最適化
- バージョン管理 : タイムトラベルと監査履歴
# Write to Unity Catalog Delta table
print(f"💾 Writing to Delta table: {UC_TABLE_NAME}")
# Write the DataFrame as a Delta table (overwrite mode)
df_spark.write \
.format("delta") \
.mode("overwrite") \
.option("overwriteSchema", "true") \
.saveAsTable(UC_TABLE_NAME)
print(f"✓ Delta table created successfully: {UC_TABLE_NAME}")
Deltaテーブルをクエリする
Delta テーブルが作成されたことを確認し、SQL を使用してクエリを実行します。
# Query the Delta table using SQL
print(f"📊 Querying Delta table: {UC_TABLE_NAME}\n")
# Get table info
display(spark.sql(f"DESCRIBE TABLE {UC_TABLE_NAME}"))
# Query sample results
print("\nSample Results from Delta Table:")
display(spark.sql(f"""
SELECT
prompt,
generated_text,
model,
timestamp
FROM {UC_TABLE_NAME}
LIMIT 10
"""))
# Get row count and verify correctness
row_count = spark.sql(f"SELECT COUNT(*) as count FROM {UC_TABLE_NAME}").collect()[0]["count"]
print(f"\n✓ Total rows in Delta table: {row_count}")
# Assert expected row count (NUM_PROMPTS should result in 999 rows: 1000 // 3 * 3 = 999)
expected_rows = (NUM_PROMPTS // 3) * 3 # Rounds down to nearest multiple of 3 base prompts
assert row_count == expected_rows, f"Expected {expected_rows} rows, but got {row_count}"
次のステップ
Databricks GPU 上で Ray Data と vLLM を使用して分散LLM推論を正常に実行し、結果をDeltaテーブルに保存しました。
達成したこと
- ✅ 複数の A10 GPU にわたって分散 LLM 推論を実行しました
- ✅ バッチ処理のためにカスタム
LLMPredictorクラスでmap_batchesを使用しました - ✅ Unity CatalogボリュームのParquetに結果を保存しました
- ✅ Parquet をガバナンスされた Delta テーブルに変換しました
カスタマイズオプション
- モデルの変更 : 別の Hugging Face モデルを使用するように
model_nameウィジェットを更新します - スケールアップ : スループットを向上させるには
num_gpusを増やします - バッチサイズの調整 : メモリ制約に基づいて
map_batches()のbatch_sizeを変更します - チューニング生成 : 異なる出力特性に合わせて
SamplingParamsを調整します - 追加モード : 増分更新のDelta書き込みを
mode("append")に変更します
リソース
掃除
ノートブックが切断されると、GPU リソースは自動的にクリーンアップされます。手動で切断するには:
- 「コンピュート」ドロップダウンで 「接続済み」 をクリックします。
- サーバレス の上にマウスを移動します
- ドロップダウンメニューから 「終了」 を選択します