Skip to main content

Full fine-tuning of Qwen3-4B

Open in Databricks

Fully fine-tune the Qwen3-4B large language model on a single H100 GPU. This walkthrough shows how to:

  • Run full fine-tuning, which updates every model parameter for maximum adaptation to your data
  • Use the Databricks AI v5 environment without installing any additional libraries
  • Leverage TRL (Transformer Reinforcement Learning) for supervised fine-tuning
  • Register the fine-tuned model in Unity Catalog for governance and deployment

Key concepts:

  • Full fine-tuning: Updates all model weights, giving the model the greatest capacity to learn from your dataset at the cost of higher memory and compute than parameter-efficient methods
  • TRL: A library for training language models with reinforcement learning and supervised fine-tuning
  • Memory-efficient training: Uses BF16 mixed precision and gradient checkpointing to fit a 4B-parameter full fine-tune on a single H100 GPU

Full fine-tuning vs LoRA decision matrix

This notebook uses full fine-tuning, which updates all model parameters. The alternative, LoRA (Low-Rank Adaptation), freezes the base model and trains only small adapter layers.

Scenario

Recommendation

Reason

Major model behavior change

Full fine-tuning

Updates all parameters for fundamental changes to model behavior

Highest possible quality on a single task

Full fine-tuning

No low-rank approximation, so the model has full capacity to adapt

Limited GPU memory

LoRA

Fits larger models in memory by training only ~1% of parameters

Multiple task-specific adapters

LoRA

Swap different adapters on the same base model

Full fine-tuning a 4B-parameter model requires significantly more GPU memory than LoRA because optimizer state and gradients are maintained for every parameter. This notebook uses BF16 mixed precision and gradient checkpointing so the training fits on a single H100 (80 GB) GPU.

Connect to serverless GPU compute

To connect to serverless GPU compute:

  1. Click the Connect drop-down menu in the notebook and select Serverless GPU.
  2. Choose a 1x H100 GPU as the accelerator.
  3. Open the Environment panel and choose AI v5 as the base environment.
  4. Click Apply.

For more information, see the GPU compute documentation.

Import libraries

The Databricks AI v5 environment already includes all the libraries required for this example (such as trl, transformers, datasets, and mlflow), so no additional installation is needed.

The next cell imports the required libraries for model training, dataset handling, and MLflow tracking.

Python
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import (
SFTConfig,
SFTTrainer,
setup_chat_format
)
import torch
import mlflow

Configuration setup

Unity Catalog integration

The next cell configures where your fine-tuned model will be stored and registered:

  • Catalog & Schema: Organize models within your Unity Catalog namespace (default: main.default)
  • Model Name: The registered model name in Unity Catalog for governance and deployment
  • Volume: Unity Catalog volume for storing model checkpoints during training

These widgets allow you to customize the storage location without editing code. The model will be registered as {catalog}.{schema}.{model_name} for easy access and version control.

Training hyperparameters

The cell also defines key training parameters:

  • Model & Dataset: Qwen3-4B with the Capybara conversational dataset
  • Batch Size (1): Number of examples per GPU per training step, kept small to fit a full fine-tune in memory
  • Gradient Accumulation (8): Accumulates gradients over 8 batches for an effective batch size of 8
  • Learning Rate (2e-5): Conservative rate appropriate for full fine-tuning
  • Max steps (50): Caps training at 50 steps for a fast demonstration run
  • Logging & Checkpointing: Saves progress every 25 steps, logs metrics every 10 steps
Python
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_model_name", "qwen3_4b_assistant")
dbutils.widgets.text("uc_volume", "checkpoints")

UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_MODEL_NAME = dbutils.widgets.get("uc_model_name")
UC_VOLUME = dbutils.widgets.get("uc_volume")

print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")
print(f"UC_VOLUME: {UC_VOLUME}")

# MLflow and Unity Catalog configuration

# Model selection
MODEL_NAME = "Qwen/Qwen3-4B"
DATASET_NAME = "trl-lib/Capybara"
OUTPUT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{UC_MODEL_NAME}"

# Training hyperparameters
BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 8
LEARNING_RATE = 2e-5
MAX_STEPS = 50
EVAL_STEPS = 25
LOGGING_STEPS = 10
SAVE_STEPS = 25

Load and prepare dataset

The next cell loads the training dataset and prepares it for fine-tuning:

  • Dataset: trl-lib/Capybara - high-quality conversational data optimized for instruction following
  • Train/validation split: Creates a 90/10 split if no test set exists
  • Data validation: Ensures proper formatting for conversational fine-tuning
Python
dataset = load_dataset(DATASET_NAME)
print(f"✓ Dataset loaded: {dataset}")

if "test" not in dataset:
print("Creating validation split from training data...")
dataset = dataset["train"].train_test_split(test_size=0.1, seed=42)
print("✓ Data split: 90% train, 10% validation")

Initialize model and tokenizer

The next cell loads the base model and tokenizer, then configures them for conversational fine-tuning:

  • Model loading: Downloads Qwen3-4B from Hugging Face in BF16 precision
  • Tokenizer setup: Configures fast tokenizer with proper padding
  • Chat formatting: Applies a chat template for structured conversations if the tokenizer doesn't already define one
  • Token configuration: Sets padding token to EOS token for proper sequence handling
Python
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
use_fast=True
)

# Chat template formatting for conversational fine-tuning
if tokenizer.chat_template is None:
print("Adding chat template for proper conversation formatting...")
model, tokenizer = setup_chat_format(model, tokenizer, format="chatml")
print("✓ ChatML format applied for structured conversations")

if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("✓ Padding token set to EOS token")

print("✓ Model and tokenizer loaded successfully")

Train the model

The next cell configures and executes the full fine-tuning process:

Training configuration

  • Batch configuration: 1 sample per device with 8 gradient accumulation steps (effective batch size: 8)
  • Optimization: Warmup steps, weight decay, and best model selection based on evaluation loss
  • Logging: Reports metrics to MLflow for experiment tracking

Key optimizations enabled

  • BF16 mixed precision: Faster computation with lower memory footprint, well suited to H100 GPUs
  • Gradient checkpointing: Trades extra compute for a large reduction in activation memory, which is what makes a 4B full fine-tune fit on a single H100
  • Gradient accumulation: Simulates larger batch sizes for stable training
  • Checkpointing: Saves the model every 25 steps with a limit of 2 checkpoints

The training loop logs progress every 10 steps and evaluates every 25 steps.

Python
with mlflow.start_run(run_name=f"{MODEL_NAME}_full-fine-tuning", log_system_metrics=True):
try:
print(f"Learning rate: {LEARNING_RATE}")

training_args_dict = {
"output_dir": OUTPUT_DIR,
"per_device_train_batch_size": BATCH_SIZE,
"per_device_eval_batch_size": BATCH_SIZE,
"gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS,
"learning_rate": LEARNING_RATE,
"max_steps": MAX_STEPS,
"eval_steps": EVAL_STEPS,
"logging_steps": LOGGING_STEPS,
"save_steps": SAVE_STEPS,
"save_total_limit": 2,
"report_to": "mlflow", # Log to MLflow
"warmup_steps": 10,
"weight_decay": 0.01,
"metric_for_best_model": "eval_loss",
"greater_is_better": False,
"eval_strategy": "steps", # Run evaluation every eval_steps
"save_strategy": "steps", # Checkpoint on the same cadence as eval
"load_best_model_at_end": True, # Register the best-eval checkpoint, not the last
"dataloader_pin_memory": False,
"remove_unused_columns": False,
"bf16": True, # Mixed precision training
"gradient_checkpointing": True, # Reduce activation memory for full fine-tuning
"gradient_checkpointing_kwargs": {"use_reentrant": False},
}

training_args = SFTConfig(**training_args_dict)

trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
processing_class=tokenizer,
)

print("\n" + "="*50)
print("STARTING TRAINING")
print("="*50)

print("🚀 Full fine-tuning Qwen3-4B on a single H100 GPU")

trainer.train()
print("\n✓ Training completed successfully!")

except Exception as e:
print(f"✗ Training failed: {e}")
raise

Save model artifacts

The next cell saves the trained model and tokenizer to the Unity Catalog volume:

  • Full model weights: Saves the complete fine-tuned model, ready to load directly for inference
  • Tokenizer: Saves tokenizer configuration for inference
  • Storage location: Saves to /Volumes/{catalog}/{schema}/{volume}/{model_name}
Python
try:
print("\nSaving trained model...")

trainer.save_model(training_args.output_dir)
print("✓ Full model weights saved")

tokenizer.save_pretrained(training_args.output_dir)
print("✓ Tokenizer saved with model")
print(f"\n🎉 All artifacts saved to: {training_args.output_dir}")

except Exception as e:
print(f"✗ Model saving failed: {e}")
raise

Register model in Unity Catalog

The next cell registers the fine-tuned model in Unity Catalog for governance and deployment:

Model registration workflow

  1. Load trained model: Loads the saved full-weight model and tokenizer
  2. Prepare for logging: Creates a transformers model dictionary with the model and tokenizer
  3. Register in Unity Catalog: Logs to MLflow and registers in Unity Catalog
  4. Add metadata: Includes task type, model family, and size information

Benefits of Unity Catalog registration

  • Governance: Centralized model registry with access control and lineage tracking
  • Versioning: Automatic version management for model lifecycle
  • Deployment: Easy deployment to model serving endpoints
  • Discoverability: Models are searchable and documented in Unity Catalog
Python
mlflow_run_id = mlflow.last_active_run().info.run_id
print("\nRegistering model with MLflow and Unity Catalog...")

with mlflow.start_run(run_id=mlflow_run_id):
try:
# Load the trained full-weight model for registration
print("Loading fine-tuned model for registration...")
trained_model = AutoModelForCausalLM.from_pretrained(
training_args.output_dir,
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(training_args.output_dir)
model_type = "Full fine-tuning"
size_params = "4b"

# Prepare transformers model dictionary
transformers_model = {
"model": trained_model,
"tokenizer": tokenizer
}

# Create Unity Catalog model name
full_model_name = f"{UC_CATALOG}.{UC_SCHEMA}.{UC_MODEL_NAME}"

print(f"Registering model as: {full_model_name}")

# Start MLflow run and log model
task = "llm/v1/chat"
model_info = mlflow.transformers.log_model(
transformers_model=transformers_model,
task=task,
registered_model_name=full_model_name,
metadata={
"task": task,
"pretrained_model_name": MODEL_NAME,
"databricks_model_family": "Qwen3ForCausalLM",
"databricks_model_size_parameters": size_params,
},
repo_type="local", # Fix: specify repo_type for local path
)

print(f"✓ Model successfully registered in Unity Catalog: {full_model_name}")
print(f"✓ MLflow model URI: {model_info.model_uri}")

# Print deployment information
print(f"\n📦 Model Registration Complete!")
print(f"Unity Catalog Path: {full_model_name}")
print(f"Model Type: {model_type}")

except Exception as e:
print(f"✗ Model registration failed: {e}")
print("Model is still saved locally and can be registered manually")
print(f"Local model path: {training_args.output_dir}")
raise

Next steps

Your Qwen3-4B model has been successfully fine-tuned with full-weight fine-tuning and registered in Unity Catalog. Next, you can:

Example notebook

Full fine-tuning of Qwen3-4B