databricks-logo

two-tower-recommender-model

(Python)
Loading...

Two Tower (using TorchRec + TorchDistributor + StreamingDataset)

This notebook illustrates how to create a distributed Two Tower recommendation model. This notebook was tested on g4dn.12xlarge instances (one instance as the driver, one instance as the worker) using the Databricks Runtime for ML 14.3 LTS. For more insight into the Two Tower recommendation model, you can view the following resources:

Note: Where you see # TODO in this notebook, you must enter custom code to ensure that the notebook runs successfully.

Requirements

This notebook requires 14.3 LTS ML.

1. Saving "Learning From Sets of Items" Data in UC Volumes in the MDS (Mosaic Data Shard) format

This notebook uses the small sample of 100k ratings from "Learning From Sets of Items". In this section you preprocess it and save it to a Volume in Unity Catalog.

3
%pip install -q mosaicml-streaming==0.7.5
dbutils.library.restartPython()

1.1. Downloading the Dataset

Download the dataframe from https://grouplens.org/datasets/learning-from-sets-of-items-2019/ to /databricks/driver and then save the data to a UC Table. The "Learning from Sets of Items" dataset has the Creative Commons 4.0 license.

%sh wget https://files.grouplens.org/datasets/learning-from-sets-2019/learning-from-sets-2019.zip -O /databricks/driver/learning-from-sets-2019.zip && unzip /databricks/driver/learning-from-sets-2019.zip -d /databricks/driver/
import pandas as pd

# Load the CSV file into a pandas DataFrame (since the data is stored on local machine)
df = pd.read_csv("/databricks/driver/learning-from-sets-2019/item_ratings.csv")

# Create a Spark DataFrame from the pandas DataFrame and save it to UC
spark_df = spark.createDataFrame(df)

# TODO: Update this with a path in UC for where this data should be saved
spark_df.write.saveAsTable("ml.recommender_systems.learning_from_sets_dataset")

1.2. Reading the Dataset from UC

The original dataset contains 500k data points. This example uses a sample of 100k data points from the dataset.

# TODO: Update this with the path in UC where this data is saved
spark_df = spark.table("ml.recommender_systems.learning_from_sets_dataset")
print(f"Dataset size: {spark_df.count()}")
display(spark_df)
# Order by userId and movieId (this allows you to get a better representation of movieIds and userIds for the dataset)
ordered_df = spark_df.orderBy("userId", "movieId").limit(100_000)

print(f"Updated Dataset Size: {ordered_df.count()}")
# Show the result
display(ordered_df)
from pyspark.sql.functions import countDistinct

# Get the total number of data points
print("Total # of data points:", ordered_df.count())

# Get the total number of users
total_users = ordered_df.select(countDistinct("userId")).collect()[0][0]
print(f"Total # of users: {total_users}")

# Get the total number of movies
total_movies = ordered_df.select(countDistinct("movieId")).collect()[0][0]
print(f"Total # of movies: {total_movies}")

1.3. Preprocessing and Cleaning the Data

The first step is to convert the hashes (in string format) of each user to an integer using the StringIndexer.

The Two Tower Model provided by TorchRec here requires a binary label. The code in this section converts all ratings less than the mean to 0 and all values greater than the mean to 1. For your own use case, you can modify the training task described here to use MSELoss instead (which can scale to ratings from 0 -> 5).

from pyspark.ml.feature import StringIndexer
from pyspark.sql.types import LongType

string_indexer = StringIndexer(inputCol="userId", outputCol="userId_index")
indexed_df = string_indexer.fit(ordered_df).transform(ordered_df)
indexed_df = indexed_df.withColumn("userId_index", indexed_df["userId_index"].cast(LongType()))
indexed_df = indexed_df.withColumn("userId", indexed_df["userId_index"]).drop("userId_index")

display(indexed_df)
from pyspark.sql import functions as F

# Select only the userId, movieId, and ratings columns
relevant_df = indexed_df.select('userId', 'movieId', 'rating')

# Calculate the mean of the ratings column
ratings_mean = relevant_df.groupBy().avg('rating').collect()[0][0]
print(f"Mean rating: {ratings_mean}")

# Modify all ratings less than the mean to 0 and greater than the mean to 1 and using a UDF to apply the transformation
modify_rating_udf = F.udf(lambda x: 0 if x < ratings_mean else 1, 'int')
relevant_df = relevant_df.withColumn('rating', modify_rating_udf('rating'))

# Rename rating to label
relevant_df = relevant_df.withColumnRenamed('rating', 'label')

# Displaying the dataframe
display(relevant_df)

1.4. Saving to MDS Format within UC Volumes

In this step, you convert the data to MDS to allow for efficient train/validation/test splitting and then save it to a UC Volume.

View the Mosaic Streaming guide here for more details:

  1. General details: https://docs.mosaicml.com/projects/streaming/en/stable/
  2. Main concepts: https://docs.mosaicml.com/projects/streaming/en/stable/getting_started/main_concepts.html#dataset-conversion
  3. dataframeToMDS details: https://docs.mosaicml.com/projects/streaming/en/stable/preparing_datasets/spark_dataframe_to_mds.html
# Split the dataframe into train, test, and validation sets
train_df, validation_df, test_df = relevant_df.randomSplit([0.7, 0.2, 0.1], seed=42)

# Show the count of each split to verify the distribution
print(f"Training Dataset Count: {train_df.count()}")
print(f"Validation Dataset Count: {validation_df.count()}")
print(f"Test Dataset Count: {test_df.count()}")
from streaming import StreamingDataset
from streaming.base.converters import dataframe_to_mds
from streaming.base import MDSWriter
from shutil import rmtree
import os
from tqdm import tqdm

# Parameters required for saving data in MDS format
cols = ["userId", "movieId"]
cat_dict = { key: 'int64' for key in cols }
label_dict = { 'label' : 'int' }
columns = {**label_dict, **cat_dict}

compression = 'zstd:7'

# TODO: Specify where the data will be stored
output_dir_train = "/Volumes/ml/recommender_systems/learning_from_sets_data/mds_train"
output_dir_validation = "/Volumes/ml/recommender_systems/learning_from_sets_data/mds_validation"
output_dir_test = "/Volumes/ml/recommender_systems/learning_from_sets_data/mds_test"

# Save the training data using the `dataframe_to_mds` function, which divides the dataframe into `num_workers` parts and merges the `index.json` from each part into one in a parent directory.
def save_data(df, output_path, label, num_workers=40):
    print(f"Saving {label} data to: {output_path}")
    mds_kwargs = {'out': output_path, 'columns': columns, 'compression': compression}
    dataframe_to_mds(df.repartition(num_workers), merge_index=True, mds_kwargs=mds_kwargs)

save_data(train_df, output_dir_train, 'train')
save_data(validation_df, output_dir_validation, 'validation')
save_data(test_df, output_dir_test, 'test')

2. Helper Functions for Recommendation Dataloading

In this section, you install the necessary libraries, add imports, and add some relevant helper functions to train the model.

2.1. Installs and Imports

%pip install -q --upgrade --no-deps --force-reinstall torch==2.2.2 torchvision==0.17.2 torchrec==0.6.0 fbgemm-gpu==0.6.0 --index-url https://download.pytorch.org/whl/cu118
%pip install torchmetrics==1.0.3 iopath==0.1.10 pyre_extensions==0.0.32 mosaicml-streaming==0.7.5 
dbutils.library.restartPython()
import os
from typing import List, Optional
from streaming import StreamingDataset, StreamingDataLoader

import torch
import torchmetrics as metrics
from torch import distributed as dist
from torch.distributed.optim import (
    _apply_optimizer_in_backward as apply_optimizer_in_backward,
)
from torch.utils.data import DataLoader
from torch import nn
from torchrec import inference as trec_infer
from torchrec.distributed import TrainPipelineSparseDist
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.inference.state_dict_transform import (
    state_dict_gather,
    state_dict_to_device,
)
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.optim.keyed import KeyedOptimizerWrapper
from torchrec.optim.rowwise_adagrad import RowWiseAdagrad
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.datasets.utils import Batch
from torch.distributed._sharded_tensor import ShardedTensor

from collections import defaultdict
from functools import partial
import mlflow

from typing import Tuple, List, Optional
from torchrec.modules.mlp import MLP

from pyspark.ml.torch.distributor import TorchDistributor
from tqdm import tqdm
import torchmetrics as metrics

2.2. Helper functions for Converting to Pipelineable DataType

Using TorchRec pipelines requires a pipelineable data type (which is Batch in this case). In this step, you create a helper function that takes each batch from the StreamingDataset and passes it through a transformation function to convert it into a pipelineable batch.

For further context, see https://github.com/pytorch/torchrec/blob/main/torchrec/datasets/utils.py#L28.

# TODO: This is from earlier outputs (section 1.2, cell 2); if another dataset is being used, these values need to be updated
cat_cols = ["userId", "movieId"]
emb_counts = [193, 9740]

def transform_to_torchrec_batch(batch, num_embeddings_per_feature: Optional[List[int]] = None) -> Batch:
    kjt_values: List[int] = []
    kjt_lengths: List[int] = []
    for col_idx, col_name in enumerate(cat_cols):
        values = batch[col_name]
        for value in values:
            if value:
                kjt_values.append(
                    value % num_embeddings_per_feature[col_idx]
                )
                kjt_lengths.append(1)
            else:
                kjt_lengths.append(0)

    sparse_features = KeyedJaggedTensor.from_lengths_sync(
        cat_cols,
        torch.tensor(kjt_values),
        torch.tensor(kjt_lengths, dtype=torch.int32),
    )
    labels = torch.tensor(batch["label"], dtype=torch.int32)
    assert isinstance(labels, torch.Tensor)

    return Batch(
        dense_features=torch.zeros(1),
        sparse_features=sparse_features,
        labels=labels,
    )

transform_partial = partial(transform_to_torchrec_batch, num_embeddings_per_feature=emb_counts)

2.3. Helper Function for DataLoading using Mosaic's StreamingDataset

This utilizes Mosaic's StreamingDataset and Mosaic's StreamingDataLoader for efficient data loading. For more information, view this documentation.

def get_dataloader_with_mosaic(path, batch_size, label):
    print(f"Getting {label} data from UC Volumes")
    dataset = StreamingDataset(local=path, shuffle=True, batch_size=batch_size)
    return StreamingDataLoader(dataset, batch_size=batch_size)

3. Creating the Relevant TorchRec code for Training

This section contains all of the training and evaluation code.

3.1. Two Tower Model Definition

This is taken directly from the torchrec example's page. Note that the loss is the Binary Cross Entropy loss, which requires labels to be within the values {0, 1}.

import torch.nn.functional as F

class TwoTower(nn.Module):
    def __init__(
        self,
        embedding_bag_collection: EmbeddingBagCollection,
        layer_sizes: List[int],
        device: Optional[torch.device] = None
    ) -> None:
        super().__init__()

        assert len(embedding_bag_collection.embedding_bag_configs()) == 2, "Expected two EmbeddingBags in the two tower model"
        assert embedding_bag_collection.embedding_bag_configs()[0].embedding_dim == embedding_bag_collection.embedding_bag_configs()[1].embedding_dim, "Both EmbeddingBagConfigs must have the same dimension"

        embedding_dim = embedding_bag_collection.embedding_bag_configs()[0].embedding_dim
        self._feature_names_query: List[str] = embedding_bag_collection.embedding_bag_configs()[0].feature_names
        self._candidate_feature_names: List[str] = embedding_bag_collection.embedding_bag_configs()[1].feature_names
        self.ebc = embedding_bag_collection
        self.query_proj = MLP(in_size=embedding_dim, layer_sizes=layer_sizes, device=device)
        self.candidate_proj = MLP(in_size=embedding_dim, layer_sizes=layer_sizes, device=device)

    def forward(self, kjt: KeyedJaggedTensor) -> Tuple[torch.Tensor, torch.Tensor]:
        pooled_embeddings = self.ebc(kjt)
        query_embedding: torch.Tensor = self.query_proj(
            torch.cat(
                [pooled_embeddings[feature] for feature in self._feature_names_query],
                dim=1,
            )
        )
        candidate_embedding: torch.Tensor = self.candidate_proj(
            torch.cat(
                [
                    pooled_embeddings[feature]
                    for feature in self._candidate_feature_names
                ],
                dim=1,
            )
        )
        return query_embedding, candidate_embedding


class TwoTowerTrainTask(nn.Module):
    def __init__(self, two_tower: TwoTower) -> None:
        super().__init__()
        self.two_tower = two_tower
        # The BCEWithLogitsLoss combines a sigmoid layer and binary cross entropy loss
        self.loss_fn: nn.Module = nn.BCEWithLogitsLoss()

    def forward(self, batch: Batch) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
        query_embedding, candidate_embedding = self.two_tower(batch.sparse_features)
        logits = (query_embedding * candidate_embedding).sum(dim=1).squeeze()
        loss = self.loss_fn(logits, batch.labels.float())
        return loss, (loss.detach(), logits.detach(), batch.labels.detach())

3.2. Base Dataclass for Training inputs

Feel free to modify any of the variables mentioned here, but note that the first layer for layer_sizes should be equivalent to embedding_dim.

from dataclasses import dataclass, field
import itertools 

# TODO: Update these values for training as needed
@dataclass
class Args:
    """
    Training arguments.
    """
    epochs: int = 3  # Training for one Epoch
    embedding_dim: int = 128  # Embedding dimension is 128
    layer_sizes: List[int] = field(default_factory=lambda: [128, 64]) # The layers for the two tower model are 128, 64 (with the final embedding size for the outputs being 64)
    learning_rate: float = 0.01
    batch_size: int = 1024 # Set a larger batch size due to the large size of dataset
    print_sharding_plan: bool = True
    print_lr: bool = False  # Optional, prints the learning rate at each iteration step
    validation_freq: int = None  # Optional, determines how often during training you want to run validation (# of training steps)
    limit_train_batches: int = None  # Optional, limits the number of training batches
    limit_val_batches: int = None  # Optional, limits the number of validation batches
    limit_test_batches: int = None  # Optional, limits the number of test batches

# Store the results in mlflow
def get_relevant_fields(args, cat_cols, emb_counts):
    fields_to_save = ["epochs", "embedding_dim", "layer_sizes", "learning_rate", "batch_size"]
    result = { key: getattr(args, key) for key in fields_to_save }
    # add dense cols
    result["cat_cols"] = cat_cols
    result["emb_counts"] = emb_counts
    return result

3.3. Training and Evaluation Helper Functions

def batched(it, n):
    assert n >= 1
    for x in it:
        yield itertools.chain((x,), itertools.islice(it, n - 1))

3.3.1. Helper Functions for Distributed Model Saving

# Two Tower and TorchRec use special tensors called ShardedTensors.
# This code localizes them and puts them in the same rank that is saved to MLflow.
def gather_and_get_state_dict(model):
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    state_dict = model.state_dict()
    gathered_state_dict = {}

    # Iterate over all items in the state_dict
    for fqn, tensor in state_dict.items():
        if isinstance(tensor, ShardedTensor):
            # Collect all shards of the tensor across ranks
            full_tensor = None
            if rank == 0:
                full_tensor = torch.zeros(tensor.size()).to(tensor.device)
            tensor.gather(0, full_tensor)
            if rank == 0:
                gathered_state_dict[fqn] = full_tensor
        else:
            # Directly add non-sharded tensors to the new state_dict
            if rank == 0:
                gathered_state_dict[fqn] = tensor

    return gathered_state_dict

def log_state_dict_to_mlflow(model, artifact_path) -> None:
    # All ranks participate in gathering
    state_dict = gather_and_get_state_dict(model)
    # Only rank 0 logs the state dictionary
    if dist.get_rank() == 0 and state_dict:
        mlflow.pytorch.log_state_dict(state_dict, artifact_path=artifact_path)

3.3.2. Helper Functions for Distributed Model Training and Evaluation

import torchmetrics as metrics

def evaluate(
    limit_batches: Optional[int],
    pipeline: TrainPipelineSparseDist,
    eval_dataloader: DataLoader,
    stage: str) -> Tuple[float, float]:
    """
    Evaluates model. Computes and prints AUROC and average loss. Helper function for train_val_test.

    Args:
        limit_batches (Optional[int]): Limits the dataloader to the first `limit_batches` batches.
        pipeline (TrainPipelineSparseDist): data pipeline.
        eval_dataloader (DataLoader): Dataloader for either the validation set or test set.
        stage (str): "val" or "test".

    Returns:
        Tuple[float, float]: a tuple of (average loss, AUROC)
    """
    pipeline._model.eval()
    device = pipeline._device

    iterator = itertools.islice(iter(eval_dataloader), limit_batches)

    # We are using the AUROC for binary classification
    auroc = metrics.AUROC(task="binary").to(device)

    is_rank_zero = dist.get_rank() == 0
    if is_rank_zero:
        pbar = tqdm(
            iter(int, 1),
            desc=f"Evaluating {stage} set",
            total=len(eval_dataloader),
            disable=False,
        )
    
    total_loss = torch.tensor(0.0).to(device)  # Initialize total_loss as a tensor on the same device as _loss
    total_samples = 0
    with torch.no_grad():
        while True:
            try:
                _loss, logits, labels = pipeline.progress(map(transform_partial, iterator))
                # Calculating AUROC
                preds = torch.sigmoid(logits)
                auroc(preds, labels)
                # Calculating loss
                total_loss += _loss.detach()  # Detach _loss to prevent gradients from being calculated
                total_samples += len(labels)
                if is_rank_zero:
                    pbar.update(1)
            except StopIteration:
                break
    
    auroc_result = auroc.compute().item()
    average_loss = total_loss / total_samples if total_samples > 0 else torch.tensor(0.0).to(device)
    average_loss_value = average_loss.item()

    if is_rank_zero:
        print(f"Average loss over {stage} set: {average_loss_value:.4f}.")
        print(f"AUROC over {stage} set: {auroc_result}")
    
    return average_loss_value, auroc_result
def train(
    pipeline: TrainPipelineSparseDist,
    train_dataloader: DataLoader,
    val_dataloader: DataLoader,
    epoch: int,
    print_lr: bool,
    validation_freq: Optional[int],
    limit_train_batches: Optional[int],
    limit_val_batches: Optional[int]) -> None:
    """
    Trains model for 1 epoch. Helper function for train_val_test.

    Args:
        pipeline (TrainPipelineSparseDist): data pipeline.
        train_dataloader (DataLoader): Training set's dataloader.
        val_dataloader (DataLoader): Validation set's dataloader.
        epoch (int): The number of complete passes through the training set so far.
        print_lr (bool): Whether to print the learning rate every training step.
        validation_freq (Optional[int]): The number of training steps between validation runs within an epoch.
        limit_train_batches (Optional[int]): Limits the training set to the first `limit_train_batches` batches.
        limit_val_batches (Optional[int]): Limits the validation set to the first `limit_val_batches` batches.

    Returns:
        None.
    """
    pipeline._model.train()

    # Get the first `limit_train_batches` batches
    iterator = itertools.islice(iter(train_dataloader), limit_train_batches)

    # Only print out the progress bar on rank 0
    is_rank_zero = dist.get_rank() == 0
    if is_rank_zero:
        pbar = tqdm(
            iter(int, 1),
            desc=f"Epoch {epoch}",
            total=len(train_dataloader),
            disable=False,
        )

    # TorchRec's pipeline paradigm is unique as it takes in an iterator of batches for training.
    start_it = 0
    n = validation_freq if validation_freq else len(train_dataloader)
    for batched_iterator in batched(iterator, n):
        for it in itertools.count(start_it):
            try:
                if is_rank_zero and print_lr:
                    for i, g in enumerate(pipeline._optimizer.param_groups):
                        print(f"lr: {it} {i} {g['lr']:.6f}")
                pipeline.progress(map(transform_partial, batched_iterator))
                if is_rank_zero:
                    pbar.update(1)
            except StopIteration:
                if is_rank_zero:
                    print("Total number of iterations:", it)
                start_it = it
                break

        # If you are validating frequently, use the evaluation function
        if validation_freq and start_it % validation_freq == 0:
            evaluate(limit_val_batches, pipeline, val_dataloader, "val")
            pipeline._model.train()
def train_val_test(args, model, optimizer, device, train_dataloader, val_dataloader, test_dataloader) -> None:
    """
    Train/validation/test loop.

    Args:
        args (Args): args for training.
        model (torch.nn.Module): model to train.
        optimizer (torch.optim.Optimizer): optimizer to use.
        device (torch.device): device to use.
        train_dataloader (DataLoader): Training set's dataloader.
        val_dataloader (DataLoader): Validation set's dataloader.
        test_dataloader (DataLoader): Test set's dataloader.

    Returns:
        TrainValTestResults.
    """
    pipeline = TrainPipelineSparseDist(model, optimizer, device)
    
    # Getting base auroc and saving it to mlflow
    val_loss, val_auroc = evaluate(args.limit_val_batches, pipeline, val_dataloader, "val")
    if int(os.environ["RANK"]) == 0:
        mlflow.log_metric('val_loss', val_loss)
        mlflow.log_metric('val_auroc', val_auroc)

    # Running a training loop
    for epoch in range(args.epochs):
        train(
            pipeline,
            train_dataloader,
            val_dataloader,
            epoch,
            args.print_lr,
            args.validation_freq,
            args.limit_train_batches,
            args.limit_val_batches,
        )

        # Evaluate after each training epoch
        val_loss, val_auroc = evaluate(args.limit_val_batches, pipeline, val_dataloader, "val")
        if int(os.environ["RANK"]) == 0:
            mlflow.log_metric('val_loss', val_loss)
            mlflow.log_metric('val_auroc', val_auroc)

        # Save the underlying model and results to mlflow
        log_state_dict_to_mlflow(pipeline._model.module, artifact_path=f"model_state_dict_{epoch}")
    
    # Evaluate on the test set after training loop finishes
    test_loss, test_auroc = evaluate(args.limit_test_batches, pipeline, test_dataloader, "test")
    if int(os.environ["RANK"]) == 0:
        mlflow.log_metric('test_loss', test_loss)
        mlflow.log_metric('test_auroc', test_auroc)
    return test_auroc

3.4. The Main Function

This function trains the Two Tower recommendation model. For more information, see the following guides/docs/code:

# TODO: Specify where the data is stored in UC Volumes
output_dir_train = "/Volumes/ml/recommender_systems/learning_from_sets_data/mds_train"
output_dir_validation = "/Volumes/ml/recommender_systems/learning_from_sets_data/mds_validation"
output_dir_test = "/Volumes/ml/recommender_systems/learning_from_sets_data/mds_test"

from torchrec.distributed.comm import get_local_size
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.storage_reservations import (
    HeuristicalStorageReservation,
)
from torchrec.distributed.model_parallel import (
    DistributedModelParallel,
    get_default_sharders,
)

def main(args: Args):
    import torch
    import mlflow
    from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict

    # Some preliminary torch setup
    torch.jit._state.disable()
    global_rank = int(os.environ["RANK"])
    local_rank = int(os.environ["LOCAL_RANK"])
    device = torch.device(f"cuda:{local_rank}")
    backend = "nccl"
    torch.cuda.set_device(device)

    # Start MLflow
    os.environ['DATABRICKS_HOST'] = db_host
    os.environ['DATABRICKS_TOKEN'] = db_token
    experiment = mlflow.set_experiment(experiment_path)

    # Save parameters to MLflow
    if global_rank == 0:
        param_dict = get_relevant_fields(args, cat_cols, emb_counts)
        mlflow.log_params(param_dict)

    # Start distributed process group
    dist.init_process_group(backend=backend)

    # Loading the data
    train_dataloader = get_dataloader_with_mosaic(output_dir_train, args.batch_size, "train")
    val_dataloader = get_dataloader_with_mosaic(output_dir_validation, args.batch_size, "val")
    test_dataloader = get_dataloader_with_mosaic(output_dir_test, args.batch_size, "test")

    # Create the embedding tables
    eb_configs = [
        EmbeddingBagConfig(
            name=f"t_{feature_name}",
            embedding_dim=args.embedding_dim,
            num_embeddings=emb_counts[feature_idx],
            feature_names=[feature_name],
        )
        for feature_idx, feature_name in enumerate(cat_cols)
    ]

    # Create the Two Tower model
    embedding_bag_collection = EmbeddingBagCollection(
        tables=eb_configs,
        device=torch.device("meta"),
    )
    two_tower_model = TwoTower(
        embedding_bag_collection=embedding_bag_collection,
        layer_sizes=args.layer_sizes,
        device=device,
    )
    two_tower_train_task = TwoTowerTrainTask(two_tower_model)
    apply_optimizer_in_backward(
        RowWiseAdagrad,
        two_tower_train_task.two_tower.ebc.parameters(),
        {"lr": args.learning_rate},
    )

    # Create a plan to shard the embedding tables across the GPUs and creating a distributed model
    planner = EmbeddingShardingPlanner(
        topology=Topology(
            local_world_size=get_local_size(),
            world_size=dist.get_world_size(),
            compute_device=device.type,
        ),
        batch_size=args.batch_size,
        # If you get an out-of-memory error, increase the percentage. See
        # https://pytorch.org/torchrec/torchrec.distributed.planner.html#torchrec.distributed.planner.storage_reservations.HeuristicalStorageReservation
        storage_reservation=HeuristicalStorageReservation(percentage=0.05),
    )
    plan = planner.collective_plan(
        two_tower_model, get_default_sharders(), dist.GroupMember.WORLD
    )
    model = DistributedModelParallel(
        module=two_tower_train_task,
        device=device,
    )

    # Print out the sharding information to see how the embedding tables are sharded across the GPUs
    if global_rank == 0 and args.print_sharding_plan:
        for collectionkey, plans in model._plan.plan.items():
            print(collectionkey)
            for table_name, plan in plans.items():
                print(table_name, "\n", plan, "\n")
    
    log_state_dict_to_mlflow(model.module.two_tower, artifact_path="model_state_dict_base")

    optimizer = KeyedOptimizerWrapper(
        dict(model.named_parameters()),
        lambda params: torch.optim.Adam(params, lr=args.learning_rate),
    )

    # Start the training loop
    results = train_val_test(
        args,
        model,
        optimizer,
        device,
        train_dataloader,
        val_dataloader,
        test_dataloader,
    )

    # Destroy the process group
    dist.destroy_process_group()

3.5. Setting up MLflow

Note: You must update the route for db_host to the URL of your Databricks workspace.

username = spark.sql("SELECT current_user()").first()['current_user()']
username

experiment_path = f'/Users/{username}/torchrec-learning-from-sets-example'
 
# TODO: Update the `db_host` with the URL for your Databricks workspace
db_host = "https://workspace-name.cloud.databricks.com/"
db_token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
 
# Manually create the experiment so that you know the id and can send that to the worker nodes when you scale later.
experiment = mlflow.set_experiment(experiment_path)

4. Single Node + Single GPU Training

Here, you set the environment variables to run training over the sample set of 100,000 data points (stored in Volumes in Unity Catalog and collected using Mosaic StreamingDataset). You can expect each epoch to take ~16 minutes.

os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"

args = Args()
main(args)

5. Single Node - Multi GPU Training

This notebook uses TorchDistributor to handle training on a g4dn.12xlarge instance with 4 T4 GPUs. You can view the sharding plan in the output logs to see what tables are located on what GPUs. This takes ~8 minutes to run per epoch.

Note: There may be cases where you receive unexpected errors (like the Python Kernel crashing or segmentation faults). This is a transient error and the easiest way to overcome it is to skip the single node single GPU training code before you run any distributed code (single node multi GPU or multi node multi GPU).

Note: If you see any errors that are associated with Mosaic Data Loading, these are transient errors that can be overcome by rerunning the failed cell.

args = Args()
TorchDistributor(num_processes=4, local_mode=True, use_gpu=True).run(main, args)

6. Multi Node + Multi GPU Training

This is tested with a g4dn.12xlarge instance as a worker (with 4 T4 GPUs). You can view the sharding plan in the output logs to see what tables are located on what GPUs. This takes ~6 minutes to run per epoch.

Note: There may be cases where you receive unexpected errors (like the Python Kernel crashing or segmentation faults). This is a transient error and the easiest way to overcome it is to skip the single node single GPU training code before you run any distributed code (single node multi GPU or multi node multi GPU).

Note: If you see any errors that are associated with Mosaic Data Loading, these are transient errors that can be overcome by rerunning the failed cell.

args = Args()
TorchDistributor(num_processes=4, local_mode=False, use_gpu=True).run(main, args)

7. Inference

Because the Two Tower Model's state_dicts are logged to MLflow, you can use the following code to load any of the saved state_dicts and create the associated Two Tower model with it. You can further expand this by 1) saving the loaded model to mlflow for inference or 2) doing batch inference using a UDF.

Note: The saving code and loading code is used for loading the entire Two Tower model on one node and is useful as an example. In real world use cases, the expected model size could be significant (as the embedding tables can scale with the number of users or the number of products and items). It might be worthwhile to consider distributed inference.

7.1. Creating the Two Tower model from saved state_dict

Note: You must update this with the correct run_id and path to the MLflow artifact.

def get_mlflow_model(run_id, artifact_path="model_state_dict"):
    from mlflow import MlflowClient

    device = torch.device("cuda")
    run = mlflow.get_run(run_id)
    
    cat_cols = eval(run.data.params.get('cat_cols'))
    emb_counts = eval(run.data.params.get('emb_counts'))
    layer_sizes = eval(run.data.params.get('layer_sizes'))
    embedding_dim = eval(run.data.params.get('embedding_dim'))

    MlflowClient().download_artifacts(run_id, f"{artifact_path}/state_dict.pth", "/databricks/driver")
    state_dict = mlflow.pytorch.load_state_dict(f"/databricks/driver/{artifact_path}")
    
    # Remove the prefix "two_tower." from all of the keys in the state_dict
    state_dict = {k[10:]: v for k, v in state_dict.items()}

    eb_configs = [
        EmbeddingBagConfig(
            name=f"t_{feature_name}",
            embedding_dim=embedding_dim,
            num_embeddings=emb_counts[feature_idx],
            feature_names=[feature_name],
        )
        for feature_idx, feature_name in enumerate(cat_cols)
    ]

    embedding_bag_collection = EmbeddingBagCollection(
        tables=eb_configs,
        device=device,
    )
    two_tower_model = TwoTower(
        embedding_bag_collection=embedding_bag_collection,
        layer_sizes=layer_sizes,
        device=device,
    )

    two_tower_model.load_state_dict(state_dict)

    return two_tower_model, cat_cols, emb_counts

# Load the model (epoch 2) from the MLflow run
# TODO: Update this with the correct run_id and path
two_tower_model, cat_cols, emb_counts = get_mlflow_model("f69cbccb2f6949efa4794b46ddcaceb4", artifact_path="model_state_dict_1")

7.2. Helper Function to Transform Dataloader to Two Tower Inputs

The inputs that Two Tower expects are: sparse_features, so this section reuses aspects of the code from Section 3.4.2. The code shown here is verbose for clarity.

def transform_test(batch, cat_cols, emb_counts):
    kjt_values: List[int] = []
    kjt_lengths: List[int] = []
    for col_idx, col_name in enumerate(cat_cols):
        values = batch[col_name]
        for value in values:
            if value:
                kjt_values.append(
                    value % emb_counts[col_idx]
                )
                kjt_lengths.append(1)
            else:
                kjt_lengths.append(0)

    sparse_features = KeyedJaggedTensor.from_lengths_sync(
        cat_cols,
        torch.tensor(kjt_values),
        torch.tensor(kjt_lengths, dtype=torch.int32),
    )
    return sparse_features

7.3. Getting the Data

num_batches = 5 # Number of batches to print out at a time 
batch_size = 1 # Print out each individual row

# TODO: Update this path to point to the test dataset stored in UC Volumes
test_data_path = "/Volumes/ml/recommender_systems/learning_from_sets_data/mds_test"
test_dataloader = iter(get_dataloader_with_mosaic(test_data_path, batch_size, "test"))

7.4. Running Tests

In this example, you ran training for 3 epochs. The results were reasonable. Running a larger number of epochs would likely lead to optimal performance.

for _ in range(num_batches):
    device = torch.device("cuda:0")
    two_tower_model.to(device)
    two_tower_model.eval()

    next_batch = next(test_dataloader)
    expected_result = next_batch["label"][0]
    
    sparse_features = transform_test(next_batch, cat_cols, emb_counts)
    sparse_features = sparse_features.to(device)
    
    query_embedding, candidate_embedding = two_tower_model(kjt=sparse_features)
    actual_result = (query_embedding * candidate_embedding).sum(dim=1).squeeze()
    actual_result = torch.sigmoid(actual_result)
    print(f"Expected Result: {expected_result}; Actual Result: {actual_result.round().item()}")

8. Model Serving and Vector Search

For information about how to serve the model, see the Databricks Model Serving documentation (AWS | Azure).

Also, the Two Tower model is unique as it generates a query and candidate embedding, and therefore, allows you to create a vector index of movies, and then allows you to find the K movies that a user (given their generated vector) would most likely give a high rating. For more information, view the code here for how to create your own FAISS Index. You can also take a similar approach with Databricks Vector Search (AWS | Azure).

;