メインコンテンツまでスキップ

FSDP を使用した複数ノード LLM ファインチューニング

備考

ベータ版

この機能はベータ版です。ワークスペース管理者は、 プレビュー ページからこの機能へのアクセスを制御できます。Databricksのプレビューを管理するを参照してください。

この例では、Llama-3.1-8B の教師ありファインチューニング(SFT)を実行しますtorchrunおよびPyTorchのFully Sharded Data Parallel (FSDP)を使用し、2つのノードにわたって分散された16個のH100 GPUでFSDP は、モデルのパラメーター、勾配、およびオプティマイザーの状態を 16 のすべてのランクにわたってシャードするため、8B のパラメーター モデルとそのオプティマイザーの状態が GPU メモリに快適に収まります。

ワークロードでは、次の処理を行います:

  • ローカルプロジェクトを code_source: snapshot でアップロードします。
  • AI Runtime が各ノードで設定する rendezvous 環境変数を使用して、torchrun と共に GPU ごとに 1 つのプロセスを起動します。
  • Databricksシークレットを使用して、Hugging Faceからゲート付きモデルを読み取ります。
  • MLflowにメトリクスをログに記録し、統合されたチェックポイントをUnity Catalogボリュームに書き込みます。

前提条件

  • air CLI がインストールされ、認証されています。「AI Runtime CLI のインストール」を参照してください。
  • 出力チェックポイント用に書き込み可能なUnity Catalogボリューム。
  • Hugging Face上のゲート付きモデルへのアクセスと、Databricksシークレットとして保存されたアクセストークン(下記を参照してください)。

Hugging Face上のモデルへのアクセス

Llama-3.1-8Bはゲートモデルであるため、アクセスをリクエストし、トークンを提供していただく必要があります:

  1. meta-llama/Llama-3.1-8B のモデルページを開き、ライセンスに同意してアクセスをリクエストしてください。アクセス権の付与までお待ちください。
  2. Hugging Face アクセストークン を**読取**権限で作成します。

Databricks シークレットにトークンを格納します

ワークロードは、ハードコーディングする代わりに、Databricksシークレットからトークンを読み取ります。シークレットスコープを作成し、トークンを追加します

Bash
databricks secrets create-scope my_scope
databricks secrets put-secret my_scope hf_token

train.yaml それをmy_scope/hf_tokenとして参照します。スコープとキーをご自身のものに置き換えてください。

プロジェクト レイアウト

次のファイルを含むディレクトリを作成します。

Text
multinode_llm_sft/
├── train.yaml # air workload config (inline dependencies + torchrun launcher)
└── train.py # FSDP fine-tuning script

ステップ 1: ワークロード YAML を記述する

train.yaml 2 つの GPU_8xH100 ノードとして 16 GPU を要求し、Hugging Face トークンをシークレットとしてマウントし、parameters ブロック経由でスクリプトにハイパーパラメーターを渡します。依存関係は、environment の下にインラインで宣言されます(クライアントイメージ version とともに)。torch パッケージはAIランタイムベースイメージに同梱されており、エキストラのみがリストされています。

YAML
experiment_name: air-multinode-llama-sft

environment:
version: '4'
dependencies:
- transformers>=4.45
- datasets>=3.0
- huggingface_hub>=0.34
- accelerate>=0.34
# The base image ships fsspec 2023.5.0, which is too old for modern
# huggingface_hub and breaks dataset/model downloads. Pin a newer fsspec.
- fsspec>=2024.6.1

# 16 GPUs across 2 nodes (GPU_8xH100 = 8 H100 per node).
compute:
num_accelerators: 16
accelerator_type: GPU_8xH100

code_source:
type: snapshot
snapshot:
root_path: .

command: |
cd $CODE_SOURCE_PATH
# air sets NUM_NODES, NODE_RANK, LOCAL_WORLD_SIZE, MASTER_ADDR, and MASTER_PORT on each node.
torchrun \
--nnodes="$NUM_NODES" \
--node_rank="$NODE_RANK" \
--nproc_per_node="${LOCAL_WORLD_SIZE:-8}" \
--master_addr="$MASTER_ADDR" \
--master_port="$MASTER_PORT" \
train.py

# Pin NCCL control-plane traffic to eth0 so cross-node rendezvous works.
env_variables:
NCCL_SOCKET_IFNAME: eth0
HF_HOME: /tmp/hf

# Gated model download needs a Hugging Face token. Replace with your own
# Databricks secret in the form "scope/key".
secrets:
HF_TOKEN: 'my_scope/hf_token'

max_retries: 1
timeout_minutes: 120

# Surfaced to train.py via HYPERPARAMETERS_PATH.
parameters:
model_name: meta-llama/Llama-3.1-8B
dataset_name: tatsu-lab/alpaca
max_seq_len: 1024
per_device_batch_size: 4
gradient_accumulation_steps: 2
learning_rate: 0.00002
max_steps: 100
output_dir: /Volumes/main/default/air_checkpoints/llama31-8b-sft

AI Runtime はノードごとに command を実行し、各ノードにランデブー環境変数(NUM_NODESNODE_RANKLOCAL_WORLD_SIZEMASTER_ADDR、および MASTER_PORT)を設定します。torchrun は、GPU ごとに 1 つのプロセスを起動するためにそれらを読み取るため、インライン コマンド全体がランチャーとなります。別途ランチャースクリプトは必要ありません。

ステップ 2: FSDP トレーニング スクリプトを記述する

train.py プロセスグループを初期化し、各トランスフォーマーブロックをFSDPでラップし、トークン化された命令データセットでトレーニングし、ランク0から統合されたチェックポイントを保存します。主な要素は次のとおりです:

Python
# Shard each transformer block independently so no single GPU holds the full model.
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls={LlamaDecoderLayer}
)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
device_id=local_rank,
use_orig_params=True,
)

ランク 0 は完全なステートディク (CPU にオフロード済み) を収集し、それを Unity Catalog ボリュームに書き込みます。

Python
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
cpu_state = model.state_dict()
if rank == 0:
model.module.save_pretrained(output_dir, state_dict=cpu_state)
tokenizer.save_pretrained(output_dir)

完全なスクリプトは、このページの最後にある「完全なトレーニングスクリプト」に記載されています。

ステップ3:実行を送信する

設定を検証し、送信してログを監視します:

Bash
air run -f train.yaml --dry-run
air run -f train.yaml --watch

ステップ 4: 実行を確認する

分散実行が複数のノードにまたがります。特定のノードからログを読み取るには、「--node」を使用してください。

Bash
air get run <run-id>
air logs <run-id> --node 0
air logs <run-id> --node 1

結果の保存場所

  • メトリクスとパラメーターexperiment_nameで指定されたMLflowエクスペリメントに記録されます。ワークスペースの MLflow UI で表示できます。
  • 微調整済みチェックポイント : parameters.output_dir の Unity Catalog ボリュームに書き込まれました。

完全なトレーニングスクリプト

コピー/貼り付けの完全なtrain.py

Python
#!/usr/bin/env python3
"""Multi-node FSDP supervised fine-tuning of Llama-3.1-8B.

Launched via ``torchrun`` from the workload YAML ``command`` across 2 nodes x 8 H100 (16 ranks). Each rank
owns one GPU. The model is sharded with PyTorch FSDP (full shard + bf16), trained on
an instruction dataset, and the consolidated checkpoint is written to a Unity Catalog
Volume by rank 0. Metrics are logged to MLflow.

Hyperparameters are read from the YAML block passed by ``air`` via HYPERPARAMETERS_PATH.
"""

import functools
import os

import mlflow
import torch
import torch.distributed as dist
import yaml
from datasets import load_dataset
from torch.distributed.fsdp import FullStateDictConfig, FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer


def load_params() -> dict:
"""Read the hyperparameters block that `air` materializes from the YAML `parameters:`."""
path = os.environ.get("HYPERPARAMETERS_PATH")
if path and os.path.exists(path):
with open(path) as f:
return yaml.safe_load(f) or {}
return {}


def build_dataset(tokenizer, dataset_name: str, max_seq_len: int):
"""Tokenize an instruction dataset into fixed-length causal-LM examples."""
raw = load_dataset(dataset_name, split="train")

def format_example(row):
instruction = row.get("instruction", "")
context = row.get("input", "")
response = row.get("output", "")
prompt = f"### Instruction:\n{instruction}\n\n"
if context:
prompt += f"### Input:\n{context}\n\n"
text = f"{prompt}### Response:\n{response}{tokenizer.eos_token}"
out = tokenizer(text, truncation=True, max_length=max_seq_len, padding="max_length")
out["labels"] = out["input_ids"].copy()
return out

cols = raw.column_names
tokenized = raw.map(format_example, remove_columns=cols)
# Emit torch tensors so the default DataLoader collate stacks them into [B, L] batches.
tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
return tokenized


def main():
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])

dist.init_process_group(backend="nccl")
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")

p = load_params()
model_name = p.get("model_name", "meta-llama/Llama-3.1-8B")
dataset_name = p.get("dataset_name", "tatsu-lab/alpaca")
max_seq_len = int(p.get("max_seq_len", 1024))
batch_size = int(p.get("per_device_batch_size", 4))
grad_accum = int(p.get("gradient_accumulation_steps", 2))
lr = float(p.get("learning_rate", 2e-5))
max_steps = int(p.get("max_steps", 100))
output_dir = p.get("output_dir", "/tmp/llama-sft")

if rank == 0:
print(f"World size={world_size} | model={model_name} | dataset={dataset_name}", flush=True)

# --- Model & data --------------------------------------------------------
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
model.config.use_cache = False # incompatible with gradient checkpointing / FSDP training
model.gradient_checkpointing_enable()

# Shard each transformer block independently so no single GPU holds the full model.
auto_wrap_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={LlamaDecoderLayer})
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
device_id=local_rank,
use_orig_params=True,
)

dataset = build_dataset(tokenizer, dataset_name, max_seq_len)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, drop_last=True)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

# --- MLflow (rank 0 only) ------------------------------------------------
# AI Runtime injects MLFLOW_RUN_ID and configures the databricks tracking URI on
# the node, so logging works without DATABRICKS_HOST/TOKEN. Gate on MLFLOW_RUN_ID
# so the script also runs cleanly off-platform (e.g. locally) where it is unset.
use_mlflow = rank == 0 and bool(os.environ.get("MLFLOW_RUN_ID"))
if use_mlflow:
mlflow.start_run(run_id=os.environ.get("MLFLOW_RUN_ID"))
mlflow.log_params({"model_name": model_name, "lr": lr, "batch_size": batch_size, "world_size": world_size})

# --- Training loop -------------------------------------------------------
model.train()
sampler.set_epoch(0)
step = 0
optimizer.zero_grad()
for micro_step, batch in enumerate(loader):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)

out = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
(out.loss / grad_accum).backward()

if (micro_step + 1) % grad_accum == 0:
model.clip_grad_norm_(1.0)
optimizer.step()
optimizer.zero_grad()
step += 1
if rank == 0:
print(f"step={step}/{max_steps} loss={out.loss.item():.4f}", flush=True)
if use_mlflow:
mlflow.log_metric("train_loss", out.loss.item(), step=step)
if step >= max_steps:
break

# --- Save consolidated checkpoint to the UC Volume (rank 0) --------------
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
cpu_state = model.state_dict()
if rank == 0:
os.makedirs(output_dir, exist_ok=True)
model.module.save_pretrained(output_dir, state_dict=cpu_state)
tokenizer.save_pretrained(output_dir)
print(f"Saved checkpoint to {output_dir}", flush=True)
if use_mlflow:
mlflow.end_run()

dist.barrier()
dist.destroy_process_group()


if __name__ == "__main__":
main()

次のステップ