Full fine-tuning of Qwen3-4B
Fully fine-tune the Qwen3-4B large language model on a single H100 GPU. This walkthrough shows how to:
- Run full fine-tuning, which updates every model parameter for maximum adaptation to your data
- Use the Databricks AI v5 environment without installing any additional libraries
- Leverage TRL (Transformer Reinforcement Learning) for supervised fine-tuning
- Register the fine-tuned model in Unity Catalog for governance and deployment
Key concepts:
- Full fine-tuning: Updates all model weights, giving the model the greatest capacity to learn from your dataset at the cost of higher memory and compute than parameter-efficient methods
- TRL: A library for training language models with reinforcement learning and supervised fine-tuning
- Memory-efficient training: Uses BF16 mixed precision and gradient checkpointing to fit a 4B-parameter full fine-tune on a single H100 GPU
Full fine-tuning vs LoRA decision matrix
This notebook uses full fine-tuning, which updates all model parameters. The alternative, LoRA (Low-Rank Adaptation), freezes the base model and trains only small adapter layers.
Scenario | Recommendation | Reason |
|---|---|---|
Major model behavior change | Full fine-tuning | Updates all parameters for fundamental changes to model behavior |
Highest possible quality on a single task | Full fine-tuning | No low-rank approximation, so the model has full capacity to adapt |
Limited GPU memory | LoRA | Fits larger models in memory by training only ~1% of parameters |
Multiple task-specific adapters | LoRA | Swap different adapters on the same base model |
Full fine-tuning a 4B-parameter model requires significantly more GPU memory than LoRA because optimizer state and gradients are maintained for every parameter. This notebook uses BF16 mixed precision and gradient checkpointing so the training fits on a single H100 (80 GB) GPU.
Connect to serverless GPU compute
To connect to serverless GPU compute:
- Click the Connect drop-down menu in the notebook and select Serverless GPU.
- Choose a 1x H100 GPU as the accelerator.
- Open the Environment panel and choose AI v5 as the base environment.
- Click Apply.
For more information, see the GPU compute documentation.
Import libraries
The Databricks AI v5 environment already includes all the libraries required for this example (such as trl, transformers, datasets, and mlflow), so no additional installation is needed.
The next cell imports the required libraries for model training, dataset handling, and MLflow tracking.
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import (
SFTConfig,
SFTTrainer,
setup_chat_format
)
import torch
import mlflow
Configuration setup
Unity Catalog integration
The next cell configures where your fine-tuned model will be stored and registered:
- Catalog & Schema: Organize models within your Unity Catalog namespace (default:
main.default) - Model Name: The registered model name in Unity Catalog for governance and deployment
- Volume: Unity Catalog volume for storing model checkpoints during training
These widgets allow you to customize the storage location without editing code. The model will be registered as {catalog}.{schema}.{model_name} for easy access and version control.
Training hyperparameters
The cell also defines key training parameters:
- Model & Dataset: Qwen3-4B with the Capybara conversational dataset
- Batch Size (1): Number of examples per GPU per training step, kept small to fit a full fine-tune in memory
- Gradient Accumulation (8): Accumulates gradients over 8 batches for an effective batch size of 8
- Learning Rate (2e-5): Conservative rate appropriate for full fine-tuning
- Max steps (50): Caps training at 50 steps for a fast demonstration run
- Logging & Checkpointing: Saves progress every 25 steps, logs metrics every 10 steps
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_model_name", "qwen3_4b_assistant")
dbutils.widgets.text("uc_volume", "checkpoints")
UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_MODEL_NAME = dbutils.widgets.get("uc_model_name")
UC_VOLUME = dbutils.widgets.get("uc_volume")
print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")
print(f"UC_VOLUME: {UC_VOLUME}")
# MLflow and Unity Catalog configuration
# Model selection
MODEL_NAME = "Qwen/Qwen3-4B"
DATASET_NAME = "trl-lib/Capybara"
OUTPUT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{UC_MODEL_NAME}"
# Training hyperparameters
BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 8
LEARNING_RATE = 2e-5
MAX_STEPS = 50
EVAL_STEPS = 25
LOGGING_STEPS = 10
SAVE_STEPS = 25
Load and prepare dataset
The next cell loads the training dataset and prepares it for fine-tuning:
- Dataset: trl-lib/Capybara - high-quality conversational data optimized for instruction following
- Train/validation split: Creates a 90/10 split if no test set exists
- Data validation: Ensures proper formatting for conversational fine-tuning
dataset = load_dataset(DATASET_NAME)
print(f"✓ Dataset loaded: {dataset}")
if "test" not in dataset:
print("Creating validation split from training data...")
dataset = dataset["train"].train_test_split(test_size=0.1, seed=42)
print("✓ Data split: 90% train, 10% validation")
Initialize model and tokenizer
The next cell loads the base model and tokenizer, then configures them for conversational fine-tuning:
- Model loading: Downloads Qwen3-4B from Hugging Face in BF16 precision
- Tokenizer setup: Configures fast tokenizer with proper padding
- Chat formatting: Applies a chat template for structured conversations if the tokenizer doesn't already define one
- Token configuration: Sets padding token to EOS token for proper sequence handling
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
use_fast=True
)
# Chat template formatting for conversational fine-tuning
if tokenizer.chat_template is None:
print("Adding chat template for proper conversation formatting...")
model, tokenizer = setup_chat_format(model, tokenizer, format="chatml")
print("✓ ChatML format applied for structured conversations")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("✓ Padding token set to EOS token")
print("✓ Model and tokenizer loaded successfully")
Train the model
The next cell configures and executes the full fine-tuning process:
Training configuration
- Batch configuration: 1 sample per device with 8 gradient accumulation steps (effective batch size: 8)
- Optimization: Warmup steps, weight decay, and best model selection based on evaluation loss
- Logging: Reports metrics to MLflow for experiment tracking
Key optimizations enabled
- BF16 mixed precision: Faster computation with lower memory footprint, well suited to H100 GPUs
- Gradient checkpointing: Trades extra compute for a large reduction in activation memory, which is what makes a 4B full fine-tune fit on a single H100
- Gradient accumulation: Simulates larger batch sizes for stable training
- Checkpointing: Saves the model every 25 steps with a limit of 2 checkpoints
The training loop logs progress every 10 steps and evaluates every 25 steps.
with mlflow.start_run(run_name=f"{MODEL_NAME}_full-fine-tuning", log_system_metrics=True):
try:
print(f"Learning rate: {LEARNING_RATE}")
training_args_dict = {
"output_dir": OUTPUT_DIR,
"per_device_train_batch_size": BATCH_SIZE,
"per_device_eval_batch_size": BATCH_SIZE,
"gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS,
"learning_rate": LEARNING_RATE,
"max_steps": MAX_STEPS,
"eval_steps": EVAL_STEPS,
"logging_steps": LOGGING_STEPS,
"save_steps": SAVE_STEPS,
"save_total_limit": 2,
"report_to": "mlflow", # Log to MLflow
"warmup_steps": 10,
"weight_decay": 0.01,
"metric_for_best_model": "eval_loss",
"greater_is_better": False,
"eval_strategy": "steps", # Run evaluation every eval_steps
"save_strategy": "steps", # Checkpoint on the same cadence as eval
"load_best_model_at_end": True, # Register the best-eval checkpoint, not the last
"dataloader_pin_memory": False,
"remove_unused_columns": False,
"bf16": True, # Mixed precision training
"gradient_checkpointing": True, # Reduce activation memory for full fine-tuning
"gradient_checkpointing_kwargs": {"use_reentrant": False},
}
training_args = SFTConfig(**training_args_dict)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
processing_class=tokenizer,
)
print("\n" + "="*50)
print("STARTING TRAINING")
print("="*50)
print("🚀 Full fine-tuning Qwen3-4B on a single H100 GPU")
trainer.train()
print("\n✓ Training completed successfully!")
except Exception as e:
print(f"✗ Training failed: {e}")
raise
Save model artifacts
The next cell saves the trained model and tokenizer to the Unity Catalog volume:
- Full model weights: Saves the complete fine-tuned model, ready to load directly for inference
- Tokenizer: Saves tokenizer configuration for inference
- Storage location: Saves to
/Volumes/{catalog}/{schema}/{volume}/{model_name}
try:
print("\nSaving trained model...")
trainer.save_model(training_args.output_dir)
print("✓ Full model weights saved")
tokenizer.save_pretrained(training_args.output_dir)
print("✓ Tokenizer saved with model")
print(f"\n🎉 All artifacts saved to: {training_args.output_dir}")
except Exception as e:
print(f"✗ Model saving failed: {e}")
raise
Register model in Unity Catalog
The next cell registers the fine-tuned model in Unity Catalog for governance and deployment:
Model registration workflow
- Load trained model: Loads the saved full-weight model and tokenizer
- Prepare for logging: Creates a transformers model dictionary with the model and tokenizer
- Register in Unity Catalog: Logs to MLflow and registers in Unity Catalog
- Add metadata: Includes task type, model family, and size information
Benefits of Unity Catalog registration
- Governance: Centralized model registry with access control and lineage tracking
- Versioning: Automatic version management for model lifecycle
- Deployment: Easy deployment to model serving endpoints
- Discoverability: Models are searchable and documented in Unity Catalog
mlflow_run_id = mlflow.last_active_run().info.run_id
print("\nRegistering model with MLflow and Unity Catalog...")
with mlflow.start_run(run_id=mlflow_run_id):
try:
# Load the trained full-weight model for registration
print("Loading fine-tuned model for registration...")
trained_model = AutoModelForCausalLM.from_pretrained(
training_args.output_dir,
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(training_args.output_dir)
model_type = "Full fine-tuning"
size_params = "4b"
# Prepare transformers model dictionary
transformers_model = {
"model": trained_model,
"tokenizer": tokenizer
}
# Create Unity Catalog model name
full_model_name = f"{UC_CATALOG}.{UC_SCHEMA}.{UC_MODEL_NAME}"
print(f"Registering model as: {full_model_name}")
# Start MLflow run and log model
task = "llm/v1/chat"
model_info = mlflow.transformers.log_model(
transformers_model=transformers_model,
task=task,
registered_model_name=full_model_name,
metadata={
"task": task,
"pretrained_model_name": MODEL_NAME,
"databricks_model_family": "Qwen3ForCausalLM",
"databricks_model_size_parameters": size_params,
},
repo_type="local", # Fix: specify repo_type for local path
)
print(f"✓ Model successfully registered in Unity Catalog: {full_model_name}")
print(f"✓ MLflow model URI: {model_info.model_uri}")
# Print deployment information
print(f"\n📦 Model Registration Complete!")
print(f"Unity Catalog Path: {full_model_name}")
print(f"Model Type: {model_type}")
except Exception as e:
print(f"✗ Model registration failed: {e}")
print("Model is still saved locally and can be registered manually")
print(f"Local model path: {training_args.output_dir}")
raise
Next steps
Your Qwen3-4B model has been successfully fine-tuned with full-weight fine-tuning and registered in Unity Catalog. Next, you can:
- Deploy the model: Serve models with Model Serving
- Learn more about distributed training: Multi-GPU and multi-node distributed training
- Track experiments and monitor GPUs: Experiment tracking and observability
- Troubleshoot issues: Troubleshoot issues on serverless GPU compute