Treinamento distribuído usando PyTorch FSDP em computeGPU serverless
Este Notebook demonstra como ensinar um modelo Transformer usando treinamento distribuído com o Fully Sharded Data Parallel (FSDP) do PyTorch em compute GPU serverless Databricks . FSDP é uma técnica de paralelismo de dados que distribui parâmetros do modelo, gradientes e estados do otimizador entre várias GPUs, permitindo o treinamento eficiente de modelos grandes que não cabem em uma única GPU.
Neste exemplo, você aprenderá como:
- Configure o treinamento distribuído com a APIde treinamento distribuído de GPUserverless
- Defina e ensine um modelo Transformer de 10 milhões de parâmetros usando FSDP.
- Salvar pontos de verificação distribuídos durante o treinamento
- Acompanhe os experimentos com o MLflow.
- Carregar pontos de verificação para inferência ou treinamento contínuo.
Este notebook utiliza dados sintéticos para manter sua autossuficiência, mas você pode adaptá-lo para funcionar com seu próprio conjunto de dados.
conceitos-chave:
- FSDP (Fully Sharded Data Parallel) : Uma estratégia de treinamento distribuído do PyTorch que fragmenta os parâmetros do modelo entre GPUs para reduzir o uso de memória e permitir o treinamento de modelos maiores.
- computeGPU sem servidor : Databricks gerencia compute GPU que escala e provisiona recursos automaticamente para suas cargas de trabalho.
Para obter mais informações, consulte Treinamento distribuído multi-GPU e multinó.
Instalar dependências
Instale a versão mais recente do MLflow para acompanhamento de experimentos e registro de modelos.
%pip install -U mlflow
%restart_python
Configurar locais Unity Catalog
Configure os locais Unity Catalog onde o modelo e os pontos de verificação serão armazenados. Atualize esses valores para que correspondam à configuração do seu workspace . Você precisa de privilégios USE CATALOG e USE SCHEMA no catálogo e esquema especificados.
# You must have `USE CATALOG` privileges on the catalog, and you must have `USE SCHEMA` privileges on the schema.
# If necessary, change the catalog and schema name here.
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("model_name", "transformer_fsdp")
dbutils.widgets.text("uc_volume", "checkpoints")
UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_VOLUME = dbutils.widgets.get("uc_volume")
MODEL_NAME = dbutils.widgets.get("model_name")
UC_MODEL_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{MODEL_NAME}"
print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_VOLUME: {UC_VOLUME}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")
Defina funções auxiliares e datasetsintéticos.
Esta seção define funções utilitárias para configuração de treinamento distribuído e uma classe de dataset sintéticos para fins de demonstração. Em produção, você substituiria o SyntheticDataset pela sua própria lógica de carregamento de dados.
componentes principais:
setup()Inicializa o grupo de processos de treinamento distribuído e configura dispositivos de GPU.cleanup()Limpa o grupo de processos distribuídos após o treinamento.AppStateUma classe wrapper para o checkpoint do modelo e do estado do otimizador, compatível com a API de checkpoint distribuído do PyTorch.SyntheticDatasetGera dados aleatórios para demonstração de treinamento.
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint import FileSystemWriter as StorageWriter
import torch.multiprocessing as mp
from torch.distributed.fsdp import fully_shard
from torch.utils.data import Dataset, DataLoader, DistributedSampler
import numpy as np
import os
import time
# Below is an example of distributed checkpoint based on
# https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html
class AppState(Stateful):
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
dcp.save/load APIs.
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
and optimizer.
"""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
}
def load_state_dict(self, state_dict):
# sets our state dicts on the model and optimizer, now that we've loaded
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"]
)
def setup():
"""Initialize the distributed training process group"""
# Check if we're in a distributed environment
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
local_rank = int(os.environ.get('LOCAL_RANK', 0))
else:
# Fallback for single GPU
rank = 0
world_size = 1
local_rank = 0
# Initialize process group
if world_size > 1:
if not dist.is_initialized():
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
# Set device
if torch.cuda.is_available():
device = torch.device(f'cuda:{local_rank}')
torch.cuda.set_device(device)
else:
device = torch.device('cpu')
return rank, world_size, device
def cleanup():
"""Clean up the distributed training process group"""
if dist.is_initialized():
dist.destroy_process_group()
class SyntheticDataset(Dataset):
"""Simple synthetic dataset for demo purposes"""
def __init__(self, size=10000, input_dim=512, num_classes=10):
self.size = size
self.input_dim = input_dim
self.num_classes = num_classes
# Generate synthetic data
np.random.seed(42) # For reproducible results
self.data = torch.randn(size, input_dim)
# Create labels with some pattern
self.labels = torch.randint(0, num_classes, (size,))
def __len__(self):
return self.size
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
Defina o modelo Transformer com FSDP
Esta seção define um modelo Transformer simples para classificação e a lógica para aplicar o particionamento FSDP. Embora o FSDP seja normalmente usado para grandes modelos de linguagem com mais de 7 bilhões de parâmetros, este exemplo demonstra a técnica com um modelo menor, de 10 milhões de parâmetros, distribuído por várias GPUs H100.
Arquitetura do modelo:
TransformerBlockUma única camada de transformador com atenção multi-cabeça e MLP (Multiple-Heads Perceptron).SimpleTransformerUma pilha de blocos Transformer com cabeça de projeção e classificação de entrada.apply_fsdp(): Envolve as camadas do modelo com FSDP para treinamento distribuído
O FSDP fragmenta os parâmetros do modelo, os gradientes e os estados do otimizador entre as GPUs, reduzindo os requisitos de memória por GPU e permitindo o treinamento de modelos maiores.
class TransformerBlock(nn.Module):
"""Simple transformer block for testing FSDP"""
def __init__(self, dim=512, num_heads=8, mlp_ratio=4):
super().__init__()
self.attention = nn.MultiheadAttention(dim, num_heads, batch_first=True)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
mlp_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Linear(mlp_dim, dim),
)
def forward(self, x):
# Self-attention
attn_out, _ = self.attention(x, x, x)
x = self.norm1(x + attn_out)
# MLP
mlp_out = self.mlp(x)
x = self.norm2(x + mlp_out)
return x
class SimpleTransformer(nn.Module):
"""Simple transformer model for classification with FSDP"""
def __init__(self, input_dim=512, num_layers=64, num_classes=10):
super().__init__()
self.input_projection = nn.Linear(input_dim, input_dim)
self.layers = nn.ModuleList([
TransformerBlock(dim=input_dim) for _ in range(num_layers)
])
self.norm = nn.LayerNorm(input_dim)
self.classifier = nn.Linear(input_dim, num_classes)
def forward(self, x):
# Add sequence dimension for transformer
x = x.unsqueeze(1) # [batch, 1, input_dim]
x = self.input_projection(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x)
# Global average pooling
x = x.mean(dim=1) # [batch, input_dim]
return self.classifier(x)
def apply_fsdp(model, world_size):
"""Apply FSDP to the model"""
if world_size > 1:
print("Applying FSDP to model layers...")
# Apply fsdp to each transformer layer
for i, layer in enumerate(model.layers):
fully_shard(layer)
print(f"Applied FSDP to layer {i}")
# Apply FSDP to the entire model
fully_shard(model)
print("Applied FSDP to entire model")
else:
print("Single GPU detected, skipping FSDP setup")
return model
Defina a função de treinamento distribuído
A função de treinamento é envolvida com o decorador @distributed da API de GPU serverless . Este decorador cuida de:
- provisionamento do número especificado de GPUs (8 GPUs H100 neste exemplo)
- Configurando o ambiente de treinamento distribuído
- Gerenciando o ciclo de vida de recursos compute remota.
A função de treinamento inclui:
- Inicialização do modelo e encapsulamento FSDP
- Carregamento de dados com
DistributedSamplerpara processamento de dados paralelo - Loop de treinamento com atualizações de gradiente
- Salvamento periódico de pontos de verificação usando a API de pontos de verificação distribuídos do PyTorch.
- Registro de MLflow para acompanhamento de experimento
Os pontos de verificação são salvos em um volume Unity Catalog e os registros são criados como artefatos MLflow para controle de versão e reprodução.
from serverless_gpu import distributed
from serverless_gpu.compute import GPUType
NUM_WORKERS = 8
CHECKPOINT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{MODEL_NAME}"
@distributed(gpus=NUM_WORKERS, gpu_type=GPUType.H100)
def run_fsdp_training(num_workers=NUM_WORKERS):
"""
Self-contained FSDP training demo using PyTorch 2.0+
Trains a simple neural network on synthetic data using FSDP
"""
import mlflow
mlflow.start_run(run_name='fsdp_example')
def main_training():
"""Main training function"""
print("Starting FSDP Training Demo...")
# Setup distributed training
rank, world_size, device = setup()
print(f"Rank: {rank}, World Size: {world_size}, Device: {device}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device count: {torch.cuda.device_count()}")
print(f"Current CUDA device: {torch.cuda.current_device()}")
# Create dataset and data loader
dataset = SyntheticDataset(size=10000, input_dim=512, num_classes=10)
# Use DistributedSampler if we have multiple processes
if world_size > 1:
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
shuffle = False
else:
sampler = None
shuffle = True
dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=shuffle,
sampler=sampler,
num_workers=num_workers,
pin_memory=True
)
# Create model
model = SimpleTransformer(input_dim=512, num_layers=4, num_classes=10).to(device)
# Apply FSDP
model = apply_fsdp(model, world_size)
print(f"Model created and moved to device: {device}")
if rank == 0:
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
# Training loop
num_epochs = 5
loss_history = []
print(f"Training for {num_epochs} epochs...")
writer = StorageWriter(cache_staged_state_dict=False, path=CHECKPOINT_DIR)
for epoch in range(num_epochs):
if sampler:
sampler.set_epoch(epoch)
model.train()
total_loss = 0.0
num_batches = 0
epoch_start_time = time.time()
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.to(device), target.to(device)
# Zero gradients
optimizer.zero_grad()
# Forward pass
output = model(data)
loss = criterion(output, target)
# Backward pass
loss.backward()
mlflow.log_metric(
key='loss',
value=loss.item(),
step=batch_idx,
)
# Update weights
optimizer.step()
total_loss += loss.item()
num_batches += 1
if batch_idx % 10 == 0:
print(f'Saving checkpoint to {CHECKPOINT_DIR}/step{batch_idx}')
state_dict = { 'app': AppState(model, optimizer) }
ckpt_start_time = time.time()
dcp.save(state_dict, storage_writer=writer, checkpoint_id=f"{CHECKPOINT_DIR}/step{batch_idx}")
ckpt_time = time.time() - ckpt_start_time
print(f'Checkpointing took {ckpt_time:.2f}s')
mlflow.log_artifacts(f'{CHECKPOINT_DIR}/step{batch_idx}', artifact_path=f'checkpoints/step{batch_idx}')
if rank == 0:
print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}, Loss: {loss.item():.6f}')
# Calculate average loss for this epoch
avg_loss = total_loss / num_batches
mlflow.log_metric(key='avg_loss', value=avg_loss)
loss_history.append(avg_loss)
epoch_time = time.time() - epoch_start_time
if rank == 0:
print(f'Epoch {epoch+1}/{num_epochs} with {num_batches} completed in {epoch_time:.2f}s. Average Loss: {avg_loss:.6f}')
# Verify loss is decreasing
if rank == 0:
print("\n=== FSDP Training Results ===")
print("Loss history:")
for i, loss in enumerate(loss_history):
print(f"Epoch {i+1}: {loss:.6f}")
# Check if loss is generally decreasing
initial_loss = loss_history[0]
final_loss = loss_history[-1]
loss_reduction = ((initial_loss - final_loss) / initial_loss) * 100
print(f"\nInitial Loss: {initial_loss:.6f}")
print(f"Final Loss: {final_loss:.6f}")
print(f"Loss Reduction: {loss_reduction:.2f}%")
if final_loss < initial_loss:
print("✅ SUCCESS: FSDP training is working! Loss is decreasing.")
else:
print("❌ WARNING: Loss did not decrease. Check training configuration.")
print(f"\nFSDP training completed successfully on {world_size} GPU(s)")
# Cleanup
cleanup()
mlflow.end_run()
return {
'initial_loss': loss_history[0] if loss_history else None,
'final_loss': loss_history[-1] if loss_history else None,
'loss_history': loss_history,
'world_size': world_size,
'device': str(device),
'fsdp_enabled': world_size > 1
}
# Run the training
return main_training()
execução do treinamento distribuído
Execute a função de treinamento para começar o treinamento distribuído em 8 GPUs H100. O método .distributed() aciona a execução remota em compute GPU serverless . O progresso do treinamento, as métricas de perda e os pontos de verificação serão registrados no MLflow.
Esta célula pode levar vários minutos para ser concluída, pois provisiona recursos de GPU, treina o modelo por 5 épocas e salva os pontos de verificação.
print("Starting FSDP Demo on Databricks Serverless GPU...")
result = run_fsdp_training.distributed()
print("FSDP Demo completed!")
print(f"Training Results: {result}")
Carregar um ponto de verificação do modelo
Esta seção demonstra como carregar um ponto de verificação salvo para inferência ou treinamento contínuo. O ponto de verificação contém os pesos do modelo e o estado do otimizador salvos durante o treinamento.
Observe que, ao carregar pontos de verificação fora de um contexto de treinamento distribuído (nenhum grupo de processos inicializado), a API de pontos de verificação distribuídos do PyTorch desativa automaticamente as operações coletivas e carrega o ponto de verificação em um único dispositivo.
def run_checkpoint_load_example():
# create the non FSDP-wrapped toy model
model = SimpleTransformer(input_dim=512, num_layers=4, num_classes=10)
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
state_dict = { 'app': AppState(model, optimizer)}
# print(state_dict)
# since no progress group is initialized, DCP will disable any collectives.
dcp.load(
state_dict=state_dict,
checkpoint_id=f'{CHECKPOINT_DIR}/step0',
)
model.load_state_dict(state_dict['app'].state_dict()['model'])
run_checkpoint_load_example()
Próximos passos
Agora que você aprendeu como usar PyTorch FSDP para treinamento distribuído em compute GPU serverless , explore estes recursos para saber mais:
- Treinamento distribuído com múltiplas GPUs e múltiplos nós - Aprenda sobre diferentes estratégias de treinamento distribuído.
- Melhores práticas para computeGPU serverless - Otimize suas cargas de trabalho de GPU
- Solução de problemas em computeGPU serverless - Problemas comuns e soluções
- DocumentaçãoPyTorch FSDP - Análise detalhada do recurso FSDP e sua configuração.