Fine-tune an embedding model using contrastive learning
This notebook demonstrates how to fine-tune a BERT-style embedding model on serverless GPU compute using contrastive learning. You'll use the gte-large-en-v1.5 model and train it on a single A10G GPU with the MosaicML Composer trainer to save checkpoints, resume training, and log results to MLflow.
Embedding models are widely used for vector databases and retrieval augmented generation (RAG) applications. Fine-tuning an embedding model on your custom data is a powerful way to improve retrieval accuracy for your specific domain.
In this notebook, you'll learn how to:
- Install dependencies and configure your environment
- Load training data from Delta tables
- Convert data to Mosaic Streaming Dataset (MDS) format
- Configure the model, optimizer, and training parameters
- Train the model using contrastive learning with in-batch negatives
- Track experiments with MLflow and register models to Unity Catalog
- Review performance and serve the fine-tuned model
This example uses a preprocessed version of the MS Marco dataset, but you can adapt it to work with your own data.
Contrastive learning overview
Contrastive learning uses contrastive loss to learn data representations where similar instances are closer together in the latent space. For embedding models, this means treating queries like "what is an affenpinscher" and "cute dachshund" as more similar than "what is an affenpinscher" and "how do I use dbsql". The model compares a query to positive passages (relevant texts) and negative passages (irrelevant texts) to learn semantic differences.
Two approaches to selecting negative passages:
-
In-batch negatives: Negative passages are selected randomly from within the batch. For a given query-passage pair, all other passages in the batch become negative examples. With a batch size of 8, you get 7 negative passages and 1 positive passage per query. Larger batch sizes provide more negative examples, making this approach more effective.
-
Hard negatives: Pre-defined negative passages that are semantically challenging—they're potentially related to the query but slightly incorrect or irrelevant. The
llm-foundrycode supports hard negatives for more advanced fine-tuning.
This notebook uses in-batch negatives. If no negative passages are provided in your data, the llm-foundry code automatically infers them by treating positive passages from other queries as negatives.
Requirements
Before running this notebook, you need to configure several parameters and connect to serverless GPU compute.
Configure notebook parameters
This notebook uses query parameters (widgets) to configure paths and settings. Update the following parameters before running:
catalog: Unity Catalog catalog name (e.g.,main)schema: Unity Catalog schema nametrain_delta_table: Training Delta table name (without catalog/schema prefix)val_delta_table: Validation Delta table name (optional)uc_checkpoint_folder: Unity Catalog volume folder name for checkpointsregister_to: Model name for Unity Catalog registrationexperiment_name: MLflow experiment path (format:/Users/<username>/<run_name>)
Connect to serverless GPU compute
This notebook requires a single A10G GPU:
- Click the Connect dropdown at the top of the notebook.
- Select Serverless GPU.
- Open the Environment side panel on the right side of the notebook.
- Set Accelerator to A10 for this demo.
- Select Apply and click Confirm to apply this environment to your notebook.
For more information, see Serverless GPU compute.
Install dependencies
To begin, first install all the necessary libraries and ensure our environment is ready to go.
%pip install llm-foundry[gpu]==0.20.0
%pip uninstall flash_attn -y
%pip install transformers==4.46.0
%pip install hf_transfer
%restart_python
Set up environment variables
Configure environment variables for distributed training and temporary file storage.
import os
import tempfile
import mlflow
os.environ["TMPDIR"] = os.path.join(os.getcwd(), tempfile.mkdtemp())
os.environ["NCCL_DEBUG"] = "WARN"
os.environ["WORLD_SIZE"] = "1"
Create configuration widgets
Create input widgets for the notebook parameters. Fill in these values in the widget panel at the top of the notebook before proceeding.
# Create widgets for configuration
dbutils.widgets.text(
"train_delta_table", "ms_marco_v_1_1_train_processed", "Training Delta Table Name"
)
dbutils.widgets.text(
"val_delta_table", "ms_marco_v_1_1_val_processed", "Validation Delta Table Name"
)
dbutils.widgets.text("register_to", "sgc_ft_embedding", "Model Registry Path")
dbutils.widgets.text("experiment_name", "/Users/<EMAIL>/Embedding_finetuning", "MLflow Experiment Name")
dbutils.widgets.text("uc_checkpoint_folder", "checkpoints", "UC Checkpoint Folder")
dbutils.widgets.text("catalog", "main", "catalog")
dbutils.widgets.text("schema", "default", "schema")
# Validate widget inputs
assert dbutils.widgets.get("train_delta_table")
assert dbutils.widgets.get("register_to")
assert dbutils.widgets.get("experiment_name")
assert dbutils.widgets.get("catalog")
assert dbutils.widgets.get("schema")
assert dbutils.widgets.get("uc_checkpoint_folder")
# Build env paths
catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
train_delta_table = dbutils.widgets.get("train_delta_table")
val_delta_table = dbutils.widgets.get("val_delta_table")
uc_checkpoint_folder = dbutils.widgets.get("uc_checkpoint_folder")
register_to = dbutils.widgets.get("register_to")
experiment_name = dbutils.widgets.get("experiment_name")
train_delta_table = f"{catalog}.{schema}.{train_delta_table}"
val_delta_table = f"{catalog}.{schema}.{val_delta_table}" if val_delta_table else None
uc_checkpoint_path = f"{catalog}.{schema}.{uc_checkpoint_folder}"
register_to_path = f"{catalog}.{schema}.{register_to}"
Load training data from Delta tables
Load your training and validation data from Delta tables. The data must have the following columns:
query_text: The query or question textpositive_passage: Relevant text passage for the querynegative_passages(optional): Array of pre-defined hard negative passages
This example uses a preprocessed version of the MS Marco dataset. If you don't provide negative_passages, the training code will automatically use in-batch negatives.
df_train = spark.table(train_delta_table)
df_val = None
if val_delta_table:
df_val = spark.table(val_delta_table)
MODEL_REGISTRY_PREFIX = f"{catalog}.{schema}"
REGISTERED_MODEL_NAME = register_to
EXPERIMENT_NAME = experiment_name
UC_CHECKPOINT_PATH = f"{catalog}.{schema}.{uc_checkpoint_folder}"
Configure Databricks credentials
Set up Databricks CLI credentials for accessing Unity Catalog and MLflow during training. This configuration file enables the training process to authenticate with Databricks services.
%sh echo -e "[DEFAULT]\nhost=$DATABRICKS_HOST\ntoken=$DATABRICKS_TOKEN" > ~/.databrickscfg
Convert data to Mosaic Streaming Dataset format
Convert your Delta tables to Mosaic Streaming Dataset (MDS) format, which is optimized for distributed training on serverless GPU compute. MDS provides:
- Faster data loading during training
- Efficient compression and storage
- Seamless integration with the Composer trainer
The conversion function handles the required schema transformation and saves the MDS files to a Unity Catalog volume. If you have different column names in your data, update the convert_x function accordingly.
For more information, see StreamingDataset documentation.
import os
import gc
from streaming import MDSWriter, StreamingDataset
from pyspark.sql import DataFrame
import json
import warnings
warnings.filterwarnings("ignore", module="threadpoolctl")
def process_embedding_data(
df: DataFrame,
output_path: str,
compression: str,
hashes: list[str],
limit: str,
):
def convert_x(x: dict) -> dict:
return {
"query_text": x["query_text"],
"positive_passage": x["positive_passage"],
"negative_passages": (
json.dumps(x["negative_passages"])
if x.get("negative_passages") is not None
else "[]"
),
}
try:
dtypes = {
"query_text": "str",
"positive_passage": "str",
"negative_passages": "str",
}
print(f"Starting conversion to MDS at {output_path}")
# Clear memory before processing
gc.collect()
row_count = 0
with MDSWriter(
out=output_path,
columns=dtypes,
compression=compression,
hashes=hashes,
size_limit=limit,
) as out:
for row in df.toLocalIterator():
record = convert_x(row.asDict())
out.write(record)
row_count += 1
if row_count % 10000 == 0:
print(f"Processed {row_count} records...")
print(f"Successfully wrote {row_count} records to {output_path}")
except Exception as e:
print(f"Error during data conversion: {e}")
raise e
compression = "zstd:7"
hashes = ["sha1"]
limit = "10mb"
import os
# FUSE path for checking existence
uc_train_folder = f"/Volumes/{catalog}/{schema}/embedding_temp_data/train"
uc_val_folder = f"/Volumes/{catalog}/{schema}/embedding_temp_data/val"
# dbfs path for MDSWriter and StreamingDataset
train_folder = f"dbfs:/Volumes/{catalog}/{schema}/embedding_temp_data/train"
val_folder = f"dbfs:/Volumes/{catalog}/{schema}/embedding_temp_data/val"
train_index = os.path.join(uc_train_folder, "index.json")
val_index = os.path.join(uc_val_folder, "index.json")
if os.path.exists(train_index):
print("Train MDS data already exists, skipping conversion.")
else:
process_embedding_data(df_train, train_folder, compression, hashes, limit)
if df_val is not None:
if os.path.exists(val_index):
print("Validation MDS data already exists, skipping conversion.")
else:
process_embedding_data(df_val, val_folder, compression, hashes, limit)
Configure the embedding model
Define the model and tokenizer configuration for fine-tuning. This example uses the gte-large-en-v1.5 model from Hugging Face.
Key configuration parameters:
temperature: Hyperparameter (0-1) that scales similarity scores in contrastive loss. Tune this if loss values are extremely high or low (default: 0.5)pos_step_size: Position step size for negative sampling. Set to2for in-batch negatives, or1 + number of hard negativesif using pre-defined negativesvector_representation: How to represent embeddingsavg: Average token embeddings (recommended for most models)eos: Use the end-of-sequence token embedding
gather_in_batch_negatives: Set totruefor in-batch negatives,falsefor pre-defined hard negativespretrained_model_name_or_path: Hugging Face model identifier
The tokenizer configuration sets the maximum sequence length and special tokens for the model.
model_cfg = {
"name": "finetune_embedding_model",
"trust_remote_code": True,
"contrastive_config": {
"temperature": 0.5,
"pos_step_size": 2, # set to 2 when not using predefined hard negatives. Otherwise use 1 + number of hard negatives
"normalize_output": True,
"vector_representation": "avg", # or eos, depending on the model default
"gather_in_batch_negatives": True, # set to true when not using predefined hard negatives
},
"pretrained_model_name_or_path": "Alibaba-NLP/gte-large-en-v1.5",
"loss_fn": "torch_crossentropy"
}
tokenizer_cfg = {
"name": "Alibaba-NLP/gte-large-en-v1.5",
"kwargs": {
"eos_token": "</s>", # this is the standard eos token for gte-large-en-v1.5
"model_max_length": 128,
"trust_remote_code": True,
},
}
Configure MLflow logging
Set up MLflow logging to track training metrics, parameters, and artifacts. The logger configuration specifies:
- The MLflow experiment where runs will be logged
- Unity Catalog as the model registry destination
- The catalog and schema prefix for model registration
logger_cfg = {
"mlflow": {
"run_name": "finetune_embedding",
"tracking_uri": "databricks",
"experiment_name": EXPERIMENT_NAME,
"model_registry_uri": "databricks-uc",
"model_registry_prefix": MODEL_REGISTRY_PREFIX,
}
}
Configure training callbacks
Callbacks control various aspects of the training process. The most important callback is hf_checkpointer, which:
- Saves Hugging Face-compatible checkpoints to Unity Catalog
- Registers the model in Unity Catalog for serving
- Configures model metadata for provisioned throughput serving
- Saves checkpoints at regular intervals (every 1 hour in this example)
Other callbacks monitor learning rate, memory usage, and perform garbage collection to optimize GPU memory.
callback_cfg = {
"lr_monitor": {},
"scheduled_gc": {"batch_interval": 1000},
"memory_monitor": {},
"hf_checkpointer": {
"precision": "bfloat16",
"save_folder": UC_CHECKPOINT_PATH,
"save_interval": "1h",
"mlflow_logging_config": {
"task": "llm/v1/embeddings",
"metadata": {
"task": "llm/v1/embeddings",
"source": "huggingface",
"pretrained_model_name": "Alibaba-NLP/gte-large-en-v1.5",
"databricks_model_family": "NewModel (gte_v1_5)",
"databricks_model_size_parameters": "434m",
},
},
"mlflow_registered_model_name": REGISTERED_MODEL_NAME,
},
}
Configure training hyperparameters
Define the optimizer, learning rate scheduler, precision, and training algorithms. These are standard machine learning training parameters that you can adjust based on your data and requirements:
- Optimizer: AdamW with decoupled weight decay
- Learning rate: 3e-5 with cosine warmup schedule
- Precision: Automatic mixed precision with bfloat16 for faster training
- Gradient clipping: Prevents exploding gradients during training
optimizer_cfg = {
"lr": 0.00003,
"eps": 1.0e-08,
"name": "decoupled_adamw",
"betas": [0.9, 0.95],
"weight_decay": 0.0001,
}
precision_cfg = "amp_bf16"
scheduler_cfg = {"name": "cosine_with_warmup", "alpha_f": 0.02, "t_warmup": "0.06dur"}
algorithms_cfg = {
"gradient_clipping": {"clipping_type": "norm", "clipping_threshold": 1}
}
Configure data loaders
Define how the training and evaluation data will be loaded during training. The data loaders:
- Point to the MDS-formatted data in Unity Catalog volumes
- Configure text preprocessing (prepending "query: " and "passage: " prefixes)
- Set the maximum sequence length and batch handling
- Enable shuffling for better training convergence
Ensure the remote path points to your converted MDS data location.
train_loader = {
"name": "contrastive_pairs",
"dataset": {
"local": None,
"split": None,
"remote": train_folder,
"shuffle": True,
"max_seq_len": 128,
"shuffle_seed": 42,
"prepend_query": "query: ",
"prepend_passage": "passage: ",
"append_eos_token": True,
},
"drop_last": True,
"num_workers": 8,
}
eval_loader = {
"name": "contrastive_pairs",
"dataset": {
"local": None,
"split": None,
"remote": val_folder,
"shuffle": True,
"max_seq_len": 128,
"shuffle_seed": 42,
"prepend_query": "query: ",
"prepend_passage": "passage: ",
"append_eos_token": True,
},
"drop_last": True,
"num_workers": 8,
}
Assemble the complete training configuration
Combine all configuration components into a single training configuration object. This includes the model, tokenizer, data loaders, optimizer, callbacks, and training parameters.
from omegaconf import DictConfig, OmegaConf
config = {
"seed": 42,
"max_seq_len": 128,
"model": model_cfg,
"tokenizer": tokenizer_cfg,
"loggers": logger_cfg,
"callbacks": callback_cfg,
"run_name": "finetune-BERT",
"optimizer": optimizer_cfg,
"precision": precision_cfg,
"scheduler": scheduler_cfg,
"algorithms": algorithms_cfg,
"train_loader": train_loader,
"eval_loader": eval_loader,
"eval_first": True,
"save_folder": 'dbfs:/databricks/mlflow-tracking/{mlflow_experiment_id}/{mlflow_run_id}/artifacts/{run_name}/checkpoints',
"max_duration": "200ba",
"progress_bar": False,
"eval_interval": "50ba",
"save_interval": "50ba",
"log_to_console": True,
"load_weights_only": True,
"console_log_interval": "1ba",
"device_eval_batch_size": 1,
"eval_subset_num_batches": 4,
"global_train_batch_size": 1,
"device_train_microbatch_size": 1
}
cfg = DictConfig(config)
Train the embedding model
Run the training process using the MosaicML LLM-Foundry library, which provides optimized training for embedding models. The training function:
- Uses the
@distributeddecorator to provision GPU resources - Trains for 200 batches with evaluation every 50 batches
- Saves checkpoints to Unity Catalog
- Logs metrics and artifacts to MLflow
- Registers the final model to Unity Catalog
Training may take several minutes depending on your dataset size. You can monitor progress in the MLflow experiment.
For more information, see LLM-Foundry on GitHub.
from serverless_gpu import distributed
from llmfoundry.command_utils.train import train
from omegaconf import DictConfig
import mlflow
import torch
@distributed(gpus=1, gpu_type='a10', remote=True)
def run_training():
mlflow.end_run()
trainer = train(cfg)
run = mlflow.active_run()
mlflow_run_id = run.info.run_id
run_name = trainer.state.run_name
del trainer
mlflow.end_run()
return mlflow_run_id, run_name
# Run training
result = run_training.distributed()
mlflow_run_id, run_name = result[0]
print(f"Run ID: {mlflow_run_id}")
# Download and load checkpoint
checkpoint_artifact_path = f"{run_name}/checkpoints/ep0-ba200-rank0.pt"
print(f"Downloading: {checkpoint_artifact_path}")
local_path = mlflow.artifacts.download_artifacts(
artifact_uri=f"runs:/{mlflow_run_id}/{checkpoint_artifact_path}"
)
print(f"Downloaded to: {local_path}")
ckpt = torch.load(local_path, map_location="cpu", weights_only=False)
print("\nTop-level keys:")
print(list(ckpt.keys()))
if "state" in ckpt:
print("\nState keys:")
print(list(ckpt["state"].keys()))
if "model" in ckpt["state"]:
print(f"\nModel state dict: {len(ckpt['state']['model'])} keys")
for k in list(ckpt["state"]["model"].keys())[:10]:
print(f" {k}")
Review training results and serve the model
After training completes, review the results and deploy your fine-tuned model:
Review training metrics:
- Navigate to the MLflow experiment specified in
experiment_name- You can also find experiments on the Experiments page in the workspace UI
- Select your training run to view:
- Training and evaluation metrics in the Metrics tab
- Model parameters in the Parameters tab
- Checkpoints and artifacts in the Artifacts tab
- The Model Details tab shows the registered model in Unity Catalog
Serve the model:
- Navigate to the registered model at the path specified in
register_to - Select the latest model version
- Click Serve this model to deploy using provisioned throughput
- Configure the serving endpoint with your desired throughput settings
Next steps
Now that you've fine-tuned an embedding model using contrastive learning, explore these resources to learn more:
- Serverless GPU compute - Learn about serverless GPU features and capabilities
- Best practices for serverless GPU compute - Optimize your GPU workloads
- Foundation Model APIs - Deploy and serve models with provisioned throughput
- LLM-Foundry documentation - Explore advanced training features and configurations
- StreamingDataset documentation - Learn more about MDS format and optimization
- Unity Catalog model registry - Manage and version your models