Deep learning model inference workflow

For model inference for deep learning applications, Databricks recommends the following workflow. For example notebooks that use TensorFlow and PyTorch, see Deep learning model inference examples.

  1. Load the data into Spark DataFrames. Depending on the data type, Databricks recommends the following ways to load data:

    • Image files (JPG,PNG): Load the image paths into a Spark DataFrame. Image loading and preprocessing input data occurs in a pandas UDF.

    files_df = spark.createDataFrame(map(lambda path: (path,), file_paths), ["path"])
    
    df = spark.read.format("tfrecords").load(image_path)
    
    • Data sources such as Parquet, CSV, JSON, JDBC, and other metadata: Load the data using Spark data sources.

  2. Perform model inference using pandas UDFs. pandas UDFs use Apache Arrow to transfer data and pandas to work with the data. To do model inference, the following are the broad steps in the workflow with pandas UDFs.

    1. Load the trained model: For efficiency, Databricks recommends broadcasting the weights of the model from the driver and loading the model graph and get the weights from the broadcasted variables in a pandas UDF.

    2. Load and preprocess input data: To load data in batches, Databricks recommends using the tf.data API for TensorFlow and the DataLoader class for PyTorch. Both also support prefetching and multi-threaded loading to hide IO bound latency.

    3. Run model prediction: run model inference on the data batch.

    4. Send predictions back to Spark DataFrames: collect the prediction results and return as pd.Series.

Deep learning model inference examples

The examples in this section follow the recommended deep learning inference workflow. These examples illustrate how to perform model inference using a pre-trained deep residual networks (ResNets) neural network model.