DLRM (using TorchRec + TorchDistributor + StreamingDataset)
This notebook illustrates how to create a distributed DLRM recommendation model for predicting click-through rates. This notebook was tested on g4dn.12xlarge
instances (one instance as the driver, one instance as the worker) on the Databricks Runtime for ML 14.3 LTS. It uses some code from the Facebook DLRM implementation (which has an MIT License) For more insight into the DLRM recommendation model, see the following resources:
- Facebook Research's repository: https://github.com/facebookresearch/dlrm
- Nvidia's repository: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Recommendation/DLRM/README.md
- Video overview: https://www.youtube.com/watch?v=r9J3UZmddC4
Note: Where you see # TODO
in this notebook, you may need to enter custom code to ensure that the notebook runs successfully.
Requirements
This notebook requires 14.3 LTS ML.
Step 1. Saving Data in UC Volumes in the MDS (Mosaic Data Shard) format
This notebook creates a synthetic dataset with 100k rows that will be used to train a DLRM model.
1.1. Creating a Synthetic Dataset
This notebook creates a synthetic dataset for predicting a binary label given both dense (numerical) and sparse (categorical) data. This synthetic dataset has a similar layout to other publicly available datasets, such as the Criteo click logs dataset. You can update this notebook to support those datasets as long as the data preprocessing is done correctly.
For a tangible example in retail, the numerical columns could represent features like the user's age, product's weight, or time of day, and the sparse columns could represent features like user's location, product's category, and so on. The label
column describes the interaction between the user and the product. For example, a positive label
of 1 might indicate that the user would buy the product, while a negative label
of 0 might indicate that the user would give the product a 1-star rating.
1.2. Preprocessing the Data
If you are using a dataset other than the provided synthetic dataset, update this cell to for preprocessing and data cleaning as needed. For this synthetic dataset, all that is required is to normalize the dense columns.
Note: You can repartition the dataset as needed to help improve performance for this cell.
1.3. Saving to MDS Format within UC Volumes
In this step, you convert the data to MDS to allow for efficient train/validation/test splitting and then save it to a UC Volume.
View the Mosaic Streaming guide here for more details:
- General details: https://docs.mosaicml.com/projects/streaming/en/stable/
- Main concepts: https://docs.mosaicml.com/projects/streaming/en/stable/getting_started/main_concepts.html#dataset-conversion
dataframeToMDS
details: https://docs.mosaicml.com/projects/streaming/en/stable/preparing_datasets/spark_dataframe_to_mds.html
Step 2. Helper Functions for Recommendation Dataloading
In this step, you install the necessary libraries, add imports, and add some relevant helper functions to train the model.
2.1. Installs and Imports
2.2. Helper functions for Converting to Pipelineable DataType
Using TorchRec pipelines requires a pipelineable data type (which is Batch
in this case). In this step, you create a helper function that takes each batch from the StreamingDataset and passes it through a transformation function to convert it into a pipelineable batch.
For further context, see:
2.3. Helper Function for DataLoading using Mosaic's StreamingDataset
This utilizes Mosaic's StreamingDataset and Mosaic's StreamingDataLoader for efficient data loading. For more information, view this documentation.
Step 3. Creating the Relevant TorchRec code for Training
This contains all of the training and evaluation code.
3.1. Base Dataclass for Training inputs
Feel free to modify any of the variables mentioned here, but note that the final layer for dense_arch_layer_sizes
should be equivalent to embedding_dim
.
3.2. LR Scheduler
This isn't specifically used unless you want to schedule the learning rate for the Adagrad Optimizer.
3.3. Training and Evaluation Helper Functions
3.4.1. Helper Functions for Distributed Model Saving
3.4.2. Helper Functions for Distributed Model Training and Evaluation
3.5. Setting up MLflow
Note: You must update the route for db_host
to the URL of your Databricks workspace.
3.6. The Main Function
This function will train the DLRM recommendation model. For more information, view the following guides/docs/code:
Step 4. Single Node + Single GPU Training
Here, you set the environment variables to run training over the sample set of 100,000 data points (stored in Volumes in Unity Catalog and collected using Mosaic StreamingDataset). You can expect each epoch to take ~40 minutes.
Step 5. Single Node + Multi GPU Training
This notebook uses the TorchDistributor for handling training on a g4dn.12xlarge
instance with 4 T4 GPUs. You can view the sharding plan to see what tables are located on what GPUs. This takes ~14 minutes to run per epoch.
Note: There may be cases where you receive unexpected errors (like the Python Kernel crashing or segmentation faults). This is a transient error and the easiest way to overcome it is to skip the single node single GPU training code before we run any distributed code (single node multi GPU or multi node multi GPU).
Note: If you see any logs that are associated with Mosaic Data Loading, these are transient errors that can be overcome by simply rerunning the failed cell.
Step 6. Multi Node + Multi GPU Training
This is tested with a g4dn.12xlarge
instance as a worker (with 4 T4 GPUs). You can view the sharding plan to see what tables are located on what GPUs. This takes ~10 minutes to run per epoch.
Note: There may be cases where you receive unexpected errors (like the Python Kernel crashing or segmentation faults). This is a transient error and the easiest way to overcome it is to skip the single node single GPU training code before you run any distributed code (single node multi GPU or multi node multi GPU).
Note: If you see any logs that are associated with Mosaic Data Loading, these are transient errors that can be overcome by simply rerunning the failed cell.
Step 7. Inference
Since the DLRM Model's state_dict
s are logged to MLflow, you can use the following code to load any of the saved state_dict
s and create the associated DLRM model with it. You can further expand this by 1) saving the loaded model to mlflow for inference or 2) doing batch inference using a UDF.
Note: The saving code and loading code is used for loading the entire DLRM model on one node and is useful as an example. In real world use cases, the expected model size could be significant (as the embedding tables can scale with the number of users or the number of products and items). It might be worthwhile to consider distributed inference.
7.1. Creating the DLRM model from saved state_dict
Note: You must update this with the correct run_id
and path to the MLflow artifact.
7.2. Helper Function to Transform Dataloader to DLRM Inputs
The inputs that DLRM expects are dense_features
and sparse_features
, and so this section reuses aspects of the code from Section 3.4.2. The code shown here is verbose for clarity.
7.3. Getting the Data
7.4. Running Tests
In this example, you ran training for 3 epochs. The results were reasonable. Running a larger number of epochs would likely lead to optimal performance.