メインコンテンツまでスキップ

分散トレーニングを使用して OpenAI の GPT-OSS 120B モデルを微調整する

このノートブックでは、 Databricksサーバーレス GPU を使用した 8 つの H100 GPU 上の大規模な 120B 問題 GPT-OSS モデルの教師ありファインチューニング (SFT) を示します。 トレーニングでは以下を活用します:

  • FSDP (Fully Sharded Data Parallel) : モデルのシャード、勾配、オプティマイザーの状態を GPU 全体で共有し、単一の GPU に収まらない大規模なモデルのトレーニングを可能にします。
  • DDP (分散データ並列) : トレーニングを複数の GPU に分散して、トレーニングを高速化します。
  • LoRA (Low-Rank Adaptation) : 小さなアダプター層を追加することでトレーニング可能な懸念の数を減らし、微調整をより効率的にします。
  • TRL (Transformers Reinforcement Learning) : 教師ありファインチューニング用の SFTTrainer を提供します。

remote=Falseを設定し、16 個の GPU を指定すると、16 個の GPU にわたるマルチノード トレーニングに拡張できます。

必要なパッケージをインストールする

分散トレーニングとモデルのファインチューニングに必要なライブラリをインストールします。

  • trl: SFTトレーニング用トランスフォーマー強化学習ライブラリ
  • peft: 不利 - LoRA アダプターの効率的なファインチューニング
  • transformers: Hugging Faceトランスフォーマーライブラリ
  • datasets: トレーニングデータセットの読み込み用
  • accelerate: 分散トレーニングオーケストレーション用
  • hf_transfer: Hugging Faceからのモデルのダウンロードを高速化
Python
%pip install "trl>=0.20.0" "peft>=0.17.0" "transformers>=4.55.0"
%pip install datasets accelerate
%pip install hf_transfer

Python環境を再起動します

新しくインストールされたパッケージが利用可能であることを確認するために、Python カーネルを再起動します。

Python
dbutils.library.restartPython()

FSDPで分散トレーニング機能を定義する

このセルは、 @distributedデコレータを使用して 8 つの H100 GPU で実行されるトレーニング関数を定義します。機能には以下が含まれます。

  • モデルの読み込み : 120B 問題 GPT-OSS モデルを bfloat16 精度で読み込みます
  • LoRA 構成 : ランク 16 で低ランク適応を適用し、トレーニング可能な論点を削減します
  • FSDP セットアップ : 自動レイヤー ラッピングとアクティベーション チェックポイントを備えた完全シャーディング データ パラレルを構成します。
  • トレーニング設定 : バッチサイズ、学習率、勾配累積、その他のハイパーパラメータを設定します
  • データセット : ファインチューニングには HuggingFaceH4/Multilingual-Thinking データセットを使用します

この関数は、FSDP ラッピングのトランスフォーマー ブロック クラスを自動的に検出し、すべての GPU にわたる分散トレーニング調整を処理します。

Python
from serverless_gpu import distributed



@distributed(gpus=8, gpu_type='H100', remote=True)
def train_gpt_oss_fsdp_120b():
"""
Fine-tune a 120B-class model with TRL SFTTrainer + FSDP on H100s.
Uses LoRA + activation ckpt + full_shard auto_wrap.
"""

# --- imports inside for pickle safety ---
import os, torch, torch.distributed as dist
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
from peft import LoraConfig, get_peft_model

# ---------- DDP / CUDA binding ----------
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
torch.cuda.set_device(local_rank)
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("NCCL_DEBUG", "WARN")
os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "0")

os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1") # replaces NCCL_ASYNC_ERROR_HANDLING

# ---------- Config ----------
MODEL_NAME = "openai/gpt-oss-120b"
MAX_LENGTH = 2048
PER_DEVICE_BATCH = 1 # start conservative for 120B
GRAD_ACCUM = 4 # tune for throughput
LR = 1.5e-4
EPOCHS = 1
OUTPUT_DIR = f"/tmp/gpt-oss-120b-finetune_2"

is_main = int(os.environ.get("RANK", "0")) == 0
world_size = int(os.environ.get("WORLD_SIZE", "1"))

if is_main:
print("=" * 60)
print("FSDP (full_shard) launch for 120B")
print(f"WORLD_SIZE={world_size} | LOCAL_RANK={local_rank}")
print("=" * 60)

# ---------- Tokenizer ----------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = MAX_LENGTH
tokenizer.truncation_side = "right"

# ---------- Model ----------
# IMPORTANT: no device_map, no .to(device) — let Trainer/Accelerate+FSDP handle placement
# low_cpu_mem_usage helps with massive checkpoints (still needs decent host RAM)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
attn_implementation="eager",
use_cache=False, # needed for grad ckpt
low_cpu_mem_usage=True,
)

# ---------- LoRA ----------
peft_config = LoraConfig(
r=16, lora_alpha=16, target_modules="all-linear",
lora_dropout=0.05, bias="none", task_type="CAUSAL_LM",
)
model = get_peft_model(model, peft_config)
if is_main:
model.print_trainable_parameters()

# ---------- Data ----------
dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train")
if is_main:
print(f"Dataset size: {len(dataset)}")

# ---------- FSDP settings ----------
def infer_transformer_blocks_for_fsdp(model):
COMMON = {
"LlamaDecoderLayer", "MistralDecoderLayer", "MixtralDecoderLayer",
"Qwen2DecoderLayer", "Gemma2DecoderLayer", "Phi3DecoderLayer",
"GPTNeoXLayer", "MPTBlock", "BloomBlock", "FalconDecoderLayer",
"DecoderLayer", "GPTJBlock", "OPTDecoderLayer"
}
hits = set()
for _, m in model.named_modules():
name = m.__class__.__name__
if name in COMMON:
hits.add(name)
# Fallback: grab anything that *looks* like a decoder block
if not hits:
for _, m in model.named_modules():
name = m.__class__.__name__
if any(s in name for s in ["Block", "DecoderLayer", "Layer"]) and "Embedding" not in name:
hits.add(name)
return sorted(hits)


fsdp_wrap_classes = infer_transformer_blocks_for_fsdp(model)
if not fsdp_wrap_classes:
raise RuntimeError("Could not infer transformer block classes for FSDP wrapping; "
"print(model) and add the block class explicitly.")


training_args = SFTConfig(
output_dir=OUTPUT_DIR,
overwrite_output_dir=True,
num_train_epochs=EPOCHS,
per_device_train_batch_size=PER_DEVICE_BATCH,
gradient_accumulation_steps=GRAD_ACCUM,
learning_rate=LR,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
bf16=True,
logging_steps=5,
logging_strategy="steps",
save_strategy="no",
report_to="none",
ddp_find_unused_parameters=False,
dataloader_pin_memory=True,
max_length=MAX_LENGTH,
gradient_checkpointing=False,

# ---- FSDP knobs ----
fsdp="full_shard auto_wrap",
fsdp_config={
"fsdp_transformer_layer_cls_to_wrap": fsdp_wrap_classes,
"activation_checkpointing": True, # <- use activation ckpt (not gradient)
"activation_checkpointing_reentrant": False,
"xla": False,
"limit_all_gathers": True,
"use_orig_params": True, # recommended in recent torch/transformers
"sync_module_states": True, # helps with fresh init consistency
},
)

# ---------- Trainer ----------
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
)

# verify distributed init & FSDP
rank = int(os.getenv("RANK", "0"))
print(f"[rank {rank}] dist.is_initialized() -> {dist.is_initialized()}")
acc = getattr(trainer, "accelerator", None)
print(f"[rank {rank}] accelerator.distributed_type = {getattr(getattr(acc,'state',None),'distributed_type','n/a')}")
print(f"[rank {rank}] accelerator.num_processes = {getattr(acc, 'num_processes', 'n/a')}")

# ---------- Train ----------
result = trainer.train()

if is_main:
print("\nTraining complete (FSDP).")
print(result.metrics)

分散トレーニングジョブを実行する

8 つの H100 GPU でトレーニング機能を実行します。@distributedデコレータは、適切な分散設定を使用してすべての GPU にわたってトレーニングを起動するオーケストレーションを処理します。

Python
train_gpt_oss_fsdp_120b.distributed()

次のステップ

サンプルノートブック

分散トレーニングを使用して OpenAI の GPT-OSS 120B モデルを微調整する

ノートブックを新しいタブで開く