Pular para o conteúdo principal

Ajuste fino de LLM em múltiplos nós com FSDP

info

Beta

Este recurso está em Beta. Os administradores do espaço de trabalho podem controlar o acesso a esse recurso na página Pré-visualizações . Consulte Gerenciar prévias do Databricks.

Este exemplo executa o ajuste fino supervisionado (SFT) de Llama-3.1-8B em 16 GPUs H100 distribuídas em 2 nós usando torchrun e PyTorch Fully Sharded Data Parallel (FSDP). O FSDP fragmenta os parâmetros do modelo, gradientes e estados do otimizador entre todos os 16 ranks para que o modelo de 8B de parâmetros e seu estado do otimizador caibam confortavelmente na memória da GPU.

A carga de trabalho executa as seguintes ações:

  • Faz upload do projeto local com code_source: snapshot.
  • Inicia um processo por GPU com torchrun, usando as variáveis de ambiente de rendezvous que o Runtime de AI define em cada nó.
  • Lê um modelo com acesso restrito do Hugging Face usando um segredo do Databricks.
  • Registra métricas no MLflow e escreve o checkpoint consolidado em um volume do Unity Catalog.

Pré-requisitos

  • air CLI instalada e autenticada. Consulte Instalar a CLI do Runtime de AI.
  • Um volume do Unity Catalog no qual é possível gravar para o ponto de verificação de saída.
  • Acesso ao modelo fechado no Hugging Face, mais um access token armazenado como um segredo do Databricks (veja abaixo).

Obter acesso ao modelo no Hugging Face

Llama-3.1-8B é um modelo restrito, então é preciso solicitar acesso e fornecer um token para baixá-lo:

  1. Acesse a página do modelo em meta-llama/Llama-3.1-8B e aceite a licença para solicitar acesso. Espere até que o acesso seja concedido.
  2. Crie um access token do Hugging Face com permissão de leitura .

Armazene o token como um segredo do Databricks

A carga de trabalho lê o token de um segredo do Databricks em vez de codificá-lo diretamente. Crie um Secret Scope e adicione seu token:

Bash
databricks secrets create-scope my_scope
databricks secrets put-secret my_scope hf_token

train.yaml o referencia como my_scope/hf_token. Substitua o escopo e a key pelos seus próprios valores.

Disposição do projeto

Criar um diretório com os seguintes arquivos.

Text
multinode_llm_sft/
├── train.yaml # air workload config (inline dependencies + torchrun launcher)
└── train.py # FSDP fine-tuning script

O passo 1: Escreva a carga de trabalho YAML

train.yaml Solicita 16 GPUs como dois nós GPU_8xH100, monta o token Hugging Face como um segredo e passa hiperparâmetros para o script através do bloco parameters. Dependências são declaradas em linha sob environment (com a imagem do cliente version). O pacote torch é enviado na imagem base do AI Runtime, então apenas os extras são listados:

YAML
experiment_name: air-multinode-llama-sft

environment:
version: '4'
dependencies:
- transformers>=4.45
- datasets>=3.0
- huggingface_hub>=0.34
- accelerate>=0.34
# The base image ships fsspec 2023.5.0, which is too old for modern
# huggingface_hub and breaks dataset/model downloads. Pin a newer fsspec.
- fsspec>=2024.6.1

# 16 GPUs across 2 nodes (GPU_8xH100 = 8 H100 per node).
compute:
num_accelerators: 16
accelerator_type: GPU_8xH100

code_source:
type: snapshot
snapshot:
root_path: .

command: |
cd $CODE_SOURCE_PATH
# air sets NUM_NODES, NODE_RANK, LOCAL_WORLD_SIZE, MASTER_ADDR, and MASTER_PORT on each node.
torchrun \
--nnodes="$NUM_NODES" \
--node_rank="$NODE_RANK" \
--nproc_per_node="${LOCAL_WORLD_SIZE:-8}" \
--master_addr="$MASTER_ADDR" \
--master_port="$MASTER_PORT" \
train.py

# Pin NCCL control-plane traffic to eth0 so cross-node rendezvous works.
env_variables:
NCCL_SOCKET_IFNAME: eth0
HF_HOME: /tmp/hf

# Gated model download needs a Hugging Face token. Replace with your own
# Databricks secret in the form "scope/key".
secrets:
HF_TOKEN: 'my_scope/hf_token'

max_retries: 1
timeout_minutes: 120

# Surfaced to train.py via HYPERPARAMETERS_PATH.
parameters:
model_name: meta-llama/Llama-3.1-8B
dataset_name: tatsu-lab/alpaca
max_seq_len: 1024
per_device_batch_size: 4
gradient_accumulation_steps: 2
learning_rate: 0.00002
max_steps: 100
output_dir: /Volumes/main/default/air_checkpoints/llama31-8b-sft

O AI Runtime executa command uma vez por nó e define as variáveis de ambiente de rendezvous (NUM_NODES, NODE_RANK, LOCAL_WORLD_SIZE, MASTER_ADDR e MASTER_PORT) em cada nó. torchrun os lê para iniciar um processo por GPU; assim, o comando em linha é o inicializador completo. Nenhum script de inicialização separado é necessário.

passo 2: Escrever o script de treinamento FSDP

train.py Inicializa o grupo de processo, envolve cada bloco de transformador em FSDP, treina em um dataset de instrução tokenizado e salva um checkpoint consolidado do rank 0. Os key pieces:

Python
# Shard each transformer block independently so no single GPU holds the full model.
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls={LlamaDecoderLayer}
)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
device_id=local_rank,
use_orig_params=True,
)

O Rank 0 coleta o dicionário de estado completo (descarregado para a CPU) e o grava no volume do Unity Catalog:

Python
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
cpu_state = model.state_dict()
if rank == 0:
model.module.save_pretrained(output_dir, state_dict=cpu_state)
tokenizer.save_pretrained(output_dir)

O script completo está listado em Script de treinamento completo no final desta página.

Passo 3: Enviar a execução

Validar a configuração, em seguida, enviar e acompanhar logs:

Bash
air run -f train.yaml --dry-run
air run -f train.yaml --watch

o passo 4: inspeção da execução

Execuções distribuídas abrangem múltiplos nós. Use --node para ler logs de um nó específico:

Bash
air get run <run-id>
air logs <run-id> --node 0
air logs <run-id> --node 1

Onde os resultados são armazenados

  • Métricas e parâmetros : registrados no experimento MLflow nomeado em experiment_name. Visualizá-los na interface do usuário do MLflow no workspace.
  • Ponto de verificação ajustado : gravado no volume do Unity Catalog em parameters.output_dir.

Script de treinamento completo

O train.py completo para copiar e colar:

Python
#!/usr/bin/env python3
"""Multi-node FSDP supervised fine-tuning of Llama-3.1-8B.

Launched via ``torchrun`` from the workload YAML ``command`` across 2 nodes x 8 H100 (16 ranks). Each rank
owns one GPU. The model is sharded with PyTorch FSDP (full shard + bf16), trained on
an instruction dataset, and the consolidated checkpoint is written to a Unity Catalog
Volume by rank 0. Metrics are logged to MLflow.

Hyperparameters are read from the YAML block passed by ``air`` via HYPERPARAMETERS_PATH.
"""

import functools
import os

import mlflow
import torch
import torch.distributed as dist
import yaml
from datasets import load_dataset
from torch.distributed.fsdp import FullStateDictConfig, FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer


def load_params() -> dict:
"""Read the hyperparameters block that `air` materializes from the YAML `parameters:`."""
path = os.environ.get("HYPERPARAMETERS_PATH")
if path and os.path.exists(path):
with open(path) as f:
return yaml.safe_load(f) or {}
return {}


def build_dataset(tokenizer, dataset_name: str, max_seq_len: int):
"""Tokenize an instruction dataset into fixed-length causal-LM examples."""
raw = load_dataset(dataset_name, split="train")

def format_example(row):
instruction = row.get("instruction", "")
context = row.get("input", "")
response = row.get("output", "")
prompt = f"### Instruction:\n{instruction}\n\n"
if context:
prompt += f"### Input:\n{context}\n\n"
text = f"{prompt}### Response:\n{response}{tokenizer.eos_token}"
out = tokenizer(text, truncation=True, max_length=max_seq_len, padding="max_length")
out["labels"] = out["input_ids"].copy()
return out

cols = raw.column_names
tokenized = raw.map(format_example, remove_columns=cols)
# Emit torch tensors so the default DataLoader collate stacks them into [B, L] batches.
tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
return tokenized


def main():
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])

dist.init_process_group(backend="nccl")
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")

p = load_params()
model_name = p.get("model_name", "meta-llama/Llama-3.1-8B")
dataset_name = p.get("dataset_name", "tatsu-lab/alpaca")
max_seq_len = int(p.get("max_seq_len", 1024))
batch_size = int(p.get("per_device_batch_size", 4))
grad_accum = int(p.get("gradient_accumulation_steps", 2))
lr = float(p.get("learning_rate", 2e-5))
max_steps = int(p.get("max_steps", 100))
output_dir = p.get("output_dir", "/tmp/llama-sft")

if rank == 0:
print(f"World size={world_size} | model={model_name} | dataset={dataset_name}", flush=True)

# --- Model & data --------------------------------------------------------
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
model.config.use_cache = False # incompatible with gradient checkpointing / FSDP training
model.gradient_checkpointing_enable()

# Shard each transformer block independently so no single GPU holds the full model.
auto_wrap_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={LlamaDecoderLayer})
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
device_id=local_rank,
use_orig_params=True,
)

dataset = build_dataset(tokenizer, dataset_name, max_seq_len)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, drop_last=True)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

# --- MLflow (rank 0 only) ------------------------------------------------
# AI Runtime injects MLFLOW_RUN_ID and configures the databricks tracking URI on
# the node, so logging works without DATABRICKS_HOST/TOKEN. Gate on MLFLOW_RUN_ID
# so the script also runs cleanly off-platform (e.g. locally) where it is unset.
use_mlflow = rank == 0 and bool(os.environ.get("MLFLOW_RUN_ID"))
if use_mlflow:
mlflow.start_run(run_id=os.environ.get("MLFLOW_RUN_ID"))
mlflow.log_params({"model_name": model_name, "lr": lr, "batch_size": batch_size, "world_size": world_size})

# --- Training loop -------------------------------------------------------
model.train()
sampler.set_epoch(0)
step = 0
optimizer.zero_grad()
for micro_step, batch in enumerate(loader):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)

out = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
(out.loss / grad_accum).backward()

if (micro_step + 1) % grad_accum == 0:
model.clip_grad_norm_(1.0)
optimizer.step()
optimizer.zero_grad()
step += 1
if rank == 0:
print(f"step={step}/{max_steps} loss={out.loss.item():.4f}", flush=True)
if use_mlflow:
mlflow.log_metric("train_loss", out.loss.item(), step=step)
if step >= max_steps:
break

# --- Save consolidated checkpoint to the UC Volume (rank 0) --------------
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
cpu_state = model.state_dict()
if rank == 0:
os.makedirs(output_dir, exist_ok=True)
model.module.save_pretrained(output_dir, state_dict=cpu_state)
tokenizer.save_pretrained(output_dir)
print(f"Saved checkpoint to {output_dir}", flush=True)
if use_mlflow:
mlflow.end_run()

dist.barrier()
dist.destroy_process_group()


if __name__ == "__main__":
main()

Passos seguintes