Skip to main content

Fine-tune Llama 3.1 8B using Mosaic LLM Foundry on Databricks Serverless GPU

This notebook demonstrates how to fine-tune a Llama 3.1 8B model using Mosaic LLM Foundry on Databricks Serverless GPU. LLM Foundry is a codebase for training, fine-tuning, evaluating, and deploying large language models with support for distributed training strategies.

The notebook uses:

  • Mosaic LLM Foundry: A framework for training and fine-tuning LLMs with built-in support for FSDP, efficient data loading, and MLflow integration
  • FSDP (Fully Sharded Data Parallel): Distributes model parameters, gradients, and optimizer states across GPUs
  • Databricks Serverless GPU: Automatically provisions and manages GPU compute resources
  • Unity Catalog: Stores model checkpoints and registers trained models
  • MLflow: Tracks experiments and logs training metrics

Install required libraries

Install Mosaic LLM Foundry and dependencies for distributed training:

  • llm-foundry: Core framework for LLM training and fine-tuning
  • mlflow: Experiment tracking and model registry
  • flash-attention: Optimized attention implementation for faster training
  • hf_transfer: Faster model downloads from Hugging Face
  • yamlmagic: Enables YAML configuration in notebook cells
Python
%pip install llm-foundry[gpu]==0.20.0
%pip install mlflow==3.6
%pip install matplotlib==3.10.0
%pip install --force-reinstall --no-cache-dir --no-deps "https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl"
%pip install hf_transfer
%pip install git+https://github.com/josejg/yamlmagic.git

Restart the Python environment

Restart the Python kernel to ensure the newly installed packages are available.

Python
dbutils.library.restartPython()

Configure Unity Catalog paths for model storage

Set up Unity Catalog locations for storing model checkpoints and registering the trained model. The configuration uses query parameters that can be customized without editing the code.

Python
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_model_name", "llama3_1-8b")
dbutils.widgets.text("uc_volume", "checkpoints")

UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_MODEL_NAME = dbutils.widgets.get("uc_model_name")
UC_VOLUME = dbutils.widgets.get("uc_volume")

MLFLOW_EXPERIMENT_NAME = '/Workspace/Shared/llm-foundry-sgc' # TODO: update this name

print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")
print(f"UC_VOLUME: {UC_VOLUME}")
print(f"EXPERIMENT_NAME: {MLFLOW_EXPERIMENT_NAME}")

# Model selection - Choose based on your compute constraints
OUTPUT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{UC_MODEL_NAME}" # Save checkpoint to UC Volume

print(f"OUTPUT_DIR: {OUTPUT_DIR}")

Define training configuration using YAML

Load the fine-tuning configuration from YAML format. The configuration specifies:

  • Model architecture and pretrained weights (Llama 3.1 8B)
  • FSDP settings for distributed training
  • Training hyperparameters (learning rate, batch size, optimizer)
  • Dataset configuration (mosaicml/dolly_hhrlhf)
  • MLflow logging and model checkpointing
  • Callbacks for monitoring and optimization
Python
%load_ext yamlmagic
Python
%%yaml config
seed: 17
model:
name: hf_causal_lm
pretrained: true
init_device: mixed
use_auth_token: true
use_flash_attention_2: true
pretrained_model_name_or_path: meta-llama/Llama-3.1-8B
loggers:
mlflow:
resume: true
tracking_uri: databricks
rename_metrics:
time/token: time/num_tokens
lr-DecoupledLionW/group0: learning_rate
log_system_metrics: true
experiment_name: "mlflow_experiment_name"
run_name: llama3_8b-finetune
model_registry_uri: databricks-uc
model_registry_prefix: main.linyuan
callbacks:
lr_monitor: {}
run_timeout:
timeout: 7200
scheduled_gc:
batch_interval: 1000
speed_monitor:
window_size: 10
memory_monitor: {}
runtime_estimator: {}
hf_checkpointer:
save_folder: "dbfs:/Volumes/main/sgc/checkpoints/llama3_1-8b-hf"
save_interval: "1ep"
precision: "bfloat16"
overwrite: true

mlflow_registered_model_name: "main.sgc.llama3_1_8b_full_ft"
mlflow_logging_config:
task: "llm/v1/completions"
metadata:
pretrained_model_name: "meta-llama/Llama-3.1-8B-Instruct"
optimizer:
lr: 5.0e-07
name: decoupled_lionw
betas:
- 0.9
- 0.95
weight_decay: 0
precision: amp_bf16
scheduler:
name: linear_decay_with_warmup
alpha_f: 0
t_warmup: 10ba
tokenizer:
name: meta-llama/Llama-3.1-8B
kwargs:
model_max_length: 1024
algorithms:
gradient_clipping:
clipping_type: norm
clipping_threshold: 1
autoresume: false
log_config: false
fsdp_config:
verbose: false
mixed_precision: PURE
state_dict_type: sharded
limit_all_gathers: true
sharding_strategy: FULL_SHARD
activation_cpu_offload: false
activation_checkpointing: true
activation_checkpointing_reentrant: false
max_seq_len: 1024
save_folder: "output_folder"
dist_timeout: 600
max_duration: 20ba
progress_bar: false
train_loader:
name: finetuning
dataset:
split: test
hf_name: mosaicml/dolly_hhrlhf
shuffle: true
safe_load: true
max_seq_len: 1024
packing_ratio: auto
target_prompts: none
target_responses: all
allow_pad_trimming: false
decoder_only_format: true
timeout: 0
drop_last: false
pin_memory: true
num_workers: 8
prefetch_factor: 2
persistent_workers: true
eval_interval: 1
save_interval: 1h
log_to_console: true
save_overwrite: true
python_log_level: debug
save_weights_only: false
console_log_interval: 10ba
device_eval_batch_size: 1
global_train_batch_size: 32
device_train_microbatch_size: 1
save_num_checkpoints_to_keep: 1
Python
config["loggers"]["mlflow"]["experiment_name"] = MLFLOW_EXPERIMENT_NAME
config["save_folder"] = OUTPUT_DIR

Define the distributed training function

This cell defines the training function that will run on 16 A10 GPUs using the @distributed decorator. The function:

  • Configures the Hugging Face token for model access
  • Enables fast model downloads with hf_transfer
  • Calls the LLM Foundry train() function with the YAML configuration
  • Returns the MLflow run ID for tracking the experiment

The @distributed decorator with remote=True automatically provisions serverless GPU compute and handles distributed training orchestration.

Python
from serverless_gpu import distributed
from serverless_gpu import runtime as sgc_runtime
from llmfoundry.command_utils.train import train
from omegaconf import DictConfig
import mlflow
import yaml
from huggingface_hub import constants

HF_TOKEN = dbutils.secrets.get(scope="sgc-nightly-notebook", key="hf_token")

@distributed(gpus=16, gpu_type='a10', remote=True)
def run_llm_foundry():
import os
import logging
os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
constants.HF_HUB_ENABLE_HF_TRANSFER = True
train(DictConfig(config))

logging.info("\n✓ Training completed successfully!")

mlflow_run_id = None
if mlflow.last_active_run() is not None:
mlflow_run_id = mlflow.last_active_run().info.run_id

return mlflow_run_id

Run the distributed training job

Execute the training function on 16 A10 GPUs. The function returns the MLflow run ID, which can be used to track metrics, view logs, and access the trained model in the MLflow UI.

Python
mlflow_run_id = run_llm_foundry.distributed()[0]
print(mlflow_run_id)

Next steps

Example notebook

Fine-tune Llama 3.1 8B using Mosaic LLM Foundry on Databricks Serverless GPU

Open notebook in new tab