Skip to main content

Fine-tune Llama 3.2 1B with LoRA using Serverless GPU

This notebook demonstrates how to fine-tune a large language model using supervised fine-tuning (SFT) with Low-Rank Adaptation (LoRA) on Databricks Serverless GPU. The notebook uses the Transformers Reinforcement Learning (TRL) library with DeepSpeed ZeRO Stage 3 optimization to efficiently train Llama 3.2 1B on a single node with 4 A10 GPUs.

Key concepts:

  • LoRA (Low-Rank Adaptation): A parameter-efficient fine-tuning technique that reduces the number of trainable parameters by adding small, trainable rank decomposition matrices to the model layers.
  • TRL (Transformers Reinforcement Learning): A library that provides tools for training language models with reinforcement learning and supervised fine-tuning.
  • DeepSpeed ZeRO Stage 3: A memory optimization technique that partitions model parameters, gradients, and optimizer states across GPUs to enable training of large models.
  • Serverless GPU: Databricks managed GPU compute that automatically provisions and scales GPU resources for training workloads.

For more information, see Serverless GPU compute.

Requirements

This notebook requires the following:

  • Serverless GPU compute: The notebook uses Databricks Serverless GPU with 4 A10 GPUs for distributed training. No cluster configuration is needed.
  • Unity Catalog: A Unity Catalog catalog and schema to store model checkpoints and register the trained model.
  • HuggingFace token: A HuggingFace access token stored in Databricks secrets to download the base model and dataset.
  • Python packages: The required packages (peft, trl, deepspeed, mlflow, hf_transfer) are installed in the setup section below.

Install required packages

Install the Python packages required for fine-tuning:

  • peft: Provides LoRA implementation for parameter-efficient fine-tuning
  • trl: Transformers Reinforcement Learning library for supervised fine-tuning
  • deepspeed: Enables memory-efficient distributed training with ZeRO optimization
  • mlflow: Tracks experiments and logs trained models
  • hf_transfer: Accelerates model downloads from HuggingFace Hub

After installation, restart the Python kernel to ensure all packages are properly loaded.

Python
%pip install --upgrade transformers==4.56.1
%pip install peft==0.17.1
%pip install trl==0.18.1
%pip install deepspeed>=0.15.4
%pip install mlflow>=3.6.0
%pip install hf_transfer==0.1.9
Python
%restart_python

Configure Unity Catalog and environment variables

Set up Unity Catalog locations for storing model checkpoints and registering the trained model. The notebook uses query parameters to configure:

  • Catalog and schema: Unity Catalog namespace for organizing models and checkpoints
  • Model name: Name for the registered model in Unity Catalog
  • Volume: Unity Catalog volume for storing model checkpoints during training

The configuration also retrieves the HuggingFace token from Databricks secrets and sets up the MLflow experiment for tracking training metrics.

Python
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_model_name", "llama3_2-1b")
dbutils.widgets.text("uc_volume", "checkpoints")

UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_MODEL_NAME = dbutils.widgets.get("uc_model_name")
UC_VOLUME = dbutils.widgets.get("uc_volume")

# Get HuggingFace token and username
hf_token = dbutils.secrets.get(scope="sgc-nightly-notebook", key="hf_token")
username = spark.sql("SELECT current_user()").collect()[0][0]

REGISTERED_MODEL_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{UC_MODEL_NAME}"
CHECKPOINT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{UC_MODEL_NAME}"
FINE_TUNED_MODEL_PATH = f"{CHECKPOINT_DIR}/fine-tuned-peft-model"
MLFLOW_EXPERIMENT_NAME = f"/Users/{username}/{UC_MODEL_NAME}"

# Create the Unity Catalog volume if it doesn't exist
spark.sql(f"CREATE VOLUME IF NOT EXISTS {UC_CATALOG}.{UC_SCHEMA}.{UC_VOLUME}")

print(f"👤 Username: {username}")
print("🔑 HuggingFace token configured")
print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")
print(f"UC_VOLUME: {UC_VOLUME}")
print(f"CHECKPOINT_DIR: {CHECKPOINT_DIR}")
print(f"MLFLOW_EXPERIMENT_NAME: {MLFLOW_EXPERIMENT_NAME}")

Python
import os
import json
import tempfile
import torch
import mlflow
from huggingface_hub import constants
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer, SFTConfig, ModelConfig, ScriptArguments, setup_chat_format
from peft import LoraConfig, get_peft_model, PeftModel

if mlflow.get_experiment_by_name(MLFLOW_EXPERIMENT_NAME) is None:
mlflow.create_experiment(name=MLFLOW_EXPERIMENT_NAME)
mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME)

Create DeepSpeed ZeRO Stage 3 configuration

DeepSpeed ZeRO (Zero Redundancy Optimizer) Stage 3 partitions model parameters, gradients, and optimizer states across all GPUs to reduce memory consumption per GPU. This enables training of large models that wouldn't fit in a single GPU's memory.

Key configuration settings:

  • bf16 enabled: Uses bfloat16 precision for faster training and reduced memory usage
  • Stage 3 optimization: Partitions all model states across GPUs
  • No CPU offloading: Keeps all data on GPUs for maximum performance on A10 hardware
  • Overlap communication: Overlaps gradient communication with computation for efficiency
Python
def create_deepspeed_config():
"""Create DeepSpeed ZeRO Stage 3 configuration for single node A10 training."""

deepspeed_config = {
"fp16": {
"enabled": False
},
"bf16": {
"enabled": True
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "none"
},
"offload_param": {
"device": "none"
},
"overlap_comm": True,
"contiguous_gradients": True,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": True
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": False
}

return deepspeed_config


# Create DeepSpeed configuration
deepspeed_config = create_deepspeed_config()
print("⚙️ DeepSpeed ZeRO Stage 3 configuration created")

Define training parameters and LoRA configuration

Configure the supervised fine-tuning parameters:

  • Model: Llama 3.2 1B Instruct, a compact model suitable for A10 GPUs
  • Dataset: Capybara dataset from the TRL library for conversational AI training
  • Batch size: 2 per device with 4 gradient accumulation steps for effective batch size of 64
  • Learning rate: 2e-4 with cosine scheduler and warmup
  • Training steps: 60 steps for demonstration (increase for full training)
  • LoRA parameters: Rank 16 with alpha 32, targeting attention and MLP projection layers

The configuration uses bfloat16 precision and gradient checkpointing to optimize memory usage.

Python
def create_training_config():
"""Create training configuration for TRL SFT with LoRA."""

# Model and dataset configuration (not part of TrainingArguments)
model_config = {
"model_name": "meta-llama/Llama-3.2-1B-Instruct", # Small Llama model for A10
"dataset_name": "trl-lib/Capybara"
}

# Training arguments that will be passed directly to TrainingArguments
training_args_config = {
"output_dir": CHECKPOINT_DIR,
"per_device_train_batch_size": 2,
"per_device_eval_batch_size": 2,
"gradient_accumulation_steps": 4,
"learning_rate": 2e-4,
"max_steps": 60, # TO DO remove when fine-tuning on full dataset. Demo purposes only.
# "num_train_epochs": 1, # TO DO update to >= 1 when fine-tuning on full dataset
"logging_steps": 10,
"save_steps": 30,
"eval_steps": 30,
"eval_strategy": "steps",
"warmup_steps": 10,
"lr_scheduler_type": "cosine",
"gradient_checkpointing": True,
"fp16": False,
"bf16": True,
"optim": "adamw_torch",
"remove_unused_columns": False,
"run_name": f"llama3.2-1b-lora",
"report_to": "mlflow",
"save_total_limit": 2,
"load_best_model_at_end": True,
"metric_for_best_model": "eval_loss",
"greater_is_better": False,
}

# LoRA configuration
lora_config = {
"r": 16,
"lora_alpha": 32,
"target_modules": ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0.1,
"bias": "none",
"task_type": "CAUSAL_LM"
}

return model_config, training_args_config, lora_config

# Create training configuration
model_config, training_args_config, lora_config = create_training_config()

print("📊 Training Configuration:")
print(f" 🤖 Model: {model_config['model_name']}")
print(f" 📚 Dataset: {model_config['dataset_name']}")
print(f" 🎯 Batch size: {training_args_config['per_device_train_batch_size']}")
print(f" 📈 Learning rate: {training_args_config['learning_rate']}")
print(f" 🧠 LoRA rank: {lora_config['r']}")

Define the distributed training function

The @distributed decorator from the serverless_gpu library enables seamless execution of GPU workloads on Databricks Serverless GPU. The decorator provisions 4 A10 GPUs and handles distributed training setup automatically.

Key parameters:

  • gpus=4: Requests 4 GPUs for distributed training
  • gpu_type='A10': Specifies A10 GPU hardware
  • remote=True: Executes on remote serverless GPU compute

The training function:

  1. Loads the base model and tokenizer from HuggingFace
  2. Sets up chat formatting for conversational AI
  3. Configures LoRA for parameter-efficient fine-tuning
  4. Loads the training dataset
  5. Initializes the TRL SFTTrainer with DeepSpeed optimization
  6. Trains the model and saves checkpoints
  7. Returns training results and MLflow run ID

For more information, see the Serverless GPU API documentation.

Python
from serverless_gpu import distributed

os.environ['MLFLOW_EXPERIMENT_NAME'] = MLFLOW_EXPERIMENT_NAME
@distributed(
gpus=4,
gpu_type='A10',
remote=True, # Set to False to run locally, True for remote GPUs
)
def run_distributed_trl_sft():
"""
Distributed TRL SFT training function using serverless GPU.

This function will be executed on the A10 GPU with DeepSpeed optimization.
"""

# Set up environment variables for remote jobs
import os
import tempfile
import json
from huggingface_hub import constants
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from trl import SFTTrainer, setup_chat_format
from peft import LoraConfig, get_peft_model

# HuggingFace configuration
os.environ["HUGGING_FACE_HUB_TOKEN"] = hf_token
os.environ['HF_TOKEN'] = hf_token
constants.HF_HUB_ENABLE_HF_TRANSFER = True

# Set up temporary directories
temp_dir = tempfile.mkdtemp()

print("🚀 Starting TRL SFT training on A10 GPU...")

try:
# Load tokenizer and model
print(f"📥 Loading model: {model_config['model_name']}")
tokenizer = AutoTokenizer.from_pretrained(model_config['model_name'])

# Add pad token if not present
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
model_config['model_name'],
dtype="auto"
)

# Setup chat template if the model doesn't have one
# This is crucial for conversational AI models and TRL SFTTrainer
if tokenizer.chat_template is None:
print("🗨️ Setting up chat template...")
model, tokenizer = setup_chat_format(model, tokenizer, format="chatml")

# Configure LoRA
print("🔧 Setting up LoRA configuration...")
peft_config = LoraConfig(**lora_config)

# Load dataset
print(f"📚 Loading dataset: {model_config['dataset_name']}")
dataset = load_dataset(model_config['dataset_name'])

# Create temporary DeepSpeed config file
deepspeed_config_path = os.path.join(temp_dir, "deepspeed_config.json")
with open(deepspeed_config_path, "w") as f:
json.dump(deepspeed_config, f, indent=2)

# Training arguments - dynamically pass all config parameters
training_args = TrainingArguments(
**training_args_config,
deepspeed=deepspeed_config_path, # Override deepspeed with the config file path
)

# Initialize SFT Trainer
print("🏋️ Initializing SFT Trainer with DeepSpeed...")
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"] if "test" in dataset else None,
processing_class=tokenizer,
peft_config=peft_config
)

# Start training
print("🎯 Starting training...")
trainer.train()

# Save the model
print("💾 Saving trained model...")
trainer.save_model()

# Get training results
train_results = trainer.state.log_history
final_loss = train_results[-1].get('train_loss', 'N/A') if train_results else 'N/A'

print("✅ Training completed successfully!")
print(f"📊 Final training loss: {final_loss}")

mlflow_run_id = None
if mlflow.last_active_run() is not None:
mlflow_run_id = mlflow.last_active_run().info.run_id

return {
"status": "success",
"final_loss": final_loss,
"output_dir": training_args_config['output_dir'],
"model_name": model_config['model_name'],
"mlflow_run_id": mlflow_run_id,
}

except Exception as e:
print(f"❌ Training failed: {e}")
import traceback
traceback.print_exc()
return {
"status": "failed",
"error": str(e)
}

Run the distributed training job

Execute the training function by calling .distributed() on the decorated function. This provisions the serverless GPU resources, runs the training across 4 A10 GPUs with DeepSpeed optimization, and returns the results.

The training process:

  • Provisions 4 A10 GPUs automatically
  • Downloads the model and dataset from HuggingFace
  • Trains the model with LoRA fine-tuning
  • Saves checkpoints to the Unity Catalog volume
  • Logs metrics to MLflow
  • Returns training status, final loss, and MLflow run ID
Python
# Execute the distributed training
results = run_distributed_trl_sft.distributed()

print("🏁 Training execution completed!")
print(f"📊 Results: {results}")

if results and results[0].get('status') == 'success':
print("✅ Training completed successfully!")
print(f"💾 Model saved to: {results[0].get('output_dir', 'N/A')}")
print(f"📈 Final loss: {results[0].get('final_loss', 'N/A')}")
print(f"🎉 MLflow run ID: {results[0].get('mlflow_run_id', 'N/A')}")
else:
print("❌ Training failed!")
if results and 'error' in results:
print(f"🔍 Error: {results['error']}")

Save the fine-tuned model and test inference

This optional step loads the trained LoRA adapter, merges it with the base model, and saves the complete fine-tuned model. The merged model can then be tested with sample prompts to verify the fine-tuning results.

The process:

  1. Loads the base Llama 3.2 1B model
  2. Applies the trained LoRA adapter weights
  3. Merges the adapter into the base model
  4. Saves the merged model to the Unity Catalog volume
  5. Tests the model with a sample conversational prompt
Python
%pip install hf_transfer
Python
def save_and_load_trained_model():
"""Load the trained model from the Unity Catalog volume."""

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

# Load base model and tokenizer
base_model = AutoModelForCausalLM.from_pretrained(
model_config['model_name'],
dtype=torch.bfloat16,
token=hf_token,
trust_remote_code=True,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_config['model_name'], token=hf_token, trust_remote_code=True)

# Load LoRA weights
model = PeftModel.from_pretrained(base_model, training_args_config['output_dir'])

# Merge LoRA weights into base model
model = model.merge_and_unload()

# Save the merged model
model.save_pretrained(FINE_TUNED_MODEL_PATH)
tokenizer.save_pretrained(FINE_TUNED_MODEL_PATH)

# Return the merged model and tokenizer
return model, tokenizer

def test_trained_model(model, tokenizer):
"""Test the trained model with simple inference."""

try:
import torch
# Test prompt
# Create a conversation following the schema
conversation = [
{
"content": "What is machine learning?",
"role": "user"
}
]

# Convert conversation to chat format
prompt = ""
for message in conversation:
if message["role"] == "user":
prompt += f"### User: {message['content']}\n### Response:"
else:
prompt += f" {message['content']}\n\n"

# Tokenize
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=500,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)

# Decode
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("🤖 Model Response:")
print(response)
return response

except Exception as e:
print(f"❌ Model testing failed: {e}")

# Save and load the trained model
model, tokenizer = save_and_load_trained_model()

# Test the trained model
response = test_trained_model(model, tokenizer)

Register the model in Unity Catalog

Log the fine-tuned model to MLflow and register it in Unity Catalog for deployment and serving. The model is logged with:

  • Model and tokenizer: Both components needed for inference
  • Task type: Configured as llm/v1/chat for conversational AI
  • Input example: Sample chat message format for testing
  • Unity Catalog registration: Automatically registers the model in the configured catalog and schema

Once registered, the model can be deployed to model serving endpoints or used for batch inference.

Python
run_id = results[0].get('mlflow_run_id')
mlflow.set_registry_uri("databricks-uc")

# log the model to mlflow using the latest run id and register to Unity Catalog
with mlflow.start_run(run_id=run_id) as run:
components = {
"model": model,
"tokenizer": tokenizer
}
logged_model = mlflow.transformers.log_model(
transformers_model=components,
name="model",
task="llm/v1/chat",
input_example={
"messages": [
{"role": "user", "content": "What is machine learning?"}
]
},
registered_model_name=REGISTERED_MODEL_NAME
)
print(f"🔍 Model logged to: {logged_model}")

Next steps

Example notebook

Fine-tune Llama 3.2 1B with LoRA using Serverless GPU

Open notebook in new tab