1. Load data using Spark
This example uses the flowers dataset from the TensorFlow team,
which contains flower photos stored under five subdirectories, one per class. The dataset is available under Databricks Datasets at dbfs:/databricks-datasets/flowers
.
The example loads the flowers table, which contains the preprocessed flowers dataset, using the binary file data source. To reduce running time, this notebook uses a small subset of the flowers dataset, including ~90 training images and ~10 validation images. When you run this notebook, you can increase the number of images used for better model accuracy.
Preprocess images
Before feeding the dataset into the model, you need to decode the raw image bytes and apply standard ImageNet transforms. Databricks recommends not doing this transformation on the Spark DataFrame since that substantially increases the size of the intermediate files and might decrease performance. Instead, do this transformation in a TransformSpec
function in petastorm.
5. Feed the data into a distributed PyTorch model for training.
Use HorovodRunner for distributed training.
The example uses the default value of parameter num_epochs=None
to generate infinite batches of data to avoid handling the last incomplete batch. This is particularly useful in the distributed training scenario, where you need to guarantee that the numbers of data records seen on all workers are identical per step. Given that the length of each data shard may not be identical, setting num_epochs
to any specific number would fail to meet the guarantee.
Simplify data conversion from Spark to PyTorch
This notebook demonstrates the following workflow on Databricks:
spark_dataset_converter
.The example in this notebook is based on the transfer learning tutorial from PyTorch. It applies the pre-trained MobileNetV2 model to the flowers dataset.
Requirements