tensorflow-tensorrt(Python)

Loading...

Model inference using TensorRT

This notebook demonstrates how to do distributed model inference using TensorFlow and TensorRT with ResNet-50 model.

Note:

To run the notebook, create a GPU-enabled cluster with Databricks Runtime 7.0 ML or above.

import os
import pandas as pd
import shutil
import uuid
  
from pyspark.sql.functions import col, pandas_udf, PandasUDFType
from pyspark.sql.types import ArrayType, FloatType
 
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.python.compiler.tensorrt import trt_convert as trt
from tensorflow.python.saved_model import signature_constants, tag_constants
from tensorflow.python.framework import convert_to_constants

Define the input and output directory.

uid = str(uuid.uuid1())
 
model_dir = f"/dbfs/ml/tmp/{uid}/model"
trt_model_dir = f"/dbfs/ml/tmp/{uid}/trt_model"
output_dbfs_dir = f"/ml/tmp/{uid}/predictions"

Prepare trained model and data for inference

Save the ResNet-50 Model

os.makedirs(model_dir)
model = ResNet50()
model.save(model_dir)

Optimize the model with TensorRT. For more details, check the TF-TRT User Guide.

conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS
conversion_params = conversion_params._replace(precision_mode='FP16')
converter = trt.TrtGraphConverterV2(
        input_saved_model_dir=model_dir,
        conversion_params=conversion_params,
    )
converter.convert()
converter.save(output_saved_model_dir=trt_model_dir)

Load the data into Spark DataFrames