import os
import pandas as pd
import shutil
import uuid
from pyspark.sql.functions import col, pandas_udf, PandasUDFType
from pyspark.sql.types import ArrayType, FloatType
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.python.compiler.tensorrt import trt_convert as trt
from tensorflow.python.saved_model import signature_constants, tag_constants
from tensorflow.python.framework import convert_to_constants
conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS
conversion_params = conversion_params._replace(precision_mode='FP16')
converter = trt.TrtGraphConverterV2(
input_saved_model_dir=model_dir,
conversion_params=conversion_params,
)
converter.convert()
converter.save(output_saved_model_dir=trt_model_dir)
Model inference using TensorRT
This notebook demonstrates how to do distributed model inference using TensorFlow and TensorRT with ResNet-50 model.
Note:
To run the notebook, create a GPU-enabled cluster with Databricks Runtime 7.0 ML or above.