メインコンテンツまでスキップ

サーバレスGPUコンピュート上のPyTorch FSDPを使用した分散トレーニング

このノートブックでは、 Databricksサーバーレス GPU コンピュート上でPyTorchの Fully Sharded Data Parallel (FSDP) による分散トレーニングを使用してTransformerモデルをトレーニングする方法を示します。 FSDP は、モデルの分散、勾配、およびオプティマイザーの状態を複数の GPU に分割するデータ並列技術で、単一の GPU に収まらない大規模なモデルの効率的なトレーニングを可能にします。

この例では、次の方法を学習します。

  • サーバレス GPU 分散トレーニングAPIを使用して分散トレーニングをセットアップする
  • FSDP を使用して 10M の損失Transformerモデルを定義およびトレーニングする
  • トレーニング中に分散チェックポイントを保存する
  • MLflowを使用してエクスペリメントを追跡する
  • 推論または継続トレーニングのためのチェックポイントをロードする

このノートブックは合成データを使用して自己完結性を維持していますが、独自のデータセットで動作するように適応させることができます。

重要な概念:

  • FSDP (Fully Sharded Data Parallel) : GPU 全体でモデルをシャーディングしてメモリ使用量を削減し、より大規模なモデルのトレーニングを可能にするPyTorch分散トレーニング戦略。
  • サーバーレス GPU コンピュート : ワークロードに合わせて自動的にスケーリングし、プロビジョニング リソースを提供するDatabricksマネージド GPU コンピュート。

詳細については、 「マルチ GPU およびマルチノード分散トレーニング」を参照してください。

依存関係をインストールする

エクスペリメントの追跡とモデルのログ記録のために、 MLflowの最新バージョンをインストールします。

Python
%pip install -U mlflow
%restart_python

Unity Catalog場所を構成する

モデルとチェックポイントが保存されるUnity Catalog場所を設定します。 ワークスペースの構成に合わせてこれらの値を更新します。指定されたカタログとスキーマに対するUSE CATALOGおよびUSE SCHEMA権限が必要です。

Python
# You must have `USE CATALOG` privileges on the catalog, and you must have `USE SCHEMA` privileges on the schema.
# If necessary, change the catalog and schema name here.
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("model_name", "transformer_fsdp")
dbutils.widgets.text("uc_volume", "checkpoints")

UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_VOLUME = dbutils.widgets.get("uc_volume")
MODEL_NAME = dbutils.widgets.get("model_name")
UC_MODEL_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{MODEL_NAME}"

print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_VOLUME: {UC_VOLUME}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")

ヘルパー関数と合成データセットを定義する

このセクションでは、分散トレーニング セットアップのユーティリティ関数と、デモンストレーション用の合成データセット クラスを定義します。本番運用では、 SyntheticDatasetを独自のデータ読み込みロジックに置き換えます。

主なコンポーネント:

  • setup(): 分散トレーニングプロセスグループを初期化し、GPUデバイスを構成します
  • cleanup(): トレーニング後に分散プロセスグループをクリーンアップします
  • AppState: PyTorch の分散チェックポイント API と互換性のある、チェックポイント モデルとオプティマイザーの状態のラッパー クラス
  • SyntheticDataset: トレーニングデモンストレーション用のランダムデータを生成します
Python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint import FileSystemWriter as StorageWriter
import torch.multiprocessing as mp
from torch.distributed.fsdp import fully_shard
from torch.utils.data import Dataset, DataLoader, DistributedSampler
import numpy as np
import os
import time

# Below is an example of distributed checkpoint based on
# https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html
class AppState(Stateful):
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
dcp.save/load APIs.

Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
and optimizer.
"""

def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer

def state_dict(self):
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
}

def load_state_dict(self, state_dict):
# sets our state dicts on the model and optimizer, now that we've loaded
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"]
)

def setup():
"""Initialize the distributed training process group"""
# Check if we're in a distributed environment
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
local_rank = int(os.environ.get('LOCAL_RANK', 0))
else:
# Fallback for single GPU
rank = 0
world_size = 1
local_rank = 0

# Initialize process group
if world_size > 1:
if not dist.is_initialized():
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)

# Set device
if torch.cuda.is_available():
device = torch.device(f'cuda:{local_rank}')
torch.cuda.set_device(device)
else:
device = torch.device('cpu')

return rank, world_size, device

def cleanup():
"""Clean up the distributed training process group"""
if dist.is_initialized():
dist.destroy_process_group()

class SyntheticDataset(Dataset):
"""Simple synthetic dataset for demo purposes"""
def __init__(self, size=10000, input_dim=512, num_classes=10):
self.size = size
self.input_dim = input_dim
self.num_classes = num_classes

# Generate synthetic data
np.random.seed(42) # For reproducible results
self.data = torch.randn(size, input_dim)
# Create labels with some pattern
self.labels = torch.randint(0, num_classes, (size,))

def __len__(self):
return self.size

def __getitem__(self, idx):
return self.data[idx], self.labels[idx]

FSDPでTransformerモデルを定義する

このセクションでは、分類のためのシンプルなTransformerモデルと、FSDPシャーディングを適用するロジックを定義します。FSDP は通常、7B 以上の大規模言語モデルに使用されますが、この例では、複数の H100 GPU にわたってシャード化された小規模な 10M プロセッサ モデルを使用した手法を示します。

モデルアーキテクチャ:

  • TransformerBlock: マルチヘッドアテンションとMLPを備えた単一のトランスフォーマー層
  • SimpleTransformer: 入力投影と分類ヘッドを備えた変換ブロックのスタック
  • apply_fsdp(): 分散トレーニングのためにモデルレイヤーをFSDPでラップする

FSDP は、GPU 全体でモデルの不安、勾配、オプティマイザーの状態をシャード化し、GPU ごとのメモリ要件を削減し、より大規模なモデルのトレーニングを可能にします。

Python
class TransformerBlock(nn.Module):
"""Simple transformer block for testing FSDP"""
def __init__(self, dim=512, num_heads=8, mlp_ratio=4):
super().__init__()
self.attention = nn.MultiheadAttention(dim, num_heads, batch_first=True)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)

mlp_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Linear(mlp_dim, dim),
)

def forward(self, x):
# Self-attention
attn_out, _ = self.attention(x, x, x)
x = self.norm1(x + attn_out)

# MLP
mlp_out = self.mlp(x)
x = self.norm2(x + mlp_out)

return x

class SimpleTransformer(nn.Module):
"""Simple transformer model for classification with FSDP"""
def __init__(self, input_dim=512, num_layers=64, num_classes=10):
super().__init__()
self.input_projection = nn.Linear(input_dim, input_dim)
self.layers = nn.ModuleList([
TransformerBlock(dim=input_dim) for _ in range(num_layers)
])
self.norm = nn.LayerNorm(input_dim)
self.classifier = nn.Linear(input_dim, num_classes)

def forward(self, x):
# Add sequence dimension for transformer
x = x.unsqueeze(1) # [batch, 1, input_dim]

x = self.input_projection(x)

for layer in self.layers:
x = layer(x)

x = self.norm(x)
# Global average pooling
x = x.mean(dim=1) # [batch, input_dim]

return self.classifier(x)

def apply_fsdp(model, world_size):
"""Apply FSDP to the model"""
if world_size > 1:
print("Applying FSDP to model layers...")
# Apply fsdp to each transformer layer
for i, layer in enumerate(model.layers):
fully_shard(layer)
print(f"Applied FSDP to layer {i}")

# Apply FSDP to the entire model
fully_shard(model)
print("Applied FSDP to entire model")
else:
print("Single GPU detected, skipping FSDP setup")

return model

分散トレーニング関数を定義する

トレーニング関数は、サーバレス GPU APIの@distributedデコレータでラップされています。 このデコレータは以下を処理します。

  • 指定された数のGPU(この例では8個のH100 GPU)をプロビジョニングします。
  • 分散トレーニング環境の設定
  • リモート コンピュート リソースのライフサイクルの管理

トレーニング機能には以下が含まれます。

  • モデルの初期化とFSDPラッピング
  • 並列データ処理のためにDistributedSamplerでデータをロードします
  • 勾配更新を伴うトレーニングループ
  • PyTorch の分散チェックポイント API を使用した定期的なチェックポイントの保存
  • エクスペリメント追跡のためのMLflowログ記録

チェックポイントはUnity Catalogボリュームに保存され、バージョン管理と再現性のためにMLflowアーティファクトとして記録されます。

Python
from serverless_gpu import distributed
from serverless_gpu.compute import GPUType

NUM_WORKERS = 8
CHECKPOINT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{MODEL_NAME}"
@distributed(gpus=NUM_WORKERS, gpu_type=GPUType.H100)
def run_fsdp_training(num_workers=NUM_WORKERS):
"""
Self-contained FSDP training demo using PyTorch 2.0+
Trains a simple neural network on synthetic data using FSDP
"""
import mlflow
mlflow.start_run(run_name='fsdp_example')
def main_training():
"""Main training function"""
print("Starting FSDP Training Demo...")

# Setup distributed training
rank, world_size, device = setup()

print(f"Rank: {rank}, World Size: {world_size}, Device: {device}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device count: {torch.cuda.device_count()}")
print(f"Current CUDA device: {torch.cuda.current_device()}")

# Create dataset and data loader
dataset = SyntheticDataset(size=10000, input_dim=512, num_classes=10)

# Use DistributedSampler if we have multiple processes
if world_size > 1:
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
shuffle = False
else:
sampler = None
shuffle = True

dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=shuffle,
sampler=sampler,
num_workers=num_workers,
pin_memory=True
)

# Create model
model = SimpleTransformer(input_dim=512, num_layers=4, num_classes=10).to(device)

# Apply FSDP
model = apply_fsdp(model, world_size)

print(f"Model created and moved to device: {device}")
if rank == 0:
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

# Training loop
num_epochs = 5
loss_history = []

print(f"Training for {num_epochs} epochs...")
writer = StorageWriter(cache_staged_state_dict=False, path=CHECKPOINT_DIR)
for epoch in range(num_epochs):
if sampler:
sampler.set_epoch(epoch)

model.train()
total_loss = 0.0
num_batches = 0

epoch_start_time = time.time()

for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.to(device), target.to(device)

# Zero gradients
optimizer.zero_grad()

# Forward pass
output = model(data)
loss = criterion(output, target)

# Backward pass
loss.backward()
mlflow.log_metric(
key='loss',
value=loss.item(),
step=batch_idx,
)
# Update weights
optimizer.step()

total_loss += loss.item()

num_batches += 1

if batch_idx % 10 == 0:
print(f'Saving checkpoint to {CHECKPOINT_DIR}/step{batch_idx}')
state_dict = { 'app': AppState(model, optimizer) }
ckpt_start_time = time.time()
dcp.save(state_dict, storage_writer=writer, checkpoint_id=f"{CHECKPOINT_DIR}/step{batch_idx}")
ckpt_time = time.time() - ckpt_start_time
print(f'Checkpointing took {ckpt_time:.2f}s')
mlflow.log_artifacts(f'{CHECKPOINT_DIR}/step{batch_idx}', artifact_path=f'checkpoints/step{batch_idx}')
if rank == 0:
print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}, Loss: {loss.item():.6f}')
# Calculate average loss for this epoch
avg_loss = total_loss / num_batches
mlflow.log_metric(key='avg_loss', value=avg_loss)

loss_history.append(avg_loss)

epoch_time = time.time() - epoch_start_time

if rank == 0:
print(f'Epoch {epoch+1}/{num_epochs} with {num_batches} completed in {epoch_time:.2f}s. Average Loss: {avg_loss:.6f}')

# Verify loss is decreasing
if rank == 0:
print("\n=== FSDP Training Results ===")
print("Loss history:")
for i, loss in enumerate(loss_history):
print(f"Epoch {i+1}: {loss:.6f}")

# Check if loss is generally decreasing
initial_loss = loss_history[0]
final_loss = loss_history[-1]
loss_reduction = ((initial_loss - final_loss) / initial_loss) * 100

print(f"\nInitial Loss: {initial_loss:.6f}")
print(f"Final Loss: {final_loss:.6f}")
print(f"Loss Reduction: {loss_reduction:.2f}%")

if final_loss < initial_loss:
print("✅ SUCCESS: FSDP training is working! Loss is decreasing.")
else:
print("❌ WARNING: Loss did not decrease. Check training configuration.")

print(f"\nFSDP training completed successfully on {world_size} GPU(s)")

# Cleanup
cleanup()
mlflow.end_run()

return {
'initial_loss': loss_history[0] if loss_history else None,
'final_loss': loss_history[-1] if loss_history else None,
'loss_history': loss_history,
'world_size': world_size,
'device': str(device),
'fsdp_enabled': world_size > 1
}

# Run the training
return main_training()

分散トレーニングを実行する

トレーニング関数を実行して、8つのH100 GPUに分散したトレーニングを開始します。.distributed()メソッドは、サーバレス GPU コンピュートでのリモート実行をトリガーします。 トレーニングの進行状況、損失メトリクス、チェックポイントがMLflowに記録されます。

このセルは、GPU リソースをプロビジョニングし、5 エポックの間モデルをトレーニングし、チェックポイントを保存するため、完了するまでに数分かかる場合があります。

Python
print("Starting FSDP Demo on Databricks Serverless GPU...")
result = run_fsdp_training.distributed()
print("FSDP Demo completed!")
print(f"Training Results: {result}")

モデルチェックポイントをロードする

このセクションでは、推論または継続的なトレーニングのために保存されたチェックポイントを読み込む方法を説明します。チェックポイントには、トレーニング中に保存されたモデルの重みとオプティマイザーの状態が含まれます。

分散トレーニング コンテキスト (プロセス グループが初期化されていない) の外部でチェックポイントをロードする場合、PyTorch の分散チェックポイント API は自動的に集合操作を無効にし、チェックポイントを単一のデバイスにロードすることに注意してください。

Python
def run_checkpoint_load_example():
# create the non FSDP-wrapped toy model
model = SimpleTransformer(input_dim=512, num_layers=4, num_classes=10)
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
state_dict = { 'app': AppState(model, optimizer)}

# print(state_dict)
# since no progress group is initialized, DCP will disable any collectives.
dcp.load(
state_dict=state_dict,
checkpoint_id=f'{CHECKPOINT_DIR}/step0',
)
model.load_state_dict(state_dict['app'].state_dict()['model'])

run_checkpoint_load_example()

次のステップ

サーバレス GPU コンピュートでの分散トレーニングにPyTorch FSDP を使用する方法を学習したので、次のリソースを参照して詳細を学習してください。

サンプルノートブック

サーバレスGPUコンピュート上のPyTorch FSDPを使用した分散トレーニング

ノートブックを新しいタブで開く