Distributed fine-tuning of OpenAI gpt-oss-20b
This notebook demonstrates how to fine-tune OpenAI's gpt-oss-20b model using distributed training on serverless GPU compute. You'll learn how to:
- Apply LoRA (Low-Rank Adaptation) to efficiently fine-tune a 20B parameter model
- Use MXFP4 quantization to reduce memory requirements during training
- Leverage distributed data parallelism across 8 H100 GPUs
- Register the fine-tuned model in Unity Catalog for deployment
Key concepts:
- gpt-oss-20b: OpenAI's 20 billion parameter open-source language model
- LoRA: Parameter-efficient fine-tuning that trains small adapter layers while freezing the base model
- MXFP4 quantization: Microscaling 4-bit floating point format that reduces memory usage
- TRL: Transformer Reinforcement Learning library for supervised fine-tuning
- Serverless GPU compute: Databricks managed compute that automatically scales GPU resources
Connect to serverless GPU compute
This notebook requires serverless GPU compute. To connect:
- Click the notebook's compute selector in the top right and select Serverless GPU
- On the right side, click the environment button
- Select 8xH100 as the Accelerator
- Click Apply
The training function will automatically provision 8 H100 GPUs for distributed training.
Install required libraries
Install the necessary libraries for distributed training, including TRL for supervised fine-tuning, PEFT for LoRA adapters, and MLflow for model tracking.
%pip install "trl>=0.20.0" "peft>=0.17.0" "transformers>=4.56.1"
%pip install mlflow>=3.6
%pip install hf_transfer==0.1.9
%restart_python
dbutils.library.restartPython()
Configure Unity Catalog and model parameters
Set up the configuration parameters for Unity Catalog registration and model training. You can customize these parameters using the widgets above:
- uc_catalog, uc_schema, uc_model_name: Unity Catalog location for model registration
- uc_volume: Volume name for storing model checkpoints
- model: Hugging Face model identifier (default: openai/gpt-oss-20b)
- dataset_path: Dataset to use for fine-tuning (default: HuggingFaceH4/Multilingual-Thinking)
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_model_name", "gpt-oss-20b-peft")
dbutils.widgets.text("uc_volume", "checkpoints")
dbutils.widgets.text("model", "openai/gpt-oss-20b")
dbutils.widgets.text("dataset_path", "HuggingFaceH4/Multilingual-Thinking")
UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_MODEL_NAME = dbutils.widgets.get("uc_model_name")
UC_VOLUME = dbutils.widgets.get("uc_volume")
HF_MODEL_NAME = dbutils.widgets.get("model")
DATASET_PATH = dbutils.widgets.get("dataset_path")
print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")
print(f"UC_VOLUME: {UC_VOLUME}")
print(f"HF_MODEL_NAME: {HF_MODEL_NAME}")
print(f"DATASET_PATH: {DATASET_PATH}")
OUTPUT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{UC_MODEL_NAME}"
print(f"OUTPUT_DIR: {OUTPUT_DIR}")
Choose your dataset
By default, this notebook uses 'HuggingFaceH4/Multilingual-Thinking', which has been specifically curated with translated chain-of-thoughts in multiple languages. You can edit the "Dataset Path" parameter above to use another dataset.
Define GPU memory logging utility
This utility function helps monitor GPU memory usage during distributed training. It logs allocated and reserved memory for each GPU rank, which is useful for debugging memory issues.
import os
import torch
import torch.distributed as dist
def log_gpu_memory(tag=""):
if not torch.cuda.is_available():
return
# rank info (if distributed is initialized)
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
device = torch.cuda.current_device() # current GPU for this process
torch.cuda.synchronize(device)
allocated = torch.cuda.memory_allocated(device) / 1024**2
reserved = torch.cuda.memory_reserved(device) / 1024**2
print(
f"[{tag}] rank={rank}/{world_size-1}, "
f"device={device}, "
f"allocated={allocated:.1f} MB, reserved={reserved:.1f} MB"
)
Define the distributed training function
The following cell defines the training function using the @distributed decorator from the serverless_gpu library. This decorator:
- Provisions 8 H100 GPUs on-demand for distributed training
- Handles data parallelism across multiple GPUs automatically
The function includes:
- Dataset loading and tokenization
- Model initialization with MXFP4 quantization
- LoRA adapter configuration
- Training with gradient checkpointing and mixed precision
- Model saving to Unity Catalog volumes
from serverless_gpu import distributed
@distributed(gpus=8, gpu_type="h100", remote=False)
def run_train():
import logging
import os
import torch
rank = int(os.environ.get("RANK", "0"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
torch.cuda.set_device(local_rank)
world_size = int(os.environ.get("WORLD_SIZE", str(torch.cuda.device_count())))
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
is_main = rank == 0
if is_main:
logging.info("DDP environment")
logging.info(f"\tWORLD_SIZE={world_size} RANK={rank} LOCAL_RANK={local_rank}")
logging.info(f"\tCUDA device count (this node): {torch.cuda.device_count()}")
from datasets import load_dataset
dataset = load_dataset(DATASET_PATH, split="train")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME)
from transformers import AutoModelForCausalLM, Mxfp4Config
quantization_config = Mxfp4Config(dequantize=True)
model_kwargs = dict(
attn_implementation="eager", # Use eager attention implementation for better performance
dtype=torch.bfloat16,
quantization_config=quantization_config,
use_cache=False, # Since using gradient checkpointing
)
model = AutoModelForCausalLM.from_pretrained(HF_MODEL_NAME, **model_kwargs)
from peft import LoraConfig, get_peft_model
peft_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules="all-linear",
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
peft_model = get_peft_model(model, peft_config)
if is_main:
peft_model.print_trainable_parameters()
from trl import SFTConfig
training_args = SFTConfig(
learning_rate=2e-4,
num_train_epochs=1,
logging_steps=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=2,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
max_length=2048,
warmup_ratio=0.03,
lr_scheduler_type="cosine_with_min_lr",
lr_scheduler_kwargs={"min_lr_rate": 0.1},
output_dir=OUTPUT_DIR,
report_to="mlflow", # No reporting to avoid Gradio issues
push_to_hub=False, # Disable push to hub to avoid authentication issues
logging_dir=None, # Disable tensorboard logging
disable_tqdm=False, # Keep progress bars for monitoring
ddp_find_unused_parameters=False,
)
from trl import SFTTrainer
trainer = SFTTrainer(
model=peft_model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
)
#torch.cuda.empty_cache()
#log_gpu_memory("before model training")
result = trainer.train()
#log_gpu_memory("after model loading")
if is_main:
logging.info("Training complete!")
logging.info(f"Final training loss: {result.training_loss:.4f}")
logging.info(f"Train runtime (s): {result.metrics.get('train_runtime', 'N/A')}")
logging.info(f"Samples/sec: {result.metrics.get('train_samples_per_second', 'N/A')}")
logging.info("\nSaving trained model...")
trainer.save_model(OUTPUT_DIR)
logging.info("✓ LoRA adapters saved - use with base model for inference")
tokenizer.save_pretrained(OUTPUT_DIR)
logging.info("✓ Tokenizer saved with model")
logging.info(f"\n🎉 All artifacts saved to: {OUTPUT_DIR}")
import mlflow
mlflow_run_id = None
if mlflow.last_active_run() is not None:
mlflow_run_id = mlflow.last_active_run().info.run_id
return mlflow_run_id
Execute the distributed training
This cell runs the training function on 8 H100 GPUs. The training typically takes 30-60 minutes depending on dataset size and compute availability. The function returns the MLflow run ID for model registration.
run_id = run_train.distributed()[0]
Register model in Unity Catalog
Now you can register the fine-tuned model with MLflow and Unity Catalog for deployment.
Important: Given the size of the model (20B parameters), reconnect the notebook to H100 accelerator before running the registration cells.
The registration process will:
- Load the base model and merge it with the fine-tuned LoRA adapters
- Create a text generation pipeline
- Log the model to MLflow with Unity Catalog registration
dbutils.widgets.dropdown("register_model", "False", ["True", "False"])
register_model = dbutils.widgets.get("register_model")
if register_model == "False":
dbutils.notebook.exit("Skipping model registration...")
Check registration parameter
This cell checks the register_model parameter. If set to False, the notebook will skip model registration. You can change this parameter using the widget at the top of the notebook.
print("\nRegistering model with MLflow and Unity Catalog...")
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from peft import PeftModel
import mlflow
from mlflow import transformers as mlflow_transformers
import torch
torch.cuda.empty_cache()
# Load the trained model for registration
print("Loading LoRA model for registration...")
# For LoRA models, we need both base model and adapter
base_model = AutoModelForCausalLM.from_pretrained(
HF_MODEL_NAME,
trust_remote_code=True
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME)
adapter_dir = OUTPUT_DIR
peft_model = PeftModel.from_pretrained(base_model, adapter_dir)
# Merge LoRA into base and drop PEFT wrappers
merged_model = peft_model.merge_and_unload()
components = {
"model": merged_model,
"tokenizer": tokenizer,
}
# Create Unity Catalog model name
full_model_name = f"{UC_CATALOG}.{UC_SCHEMA}.{UC_MODEL_NAME}"
print(f"Registering model as: {full_model_name}")
text_gen_pipe = pipeline(
task="text-generation",
model=peft_model,
tokenizer=tokenizer,
)
input_example = ["Hello, world!"]
with mlflow.start_run():
model_info = mlflow.transformers.log_model(
transformers_model=text_gen_pipe, # 🚨 pass the pipeline, not just the model
artifact_path="model",
input_example=input_example,
# optional: save_pretrained=False for reference-only PEFT logging
# save_pretrained=False,
)
# Start MLflow run and log model
print(f"✓ Model successfully registered in Unity Catalog: {full_model_name}")
print(f"✓ MLflow model URI: {model_info.model_uri}")
print(f"✓ Model version: {model_info.registered_model_version}")
# Print deployment information
print(f"\n📦 Model Registration Complete!")
print(f"Unity Catalog Path: {full_model_name}")
print(f"Optimization: Liger Kernels + LoRA")
Test multilingual reasoning capabilities
The fine-tuned model has been trained on the Multilingual-Thinking dataset, which includes chain-of-thought reasoning in multiple languages.
The following cell demonstrates this capability by:
- Setting the reasoning language to German
- Providing a prompt in Spanish ("What is the capital of Australia?")
- Observing that the model's internal reasoning is performed in German
REASONING_LANGUAGE = "German"
SYSTEM_PROMPT = f"reasoning language: {REASONING_LANGUAGE}"
USER_PROMPT = "¿Cuál es el capital de Australia?" # Spanish for "What is the capital of Australia?"
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": USER_PROMPT},
]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
).to(merged_model.device)
gen_kwargs = {"max_new_tokens": 512, "do_sample": True, "temperature": 0.6, "top_p": None, "top_k": None}
output_ids = merged_model.generate(input_ids, **gen_kwargs)
response = tokenizer.batch_decode(output_ids)[0]
print(response)
Next steps
Now that you've fine-tuned and tested your model, you can:
- Deploy the model: Serve models with Model Serving
- Learn more about distributed training: Multi-GPU and multi-node distributed training
- Optimize serverless GPU usage: Best practices for Serverless GPU compute
- Troubleshoot issues: Troubleshoot issues on serverless GPU compute
- Learn about OpenAI's gpt-oss: OpenAI fine-tuning cookbook