pytorch-images(Python)

Loading...

Distributed model inference using PyTorch

This notebook demonstrates how to do distributed model inference using PyTorch with ResNet-50 model from torchvision.models and image files as input data.

This guide consists of the following sections:

  • Prepare trained model for inference.
  • Load the data from databricks-dataset into Spark DataFrames.
  • Run model inference via Pandas UDF.

Note:

  • To run the notebook on CPU-enabled Apache Spark clusters, change the variable cuda = False.
  • To run the notebook on GPU-enabled Apache Spark clusters, change the variable cuda = True.

    Prepare trained model for inference

    Load ResNet50 on driver node and broadcast its state.

    Load the data from databricks-dataset into Spark DataFrames

    This notebooks uses the flowers dataset from the TensorFlow team as our example dataset, which contains flower photos stored under five sub-directories, one per class.

    Create a DataFrame of image paths.

    Run model inference via Pandas UDF

    Create a custom PyTorch dataset class.

    Define the function for model inference.

    Run the model inference and save the result to a Parquet file.

    Load and check the prediction results.