Skip to main content

Fine-tune Olmo3 7B with Axolotl on multi-GPU serverless compute

This notebook demonstrates how to fine-tune the Olmo3 7B Instruct model using Axolotl on Databricks serverless GPU compute. Axolotl provides a high-performance framework for LLM post-training with QLoRA (Quantized Low-Rank Adaptation), enabling efficient fine-tuning on multi-GPU infrastructure. The trained model is logged to MLflow and registered in Unity Catalog for deployment.

Install required dependencies

Installs Axolotl with Flash Attention support, MLflow for experiment tracking, and compatible versions of transformers and optimization libraries. The cut-cross-entropy package provides memory-efficient loss computation for large language models.

Python
%pip install -U packaging setuptools wheel ninja
%pip install mlflow>=3.6
%pip install --no-build-isolation axolotl[flash-attn]>=0.12.0
%pip install transformers==4.57.3
%pip uninstall -y awq autoawq
%pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712"
dbutils.library.restartPython()

Retrieve HuggingFace token

Retrieves the HuggingFace authentication token from Databricks secrets. This token is required to download the Olmo3 7B base model from the HuggingFace Hub.

Python
HF_TOKEN = dbutils.secrets.get(scope="sgc-nightly-notebook", key="hf_token")

Configure training parameters

Sets up the Axolotl training configuration based on the olmo3-7b-qlora.yaml example. Key modifications include:

  • MLflow integration for experiment tracking
  • Unity Catalog volume path for checkpoint storage
  • SDPA (Scaled Dot Product Attention) instead of Flash Attention for broader GPU compatibility

Define Unity Catalog paths

Creates widgets to specify the Unity Catalog location for storing model checkpoints. The output directory combines the catalog, schema, volume, and model name into a fully qualified path.

Python
dbutils.widgets.text("catalog", "main")
dbutils.widgets.text("schema", "default")
dbutils.widgets.text("volume", "checkpoints")
dbutils.widgets.text("model", "openai/gpt-oss-20b")

UC_CATALOG = dbutils.widgets.get("catalog")
UC_SCHEMA = dbutils.widgets.get("schema")
UC_VOLUME = dbutils.widgets.get("volume")
UC_MODEL_NAME = dbutils.widgets.get("model")

print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_VOLUME: {UC_VOLUME}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")

OUTPUT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{UC_MODEL_NAME}"
print(f"OUTPUT_DIR: {OUTPUT_DIR}")
Python
import os
os.environ['AXOLOTL_DO_NOT_TRACK'] = '1'

Disable telemetry

Disables Axolotl's usage tracking by setting the environment variable.

Create Axolotl configuration

Defines the complete training configuration using Axolotl's DictDefault format. This includes model settings (QLoRA with 4-bit quantization), dataset configuration (Alpaca format), LoRA hyperparameters (rank 32, alpha 16), training parameters (1 epoch, batch size 2, gradient accumulation 4), and MLflow integration for experiment tracking.

Python
from axolotl.cli.config import load_cfg
from axolotl.utils.dict import DictDefault

# Config is based on with some changes to fit GPU types
# https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/olmo3/olmo3-7b-qlora.yaml

# Axolotl provides full control and transparency over model and training configuration
config = DictDefault(
base_model="allenai/Olmo-3-7B-Instruct-SFT",
plugins=[
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin"
],
load_in_8bit=False,
load_in_4bit=True,
datasets=[
{
"path": "fozziethebeat/alpaca_messages_2k_test",
"type": "chat_template"
}
],
dataset_prepared_path="last_run_prepared",
val_set_size=0.1,
output_dir=OUTPUT_DIR,
adapter="qlora",
lora_model_dir=None,
sequence_len=2048,
sample_packing=True,
lora_r=32,
lora_alpha=16,
lora_dropout=0.05,
lora_target_linear=True,
lora_target_modules=[
"gate_proj",
"down_proj",
"up_proj",
"q_proj",
"v_proj",
"k_proj",
"o_proj"
],
wandb_project=None,
wandb_entity=None,
wandb_watch=None,
wandb_name=None,
wandb_log_model=None,
gradient_accumulation_steps=4,
micro_batch_size=2,
num_epochs=1,
optimizer="adamw_bnb_8bit",
lr_scheduler="cosine",
learning_rate=0.0002,
bf16="auto",
tf32=False,
gradient_checkpointing=True,
resume_from_checkpoint=None,
logging_steps=1,
flash_attention=False,
warmup_ratio=0.1,
evals_per_epoch=1,
saves_per_epoch=1,
# Eval dataset is too small
eval_sample_packing=False,
# Write metrics to MLflow
use_mlflow=True,
mlflow_tracking_uri="databricks",
mlflow_run_name="olmo3-7b-qlora-axolotl",
hf_mlflow_log_artifacts=False,
wandb_mode="disabled",
attn_implementation="sdpa",
sdpa_attention=True,
save_first_step=True,
device_map=None,
)
Python
from axolotl.utils import set_pytorch_cuda_alloc_conf

set_pytorch_cuda_alloc_conf()

Configure PyTorch CUDA memory allocation

Optimizes GPU memory management for efficient training on multi-GPU setups.

Run distributed training on serverless GPU compute

Uses the @distributed decorator from the serverless GPU API to distribute the Axolotl training job across 8 H100 GPUs. The decorator handles multi-GPU orchestration, allowing the training function to run in a distributed environment without manual cluster setup.

Python
from serverless_gpu.launcher import distributed
from serverless_gpu.compute import GPUType

@distributed(gpus=8, gpu_type=GPUType.H100, remote=False)
def run_train(cfg: DictDefault):
import os
os.environ['HF_TOKEN'] = HF_TOKEN

from axolotl.common.datasets import load_datasets

# Load, parse and tokenize the datasets to be formatted with qwen3 chat template
# Drop long samples from the dataset that overflow the max sequence length

# validates the configuration
cfg = load_cfg(cfg)
dataset_meta = load_datasets(cfg=cfg)

from axolotl.train import train

# just train the first 16 steps for demo.
# This is sufficient to align the model as we've used packing to maximize the trainable samples per step.
cfg.max_steps = 16
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)

import mlflow
mlflow_run_id = None
if mlflow.last_active_run() is not None:
mlflow_run_id = mlflow.last_active_run().info.run_id

return mlflow_run_id
Python
result = run_train.distributed(config)

Execute the training job

Launches the distributed training job. The function loads the dataset, validates the configuration, trains the model for 16 steps, and returns the MLflow run ID for tracking.

Python
run_id = result[0]
print(run_id)

Extract MLflow run ID

Retrieves the MLflow run ID from the training results for model registration and experiment tracking.

Register the fine-tuned model to Unity Catalog

Loads the trained LoRA adapter, merges it with the base model, and registers the combined model to Unity Catalog via MLflow. This makes the model available for deployment and inference.

Note: This step requires H100 GPU compute to load the model checkpoint. Running on smaller GPUs may result in CUDA out-of-memory errors.

Python
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
try:
from transformers.activations import NewGELUActivation, PytorchGELUTanh, GELUActivation
except ImportError:
from transformers.activations import NewGELUActivation, GELUTanh as PytorchGELUTanh, GELUActivation

from peft import PeftModel
import mlflow
from mlflow import transformers as mlflow_transformers
import torch

HF_MODEL_NAME = "allenai/Olmo-3-7B-Instruct-SFT"

torch.cuda.empty_cache()
# Load the trained model for registration
print("Loading LoRA model for registration...")
# For LoRA models, we need both base model and adapter
base_model = AutoModelForCausalLM.from_pretrained(
HF_MODEL_NAME,
trust_remote_code=True
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME)
adapter_dir = OUTPUT_DIR
peft_model = PeftModel.from_pretrained(base_model, adapter_dir)
# Merge LoRA into base and drop PEFT wrappers
merged_model = peft_model.merge_and_unload()

components = {
"model": merged_model,
"tokenizer": tokenizer,
}

# Create Unity Catalog model name
full_model_name = f"{UC_CATALOG}.{UC_SCHEMA}.{UC_MODEL_NAME}"

print(f"Registering model as: {full_model_name}")

text_gen_pipe = pipeline(
task="text-generation",
model=peft_model,
tokenizer=tokenizer,
)

input_example = ["Hello, world!"]

with mlflow.start_run(run_id=run_id):
model_info = mlflow.transformers.log_model(
transformers_model=text_gen_pipe, # 🚨 pass the pipeline, not just the model
artifact_path="model",
input_example=input_example,
# optional: save_pretrained=False for reference-only PEFT logging
# save_pretrained=False,
)
# Start MLflow run and log model
print(f"✓ Model successfully registered in Unity Catalog: {full_model_name}")
print(f"✓ MLflow model URI: {model_info.model_uri}")
print(f"✓ Model version: {model_info.registered_model_version}")

# Print deployment information
print(f"\n📦 Model Registration Complete!")
print(f"Unity Catalog Path: {full_model_name}")
print(f"Optimization: Liger Kernels + LoRA")

Next steps

Example notebook

Fine-tune Olmo3 7B with Axolotl on multi-GPU serverless compute

Open notebook in new tab