Skip to main content

Perform batch inference using a Spark DataFrame

This article describes how to perform batch inference on a Spark DataFrame using a registered model in Databricks. The workflow applies to various machine learning and deep learning models, including TensorFlow, PyTorch, and scikit-learn. It includes best practices for data loading, model inference, and performance tuning.

For model inference for deep learning applications, Databricks recommends the following workflow. For example notebooks that use TensorFlow and PyTorch, see Batch inference examples.

Model inference workflow

Databricks recommends the following workflow for performing batch inference using Spark DataFrames.

Step 1: Environment setup

Ensure your cluster runs a compatible Databricks ML Runtime version to match the training environment. The model, logged using MLflow, contains the requirements that can be installed to ensure that the training and inference environments match.

Python
requirements_path = os.path.join(local_path, "requirements.txt")
if not os.path.exists(requirements_path):
dbutils.fs.put("file:" + requirements_path, "", True)

%pip install -r $requirements_path
%restart_python

Step 2: Load data into Spark DataFrames

Depending on the data type, use the appropriate method to load data into a Spark DataFrame:

Data type

Method

Table from Unity Catalog (Recommended)

table = spark.table(input_table_name)

Image files (JPG, PNG)

files_df = spark.createDataFrame(map(lambda path: (path,), file_paths), ["path"])

TFRecords

df = spark.read.format("tfrecords").load(image_path)

Other formats (Parquet, CSV, JSON, JDBC)

Load using Spark data sources.

Step 3: Load model from model registry

This example uses a model from the Databricks Model Registry for inference.

Python
predict_udf = mlflow.pyfunc.spark_udf(spark, model_uri)

Step 4: Perform model inference using Pandas UDFs

Pandas UDFs leverage Apache Arrow for efficient data transfer and pandas for processing. The typical steps for inference with pandas UDFs are:

  1. Load the trained model: Use MLflow to create a Spark UDF for inference.
  2. Preprocess input data: Ensure the input schema matches the model requirements.
  3. Run model prediction: Use the model’s UDF function on the DataFrame.
Python
df_result = df_spark.withColumn("prediction", predict_udf(*df_spark.columns))
  1. (Recommended) Save predictions to Unity Catalog.

The following example saves predictions to Unity Catalog.

Python
df_result.write.mode("overwrite").saveAsTable(output_table)

Performance tuning for model inference

This section provides some tips for debugging and performance tuning for model inference on Databricks. For an overview, see the Perform batch inference using a Spark DataFrame.

Typically there are two main parts in model inference: data input pipeline and model inference. The data input pipeline is heavy on data I/O input and model inference is heavy on computation. Determining the bottleneck of the workflow is simple. Here are some approaches:

  • Reduce the model to a trivial model and measure the examples per second. If the difference of the end to end time between the full model and the trivial model is minimal, then the data input pipeline is likely a bottleneck, otherwise model inference is the bottleneck.
  • If running model inference with GPU, check the GPU utilization metrics. If GPU utilization is not continuously high, then the data input pipeline may be the bottleneck.

Optimize data input pipeline

Using GPUs can efficiently optimize the running speed for model inference. As GPUs and other accelerators become faster, it is important that the data input pipeline keep up with demand. The data input pipeline reads the data into Spark DataFrames, transforms it, and loads it as the input for model inference. If data input is the bottleneck, here are some tips to increase I/O throughput:

  • Set the max records per batch. Larger number of max records can reduce the I/O overhead to call the UDF function as long as the records can fit in memory. To set the batch size, set the following config:

    Python
    spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "5000")
  • Load the data in batches and prefetch it when preprocessing the input data in the pandas UDF.

    For TensorFlow, Databricks recommends using the tf.data API. You can parse the map in parallel by setting num_parallel_calls in a map function and call prefetch and batch for prefetching and batching.

    Python
    dataset.map(parse_example, num_parallel_calls=num_process).prefetch(prefetch_size).batch(batch_size)

    For PyTorch, Databricks recommends using the DataLoader class. You can set batch_size for batching and num_workers for parallel data loading.

    Python
    torch.utils.data.DataLoader(images, batch_size=batch_size, num_workers=num_process)

Batch inference examples

The examples in this section follow the recommended deep learning inference workflow. These examples illustrate how to perform model inference using a pre-trained deep residual networks (ResNets) neural network model.