Fine-tune Llama 3.1 8B using Mosaic LLM Foundry on Databricks Serverless GPU
This notebook demonstrates how to fine-tune a Llama 3.1 8B model using Mosaic LLM Foundry on Databricks Serverless GPU. LLM Foundry is a codebase for training, fine-tuning, evaluating, and deploying large language models with support for distributed training strategies.
The notebook uses:
- Mosaic LLM Foundry: A framework for training and fine-tuning LLMs with built-in support for FSDP, efficient data loading, and MLflow integration
- FSDP (Fully Sharded Data Parallel): Distributes model parameters, gradients, and optimizer states across GPUs
- Databricks Serverless GPU: Automatically provisions and manages GPU compute resources
- Unity Catalog: Stores model checkpoints and registers trained models
- MLflow: Tracks experiments and logs training metrics
Install required libraries
Install Mosaic LLM Foundry and dependencies for distributed training:
llm-foundry: Core framework for LLM training and fine-tuningmlflow: Experiment tracking and model registryflash-attention: Optimized attention implementation for faster traininghf_transfer: Faster model downloads from Hugging Faceyamlmagic: Enables YAML configuration in notebook cells
%pip install llm-foundry[gpu]==0.20.0
%pip install mlflow==3.6
%pip install matplotlib==3.10.0
%pip install --force-reinstall --no-cache-dir --no-deps "https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl"
%pip install hf_transfer
%pip install git+https://github.com/josejg/yamlmagic.git
Restart the Python environment
Restart the Python kernel to ensure the newly installed packages are available.
dbutils.library.restartPython()
Configure Unity Catalog paths for model storage
Set up Unity Catalog locations for storing model checkpoints and registering the trained model. The configuration uses query parameters that can be customized without editing the code.
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_model_name", "llama3_1-8b")
dbutils.widgets.text("uc_volume", "checkpoints")
UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_MODEL_NAME = dbutils.widgets.get("uc_model_name")
UC_VOLUME = dbutils.widgets.get("uc_volume")
MLFLOW_EXPERIMENT_NAME = '/Workspace/Shared/llm-foundry-sgc' # TODO: update this name
print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")
print(f"UC_VOLUME: {UC_VOLUME}")
print(f"EXPERIMENT_NAME: {MLFLOW_EXPERIMENT_NAME}")
# Model selection - Choose based on your compute constraints
OUTPUT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{UC_MODEL_NAME}" # Save checkpoint to UC Volume
print(f"OUTPUT_DIR: {OUTPUT_DIR}")
Define training configuration using YAML
Load the fine-tuning configuration from YAML format. The configuration specifies:
- Model architecture and pretrained weights (Llama 3.1 8B)
- FSDP settings for distributed training
- Training hyperparameters (learning rate, batch size, optimizer)
- Dataset configuration (mosaicml/dolly_hhrlhf)
- MLflow logging and model checkpointing
- Callbacks for monitoring and optimization
%load_ext yamlmagic
%%yaml config
seed: 17
model:
name: hf_causal_lm
pretrained: true
init_device: mixed
use_auth_token: true
use_flash_attention_2: true
pretrained_model_name_or_path: meta-llama/Llama-3.1-8B
loggers:
mlflow:
resume: true
tracking_uri: databricks
rename_metrics:
time/token: time/num_tokens
lr-DecoupledLionW/group0: learning_rate
log_system_metrics: true
experiment_name: "mlflow_experiment_name"
run_name: llama3_8b-finetune
model_registry_uri: databricks-uc
model_registry_prefix: main.linyuan
callbacks:
lr_monitor: {}
run_timeout:
timeout: 7200
scheduled_gc:
batch_interval: 1000
speed_monitor:
window_size: 10
memory_monitor: {}
runtime_estimator: {}
hf_checkpointer:
save_folder: "dbfs:/Volumes/main/sgc/checkpoints/llama3_1-8b-hf"
save_interval: "1ep"
precision: "bfloat16"
overwrite: true
mlflow_registered_model_name: "main.sgc.llama3_1_8b_full_ft"
mlflow_logging_config:
task: "llm/v1/completions"
metadata:
pretrained_model_name: "meta-llama/Llama-3.1-8B-Instruct"
optimizer:
lr: 5.0e-07
name: decoupled_lionw
betas:
- 0.9
- 0.95
weight_decay: 0
precision: amp_bf16
scheduler:
name: linear_decay_with_warmup
alpha_f: 0
t_warmup: 10ba
tokenizer:
name: meta-llama/Llama-3.1-8B
kwargs:
model_max_length: 1024
algorithms:
gradient_clipping:
clipping_type: norm
clipping_threshold: 1
autoresume: false
log_config: false
fsdp_config:
verbose: false
mixed_precision: PURE
state_dict_type: sharded
limit_all_gathers: true
sharding_strategy: FULL_SHARD
activation_cpu_offload: false
activation_checkpointing: true
activation_checkpointing_reentrant: false
max_seq_len: 1024
save_folder: "output_folder"
dist_timeout: 600
max_duration: 20ba
progress_bar: false
train_loader:
name: finetuning
dataset:
split: test
hf_name: mosaicml/dolly_hhrlhf
shuffle: true
safe_load: true
max_seq_len: 1024
packing_ratio: auto
target_prompts: none
target_responses: all
allow_pad_trimming: false
decoder_only_format: true
timeout: 0
drop_last: false
pin_memory: true
num_workers: 8
prefetch_factor: 2
persistent_workers: true
eval_interval: 1
save_interval: 1h
log_to_console: true
save_overwrite: true
python_log_level: debug
save_weights_only: false
console_log_interval: 10ba
device_eval_batch_size: 1
global_train_batch_size: 32
device_train_microbatch_size: 1
save_num_checkpoints_to_keep: 1
config["loggers"]["mlflow"]["experiment_name"] = MLFLOW_EXPERIMENT_NAME
config["save_folder"] = OUTPUT_DIR
Define the distributed training function
This cell defines the training function that will run on 16 A10 GPUs using the @distributed decorator. The function:
- Configures the Hugging Face token for model access
- Enables fast model downloads with
hf_transfer - Calls the LLM Foundry
train()function with the YAML configuration - Returns the MLflow run ID for tracking the experiment
The @distributed decorator with remote=True automatically provisions serverless GPU compute and handles distributed training orchestration.
from serverless_gpu import distributed
from serverless_gpu import runtime as sgc_runtime
from llmfoundry.command_utils.train import train
from omegaconf import DictConfig
import mlflow
import yaml
from huggingface_hub import constants
HF_TOKEN = dbutils.secrets.get(scope="sgc-nightly-notebook", key="hf_token")
@distributed(gpus=16, gpu_type='a10', remote=True)
def run_llm_foundry():
import os
import logging
os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
constants.HF_HUB_ENABLE_HF_TRANSFER = True
train(DictConfig(config))
logging.info("\n✓ Training completed successfully!")
mlflow_run_id = None
if mlflow.last_active_run() is not None:
mlflow_run_id = mlflow.last_active_run().info.run_id
return mlflow_run_id
Run the distributed training job
Execute the training function on 16 A10 GPUs. The function returns the MLflow run ID, which can be used to track metrics, view logs, and access the trained model in the MLflow UI.
mlflow_run_id = run_llm_foundry.distributed()[0]
print(mlflow_run_id)
Next steps
- Multi-GPU and multi-node distributed training
- Best practices for Serverless GPU compute
- Troubleshoot issues on serverless GPU compute
- Mosaic LLM Foundry documentation
- Unity Catalog model registry