For model inference, Databricks recommends the following workflow. For details about how to do model inference with Tensorflow, Keras, PyTorch, see the model inference examples.
Load the data into Spark DataFrames. Depending on the data type, Databricks recommends the following ways to load data:
- Image files (JPG,PNG): Load the image paths into a Spark DataFrame. Image loading and preprocessing input data occurs in a Pandas UDF.
files_df = spark.createDataFrame(map(lambda path: (path,), file_paths), ["path"])
- TFRecords: Load the data using the spark-tensorflow-connector.
df = spark.read.format("tfrecords").load(image_path)
- Data sources such as Parquet, CSV, JSON, JDBC, and other metadata: Load the data using Spark data sources.
Perform model inference using Pandas UDFs. Pandas UDFs use Apache Arrow to transfer data and pandas to work with the data. To do model inference, the following are the broad steps in the workflow with Pandas UDFs.
- Load the trained model: For efficiency, Databricks recommends broadcasting the weights of the model from the driver and loading the model graph and get the weights from the broadcasted variables in a Pandas UDF.
- Load and preprocess input data: To load data in batches, Databricks recommends using the tf.data API for TensorFlow and the DataLoader class for PyTorch. Both also support prefetching and multi-threaded loading to hide IO bound latency.
- Run model prediction: run model inference on the data batch.
- Send predictions back to Spark DataFrames: collect the prediction results and return as