Fine-tune OpenAI's GPT-OSS 120B model using distributed training
This notebook demonstrates supervised fine-tuning (SFT) of the large 120B parameter GPT-OSS model on 8 H100 GPUs using Databricks Serverless GPU. The training leverages:
- FSDP (Fully Sharded Data Parallel): Shards model parameters, gradients, and optimizer states across GPUs to enable training of large models that don't fit on a single GPU.
- DDP (Distributed Data Parallel): Distributes training across multiple GPUs for faster training.
- LoRA (Low-Rank Adaptation): Reduces the number of trainable parameters by adding small adapter layers, making fine-tuning more efficient.
- TRL (Transformers Reinforcement Learning): Provides the SFTTrainer for supervised fine-tuning.
By setting remote=False and specifying 16 GPUs, this can be extended to multi-node training across 16 GPUs.
Install required packages
Install the necessary libraries for distributed training and model fine-tuning:
trl: Transformers Reinforcement Learning library for SFT trainingpeft: Parameter-Efficient Fine-Tuning for LoRA adapterstransformers: Hugging Face transformers librarydatasets: For loading training datasetsaccelerate: For distributed training orchestrationhf_transfer: For faster model downloads from Hugging Face
%pip install "trl>=0.20.0" "peft>=0.17.0" "transformers>=4.55.0"
%pip install datasets accelerate
%pip install hf_transfer
Restart the Python environment
Restart the Python kernel to ensure the newly installed packages are available.
dbutils.library.restartPython()
Define the distributed training function with FSDP
This cell defines the training function that will run on 8 H100 GPUs using the @distributed decorator. The function includes:
- Model loading: Loads the 120B parameter GPT-OSS model in bfloat16 precision
- LoRA configuration: Applies Low-Rank Adaptation with rank 16 to reduce trainable parameters
- FSDP setup: Configures Fully Sharded Data Parallel with automatic layer wrapping and activation checkpointing
- Training configuration: Sets batch size, learning rate, gradient accumulation, and other hyperparameters
- Dataset: Uses the HuggingFaceH4/Multilingual-Thinking dataset for fine-tuning
The function automatically detects transformer block classes for FSDP wrapping and handles distributed training coordination across all GPUs.
from serverless_gpu import distributed
@distributed(gpus=8, gpu_type='H100', remote=True)
def train_gpt_oss_fsdp_120b():
"""
Fine-tune a 120B-class model with TRL SFTTrainer + FSDP on H100s.
Uses LoRA + activation ckpt + full_shard auto_wrap.
"""
# --- imports inside for pickle safety ---
import os, torch, torch.distributed as dist
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
# ---------- DDP / CUDA binding ----------
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
torch.cuda.set_device(local_rank)
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("NCCL_DEBUG", "WARN")
os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "0")
os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1") # replaces NCCL_ASYNC_ERROR_HANDLING
# ---------- Config ----------
MODEL_NAME = "openai/gpt-oss-120b"
MAX_LENGTH = 2048
PER_DEVICE_BATCH = 1 # start conservative for 120B
GRAD_ACCUM = 4 # tune for throughput
LR = 1.5e-4
EPOCHS = 1
OUTPUT_DIR = f"/tmp/gpt-oss-120b-finetune_2"
is_main = int(os.environ.get("RANK", "0")) == 0
world_size = int(os.environ.get("WORLD_SIZE", "1"))
if is_main:
print("=" * 60)
print("FSDP (full_shard) launch for 120B")
print(f"WORLD_SIZE={world_size} | LOCAL_RANK={local_rank}")
print("=" * 60)
# ---------- Tokenizer ----------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = MAX_LENGTH
tokenizer.truncation_side = "right"
# ---------- Model ----------
# IMPORTANT: no device_map, no .to(device) — let Trainer/Accelerate+FSDP handle placement
# low_cpu_mem_usage helps with massive checkpoints (still needs decent host RAM)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
attn_implementation="eager",
use_cache=False, # needed for grad ckpt
low_cpu_mem_usage=True,
)
# ---------- LoRA ----------
peft_config = LoraConfig(
r=16, lora_alpha=16, target_modules="all-linear",
lora_dropout=0.05, bias="none", task_type="CAUSAL_LM",
)
model = get_peft_model(model, peft_config)
if is_main:
model.print_trainable_parameters()
# ---------- Data ----------
dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train")
if is_main:
print(f"Dataset size: {len(dataset)}")
# ---------- FSDP settings ----------
def infer_transformer_blocks_for_fsdp(model):
COMMON = {
"LlamaDecoderLayer", "MistralDecoderLayer", "MixtralDecoderLayer",
"Qwen2DecoderLayer", "Gemma2DecoderLayer", "Phi3DecoderLayer",
"GPTNeoXLayer", "MPTBlock", "BloomBlock", "FalconDecoderLayer",
"DecoderLayer", "GPTJBlock", "OPTDecoderLayer"
}
hits = set()
for _, m in model.named_modules():
name = m.__class__.__name__
if name in COMMON:
hits.add(name)
# Fallback: grab anything that *looks* like a decoder block
if not hits:
for _, m in model.named_modules():
name = m.__class__.__name__
if any(s in name for s in ["Block", "DecoderLayer", "Layer"]) and "Embedding" not in name:
hits.add(name)
return sorted(hits)
fsdp_wrap_classes = infer_transformer_blocks_for_fsdp(model)
if not fsdp_wrap_classes:
raise RuntimeError("Could not infer transformer block classes for FSDP wrapping; "
"print(model) and add the block class explicitly.")
training_args = SFTConfig(
output_dir=OUTPUT_DIR,
overwrite_output_dir=True,
num_train_epochs=EPOCHS,
per_device_train_batch_size=PER_DEVICE_BATCH,
gradient_accumulation_steps=GRAD_ACCUM,
learning_rate=LR,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
bf16=True,
logging_steps=5,
logging_strategy="steps",
save_strategy="no",
report_to="none",
ddp_find_unused_parameters=False,
dataloader_pin_memory=True,
max_length=MAX_LENGTH,
gradient_checkpointing=False,
# ---- FSDP knobs ----
fsdp="full_shard auto_wrap",
fsdp_config={
"fsdp_transformer_layer_cls_to_wrap": fsdp_wrap_classes,
"activation_checkpointing": True, # <- use activation ckpt (not gradient)
"activation_checkpointing_reentrant": False,
"xla": False,
"limit_all_gathers": True,
"use_orig_params": True, # recommended in recent torch/transformers
"sync_module_states": True, # helps with fresh init consistency
},
)
# ---------- Trainer ----------
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
)
# verify distributed init & FSDP
rank = int(os.getenv("RANK", "0"))
print(f"[rank {rank}] dist.is_initialized() -> {dist.is_initialized()}")
acc = getattr(trainer, "accelerator", None)
print(f"[rank {rank}] accelerator.distributed_type = {getattr(getattr(acc,'state',None),'distributed_type','n/a')}")
print(f"[rank {rank}] accelerator.num_processes = {getattr(acc, 'num_processes', 'n/a')}")
# ---------- Train ----------
result = trainer.train()
if is_main:
print("\nTraining complete (FSDP).")
print(result.metrics)
Run the distributed training job
Execute the training function on 8 H100 GPUs. The @distributed decorator handles the orchestration of launching the training across all GPUs with proper distributed setup.
train_gpt_oss_fsdp_120b.distributed()
Next steps
- Multi-GPU and multi-node distributed training
- Best practices for Serverless GPU compute
- Troubleshoot issues on serverless GPU compute
- PEFT documentation
- TRL documentation