Skip to main content

Distributed finetune Llama-3.2-3B with Unsloth on multiple A10 GPUs

This notebook demonstrates how to finetune the Llama-3.2-3B large language model using distributed training across multiple A10 GPUs. It combines the Unsloth library for optimized parameter-efficient fine-tuning with the serverless_gpu library for distributed training orchestration.

The notebook covers:

  • Configuring distributed training across 4 A10 GPUs
  • Loading and fine-tuning Llama-3.2-3B with LoRA adapters
  • Processing training data from the FineTome-100k dataset
  • Training with supervised fine-tuning (SFT) and MLflow tracking
  • Merging adapters and registering the model in Unity Catalog

Distributed training significantly reduces training time by parallelizing computation across multiple GPUs while maintaining model quality.

Requirements: Serverless GPU compute with A10 accelerators

This notebook requires GPU compute with A10 accelerators. Select A10 as the accelerator in the environment panel and click Apply.

Note: Compute provisioning can take up to 8 minutes. The distributed training will automatically provision 4 A10 GPUs when the training function executes.

Install required libraries

Install the Unsloth library with CUDA 12.4 and PyTorch 2.6.0 support, along with accelerate for distributed training, unsloth_zoo for additional utilities, and MLflow for experiment tracking. The Python runtime restarts after installation to load the new packages.

Python
%pip install unsloth[cu124-torch260]==2025.9.6
%pip install accelerate==1.7.0
%pip install unsloth_zoo==2025.9.8
%pip install mlflow>=3.6
%restart_python

Configure Unity Catalog and model settings

Define Unity Catalog locations and model configuration using notebook widgets for easy customization. The configuration includes:

  • Unity Catalog namespace (catalog, schema, model name, volume)
  • Base model selection (Llama-3.2-3B-Instruct from Unsloth)
  • Output directory for saving checkpoints to Unity Catalog volumes
  • Training dataset (FineTome-100k)
Python
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_model_name", "llama-3_2-3b")
dbutils.widgets.text("uc_volume", "checkpoints")

UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_MODEL_NAME = dbutils.widgets.get("uc_model_name")
UC_VOLUME = dbutils.widgets.get("uc_volume")

print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")
print(f"UC_VOLUME: {UC_VOLUME}")

# Model selection - Choose based on your compute constraints
MODEL_NAME = "unsloth/Llama-3.2-3B-Instruct" # or choose "unsloth/Llama-3.2-1B-Instruct"
OUTPUT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{UC_MODEL_NAME}" # Save checkpoint to UC Volume
DATASET_NAME = "mlabonne/FineTome-100k"

print(f"MODEL_NAME: {MODEL_NAME}")
print(f"OUTPUT_DIR: {OUTPUT_DIR}")
print(f"DATASET_NAME: {DATASET_NAME}")

Define the distributed training function

Create a training function decorated with @distributed(gpus=4, gpu_type='a10', remote=True) to enable multi-GPU training. This function encapsulates the entire training workflow:

  • Loading the base Llama-3.2-3B model and tokenizer
  • Applying LoRA adapters for parameter-efficient fine-tuning
  • Processing the FineTome-100k dataset with chat templates
  • Configuring the SFTTrainer with distributed training settings
  • Training the model with MLflow tracking
  • Saving the trained adapters and tokenizer to Unity Catalog volumes

The @distributed decorator automatically handles GPU provisioning and distributed training orchestration across 4 A10 GPUs.

Python
from serverless_gpu import distributed
from serverless_gpu import runtime as rt

@distributed(gpus=4, gpu_type='a10', remote=True)
def run_train():
from datasets import load_dataset
import logging
import mlflow
import torch

# IMPORTANT: import unsloth BEFORE trl
from unsloth import FastLanguageModel, is_bfloat16_supported
from unsloth.chat_templates import get_chat_template, standardize_sharegpt, train_on_responses_only

from trl import SFTTrainer
from transformers import TrainingArguments, DataCollatorForSeq2Seq
from transformers.integrations import MLflowCallback

max_seq_length = 2048 # Choose any!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = False # Use 4bit quantization to reduce memory usage. Can be False.

# Load model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=MODEL_NAME,
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

model = FastLanguageModel.get_peft_model(
model,
r=16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha=16,
lora_dropout=0, # Supports any, but = 0 is optimized
bias="none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
random_state=3407,
use_rslora=False, # We support rank stabilized LoRA
loftq_config=None, # And LoftQ
)

# Process data
tokenizer = get_chat_template(
tokenizer,
chat_template="llama-3.1",
)

def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
return { "text" : texts, }


dataset = load_dataset(DATASET_NAME, split="train")

dataset = standardize_sharegpt(dataset)
dataset = dataset.map(formatting_prompts_func, batched=True,)

trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = dataset,
dataset_text_field = "text",
max_seq_length = max_seq_length,
data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
dataset_num_proc = 6,
packing = False, # Can make training 5x faster for short sequences.
args = TrainingArguments(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
warmup_steps = 5,
# num_train_epochs = 1, # Set this for 1 full training run.
max_steps = 25,
learning_rate = 2e-4,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 1,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
output_dir = OUTPUT_DIR,
report_to = "mlflow", # Use MLflow to track model metrics,
run_name = f"{MODEL_NAME}-finetune-unsloth",
),
)
trainer = train_on_responses_only(
trainer,
instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
num_proc=1
)

trainer.train()

# Save model
if rt.get_global_rank() == 0:
logging.info("\nSaving trained model...")
trainer.save_model(OUTPUT_DIR)
logging.info("✓ LoRA adapters saved - use with base model for inference")
tokenizer.save_pretrained(OUTPUT_DIR)
logging.info("✓ Tokenizer saved with model")
logging.info(f"\n🎉 All artifacts saved to: {OUTPUT_DIR}")

mlflow_run_id = None
if mlflow.last_active_run() is not None:
mlflow_run_id = mlflow.last_active_run().info.run_id

return mlflow_run_id

Execute the distributed training

Launch the distributed training function across 4 A10 GPUs. The .distributed() method provisions the GPUs, distributes the training workload, and returns the MLflow run ID for tracking. This step may take several minutes as it provisions compute resources and runs the training loop.

Python
run_id = run_train.distributed()[0]

Model registration and deployment strategy

After distributed training completes, register the fine-tuned model for production use:

  • MLflow Tracking - Log model artifacts, training metrics, and metadata for experiment tracking
  • Unity Catalog - Register the model for centralized governance, access control, and lineage tracking
  • Model Versioning - Automatic versioning enables model lifecycle management and rollback capabilities
  • Metadata - Complete model information ensures reproducibility and compliance

Merge adapters and register in Unity Catalog

Load the trained LoRA adapters, merge them with the base Llama-3.2-3B model weights, and register the final model in Unity Catalog. This process:

  • Loads the base model and trained LoRA adapters from the checkpoint directory
  • Merges the adapter weights into the base model to create a single deployable model
  • Logs the merged model to MLflow with appropriate metadata
  • Registers the model in Unity Catalog for governance and deployment

The registered model is ready for deployment to model serving endpoints.

Python
print("\nRegistering model with MLflow and Unity Catalog...")

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import mlflow
from mlflow import transformers as mlflow_transformers

# Load the trained model for registration
print("Loading LoRA model for registration...")
# For LoRA models, we need both base model and adapter
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
adapter_dir = OUTPUT_DIR
peft_model = PeftModel.from_pretrained(base_model, adapter_dir)
# Merge LoRA into base and drop PEFT wrappers
merged_model = peft_model.merge_and_unload()

components = {
"model": merged_model,
"tokenizer": tokenizer,
}

# Create Unity Catalog model name
full_model_name = f"{UC_CATALOG}.{UC_SCHEMA}.{UC_MODEL_NAME}"

print(f"Registering model as: {full_model_name}")

# Start MLflow run and log model
task = "llm/v1/chat"
with mlflow.start_run(run_id=run_id):
model_info = mlflow.transformers.log_model(
transformers_model=components,
name="model",
task=task,
registered_model_name=full_model_name,
metadata={
"task": task,
"pretrained_model_name": MODEL_NAME,
"databricks_model_family": "Llama3.2",
},
)

print(f"✓ Model successfully registered in Unity Catalog: {full_model_name}")
print(f"✓ MLflow model URI: {model_info.model_uri}")
print(f"✓ Model version: {model_info.registered_model_version}")

# Print deployment information
print(f"\n📦 Model Registration Complete!")
print(f"Unity Catalog Path: {full_model_name}")
print(f"Optimization: Liger Kernels + LoRA")

Next steps

The fine-tuned model is now registered in Unity Catalog and ready for deployment. Learn more about distributed training and model serving:

Example notebook

Distributed finetune Llama-3.2-3B with Unsloth on multiple A10 GPUs

Open notebook in new tab