FSDP を使用した複数ノード LLM ファインチューニング
ベータ版
この機能はベータ版です。ワークスペース管理者は、 プレビュー ページからこの機能へのアクセスを制御できます。Databricksのプレビューを管理するを参照してください。
この例では、Llama-3.1-8B の教師ありファインチューニング(SFT)を実行しますtorchrunおよびPyTorchのFully Sharded Data Parallel (FSDP)を使用し、2つのノードにわたって分散された16個のH100 GPUでFSDP は、モデルのパラメーター、勾配、およびオプティマイザーの状態を 16 のすべてのランクにわたってシャードするため、8B のパラメーター モデルとそのオプティマイザーの状態が GPU メモリに快適に収まります。
ワークロードでは、次の処理を行います:
- ローカルプロジェクトを
code_source: snapshotでアップロードします。 - AI Runtime が各ノードで設定する rendezvous 環境変数を使用して、
torchrunと共に GPU ごとに 1 つのプロセスを起動します。 - Databricksシークレットを使用して、Hugging Faceからゲート付きモデルを読み取ります。
- MLflowにメトリクスをログに記録し、統合されたチェックポイントをUnity Catalogボリュームに書き込みます。
前提条件
airCLI がインストールされ、認証されています。「AI Runtime CLI のインストール」を参照してください。- 出力チェックポイント用に書き込み可能なUnity Catalogボリューム。
- Hugging Face上のゲート付きモデルへのアクセスと、Databricksシークレットとして保存されたアクセストークン(下記を参照してください)。
Hugging Face上のモデルへのアクセス
Llama-3.1-8Bはゲートモデルであるため、アクセスをリクエストし、トークンを提供していただく必要があります:
- meta-llama/Llama-3.1-8B のモデルページを開き、ライセンスに同意してアクセスをリクエストしてください。アクセス権の付与までお待ちください。
- Hugging Face アクセストークン を**読取**権限で作成します。
Databricks シークレットにトークンを格納します
ワークロードは、ハードコーディングする代わりに、Databricksシークレットからトークンを読み取ります。シークレットスコープを作成し、トークンを追加します
databricks secrets create-scope my_scope
databricks secrets put-secret my_scope hf_token
train.yaml それをmy_scope/hf_tokenとして参照します。スコープとキーをご自身のものに置き換えてください。
プロジェクト レイアウト
次のファイルを含むディレクトリを作成します。
multinode_llm_sft/
├── train.yaml # air workload config (inline dependencies + torchrun launcher)
└── train.py # FSDP fine-tuning script
ステップ 1: ワークロード YAML を記述する
train.yaml 2 つの GPU_8xH100 ノードとして 16 GPU を要求し、Hugging Face トークンをシークレットとしてマウントし、parameters ブロック経由でスクリプトにハイパーパラメーターを渡します。依存関係は、environment の下にインラインで宣言されます(クライアントイメージ version とともに)。torch パッケージはAIランタイムベースイメージに同梱されており、エキストラのみがリストされています。
experiment_name: air-multinode-llama-sft
environment:
version: '4'
dependencies:
- transformers>=4.45
- datasets>=3.0
- huggingface_hub>=0.34
- accelerate>=0.34
# The base image ships fsspec 2023.5.0, which is too old for modern
# huggingface_hub and breaks dataset/model downloads. Pin a newer fsspec.
- fsspec>=2024.6.1
# 16 GPUs across 2 nodes (GPU_8xH100 = 8 H100 per node).
compute:
num_accelerators: 16
accelerator_type: GPU_8xH100
code_source:
type: snapshot
snapshot:
root_path: .
command: |
cd $CODE_SOURCE_PATH
# air sets NUM_NODES, NODE_RANK, LOCAL_WORLD_SIZE, MASTER_ADDR, and MASTER_PORT on each node.
torchrun \
--nnodes="$NUM_NODES" \
--node_rank="$NODE_RANK" \
--nproc_per_node="${LOCAL_WORLD_SIZE:-8}" \
--master_addr="$MASTER_ADDR" \
--master_port="$MASTER_PORT" \
train.py
# Pin NCCL control-plane traffic to eth0 so cross-node rendezvous works.
env_variables:
NCCL_SOCKET_IFNAME: eth0
HF_HOME: /tmp/hf
# Gated model download needs a Hugging Face token. Replace with your own
# Databricks secret in the form "scope/key".
secrets:
HF_TOKEN: 'my_scope/hf_token'
max_retries: 1
timeout_minutes: 120
# Surfaced to train.py via HYPERPARAMETERS_PATH.
parameters:
model_name: meta-llama/Llama-3.1-8B
dataset_name: tatsu-lab/alpaca
max_seq_len: 1024
per_device_batch_size: 4
gradient_accumulation_steps: 2
learning_rate: 0.00002
max_steps: 100
output_dir: /Volumes/main/default/air_checkpoints/llama31-8b-sft
AI Runtime はノードごとに command を実行し、各ノードにランデブー環境変数(NUM_NODES、NODE_RANK、LOCAL_WORLD_SIZE、MASTER_ADDR、および MASTER_PORT)を設定します。torchrun は、GPU ごとに 1 つのプロセスを起動するためにそれらを読み取るため、インライン コマンド全体がランチャーとなります。別途ランチャースクリプトは必要ありません。
ステップ 2: FSDP トレーニング スクリプトを記述する
train.py プロセスグループを初期化し、各トランスフォーマーブロックをFSDPでラップし、トークン化された命令データセットでトレーニングし、ランク0から統合されたチェックポイントを保存します。主な要素は次のとおりです:
# Shard each transformer block independently so no single GPU holds the full model.
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls={LlamaDecoderLayer}
)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
device_id=local_rank,
use_orig_params=True,
)
ランク 0 は完全なステートディク (CPU にオフロード済み) を収集し、それを Unity Catalog ボリュームに書き込みます。
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
cpu_state = model.state_dict()
if rank == 0:
model.module.save_pretrained(output_dir, state_dict=cpu_state)
tokenizer.save_pretrained(output_dir)
完全なスクリプトは、このページの最後にある「完全なトレーニングスクリプト」に記載されています。
ステップ3:実行を送信する
設定を検証し、送信してログを監視します:
air run -f train.yaml --dry-run
air run -f train.yaml --watch
ステップ 4: 実行を確認する
分散実行が複数のノードにまたがります。特定のノードからログを読み取るには、「--node」を使用してください。
air get run <run-id>
air logs <run-id> --node 0
air logs <run-id> --node 1
結果の保存場所
- メトリクスとパラメーター :
experiment_nameで指定されたMLflowエクスペリメントに記録されます。ワークスペースの MLflow UI で表示できます。 - 微調整済みチェックポイント :
parameters.output_dirの Unity Catalog ボリュームに書き込まれました。
完全なトレーニングスクリプト
コピー/貼り付けの完全なtrain.py
#!/usr/bin/env python3
"""Multi-node FSDP supervised fine-tuning of Llama-3.1-8B.
Launched via ``torchrun`` from the workload YAML ``command`` across 2 nodes x 8 H100 (16 ranks). Each rank
owns one GPU. The model is sharded with PyTorch FSDP (full shard + bf16), trained on
an instruction dataset, and the consolidated checkpoint is written to a Unity Catalog
Volume by rank 0. Metrics are logged to MLflow.
Hyperparameters are read from the YAML block passed by ``air`` via HYPERPARAMETERS_PATH.
"""
import functools
import os
import mlflow
import torch
import torch.distributed as dist
import yaml
from datasets import load_dataset
from torch.distributed.fsdp import FullStateDictConfig, FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
def load_params() -> dict:
"""Read the hyperparameters block that `air` materializes from the YAML `parameters:`."""
path = os.environ.get("HYPERPARAMETERS_PATH")
if path and os.path.exists(path):
with open(path) as f:
return yaml.safe_load(f) or {}
return {}
def build_dataset(tokenizer, dataset_name: str, max_seq_len: int):
"""Tokenize an instruction dataset into fixed-length causal-LM examples."""
raw = load_dataset(dataset_name, split="train")
def format_example(row):
instruction = row.get("instruction", "")
context = row.get("input", "")
response = row.get("output", "")
prompt = f"### Instruction:\n{instruction}\n\n"
if context:
prompt += f"### Input:\n{context}\n\n"
text = f"{prompt}### Response:\n{response}{tokenizer.eos_token}"
out = tokenizer(text, truncation=True, max_length=max_seq_len, padding="max_length")
out["labels"] = out["input_ids"].copy()
return out
cols = raw.column_names
tokenized = raw.map(format_example, remove_columns=cols)
# Emit torch tensors so the default DataLoader collate stacks them into [B, L] batches.
tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
return tokenized
def main():
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group(backend="nccl")
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
p = load_params()
model_name = p.get("model_name", "meta-llama/Llama-3.1-8B")
dataset_name = p.get("dataset_name", "tatsu-lab/alpaca")
max_seq_len = int(p.get("max_seq_len", 1024))
batch_size = int(p.get("per_device_batch_size", 4))
grad_accum = int(p.get("gradient_accumulation_steps", 2))
lr = float(p.get("learning_rate", 2e-5))
max_steps = int(p.get("max_steps", 100))
output_dir = p.get("output_dir", "/tmp/llama-sft")
if rank == 0:
print(f"World size={world_size} | model={model_name} | dataset={dataset_name}", flush=True)
# --- Model & data --------------------------------------------------------
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
model.config.use_cache = False # incompatible with gradient checkpointing / FSDP training
model.gradient_checkpointing_enable()
# Shard each transformer block independently so no single GPU holds the full model.
auto_wrap_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={LlamaDecoderLayer})
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
device_id=local_rank,
use_orig_params=True,
)
dataset = build_dataset(tokenizer, dataset_name, max_seq_len)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, drop_last=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
# --- MLflow (rank 0 only) ------------------------------------------------
# AI Runtime injects MLFLOW_RUN_ID and configures the databricks tracking URI on
# the node, so logging works without DATABRICKS_HOST/TOKEN. Gate on MLFLOW_RUN_ID
# so the script also runs cleanly off-platform (e.g. locally) where it is unset.
use_mlflow = rank == 0 and bool(os.environ.get("MLFLOW_RUN_ID"))
if use_mlflow:
mlflow.start_run(run_id=os.environ.get("MLFLOW_RUN_ID"))
mlflow.log_params({"model_name": model_name, "lr": lr, "batch_size": batch_size, "world_size": world_size})
# --- Training loop -------------------------------------------------------
model.train()
sampler.set_epoch(0)
step = 0
optimizer.zero_grad()
for micro_step, batch in enumerate(loader):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
out = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
(out.loss / grad_accum).backward()
if (micro_step + 1) % grad_accum == 0:
model.clip_grad_norm_(1.0)
optimizer.step()
optimizer.zero_grad()
step += 1
if rank == 0:
print(f"step={step}/{max_steps} loss={out.loss.item():.4f}", flush=True)
if use_mlflow:
mlflow.log_metric("train_loss", out.loss.item(), step=step)
if step >= max_steps:
break
# --- Save consolidated checkpoint to the UC Volume (rank 0) --------------
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
cpu_state = model.state_dict()
if rank == 0:
os.makedirs(output_dir, exist_ok=True)
model.module.save_pretrained(output_dir, state_dict=cpu_state)
tokenizer.save_pretrained(output_dir)
print(f"Saved checkpoint to {output_dir}", flush=True)
if use_mlflow:
mlflow.end_run()
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
main()