Create a training run using the Foundation Model Fine-tuning API
Preview
This feature is in Public Preview in us-east-1
and us-west-2
.
This article describes how to create and configure a training run using the Foundation Model Fine-tuning (now part of Mosaic AI Model Training) API, and describes all of the parameters used in the API call. You can also create a run using the UI. For instructions, see Create a training run using the Foundation Model Fine-tuning UI.
Requirements
See Requirements.
Create a training run
To create training runs programmatically, use the create()
function. This function trains a model on the provided dataset and converts the final Composer checkpoint to a Hugging Face formatted checkpoint for inference.
The required inputs are the model you want to train, the location of your training dataset, and where to register your model. There are also optional parameters that allow you to perform evaluation and change the hyperparameters of your run.
After the run completes, the completed run and final checkpoint are saved, the model is cloned, and that clone is registered to Unity Catalog as a model version for inference.
The model from the completed run, not the cloned model version in Unity Catalog, and its Composer and Hugging Face checkpoints are saved to MLflow. The Composer checkpoints can be used for continued fine-tuning tasks.
See Configure a training run for details about arguments for the create()
function.
from databricks.model_training import foundation_model as fm
run = fm.create(
model='meta-llama/Llama-3.2-3B-Instruct',
train_data_path='dbfs:/Volumes/main/mydirectory/ift/train.jsonl', # UC Volume with JSONL formatted data
# Public HF dataset is also supported
# train_data_path='mosaicml/dolly_hhrlhf/train'
register_to='main.mydirectory', # UC catalog and schema to register the model to
)
Configure a training run
The following table summarizes the parameters for the foundation_model.create()
function.
Parameter |
Required |
Type |
Description |
---|---|---|---|
|
x |
str |
The name of the model to use. See Supported models. |
|
x |
str |
The location of your training data. This can be a location in Unity Catalog ( For For |
|
x |
str |
The Unity Catalog catalog and schema ( |
|
str |
The cluster ID of the cluster to use for Spark data processing. This is required for instruction training tasks where the training data is in a Delta table. For information on how to find the cluster ID, see Get cluster ID. |
|
|
str |
The path to the MLflow experiment where the training run output (metrics and checkpoints) is saved. Defaults to the run name within the user’s personal workspace (i.e. |
|
|
str |
The type of task to run. Can be |
|
|
str |
The remote location of your evaluation data (if any). Must follow the same format as |
|
|
List[str] |
A list of prompt strings to generate responses during evaluation. Default is |
|
|
str |
The remote location of a custom model checkpoint for training. Default is |
|
|
str |
The total duration of your run. Default is one epoch or |
|
|
str |
The learning rate for model training. For all models other than Llama 3.1 405B Instruct, the default learning rate is |
|
|
str |
The maximum sequence length of a data sample. This is used to truncate any data that is too long and to package shorter sequences together for efficiency. The default is 8192 tokens or the maximum context length for the provided model, whichever is lower. You can use this parameter to configure the context length, but configuring beyond each model’s maximum context length is not supported. See Supported models for the maximum supported context length of each model. |
|
|
Boolean |
Whether to validate the access to input paths before submitting the training job. Default is |
Build on custom model weights
Foundation Model Fine-tuning supports adding custom weights using the optional parameter custom_weights_path
to train and customize a model.
To get started, set custom_weights_path
to the Composer checkpoint path from a previous training run. Checkpoint paths can be found in the Artifacts tab of a previous MLflow run. The checkpoint folder name corresponds to the batch and epoch of a particular snapshot, such as ep29-ba30/
.
To provide the latest checkpoint from a previous run, set
custom_weights_path
to the Composer checkpoint. For example,custom_weights_path=dbfs:/databricks/mlflow-tracking/<experiment_id>/<run_id>/artifacts/<run_name>/checkpoints/latest-sharded-rank0.symlink
.To provide an earlier checkpoint, set
custom_weights_path
to a path to a folder containing.distcp
files corresponding to the desired checkpoint, such ascustom_weights_path=dbfs:/databricks/mlflow-tracking/<experiment_id>/<run_id>/artifacts/<run_name>/checkpoints/ep#-ba#
.
Next, update the model
parameter to match the base model of the checkpoint you passed to custom_weights_path
.
In the following example ift-meta-llama-3-1-70b-instruct-ohugkq
is a previous run that fine-tunes meta-llama/Meta-Llama-3.1-70B
. To fine-tune the latest checkpoint from ift-meta-llama-3-1-70b-instruct-ohugkq
, set the model
and custom_weights_path
variables as follows:
from databricks.model_training import foundation_model as fm
run = fm.create(
model = 'meta-llama/Meta-Llama-3.1-70B'
custom_weights_path = 'dbfs:/databricks/mlflow-tracking/2948323364469837/d4cd1fcac71b4fb4ae42878cb81d8def/artifacts/ift-meta-llama-3-1-70b-instruct-ohugkq/checkpoints/latest-sharded-rank0.symlink'
... ## other parameters for your fine-tuning run
)
See Configure a training run for configuring other parameters in your fine-tuning run.
Get cluster ID
To retrieve the cluster ID:
In the left nav bar of the Databricks workspace, click Compute.
In the table, click the name of your cluster.
Click in the upper-right corner and select View JSON from the drop-down menu.
The Cluster JSON file appears. Copy the cluster ID, which is the first line in the file.
Get status of a run
You can track the progress of a run using the Experiment page in the Databricks UI or using the API command get_events()
. For details, see View, manage, and analyze Foundation Model Fine-tuning runs.
Example output from get_events()
:
Sample run details on the Experiment page:
Next steps
After your training run is complete, you can review metrics in MLflow and deploy your model for inference. See steps 5 through 7 of Tutorial: Create and deploy a Foundation Model Fine-tuning run.
See the Instruction fine-tuning: Named Entity Recognition demo notebook for an instruction fine-tuning example that walks through data preparation, fine-tuning training run configuration and deployment.