Distributed model inference using PyTorch
This notebook demonstrates how to do distributed model inference using PyTorch with ResNet-50 model from torchvision.models and image files as input data.
This guide consists of the following sections:
- Prepare trained model for inference.
- Load the data from databricks-dataset into Spark DataFrames.
- Run model inference via Pandas UDF.
Note:
- To run the notebook on CPU-enabled Apache Spark clusters, change the variable
cuda = False
. - To run the notebook on GPU-enabled Apache Spark clusters, change the variable
cuda = True
.
Prepare trained model for inference
Load ResNet50 on driver node and broadcast its state.
Load the data from databricks-dataset into Spark DataFrames
This notebooks uses the flowers dataset from the TensorFlow team as our example dataset, which contains flower photos stored under five sub-directories, one per class.
Create a DataFrame of image paths.
Run model inference via Pandas UDF
Create a custom PyTorch dataset class.
Define the function for model inference.
Run the model inference and save the result to a Parquet file.
Load and check the prediction results.