Two Tower (using TorchRec + TorchDistributor + StreamingDataset)
This notebook illustrates how to create a distributed Two Tower recommendation model. This notebook was tested on g4dn.12xlarge
instances (one instance as the driver, one instance as the worker) using the Databricks Runtime for ML 14.3 LTS. For more insight into the Two Tower recommendation model, you can view the following resources:
- Hopsworks definition: https://www.hopsworks.ai/dictionary/two-tower-embedding-model
- TorchRec's training implementation: https://github.com/pytorch/torchrec/blob/main/examples/retrieval/two_tower_train.py#L75
Note: Where you see # TODO
in this notebook, you must enter custom code to ensure that the notebook runs successfully.
Requirements
This notebook requires 14.3 LTS ML.
1. Saving "Learning From Sets of Items" Data in UC Volumes in the MDS (Mosaic Data Shard) format
This notebook uses the small sample of 100k ratings from "Learning From Sets of Items". In this section you preprocess it and save it to a Volume in Unity Catalog.
1.1. Downloading the Dataset
Download the dataframe from https://grouplens.org/datasets/learning-from-sets-of-items-2019/
to /databricks/driver
and then save the data to a UC Table. The "Learning from Sets of Items" dataset has the Creative Commons 4.0 license.
1.2. Reading the Dataset from UC
The original dataset contains 500k data points. This example uses a sample of 100k data points from the dataset.
1.3. Preprocessing and Cleaning the Data
The first step is to convert the hashes (in string format) of each user to an integer using the StringIndexer.
The Two Tower Model provided by TorchRec here requires a binary label. The code in this section converts all ratings less than the mean to 0
and all values greater than the mean to 1
. For your own use case, you can modify the training task described here to use MSELoss instead (which can scale to ratings from 0 -> 5).
1.4. Saving to MDS Format within UC Volumes
In this step, you convert the data to MDS to allow for efficient train/validation/test splitting and then save it to a UC Volume.
View the Mosaic Streaming guide here for more details:
- General details: https://docs.mosaicml.com/projects/streaming/en/stable/
- Main concepts: https://docs.mosaicml.com/projects/streaming/en/stable/getting_started/main_concepts.html#dataset-conversion
dataframeToMDS
details: https://docs.mosaicml.com/projects/streaming/en/stable/preparing_datasets/spark_dataframe_to_mds.html
2. Helper Functions for Recommendation Dataloading
In this section, you install the necessary libraries, add imports, and add some relevant helper functions to train the model.
2.1. Installs and Imports
2.2. Helper functions for Converting to Pipelineable DataType
Using TorchRec pipelines requires a pipelineable data type (which is Batch
in this case). In this step, you create a helper function that takes each batch from the StreamingDataset and passes it through a transformation function to convert it into a pipelineable batch.
For further context, see https://github.com/pytorch/torchrec/blob/main/torchrec/datasets/utils.py#L28.
2.3. Helper Function for DataLoading using Mosaic's StreamingDataset
This utilizes Mosaic's StreamingDataset and Mosaic's StreamingDataLoader for efficient data loading. For more information, view this documentation.
3. Creating the Relevant TorchRec code for Training
This section contains all of the training and evaluation code.
3.1. Two Tower Model Definition
This is taken directly from the torchrec example's page. Note that the loss is the Binary Cross Entropy loss, which requires labels to be within the values {0, 1}.
3.2. Base Dataclass for Training inputs
Feel free to modify any of the variables mentioned here, but note that the first layer for layer_sizes
should be equivalent to embedding_dim
.
3.3. Training and Evaluation Helper Functions
3.3.1. Helper Functions for Distributed Model Saving
3.3.2. Helper Functions for Distributed Model Training and Evaluation
3.4. The Main Function
This function trains the Two Tower recommendation model. For more information, see the following guides/docs/code:
3.5. Setting up MLflow
Note: You must update the route for db_host
to the URL of your Databricks workspace.
4. Single Node + Single GPU Training
Here, you set the environment variables to run training over the sample set of 100,000 data points (stored in Volumes in Unity Catalog and collected using Mosaic StreamingDataset). You can expect each epoch to take ~16 minutes.
5. Single Node - Multi GPU Training
This notebook uses TorchDistributor to handle training on a g4dn.12xlarge
instance with 4 T4 GPUs. You can view the sharding plan in the output logs to see what tables are located on what GPUs. This takes ~8 minutes to run per epoch.
Note: There may be cases where you receive unexpected errors (like the Python Kernel crashing or segmentation faults). This is a transient error and the easiest way to overcome it is to skip the single node single GPU training code before you run any distributed code (single node multi GPU or multi node multi GPU).
Note: If you see any errors that are associated with Mosaic Data Loading, these are transient errors that can be overcome by rerunning the failed cell.
6. Multi Node + Multi GPU Training
This is tested with a g4dn.12xlarge
instance as a worker (with 4 T4 GPUs). You can view the sharding plan in the output logs to see what tables are located on what GPUs. This takes ~6 minutes to run per epoch.
Note: There may be cases where you receive unexpected errors (like the Python Kernel crashing or segmentation faults). This is a transient error and the easiest way to overcome it is to skip the single node single GPU training code before you run any distributed code (single node multi GPU or multi node multi GPU).
Note: If you see any errors that are associated with Mosaic Data Loading, these are transient errors that can be overcome by rerunning the failed cell.
7. Inference
Because the Two Tower Model's state_dict
s are logged to MLflow, you can use the following code to load any of the saved state_dict
s and create the associated Two Tower model with it. You can further expand this by 1) saving the loaded model to mlflow for inference or 2) doing batch inference using a UDF.
Note: The saving code and loading code is used for loading the entire Two Tower model on one node and is useful as an example. In real world use cases, the expected model size could be significant (as the embedding tables can scale with the number of users or the number of products and items). It might be worthwhile to consider distributed inference.
7.1. Creating the Two Tower model from saved state_dict
Note: You must update this with the correct run_id
and path to the MLflow artifact.
7.2. Helper Function to Transform Dataloader to Two Tower Inputs
The inputs that Two Tower expects are: sparse_features
, so this section reuses aspects of the code from Section 3.4.2. The code shown here is verbose for clarity.
7.3. Getting the Data
7.4. Running Tests
In this example, you ran training for 3 epochs. The results were reasonable. Running a larger number of epochs would likely lead to optimal performance.
8. Model Serving and Vector Search
For information about how to serve the model, see the Databricks Model Serving documentation (AWS | Azure).
Also, the Two Tower model is unique as it generates a query
and candidate
embedding, and therefore, allows you to create a vector index of movies, and then allows you to find the K movies that a user (given their generated vector) would most likely give a high rating. For more information, view the code here for how to create your own FAISS Index. You can also take a similar approach with Databricks Vector Search (AWS | Azure).