Distributed training of two tower recommendation model using Lightning
This notebook demonstrates how to create a two tower recommendation model using the PyTorch Lightning Trainer API with distributed training across 8 H100 GPUs on a single node.
Reminders:
- For this demo, attach to GPU 8xH100 to leverage distributed training across multiple GPUs.
- The
@distributeddecorator from theserverless_gpuPython library will distribute the PyTorch Lightning training function across 8 H100 GPUs.
To get started, configure your notebook to use serverless GPU:
- Click the Connect dropdown at the top to open the compute selector.
- Select Serverless GPU.
- Open the Environment panel on the right side.
- Select 8xH100 as your accelerator.
- No package dependencies need to be configured in your environment panel to continue. Click Apply, then Confirm.
Your notebook is now connected to serverless GPU compute. The @distributed decorator will handle launching your training across all 8 GPUs.
Prerequisites
Before running through this demo, configure the widget variables at the top of this notebook:
catalog: The Unity Catalog catalog where the trained model will be registered.schema: The Unity Catalog schema from the above catalog where the trained model will be registered.
The dataset is downloaded automatically during the notebook run. The model will be saved in <catalog>.<schema>.<model_name_in_registry>.
Two tower recommendation model
For more insight into the two tower recommendation model, view the following resources:
###Instructions: Below, the code walks you through how to:
- Install dependencies
- Download and prepare dataset
- Required training configurations
- Two tower recommendation model definition
- Creating the main training function
- Training the two tower model
- Perform inference
- Register your model to MLflow for serving
##1) Install dependencies To begin, let's first install all the necessary libraries and ensure our environment is ready to go.
Pip install dependencies
Since TorchRec and the associated dependencies are not officially released and not installed as part of serverless GPU compute, manually install these libraries.
%pip install mlflow==3.7
%pip install -q iopath==0.1.10 pyre_extensions
%pip install -q --upgrade --no-deps --force-reinstall fbgemm-gpu==0.8.0 torchrec==0.8.0 torchmetrics==1.0.3 --index-url https://download.pytorch.org/whl/cu124
%pip install lightning==2.6.1
dbutils.library.restartPython()
####Import packages
There is a suite of imports used throughout the code. The following cell consolidates all necessary imports.
# General Imports
import os
import urllib.request
import zipfile
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
# Data Processing Imports
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
# Databricks Specific Imports
import mlflow
from mlflow import MlflowClient
from mlflow.models.signature import infer_signature
from mlflow.pyfunc import PythonModel
# Torch Specific Imports
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchmetrics.classification import AUROC
# PyTorch Lightning
import lightning.pytorch as pl
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint, DeviceStatsMonitor
from lightning.pytorch.loggers import MLFlowLogger
# TorchRec Specific Imports
from torchrec.datasets.utils import Batch
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.mlp import MLP
from torchrec.optim.keyed import KeyedOptimizerWrapper
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
##2) Download and prepare dataset Download the Learning from Sets dataset, preprocess it, and split into train/validation/test sets.
dbutils.widgets.text("catalog", "main")
dbutils.widgets.text("schema", "default")
dbutils.widgets.text("volume", "recsys")
catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
volume = dbutils.widgets.get("volume")
DATASET_URL = "https://files.grouplens.org/datasets/learning-from-sets-2019/learning-from-sets-2019.zip"
DATASET_PATH = f"/Volumes/{catalog}/{schema}/{volume}/dataset"
ZIP_PATH = f"{DATASET_PATH}/learning-from-sets-2019.zip"
CSV_PATH = f"{DATASET_PATH}/learning-from-sets-2019/item_ratings.csv"
# Download and extract
if not os.path.exists(CSV_PATH):
os.makedirs(DATASET_PATH, exist_ok=True)
print("Downloading dataset...")
urllib.request.urlretrieve(DATASET_URL, ZIP_PATH)
with zipfile.ZipFile(ZIP_PATH, "r") as zf:
zf.extractall(DATASET_PATH)
print("Download complete.")
# Load and preprocess
df = pd.read_csv(CSV_PATH)
df = df.sort_values(["userId", "movieId"]).head(100_000)
# Encode userId to contiguous integers
user_encoder = LabelEncoder()
df["userId"] = user_encoder.fit_transform(df["userId"])
# Binarize ratings: 1 if >= mean, else 0
mean_rating = df["rating"].mean()
df["label"] = (df["rating"] >= mean_rating).astype(np.int64)
df = df[["userId", "movieId", "label"]]
# Compute embedding table sizes from data
num_users = int(df["userId"].nunique())
num_movies = int(df["movieId"].nunique())
print(f"Dataset: {len(df)} rows, {num_users} users, {num_movies} movies")
# Split: 70% train, 21% validation, 9% test
train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.33, random_state=42)
print(f"Train: {len(train_df)}, Validation: {len(val_df)}, Test: {len(test_df)}")
class RecDataset(Dataset):
"""Wraps a DataFrame with columns [userId, movieId, label] as a PyTorch Dataset."""
def __init__(self, dataframe: pd.DataFrame):
self.users = dataframe["userId"].values.astype(np.int64)
self.movies = dataframe["movieId"].values.astype(np.int64)
self.labels = dataframe["label"].values.astype(np.int64)
def __len__(self) -> int:
return len(self.labels)
def __getitem__(self, idx: int) -> dict:
return {"userId": self.users[idx], "movieId": self.movies[idx], "label": self.labels[idx]}
def get_dataloader(dataframe: pd.DataFrame, batch_size: int = 1024, shuffle: bool = True) -> DataLoader:
return DataLoader(RecDataset(dataframe), batch_size=batch_size, shuffle=shuffle, num_workers=2, pin_memory=True)
##3) Required training configurations
All the arguments and information required for this training example are consolidated into the following cell. All of these can be modified to suit your use case.
@dataclass
class Args:
epochs: int = 3
embedding_dim: int = 128
layer_sizes: List[int] = field(default_factory=lambda: [128, 64])
learning_rate: float = 0.01
batch_size: int = 1024
cat_cols = ["userId", "movieId"]
emb_counts = [num_users, num_movies] # computed from data in section 2
##4) Two tower recommendation model definition
This section defines the model using PyTorch Lightning. For more information, see the documentation:
class TwoTowerModel(nn.Module):
def __init__(
self,
embedding_bag_collection: EmbeddingBagCollection,
layer_sizes: List[int],
device: Optional[torch.device] = None
) -> None:
super().__init__()
assert len(embedding_bag_collection.embedding_bag_configs()) == 2, "Expected two EmbeddingBags in the two tower model"
assert embedding_bag_collection.embedding_bag_configs()[0].embedding_dim == embedding_bag_collection.embedding_bag_configs()[1].embedding_dim, "Both EmbeddingBagConfigs must have the same dimension"
embedding_dim = embedding_bag_collection.embedding_bag_configs()[0].embedding_dim
self._feature_names_query: List[str] = embedding_bag_collection.embedding_bag_configs()[0].feature_names
self._candidate_feature_names: List[str] = embedding_bag_collection.embedding_bag_configs()[1].feature_names
self.ebc = embedding_bag_collection
self.query_proj = MLP(in_size=embedding_dim, layer_sizes=layer_sizes, device=device)
self.candidate_proj = MLP(in_size=embedding_dim, layer_sizes=layer_sizes, device=device)
def forward(self, kjt: KeyedJaggedTensor) -> Tuple[torch.Tensor, torch.Tensor]:
pooled_embeddings = self.ebc(kjt)
query_embedding: torch.Tensor = self.query_proj(
torch.cat(
[pooled_embeddings[feature] for feature in self._feature_names_query],
dim=1,
)
)
candidate_embedding: torch.Tensor = self.candidate_proj(
torch.cat(
[pooled_embeddings[feature] for feature in self._candidate_feature_names],
dim=1,
)
)
return query_embedding, candidate_embedding
class LitTwoTower(pl.LightningModule):
"""
PyTorch Lightning module wrapping a TwoTowerModel.
Uses torchmetrics AUROC for train/val metrics.
"""
def __init__(
self,
two_tower: nn.Module,
device: torch.device,
emb_counts: Optional[List[int]],
cat_cols: List[str],
lr: float = 1e-3,
) -> None:
super().__init__()
self.two_tower = two_tower
self.loss_fn = nn.BCEWithLogitsLoss()
self.train_auroc = AUROC(task="binary")
self.val_auroc = AUROC(task="binary")
self.lr = lr
# Store metadata used in batch transform
self.emb_counts = emb_counts
self.cat_cols = cat_cols
self.save_hyperparameters(ignore=["two_tower", "device"])
def forward(self, batch: Dict[str, Any]) -> torch.Tensor:
kjt_batch = self._transform_to_torchrec_batch(batch, self.emb_counts)
query_embedding, candidate_embedding = self.two_tower(kjt_batch.sparse_features)
logits = (query_embedding * candidate_embedding).sum(dim=1).squeeze()
return logits
def _loss(self, outputs: torch.Tensor, batch: Dict[str, Any]) -> torch.Tensor:
labels = self._get_batch_labels(batch)
return self.loss_fn(outputs, labels)
def _update_metric(self, batch: Dict[str, Any], outputs: Optional[torch.Tensor], metric: AUROC) -> None:
if outputs is None:
outputs = self.forward(batch)
preds = torch.sigmoid(outputs)
labels = self._get_batch_labels(batch)
metric.update(preds, labels)
def training_step(self, batch: Dict[str, Any], batch_idx: int):
logits = self.forward(batch)
loss = self._loss(logits, batch)
# Metric update
self._update_metric(batch, logits, self.train_auroc)
# Log both step and epoch loss series; enable sync_dist for multi-GPU/DDP
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
self.log("train_auroc", self.train_auroc, on_step=False, on_epoch=True, prog_bar=True,
logger=True, sync_dist=True)
return loss
def validation_step(self, batch: Dict[str, Any], batch_idx: int):
logits = self.forward(batch)
loss = self._loss(logits, batch)
self._update_metric(batch, logits, self.val_auroc)
# Typically only epoch-level val metrics are needed for monitoring
self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
self.log("val_auroc", self.val_auroc, on_step=False, on_epoch=True, prog_bar=True,
logger=True, sync_dist=True)
def configure_optimizers(self):
optimizer = KeyedOptimizerWrapper(
dict(self.two_tower.named_parameters()),
lambda params: torch.optim.Adam(params, lr=self.lr),
)
return optimizer
def _get_batch_labels(self, batch: Dict[str, Any]) -> torch.Tensor:
return batch["label"].to(dtype=torch.float32, device=self.device)
def _transform_to_torchrec_batch(
self,
batch: Dict[str, Any],
num_embeddings_per_feature: Optional[List[int]],
) -> Batch:
kjt_values_list = []
kjt_lengths_list = []
for col_idx, col_name in enumerate(self.cat_cols):
values = batch[col_name]
num_emb = num_embeddings_per_feature[col_idx]
kjt_values_list.append(values % num_emb)
kjt_lengths_list.append(torch.ones(len(values), dtype=torch.int64))
values_t = torch.cat(kjt_values_list).to(dtype=torch.int64, device=self.device)
lengths_t = torch.cat(kjt_lengths_list).to(device=self.device)
sparse_features = KeyedJaggedTensor.from_lengths_sync(
self.cat_cols,
values_t,
lengths_t,
)
labels = batch["label"].to(dtype=torch.int64, device=self.device)
return Batch(
dense_features=torch.zeros(1, device=self.device),
sparse_features=sparse_features,
labels=labels,
)
def create_two_tower_model(args, device, cat_cols, emb_counts) -> LitTwoTower:
eb_configs = [
EmbeddingBagConfig(
name=f"t_{feature_name}",
embedding_dim=args.embedding_dim,
num_embeddings=emb_counts[feature_idx],
feature_names=[feature_name],
)
for feature_idx, feature_name in enumerate(cat_cols)
]
ebc = EmbeddingBagCollection(tables=eb_configs, device=device)
base = TwoTowerModel(
embedding_bag_collection=ebc, layer_sizes=args.layer_sizes, device=device
)
lit = LitTwoTower(
base, cat_cols=cat_cols, emb_counts=emb_counts, device=device, lr=args.learning_rate
)
return lit
##5) Create the main training function
Next, use the @distributed decorator from serverless_gpu library together with the helper functions and the Trainer API by PyTorch Lightning to launch training on multiple GPUs.
# setup mlflow experiment
username = spark.sql("SELECT current_user()").first()['current_user()']
experiment_path = f'/Users/{username}/sgc-torchrec-example'
experiment = mlflow.set_experiment(experiment_path)
os.environ["MLFLOW_EXPERIMENT_NAME"] = experiment_path
from serverless_gpu import distributed
# get DB Host and Token
db_host = f"https://{dbutils.notebook.entry_point.getDbutils().notebook().getContext().browserHostName().get()}/"
db_token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
# setup arguments for training function
args = Args(epochs=1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_PATH = f"/Volumes/{catalog}/{schema}/{volume}/checkpoints"
@distributed(gpus=8, gpu_type="H100")
def training_function(args=args, cat_cols=cat_cols, emb_counts=emb_counts, device=device,
train_data=train_df, val_data=val_df, checkpoint_path=CHECKPOINT_PATH):
mlflow.pytorch.autolog()
model = create_two_tower_model(args, device=device, cat_cols=cat_cols, emb_counts=emb_counts)
train_dataloader = get_dataloader(train_data, batch_size=args.batch_size, shuffle=True)
eval_dataloader = get_dataloader(val_data, batch_size=args.batch_size, shuffle=False)
mlflow_logger = MLFlowLogger(
experiment_name=experiment_path,
log_model="all",
)
ckpt_cb = ModelCheckpoint(
dirpath=checkpoint_path,
monitor="val_auroc",
mode="max",
save_top_k=1,
save_last=True, # enables last_model_path
filename="{epoch}-{val_auroc:.4f}",
)
callbacks = [
LearningRateMonitor(logging_interval="step"),
DeviceStatsMonitor(),
ckpt_cb,
]
trainer = Trainer(
max_epochs=args.epochs,
accelerator="gpu",
strategy="ddp",
devices=8,
log_every_n_steps=20,
logger=mlflow_logger,
callbacks=callbacks,
)
trainer.fit(
model,
train_dataloaders=train_dataloader,
val_dataloaders=eval_dataloader
)
# Return run_id and best checkpoint path
result = {
"run_id": trainer.logger.run_id, # MLflow run id
"best_model_checkpoint": ckpt_cb.best_model_path, # best checkpoint path
"last_model_checkpoint": ckpt_cb.last_model_path # last checkpoint path
}
return result
##6) Train the two tower model using the serverless GPU distributed training API
result = training_function.distributed()
##7) Test the best model checkpoint
Retrieve the best model checkpoint and run test to verify results
print(f"Experiment Name: {experiment.name}")
print(f"Experiment ID: {experiment.experiment_id}")
print(f"Artifact Location: {experiment.artifact_location}")
print(f"Lifecycle_stage: {experiment.lifecycle_stage}")
ranked_checkpoints = mlflow.search_logged_models(
experiment_ids=[experiment.experiment_id],
output_format="list",
order_by=[{"field_name": "metrics.accuracy", "ascending": False}]
)
best_checkpoint: mlflow.entities.LoggedModel = ranked_checkpoints[0]
print(best_checkpoint.metrics[0])
run_id = best_checkpoint.source_run_id
artifact_path = best_checkpoint.artifact_location
model_uri = f"runs:/{run_id}/{artifact_path}"
two_tower_model = mlflow.pytorch.load_model(model_uri)
num_batches = 5 # Number of batches to print out at a time
batch_size = 1 # Print out each individual row
test_dataloader = iter(get_dataloader(test_df, batch_size=batch_size, shuffle=False))
device = torch.device("cuda:0")
two_tower_model.to(device)
two_tower_model.eval()
for _ in range(num_batches):
next_batch = next(test_dataloader)
expected_result = next_batch["label"][0]
actual_result = two_tower_model(next_batch)
actual_result = torch.sigmoid(actual_result)
print(f"Expected Result: {expected_result}; Actual Result: {actual_result.round().item()}")
##8) Register your model to MLflow for serving
When the model in the previous step looks correct, use the corresponding run_id from the latest run to register the model. To make this easy to serve, create a PyFunc that wraps the two tower model to take in a simpler input: (Dict[str, List] -> List[float]).
class TwoTowerWrapper(PythonModel):
"""
MLflow PythonModel wrapper for TwoTower model that handles dictionary input and returns list outputs
"""
def __init__(self, two_tower_model):
self.two_tower_model = two_tower_model
def predict(self, model_input: Dict[str, List]) -> List[float]:
batch = {key: torch.tensor(value) for key, value in model_input.items()}
if "label" not in batch:
batch["label"] = torch.zeros(len(next(iter(batch.values()))))
with torch.no_grad():
output = self.two_tower_model(batch).cpu()
output = torch.sigmoid(output)
return output.tolist()
def preprocess_data(batch):
# turn the example test dataset from Dict[str, Tensor] to Dict[str, List] and remove the label
return {key: tensor.tolist() for key, tensor in batch.items() if key != "label"}
def add_and_get_model_signature(two_tower_model, test_dataloader):
current_batch = preprocess_data(next(test_dataloader))
pyfunc_two_tower_model = TwoTowerWrapper(two_tower_model)
current_output = pyfunc_two_tower_model.predict(current_batch)
signature = infer_signature(current_batch, current_output)
logged_model = mlflow.pyfunc.log_model(
artifact_path="two_tower_pyfunc",
python_model=pyfunc_two_tower_model,
signature=signature,
input_example=current_batch
)
return signature, logged_model
signature, logged_model = add_and_get_model_signature(two_tower_model, test_dataloader)
model_name = "two_tower_model"
uc_model_version = mlflow.register_model(
f"models:/{logged_model.model_id}",
name=f"{catalog}.{schema}.{model_name}"
)