Pular para o conteúdo principal

Ajustar o modelo OpenAI gpt-oss 120B

Este notebook demonstra como usar a GPU sem servidor Databricks para executar o ajuste fino supervisionado (SFT) no modelo gpt-oss de grande porte usando o FSDP em 8 GPUs H100.

  • Conecte-se ao 8XH100
  • Selecione o Ambiente Padrão V4
Python
%pip install "trl>=0.20.0" "peft>=0.17.0" "transformers==4.56.1"
%pip install datasets
%pip install hf_transfer==0.1.9
%pip install mlflow>=3.7.0
Python
dbutils.library.restartPython()

Python
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_model_name", "gpt-oss-120b-peft")
dbutils.widgets.text("uc_volume", "checkpoints")
dbutils.widgets.text("model", "openai/gpt-oss-120b")
dbutils.widgets.text("dataset_path", "HuggingFaceH4/Multilingual-Thinking")

UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_MODEL_NAME = dbutils.widgets.get("uc_model_name")
UC_VOLUME = dbutils.widgets.get("uc_volume")
HF_MODEL_NAME = dbutils.widgets.get("model")
DATASET_PATH = dbutils.widgets.get("dataset_path")

print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")
print(f"UC_VOLUME: {UC_VOLUME}")
print(f"HF_MODEL_NAME: {HF_MODEL_NAME}")
print(f"DATASET_PATH: {DATASET_PATH}")

OUTPUT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{UC_MODEL_NAME}"
# OUTPUT_DIR = "./tmp/output"
print(f"OUTPUT_DIR: {OUTPUT_DIR}")
Python
from serverless_gpu import distributed
from typing import List

@distributed(gpus=8, gpu_type='H100')
def run_train(num_epochs: int = 1, learning_rate: float = 1.5e-4) -> List[int]:
"""
Fine-tune a 120B-class model with TRL SFTTrainer + FSDP on H100s.
Uses LoRA + activation ckpt + full_shard auto_wrap.
"""
import logging
import os
import torch

rank = int(os.environ.get("RANK", "0"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
torch.cuda.set_device(local_rank)
world_size = int(os.environ.get("WORLD_SIZE", str(torch.cuda.device_count())))

os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

is_main = int(os.environ.get("RANK", "0")) == 0
world_size = int(os.environ.get("WORLD_SIZE", "1"))

is_main = rank == 0
if is_main:
logging.info("FSDP (full_shard) launch for 120B")
logging.info(f"\tWORLD_SIZE={world_size} | LOCAL_RANK={local_rank}")
logging.info(f"\tCUDA device count (this node): {torch.cuda.device_count()}")

# load dataset inside the distributed context
from datasets import load_dataset
dataset = load_dataset(DATASET_PATH, split="train")

# load tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME)

# load model
from transformers import AutoModelForCausalLM, Mxfp4Config

model_kwargs = dict(
attn_implementation="eager", # Use eager attention implementation for better performance
dtype=torch.bfloat16,
use_cache=False, # Since using gradient checkpointing
low_cpu_mem_usage=True, # helps with massive checkpoints
)

model = AutoModelForCausalLM.from_pretrained(HF_MODEL_NAME, **model_kwargs)

# LORA config
from peft import LoraConfig, get_peft_model

peft_config = LoraConfig(
r=16,
lora_alpha=16,
target_modules="all-linear",
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
peft_model = get_peft_model(model, peft_config)
if is_main:
peft_model.print_trainable_parameters()

# FSDP settings
def infer_transformer_blocks_for_fsdp(model):
COMMON = {
"LlamaDecoderLayer", "MistralDecoderLayer", "MixtralDecoderLayer",
"Qwen2DecoderLayer", "Gemma2DecoderLayer", "Phi3DecoderLayer",
"GPTNeoXLayer", "MPTBlock", "BloomBlock", "FalconDecoderLayer",
"DecoderLayer", "GPTJBlock", "OPTDecoderLayer"
}
hits = set()
for _, m in model.named_modules():
name = m.__class__.__name__
if name in COMMON:
hits.add(name)
# Fallback: grab anything that *looks* like a decoder block
if not hits:
for _, m in model.named_modules():
name = m.__class__.__name__
if any(s in name for s in ["Block", "DecoderLayer", "Layer"]) and "Embedding" not in name:
hits.add(name)
return sorted(hits)


fsdp_wrap_classes = infer_transformer_blocks_for_fsdp(model)
if not fsdp_wrap_classes:
raise RuntimeError("Could not infer transformer block classes for FSDP wrapping; "
"print(model) and add the block class explicitly.")

from trl import SFTConfig

training_args = SFTConfig(
output_dir=OUTPUT_DIR,
overwrite_output_dir=True,
num_train_epochs=num_epochs,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
learning_rate=learning_rate,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
bf16=True,
logging_steps=5,
logging_strategy="steps",
save_strategy="no",
report_to="mlflow",
run_name="gpt-oss-120b-lora-fsdp",
ddp_find_unused_parameters=False,
dataloader_pin_memory=True,
max_length=2048,
gradient_checkpointing=False,

# ---- FSDP knobs ----
fsdp="full_shard auto_wrap",
fsdp_config={
# ✅ avoid FULL state-dict all-gathers on save (prevents the save-time _ALLGATHER_BASE hang)
"fsdp_state_dict_type": "SHARDED_STATE_DICT", # recommended for checkpoints
"fsdp_transformer_layer_cls_to_wrap": fsdp_wrap_classes,
"activation_checkpointing": True, # <- use activation ckpt (not gradient)
"activation_checkpointing_reentrant": False,
"xla": False,
"limit_all_gathers": True,
"use_orig_params": True, # recommended in recent torch/transformers
"sync_module_states": True, # helps with fresh init consistency
},
)

# Trainer
from trl import SFTTrainer
trainer = SFTTrainer(
model=peft_model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
)

result = trainer.train()
if trainer.accelerator.is_main_process:
logging.info("Training complete!")
logging.info(f"Final training loss: {result.training_loss:.4f}")
logging.info(f"Train runtime (s): {result.metrics.get('train_runtime', 'N/A')}")
logging.info(f"Samples/sec: {result.metrics.get('train_samples_per_second', 'N/A')}")


# --- IMPORTANT: all ranks must participate in save_state/barriers ---
trainer.accelerator.wait_for_everyone()
logging.info("Saving sharded FSDP checkpoint...")
# 1) ✅ resumable FSDP/Accelerate checkpoint shards (model/optim/sched/RNG)
trainer.accelerator.save_state(OUTPUT_DIR)

trainer.accelerator.wait_for_everyone()

if trainer.accelerator.is_main_process:
tokenizer.save_pretrained(OUTPUT_DIR)
logging.info(f"Saved tokenizer to: {OUTPUT_DIR}")
logging.info("✓ Tokenizer saved with model")
logging.info(f"\n🎉 All artifacts saved to: {OUTPUT_DIR}")

import mlflow
mlflow_run_id = None
if mlflow.last_active_run() is not None:
mlflow_run_id = mlflow.last_active_run().info.run_id

return mlflow_run_id
Python
results = run_train.distributed(num_epochs=1, learning_rate=1.5e-4)
run_id = results[0]

Python
import mlflow

run = mlflow.get_run(run_id)
status = run.info.status
display(status)

Exemplo de caderno

Ajustar o modelo OpenAI gpt-oss 120B

Abrir notebook em uma nova aba