Simplify Data Loading from Spark to PyTorch using Mosaic Streaming
This notebook demonstrates the following workflow on Databricks:
- Load data using Spark.
- Use the
streaming.base.converters.dataframe_to_mds
API to save the dataset to a UC Volume. - Use
StreamingDataset
andStreamingDataLoader
to feed the data into a single-node PyTorch model for training. - Use
StreamingDataset
andStreamingDataLoader
to feed the data into a multi-node PyTorch model for training.
Requirements:
- Databricks Runtime for ML 15.2 or higher.
- Node configuration:
- One GPU driver.
- Two GPU workers.
0. Imports and Inputs
Please update the following cells with information on where data should be stored.
1. Load data using Spark
This example uses the flowers dataset from the TensorFlow team, which contains flower photos stored under five subdirectories, one per class. The dataset is available at /databricks-datasets/flowers
.
The example loads the flowers table, which contains the preprocessed flowers dataset, using the binary file data source. To reduce running time, this notebook uses a small subset of the flowers dataset, including ~90 training images and ~10 validation images. When you run this notebook, you can increase the number of images used for better model accuracy.
2. Save data to a UC Volume using Mosaic Streaming
The compressed training/evaluation/testing data can either be saved on the local machine for transient usage or to a UC Volume for persistent storage. When doing any form of distributed training, saving to a UC Volume is required since this data can be accessed by all of the spark workers. This data will be saved in the MDS (Mosaic Data Shard) format.
Various data types are supported, including bytes
, str
, int
, pkl
, img
, etc. The full list can be found here. Furthemore, you can specify specific compression algorithms, hashing, shard size limits, etc. View this page for more information.
3. Data Loading and Preprocessing Functions
The following functions will set up training by converting the training data to be better used by the underlying model. As you can see, the key function get_dataloader_with_mosaic
only takes two significant lines to set up.
4. Single Node Training Procedure
The following functions collect the mobilenet_v2
model and train the model for NUM_EPOCHS
epochs.
5. Multi Node Distributed Training Procedure
Creating the distributed training function involves a few minor changes. The following code is an example of DDP (Distributed Data Parallel) and will use the TorchDistributor
to start multi-node-multi-GPU training. View the TorchDistributor docs for more details.