tensorflow-tensorrt(Python)

Loading...

Model inference using TensorRT

This notebook demonstrates how to do distributed model inference using TensorFlow and TensorRT with ResNet-50 model.

Note:

To run the notebook, create a GPU-enabled cluster with Databricks Runtime 7.0 ML or above.

import os
import pandas as pd
import shutil
import uuid
  
from pyspark.sql.functions import col, pandas_udf, PandasUDFType
from pyspark.sql.types import ArrayType, FloatType
 
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.python.compiler.tensorrt import trt_convert as trt
from tensorflow.python.saved_model import signature_constants, tag_constants
from tensorflow.python.framework import convert_to_constants

Define the input and output directory.

uid = str(uuid.uuid1())
 
model_dir = f"/dbfs/ml/tmp/{uid}/model"
trt_model_dir = f"/dbfs/ml/tmp/{uid}/trt_model"
output_dbfs_dir = f"/ml/tmp/{uid}/predictions"

Prepare trained model and data for inference

Save the ResNet-50 Model

os.makedirs(model_dir)
model = ResNet50()
model.save(model_dir)

Optimize the model with TensorRT. For more details, check the TF-TRT User Guide.

conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS
conversion_params = conversion_params._replace(precision_mode='FP16')
converter = trt.TrtGraphConverterV2(
        input_saved_model_dir=model_dir,
        conversion_params=conversion_params,
    )
converter.convert()
converter.save(output_saved_model_dir=trt_model_dir)

Load the data into Spark DataFrames

df = spark.read.format("delta").load("/databricks-datasets/flowers/delta")

Run model inference via pandas UDF

Define the function to parse the input data.

def parse_example(image_data):
  image = tf.image.decode_jpeg(image_data, channels=3)
  image = tf.image.resize(image, [224, 224])
  return preprocess_input(image)

Define the function for model inference.

@pandas_udf(ArrayType(FloatType()), PandasUDFType.SCALAR_ITER)
def predict_batch_udf(image_batch_iter):
  batch_size = 64
  saved_model_loaded = tf.saved_model.load(
      trt_model_dir, tags=[tag_constants.SERVING])
  graph_func = saved_model_loaded.signatures[
      signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
  graph_func = convert_to_constants.convert_variables_to_constants_v2(graph_func)
  for image_batch in image_batch_iter:
    dataset = tf.data.Dataset.from_tensor_slices(image_batch)
    dataset = dataset.map(parse_example).prefetch(512).batch(64)
    prediction = []
    for i, batch_images in enumerate(dataset):
      batch_preds = graph_func(batch_images)[0].numpy()
      prediction = prediction + list(batch_preds)
    yield pd.Series(prediction)

Run the model inference and save the result to Parquet file.

predictions_df = df.select(col("path"), predict_batch_udf(col("content")).alias("prediction"))
predictions_df.write.mode("overwrite").parquet(output_dbfs_dir)

Load and check the prediction results.

result_df = spark.read.parquet(output_dbfs_dir)
display(result_df)

Clean up the directory.

shutil.rmtree("/"+model_dir, ignore_errors=True)
shutil.rmtree("/"+trt_model_dir, ignore_errors=True)
shutil.rmtree("/dbfs/"+output_dbfs_dir, ignore_errors=True)