import os
import pandas as pd
import shutil
import uuid
from pyspark.sql.functions import col, pandas_udf, PandasUDFType
from pyspark.sql.types import ArrayType, FloatType
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.python.compiler.tensorrt import trt_convert as trt
from tensorflow.python.saved_model import signature_constants, tag_constants
from tensorflow.python.framework import convert_to_constants
conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS
conversion_params = conversion_params._replace(precision_mode='FP16')
converter = trt.TrtGraphConverterV2(
input_saved_model_dir=model_dir,
conversion_params=conversion_params,
)
converter.convert()
converter.save(output_saved_model_dir=trt_model_dir)
(ArrayType(FloatType()), PandasUDFType.SCALAR_ITER)
def predict_batch_udf(image_batch_iter):
batch_size = 64
saved_model_loaded = tf.saved_model.load(
trt_model_dir, tags=[tag_constants.SERVING])
graph_func = saved_model_loaded.signatures[
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
graph_func = convert_to_constants.convert_variables_to_constants_v2(graph_func)
for image_batch in image_batch_iter:
dataset = tf.data.Dataset.from_tensor_slices(image_batch)
dataset = dataset.map(parse_example).prefetch(512).batch(64)
prediction = []
for i, batch_images in enumerate(dataset):
batch_preds = graph_func(batch_images)[0].numpy()
prediction = prediction + list(batch_preds)
yield pd.Series(prediction)
Model inference using TensorRT
This notebook demonstrates how to do distributed model inference using TensorFlow and TensorRT with ResNet-50 model.
Note:
To run the notebook, create a GPU-enabled cluster with Databricks Runtime 7.0 ML or above.