databricks-logo

dlrm-recommender-model

(Python)
Loading...

DLRM (using TorchRec + TorchDistributor + StreamingDataset)

This notebook illustrates how to create a distributed DLRM recommendation model for predicting click-through rates. This notebook was tested on g4dn.12xlarge instances (one instance as the driver, one instance as the worker) on the Databricks Runtime for ML 14.3 LTS. It uses some code from the Facebook DLRM implementation (which has an MIT License) For more insight into the DLRM recommendation model, see the following resources:

Note: Where you see # TODO in this notebook, you may need to enter custom code to ensure that the notebook runs successfully.

Requirements

This notebook requires 14.3 LTS ML.

Step 1. Saving Data in UC Volumes in the MDS (Mosaic Data Shard) format

This notebook creates a synthetic dataset with 100k rows that will be used to train a DLRM model.

%pip install -q mosaicml-streaming==0.7.5
dbutils.library.restartPython()

1.1. Creating a Synthetic Dataset

This notebook creates a synthetic dataset for predicting a binary label given both dense (numerical) and sparse (categorical) data. This synthetic dataset has a similar layout to other publicly available datasets, such as the Criteo click logs dataset. You can update this notebook to support those datasets as long as the data preprocessing is done correctly.

For a tangible example in retail, the numerical columns could represent features like the user's age, product's weight, or time of day, and the sparse columns could represent features like user's location, product's category, and so on. The label column describes the interaction between the user and the product. For example, a positive label of 1 might indicate that the user would buy the product, while a negative label of 0 might indicate that the user would give the product a 1-star rating.

import pandas as pd
import numpy as np
import random

# Set random seed for reproducibility
random.seed(42)
np.random.seed(42)

# Parameters for dataset generation
num_samples = 100000
num_int_features = 10
num_cat_features = 10

# Define the ranges for each integer feature
int_feature_ranges = [(0, random.randint(501, 10000)) for _ in range(num_int_features)]

# Generate unique category sizes for each categorical feature
cat_feature_categories = [np.random.randint(1, 201) for _ in range(num_cat_features)]

# Generate label column (0 or 1 depending on whether the interaction is negative or positive)
labels = np.random.randint(0, 2, size=num_samples)

# Generate integer features with the specified range
int_features = np.column_stack([
    np.random.randint(low, high, size=num_samples) 
    for (low, high) in int_feature_ranges
])

# Generate categorical features with different unique categories
cat_features = np.column_stack([
    np.random.randint(0, num_categories, size=num_samples) 
    for num_categories in cat_feature_categories
])

# Combine all features into a DataFrame
data = np.column_stack((labels, int_features, cat_features))
columns = ['label'] + [f'int_{i}' for i in range(1, num_int_features+1)] + [f'cat_{i}' for i in range(1, num_cat_features+1)]
df = pd.DataFrame(data, columns=columns)

# Convert to a Spark DataFrame and write the dataset to a UC Table
spark_df = spark.createDataFrame(df)
spark_df.write.saveAsTable("ml.recommender_systems.dlrm_sample_dataset")
# TODO: Update this with the path in UC where this data is saved
df = spark.table("ml.recommender_systems.dlrm_sample_dataset")
display(df)

1.2. Preprocessing the Data

If you are using a dataset other than the provided synthetic dataset, update this cell to for preprocessing and data cleaning as needed. For this synthetic dataset, all that is required is to normalize the dense columns.

Note: You can repartition the dataset as needed to help improve performance for this cell.

from pyspark.ml.feature import StandardScaler, StringIndexer, VectorAssembler
from pyspark.ml import Pipeline
from pyspark.sql.functions import udf, col
from pyspark.sql.types import DoubleType, IntegerType
from pyspark.ml.linalg import VectorUDT

# UDF to extract the first element of the vector
extract_value = udf(lambda vector: float(vector[0]), DoubleType())

# List of dense features
dense_cols = [c for c in df.columns if 'int' in c]

# Normalize dense columns
stages = []
for column in dense_cols:
    assembler = VectorAssembler(inputCols=[column], outputCol=column + "_vec")
    scaler = StandardScaler(inputCol=column + "_vec", outputCol=column + "_scaled", withStd=True, withMean=True)
    stages += [assembler, scaler]

# Define the pipeline
pipeline = Pipeline(stages=stages)
model = pipeline.fit(df)

# Transform the dataframe
df_transformed = model.transform(df)

# Use the UDF to overwrite the scaled vector columns with the single extracted value
for column in dense_cols:
    df_transformed = df_transformed.withColumn(column, extract_value(col(column + "_scaled")))

# Drop the intermediate columns and update the dataframe with transformed dense columns
for column in dense_cols:
    if column in dense_cols:
        df_transformed = df_transformed.drop(column + "_vec").drop(column + "_scaled")

# Display the transformed dataset
display(df_transformed)
from pyspark.sql.functions import col

# Identifying dense and categorical columns from the Spark DataFrame 'df'
dense_cols = [c for c in df_transformed.columns if 'int' in c]
cat_cols = [c for c in df_transformed.columns if 'cat' in c]

# Calculating the number of unique values (distinct count) for each categorical column
emb_counts = [df_transformed.select(c).distinct().count() for c in cat_cols]

# Printing the results
print(f"The number of rows in df are {df_transformed.count()}")
print(f"There are {len(dense_cols)} dense columns: {dense_cols}")
print(f"There are {len(cat_cols)} categorical columns, where the max embedding count is {max(emb_counts)}:")

# Creating a dictionary to print the categorical column names alongside their unique counts
emb_count_dict = {cat_col: emb_count for cat_col, emb_count in zip(cat_cols, emb_counts)}
print(emb_count_dict)
# Splitting the dataframe into train, test, and validation sets
train_df, validation_df, test_df = df_transformed.randomSplit([0.7, 0.2, 0.1], seed=42)

# Showing the count of each split to verify our distribution
print(f"Training Dataset Count: {train_df.count()}")
print(f"Validation Dataset Count: {validation_df.count()}")
print(f"Test Dataset Count: {test_df.count()}")

1.3. Saving to MDS Format within UC Volumes

In this step, you convert the data to MDS to allow for efficient train/validation/test splitting and then save it to a UC Volume.

View the Mosaic Streaming guide here for more details:

  1. General details: https://docs.mosaicml.com/projects/streaming/en/stable/
  2. Main concepts: https://docs.mosaicml.com/projects/streaming/en/stable/getting_started/main_concepts.html#dataset-conversion
  3. dataframeToMDS details: https://docs.mosaicml.com/projects/streaming/en/stable/preparing_datasets/spark_dataframe_to_mds.html
from streaming import StreamingDataset
from streaming.base.converters import dataframe_to_mds
from streaming.base import MDSWriter
from shutil import rmtree
import os
from tqdm import tqdm

# Parameters required for saving data in MDS format
dense_dict = { key: 'float64' for key in dense_cols }
cat_dict = { key: 'int64' for key in cat_cols }
label_dict = { 'label' : 'int64' }
columns = {**label_dict, **dense_dict, **cat_dict}

compression = 'zstd:7'

# TODO: Specify where the data will be stored
output_dir_train = "/Volumes/ml/recommender_systems/dlrm_sample_data/mds_train"
output_dir_validation = "/Volumes/ml/recommender_systems/dlrm_sample_data/mds_validation"
output_dir_test = "/Volumes/ml/recommender_systems/dlrm_sample_data/mds_test"

# Save the training data using the `dataframe_to_mds` function, which divides the dataframe into `num_workers` parts and merges the `index.json` from each part into one in a parent directory.
def save_data(df, output_path, label, num_workers=20):
    print(f"Saving {label} data to: {output_path}")
    mds_kwargs = {'out': output_path, 'columns': columns, 'compression': compression}
    dataframe_to_mds(df.repartition(num_workers), merge_index=True, mds_kwargs=mds_kwargs)

save_data(train_df, output_dir_train, 'train')
save_data(validation_df, output_dir_validation, 'validation')
save_data(test_df, output_dir_test, 'test')

Step 2. Helper Functions for Recommendation Dataloading

In this step, you install the necessary libraries, add imports, and add some relevant helper functions to train the model.

2.1. Installs and Imports

%pip install -q --upgrade --no-deps --force-reinstall torch==2.2.2 torchvision==0.17.2 torchrec==0.6.0 fbgemm-gpu==0.6.0 --index-url https://download.pytorch.org/whl/cu118
%pip install torchmetrics==1.0.3 iopath==0.1.10 pyre_extensions==0.0.32 mosaicml-streaming==0.7.5 
dbutils.library.restartPython()
import contextlib
import tempfile
from typing import Generator
import csv
import os
import random

from streaming import StreamingDataset, StreamingDataLoader
from streaming.base import MDSWriter

import torch.distributed as dist
from torch.utils.data import DataLoader
import itertools
import torchmetrics as metrics
from tqdm import tqdm
from dataclasses import dataclass, field
from typing import List, Optional
import torch

from torchrec import EmbeddingBagCollection
from torchrec.distributed import TrainPipelineSparseDist
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.model_parallel import (
    DistributedModelParallel,
    get_default_sharders,
)
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.storage_reservations import (
    HeuristicalStorageReservation,
)
from torchrec.models.dlrm import DLRM, DLRM_DCN, DLRM_Projection, DLRMTrain
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
from torchrec.optim.optimizers import in_backward_optimizer_filter
import sys

from torch.optim.lr_scheduler import _LRScheduler
from pyre_extensions import none_throws
import torch
from torchrec.datasets.utils import Batch
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from typing import Optional, List, Iterable, cast
from collections import defaultdict
from functools import partial
import dataclasses
from dataclasses import asdict

import mlflow
import torch
from torch import nn
from torch.distributed._sharded_tensor import ShardedTensor

from pyspark.ml.torch.distributor import TorchDistributor

2.2. Helper functions for Converting to Pipelineable DataType

Using TorchRec pipelines requires a pipelineable data type (which is Batch in this case). In this step, you create a helper function that takes each batch from the StreamingDataset and passes it through a transformation function to convert it into a pipelineable batch.

For further context, see:

  1. https://github.com/pytorch/torchrec/blob/29f503a8855040bc49788d3ad24e7ab93d944885/torchrec/datasets/utils.py#L28
# TODO: This is from earlier outputs (section 1.2, cell 2); if another dataset is being used, these values need to be updated
cat_data = {'cat_1': 103, 'cat_2': 180, 'cat_3': 93, 'cat_4': 15, 'cat_5': 107, 'cat_6': 72, 'cat_7': 189, 'cat_8': 21, 'cat_9': 103, 'cat_10': 122}
dense_cols = [f"int_{i}" for i in range(1, 11)]
cat_cols = [f"cat_{i}" for i in range(1, 11)]
emb_counts = [cat_data[k] for k in cat_cols]

def transform_to_torchrec_batch(batch, num_embeddings_per_feature: Optional[List[int]] = None) -> Batch:
    cat_list = []
    for col_name in dense_cols:
        val = torch.tensor(batch[col_name], dtype=torch.float32)
        cat_list.append(val.unsqueeze(0).T)
    dense_features = torch.cat(cat_list, dim=1)

    kjt_values: List[int] = []
    kjt_lengths: List[int] = []
    for col_idx, col_name in enumerate(cat_cols):
        values = batch[col_name]
        for value in values:
            if value:
                kjt_values.append(
                    value % num_embeddings_per_feature[col_idx]
                )
                kjt_lengths.append(1)
            else:
                kjt_lengths.append(0)

    sparse_features = KeyedJaggedTensor.from_lengths_sync(
        cat_cols,
        torch.tensor(kjt_values),
        torch.tensor(kjt_lengths, dtype=torch.int32),
    )
    labels = torch.tensor(batch["label"], dtype=torch.int32)
    assert isinstance(labels, torch.Tensor)

    return Batch(
        dense_features=dense_features,
        sparse_features=sparse_features,
        labels=labels,
    )

transform_partial = partial(transform_to_torchrec_batch, num_embeddings_per_feature=emb_counts)

2.3. Helper Function for DataLoading using Mosaic's StreamingDataset

This utilizes Mosaic's StreamingDataset and Mosaic's StreamingDataLoader for efficient data loading. For more information, view this documentation.

def get_dataloader_with_mosaic(path, batch_size, label):
    print(f"Getting {label} data from UC Volumes")
    dataset = StreamingDataset(local=path, shuffle=True, batch_size=batch_size)
    return StreamingDataLoader(dataset, batch_size=batch_size)

Step 3. Creating the Relevant TorchRec code for Training

This contains all of the training and evaluation code.

3.1. Base Dataclass for Training inputs

Feel free to modify any of the variables mentioned here, but note that the final layer for dense_arch_layer_sizes should be equivalent to embedding_dim.

# TODO: Update these values for training as needed
@dataclass
class Args:
    """
    Training arguments.
    """
    epochs: int = 3  # Training for one Epoch
    embedding_dim: int = 128  # Embedding dimension is 128
    dense_arch_layer_sizes: list = dataclasses.field(default_factory=lambda: [512, 256, 128])  # The last layer for dense should be equivalent to `embedding_dim`
    over_arch_layer_sizes: list = dataclasses.field(default_factory=lambda: [512, 512, 256, 1])  # We are essentially doing binary classification here, so the last layer is 1
    learning_rate: float = 0.03
    eps: float = 1e-8
    batch_size: int = 512
    print_sharding_plan: bool = True
    print_lr: bool = False  # Optional, prints the learning rate at each iteration step
    lr_warmup_steps: int = 0  # Optional, sets the learning rate warmup steps
    lr_decay_start: int = 0  # Optional, sets the decay start for learning rate
    lr_decay_steps: int = 0  # Optional, sets the decay steps for learning rate
    validation_freq: int = None  # Optional, determines how often during training you want to run validation (# of training steps)
    limit_train_batches: int = None  # Optional, limits the number of training batches
    limit_val_batches: int = None  # Optional, limits the number of validation batches
    limit_test_batches: int = None  # Optional, limits the number of test batches

# We will use this to store our results in mlflow
def get_relevant_fields(args, dense_cols, cat_cols, emb_counts):
    fields_to_save = ["epochs", "embedding_dim", "dense_arch_layer_sizes", "over_arch_layer_sizes", "learning_rate", "eps", "batch_size"]
    result = { key: getattr(args, key) for key in fields_to_save }
    # adding dense cols
    result["dense_cols"] = dense_cols
    result["cat_cols"] = cat_cols
    result["emb_counts"] = emb_counts
    return result
@dataclass
class TrainValTestResults:
    """
    A dataclass to store our results.
    """
    val_aurocs: List[float] = field(default_factory=list)
    test_auroc: Optional[float] = None

3.2. LR Scheduler

This isn't specifically used unless you want to schedule the learning rate for the Adagrad Optimizer.

class LRPolicyScheduler(_LRScheduler):
    def __init__(self, optimizer, num_warmup_steps, decay_start_step, num_decay_steps):
        self.num_warmup_steps = num_warmup_steps
        self.decay_start_step = decay_start_step
        self.decay_end_step = decay_start_step + num_decay_steps
        self.num_decay_steps = num_decay_steps

        if self.decay_start_step < self.num_warmup_steps:
            sys.exit("Learning rate warmup must finish before the decay starts")

        super(LRPolicyScheduler, self).__init__(optimizer)

    def get_lr(self):
        step_count = self._step_count
        if step_count < self.num_warmup_steps:
            # warmup
            scale = 1.0 - (self.num_warmup_steps - step_count) / self.num_warmup_steps
            lr = [base_lr * scale for base_lr in self.base_lrs]
            self.last_lr = lr
        elif self.decay_start_step <= step_count and step_count < self.decay_end_step:
            # decay
            decayed_steps = step_count - self.decay_start_step
            scale = ((self.num_decay_steps - decayed_steps) / self.num_decay_steps) ** 2
            min_lr = 0.0000001
            lr = [max(min_lr, base_lr * scale) for base_lr in self.base_lrs]
            self.last_lr = lr
        else:
            if self.num_decay_steps > 0:
                # freeze at last, either because we're after decay
                # or because we're between warmup and decay
                lr = self.last_lr
            else:
                # do not adjust
                lr = self.base_lrs
        return lr

3.3. Training and Evaluation Helper Functions

def batched(it, n):
    assert n >= 1
    for x in it:
        yield itertools.chain((x,), itertools.islice(it, n - 1))

3.4.1. Helper Functions for Distributed Model Saving

# DLRM and TorchRec use special tensors called ShardedTensors, so we localize them
# and put them in the same rank that is saving to mlflow
def gather_and_get_state_dict(model):
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    state_dict = model.state_dict()
    gathered_state_dict = {}

    # Iterate over all items in the state_dict
    for fqn, tensor in state_dict.items():
        if isinstance(tensor, ShardedTensor):
            # Collect all shards of the tensor across ranks
            full_tensor = None
            if rank == 0:
                full_tensor = torch.zeros(tensor.size()).to(tensor.device)
            tensor.gather(0, full_tensor)
            if rank == 0:
                gathered_state_dict[fqn] = full_tensor
        else:
            # Directly add non-sharded tensors to the new state_dict
            if rank == 0:
                gathered_state_dict[fqn] = tensor

    return gathered_state_dict

def log_state_dict_to_mlflow(model, artifact_path) -> None:
    # All ranks participate in gathering
    state_dict = gather_and_get_state_dict(model)
    # Only rank 0 logs the state dictionary
    if dist.get_rank() == 0 and state_dict:
        mlflow.pytorch.log_state_dict(state_dict, artifact_path=artifact_path)

3.4.2. Helper Functions for Distributed Model Training and Evaluation

def train(
    pipeline: TrainPipelineSparseDist,
    train_dataloader: DataLoader,
    val_dataloader: DataLoader,
    epoch: int,
    lr_scheduler,
    print_lr: bool,
    validation_freq: Optional[int],
    limit_train_batches: Optional[int],
    limit_val_batches: Optional[int]) -> None:
    """
    Trains model for 1 epoch. Helper function for train_val_test.

    Args:
        pipeline (TrainPipelineSparseDist): data pipeline.
        train_dataloader (DataLoader): Training set's dataloader.
        val_dataloader (DataLoader): Validation set's dataloader.
        epoch (int): The number of complete passes through the training set so far.
        lr_scheduler (LRPolicyScheduler): Learning rate scheduler.
        print_lr (bool): Whether to print the learning rate every training step.
        validation_freq (Optional[int]): The number of training steps between validation runs within an epoch.
        limit_train_batches (Optional[int]): Limits the training set to the first `limit_train_batches` batches.
        limit_val_batches (Optional[int]): Limits the validation set to the first `limit_val_batches` batches.

    Returns:
        None.
    """
    pipeline._model.train()

    # Get the first `limit_train_batches` batches
    iterator = itertools.islice(iter(train_dataloader), limit_train_batches)

    # Only print out the progress bar on rank 0
    is_rank_zero = dist.get_rank() == 0
    if is_rank_zero:
        pbar = tqdm(
            iter(int, 1),
            desc=f"Epoch {epoch}",
            total=len(train_dataloader),
            disable=False,
        )

    # TorchRec's pipeline paradigm is unique as it takes in an iterator of batches for training.
    start_it = 0
    n = validation_freq if validation_freq else len(train_dataloader)
    for batched_iterator in batched(iterator, n):
        for it in itertools.count(start_it):
            try:
                if is_rank_zero and print_lr:
                    for i, g in enumerate(pipeline._optimizer.param_groups):
                        print(f"lr: {it} {i} {g['lr']:.6f}")
                pipeline.progress(map(transform_partial, batched_iterator))
                lr_scheduler.step()
                if is_rank_zero:
                    pbar.update(1)
            except StopIteration:
                if is_rank_zero:
                    print("Total number of iterations:", it)
                start_it = it
                break

        # If we are validating frequently, use the evaluation function
        if validation_freq and start_it % validation_freq == 0:
            evaluate(limit_val_batches, pipeline, val_dataloader, "val")
            pipeline._model.train()
def evaluate(
    limit_batches: Optional[int],
    pipeline: TrainPipelineSparseDist,
    eval_dataloader: DataLoader,
    stage: str) -> float:
    """
    Evaluates model. Computes and prints AUROC. Helper function for train_val_test.

    Args:
        limit_batches (Optional[int]): Limits the dataloader to the first `limit_batches` batches.
        pipeline (TrainPipelineSparseDist): data pipeline.
        eval_dataloader (DataLoader): Dataloader for either the validation set or test set.
        stage (str): "val" or "test".

    Returns:
        float: auroc result
    """
    pipeline._model.eval()
    device = pipeline._device

    iterator = itertools.islice(iter(eval_dataloader), limit_batches)

    # We are using AUROC for our metric
    auroc = metrics.AUROC(task="binary").to(device)

    is_rank_zero = dist.get_rank() == 0
    if is_rank_zero:
        pbar = tqdm(
            iter(int, 1),
            desc=f"Evaluating {stage} set",
            total=len(eval_dataloader),
            disable=False,
        )
    with torch.no_grad():
        while True:
            try:
                _loss, logits, labels = pipeline.progress(map(transform_partial, iterator))
                preds = torch.sigmoid(logits)
                auroc(preds, labels)
                if is_rank_zero:
                    pbar.update(1)
            except StopIteration:
                break

    auroc_result = auroc.compute().item()
    num_samples = torch.tensor(sum(map(len, auroc.target)), device=device)
    dist.reduce(num_samples, 0, op=dist.ReduceOp.SUM)

    if is_rank_zero:
        print(f"AUROC over {stage} set: {auroc_result}.")
        print(f"Number of {stage} samples: {num_samples}")
    return auroc_result
def train_val_test(args, model, optimizer, device, train_dataloader, val_dataloader, test_dataloader, lr_scheduler) -> TrainValTestResults:
    """
    Train/validation/test loop.

    Args:
        args (Args): the arguments for the model defined earlier.
        model (torch.nn.Module): model to train.
        optimizer (torch.optim.Optimizer): optimizer to use.
        device (torch.device): device to use.
        train_dataloader (DataLoader): Training set's dataloader.
        val_dataloader (DataLoader): Validation set's dataloader.
        test_dataloader (DataLoader): Test set's dataloader.
        lr_scheduler (LRPolicyScheduler): Learning rate scheduler.

    Returns:
        TrainValTestResults.
    """
    results = TrainValTestResults()
    pipeline = TrainPipelineSparseDist(model, optimizer, device, execute_all_batches=True)
    
    # Getting base auroc and saving it to mlflow
    val_auroc = evaluate(args.limit_val_batches, pipeline, val_dataloader, "val")
    results.val_aurocs.append(val_auroc)
    if int(os.environ["RANK"]) == 0:
        mlflow.log_metric('val_auroc', val_auroc)

    # Running a training loop
    for epoch in range(args.epochs):
        train(
            pipeline,
            train_dataloader,
            val_dataloader,
            epoch,
            lr_scheduler,
            args.print_lr,
            args.validation_freq,
            args.limit_train_batches,
            args.limit_val_batches,
        )

        # Evaluate after each training epoch
        val_auroc = evaluate(args.limit_val_batches, pipeline, val_dataloader, "val")
        results.val_aurocs.append(val_auroc)
        if int(os.environ["RANK"]) == 0:
            mlflow.log_metric('val_auroc', val_auroc)

        # Save the underlying model and results to mlflow
        log_state_dict_to_mlflow(pipeline._model.module, artifact_path=f"model_state_dict_{epoch}")
    
    # Evaluate on the test set after training loop finishes
    test_auroc = evaluate(args.limit_test_batches, pipeline, test_dataloader, "test")
    results.test_auroc = test_auroc
    if int(os.environ["RANK"]) == 0:
        mlflow.log_metric('test_auroc', test_auroc)
    return results

3.5. Setting up MLflow

Note: You must update the route for db_host to the URL of your Databricks workspace.

username = spark.sql("SELECT current_user()").first()['current_user()']
username
 
experiment_path = f'/Users/{username}/torchrec-dlrm-example'
 
# TODO: Update the `db_host` with the URL for your Databricks workspace
db_host = "https://workspace-name.cloud.databricks.com/"
db_token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
 
# We manually create the experiment so that we know the id and can send that to the worker nodes when we scale
experiment = mlflow.set_experiment(experiment_path)

3.6. The Main Function

This function will train the DLRM recommendation model. For more information, view the following guides/docs/code:

# TODO: Specify where the data is stored in UC Volumes
input_dir_train = "/Volumes/ml/recommender_systems/dlrm_sample_data/mds_train"
input_dir_validation = "/Volumes/ml/recommender_systems/dlrm_sample_data/mds_validation"
input_dir_test = "/Volumes/ml/recommender_systems/dlrm_sample_data/mds_test"

def main(args):
    import torch
    import mlflow
    from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict

    # Some preliminary torch setup
    torch.jit._state.disable()
    global_rank = int(os.environ["RANK"])
    local_rank = int(os.environ["LOCAL_RANK"])
    device = torch.device(f"cuda:{local_rank}")
    backend = "nccl"
    torch.cuda.set_device(device)

    # Starting mlflow
    os.environ['DATABRICKS_HOST'] = db_host
    os.environ['DATABRICKS_TOKEN'] = db_token
    experiment = mlflow.set_experiment(experiment_path)

    # Saving parameters to mlflow
    if global_rank == 0:
        param_dict = get_relevant_fields(args, dense_cols, cat_cols, emb_counts)
        mlflow.log_params(param_dict)

    # Straing distributed process group
    dist.init_process_group(backend=backend)

    # Loading the data
    train_dataloader = get_dataloader_with_mosaic(input_dir_train, args.batch_size, "train")
    val_dataloader = get_dataloader_with_mosaic(input_dir_validation, args.batch_size, "val")
    test_dataloader = get_dataloader_with_mosaic(input_dir_test, args.batch_size, "test")

    # Creating the embedding tables
    eb_configs = [
        EmbeddingBagConfig(
            name=f"t_{feature_name}",
            embedding_dim=args.embedding_dim,
            num_embeddings=emb_counts[feature_idx],
            feature_names=[feature_name],
        )
        for feature_idx, feature_name in enumerate(cat_cols)
    ]

    # Creating the DLRM model
    dlrm_model = DLRM(
            embedding_bag_collection=EmbeddingBagCollection(
                tables=eb_configs, device=torch.device("meta")
            ),
            dense_in_features=len(dense_cols),
            dense_arch_layer_sizes=args.dense_arch_layer_sizes,
            over_arch_layer_sizes=args.over_arch_layer_sizes,
            dense_device=device,
        )

    train_model = DLRMTrain(dlrm_model)

    # Setting up the optimizer
    embedding_optimizer = torch.optim.Adagrad
    optimizer_kwargs = {"lr": args.learning_rate, "eps": args.eps}

    apply_optimizer_in_backward(
        embedding_optimizer,
        train_model.model.sparse_arch.parameters(),
        optimizer_kwargs,
    )

    # Creating a plan to shard the embedding tables across the GPUs and creating a distributed model
    planner = EmbeddingShardingPlanner(
        topology=Topology(
            local_world_size=get_local_size(),
            world_size=dist.get_world_size(),
            compute_device=device.type,
        ),
        batch_size=args.batch_size,
        # If experience OOM, increase the percentage. see
        # https://pytorch.org/torchrec/torchrec.distributed.planner.html#torchrec.distributed.planner.storage_reservations.HeuristicalStorageReservation
        storage_reservation=HeuristicalStorageReservation(percentage=0.05),
    )
    plan = planner.collective_plan(
        train_model, get_default_sharders(), dist.GroupMember.WORLD
    )
    model = DistributedModelParallel(
        module=train_model,
        device=device,
        plan=plan,
    )

    log_state_dict_to_mlflow(model.module, artifact_path="model_state_dict_base")

    # Printing out the sharding information to see how the embedding tables are sharded across the GPUs
    if global_rank == 0 and args.print_sharding_plan:
        for collectionkey, plans in model._plan.plan.items():
            print(collectionkey)
            for table_name, plan in plans.items():
                print(table_name, "\n", plan, "\n")

    # Creating the optimizer for the distributed model
    def optimizer_with_params():
        return lambda params: torch.optim.Adagrad(params, lr=args.learning_rate, eps=args.eps)

    dense_optimizer = KeyedOptimizerWrapper(
        dict(in_backward_optimizer_filter(model.named_parameters())),
        optimizer_with_params(),
    )
    optimizer = CombinedOptimizer([model.fused_optimizer, dense_optimizer])
    lr_scheduler = LRPolicyScheduler(
        optimizer, args.lr_warmup_steps, args.lr_decay_start, args.lr_decay_steps
    )

    # Starting the training loop
    results = train_val_test(
        args,
        model,
        optimizer,
        device,
        train_dataloader,
        val_dataloader,
        test_dataloader,
        lr_scheduler,
    )

    # Destroying the process group
    dist.destroy_process_group()

Step 4. Single Node + Single GPU Training

Here, you set the environment variables to run training over the sample set of 100,000 data points (stored in Volumes in Unity Catalog and collected using Mosaic StreamingDataset). You can expect each epoch to take ~40 minutes.

os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"

args = Args()
main(args)

Step 5. Single Node + Multi GPU Training

This notebook uses the TorchDistributor for handling training on a g4dn.12xlarge instance with 4 T4 GPUs. You can view the sharding plan to see what tables are located on what GPUs. This takes ~14 minutes to run per epoch.

Note: There may be cases where you receive unexpected errors (like the Python Kernel crashing or segmentation faults). This is a transient error and the easiest way to overcome it is to skip the single node single GPU training code before we run any distributed code (single node multi GPU or multi node multi GPU).

Note: If you see any logs that are associated with Mosaic Data Loading, these are transient errors that can be overcome by simply rerunning the failed cell.

args = Args()
TorchDistributor(num_processes=4, local_mode=True, use_gpu=True).run(main, args)

Step 6. Multi Node + Multi GPU Training

This is tested with a g4dn.12xlarge instance as a worker (with 4 T4 GPUs). You can view the sharding plan to see what tables are located on what GPUs. This takes ~10 minutes to run per epoch.

Note: There may be cases where you receive unexpected errors (like the Python Kernel crashing or segmentation faults). This is a transient error and the easiest way to overcome it is to skip the single node single GPU training code before you run any distributed code (single node multi GPU or multi node multi GPU).

Note: If you see any logs that are associated with Mosaic Data Loading, these are transient errors that can be overcome by simply rerunning the failed cell.

args = Args()
TorchDistributor(num_processes=4, local_mode=False, use_gpu=True).run(main, args)

Step 7. Inference

Since the DLRM Model's state_dicts are logged to MLflow, you can use the following code to load any of the saved state_dicts and create the associated DLRM model with it. You can further expand this by 1) saving the loaded model to mlflow for inference or 2) doing batch inference using a UDF.

Note: The saving code and loading code is used for loading the entire DLRM model on one node and is useful as an example. In real world use cases, the expected model size could be significant (as the embedding tables can scale with the number of users or the number of products and items). It might be worthwhile to consider distributed inference.

7.1. Creating the DLRM model from saved state_dict

Note: You must update this with the correct run_id and path to the MLflow artifact.

from mlflow import MlflowClient

def get_latest_run_id(experiment):
    latest_run = mlflow.search_runs(experiment_ids=[experiment.experiment_id], order_by=["start_time desc"], max_results=1).iloc[0]
    return latest_run.run_id

def get_latest_artifact_path(run_id):
    client = MlflowClient()
    run = client.get_run(run_id)
    artifact_uri = run.info.artifact_uri
    artifact_paths = [i.path for i in client.list_artifacts(run_id) if "base" not in i.path]
    return artifact_paths[-1]

def get_mlflow_model(run_id, artifact_path="model_state_dict"):
    from mlflow import MlflowClient

    device = torch.device("cuda")
    run = mlflow.get_run(run_id)
    
    cat_cols = eval(run.data.params.get('cat_cols'))
    emb_counts = eval(run.data.params.get('emb_counts'))
    dense_cols = eval(run.data.params.get('dense_cols'))
    embedding_dim = int(run.data.params.get('embedding_dim'))
    dense_arch_layer_sizes = eval(run.data.params.get('dense_arch_layer_sizes'))
    over_arch_layer_sizes = eval(run.data.params.get('over_arch_layer_sizes'))

    MlflowClient().download_artifacts(run_id, f"{artifact_path}/state_dict.pth", "/databricks/driver")
    state_dict = mlflow.pytorch.load_state_dict(f"/databricks/driver/{artifact_path}")
    
    # Remove the prefix "model." from all of the keys in the state_dict
    state_dict = {k[6:]: v for k, v in state_dict.items()}

    eb_configs = [
        EmbeddingBagConfig(
            name=f"t_{feature_name}",
            embedding_dim=embedding_dim,
            num_embeddings=emb_counts[feature_idx],
            feature_names=[feature_name],
        )
        for feature_idx, feature_name in enumerate(cat_cols)
    ]

    dlrm_model = DLRM(
        embedding_bag_collection=EmbeddingBagCollection(
            tables=eb_configs, device=device
        ),
        dense_in_features=len(dense_cols),
        dense_arch_layer_sizes=dense_arch_layer_sizes,
        over_arch_layer_sizes=over_arch_layer_sizes,
        dense_device=device,
    )

    dlrm_model.load_state_dict(state_dict)

    return dlrm_model, dense_cols, cat_cols, emb_counts

# Loading the latest model state dict from the latest run of the current experiment
latest_run_id = get_latest_run_id(experiment)
latest_artifact_path = get_latest_artifact_path(latest_run_id)
dlrm_model, dense_cols, cat_cols, emb_counts = get_mlflow_model(latest_run_id, artifact_path=latest_artifact_path)

7.2. Helper Function to Transform Dataloader to DLRM Inputs

The inputs that DLRM expects are dense_features and sparse_features, and so this section reuses aspects of the code from Section 3.4.2. The code shown here is verbose for clarity.

def transform_test(batch, dense_cols, cat_cols, emb_counts):
    cat_list = []
    for col_name in dense_cols:
        val = torch.tensor(batch[col_name], dtype=torch.float32)
        cat_list.append(val.unsqueeze(0).T)
    dense_features = torch.cat(cat_list, dim=1)

    kjt_values: List[int] = []
    kjt_lengths: List[int] = []
    for col_idx, col_name in enumerate(cat_cols):
        values = batch[col_name]
        for value in values:
            if value:
                kjt_values.append(
                    value % emb_counts[col_idx]
                )
                kjt_lengths.append(1)
            else:
                kjt_lengths.append(0)

    sparse_features = KeyedJaggedTensor.from_lengths_sync(
        cat_cols,
        torch.tensor(kjt_values),
        torch.tensor(kjt_lengths, dtype=torch.int32),
    )
    return dense_features, sparse_features

7.3. Getting the Data

num_batches = 5 # Number of batches we want to print out at a time 
batch_size = 1 # We want to print out each individual row

# TODO: Update this path to point to the test dataset stored in UC Volumes
test_data_path = "/Volumes/ml/recommender_systems/dlrm_sample_data/mds_test"
test_dataloader = iter(get_dataloader_with_mosaic(test_data_path, batch_size, "test"))

7.4. Running Tests

In this example, you ran training for 3 epochs. The results were reasonable. Running a larger number of epochs would likely lead to optimal performance.

for _ in range(num_batches):
    device = torch.device("cuda:0")
    dlrm_model.to(device)
    dlrm_model.eval()

    next_batch = next(test_dataloader)
    expected_result = next_batch["label"][0]
    
    dense_features, sparse_features = transform_test(next_batch, dense_cols, cat_cols, emb_counts)
    dense_features = dense_features.to(device)
    sparse_features = sparse_features.to(device)
    
    actual_result = torch.sigmoid(dlrm_model(dense_features=dense_features, sparse_features=sparse_features))
    print(f"Expected Result: {expected_result}; Actual Result: {round(actual_result[0][0].item())}")

Step 8. Model Serving

For information about how to serve the model, see the Databricks Model Serving documentation (AWS | Azure).

;