Fine-tune Olmo3 7B with Axolotl on multi-GPU serverless compute
This notebook demonstrates how to fine-tune the Olmo3 7B Instruct model using Axolotl on Databricks serverless GPU compute. Axolotl provides a high-performance framework for LLM post-training with QLoRA (Quantized Low-Rank Adaptation), enabling efficient fine-tuning on multi-GPU infrastructure. The trained model is logged to MLflow and registered in Unity Catalog for deployment.
Install required dependencies
Installs Axolotl with Flash Attention support, MLflow for experiment tracking, and compatible versions of transformers and optimization libraries. The cut-cross-entropy package provides memory-efficient loss computation for large language models.
%pip install -U packaging setuptools wheel ninja
%pip install mlflow>=3.6
%pip install --no-build-isolation axolotl[flash-attn]>=0.12.0
%pip install transformers==4.57.3
%pip uninstall -y awq autoawq
%pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712"
dbutils.library.restartPython()
Retrieve HuggingFace token
Retrieves the HuggingFace authentication token from Databricks secrets. This token is required to download the Olmo3 7B base model from the HuggingFace Hub.
HF_TOKEN = dbutils.secrets.get(scope="sgc-nightly-notebook", key="hf_token")
Configure training parameters
Sets up the Axolotl training configuration based on the olmo3-7b-qlora.yaml example. Key modifications include:
- MLflow integration for experiment tracking
- Unity Catalog volume path for checkpoint storage
- SDPA (Scaled Dot Product Attention) instead of Flash Attention for broader GPU compatibility
Define Unity Catalog paths
Creates widgets to specify the Unity Catalog location for storing model checkpoints. The output directory combines the catalog, schema, volume, and model name into a fully qualified path.
dbutils.widgets.text("catalog", "main")
dbutils.widgets.text("schema", "default")
dbutils.widgets.text("volume", "checkpoints")
dbutils.widgets.text("model", "openai/gpt-oss-20b")
UC_CATALOG = dbutils.widgets.get("catalog")
UC_SCHEMA = dbutils.widgets.get("schema")
UC_VOLUME = dbutils.widgets.get("volume")
UC_MODEL_NAME = dbutils.widgets.get("model")
print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_VOLUME: {UC_VOLUME}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")
OUTPUT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{UC_MODEL_NAME}"
print(f"OUTPUT_DIR: {OUTPUT_DIR}")
import os
os.environ['AXOLOTL_DO_NOT_TRACK'] = '1'
Disable telemetry
Disables Axolotl's usage tracking by setting the environment variable.
Create Axolotl configuration
Defines the complete training configuration using Axolotl's DictDefault format. This includes model settings (QLoRA with 4-bit quantization), dataset configuration (Alpaca format), LoRA hyperparameters (rank 32, alpha 16), training parameters (1 epoch, batch size 2, gradient accumulation 4), and MLflow integration for experiment tracking.
from axolotl.cli.config import load_cfg
from axolotl.utils.dict import DictDefault
# Config is based on with some changes to fit GPU types
# https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/olmo3/olmo3-7b-qlora.yaml
# Axolotl provides full control and transparency over model and training configuration
config = DictDefault(
base_model="allenai/Olmo-3-7B-Instruct-SFT",
plugins=[
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin"
],
load_in_8bit=False,
load_in_4bit=True,
datasets=[
{
"path": "fozziethebeat/alpaca_messages_2k_test",
"type": "chat_template"
}
],
dataset_prepared_path="last_run_prepared",
val_set_size=0.1,
output_dir=OUTPUT_DIR,
adapter="qlora",
lora_model_dir=None,
sequence_len=2048,
sample_packing=True,
lora_r=32,
lora_alpha=16,
lora_dropout=0.05,
lora_target_linear=True,
lora_target_modules=[
"gate_proj",
"down_proj",
"up_proj",
"q_proj",
"v_proj",
"k_proj",
"o_proj"
],
wandb_project=None,
wandb_entity=None,
wandb_watch=None,
wandb_name=None,
wandb_log_model=None,
gradient_accumulation_steps=4,
micro_batch_size=2,
num_epochs=1,
optimizer="adamw_bnb_8bit",
lr_scheduler="cosine",
learning_rate=0.0002,
bf16="auto",
tf32=False,
gradient_checkpointing=True,
resume_from_checkpoint=None,
logging_steps=1,
flash_attention=False,
warmup_ratio=0.1,
evals_per_epoch=1,
saves_per_epoch=1,
# Eval dataset is too small
eval_sample_packing=False,
# Write metrics to MLflow
use_mlflow=True,
mlflow_tracking_uri="databricks",
mlflow_run_name="olmo3-7b-qlora-axolotl",
hf_mlflow_log_artifacts=False,
wandb_mode="disabled",
attn_implementation="sdpa",
sdpa_attention=True,
save_first_step=True,
device_map=None,
)
from axolotl.utils import set_pytorch_cuda_alloc_conf
set_pytorch_cuda_alloc_conf()
Configure PyTorch CUDA memory allocation
Optimizes GPU memory management for efficient training on multi-GPU setups.
Run distributed training on serverless GPU compute
Uses the @distributed decorator from the serverless GPU API to distribute the Axolotl training job across 8 H100 GPUs. The decorator handles multi-GPU orchestration, allowing the training function to run in a distributed environment without manual cluster setup.
from serverless_gpu.launcher import distributed
from serverless_gpu.compute import GPUType
@distributed(gpus=8, gpu_type=GPUType.H100, remote=False)
def run_train(cfg: DictDefault):
import os
os.environ['HF_TOKEN'] = HF_TOKEN
from axolotl.common.datasets import load_datasets
# Load, parse and tokenize the datasets to be formatted with qwen3 chat template
# Drop long samples from the dataset that overflow the max sequence length
# validates the configuration
cfg = load_cfg(cfg)
dataset_meta = load_datasets(cfg=cfg)
from axolotl.train import train
# just train the first 16 steps for demo.
# This is sufficient to align the model as we've used packing to maximize the trainable samples per step.
cfg.max_steps = 16
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
import mlflow
mlflow_run_id = None
if mlflow.last_active_run() is not None:
mlflow_run_id = mlflow.last_active_run().info.run_id
return mlflow_run_id
result = run_train.distributed(config)
Execute the training job
Launches the distributed training job. The function loads the dataset, validates the configuration, trains the model for 16 steps, and returns the MLflow run ID for tracking.
run_id = result[0]
print(run_id)
Extract MLflow run ID
Retrieves the MLflow run ID from the training results for model registration and experiment tracking.
Register the fine-tuned model to Unity Catalog
Loads the trained LoRA adapter, merges it with the base model, and registers the combined model to Unity Catalog via MLflow. This makes the model available for deployment and inference.
Note: This step requires H100 GPU compute to load the model checkpoint. Running on smaller GPUs may result in CUDA out-of-memory errors.
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
try:
from transformers.activations import NewGELUActivation, PytorchGELUTanh, GELUActivation
except ImportError:
from transformers.activations import NewGELUActivation, GELUTanh as PytorchGELUTanh, GELUActivation
from peft import PeftModel
import mlflow
from mlflow import transformers as mlflow_transformers
import torch
HF_MODEL_NAME = "allenai/Olmo-3-7B-Instruct-SFT"
torch.cuda.empty_cache()
# Load the trained model for registration
print("Loading LoRA model for registration...")
# For LoRA models, we need both base model and adapter
base_model = AutoModelForCausalLM.from_pretrained(
HF_MODEL_NAME,
trust_remote_code=True
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME)
adapter_dir = OUTPUT_DIR
peft_model = PeftModel.from_pretrained(base_model, adapter_dir)
# Merge LoRA into base and drop PEFT wrappers
merged_model = peft_model.merge_and_unload()
components = {
"model": merged_model,
"tokenizer": tokenizer,
}
# Create Unity Catalog model name
full_model_name = f"{UC_CATALOG}.{UC_SCHEMA}.{UC_MODEL_NAME}"
print(f"Registering model as: {full_model_name}")
text_gen_pipe = pipeline(
task="text-generation",
model=peft_model,
tokenizer=tokenizer,
)
input_example = ["Hello, world!"]
with mlflow.start_run(run_id=run_id):
model_info = mlflow.transformers.log_model(
transformers_model=text_gen_pipe, # 🚨 pass the pipeline, not just the model
artifact_path="model",
input_example=input_example,
# optional: save_pretrained=False for reference-only PEFT logging
# save_pretrained=False,
)
# Start MLflow run and log model
print(f"✓ Model successfully registered in Unity Catalog: {full_model_name}")
print(f"✓ MLflow model URI: {model_info.model_uri}")
print(f"✓ Model version: {model_info.registered_model_version}")
# Print deployment information
print(f"\n📦 Model Registration Complete!")
print(f"Unity Catalog Path: {full_model_name}")
print(f"Optimization: Liger Kernels + LoRA")
Next steps
- Best practices for Serverless GPU compute
- Troubleshoot issues on serverless GPU compute
- Multi-GPU and multi-node distributed training