%md # DLRM (using TorchRec + TorchDistributor + StreamingDataset) This notebook illustrates how to create a distributed DLRM recommendation model for predicting click-through rates. This notebook was tested on `g4dn.12xlarge` instances (one instance as the driver, one instance as the worker) on the Databricks Runtime for ML 14.3 LTS. It uses some code from the [Facebook DLRM implementation](https://github.com/facebookresearch/dlrm) (which has an [MIT License](https://github.com/facebookresearch/dlrm/blob/main/LICENSE)) For more insight into the DLRM recommendation model, see the following resources: - Facebook Research's repository: https://github.com/facebookresearch/dlrm - Nvidia's repository: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Recommendation/DLRM/README.md - Video overview: https://www.youtube.com/watch?v=r9J3UZmddC4 **Note: Where you see `# TODO` in this notebook, you may need to enter custom code to ensure that the notebook runs successfully.** ## Requirements This notebook requires <DBR> 14.3 LTS ML.
DLRM (using TorchRec + TorchDistributor + StreamingDataset)
This notebook illustrates how to create a distributed DLRM recommendation model for predicting click-through rates. This notebook was tested on g4dn.12xlarge
instances (one instance as the driver, one instance as the worker) on the Databricks Runtime for ML 14.3 LTS. It uses some code from the Facebook DLRM implementation (which has an MIT License) For more insight into the DLRM recommendation model, see the following resources:
- Facebook Research's repository: https://github.com/facebookresearch/dlrm
- Nvidia's repository: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Recommendation/DLRM/README.md
- Video overview: https://www.youtube.com/watch?v=r9J3UZmddC4
Note: Where you see # TODO
in this notebook, you may need to enter custom code to ensure that the notebook runs successfully.
Requirements
This notebook requires 14.3 LTS ML.
%md ## Step 1. Saving Data in UC Volumes in the MDS (Mosaic Data Shard) format This notebook creates a synthetic dataset with 100k rows that will be used to train a DLRM model.
Step 1. Saving Data in UC Volumes in the MDS (Mosaic Data Shard) format
This notebook creates a synthetic dataset with 100k rows that will be used to train a DLRM model.
%pip install -q mosaicml-streaming==0.7.5 dbutils.library.restartPython()
%md ### 1.1. Creating a Synthetic Dataset This notebook creates a synthetic dataset for predicting a binary label given both dense (numerical) and sparse (categorical) data. This synthetic dataset has a similar layout to other publicly available datasets, such as the [Criteo](https://ailab.criteo.com/download-criteo-1tb-click-logs-dataset/) click logs dataset. You can update this notebook to support those datasets as long as the data preprocessing is done correctly. For a tangible example in retail, the numerical columns could represent features like the user's age, product's weight, or time of day, and the sparse columns could represent features like user's location, product's category, and so on. The `label` column describes the interaction between the user and the product. For example, a positive `label` of 1 might indicate that the user would buy the product, while a negative `label` of 0 might indicate that the user would give the product a 1-star rating.
1.1. Creating a Synthetic Dataset
This notebook creates a synthetic dataset for predicting a binary label given both dense (numerical) and sparse (categorical) data. This synthetic dataset has a similar layout to other publicly available datasets, such as the Criteo click logs dataset. You can update this notebook to support those datasets as long as the data preprocessing is done correctly.
For a tangible example in retail, the numerical columns could represent features like the user's age, product's weight, or time of day, and the sparse columns could represent features like user's location, product's category, and so on. The label
column describes the interaction between the user and the product. For example, a positive label
of 1 might indicate that the user would buy the product, while a negative label
of 0 might indicate that the user would give the product a 1-star rating.
import pandas as pd import numpy as np import random # Set random seed for reproducibility random.seed(42) np.random.seed(42) # Parameters for dataset generation num_samples = 100000 num_int_features = 10 num_cat_features = 10 # Define the ranges for each integer feature int_feature_ranges = [(0, random.randint(501, 10000)) for _ in range(num_int_features)] # Generate unique category sizes for each categorical feature cat_feature_categories = [np.random.randint(1, 201) for _ in range(num_cat_features)] # Generate label column (0 or 1 depending on whether the interaction is negative or positive) labels = np.random.randint(0, 2, size=num_samples) # Generate integer features with the specified range int_features = np.column_stack([ np.random.randint(low, high, size=num_samples) for (low, high) in int_feature_ranges ]) # Generate categorical features with different unique categories cat_features = np.column_stack([ np.random.randint(0, num_categories, size=num_samples) for num_categories in cat_feature_categories ]) # Combine all features into a DataFrame data = np.column_stack((labels, int_features, cat_features)) columns = ['label'] + [f'int_{i}' for i in range(1, num_int_features+1)] + [f'cat_{i}' for i in range(1, num_cat_features+1)] df = pd.DataFrame(data, columns=columns) # Convert to a Spark DataFrame and write the dataset to a UC Table spark_df = spark.createDataFrame(df) spark_df.write.saveAsTable("ml.recommender_systems.dlrm_sample_dataset")
# TODO: Update this with the path in UC where this data is saved df = spark.table("ml.recommender_systems.dlrm_sample_dataset") display(df)
%md ### 1.2. Preprocessing the Data If you are using a dataset other than the provided synthetic dataset, update this cell to for preprocessing and data cleaning as needed. For this synthetic dataset, all that is required is to normalize the dense columns. **Note:** You can repartition the dataset as needed to help improve performance for this cell.
1.2. Preprocessing the Data
If you are using a dataset other than the provided synthetic dataset, update this cell to for preprocessing and data cleaning as needed. For this synthetic dataset, all that is required is to normalize the dense columns.
Note: You can repartition the dataset as needed to help improve performance for this cell.
from pyspark.ml.feature import StandardScaler, StringIndexer, VectorAssembler from pyspark.ml import Pipeline from pyspark.sql.functions import udf, col from pyspark.sql.types import DoubleType, IntegerType from pyspark.ml.linalg import VectorUDT # UDF to extract the first element of the vector extract_value = udf(lambda vector: float(vector[0]), DoubleType()) # List of dense features dense_cols = [c for c in df.columns if 'int' in c] # Normalize dense columns stages = [] for column in dense_cols: assembler = VectorAssembler(inputCols=[column], outputCol=column + "_vec") scaler = StandardScaler(inputCol=column + "_vec", outputCol=column + "_scaled", withStd=True, withMean=True) stages += [assembler, scaler] # Define the pipeline pipeline = Pipeline(stages=stages) model = pipeline.fit(df) # Transform the dataframe df_transformed = model.transform(df) # Use the UDF to overwrite the scaled vector columns with the single extracted value for column in dense_cols: df_transformed = df_transformed.withColumn(column, extract_value(col(column + "_scaled"))) # Drop the intermediate columns and update the dataframe with transformed dense columns for column in dense_cols: if column in dense_cols: df_transformed = df_transformed.drop(column + "_vec").drop(column + "_scaled") # Display the transformed dataset display(df_transformed)
from pyspark.sql.functions import col # Identifying dense and categorical columns from the Spark DataFrame 'df' dense_cols = [c for c in df_transformed.columns if 'int' in c] cat_cols = [c for c in df_transformed.columns if 'cat' in c] # Calculating the number of unique values (distinct count) for each categorical column emb_counts = [df_transformed.select(c).distinct().count() for c in cat_cols] # Printing the results print(f"The number of rows in df are {df_transformed.count()}") print(f"There are {len(dense_cols)} dense columns: {dense_cols}") print(f"There are {len(cat_cols)} categorical columns, where the max embedding count is {max(emb_counts)}:") # Creating a dictionary to print the categorical column names alongside their unique counts emb_count_dict = {cat_col: emb_count for cat_col, emb_count in zip(cat_cols, emb_counts)} print(emb_count_dict)
# Splitting the dataframe into train, test, and validation sets train_df, validation_df, test_df = df_transformed.randomSplit([0.7, 0.2, 0.1], seed=42) # Showing the count of each split to verify our distribution print(f"Training Dataset Count: {train_df.count()}") print(f"Validation Dataset Count: {validation_df.count()}") print(f"Test Dataset Count: {test_df.count()}")
%md ### 1.3. Saving to MDS Format within UC Volumes In this step, you convert the data to MDS to allow for efficient train/validation/test splitting and then save it to a UC Volume. View the Mosaic Streaming guide here for more details: 1. General details: https://docs.mosaicml.com/projects/streaming/en/stable/ 2. Main concepts: https://docs.mosaicml.com/projects/streaming/en/stable/getting_started/main_concepts.html#dataset-conversion 2. `dataframeToMDS` details: https://docs.mosaicml.com/projects/streaming/en/stable/preparing_datasets/spark_dataframe_to_mds.html
1.3. Saving to MDS Format within UC Volumes
In this step, you convert the data to MDS to allow for efficient train/validation/test splitting and then save it to a UC Volume.
View the Mosaic Streaming guide here for more details:
- General details: https://docs.mosaicml.com/projects/streaming/en/stable/
- Main concepts: https://docs.mosaicml.com/projects/streaming/en/stable/getting_started/main_concepts.html#dataset-conversion
dataframeToMDS
details: https://docs.mosaicml.com/projects/streaming/en/stable/preparing_datasets/spark_dataframe_to_mds.html
from streaming import StreamingDataset from streaming.base.converters import dataframe_to_mds from streaming.base import MDSWriter from shutil import rmtree import os from tqdm import tqdm # Parameters required for saving data in MDS format dense_dict = { key: 'float64' for key in dense_cols } cat_dict = { key: 'int64' for key in cat_cols } label_dict = { 'label' : 'int64' } columns = {**label_dict, **dense_dict, **cat_dict} compression = 'zstd:7' # TODO: Specify where the data will be stored output_dir_train = "/Volumes/ml/recommender_systems/dlrm_sample_data/mds_train" output_dir_validation = "/Volumes/ml/recommender_systems/dlrm_sample_data/mds_validation" output_dir_test = "/Volumes/ml/recommender_systems/dlrm_sample_data/mds_test" # Save the training data using the `dataframe_to_mds` function, which divides the dataframe into `num_workers` parts and merges the `index.json` from each part into one in a parent directory. def save_data(df, output_path, label, num_workers=20): print(f"Saving {label} data to: {output_path}") mds_kwargs = {'out': output_path, 'columns': columns, 'compression': compression} dataframe_to_mds(df.repartition(num_workers), merge_index=True, mds_kwargs=mds_kwargs) save_data(train_df, output_dir_train, 'train') save_data(validation_df, output_dir_validation, 'validation') save_data(test_df, output_dir_test, 'test')
%md ## Step 2. Helper Functions for Recommendation Dataloading In this step, you install the necessary libraries, add imports, and add some relevant helper functions to train the model.
Step 2. Helper Functions for Recommendation Dataloading
In this step, you install the necessary libraries, add imports, and add some relevant helper functions to train the model.
%md ### 2.1. Installs and Imports
2.1. Installs and Imports
%pip install -q --upgrade --no-deps --force-reinstall torch==2.2.2 torchvision==0.17.2 torchrec==0.6.0 fbgemm-gpu==0.6.0 --index-url https://download.pytorch.org/whl/cu118 %pip install torchmetrics==1.0.3 iopath==0.1.10 pyre_extensions==0.0.32 mosaicml-streaming==0.7.5 dbutils.library.restartPython()
import contextlib import tempfile from typing import Generator import csv import os import random from streaming import StreamingDataset, StreamingDataLoader from streaming.base import MDSWriter import torch.distributed as dist from torch.utils.data import DataLoader import itertools import torchmetrics as metrics from tqdm import tqdm from dataclasses import dataclass, field from typing import List, Optional import torch from torchrec import EmbeddingBagCollection from torchrec.distributed import TrainPipelineSparseDist from torchrec.distributed.comm import get_local_size from torchrec.distributed.model_parallel import ( DistributedModelParallel, get_default_sharders, ) from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology from torchrec.distributed.planner.storage_reservations import ( HeuristicalStorageReservation, ) from torchrec.models.dlrm import DLRM, DLRM_DCN, DLRM_Projection, DLRMTrain from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper from torchrec.optim.optimizers import in_backward_optimizer_filter import sys from torch.optim.lr_scheduler import _LRScheduler from pyre_extensions import none_throws import torch from torchrec.datasets.utils import Batch from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from typing import Optional, List, Iterable, cast from collections import defaultdict from functools import partial import dataclasses from dataclasses import asdict import mlflow import torch from torch import nn from torch.distributed._sharded_tensor import ShardedTensor from pyspark.ml.torch.distributor import TorchDistributor
%md ### 2.2. Helper functions for Converting to Pipelineable DataType Using TorchRec pipelines requires a pipelineable data type (which is `Batch` in this case). In this step, you create a helper function that takes each batch from the StreamingDataset and passes it through a transformation function to convert it into a pipelineable batch. For further context, see: 1. https://github.com/pytorch/torchrec/blob/29f503a8855040bc49788d3ad24e7ab93d944885/torchrec/datasets/utils.py#L28
2.2. Helper functions for Converting to Pipelineable DataType
Using TorchRec pipelines requires a pipelineable data type (which is Batch
in this case). In this step, you create a helper function that takes each batch from the StreamingDataset and passes it through a transformation function to convert it into a pipelineable batch.
For further context, see:
# TODO: This is from earlier outputs (section 1.2, cell 2); if another dataset is being used, these values need to be updated cat_data = {'cat_1': 103, 'cat_2': 180, 'cat_3': 93, 'cat_4': 15, 'cat_5': 107, 'cat_6': 72, 'cat_7': 189, 'cat_8': 21, 'cat_9': 103, 'cat_10': 122} dense_cols = [f"int_{i}" for i in range(1, 11)] cat_cols = [f"cat_{i}" for i in range(1, 11)] emb_counts = [cat_data[k] for k in cat_cols] def transform_to_torchrec_batch(batch, num_embeddings_per_feature: Optional[List[int]] = None) -> Batch: cat_list = [] for col_name in dense_cols: val = torch.tensor(batch[col_name], dtype=torch.float32) cat_list.append(val.unsqueeze(0).T) dense_features = torch.cat(cat_list, dim=1) kjt_values: List[int] = [] kjt_lengths: List[int] = [] for col_idx, col_name in enumerate(cat_cols): values = batch[col_name] for value in values: if value: kjt_values.append( value % num_embeddings_per_feature[col_idx] ) kjt_lengths.append(1) else: kjt_lengths.append(0) sparse_features = KeyedJaggedTensor.from_lengths_sync( cat_cols, torch.tensor(kjt_values), torch.tensor(kjt_lengths, dtype=torch.int32), ) labels = torch.tensor(batch["label"], dtype=torch.int32) assert isinstance(labels, torch.Tensor) return Batch( dense_features=dense_features, sparse_features=sparse_features, labels=labels, ) transform_partial = partial(transform_to_torchrec_batch, num_embeddings_per_feature=emb_counts)
%md ### 2.3. Helper Function for DataLoading using Mosaic's StreamingDataset This utilizes Mosaic's StreamingDataset and Mosaic's StreamingDataLoader for efficient data loading. For more information, view this [documentation](https://docs.mosaicml.com/projects/streaming/en/stable/distributed_training/fast_resumption.html#saving-and-loading-state).
2.3. Helper Function for DataLoading using Mosaic's StreamingDataset
This utilizes Mosaic's StreamingDataset and Mosaic's StreamingDataLoader for efficient data loading. For more information, view this documentation.
def get_dataloader_with_mosaic(path, batch_size, label): print(f"Getting {label} data from UC Volumes") dataset = StreamingDataset(local=path, shuffle=True, batch_size=batch_size) return StreamingDataLoader(dataset, batch_size=batch_size)
%md ## Step 3. Creating the Relevant TorchRec code for Training This contains all of the training and evaluation code.
Step 3. Creating the Relevant TorchRec code for Training
This contains all of the training and evaluation code.
%md ### 3.1. Base Dataclass for Training inputs Feel free to modify any of the variables mentioned here, but note that the final layer for `dense_arch_layer_sizes` should be equivalent to `embedding_dim`.
3.1. Base Dataclass for Training inputs
Feel free to modify any of the variables mentioned here, but note that the final layer for dense_arch_layer_sizes
should be equivalent to embedding_dim
.
# TODO: Update these values for training as needed @dataclass class Args: """ Training arguments. """ epochs: int = 3 # Training for one Epoch embedding_dim: int = 128 # Embedding dimension is 128 dense_arch_layer_sizes: list = dataclasses.field(default_factory=lambda: [512, 256, 128]) # The last layer for dense should be equivalent to `embedding_dim` over_arch_layer_sizes: list = dataclasses.field(default_factory=lambda: [512, 512, 256, 1]) # We are essentially doing binary classification here, so the last layer is 1 learning_rate: float = 0.03 eps: float = 1e-8 batch_size: int = 512 print_sharding_plan: bool = True print_lr: bool = False # Optional, prints the learning rate at each iteration step lr_warmup_steps: int = 0 # Optional, sets the learning rate warmup steps lr_decay_start: int = 0 # Optional, sets the decay start for learning rate lr_decay_steps: int = 0 # Optional, sets the decay steps for learning rate validation_freq: int = None # Optional, determines how often during training you want to run validation (# of training steps) limit_train_batches: int = None # Optional, limits the number of training batches limit_val_batches: int = None # Optional, limits the number of validation batches limit_test_batches: int = None # Optional, limits the number of test batches # We will use this to store our results in mlflow def get_relevant_fields(args, dense_cols, cat_cols, emb_counts): fields_to_save = ["epochs", "embedding_dim", "dense_arch_layer_sizes", "over_arch_layer_sizes", "learning_rate", "eps", "batch_size"] result = { key: getattr(args, key) for key in fields_to_save } # adding dense cols result["dense_cols"] = dense_cols result["cat_cols"] = cat_cols result["emb_counts"] = emb_counts return result
@dataclass class TrainValTestResults: """ A dataclass to store our results. """ val_aurocs: List[float] = field(default_factory=list) test_auroc: Optional[float] = None
%md ### 3.2. LR Scheduler This isn't specifically used unless you want to schedule the learning rate for the Adagrad Optimizer.
3.2. LR Scheduler
This isn't specifically used unless you want to schedule the learning rate for the Adagrad Optimizer.
class LRPolicyScheduler(_LRScheduler): def __init__(self, optimizer, num_warmup_steps, decay_start_step, num_decay_steps): self.num_warmup_steps = num_warmup_steps self.decay_start_step = decay_start_step self.decay_end_step = decay_start_step + num_decay_steps self.num_decay_steps = num_decay_steps if self.decay_start_step < self.num_warmup_steps: sys.exit("Learning rate warmup must finish before the decay starts") super(LRPolicyScheduler, self).__init__(optimizer) def get_lr(self): step_count = self._step_count if step_count < self.num_warmup_steps: # warmup scale = 1.0 - (self.num_warmup_steps - step_count) / self.num_warmup_steps lr = [base_lr * scale for base_lr in self.base_lrs] self.last_lr = lr elif self.decay_start_step <= step_count and step_count < self.decay_end_step: # decay decayed_steps = step_count - self.decay_start_step scale = ((self.num_decay_steps - decayed_steps) / self.num_decay_steps) ** 2 min_lr = 0.0000001 lr = [max(min_lr, base_lr * scale) for base_lr in self.base_lrs] self.last_lr = lr else: if self.num_decay_steps > 0: # freeze at last, either because we're after decay # or because we're between warmup and decay lr = self.last_lr else: # do not adjust lr = self.base_lrs return lr
%md ### 3.3. Training and Evaluation Helper Functions
3.3. Training and Evaluation Helper Functions
def batched(it, n): assert n >= 1 for x in it: yield itertools.chain((x,), itertools.islice(it, n - 1))
%md ### 3.4.1. Helper Functions for Distributed Model Saving
3.4.1. Helper Functions for Distributed Model Saving
# DLRM and TorchRec use special tensors called ShardedTensors, so we localize them # and put them in the same rank that is saving to mlflow def gather_and_get_state_dict(model): rank = dist.get_rank() world_size = dist.get_world_size() state_dict = model.state_dict() gathered_state_dict = {} # Iterate over all items in the state_dict for fqn, tensor in state_dict.items(): if isinstance(tensor, ShardedTensor): # Collect all shards of the tensor across ranks full_tensor = None if rank == 0: full_tensor = torch.zeros(tensor.size()).to(tensor.device) tensor.gather(0, full_tensor) if rank == 0: gathered_state_dict[fqn] = full_tensor else: # Directly add non-sharded tensors to the new state_dict if rank == 0: gathered_state_dict[fqn] = tensor return gathered_state_dict def log_state_dict_to_mlflow(model, artifact_path) -> None: # All ranks participate in gathering state_dict = gather_and_get_state_dict(model) # Only rank 0 logs the state dictionary if dist.get_rank() == 0 and state_dict: mlflow.pytorch.log_state_dict(state_dict, artifact_path=artifact_path)
%md ### 3.4.2. Helper Functions for Distributed Model Training and Evaluation
3.4.2. Helper Functions for Distributed Model Training and Evaluation
def train( pipeline: TrainPipelineSparseDist, train_dataloader: DataLoader, val_dataloader: DataLoader, epoch: int, lr_scheduler, print_lr: bool, validation_freq: Optional[int], limit_train_batches: Optional[int], limit_val_batches: Optional[int]) -> None: """ Trains model for 1 epoch. Helper function for train_val_test. Args: pipeline (TrainPipelineSparseDist): data pipeline. train_dataloader (DataLoader): Training set's dataloader. val_dataloader (DataLoader): Validation set's dataloader. epoch (int): The number of complete passes through the training set so far. lr_scheduler (LRPolicyScheduler): Learning rate scheduler. print_lr (bool): Whether to print the learning rate every training step. validation_freq (Optional[int]): The number of training steps between validation runs within an epoch. limit_train_batches (Optional[int]): Limits the training set to the first `limit_train_batches` batches. limit_val_batches (Optional[int]): Limits the validation set to the first `limit_val_batches` batches. Returns: None. """ pipeline._model.train() # Get the first `limit_train_batches` batches iterator = itertools.islice(iter(train_dataloader), limit_train_batches) # Only print out the progress bar on rank 0 is_rank_zero = dist.get_rank() == 0 if is_rank_zero: pbar = tqdm( iter(int, 1), desc=f"Epoch {epoch}", total=len(train_dataloader), disable=False, ) # TorchRec's pipeline paradigm is unique as it takes in an iterator of batches for training. start_it = 0 n = validation_freq if validation_freq else len(train_dataloader) for batched_iterator in batched(iterator, n): for it in itertools.count(start_it): try: if is_rank_zero and print_lr: for i, g in enumerate(pipeline._optimizer.param_groups): print(f"lr: {it} {i} {g['lr']:.6f}") pipeline.progress(map(transform_partial, batched_iterator)) lr_scheduler.step() if is_rank_zero: pbar.update(1) except StopIteration: if is_rank_zero: print("Total number of iterations:", it) start_it = it break # If we are validating frequently, use the evaluation function if validation_freq and start_it % validation_freq == 0: evaluate(limit_val_batches, pipeline, val_dataloader, "val") pipeline._model.train()
def evaluate( limit_batches: Optional[int], pipeline: TrainPipelineSparseDist, eval_dataloader: DataLoader, stage: str) -> float: """ Evaluates model. Computes and prints AUROC. Helper function for train_val_test. Args: limit_batches (Optional[int]): Limits the dataloader to the first `limit_batches` batches. pipeline (TrainPipelineSparseDist): data pipeline. eval_dataloader (DataLoader): Dataloader for either the validation set or test set. stage (str): "val" or "test". Returns: float: auroc result """ pipeline._model.eval() device = pipeline._device iterator = itertools.islice(iter(eval_dataloader), limit_batches) # We are using AUROC for our metric auroc = metrics.AUROC(task="binary").to(device) is_rank_zero = dist.get_rank() == 0 if is_rank_zero: pbar = tqdm( iter(int, 1), desc=f"Evaluating {stage} set", total=len(eval_dataloader), disable=False, ) with torch.no_grad(): while True: try: _loss, logits, labels = pipeline.progress(map(transform_partial, iterator)) preds = torch.sigmoid(logits) auroc(preds, labels) if is_rank_zero: pbar.update(1) except StopIteration: break auroc_result = auroc.compute().item() num_samples = torch.tensor(sum(map(len, auroc.target)), device=device) dist.reduce(num_samples, 0, op=dist.ReduceOp.SUM) if is_rank_zero: print(f"AUROC over {stage} set: {auroc_result}.") print(f"Number of {stage} samples: {num_samples}") return auroc_result
def train_val_test(args, model, optimizer, device, train_dataloader, val_dataloader, test_dataloader, lr_scheduler) -> TrainValTestResults: """ Train/validation/test loop. Args: args (Args): the arguments for the model defined earlier. model (torch.nn.Module): model to train. optimizer (torch.optim.Optimizer): optimizer to use. device (torch.device): device to use. train_dataloader (DataLoader): Training set's dataloader. val_dataloader (DataLoader): Validation set's dataloader. test_dataloader (DataLoader): Test set's dataloader. lr_scheduler (LRPolicyScheduler): Learning rate scheduler. Returns: TrainValTestResults. """ results = TrainValTestResults() pipeline = TrainPipelineSparseDist(model, optimizer, device, execute_all_batches=True) # Getting base auroc and saving it to mlflow val_auroc = evaluate(args.limit_val_batches, pipeline, val_dataloader, "val") results.val_aurocs.append(val_auroc) if int(os.environ["RANK"]) == 0: mlflow.log_metric('val_auroc', val_auroc) # Running a training loop for epoch in range(args.epochs): train( pipeline, train_dataloader, val_dataloader, epoch, lr_scheduler, args.print_lr, args.validation_freq, args.limit_train_batches, args.limit_val_batches, ) # Evaluate after each training epoch val_auroc = evaluate(args.limit_val_batches, pipeline, val_dataloader, "val") results.val_aurocs.append(val_auroc) if int(os.environ["RANK"]) == 0: mlflow.log_metric('val_auroc', val_auroc) # Save the underlying model and results to mlflow log_state_dict_to_mlflow(pipeline._model.module, artifact_path=f"model_state_dict_{epoch}") # Evaluate on the test set after training loop finishes test_auroc = evaluate(args.limit_test_batches, pipeline, test_dataloader, "test") results.test_auroc = test_auroc if int(os.environ["RANK"]) == 0: mlflow.log_metric('test_auroc', test_auroc) return results
%md ### 3.5. Setting up MLflow **Note:** You must update the route for `db_host` to the URL of your Databricks workspace.
3.5. Setting up MLflow
Note: You must update the route for db_host
to the URL of your Databricks workspace.
username = spark.sql("SELECT current_user()").first()['current_user()'] username experiment_path = f'/Users/{username}/torchrec-dlrm-example' # TODO: Update the `db_host` with the URL for your Databricks workspace db_host = "https://workspace-name.cloud.databricks.com/" db_token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get() # We manually create the experiment so that we know the id and can send that to the worker nodes when we scale experiment = mlflow.set_experiment(experiment_path)
%md ### 3.6. The Main Function This function will train the DLRM recommendation model. For more information, view the following guides/docs/code: - https://pytorch.org/torchrec/ - https://github.com/pytorch/torchrec - https://github.com/facebookresearch/dlrm/blob/main/torchrec_dlrm/dlrm_main.py
3.6. The Main Function
This function will train the DLRM recommendation model. For more information, view the following guides/docs/code:
# TODO: Specify where the data is stored in UC Volumes input_dir_train = "/Volumes/ml/recommender_systems/dlrm_sample_data/mds_train" input_dir_validation = "/Volumes/ml/recommender_systems/dlrm_sample_data/mds_validation" input_dir_test = "/Volumes/ml/recommender_systems/dlrm_sample_data/mds_test" def main(args): import torch import mlflow from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict # Some preliminary torch setup torch.jit._state.disable() global_rank = int(os.environ["RANK"]) local_rank = int(os.environ["LOCAL_RANK"]) device = torch.device(f"cuda:{local_rank}") backend = "nccl" torch.cuda.set_device(device) # Starting mlflow os.environ['DATABRICKS_HOST'] = db_host os.environ['DATABRICKS_TOKEN'] = db_token experiment = mlflow.set_experiment(experiment_path) # Saving parameters to mlflow if global_rank == 0: param_dict = get_relevant_fields(args, dense_cols, cat_cols, emb_counts) mlflow.log_params(param_dict) # Straing distributed process group dist.init_process_group(backend=backend) # Loading the data train_dataloader = get_dataloader_with_mosaic(input_dir_train, args.batch_size, "train") val_dataloader = get_dataloader_with_mosaic(input_dir_validation, args.batch_size, "val") test_dataloader = get_dataloader_with_mosaic(input_dir_test, args.batch_size, "test") # Creating the embedding tables eb_configs = [ EmbeddingBagConfig( name=f"t_{feature_name}", embedding_dim=args.embedding_dim, num_embeddings=emb_counts[feature_idx], feature_names=[feature_name], ) for feature_idx, feature_name in enumerate(cat_cols) ] # Creating the DLRM model dlrm_model = DLRM( embedding_bag_collection=EmbeddingBagCollection( tables=eb_configs, device=torch.device("meta") ), dense_in_features=len(dense_cols), dense_arch_layer_sizes=args.dense_arch_layer_sizes, over_arch_layer_sizes=args.over_arch_layer_sizes, dense_device=device, ) train_model = DLRMTrain(dlrm_model) # Setting up the optimizer embedding_optimizer = torch.optim.Adagrad optimizer_kwargs = {"lr": args.learning_rate, "eps": args.eps} apply_optimizer_in_backward( embedding_optimizer, train_model.model.sparse_arch.parameters(), optimizer_kwargs, ) # Creating a plan to shard the embedding tables across the GPUs and creating a distributed model planner = EmbeddingShardingPlanner( topology=Topology( local_world_size=get_local_size(), world_size=dist.get_world_size(), compute_device=device.type, ), batch_size=args.batch_size, # If experience OOM, increase the percentage. see # https://pytorch.org/torchrec/torchrec.distributed.planner.html#torchrec.distributed.planner.storage_reservations.HeuristicalStorageReservation storage_reservation=HeuristicalStorageReservation(percentage=0.05), ) plan = planner.collective_plan( train_model, get_default_sharders(), dist.GroupMember.WORLD ) model = DistributedModelParallel( module=train_model, device=device, plan=plan, ) log_state_dict_to_mlflow(model.module, artifact_path="model_state_dict_base") # Printing out the sharding information to see how the embedding tables are sharded across the GPUs if global_rank == 0 and args.print_sharding_plan: for collectionkey, plans in model._plan.plan.items(): print(collectionkey) for table_name, plan in plans.items(): print(table_name, "\n", plan, "\n") # Creating the optimizer for the distributed model def optimizer_with_params(): return lambda params: torch.optim.Adagrad(params, lr=args.learning_rate, eps=args.eps) dense_optimizer = KeyedOptimizerWrapper( dict(in_backward_optimizer_filter(model.named_parameters())), optimizer_with_params(), ) optimizer = CombinedOptimizer([model.fused_optimizer, dense_optimizer]) lr_scheduler = LRPolicyScheduler( optimizer, args.lr_warmup_steps, args.lr_decay_start, args.lr_decay_steps ) # Starting the training loop results = train_val_test( args, model, optimizer, device, train_dataloader, val_dataloader, test_dataloader, lr_scheduler, ) # Destroying the process group dist.destroy_process_group()
%md ## Step 4. Single Node + Single GPU Training Here, you set the environment variables to run training over the sample set of 100,000 data points (stored in Volumes in Unity Catalog and collected using Mosaic StreamingDataset). You can expect each epoch to take ~40 minutes.
Step 4. Single Node + Single GPU Training
Here, you set the environment variables to run training over the sample set of 100,000 data points (stored in Volumes in Unity Catalog and collected using Mosaic StreamingDataset). You can expect each epoch to take ~40 minutes.
os.environ["RANK"] = "0" os.environ["LOCAL_RANK"] = "0" os.environ["WORLD_SIZE"] = "1" os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29500" args = Args() main(args)
%md ## Step 5. Single Node + Multi GPU Training This notebook uses the TorchDistributor for handling training on a `g4dn.12xlarge` instance with 4 T4 GPUs. You can view the sharding plan to see what tables are located on what GPUs. This takes ~14 minutes to run per epoch. **Note**: There may be cases where you receive unexpected errors (like the Python Kernel crashing or segmentation faults). This is a transient error and the easiest way to overcome it is to skip the single node single GPU training code before we run any distributed code (single node multi GPU or multi node multi GPU). **Note**: If you see any logs that are associated with Mosaic Data Loading, these are transient errors that can be overcome by simply rerunning the failed cell.
Step 5. Single Node + Multi GPU Training
This notebook uses the TorchDistributor for handling training on a g4dn.12xlarge
instance with 4 T4 GPUs. You can view the sharding plan to see what tables are located on what GPUs. This takes ~14 minutes to run per epoch.
Note: There may be cases where you receive unexpected errors (like the Python Kernel crashing or segmentation faults). This is a transient error and the easiest way to overcome it is to skip the single node single GPU training code before we run any distributed code (single node multi GPU or multi node multi GPU).
Note: If you see any logs that are associated with Mosaic Data Loading, these are transient errors that can be overcome by simply rerunning the failed cell.
args = Args() TorchDistributor(num_processes=4, local_mode=True, use_gpu=True).run(main, args)
%md ## Step 6. Multi Node + Multi GPU Training This is tested with a `g4dn.12xlarge` instance as a worker (with 4 T4 GPUs). You can view the sharding plan to see what tables are located on what GPUs. This takes ~10 minutes to run per epoch. **Note**: There may be cases where you receive unexpected errors (like the Python Kernel crashing or segmentation faults). This is a transient error and the easiest way to overcome it is to skip the single node single GPU training code before you run any distributed code (single node multi GPU or multi node multi GPU). **Note**: If you see any logs that are associated with Mosaic Data Loading, these are transient errors that can be overcome by simply rerunning the failed cell.
Step 6. Multi Node + Multi GPU Training
This is tested with a g4dn.12xlarge
instance as a worker (with 4 T4 GPUs). You can view the sharding plan to see what tables are located on what GPUs. This takes ~10 minutes to run per epoch.
Note: There may be cases where you receive unexpected errors (like the Python Kernel crashing or segmentation faults). This is a transient error and the easiest way to overcome it is to skip the single node single GPU training code before you run any distributed code (single node multi GPU or multi node multi GPU).
Note: If you see any logs that are associated with Mosaic Data Loading, these are transient errors that can be overcome by simply rerunning the failed cell.
args = Args() TorchDistributor(num_processes=4, local_mode=False, use_gpu=True).run(main, args)
%md ## Step 7. Inference Since the DLRM Model's `state_dict`s are logged to MLflow, you can use the following code to load any of the saved `state_dict`s and create the associated DLRM model with it. You can further expand this by 1) saving the loaded model to mlflow for inference or 2) doing batch inference using a UDF. Note: The saving code and loading code is used for loading the entire DLRM model on one node and is useful as an example. In real world use cases, the expected model size could be significant (as the embedding tables can scale with the number of users or the number of products and items). It might be worthwhile to consider distributed inference.
Step 7. Inference
Since the DLRM Model's state_dict
s are logged to MLflow, you can use the following code to load any of the saved state_dict
s and create the associated DLRM model with it. You can further expand this by 1) saving the loaded model to mlflow for inference or 2) doing batch inference using a UDF.
Note: The saving code and loading code is used for loading the entire DLRM model on one node and is useful as an example. In real world use cases, the expected model size could be significant (as the embedding tables can scale with the number of users or the number of products and items). It might be worthwhile to consider distributed inference.
%md ### 7.1. Creating the DLRM model from saved `state_dict` **Note:** You must update this with the correct `run_id` and path to the MLflow artifact.
7.1. Creating the DLRM model from saved state_dict
Note: You must update this with the correct run_id
and path to the MLflow artifact.
from mlflow import MlflowClient def get_latest_run_id(experiment): latest_run = mlflow.search_runs(experiment_ids=[experiment.experiment_id], order_by=["start_time desc"], max_results=1).iloc[0] return latest_run.run_id def get_latest_artifact_path(run_id): client = MlflowClient() run = client.get_run(run_id) artifact_uri = run.info.artifact_uri artifact_paths = [i.path for i in client.list_artifacts(run_id) if "base" not in i.path] return artifact_paths[-1] def get_mlflow_model(run_id, artifact_path="model_state_dict"): from mlflow import MlflowClient device = torch.device("cuda") run = mlflow.get_run(run_id) cat_cols = eval(run.data.params.get('cat_cols')) emb_counts = eval(run.data.params.get('emb_counts')) dense_cols = eval(run.data.params.get('dense_cols')) embedding_dim = int(run.data.params.get('embedding_dim')) dense_arch_layer_sizes = eval(run.data.params.get('dense_arch_layer_sizes')) over_arch_layer_sizes = eval(run.data.params.get('over_arch_layer_sizes')) MlflowClient().download_artifacts(run_id, f"{artifact_path}/state_dict.pth", "/databricks/driver") state_dict = mlflow.pytorch.load_state_dict(f"/databricks/driver/{artifact_path}") # Remove the prefix "model." from all of the keys in the state_dict state_dict = {k[6:]: v for k, v in state_dict.items()} eb_configs = [ EmbeddingBagConfig( name=f"t_{feature_name}", embedding_dim=embedding_dim, num_embeddings=emb_counts[feature_idx], feature_names=[feature_name], ) for feature_idx, feature_name in enumerate(cat_cols) ] dlrm_model = DLRM( embedding_bag_collection=EmbeddingBagCollection( tables=eb_configs, device=device ), dense_in_features=len(dense_cols), dense_arch_layer_sizes=dense_arch_layer_sizes, over_arch_layer_sizes=over_arch_layer_sizes, dense_device=device, ) dlrm_model.load_state_dict(state_dict) return dlrm_model, dense_cols, cat_cols, emb_counts # Loading the latest model state dict from the latest run of the current experiment latest_run_id = get_latest_run_id(experiment) latest_artifact_path = get_latest_artifact_path(latest_run_id) dlrm_model, dense_cols, cat_cols, emb_counts = get_mlflow_model(latest_run_id, artifact_path=latest_artifact_path)
%md ### 7.2. Helper Function to Transform Dataloader to DLRM Inputs The inputs that DLRM expects are `dense_features` and `sparse_features`, and so this section reuses aspects of the code from Section 3.4.2. The code shown here is verbose for clarity.
7.2. Helper Function to Transform Dataloader to DLRM Inputs
The inputs that DLRM expects are dense_features
and sparse_features
, and so this section reuses aspects of the code from Section 3.4.2. The code shown here is verbose for clarity.
def transform_test(batch, dense_cols, cat_cols, emb_counts): cat_list = [] for col_name in dense_cols: val = torch.tensor(batch[col_name], dtype=torch.float32) cat_list.append(val.unsqueeze(0).T) dense_features = torch.cat(cat_list, dim=1) kjt_values: List[int] = [] kjt_lengths: List[int] = [] for col_idx, col_name in enumerate(cat_cols): values = batch[col_name] for value in values: if value: kjt_values.append( value % emb_counts[col_idx] ) kjt_lengths.append(1) else: kjt_lengths.append(0) sparse_features = KeyedJaggedTensor.from_lengths_sync( cat_cols, torch.tensor(kjt_values), torch.tensor(kjt_lengths, dtype=torch.int32), ) return dense_features, sparse_features
%md ### 7.3. Getting the Data
7.3. Getting the Data
num_batches = 5 # Number of batches we want to print out at a time batch_size = 1 # We want to print out each individual row # TODO: Update this path to point to the test dataset stored in UC Volumes test_data_path = "/Volumes/ml/recommender_systems/dlrm_sample_data/mds_test" test_dataloader = iter(get_dataloader_with_mosaic(test_data_path, batch_size, "test"))
%md ### 7.4. Running Tests In this example, you ran training for 3 epochs. The results were reasonable. Running a larger number of epochs would likely lead to optimal performance.
7.4. Running Tests
In this example, you ran training for 3 epochs. The results were reasonable. Running a larger number of epochs would likely lead to optimal performance.
for _ in range(num_batches): device = torch.device("cuda:0") dlrm_model.to(device) dlrm_model.eval() next_batch = next(test_dataloader) expected_result = next_batch["label"][0] dense_features, sparse_features = transform_test(next_batch, dense_cols, cat_cols, emb_counts) dense_features = dense_features.to(device) sparse_features = sparse_features.to(device) actual_result = torch.sigmoid(dlrm_model(dense_features=dense_features, sparse_features=sparse_features)) print(f"Expected Result: {expected_result}; Actual Result: {round(actual_result[0][0].item())}")
%md ## Step 8. Model Serving For information about how to serve the model, see the Databricks Model Serving documentation ([AWS](https://docs.databricks.com/en/machine-learning/model-serving/index.html) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/machine-learning/model-serving/)).