From 5302cdf6f5c0bd839b1d4693721813b6e0d3fc9f Mon Sep 17 00:00:00 2001 From: John Welsh <35354912+jaybdub-nv@users.noreply.github.com> Date: Fri, 30 Nov 2018 15:13:57 -0800 Subject: [PATCH 01/56] added object detection example (#2) * added object detection example * PR fixes * localized object detection third party * license headers * license to setup.py --- .gitmodules | 6 + README.md | 11 +- setup.py | 27 + tftrt/__init__.py | 16 + tftrt/examples/__init__.py | 16 + tftrt/examples/object_detection/README.md | 138 ++++ tftrt/examples/object_detection/__init__.py | 19 + .../examples/object_detection/graph_utils.py | 108 +++ .../object_detection/install_dependencies.sh | 71 ++ .../object_detection/object_detection.py | 632 ++++++++++++++++++ tftrt/examples/object_detection/test.py | 101 +++ .../object_detection/third_party/cocoapi | 1 + .../object_detection/third_party/models | 1 + 13 files changed, 1146 insertions(+), 1 deletion(-) create mode 100644 .gitmodules create mode 100644 setup.py create mode 100644 tftrt/__init__.py create mode 100644 tftrt/examples/__init__.py create mode 100644 tftrt/examples/object_detection/README.md create mode 100644 tftrt/examples/object_detection/__init__.py create mode 100644 tftrt/examples/object_detection/graph_utils.py create mode 100755 tftrt/examples/object_detection/install_dependencies.sh create mode 100644 tftrt/examples/object_detection/object_detection.py create mode 100644 tftrt/examples/object_detection/test.py create mode 160000 tftrt/examples/object_detection/third_party/cocoapi create mode 160000 tftrt/examples/object_detection/third_party/models diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..2688d24bc --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule "third_party/models"] + path = tftrt/examples/object_detection/third_party/models + url = https://github.com/tensorflow/models +[submodule "third_party/cocoapi"] + path = tftrt/examples/object_detection/third_party/cocoapi + url = https://github.com/cocodataset/cocoapi diff --git a/README.md b/README.md index 57e3269b2..90394bad2 100644 --- a/README.md +++ b/README.md @@ -1 +1,10 @@ -Coming soon: Examples using [NVIDIA TensorRT](https://developer.nvidia.com/tensorrt) in TensorFlow. +# TensorRT Integration in TensorFlow + +This repository demonstrates TensorRT integration in TensorFlow. Currently +it contains examples for accelerated image classification and object +detection. + + +## Examples + +* [Object Detection](tftrt/examples/object_detection) diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..a727bc46e --- /dev/null +++ b/setup.py @@ -0,0 +1,27 @@ +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from setuptools import find_packages, setup + +setup( + name='tftrt', + version='0.0', + description='NVIDIA TensorRT integration in TensorFlow', + author='NVIDIA', + packages=find_packages(), + install_requires=['tqdm'] +) diff --git a/tftrt/__init__.py b/tftrt/__init__.py new file mode 100644 index 000000000..04285a017 --- /dev/null +++ b/tftrt/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= diff --git a/tftrt/examples/__init__.py b/tftrt/examples/__init__.py new file mode 100644 index 000000000..04285a017 --- /dev/null +++ b/tftrt/examples/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= diff --git a/tftrt/examples/object_detection/README.md b/tftrt/examples/object_detection/README.md new file mode 100644 index 000000000..89dead0d2 --- /dev/null +++ b/tftrt/examples/object_detection/README.md @@ -0,0 +1,138 @@ +# TensorRT / TensorFlow Object Detection + +This package demonstrated object detection using TensorRT integration in TensorFlow. +It includes utilities for accuracy and performance benchmarking, along with +utilities for model construction and optimization. + +* [Setup](#setup) +* [Download](#od_download) +* [Optimize](#od_optimize) +* [Benchmark](#od_benchmark) +* [Test](#od_test) + + +## Setup + +1. Install object detection dependencies (from tftrt/examples/object_detection) + +```bash +git submodule update --init +./install_dependencies.sh +``` + +2. Ensure you've installed the tftrt package (from root folder of repository) + +```bash +python setup.py install --user +``` + + +## Object Detection + + +### Download +```python +from tftrt.examples.object_detection import download_model + +config_path, checkpoint_path = download_model('ssd_mobilenet_v1_coco', output_dir='models') +# help(download_model) for more +``` + + +### Optimize + +```python +from tftrt.examples.object_detection import optimize_model + +frozen_graph = optimize_model( + config_path=config_path, + checkpoint_path=checkpoint_path, + use_trt=True, + precision_mode='FP16' +) +# help(optimize_model) for other parameters +``` + + +### Benchmark + +First, we download the validation dataset + +```python +from tftrt.examples.object_detection import download_dataset + +images_dir, annotation_path = download_dataset('val2014', output_dir='dataset') +# help(download_dataset) for more +``` + +Next, we run inference over the dataset to benchmark the optimized model + +```python +from tftrt.examples.object_detection import benchmark_model + +statistics = benchmark_model( + frozen_graph=frozen_graph, + images_dir=images_dir, + annotation_path=annotation_path +) +# help(benchmark_model) for more parameters +``` + + +### Test +To simplify evaluation of different models with different optimization parameters +we include a ``test`` function that ingests a JSON file containing test arguments +and combines the model download, optimization, and benchmark steps. Below is an +example JSON file, call it ``my_test.json`` + +```json +{ + "source_model": { + "model_name": "ssd_inception_v2_coco", + "output_dir": "models" + }, + "optimization_config": { + "use_trt": true, + "precision_mode": "FP16", + "force_nms_cpu": true, + "replace_relu6": true, + "remove_assert": true, + "override_nms_score_threshold": 0.3, + "max_batch_size": 1 + }, + "benchmark_config": { + "images_dir": "coco/val2017", + "annotation_path": "coco/annotations/instances_val2017.json", + "batch_size": 1, + "image_shape": [600, 600], + "num_images": 4096, + "output_path": "stats/ssd_inception_v2_coco_trt_fp16.json" + }, + "assertions": [ + "statistics['map'] > (0.268 - 0.005)" + ] +} +``` + +We execute the test using the ``test`` python function + +```python +from tftrt.examples.object_detection import test + +test('my_test.json') +# help(test) for more details +``` + +Alternatively, we can directly call the object_detection.test module, which +is configured to execute this function by default. + +```shell +python -m tftrt.examples.object_detection.test my_test.json +``` + +For the example configuration shown above, the following steps will be performed + +1. Downloads ssd_inception_v2_coco +2. Optimizes with TensorRT and FP16 precision +3. Benchmarks against the MSCOCO 2017 validation dataset +4. Asserts that the MAP is greater than some reference value diff --git a/tftrt/examples/object_detection/__init__.py b/tftrt/examples/object_detection/__init__.py new file mode 100644 index 000000000..d7675e24e --- /dev/null +++ b/tftrt/examples/object_detection/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from .object_detection import download_model, download_dataset, optimize_model, benchmark_model +from .test import test diff --git a/tftrt/examples/object_detection/graph_utils.py b/tftrt/examples/object_detection/graph_utils.py new file mode 100644 index 000000000..775127abb --- /dev/null +++ b/tftrt/examples/object_detection/graph_utils.py @@ -0,0 +1,108 @@ +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import tensorflow as tf + + +def make_const6(const6_name='const6'): + graph = tf.Graph() + with graph.as_default(): + tf_6 = tf.constant(dtype=tf.float32, value=6.0, name=const6_name) + return graph.as_graph_def() + + +def make_relu6(output_name, input_name, const6_name='const6'): + graph = tf.Graph() + with graph.as_default(): + tf_x = tf.placeholder(tf.float32, [10, 10], name=input_name) + tf_6 = tf.constant(dtype=tf.float32, value=6.0, name=const6_name) + with tf.name_scope(output_name): + tf_y1 = tf.nn.relu(tf_x, name='relu1') + tf_y2 = tf.nn.relu(tf.subtract(tf_x, tf_6, name='sub1'), name='relu2') + + #tf_y = tf.nn.relu(tf.subtract(tf_6, tf.nn.relu(tf_x, name='relu1'), name='sub'), name='relu2') + #tf_y = tf.subtract(tf_6, tf_y, name=output_name) + tf_y = tf.subtract(tf_y1, tf_y2, name=output_name) + + graph_def = graph.as_graph_def() + graph_def.node[-1].name = output_name + + # remove unused nodes + for node in graph_def.node: + if node.name == input_name: + graph_def.node.remove(node) + for node in graph_def.node: + if node.name == const6_name: + graph_def.node.remove(node) + for node in graph_def.node: + if node.op == '_Neg': + node.op = 'Neg' + + return graph_def + + +def convert_relu6(graph_def, const6_name='const6'): + # add constant 6 + has_const6 = False + for node in graph_def.node: + if node.name == const6_name: + has_const6 = True + if not has_const6: + const6_graph_def = make_const6(const6_name=const6_name) + graph_def.node.extend(const6_graph_def.node) + + for node in graph_def.node: + if node.op == 'Relu6': + input_name = node.input[0] + output_name = node.name + relu6_graph_def = make_relu6(output_name, input_name, const6_name=const6_name) + graph_def.node.remove(node) + graph_def.node.extend(relu6_graph_def.node) + + return graph_def + + +def remove_node(graph_def, node): + for n in graph_def.node: + if node.name in n.input: + n.input.remove(node.name) + ctrl_name = '^' + node.name + if ctrl_name in n.input: + n.input.remove(ctrl_name) + graph_def.node.remove(node) + + +def remove_op(graph_def, op_name): + matches = [node for node in graph_def.node if node.op == op_name] + for match in matches: + remove_node(graph_def, match) + + +def force_nms_cpu(frozen_graph): + for node in frozen_graph.node: + if 'NonMaxSuppression' in node.name: + node.device = '/device:CPU:0' + return frozen_graph + + +def replace_relu6(frozen_graph): + return convert_relu6(frozen_graph) + + +def remove_assert(frozen_graph): + remove_op(frozen_graph, 'Assert') + return frozen_graph diff --git a/tftrt/examples/object_detection/install_dependencies.sh b/tftrt/examples/object_detection/install_dependencies.sh new file mode 100755 index 000000000..0f55d90db --- /dev/null +++ b/tftrt/examples/object_detection/install_dependencies.sh @@ -0,0 +1,71 @@ +#!/bin/bash +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +TF_MODELS_DIR=third_party/models +COCO_API_DIR=third_party/cocoapi + +python -V 2>&1 | grep "Python 3" || \ + ( export DEBIAN_FRONTEND=noninteractive && \ + apt-get update && \ + apt-get install -y --no-install-recommends python-tk ) + +RESEARCH_DIR=$TF_MODELS_DIR/research +SLIM_DIR=$RESEARCH_DIR/slim +PYCOCO_DIR=$COCO_API_DIR/PythonAPI + +pushd $RESEARCH_DIR + +# GET PROTOC 3.5 + +BASE_URL="https://github.com/google/protobuf/releases/download/v3.5.1/" +PROTOC_DIR=protoc +PROTOC_EXE=$PROTOC_DIR/bin/protoc + +mkdir -p $PROTOC_DIR +pushd $PROTOC_DIR +ARCH=$(uname -m) +if [ "$ARCH" == "aarch64" ] ; then + filename="protoc-3.5.1-linux-aarch_64.zip" +elif [ "$ARCH" == "x86_64" ] ; then + filename="protoc-3.5.1-linux-x86_64.zip" +else + echo ERROR: $ARCH not supported. + exit 1; +fi +wget --no-check-certificate ${BASE_URL}${filename} +unzip ${filename} +popd + +# BUILD PROTOBUF FILES +$PROTOC_EXE object_detection/protos/*.proto --python_out=. + +# INSTALL OBJECT DETECTION + +pip install -e . + +popd + +pushd $SLIM_DIR +pip install -e . +popd + +# INSTALL PYCOCOTOOLS + +pushd $PYCOCO_DIR +pip install -e . +popd diff --git a/tftrt/examples/object_detection/object_detection.py b/tftrt/examples/object_detection/object_detection.py new file mode 100644 index 000000000..38e86a300 --- /dev/null +++ b/tftrt/examples/object_detection/object_detection.py @@ -0,0 +1,632 @@ +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + + +from __future__ import absolute_import + +import tensorflow as tf +import tensorflow.contrib.tensorrt as trt +import tqdm +import pdb + +from collections import namedtuple +from PIL import Image +import numpy as np +import time +import json +import subprocess +import os +import glob + +from .graph_utils import force_nms_cpu as f_force_nms_cpu +from .graph_utils import replace_relu6 as f_replace_relu6 +from .graph_utils import remove_assert as f_remove_assert + +from google.protobuf import text_format +from object_detection.protos import pipeline_pb2, image_resizer_pb2 +from object_detection import exporter + +Model = namedtuple('Model', ['name', 'url', 'extract_dir']) + +INPUT_NAME = 'image_tensor' +BOXES_NAME = 'detection_boxes' +CLASSES_NAME = 'detection_classes' +SCORES_NAME = 'detection_scores' +MASKS_NAME = 'detection_masks' +NUM_DETECTIONS_NAME = 'num_detections' +FROZEN_GRAPH_NAME = 'frozen_inference_graph.pb' +PIPELINE_CONFIG_NAME = 'pipeline.config' +CHECKPOINT_PREFIX = 'model.ckpt' + +MODELS = { + 'ssd_mobilenet_v1_coco': + Model( + 'ssd_mobilenet_v1_coco', + 'http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz', + 'ssd_mobilenet_v1_coco_2018_01_28', + ), + 'ssd_mobilenet_v1_0p75_depth_quantized_coco': + Model( + 'ssd_mobilenet_v1_0p75_depth_quantized_coco', + 'http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_0.75_depth_quantized_300x300_coco14_sync_2018_07_18.tar.gz', + 'ssd_mobilenet_v1_0.75_depth_quantized_300x300_coco14_sync_2018_07_18' + ), + 'ssd_mobilenet_v1_ppn_coco': + Model( + 'ssd_mobilenet_v1_ppn_coco', + 'http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_ppn_shared_box_predictor_300x300_coco14_sync_2018_07_03.tar.gz', + 'ssd_mobilenet_v1_ppn_shared_box_predictor_300x300_coco14_sync_2018_07_03' + ), + 'ssd_mobilenet_v1_fpn_coco': + Model( + 'ssd_mobilenet_v1_fpn_coco', + 'http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03.tar.gz', + 'ssd_mobilenet_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03' + ), + 'ssd_mobilenet_v2_coco': + Model( + 'ssd_mobilenet_v2_coco', + 'http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v2_coco_2018_03_29.tar.gz', + 'ssd_mobilenet_v2_coco_2018_03_29', + ), + 'ssdlite_mobilenet_v2_coco': + Model( + 'ssdlite_mobilenet_v2_coco', + 'http://download.tensorflow.org/models/object_detection/ssdlite_mobilenet_v2_coco_2018_05_09.tar.gz', + 'ssdlite_mobilenet_v2_coco_2018_05_09'), + 'ssd_inception_v2_coco': + Model( + 'ssd_inception_v2_coco', + 'http://download.tensorflow.org/models/object_detection/ssd_inception_v2_coco_2018_01_28.tar.gz', + 'ssd_inception_v2_coco_2018_01_28', + ), + 'ssd_resnet_50_fpn_coco': + Model( + 'ssd_resnet_50_fpn_coco', + 'http://download.tensorflow.org/models/object_detection/ssd_resnet50_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03.tar.gz', + 'ssd_resnet50_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03', + ), + 'faster_rcnn_resnet50_coco': + Model( + 'faster_rcnn_resnet50_coco', + 'http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet50_coco_2018_01_28.tar.gz', + 'faster_rcnn_resnet50_coco_2018_01_28', + ), + 'faster_rcnn_nas': + Model( + 'faster_rcnn_nas', + 'http://download.tensorflow.org/models/object_detection/faster_rcnn_nas_coco_2018_01_28.tar.gz', + 'faster_rcnn_nas_coco_2018_01_28', + ), + 'mask_rcnn_resnet50_atrous_coco': + Model( + 'mask_rcnn_resnet50_atrous_coco', + 'http://download.tensorflow.org/models/object_detection/mask_rcnn_resnet50_atrous_coco_2018_01_28.tar.gz', + 'mask_rcnn_resnet50_atrous_coco_2018_01_28', + ), + 'facessd_mobilenet_v2_quantized_open_image_v4': + Model( + 'facessd_mobilenet_v2_quantized_open_image_v4', + 'http://download.tensorflow.org/models/object_detection/facessd_mobilenet_v2_quantized_320x320_open_image_v4.tar.gz', + 'facessd_mobilenet_v2_quantized_320x320_open_image_v4') +} + +Dataset = namedtuple( + 'Dataset', + ['images_url', 'images_dir', 'annotation_url', 'annotation_path']) + +DATASETS = { + 'val2014': + Dataset( + 'http://images.cocodataset.org/zips/val2014.zip', 'val2014', + 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip', + 'annotations/instances_val2014.json'), + 'train2014': + Dataset( + 'http://images.cocodataset.org/zips/train2014.zip', 'train2014', + 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip', + 'annotations/instances_train2014.json'), + 'val2017': + Dataset( + 'http://images.cocodataset.org/zips/val2017.zip', 'val2017', + 'http://images.cocodataset.org/annotations/annotations_trainval2017.zip', + 'annotations/instances_val2017.json'), + 'train2017': + Dataset( + 'http://images.cocodataset.org/zips/train2017.zip', 'train2017', + 'http://images.cocodataset.org/annotations/annotations_trainval2017.zip', + 'annotations/instances_train2017.json') +} + + +def download_model(model_name, output_dir='.'): + """Downloads a model from the TensorFlow Object Detection API + + Downloads a model from the TensorFlow Object Detection API to a specific + output directory. The download will be skipped if an existing directory + for the selected model already found under output_dir. + + Args + ---- + model_name: A string representing the model to download. This must be + one of the keys in the module variable + ``trt_samples.object_detection.MODELS``. + output_dir: A string representing the directory to download the model + under. A directory for the specified model will be created at + ``output_dir/``. If output_dir/ + already exists, then the download will be skipped. + + Returns + ------- + config_path: A string representing the path to the object detection + pipeline configuration file of the downloaded model. + checkpoint_path: A string representing the path to the object detection + model checkpoint. + """ + global MODELS + + model_name + + model = MODELS[model_name] + + # make output directory if it doesn't exist + subprocess.call(['mkdir', '-p', output_dir]) + + tar_file = os.path.join(output_dir, os.path.basename(model.url)) + + config_path = os.path.join(output_dir, model.extract_dir, + PIPELINE_CONFIG_NAME) + checkpoint_path = os.path.join(output_dir, model.extract_dir, + CHECKPOINT_PREFIX) + + extract_dir = os.path.join(output_dir, model.extract_dir) + if os.path.exists(extract_dir): + print('Using cached model found at: %s' % extract_dir) + else: + subprocess.call(['wget', model.url, '-O', tar_file]) + subprocess.call(['tar', '-xzf', tar_file, '-C', output_dir]) + + # hack fix to handle mobilenet_v2 config bug + subprocess.call(['sed', '-i', '/batch_norm_trainable/d', config_path]) + + return config_path, checkpoint_path + + +def optimize_model(config_path, + checkpoint_path, + use_trt=True, + force_nms_cpu=True, + replace_relu6=True, + remove_assert=True, + override_nms_score_threshold=None, + override_resizer_shape=None, + max_batch_size=1, + precision_mode='FP32', + minimum_segment_size=50, + max_workspace_size_bytes=1 << 25, + calib_images_dir=None, + num_calib_images=None, + calib_image_shape=None, + tmp_dir='.optimize_model_tmp_dir', + remove_tmp_dir=True, + output_path=None): + """Optimizes an object detection model using TensorRT + + Optimizes an object detection model using TensorRT. This method also + performs pre-tensorrt optimizations specific to the TensorFlow object + detection API models. Please see the list of arguments for other + optimization parameters. + + Args + ---- + config_path: A string representing the path of the object detection + pipeline config file. + checkpoint_path: A string representing the path of the object + detection model checkpoint. + use_trt: A boolean representing whether to optimize with TensorRT. If + False, regular TensorFlow will be used but other optimizations + (like NMS device placement) will still be applied. + force_nms_cpu: A boolean indicating whether to place NMS operations on + the CPU. + replace_relu6: A boolean indicating whether to replace relu6(x) + operations with relu(x) - relu(x-6). + remove_assert: A boolean indicating whether to remove Assert + operations from the graph. + override_nms_score_threshold: An optional float representing + a NMS score threshold to override that specified in the object + detection configuration file. + override_resizer_shape: An optional list/tuple of integers + representing a fixed shape to override the default image resizer + specified in the object detection configuration file. + max_batch_size: An integer representing the max batch size to use for + TensorRT optimization. + precision_mode: A string representing the precision mode to use for + TensorRT optimization. Must be one of 'FP32', 'FP16', or 'INT8'. + minimum_segment_size: An integer representing the minimum segment size + to use for TensorRT graph segmentation. + max_workspace_size_bytes: An integer representing the max workspace + size for TensorRT optimization. + calib_images_dir: A string representing a directory containing images to + use for int8 calibration. + num_calib_images: An integer representing the number of calibration + images to use. If None, will use all images in directory. + calib_image_shape: A tuple of integers representing the height, + width that images will be resized to for calibration. + tmp_dir: A string representing a directory for temporary files. This + directory will be created and removed by this function and should + not already exist. If the directory exists, an error will be + thrown. + remove_tmp_dir: A boolean indicating whether we should remove the + tmp_dir or throw error. + output_path: An optional string representing the path to save the + optimized GraphDef to. + + Returns + ------- + A GraphDef representing the optimized model. + """ + if os.path.exists(tmp_dir): + if not remove_tmp_dir: + raise RuntimeError( + 'Cannot create temporary directory, path exists: %s' % tmp_dir) + subprocess.call(['rm', '-rf', tmp_dir]) + + # load config from file + config = pipeline_pb2.TrainEvalPipelineConfig() + with open(config_path, 'r') as f: + text_format.Merge(f.read(), config, allow_unknown_extension=True) + + # override some config parameters + if config.model.HasField('ssd'): + config.model.ssd.feature_extractor.override_base_feature_extractor_hyperparams = True + if override_nms_score_threshold is not None: + config.model.ssd.post_processing.batch_non_max_suppression.score_threshold = override_nms_score_threshold + if override_resizer_shape is not None: + config.model.ssd.image_resizer.fixed_shape_resizer.height = override_resizer_shape[ + 0] + config.model.ssd.image_resizer.fixed_shape_resizer.width = override_resizer_shape[ + 1] + elif config.model.HasField('faster_rcnn'): + if override_nms_score_threshold is not None: + config.model.faster_rcnn.second_stage_post_processing.score_threshold = override_nms_score_threshold + if override_resizer_shape is not None: + config.model.faster_rcnn.image_resizer.fixed_shape_resizer.height = override_resizer_shape[ + 0] + config.model.faster_rcnn.image_resizer.fixed_shape_resizer.width = override_resizer_shape[ + 1] + + tf_config = tf.ConfigProto() + tf_config.gpu_options.allow_growth = True + + # export inference graph to file (initial), this will create tmp_dir + with tf.Session(config=tf_config): + with tf.Graph().as_default(): + exporter.export_inference_graph( + INPUT_NAME, + config, + checkpoint_path, + tmp_dir, + input_shape=[max_batch_size, None, None, 3]) + + # read frozen graph from file + frozen_graph_path = os.path.join(tmp_dir, FROZEN_GRAPH_NAME) + frozen_graph = tf.GraphDef() + with open(frozen_graph_path, 'rb') as f: + frozen_graph.ParseFromString(f.read()) + + # apply graph modifications + if force_nms_cpu: + frozen_graph = f_force_nms_cpu(frozen_graph) + if replace_relu6: + frozen_graph = f_replace_relu6(frozen_graph) + if remove_assert: + frozen_graph = f_remove_assert(frozen_graph) + + # get input names + output_names = [BOXES_NAME, CLASSES_NAME, SCORES_NAME, NUM_DETECTIONS_NAME] + + # optionally perform TensorRT optimization + if use_trt: + with tf.Graph().as_default() as tf_graph: + with tf.Session(config=tf_config) as tf_sess: + frozen_graph = trt.create_inference_graph( + input_graph_def=frozen_graph, + outputs=output_names, + max_batch_size=max_batch_size, + max_workspace_size_bytes=max_workspace_size_bytes, + precision_mode=precision_mode, + minimum_segment_size=minimum_segment_size) + + # perform calibration for int8 precision + if precision_mode == 'INT8': + + if calib_images_dir is None: + raise ValueError('calib_images_dir must be provided for int8 optimization.') + + tf.import_graph_def(frozen_graph, name='') + tf_input = tf_graph.get_tensor_by_name(INPUT_NAME + ':0') + tf_boxes = tf_graph.get_tensor_by_name(BOXES_NAME + ':0') + tf_classes = tf_graph.get_tensor_by_name(CLASSES_NAME + ':0') + tf_scores = tf_graph.get_tensor_by_name(SCORES_NAME + ':0') + tf_num_detections = tf_graph.get_tensor_by_name( + NUM_DETECTIONS_NAME + ':0') + + image_paths = glob.glob(os.path.join(calib_images_dir, '*.jpg')) + image_paths = image_paths[0:num_calib_images] + + for image_idx in tqdm.tqdm(range(0, len(image_paths), max_batch_size)): + + # read batch of images + batch_images = [] + for image_path in image_paths[image_idx:image_idx+max_batch_size]: + image = _read_image(image_path, calib_image_shape) + batch_images.append(image) + + # execute batch of images + boxes, classes, scores, num_detections = tf_sess.run( + [tf_boxes, tf_classes, tf_scores, tf_num_detections], + feed_dict={tf_input: batch_images}) + + pdb.set_trace() + frozen_graph = trt.calib_graph_to_infer_graph(frozen_graph) + + # re-enable variable batch size, this was forced to max + # batch size during export to enable TensorRT optimization + for node in frozen_graph.node: + if INPUT_NAME == node.name: + node.attr['shape'].shape.dim[0].size = -1 + + # write optimized model to disk + if output_path is not None: + with open(output_path, 'wb') as f: + f.write(frozen_graph.SerializeToString()) + + # remove temporary directory + subprocess.call(['rm', '-rf', tmp_dir]) + + return frozen_graph + + +def download_dataset(dataset_name, output_dir='.'): + """Downloads a COCO dataset + + Downloads a COCO dataset to the specified output directory. A new + directory corresponding to the specified dataset will be created under + output_dir. This directory will contain the images of the dataset. + + Args + ---- + dataset_name: A string representing the name of the dataset, it must + be one of the keys in trt_samples.object_detection.DATASETS. + + Returns + ------- + images_dir: A string representing the path of the directory containing + images of the dataset. + annotation_path: A string representing the path of the COCO annotation + file for the dataset. + """ + global DATASETS + + dataset = DATASETS[dataset_name] + + subprocess.call(['mkdir', '-p', output_dir]) + + images_dir = os.path.join(output_dir, dataset.images_dir) + images_zip_file = os.path.join(output_dir, + os.path.basename(dataset.images_url)) + annotation_path = os.path.join(output_dir, dataset.annotation_path) + annotation_zip_file = os.path.join( + output_dir, os.path.basename(dataset.annotation_url)) + + # download or use cached annotation + if os.path.exists(annotation_path): + print('Using cached annotation_path; %s' % (annotation_path)) + else: + subprocess.call( + ['wget', dataset.annotation_url, '-O', annotation_zip_file]) + subprocess.call(['unzip', annotation_zip_file, '-d', output_dir]) + + # download or use cached images + if os.path.exists(images_dir): + print('Using cached images_dir; %s' % (images_dir)) + else: + subprocess.call(['wget', dataset.images_url, '-O', images_zip_file]) + subprocess.call(['unzip', images_zip_file, '-d', output_dir]) + + return images_dir, annotation_path + + +def benchmark_model(frozen_graph, + images_dir, + annotation_path, + batch_size=1, + image_shape=None, + num_images=4096, + tmp_dir='.benchmark_model_tmp_dir', + remove_tmp_dir=True, + output_path=None): + """Computes accuracy and performance statistics + + Computes accuracy and performance statistics by executing over many images + from the MSCOCO dataset defined by images_dir and annotation_path. + + Args + ---- + frozen_graph: A GraphDef representing the object detection model to + test. Alternatively, a string representing the path to the saved + frozen graph. + images_dir: A string representing the path of the COCO images + directory. + annotation_path: A string representing the path of the COCO annotation + file. + batch_size: An integer representing the batch size to use when feeding + images to the model. + image_shape: An optional tuple of integers representing a fixed shape + to resize all images before testing. + num_images: An integer representing the number of images in the + dataset to evaluate with. + tmp_dir: A string representing the path where the function may create + a temporary directory to store intermediate files. + output_path: An optional string representing a path to store the + statistics in JSON format. + + Returns + ------- + statistics: A named dictionary of accuracy and performance statistics + computed for the model. + """ + if os.path.exists(tmp_dir): + if not remove_tmp_dir: + raise RuntimeError('Temporary directory exists; %s' % tmp_dir) + subprocess.call(['rm', '-rf', tmp_dir]) + if batch_size > 1 and image_shape is None: + raise RuntimeError( + 'Fixed image shape must be provided for batch size > 1') + + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval + + coco = COCO(annotation_file=annotation_path) + + # get list of image ids to use for evaluation + image_ids = coco.getImgIds() + if num_images > len(image_ids): + print( + 'Num images provided %d exceeds number in dataset %d, using %d images instead' + % (num_images, len(image_ids), len(image_ids))) + num_images = len(image_ids) + image_ids = image_ids[0:num_images] + + # load frozen graph from file if string, otherwise must be GraphDef + if isinstance(frozen_graph, str): + frozen_graph_path = frozen_graph + frozen_graph = tf.GraphDef() + with open(frozen_graph_path, 'rb') as f: + frozen_graph.ParseFromString(f.read()) + elif not isinstance(frozen_graph, tf.GraphDef): + raise TypeError('Expected frozen_graph to be GraphDef or str') + + tf_config = tf.ConfigProto() + tf_config.gpu_options.allow_growth = True + + coco_detections = [] # list of all bounding box detections in coco format + runtimes = [] # list of runtimes for each batch + image_counts = [] # list of number of images in each batch + + with tf.Graph().as_default() as tf_graph: + with tf.Session(config=tf_config) as tf_sess: + tf.import_graph_def(frozen_graph, name='') + tf_input = tf_graph.get_tensor_by_name(INPUT_NAME + ':0') + tf_boxes = tf_graph.get_tensor_by_name(BOXES_NAME + ':0') + tf_classes = tf_graph.get_tensor_by_name(CLASSES_NAME + ':0') + tf_scores = tf_graph.get_tensor_by_name(SCORES_NAME + ':0') + tf_num_detections = tf_graph.get_tensor_by_name( + NUM_DETECTIONS_NAME + ':0') + + # load batches from coco dataset + for image_idx in tqdm.tqdm(range(0, len(image_ids), batch_size)): + batch_image_ids = image_ids[image_idx:image_idx + batch_size] + batch_images = [] + batch_coco_images = [] + + # read images from file + for image_id in batch_image_ids: + coco_img = coco.imgs[image_id] + batch_coco_images.append(coco_img) + image_path = os.path.join(images_dir, + coco_img['file_name']) + image = _read_image(image_path, image_shape) + batch_images.append(image) + + # run once outside of timing to initialize + if image_idx == 0: + boxes, classes, scores, num_detections = tf_sess.run( + [tf_boxes, tf_classes, tf_scores, tf_num_detections], + feed_dict={tf_input: batch_images}) + + # execute model and compute time difference + t0 = time.time() + boxes, classes, scores, num_detections = tf_sess.run( + [tf_boxes, tf_classes, tf_scores, tf_num_detections], + feed_dict={tf_input: batch_images}) + t1 = time.time() + + # log runtime and image count + runtimes.append(float(t1 - t0)) + image_counts.append(len(batch_images)) + + # add coco detections for this batch to running list + batch_coco_detections = [] + for i, image_id in enumerate(batch_image_ids): + image_width = batch_coco_images[i]['width'] + image_height = batch_coco_images[i]['height'] + + for j in range(int(num_detections[i])): + bbox = boxes[i][j] + bbox_coco_fmt = [ + bbox[1] * image_width, # x0 + bbox[0] * image_height, # x1 + (bbox[3] - bbox[1]) * image_width, # width + (bbox[2] - bbox[0]) * image_height, # height + ] + + coco_detection = { + 'image_id': image_id, + 'category_id': int(classes[i][j]), + 'bbox': bbox_coco_fmt, + 'score': float(scores[i][j]) + } + + coco_detections.append(coco_detection) + + # write coco detections to file + subprocess.call(['mkdir', '-p', tmp_dir]) + coco_detections_path = os.path.join(tmp_dir, 'coco_detections.json') + with open(coco_detections_path, 'w') as f: + json.dump(coco_detections, f) + + # compute coco metrics + cocoDt = coco.loadRes(coco_detections_path) + eval = COCOeval(coco, cocoDt, 'bbox') + eval.params.imgIds = image_ids + + eval.evaluate() + eval.accumulate() + eval.summarize() + + statistics = { + 'map': eval.stats[0], + 'avg_latency_ms': 1000.0 * np.mean(runtimes), + 'avg_throughput_fps': np.sum(image_counts) / np.sum(runtimes) + } + + if output_path is not None: + subprocess.call(['mkdir', '-p', os.path.dirname(output_path)]) + with open(output_path, 'w') as f: + json.dump(statistics, f) + + subprocess.call(['rm', '-rf', tmp_dir]) + + return statistics + + +def _read_image(image_path, image_shape): + image = Image.open(image_path).convert('RGB') + if image_shape is not None: + image = image.resize(image_shape[::-1]) + return np.array(image) diff --git a/tftrt/examples/object_detection/test.py b/tftrt/examples/object_detection/test.py new file mode 100644 index 000000000..b7c2248c6 --- /dev/null +++ b/tftrt/examples/object_detection/test.py @@ -0,0 +1,101 @@ +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import argparse +import json +from .object_detection import download_model, download_dataset, optimize_model, benchmark_model + + +def test(test_config_path): + """Runs an object detection test configuration + + This runs an object detection test configuration. This involves + + 1. Download a model architecture (or use cached). + 2. Optimize the downloaded model architecrue + 3. Benchmark the optimized model against a dataset + 4. (optional) Run assertions to check the benchmark output + + The input to this function is a JSON file which specifies the test + configuration. + + example_test_config.json: + + { + "source_model": { ... }, + "optimization_config": { ... }, + "benchmark_config": { ... }, + "assertions": [ ... ] + } + + source_model: A dictionary of arguments passed to download_model, which + specify the pre-optimized model architure. The model downloaded (or + the cached model if found) will be passed to optimize_model. + optimization_config: A dictionary of arguments passed to optimize_model. + Please see help(optimize_model) for more details. + benchmark_config: A dictionary of arguments passed to benchmark_model. + Please see help(benchmark_model) for more details. + assertions: A list of strings containing python code that will be + evaluated. If the code returns false, an error will be thrown. These + assertions can reference any variables local to this 'test' function. + Some useful values are + + statistics['map'] + statistics['avg_latency'] + statistics['avg_throughput'] + + Args + ---- + test_config_path: A string corresponding to the test configuration + JSON file. + """ + with open(args.test_config_path, 'r') as f: + test_config = json.load(f) + print(json.dumps(test_config, sort_keys=True, indent=4)) + + # download model or use cached + config_path, checkpoint_path = download_model(**test_config['source_model']) + + # optimize model using source model + frozen_graph = optimize_model( + config_path=config_path, + checkpoint_path=checkpoint_path, + **test_config['optimization_config']) + + # benchmark optimized model + statistics = benchmark_model( + frozen_graph=frozen_graph, + **test_config['benchmark_config']) + print(json.dumps(statistics, sort_keys=True, indent=4)) + + # run assertions + if 'assertions' in test_config: + for a in test_config['assertions']: + if not eval(a): + raise AssertionError('ASSERTION FAILED: %s' % a) + else: + print('ASSERTION PASSED: %s' % a) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + 'test_config_path', + help='Path of JSON file containing test configuration. Please' + 'see help(tftrt.examples.object_detection.test) for more information') + args=parser.parse_args() + test(args.test_config_path) diff --git a/tftrt/examples/object_detection/third_party/cocoapi b/tftrt/examples/object_detection/third_party/cocoapi new file mode 160000 index 000000000..ed842bffd --- /dev/null +++ b/tftrt/examples/object_detection/third_party/cocoapi @@ -0,0 +1 @@ +Subproject commit ed842bffd41f6ff38707c4f0968d2cfd91088688 diff --git a/tftrt/examples/object_detection/third_party/models b/tftrt/examples/object_detection/third_party/models new file mode 160000 index 000000000..402b561b0 --- /dev/null +++ b/tftrt/examples/object_detection/third_party/models @@ -0,0 +1 @@ +Subproject commit 402b561b03857151f684ee00b3d997e5e6be9778 From 7ce269000fbd31b54de121a5efb9ac120ca30ddd Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Fri, 30 Nov 2018 19:10:07 -0800 Subject: [PATCH 02/56] Update README.md --- tftrt/examples/image-classification/README.md | 51 ++++++++++++++----- 1 file changed, 39 insertions(+), 12 deletions(-) diff --git a/tftrt/examples/image-classification/README.md b/tftrt/examples/image-classification/README.md index d4b8d662f..09929a97d 100644 --- a/tftrt/examples/image-classification/README.md +++ b/tftrt/examples/image-classification/README.md @@ -1,24 +1,23 @@ -# TensorFlow-TensorRT Examples +# Image classification examples -This script will run inference using a few popular image classification models -on the ImageNet validation set. +This example includes scripts to run inference using a number of popular image classification models. You can turn on TensorFlow-TensorRT integration with the flag `--use_trt`. This will apply TensorRT inference optimization to speed up execution for portions of the model's graph where supported, and will fall back to native TensorFlow for -layers and operations which are not supported. See -https://devblogs.nvidia.com/tensorrt-integration-speeds-tensorflow-inference/ -for more information. +layers and operations which are not supported. +See https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html for more information. When using TF-TRT, you can also control the precision with `--precision`. float32 is the default (`--precision fp32`) with float16 (`--precision fp16`) or -int8 (`--precision int8`) allowing further performance improvements, at the cost -of some accuracy. int8 mode requires a calibration step which is done +int8 (`--precision int8`) allowing further performance improvements. +int8 mode requires a calibration step which is done automatically. ## Models -This test supports the following models for image classification: +We have verified the following models. + * MobileNet v1 * MobileNet v2 * NASNet - Large @@ -30,6 +29,10 @@ This test supports the following models for image classification: * Inception v3 * Inception v4 +For the accuracy numbers of these models on the +ImageNet validation dataset, see +[Verified Models](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#verified-models) + ## Setup ``` # Clone [tensorflow/models](https://github.com/tensorflow/models) @@ -52,19 +55,43 @@ add `export PYTHONPATH="$PYTHONPATH:/path/to/tensorflow_models"` to your .bashrc file (replacing /path/to/tensorflow_models with the path to your tensorflow/models repository). +See [Setting Up The Environment +](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#image-class-envirn) +for more information. + ### Data -The script supports only TFRecord format for data. The script -assumes that validation TFRecords are named according to the pattern: -`validation-*-of-00128`. +The example supports using a dataset in TFRecords or synthetic data. +In case of using TFRecord files, the scripts assume that TFRecords +are named according to the pattern: `validation-*-of-00128`. +The reported accuracy numbers are the results of running the scripts on +the ImageNet validation dataset. You can download and process Imagenet using [this script provided by TF Slim](https://github.com/tensorflow/models/blob/master/research/slim/datasets/download_imagenet.sh). Please note that this script downloads both the training and validation sets, and this example only requires the validation set. +See [Obtaining The ImageNet Data +](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#image-class-data) +for more information. + ## Usage `python inference.py --data_dir /imagenet_validation_data --model vgg_16 [--use_trt]` Run with `--help` to see all available options. + +See [General Script Usage +](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#image-class-usage) +for more information. + +### Accuracy tests + +There is the script `check_accuracy.py` provided in the example that parses the output log of `inference.py` +to find the reported accuracy, and reports whether that accuracy matches with the +baseline numbers. + +See [Checking Accuracy +](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#image-class-accuracy) +for more information. From b083b98908042b8f40f31101afb9719d00005b54 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Fri, 30 Nov 2018 19:11:05 -0800 Subject: [PATCH 03/56] Update README.md --- README.md | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 90394bad2..0e7e2ef7a 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,18 @@ -# TensorRT Integration in TensorFlow +# Examples for TF-TRT -This repository demonstrates TensorRT integration in TensorFlow. Currently -it contains examples for accelerated image classification and object -detection. +[TF-TRT](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html) +(TensorFlow integration with TensorRT) is a part of TensorFlow +that optimizes TensorFlow graphs using +[TensorRT](https://developer.nvidia.com/tensorrt). + +This repository contains a number of different examples +that show how to use TF-TRT. - ## Examples +* [Image Classification](tftrt/examples/image_classification) * [Object Detection](tftrt/examples/object_detection) + +## License + +[Apache License 2.0](LICENSE) From 06ed779f4a8ff594745b199e67dbfa97543573e8 Mon Sep 17 00:00:00 2001 From: Ritwik Sharma Date: Sun, 2 Dec 2018 23:25:42 +0530 Subject: [PATCH 04/56] Correct broken link (#4) * Correct broken link Corrected broken link for image-classification in examples. * Correct broken link --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0e7e2ef7a..89e7ed62e 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ that show how to use TF-TRT. ## Examples -* [Image Classification](tftrt/examples/image_classification) +* [Image Classification](tftrt/examples/image-classification) * [Object Detection](tftrt/examples/object_detection) ## License From 25ec9950df355ab0d8479f72638672985b2842e7 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Sun, 9 Dec 2018 09:40:12 -0800 Subject: [PATCH 05/56] Update README.md (#6) --- README.md | 75 ++++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 69 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 89e7ed62e..3bcdf2aab 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,81 @@ -# Examples for TF-TRT +# Examples for TensorRT in TensorFlow (TF-TRT) -[TF-TRT](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html) -(TensorFlow integration with TensorRT) is a part of TensorFlow +This repository contains a number of different examples +that show how to use +[TF-TRT](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/tensorrt). +TF-TRT is a part of TensorFlow that optimizes TensorFlow graphs using [TensorRT](https://developer.nvidia.com/tensorrt). - -This repository contains a number of different examples -that show how to use TF-TRT. +We have used these examples to verify the accuracy and +performance of TF-TRT. For more information see +[Verified Models](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#verified-models). ## Examples * [Image Classification](tftrt/examples/image-classification) * [Object Detection](tftrt/examples/object_detection) + +# Using TensorRT in TensorFlow (TF-TRT) + +This module provides necessary bindings and introduces +`TRTEngineOp` operator that wraps a subgraph in TensorRT. +This module is under active development. + + +## Installing TF-TRT + +Currently Tensorflow nightly builds include TF-TRT by default, +which means you don't need to install TF-TRT separately. +You can pull the latest TF containers from docker hub or +install the latest TF pip package to get access to the latest TF-TRT. + +If you want to use TF-TRT on NVIDIA Jetson platform, you can find +the download links for the relevant Tensorflow pip packages here: +https://docs.nvidia.com/deeplearning/dgx/index.html#installing-frameworks-for-jetson + + +## Installing TensorRT + +In order to make use of TF-TRT, you will need a local installation +of TensorRT from the +[NVIDIA Developer website](https://developer.nvidia.com/tensorrt). +Installation instructions for compatibility with TensorFlow are provided on the +[TensorFlow GPU support](https://www.tensorflow.org/install/gpu) guide. + + +## Documentation + +[TF-TRT documentaion](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html) +gives an overview of the supported functionalities, provides tutorials +and verified models, explains best practices with troubleshooting guides. + + +## Tests + +TF-TRT includes both Python tests and C++ unit tests. +Most of Python tests are located in the test directory +and they can be executed uring `bazel test` or directly +with the Python command. Most of the C++ unit tests are +used to test the conversion functions that convert each TF op to +a number of TensorRT layers. + + +## Compilation + +In order to compile the module, you need to have a local TensorRT installation +(libnvinfer.so and respective include files). During the configuration step, +TensorRT should be enabled and installation path should be set. If installed +through package managers (deb,rpm), configure script should find the necessary +components from the system automatically. If installed from tar packages, user +has to set path to location where the library is installed during configuration. + +```shell +bazel build --config=cuda --config=opt //tensorflow/tools/pip_package:build_pip_package +bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/ +``` + + ## License [Apache License 2.0](LICENSE) From b1dbe8068e7b0fda03260b4788e253503b1ecf0f Mon Sep 17 00:00:00 2001 From: John Welsh <35354912+jaybdub-nv@users.noreply.github.com> Date: Fri, 21 Dec 2018 09:28:09 -0800 Subject: [PATCH 06/56] Score thresh bug fix (#7) * added fix * added runtimes to test * Clean up install_dependencies.sh --- .../object_detection/install_dependencies.sh | 43 +++++++++---------- .../object_detection/object_detection.py | 5 ++- tftrt/examples/object_detection/test.py | 7 ++- 3 files changed, 30 insertions(+), 25 deletions(-) diff --git a/tftrt/examples/object_detection/install_dependencies.sh b/tftrt/examples/object_detection/install_dependencies.sh index 0f55d90db..5d7a0e61a 100755 --- a/tftrt/examples/object_detection/install_dependencies.sh +++ b/tftrt/examples/object_detection/install_dependencies.sh @@ -16,26 +16,24 @@ # limitations under the License. # ============================================================================= +echo Setup local variables... TF_MODELS_DIR=third_party/models -COCO_API_DIR=third_party/cocoapi - -python -V 2>&1 | grep "Python 3" || \ - ( export DEBIAN_FRONTEND=noninteractive && \ - apt-get update && \ - apt-get install -y --no-install-recommends python-tk ) - RESEARCH_DIR=$TF_MODELS_DIR/research SLIM_DIR=$RESEARCH_DIR/slim +COCO_API_DIR=third_party/cocoapi PYCOCO_DIR=$COCO_API_DIR/PythonAPI +PROTO_BASE_URL="https://github.com/google/protobuf/releases/download/v3.5.1/" +PROTOC_DIR=$PWD/protoc -pushd $RESEARCH_DIR - -# GET PROTOC 3.5 +#echo Install python-tk ... +#python -V 2>&1 | grep "Python 3" || \ +# ( export DEBIAN_FRONTEND=noninteractive && \ +# apt-get update && \ +# apt-get install -y --no-install-recommends python-tk ) -BASE_URL="https://github.com/google/protobuf/releases/download/v3.5.1/" -PROTOC_DIR=protoc -PROTOC_EXE=$PROTOC_DIR/bin/protoc +set -v +echo Download protobuf... mkdir -p $PROTOC_DIR pushd $PROTOC_DIR ARCH=$(uname -m) @@ -47,25 +45,26 @@ else echo ERROR: $ARCH not supported. exit 1; fi -wget --no-check-certificate ${BASE_URL}${filename} -unzip ${filename} +wget --no-check-certificate ${PROTO_BASE_URL}${filename} +unzip -o ${filename} popd -# BUILD PROTOBUF FILES -$PROTOC_EXE object_detection/protos/*.proto --python_out=. - -# INSTALL OBJECT DETECTION +echo Compile object detection protobuf files... +pushd $RESEARCH_DIR +$PROTOC_DIR/bin/protoc object_detection/protos/*.proto --python_out=. +popd +echo Install tensorflow/models/research... +pushd $RESEARCH_DIR pip install -e . - popd +echo Install tensorflow/models/research/slim... pushd $SLIM_DIR pip install -e . popd -# INSTALL PYCOCOTOOLS - +echo Install cocodataset/cocoapi/PythonAPI... pushd $PYCOCO_DIR pip install -e . popd diff --git a/tftrt/examples/object_detection/object_detection.py b/tftrt/examples/object_detection/object_detection.py index 38e86a300..17e93a993 100644 --- a/tftrt/examples/object_detection/object_detection.py +++ b/tftrt/examples/object_detection/object_detection.py @@ -302,7 +302,7 @@ def optimize_model(config_path, 1] elif config.model.HasField('faster_rcnn'): if override_nms_score_threshold is not None: - config.model.faster_rcnn.second_stage_post_processing.score_threshold = override_nms_score_threshold + config.model.faster_rcnn.second_stage_post_processing.batch_non_max_suppression.score_threshold = override_nms_score_threshold if override_resizer_shape is not None: config.model.faster_rcnn.image_resizer.fixed_shape_resizer.height = override_resizer_shape[ 0] @@ -612,7 +612,8 @@ def benchmark_model(frozen_graph, statistics = { 'map': eval.stats[0], 'avg_latency_ms': 1000.0 * np.mean(runtimes), - 'avg_throughput_fps': np.sum(image_counts) / np.sum(runtimes) + 'avg_throughput_fps': np.sum(image_counts) / np.sum(runtimes), + 'runtimes_ms': [1000.0 * r for r in runtimes] } if output_path is not None: diff --git a/tftrt/examples/object_detection/test.py b/tftrt/examples/object_detection/test.py index b7c2248c6..89b1175a8 100644 --- a/tftrt/examples/object_detection/test.py +++ b/tftrt/examples/object_detection/test.py @@ -80,7 +80,12 @@ def test(test_config_path): statistics = benchmark_model( frozen_graph=frozen_graph, **test_config['benchmark_config']) - print(json.dumps(statistics, sort_keys=True, indent=4)) + + # print some statistics to command line + print_statistics = statistics + if 'runtimes_ms' in print_statistics: + print_statistics.pop('runtimes_ms') + print(json.dumps(print_statistics, sort_keys=True, indent=4)) # run assertions if 'assertions' in test_config: From 73279e6a15426281fc7ea38abb15ffceb908780f Mon Sep 17 00:00:00 2001 From: mdrozdowski1996 Date: Fri, 21 Dec 2018 14:20:53 -0800 Subject: [PATCH 07/56] change calibration dataset (#9) --- .../image-classification/image_classification.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index 0090b802a..e92f63ddf 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -53,7 +53,7 @@ def after_run(self, run_context, run_values): self.batch_size / self.iter_times[-1])) def run(frozen_graph, model, data_dir, batch_size, - num_iterations, num_warmup_iterations, use_synthetic, display_every=100): + num_iterations, num_warmup_iterations, use_synthetic, display_every=100, run_calibration=False): """Evaluates a frozen graph This function evaluates a graph on the ImageNet validation set. @@ -82,7 +82,11 @@ def model_fn(features, labels, mode): # Create the dataset preprocess_fn = get_preprocess_fn(model) - validation_files = tf.gfile.Glob(os.path.join(data_dir, 'validation*')) + + if run_calibration: + validation_files = tf.gfile.Glob(os.path.join(data_dir, 'train*')) + else: + validation_files = tf.gfile.Glob(os.path.join(data_dir, 'validation*')) def get_tfrecords_count(files): num_records = 0 @@ -484,7 +488,7 @@ def get_frozen_graph( print('Calibrating INT8...') start_time = time.time() run(calib_graph, model, calib_data_dir, batch_size, - num_calib_inputs // batch_size, 0, False) + num_calib_inputs // batch_size, 0, False, run_calibration=True) times['trt_calibration'] = time.time() - start_time start_time = time.time() From 552479843d84e51273c435bd19d1af1f56afbab5 Mon Sep 17 00:00:00 2001 From: mdrozdowski1996 Date: Fri, 21 Dec 2018 14:25:20 -0800 Subject: [PATCH 08/56] add frozen_graph size (#5) * add frozen_graph size * add workspace_size args --- .gitmodules | 4 ++-- .../image-classification/image_classification.py | 14 +++++++++++--- .../object_detection/install_dependencies.sh | 4 ++-- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/.gitmodules b/.gitmodules index 2688d24bc..36fd3ef55 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ [submodule "third_party/models"] - path = tftrt/examples/object_detection/third_party/models + path = tftrt/examples/third_party/models url = https://github.com/tensorflow/models [submodule "third_party/cocoapi"] - path = tftrt/examples/object_detection/third_party/cocoapi + path = tftrt/examples/third_party/cocoapi url = https://github.com/cocodataset/cocoapi diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index e92f63ddf..b325e89ae 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -435,7 +435,8 @@ def get_frozen_graph( num_calib_inputs=None, use_synthetic=False, cache=False, - default_models_dir='./data'): + default_models_dir='./data', + max_workspace_size=(2<<32)-1000): """Retreives a frozen GraphDef from model definitions in classification.py and applies TF-TRT model: str, the model name (see NETS table in classification.py) @@ -473,7 +474,7 @@ def get_frozen_graph( input_graph_def=frozen_graph, outputs=['logits', 'classes'], max_batch_size=batch_size, - max_workspace_size_bytes=(4096<<20)-1000, + max_workspace_size_bytes=max_workspace_size, precision_mode=precision, minimum_segment_size=minimum_segment_size, is_dynamic_op=use_dynamic_op @@ -551,6 +552,8 @@ def get_frozen_graph( parser.add_argument('--num_calib_inputs', type=int, default=500, help='Number of inputs (e.g. images) used for calibration ' '(last batch is skipped in case it is not full)') + parser.add_argument('--max_workspace_size', type=int, default=(2<<32)-1000, + help='workspace size in bytes') parser.add_argument('--cache', action='store_true', help='If set, graphs will be saved to disk after conversion. If a converted graph is present on disk, it will be loaded instead of building the graph again.') args = parser.parse_args() @@ -579,12 +582,17 @@ def get_frozen_graph( num_calib_inputs=args.num_calib_inputs, use_synthetic=args.use_synthetic, cache=args.cache, - default_models_dir=args.default_models_dir) + default_models_dir=args.default_models_dir, + max_workspace_size=args.max_workspace_size) def print_dict(input_dict, str=''): for k, v in sorted(input_dict.items()): headline = '{}({}): '.format(str, k) if str else '{}: '.format(k) print('{}{}'.format(headline, '%.1f'%v if type(v)==float else v)) + + serialized_graph = frozen_graph.SerializeToString() + print('frozen graph size: {}'.format(len(serialized_graph))) + print_dict(vars(args)) print_dict(num_nodes, str='num_nodes') print_dict(times, str='time(s)') diff --git a/tftrt/examples/object_detection/install_dependencies.sh b/tftrt/examples/object_detection/install_dependencies.sh index 5d7a0e61a..a1041f450 100755 --- a/tftrt/examples/object_detection/install_dependencies.sh +++ b/tftrt/examples/object_detection/install_dependencies.sh @@ -17,10 +17,10 @@ # ============================================================================= echo Setup local variables... -TF_MODELS_DIR=third_party/models +TF_MODELS_DIR=../third_party/models RESEARCH_DIR=$TF_MODELS_DIR/research SLIM_DIR=$RESEARCH_DIR/slim -COCO_API_DIR=third_party/cocoapi +COCO_API_DIR=../third_party/cocoapi PYCOCO_DIR=$COCO_API_DIR/PythonAPI PROTO_BASE_URL="https://github.com/google/protobuf/releases/download/v3.5.1/" PROTOC_DIR=$PWD/protoc From 494824a503eb45579607683bd16ddd54e8562295 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Tue, 11 Dec 2018 09:28:41 -0800 Subject: [PATCH 09/56] Update README.md Remove check_accuracy section because we have removed the script from the repository. --- tftrt/examples/image-classification/README.md | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/tftrt/examples/image-classification/README.md b/tftrt/examples/image-classification/README.md index 09929a97d..a821a3e11 100644 --- a/tftrt/examples/image-classification/README.md +++ b/tftrt/examples/image-classification/README.md @@ -2,7 +2,7 @@ This example includes scripts to run inference using a number of popular image classification models. -You can turn on TensorFlow-TensorRT integration with the flag `--use_trt`. This +You can turn on TF-TRT integration with the flag `--use_trt`. This will apply TensorRT inference optimization to speed up execution for portions of the model's graph where supported, and will fall back to native TensorFlow for layers and operations which are not supported. @@ -85,13 +85,3 @@ Run with `--help` to see all available options. See [General Script Usage ](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#image-class-usage) for more information. - -### Accuracy tests - -There is the script `check_accuracy.py` provided in the example that parses the output log of `inference.py` -to find the reported accuracy, and reports whether that accuracy matches with the -baseline numbers. - -See [Checking Accuracy -](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#image-class-accuracy) -for more information. From 62c15df72e924c301bba225889a319edcb553d0b Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Tue, 11 Dec 2018 12:34:35 -0800 Subject: [PATCH 10/56] Fix an error when there is no data file in data_dir And a bit cleanup around data directories --- .../image_classification.py | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index b325e89ae..d7891364f 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -52,7 +52,7 @@ def after_run(self, run_context, run_values): current_step, self.num_steps, duration * 1000, self.batch_size / self.iter_times[-1])) -def run(frozen_graph, model, data_dir, batch_size, +def run(frozen_graph, model, data_files, batch_size, num_iterations, num_warmup_iterations, use_synthetic, display_every=100, run_calibration=False): """Evaluates a frozen graph @@ -62,7 +62,7 @@ def run(frozen_graph, model, data_dir, batch_size, frozen_graph: GraphDef, a graph containing input node 'input' and outputs 'logits' and 'classes' model: string, the model name (see NETS table in graph.py) - data_dir: str, directory containing ImageNet validation TFRecord files + data_files: List of TFRecord files used for inference batch_size: int, batch size for TensorRT optimizations num_iterations: int, number of iterations(batches) to run for """ @@ -80,14 +80,9 @@ def model_fn(features, labels, mode): loss=loss, eval_metric_ops={'accuracy': accuracy}) - # Create the dataset + # preprocess function for input data preprocess_fn = get_preprocess_fn(model) - if run_calibration: - validation_files = tf.gfile.Glob(os.path.join(data_dir, 'train*')) - else: - validation_files = tf.gfile.Glob(os.path.join(data_dir, 'validation*')) - def get_tfrecords_count(files): num_records = 0 for fn in files: @@ -111,7 +106,7 @@ def eval_input_fn(): dtype=np.int32) labels = tf.identity(tf.constant(labels)) else: - dataset = tf.data.TFRecordDataset(validation_files) + dataset = tf.data.TFRecordDataset(data_files) dataset = dataset.apply(tf.contrib.data.map_and_batch(map_func=preprocess_fn, batch_size=batch_size, num_parallel_calls=8)) dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) dataset = dataset.repeat(count=1) @@ -123,7 +118,7 @@ def eval_input_fn(): logger = LoggerHook( display_every=display_every, batch_size=batch_size, - num_records=get_tfrecords_count(validation_files)) + num_records=get_tfrecords_count(data_files)) tf_config = tf.ConfigProto() tf_config.gpu_options.allow_growth = True estimator = tf.estimator.Estimator( @@ -431,7 +426,7 @@ def get_frozen_graph( precision='fp32', batch_size=8, minimum_segment_size=2, - calib_data_dir=None, + calib_files=None, num_calib_inputs=None, use_synthetic=False, cache=False, @@ -488,7 +483,7 @@ def get_frozen_graph( # INT8 calibration step print('Calibrating INT8...') start_time = time.time() - run(calib_graph, model, calib_data_dir, batch_size, + run(calib_graph, model, calib_files, batch_size, num_calib_inputs // batch_size, 0, False, run_calibration=True) times['trt_calibration'] = time.time() - start_time @@ -569,6 +564,18 @@ def get_frozen_graph( raise ValueError('--num_calib_inputs must not be smaller than --batch_size' '({} <= {})'.format(args.num_calib_inputs, args.batch_size)) + def get_files(data_dir, filename_pattern): + if data_dir == None: + return [] + files = tf.gfile.Glob(os.path.join(data_dir, filename_pattern)) + if files == []: + raise ValueError('Can not find any files in {} with pattern "{}"'.format( + data_dir, filename_pattern)) + return files + + validation_files = get_files(args.data_dir, 'validation*') + calib_files = get_files(args.calib_data_dir, 'train*') + # Retreive graph using NETS table in graph.py frozen_graph, num_nodes, times = get_frozen_graph( model=args.model, @@ -578,7 +585,7 @@ def get_frozen_graph( precision=args.precision, batch_size=args.batch_size, minimum_segment_size=args.minimum_segment_size, - calib_data_dir=args.calib_data_dir, + calib_files=calib_files, num_calib_inputs=args.num_calib_inputs, use_synthetic=args.use_synthetic, cache=args.cache, @@ -602,7 +609,7 @@ def print_dict(input_dict, str=''): results = run( frozen_graph, model=args.model, - data_dir=args.data_dir, + data_files=validation_files, batch_size=args.batch_size, num_iterations=args.num_iterations, num_warmup_iterations=args.num_warmup_iterations, From bbc973bc6657e2f4787d1869b2fc78bad2813e26 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Fri, 21 Dec 2018 16:00:50 -0800 Subject: [PATCH 11/56] Decrease default workspace size from 8GB to 4GB --- tftrt/examples/image-classification/image_classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index d7891364f..998ef3e3a 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -431,7 +431,7 @@ def get_frozen_graph( use_synthetic=False, cache=False, default_models_dir='./data', - max_workspace_size=(2<<32)-1000): + max_workspace_size=(1<<32)): """Retreives a frozen GraphDef from model definitions in classification.py and applies TF-TRT model: str, the model name (see NETS table in classification.py) @@ -547,7 +547,7 @@ def get_frozen_graph( parser.add_argument('--num_calib_inputs', type=int, default=500, help='Number of inputs (e.g. images) used for calibration ' '(last batch is skipped in case it is not full)') - parser.add_argument('--max_workspace_size', type=int, default=(2<<32)-1000, + parser.add_argument('--max_workspace_size', type=int, default=(1<<32), help='workspace size in bytes') parser.add_argument('--cache', action='store_true', help='If set, graphs will be saved to disk after conversion. If a converted graph is present on disk, it will be loaded instead of building the graph again.') From 4d2069447acb5888b212534b1e8274d8d24376cd Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Fri, 21 Dec 2018 17:02:15 -0800 Subject: [PATCH 12/56] Print graph size for both TF and TRT graphs --- .../image_classification.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index 998ef3e3a..a02a0f059 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -442,6 +442,7 @@ def get_frozen_graph( """ num_nodes = {} times = {} + graph_sizes = {} # Load from pb file if frozen graph was already created and cached if cache: @@ -456,11 +457,13 @@ def get_frozen_graph( times['loading_frozen_graph'] = time.time() - start_time num_nodes['loaded_frozen_graph'] = len(frozen_graph.node) num_nodes['trt_only'] = len([1 for n in frozen_graph.node if str(n.op)=='TRTEngineOp']) - return frozen_graph, num_nodes, times + graph_sizes['loaded_frozen_graph'] = len(frozen_graph.SerializeToString()) + return frozen_graph, num_nodes, times, graph_sizes # Build graph and load weights frozen_graph = build_classification_graph(model, model_dir, default_models_dir) num_nodes['native_tf'] = len(frozen_graph.node) + graph_sizes['native_tf'] = len(frozen_graph.SerializeToString()) # Convert to TensorRT graph if use_trt: @@ -477,9 +480,11 @@ def get_frozen_graph( times['trt_conversion'] = time.time() - start_time num_nodes['tftrt_total'] = len(frozen_graph.node) num_nodes['trt_only'] = len([1 for n in frozen_graph.node if str(n.op)=='TRTEngineOp']) + graph_sizes['trt'] = len(frozen_graph.SerializeToString()) if precision == 'int8': calib_graph = frozen_graph + graph_sizes['calib'] = len(calib_graph.SerializeToString()) # INT8 calibration step print('Calibrating INT8...') start_time = time.time() @@ -490,6 +495,8 @@ def get_frozen_graph( start_time = time.time() frozen_graph = trt.calib_graph_to_infer_graph(calib_graph) times['trt_int8_conversion'] = time.time() - start_time + # This is already set but overwriting it here to ensure the right size + graph_sizes['trt'] = len(frozen_graph.SerializeToString()) del calib_graph print('INT8 graph created.') @@ -506,7 +513,7 @@ def get_frozen_graph( f.write(frozen_graph.SerializeToString()) times['saving_frozen_graph'] = time.time() - start_time - return frozen_graph, num_nodes, times + return frozen_graph, num_nodes, times, graph_sizes if __name__ == '__main__': parser = argparse.ArgumentParser(description='Evaluate model') @@ -577,7 +584,7 @@ def get_files(data_dir, filename_pattern): calib_files = get_files(args.calib_data_dir, 'train*') # Retreive graph using NETS table in graph.py - frozen_graph, num_nodes, times = get_frozen_graph( + frozen_graph, num_nodes, times, graph_sizes = get_frozen_graph( model=args.model, model_dir=args.model_dir, use_trt=args.use_trt, @@ -592,16 +599,15 @@ def get_files(data_dir, filename_pattern): default_models_dir=args.default_models_dir, max_workspace_size=args.max_workspace_size) - def print_dict(input_dict, str=''): + def print_dict(input_dict, str='', scale=None): for k, v in sorted(input_dict.items()): headline = '{}({}): '.format(str, k) if str else '{}: '.format(k) + v = v * scale if scale else v print('{}{}'.format(headline, '%.1f'%v if type(v)==float else v)) - serialized_graph = frozen_graph.SerializeToString() - print('frozen graph size: {}'.format(len(serialized_graph))) - print_dict(vars(args)) print_dict(num_nodes, str='num_nodes') + print_dict(graph_sizes, str='graph_size(MB)', scale=1./(1<<20)) print_dict(times, str='time(s)') # Evaluate model From 7fe06a3ef03370e760dfbd900d1d5c5b2e872b12 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Fri, 21 Dec 2018 17:14:13 -0800 Subject: [PATCH 13/56] Fix third_party submodules --- tftrt/examples/object_detection/third_party/cocoapi | 1 - tftrt/examples/object_detection/third_party/models | 1 - 2 files changed, 2 deletions(-) delete mode 160000 tftrt/examples/object_detection/third_party/cocoapi delete mode 160000 tftrt/examples/object_detection/third_party/models diff --git a/tftrt/examples/object_detection/third_party/cocoapi b/tftrt/examples/object_detection/third_party/cocoapi deleted file mode 160000 index ed842bffd..000000000 --- a/tftrt/examples/object_detection/third_party/cocoapi +++ /dev/null @@ -1 +0,0 @@ -Subproject commit ed842bffd41f6ff38707c4f0968d2cfd91088688 diff --git a/tftrt/examples/object_detection/third_party/models b/tftrt/examples/object_detection/third_party/models deleted file mode 160000 index 402b561b0..000000000 --- a/tftrt/examples/object_detection/third_party/models +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 402b561b03857151f684ee00b3d997e5e6be9778 From 1af92026db37289f964928c2d4adaad02f1382cf Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Fri, 21 Dec 2018 17:25:10 -0800 Subject: [PATCH 14/56] Revert change in .gitmodules --- .gitmodules | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitmodules b/.gitmodules index 36fd3ef55..2688d24bc 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ [submodule "third_party/models"] - path = tftrt/examples/third_party/models + path = tftrt/examples/object_detection/third_party/models url = https://github.com/tensorflow/models [submodule "third_party/cocoapi"] - path = tftrt/examples/third_party/cocoapi + path = tftrt/examples/object_detection/third_party/cocoapi url = https://github.com/cocodataset/cocoapi From b1754def74c3743621da1b57a556a205078f6c6e Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Fri, 21 Dec 2018 17:36:43 -0800 Subject: [PATCH 15/56] Update submodules --- .gitmodules | 6 ++++++ tftrt/examples/third_party/cocoapi | 1 + tftrt/examples/third_party/models | 1 + 3 files changed, 8 insertions(+) create mode 160000 tftrt/examples/third_party/cocoapi create mode 160000 tftrt/examples/third_party/models diff --git a/.gitmodules b/.gitmodules index 2688d24bc..6c8f43414 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,9 @@ [submodule "third_party/cocoapi"] path = tftrt/examples/object_detection/third_party/cocoapi url = https://github.com/cocodataset/cocoapi +[submodule "tftrt/examples/third_party/models"] + path = tftrt/examples/third_party/models + url = https://github.com/tensorflow/models.git +[submodule "tftrt/examples/third_party/cocoapi"] + path = tftrt/examples/third_party/cocoapi + url = https://github.com/cocodataset/cocoapi.git diff --git a/tftrt/examples/third_party/cocoapi b/tftrt/examples/third_party/cocoapi new file mode 160000 index 000000000..ed842bffd --- /dev/null +++ b/tftrt/examples/third_party/cocoapi @@ -0,0 +1 @@ +Subproject commit ed842bffd41f6ff38707c4f0968d2cfd91088688 diff --git a/tftrt/examples/third_party/models b/tftrt/examples/third_party/models new file mode 160000 index 000000000..402b561b0 --- /dev/null +++ b/tftrt/examples/third_party/models @@ -0,0 +1 @@ +Subproject commit 402b561b03857151f684ee00b3d997e5e6be9778 From 24945f61756077a42ae8c380c71b390dc82cd14e Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Fri, 21 Dec 2018 17:47:43 -0800 Subject: [PATCH 16/56] Update submodule --- .gitmodules | 6 ------ 1 file changed, 6 deletions(-) diff --git a/.gitmodules b/.gitmodules index 6c8f43414..9aa63ca53 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,9 +1,3 @@ -[submodule "third_party/models"] - path = tftrt/examples/object_detection/third_party/models - url = https://github.com/tensorflow/models -[submodule "third_party/cocoapi"] - path = tftrt/examples/object_detection/third_party/cocoapi - url = https://github.com/cocodataset/cocoapi [submodule "tftrt/examples/third_party/models"] path = tftrt/examples/third_party/models url = https://github.com/tensorflow/models.git From 9ec4769b5af3154f173001e86f87435d8475eec2 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Tue, 8 Jan 2019 11:20:41 -0800 Subject: [PATCH 17/56] Add install_dependencies.sh for image-classification Also don't use -e for pip install. --- .../install_dependencies.sh | 30 +++++++++++++++++++ .../object_detection/install_dependencies.sh | 6 ++-- 2 files changed, 33 insertions(+), 3 deletions(-) create mode 100755 tftrt/examples/image-classification/install_dependencies.sh diff --git a/tftrt/examples/image-classification/install_dependencies.sh b/tftrt/examples/image-classification/install_dependencies.sh new file mode 100755 index 000000000..27fd5767f --- /dev/null +++ b/tftrt/examples/image-classification/install_dependencies.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +set +e + +TF_MODELS_DIR=$PWD/../third_party/models + +echo Install slim +pushd $TF_MODELS_DIR/research/slim +pip install . +popd + +echo Install requests +pip install requests + diff --git a/tftrt/examples/object_detection/install_dependencies.sh b/tftrt/examples/object_detection/install_dependencies.sh index a1041f450..81ed5c983 100755 --- a/tftrt/examples/object_detection/install_dependencies.sh +++ b/tftrt/examples/object_detection/install_dependencies.sh @@ -56,15 +56,15 @@ popd echo Install tensorflow/models/research... pushd $RESEARCH_DIR -pip install -e . +pip install . popd echo Install tensorflow/models/research/slim... pushd $SLIM_DIR -pip install -e . +pip install . popd echo Install cocodataset/cocoapi/PythonAPI... pushd $PYCOCO_DIR -pip install -e . +pip install . popd From e058613f216589ff5e14af3ec25898a306bc470d Mon Sep 17 00:00:00 2001 From: mdrozdowski1996 Date: Mon, 14 Jan 2019 16:20:21 -0800 Subject: [PATCH 18/56] print model url (#15) --- tftrt/examples/image-classification/image_classification.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index a02a0f059..b821da41d 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -175,6 +175,10 @@ def get_input_dims(self): def get_num_classes(self): return self.num_classes + def get_url(self): + return self.url + + def get_netdef(model): """ Creates the dictionary NETS with model names as keys and NetDef as values. @@ -606,6 +610,7 @@ def print_dict(input_dict, str='', scale=None): print('{}{}'.format(headline, '%.1f'%v if type(v)==float else v)) print_dict(vars(args)) + print("url: " + get_netdef(args.model).get_url()) print_dict(num_nodes, str='num_nodes') print_dict(graph_sizes, str='graph_size(MB)', scale=1./(1<<20)) print_dict(times, str='time(s)') From 7d09606f040eeb809b3e2eb901b144e545bfe5b3 Mon Sep 17 00:00:00 2001 From: mdrozdowski1996 Date: Tue, 15 Jan 2019 09:30:37 -0800 Subject: [PATCH 19/56] fixed validation files requirement for synthetic (#8) --- tftrt/examples/image-classification/image_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index b821da41d..6be178d95 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -566,7 +566,7 @@ def get_frozen_graph( if args.precision != 'fp32' and not args.use_trt: raise ValueError('TensorRT must be enabled for fp16 or int8 modes (--use_trt).') - if args.precision == 'int8' and not args.calib_data_dir: + if args.precision == 'int8' and not args.calib_data_dir and not args.use_synthetic: raise ValueError('--calib_data_dir is required for int8 mode') if args.num_iterations is not None and args.num_iterations <= args.num_warmup_iterations: raise ValueError('--num_iterations must be larger than --num_warmup_iterations ' From d57752542f2f89758f1503454d24c085bd692204 Mon Sep 17 00:00:00 2001 From: mdrozdowski1996 Date: Tue, 15 Jan 2019 09:30:46 -0800 Subject: [PATCH 20/56] enable use_synthetic for calibration (#17) --- tftrt/examples/image-classification/image_classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index 6be178d95..21427e421 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -53,7 +53,7 @@ def after_run(self, run_context, run_values): self.batch_size / self.iter_times[-1])) def run(frozen_graph, model, data_files, batch_size, - num_iterations, num_warmup_iterations, use_synthetic, display_every=100, run_calibration=False): + num_iterations, num_warmup_iterations, use_synthetic=False, display_every=100, run_calibration=False): """Evaluates a frozen graph This function evaluates a graph on the ImageNet validation set. @@ -493,7 +493,7 @@ def get_frozen_graph( print('Calibrating INT8...') start_time = time.time() run(calib_graph, model, calib_files, batch_size, - num_calib_inputs // batch_size, 0, False, run_calibration=True) + num_calib_inputs // batch_size, 0, use_synthetic=use_synthetic, run_calibration=True) times['trt_calibration'] = time.time() - start_time start_time = time.time() From 03a9134de932bf91be8bc15007b8977339141d09 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Mon, 28 Jan 2019 09:25:13 -0800 Subject: [PATCH 21/56] Update readme (#19) --- tftrt/examples/image-classification/README.md | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tftrt/examples/image-classification/README.md b/tftrt/examples/image-classification/README.md index a821a3e11..1990dd9ee 100644 --- a/tftrt/examples/image-classification/README.md +++ b/tftrt/examples/image-classification/README.md @@ -11,8 +11,11 @@ See https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html for mor When using TF-TRT, you can also control the precision with `--precision`. float32 is the default (`--precision fp32`) with float16 (`--precision fp16`) or int8 (`--precision int8`) allowing further performance improvements. -int8 mode requires a calibration step which is done -automatically. + +int8 mode requires a calibration step which is done automatically, but you will +also have to specificy the directory in which the calibration dataset is stored +with `--calib_data_dir /imagenet_validation_data`. You can use the same data for +both calibration and validation. ## Models @@ -34,6 +37,10 @@ ImageNet validation dataset, see [Verified Models](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#verified-models) ## Setup +If you are running these examples within the [NVIDIA TensorFlow docker +container](https://ngc.nvidia.com/catalog/containers/nvidia:tensorflow), you can +skip these steps by running `./install_dependencies.sh`. + ``` # Clone [tensorflow/models](https://github.com/tensorflow/models) git clone https://github.com/tensorflow/models.git @@ -78,7 +85,7 @@ for more information. ## Usage -`python inference.py --data_dir /imagenet_validation_data --model vgg_16 [--use_trt]` +`python image_classification.py --data_dir /imagenet_validation_data --model vgg_16 [--use_trt]` Run with `--help` to see all available options. From 34344a9d9ccdb200e2bcc1e4da09f6d2507fc164 Mon Sep 17 00:00:00 2001 From: Anuj Khandelwal Date: Fri, 1 Feb 2019 23:42:10 +0530 Subject: [PATCH 22/56] Updated install_dependencies.sh (#18) This script was throwing error while installing cocoapi, the changes I made above solved this for my setup. --- tftrt/examples/object_detection/install_dependencies.sh | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tftrt/examples/object_detection/install_dependencies.sh b/tftrt/examples/object_detection/install_dependencies.sh index 81ed5c983..8633f1939 100755 --- a/tftrt/examples/object_detection/install_dependencies.sh +++ b/tftrt/examples/object_detection/install_dependencies.sh @@ -66,5 +66,8 @@ popd echo Install cocodataset/cocoapi/PythonAPI... pushd $PYCOCO_DIR -pip install . +python setup.py build_ext --inplace +make +# pip install . +python setup.py install popd From 90d62f6668feb109906b1d9089c9b9ef9da48a36 Mon Sep 17 00:00:00 2001 From: otstrel Date: Tue, 12 Feb 2019 18:15:57 +0300 Subject: [PATCH 23/56] Addin target_duration argument (#23) --- .../image_classification.py | 41 +++++++++++++++---- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index 21427e421..6d8e24a8a 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -52,10 +52,31 @@ def after_run(self, run_context, run_values): current_step, self.num_steps, duration * 1000, self.batch_size / self.iter_times[-1])) +class DurationHook(tf.train.SessionRunHook): + """Limits run duration""" + def __init__(self, target_duration): + self.target_duration = target_duration + self.start_time = None + + def after_run(self, run_context, run_values): + if not self.target_duration: + return + + if not self.start_time: + self.start_time = time.time() + print(" running for target duration from %d" % self.start_time) + return + + current_time = time.time() + if (current_time - self.start_time) > self.target_duration: + print(" target duration %d reached at %d, requesting stop" % (self.target_duration, current_time)) + run_context.request_stop() + def run(frozen_graph, model, data_files, batch_size, - num_iterations, num_warmup_iterations, use_synthetic=False, display_every=100, run_calibration=False): + num_iterations, num_warmup_iterations, use_synthetic, display_every=100, run_calibration=False, + target_duration=None): """Evaluates a frozen graph - + This function evaluates a graph on the ImageNet validation set. tf.estimator.Estimator is used to evaluate the accuracy of the model and a few other metrics. The results are returned as a dict. @@ -125,8 +146,9 @@ def eval_input_fn(): model_fn=model_fn, config=tf.estimator.RunConfig(session_config=tf_config), model_dir='model_dir') - results = estimator.evaluate(eval_input_fn, steps=num_iterations, hooks=[logger]) - + duration_hook = DurationHook(target_duration) + results = estimator.evaluate(eval_input_fn, steps=num_iterations, hooks=[logger, duration_hook]) + # Gather additional results iter_times = np.array(logger.iter_times[num_warmup_iterations:]) results['total_time'] = np.sum(iter_times) @@ -137,7 +159,7 @@ def eval_input_fn(): class NetDef(object): """Contains definition of a model - + name: Name of model url: (optional) Where to download archive containing checkpoint model_dir_in_archive: (optional) Subdirectory in archive containing @@ -375,7 +397,7 @@ def get_checkpoint(model, model_dir=None, default_models_dir='.'): if get_netdef(model).url: download_checkpoint(model, model_dir) return find_checkpoint_in_dir(model_dir) - + print('No model_dir was provided and the model does not define a download' \ ' URL.') exit(1) @@ -533,7 +555,7 @@ def get_frozen_graph( parser.add_argument('--model_dir', type=str, default=None, help='Directory containing model checkpoint. If not provided, a ' \ 'checkpoint may be downloaded automatically and stored in ' \ - '"{--default_models_dir}/{--model}" for future use.') + '"{--default_models_dir}/{--model}" for future use.') parser.add_argument('--default_models_dir', type=str, default='./data', help='Directory where downloaded model checkpoints will be stored and ' \ 'loaded from if --model_dir is not provided.') @@ -562,6 +584,8 @@ def get_frozen_graph( help='workspace size in bytes') parser.add_argument('--cache', action='store_true', help='If set, graphs will be saved to disk after conversion. If a converted graph is present on disk, it will be loaded instead of building the graph again.') + parser.add_argument('--target_duration', type=int, default=None, + help='If set, script will run for specified number of seconds.') args = parser.parse_args() if args.precision != 'fp32' and not args.use_trt: @@ -625,7 +649,8 @@ def print_dict(input_dict, str='', scale=None): num_iterations=args.num_iterations, num_warmup_iterations=args.num_warmup_iterations, use_synthetic=args.use_synthetic, - display_every=args.display_every) + display_every=args.display_every, + target_duration=args.target_duration) # Display results print('results of {}:'.format(args.model)) From ee7f173d9e03ab9150bb8647326754408da77918 Mon Sep 17 00:00:00 2001 From: mdrozdowski1996 Date: Thu, 28 Feb 2019 01:55:51 +0100 Subject: [PATCH 24/56] add_benchmark_mode (#21) * add_benchmark_mode * refactor_changes * fix infinity loop using synthetic mode * convert duration_hooks into benchmark_hooks * change target_duration and iteration_limit for the optional paramteters * refactor code * update comment * update error message * remove space --- .../image_classification.py | 145 +++++++++++++----- 1 file changed, 104 insertions(+), 41 deletions(-) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index 6d8e24a8a..8c3e22aa2 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -52,29 +52,34 @@ def after_run(self, run_context, run_values): current_step, self.num_steps, duration * 1000, self.batch_size / self.iter_times[-1])) -class DurationHook(tf.train.SessionRunHook): - """Limits run duration""" - def __init__(self, target_duration): +class BenchmarkHook(tf.train.SessionRunHook): + """Limits run duration and number of iterations""" + def __init__(self, target_duration=None, iteration_limit=None): self.target_duration = target_duration self.start_time = None + self.current_iteration = 0 + self.iteration_limit = iteration_limit - def after_run(self, run_context, run_values): - if not self.target_duration: - return - + def before_run(self, run_context): if not self.start_time: self.start_time = time.time() - print(" running for target duration from %d" % self.start_time) - return + print(" running for target duration from %d", self.start_time) - current_time = time.time() - if (current_time - self.start_time) > self.target_duration: - print(" target duration %d reached at %d, requesting stop" % (self.target_duration, current_time)) - run_context.request_stop() + def after_run(self, run_context, run_values): + if self.target_duration: + current_time = time.time() + if (current_time - self.start_time) > self.target_duration: + print(" target duration %d reached at %d, requesting stop" % (self.target_duration, current_time)) + run_context.request_stop() + + if self.iteration_limit: + self.current_iteration += 1 + if self.current_iteration >= self.iteration_limit: + run_context.request_stop() def run(frozen_graph, model, data_files, batch_size, num_iterations, num_warmup_iterations, use_synthetic, display_every=100, run_calibration=False, - target_duration=None): + mode='validation', target_duration=None): """Evaluates a frozen graph This function evaluates a graph on the ImageNet validation set. @@ -86,6 +91,12 @@ def run(frozen_graph, model, data_files, batch_size, data_files: List of TFRecord files used for inference batch_size: int, batch size for TensorRT optimizations num_iterations: int, number of iterations(batches) to run for + num_warmup_iterations: int, number of iteration(batches) to exclude from benchmark measurments + use_synthetic: bool, if true run using real data, otherwise synthetic + display_every: int, print log every @display_every iteration + run_calibration: bool, run using calibration or not (only int8 precision) + mode: validation - using estimator.evaluate with accuracy measurments, + benchmark - using estimator.predict """ # Define model function for tf.estimator.Estimator def model_fn(features, labels, mode): @@ -93,16 +104,19 @@ def model_fn(features, labels, mode): input_map={'input': features}, return_elements=['logits:0', 'classes:0'], name='') - loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits_out) - accuracy = tf.metrics.accuracy(labels=labels, predictions=classes_out, name='acc_op') + if mode == tf.estimator.ModeKeys.PREDICT: + return tf.estimator.EstimatorSpec(mode=mode, + predictions={'classes': classes_out}) if mode == tf.estimator.ModeKeys.EVAL: + loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits_out) + accuracy = tf.metrics.accuracy(labels=labels, predictions=classes_out, name='acc_op') return tf.estimator.EstimatorSpec( mode, loss=loss, eval_metric_ops={'accuracy': accuracy}) # preprocess function for input data - preprocess_fn = get_preprocess_fn(model) + preprocess_fn = get_preprocess_fn(model, mode) def get_tfrecords_count(files): num_records = 0 @@ -112,7 +126,7 @@ def get_tfrecords_count(files): return num_records # Define the dataset input function for tf.estimator.Estimator - def eval_input_fn(): + def input_fn(): if use_synthetic: input_width, input_height = get_netdef(model).get_input_dims() features = np.random.normal( @@ -127,28 +141,54 @@ def eval_input_fn(): dtype=np.int32) labels = tf.identity(tf.constant(labels)) else: - dataset = tf.data.TFRecordDataset(data_files) - dataset = dataset.apply(tf.contrib.data.map_and_batch(map_func=preprocess_fn, batch_size=batch_size, num_parallel_calls=8)) - dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) - dataset = dataset.repeat(count=1) - iterator = dataset.make_one_shot_iterator() - features, labels = iterator.get_next() + if mode == 'validation': + dataset = tf.data.TFRecordDataset(data_files) + dataset = dataset.apply(tf.contrib.data.map_and_batch(map_func=preprocess_fn, batch_size=batch_size, num_parallel_calls=8)) + dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) + dataset = dataset.repeat(count=1) + iterator = dataset.make_one_shot_iterator() + features, labels = iterator.get_next() + elif mode == 'benchmark': + dataset = tf.data.Dataset.from_tensor_slices(data_files) + dataset = dataset.apply(tf.contrib.data.map_and_batch(map_func=preprocess_fn, batch_size=batch_size, num_parallel_calls=8)) + dataset = dataset.repeat(count=1) + iterator = dataset.make_one_shot_iterator() + features = iterator.get_next() + labels = np.random.randint( + low=0, + high=get_netdef(model).get_num_classes(), + size=(batch_size), + dtype=np.int32) + labels = tf.identity(tf.constant(labels)) + else: + raise ValueError("Mode must be either 'validation' or 'benchmark'") return features, labels # Evaluate model + if mode == 'validation': + num_records = get_tfrecords_count(data_files) + elif mode == 'benchmark': + num_records = len(data_files) + else: + raise ValueError("Mode must be either 'validation' or 'benchmark'") logger = LoggerHook( display_every=display_every, batch_size=batch_size, - num_records=get_tfrecords_count(data_files)) + num_records=num_records) tf_config = tf.ConfigProto() tf_config.gpu_options.allow_growth = True estimator = tf.estimator.Estimator( model_fn=model_fn, config=tf.estimator.RunConfig(session_config=tf_config), model_dir='model_dir') - duration_hook = DurationHook(target_duration) - results = estimator.evaluate(eval_input_fn, steps=num_iterations, hooks=[logger, duration_hook]) - + results = {} + if mode == 'validation': + results = estimator.evaluate(input_fn, steps=num_iterations, hooks=[logger]) + elif mode == 'benchmark': + benchmark_hook = BenchmarkHook(target_duration=target_duration, iteration_limit=num_iterations) + prediction_results = [p for p in estimator.predict(input_fn, predict_keys=["classes"], hooks=[logger, benchmark_hook])] + else: + raise ValueError("Mode must be either 'validation' or 'benchmark'") # Gather additional results iter_times = np.array(logger.iter_times[num_warmup_iterations:]) results['total_time'] = np.sum(iter_times) @@ -202,10 +242,8 @@ def get_url(self): def get_netdef(model): - """ - Creates the dictionary NETS with model names as keys and NetDef as values. + """Creates the dictionary NETS with model names as keys and NetDef as values. Returns the NetDef corresponding to the model specified in the parameter. - model: string, the model name (see NETS table) """ NETS = { @@ -292,14 +330,14 @@ def deserialize_image_record(record): text = obj['image/class/text'] return imgdata, label, bbox, text -def get_preprocess_fn(model, mode='classification'): +def get_preprocess_fn(model, mode='validation'): """Creates a function to parse and process a TFRecord using the model's parameters model: string, the model name (see NETS table) - mode: string, whether the model is for classification or detection + mode: string, which mode to use (validation or benchmark) returns: function, the preprocessing function for a record """ - def process(record): + def validation_process(record): # Parse TFRecord imgdata, label, bbox, text = deserialize_image_record(record) label -= 1 # Change to 0-based (don't use background class) @@ -310,7 +348,22 @@ def process(record): image = netdef.preprocess(image, netdef.input_height, netdef.input_width, is_training=False) return image, label - return process + def benchmark_process(path): + image = tf.read_file(path) + image = tf.image.decode_jpeg(image, channels=3) + net_def = get_netdef(model) + input_width, input_height = net_def.get_input_dims() + image = net_def.preprocess(image, input_width, input_height, is_training=False) + return image + + if mode == 'validation': + return validation_process + elif mode == 'benchmark': + return benchmark_process + else: + raise ValueError("Mode must be either 'validation' or 'benchmark'") + + def build_classification_graph(model, model_dir=None, default_models_dir='./data'): """Builds an image classification model by name @@ -584,6 +637,8 @@ def get_frozen_graph( help='workspace size in bytes') parser.add_argument('--cache', action='store_true', help='If set, graphs will be saved to disk after conversion. If a converted graph is present on disk, it will be loaded instead of building the graph again.') + parser.add_argument('--mode', choices=['validation', 'benchmark'], default='validation', + help='Which mode to use (validation or benchmark)') parser.add_argument('--target_duration', type=int, default=None, help='If set, script will run for specified number of seconds.') args = parser.parse_args() @@ -598,20 +653,26 @@ def get_frozen_graph( if args.num_calib_inputs < args.batch_size: raise ValueError('--num_calib_inputs must not be smaller than --batch_size' '({} <= {})'.format(args.num_calib_inputs, args.batch_size)) + if args.mode == 'validation' and args.use_synthetic: + raise ValueError('Cannot use both validation mode and synthetic dataset') def get_files(data_dir, filename_pattern): if data_dir == None: return [] files = tf.gfile.Glob(os.path.join(data_dir, filename_pattern)) if files == []: - raise ValueError('Can not find any files in {} with pattern "{}"'.format( - data_dir, filename_pattern)) + raise ValueError('Can not find any files in {} with ' + 'pattern "{}"'.format(data_dir, filename_pattern)) return files - validation_files = get_files(args.data_dir, 'validation*') + if args.mode == "validation": + data_files = get_files(args.data_dir, 'validation*') + elif args.mode == "benchmark": + data_files = [os.path.join(path, name) for path, _, files in os.walk(args.data_dir) for name in files] + else: + raise ValueError("Mode must be either 'validation' or 'benchamark'") calib_files = get_files(args.calib_data_dir, 'train*') - # Retreive graph using NETS table in graph.py frozen_graph, num_nodes, times, graph_sizes = get_frozen_graph( model=args.model, model_dir=args.model_dir, @@ -644,17 +705,19 @@ def print_dict(input_dict, str='', scale=None): results = run( frozen_graph, model=args.model, - data_files=validation_files, + data_files=data_files, batch_size=args.batch_size, num_iterations=args.num_iterations, num_warmup_iterations=args.num_warmup_iterations, use_synthetic=args.use_synthetic, display_every=args.display_every, + mode=args.mode, target_duration=args.target_duration) # Display results print('results of {}:'.format(args.model)) - print(' accuracy: %.2f' % (results['accuracy'] * 100)) + if args.mode == 'validation': + print(' accuracy: %.2f' % (results['accuracy'] * 100)) print(' images/sec: %d' % results['images_per_sec']) print(' 99th_percentile(ms): %.1f' % results['99th_percentile']) print(' total_time(s): %.1f' % results['total_time']) From edd05b538887c7bfe161a798b988a607a60cb4d5 Mon Sep 17 00:00:00 2001 From: mdrozdowski1996 Date: Thu, 28 Feb 2019 01:56:25 +0100 Subject: [PATCH 25/56] update models (#24) * update resnet_v1 and resnet_v2 * put copy into download_model function --- .../image_classification.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index 8c3e22aa2..c8aea134b 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -266,16 +266,16 @@ def get_netdef(model): 'resnet_v1_50': NetDef( name='resnet_v1_50', - url='http://download.tensorflow.org/models/official/20180601_resnet_v1_imagenet_checkpoint.tar.gz', - model_dir_in_archive='20180601_resnet_v1_imagenet_checkpoint', + url='http://download.tensorflow.org/models/official/20181001_resnet/checkpoints/resnet_imagenet_v1_fp32_20181001.tar.gz', + model_dir_in_archive='resnet_imagenet_v1_fp32_20181001', slim=False, preprocess='vgg', model_fn=official.resnet.imagenet_main.ImagenetModel(resnet_size=50, resnet_version=1)), 'resnet_v2_50': NetDef( name='resnet_v2_50', - url='http://download.tensorflow.org/models/official/20180601_resnet_v2_imagenet_checkpoint.tar.gz', - model_dir_in_archive='20180601_resnet_v2_imagenet_checkpoint', + url='http://download.tensorflow.org/models/official/20181001_resnet/checkpoints/resnet_imagenet_v2_fp32_20181001.tar.gz', + model_dir_in_archive='resnet_imagenet_v2_fp32_20181001', slim=False, preprocess='vgg', model_fn=official.resnet.imagenet_main.ImagenetModel(resnet_size=50, resnet_version=2)), @@ -477,7 +477,17 @@ def find_checkpoint_in_dir(model_dir): checkpoint_path = '.'.join(parts[:ckpt_index+1]) return checkpoint_path + def download_checkpoint(model, destination_path): + #copy files from source to destination (without any directories) + def copy_files(source, destination): + try: + shutil.copy2(source, destination) + except OSError as e: + pass + except shutil.Error as e: + pass + # Make directories if they don't exist. if not os.path.exists(destination_path): os.makedirs(destination_path) @@ -495,7 +505,7 @@ def download_checkpoint(model, destination_path): get_netdef(model).model_dir_in_archive, '*') for f in glob.glob(source_files): - shutil.copy2(f, destination_path) + copy_files(f, destination_path) def get_frozen_graph( model, From 78f883fd6924c39bc6288def1c718972adc86c83 Mon Sep 17 00:00:00 2001 From: brian pardini Date: Wed, 27 Feb 2019 20:47:41 -0800 Subject: [PATCH 26/56] Pull in sections from Accelerating Inference Guide (#27) --- tftrt/examples/image-classification/README.md | 232 +++++++++++++++--- 1 file changed, 195 insertions(+), 37 deletions(-) diff --git a/tftrt/examples/image-classification/README.md b/tftrt/examples/image-classification/README.md index 1990dd9ee..d84b7b8e4 100644 --- a/tftrt/examples/image-classification/README.md +++ b/tftrt/examples/image-classification/README.md @@ -1,21 +1,29 @@ -# Image classification examples +# Image classification example -This example includes scripts to run inference using a number of popular image classification models. +The example script `image_classification.py` runs inference using a number of +popular image classification models. This script is included in the NVIDIA +TensorFlow Docker containers under `/workspace/nvidia-examples`. See [Preparing +To Use NVIDIA +Containers](https://docs.nvidia.com/deeplearning/dgx/preparing-containers/index.html) +for more information. -You can turn on TF-TRT integration with the flag `--use_trt`. This -will apply TensorRT inference optimization to speed up execution for portions of -the model's graph where supported, and will fall back to native TensorFlow for -layers and operations which are not supported. -See https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html for more information. +You can enable TF-TRT integration by passing the `--use_trt` flag to the script. +This causes the script to apply TensorRT inference optimization to speed up +execution for portions of the model's graph where supported, and to fall back on +native TensorFlow for layers and operations which are not supported. See +[Accelerating Inference In TensorFlow With TensorRT User +Guide](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html) for +more information. -When using TF-TRT, you can also control the precision with `--precision`. -float32 is the default (`--precision fp32`) with float16 (`--precision fp16`) or -int8 (`--precision int8`) allowing further performance improvements. +When using the TF-TRT integration flag, you can use the precision option +(`--precision`) to control precision. float32 is the default (`--precision +fp32`) with float16 (`--precision fp16`) or int8 (`--precision int8`) allowing +further performance improvements. -int8 mode requires a calibration step which is done automatically, but you will -also have to specificy the directory in which the calibration dataset is stored -with `--calib_data_dir /imagenet_validation_data`. You can use the same data for -both calibration and validation. +int8 mode requires a calibration step (which is done automatically), but you +also must specificy the directory in which the calibration dataset is stored +with `--calib_data_dir /imagenet_validation_data`. You can use the same data +for both calibration and validation. ## Models @@ -34,61 +42,211 @@ We have verified the following models. For the accuracy numbers of these models on the ImageNet validation dataset, see -[Verified Models](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#verified-models) +[Verified Models](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#verified-models). ## Setup + +### Setup for running within an NVIDIA TensorFlow Docker container + If you are running these examples within the [NVIDIA TensorFlow docker -container](https://ngc.nvidia.com/catalog/containers/nvidia:tensorflow), you can -skip these steps by running `./install_dependencies.sh`. +container](https://ngc.nvidia.com/catalog/containers/nvidia:tensorflow) under +`/workspace/nvidia-examples/tensorrt/tftrt/examples/image-classification`, run +the `install_dependencies.sh` setup script. Then skip below to the +[Data](#Data) section. + +``` +cd /workspace/nvidia-examples/tensorrt/tftrt/examples/image-classification +./install_dependencies.sh +cd ../third_party/models +export PYTHONPATH="$PYTHONPATH:$PWD" +``` + +### Setup for running standalone + +If you are running these examples within your own TensorFlow environment, +perform the following steps: ``` -# Clone [tensorflow/models](https://github.com/tensorflow/models) +# Clone this repository (tensorflow/tensorrt) if you haven't already. +git clone https://github.com/tensorflow/tensorrt.git + +# Clone tensorflow/models. git clone https://github.com/tensorflow/models.git # Add the models directory to PYTHONPATH to install tensorflow/models. cd models export PYTHONPATH="$PYTHONPATH:$PWD" -# Run the TF Slim setup. +# Run the TensorFlow Slim setup. cd research/slim python setup.py install -# You may also need to install the requests package +# Install the requests package. pip install requests ``` -Note: the PYTHONPATH environment variable will be not be saved between different -shells. You can either repeat that step each time you work in a new shell, or -add `export PYTHONPATH="$PYTHONPATH:/path/to/tensorflow_models"` to your .bashrc -file (replacing /path/to/tensorflow_models with the path to your -tensorflow/models repository). -See [Setting Up The Environment +### PYTHONPATH environment variable + +The `PYTHONPATH` environment variable is not saved between different shell +sessions. To avoid having to set `PYTHONPATH` in each new shell session, you +can add the following line to your `.bashrc` file: + +```export PYTHONPATH="$PYTHONPATH:/path/to/tensorflow_models"``` + +replacing `/path/to/tensorflow_models` with the path to your `tensorflow/models` +repository). + +Also see [Setting Up The Environment ](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#image-class-envirn) for more information. ### Data -The example supports using a dataset in TFRecords or synthetic data. -In case of using TFRecord files, the scripts assume that TFRecords -are named according to the pattern: `validation-*-of-00128`. +The example script supports either using a dataset in TFRecord format or using +autogenerated synthetic data (with the `--use_synthetic` flag). If you use +TFRecord files, the script assumes that the TFRecords are named according to the +pattern: `validation-*-of-00128`. -The reported accuracy numbers are the results of running the scripts on +Note: The reported accuracy numbers are the results of running the scripts on the ImageNet validation dataset. -You can download and process Imagenet using [this script provided by TF -Slim](https://github.com/tensorflow/models/blob/master/research/slim/datasets/download_imagenet.sh). -Please note that this script downloads both the training and validation sets, -and this example only requires the validation set. -See [Obtaining The ImageNet Data +To download and process the ImageNet data, you can: + +- Use the scripts provided in the `nvidia-examples/build_imagenet_data` + directory in the NVIDIA TensorFlow Docker container `workspace` directory. + Follow the `README` file in that directory for instructions on how to use + these scripts. + +or + +- Use the scripts provided by TF Slim in the `tensorflow/models` repository at + `research/slim`. Consult the `README` file under `research/slim for + instructions on how to use these scripts. Also please note that these scripts + download both the training and validation sets, and this example only requires + the validation set. + +Also see [Obtaining The ImageNet Data ](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#image-class-data) for more information. +## Running the examples as a Jupyter notebook + +You can run the examples as a Jupyter notebook (`image-classification.ipynb`) +from this directory: + +``` +jupyter notebook --ip=0.0.0.0 +``` + +If you want to run these examples as a Jupyter notebook within an NVIDIA +TensorFlow Docker container, first you need to run the container with the +`--publish 0.0.0.0:8888:8888` option to publish Jupyter's port `8888` to the +host machine at port `8888` over all network interfaces (`0.0.0.0`). Then you +can use the following command in the +`/workspace/nvidia-examples/tensorrt/tftrt/examples/image-classification` +directory: + +``` +jupyter notebook --ip=0.0.0.0 --allow-root +``` + ## Usage -`python image_classification.py --data_dir /imagenet_validation_data --model vgg_16 [--use_trt]` +The main Python script is `image_classification.py`. Assuming that the ImageNet +validation data are located under `/data/imagenet/train-val-tfrecord`, you can +evaluate inference with TF-TRT integration using the pre-trained ResNet V1 50 +model as follows: + +``` +python image_classification.py --model resnet_v1_50 \ + --data_dir /data/imagenet/train-val-tfrecord \ + --use_trt \ + --precision fp16 +``` + +Where: + +`--model`: Which model to use to run inference, in this case ResNet V1 50. + +`--data_dir`: Path to the ImageNet TFRecord validation files. + +`--use_trt`: Convert the graph to a TensorRT graph. + +`--precision`: Precision mode to use, in this case FP16. Run with `--help` to see all available options. -See [General Script Usage +Also see [General Script Usage ](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#image-class-usage) for more information. + +## Output + +The script first loads the pre-trained model. If given the flag `--use_trt`, +the model is converted to a TensorRT graph, and the script displays (in addition +to its initial configuration options): + +- the number of nodes before conversion (`num_nodes(native_tf)`) + +- the number of nodes after conversion (`num_nodes(trt_total)`) + +- the number of separate TensorRT nodes (`num_nodes(trt_only)`) + +- the size of the graph before conversion (`graph_size(MB)(native_tf)`) + +- the size of the graph after conversion (`graph_size(MB)(trt)`) + +- how long the conversion took (`time(s)(trt_conversion)`) + +For example: + +``` +num_nodes(native_tf): 741 +num_nodes(trt_total): 10 +num_nodes(trt_only): 1 +graph_size(MB)(native_tf): *** +graph_size(MB)(tft): *** +time(s)(trt_conversion): *** +``` + +Note: For a list of supported operations that can be converted to a TensorRT +graph, see the [Supported +Ops](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#support-ops) +section of the [Accelerating Inference In TensorFlow With TensorRT User +Guide](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html). + +The script then begins running inference on the ImageNet validation set, +displaying run times of each iteration after the interval defined by the +`--display_every` option (default: `100`): + +``` +running inference... + step 100/6202, iter_time(ms)=**.****, images/sec=*** + step 200/6202, iter_time(ms)=**.****, images/sec=*** + step 300/6202, iter_time(ms)=**.****, images/sec=*** + ... +``` + +On completion, the script prints overall accuracy and timing information over +the inference session: + +``` +results of resnet_v1_50: + accuracy: 75.95 + images/sec: *** + 99th_percentile(ms): *** + total_time(s): *** + latency_mean(ms): *** +``` + +The accuracy metric measures the percentage of predictions from inference that +match the labels on the ImageNet Validation set. The remaining metrics capture +various performance measurements: + +- number of images processed per second (`images/sec`) + +- total time of the inference session (`total_time(s)`) + +- the mean duration for each iteration (`latency_mean(ms)`) + +- the slowest duration for an iteration (`99th_percentile(ms)`) From d2c28ffb775f8b550541fbde7061caf3daf14375 Mon Sep 17 00:00:00 2001 From: mdrozdowski1996 Date: Fri, 1 Mar 2019 14:04:04 -0800 Subject: [PATCH 27/56] add params to tftrt conversion (#30) * add params * Fix style for parenthesis --- tftrt/examples/object_detection/object_detection.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tftrt/examples/object_detection/object_detection.py b/tftrt/examples/object_detection/object_detection.py index 17e93a993..0d3528cb6 100644 --- a/tftrt/examples/object_detection/object_detection.py +++ b/tftrt/examples/object_detection/object_detection.py @@ -349,7 +349,9 @@ def optimize_model(config_path, max_batch_size=max_batch_size, max_workspace_size_bytes=max_workspace_size_bytes, precision_mode=precision_mode, - minimum_segment_size=minimum_segment_size) + minimum_segment_size=minimum_segment_size, + is_dynamic_op=True, + maximum_cached_engines=10) # perform calibration for int8 precision if precision_mode == 'INT8': From 0d97bb87d54b8821f03ea9f8747b4c726ca04826 Mon Sep 17 00:00:00 2001 From: brian pardini Date: Sat, 2 Mar 2019 10:22:44 -0800 Subject: [PATCH 28/56] Create Jupyter notebook for image classification example (#28) --- .../image_classification.ipynb | 417 ++++++++++++++++++ 1 file changed, 417 insertions(+) create mode 100644 tftrt/examples/image-classification/image_classification.ipynb diff --git a/tftrt/examples/image-classification/image_classification.ipynb b/tftrt/examples/image-classification/image_classification.ipynb new file mode 100644 index 000000000..7bb4b8bf2 --- /dev/null +++ b/tftrt/examples/image-classification/image_classification.ipynb @@ -0,0 +1,417 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Image classification example\n", + "\n", + "This example script runs inference using a number of popular image classification models. This script is included in the NVIDIA TensorFlow Docker containers under `/workspace/nvidia-examples`. See [Preparing To Use NVIDIA Containers](https://docs.nvidia.com/deeplearning/dgx/preparing-containers/index.html) for more information.\n", + "\n", + "You can enable TF-TRT integration by passing the `--use_trt` flag to the script. This causes the script to apply TensorRT inference optimization to speed up execution for portions of the model's graph where supported, and to fall back on native TensorFlow for layers and operations which are not supported. See [Accelerating Inference In TensorFlow With TensorRT User Guide](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html) for more information. \n", + "\n", + "When using TF-TRT, you can use the precision option (`--precision`) to control precision. float32 is the default (`--precision fp32`) with float16 (`--precision fp16`) or int8 (`--precision int8`) allowing further performance improvements. \n", + "\n", + "int8 mode requires a calibration step (which is done automatically), but you also must specificy the directory in which the calibration dataset is stored with `--calib_data_dir /imagenet_validation_data`. You can use the same data for both calibration and validation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Models\n", + "\n", + "We have verified the following models.\n", + "\n", + "* MobileNet v1\n", + "* MobileNet v2\n", + "* NASNet - Large\n", + "* NASNet - Mobile\n", + "* ResNet50 v1\n", + "* ResNet50 v2\n", + "* VGG16\n", + "* VGG19\n", + "* Inception v3\n", + "* Inception v4\n", + "\n", + "For the accuracy numbers of these models on the ImageNet validation dataset, see [Verified Models](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#verified-models)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Usage\n", + "\n", + "The example Python script is `image_classification.py`. You can evaluate inference with TF-TRT integration using the pre-trained ResNet V1 50 model by calling the script with the following arguments:\n", + "\n", + "```\n", + "python image_classification.py --model resnet_v1_50 \\\n", + " --data_dir /path/to/imagenet/tfrecord/files \\\n", + " --use_trt \\\n", + " --precision fp16\n", + "```\n", + "\n", + "Where:\n", + "\n", + "`--model`: Which model to use to run inference, in this case ResNet V1 50.\n", + "\n", + "`--data_dir`: Path to the ImageNet TFRecord validation files.\n", + "\n", + "`--use_trt`: Convert the graph to a TensorRT graph.\n", + "\n", + "`--precision`: Precision mode to use, in this case FP16.\n", + "\n", + "Run with `--help` to see all available options.\n", + "\n", + "Note: In this notebook, we run the script inside IPython using the `%run` built-in command, so that realtime output and tracebacks are displayed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%run image_classification --help" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Also see [General Script Usage](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#image-class-usage) for more information." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Output\n", + "\n", + "The script first loads the pre-trained model. If given the flag `--use_trt`, the model is converted to a TensorRT graph, and the script displays (in addition to its inital configuration options):\n", + "\n", + "- the number of nodes before conversion (`num_nodes(native_tf)`)\n", + "\n", + "- the number of nodes after conversion (`num_nodes(trt_total)`)\n", + "\n", + "- the number of separate TensorRT nodes (`num_nodes(trt_only)`)\n", + "\n", + "- the size of the graph before conversion (`graph_size(MB)(native_tf)`)\n", + "\n", + "- the size of the graph after conversion (`graph_size(MB)(trt)`)\n", + "\n", + "- how long the conversion took (`time(s)(trt_conversion)`)\n", + "\n", + "For example:\n", + "\n", + "```\n", + "num_nodes(native_tf): 741\n", + "num_nodes(trt_total): 10\n", + "num_nodes(trt_only): 1\n", + "graph_size(MB)(native_tf): ***\n", + "graph_size(MB)(tft): ***\n", + "time(s)(trt_conversion): ***\n", + "```\n", + "\n", + "Note: For a list of supported operations that can be converted to a TensorRT graph, see the [Supported\n", + "Ops](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#support-ops) section of the [Accelerating Inference In TensorFlow With TensorRT User Guide](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html).\n", + "\n", + "The script then begins running inference on the ImageNet validation set, displaying run times of each iteration after the interval defined by the `--display_every` option (default: `100`):\n", + "\n", + "```\n", + "running inference...\n", + " step 100/6202, iter_time(ms)=**.****, images/sec=***\n", + " step 200/6202, iter_time(ms)=**.****, images/sec=***\n", + " step 300/6202, iter_time(ms)=**.****, images/sec=***\n", + " ...\n", + "```\n", + "\n", + "On completion, the script prints overall accuracy and timing information over the inference session:\n", + "\n", + "```\n", + "results of resnet_v1_50:\n", + " accuracy: 75.95\n", + " images/sec: ***\n", + " 99th_percentile(ms): ***\n", + " total_time(s): ***\n", + " latency_mean(ms): ***\n", + "```\n", + "\n", + "The accuracy metric measures the percentage of predictions from inference that match the labels on the ImageNet Validation set. The remaining metrics capture various performance measurements:\n", + "\n", + "- number of images processed per second (`images/sec`)\n", + "\n", + "- total time of the inference session (`total_time(s)`)\n", + "\n", + "- the mean duration for each iteration (`latency_mean(ms)`)\n", + "\n", + "- the slowest duration for an iteration (`99th_percentile(ms)`)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using TF-TRT With ResNet V1 50\n", + "\n", + "Here we walk through how to use the example Python scripts in the with the ResNet V1 50 model.\n", + "\n", + "Using TF-TRT with precision modes lower than FP32, that is, FP16 and INT8, improves the performance of inference. The FP16 precision mode uses Tensor Cores or half-precision hardware instructions, if possible, while the INT8 precision mode uses Tensor Cores or integer hardware instructions. INT8 mode also requires running a calibration step, which the script does automatically.\n", + "\n", + "Below we use the example script to compare the accuracy and timing performance of all the precision modes when running inference using the ResNet V1 50 model." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Native TensorFlow Using FP32\n", + "\n", + "This is our baseline session running inference using native TensorFlow without TensorRT integration/conversion.\n", + "\n", + "First, set `DATA_DIR` to where you stored the ImageNet TFRecord validation files:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "DATA_DIR = \"/path/to/imagenet/tfrecord/files\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can run the baseline session with native TensorFlow.\n", + "\n", + "Note: We use the `--cache` flag to allow the script to cache checkpoint and frozen graph files to use with future sessions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%run image_classification --model resnet_v1_50 \\\n", + " --data_dir $DATA_DIR \\\n", + " --cache" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Look for the accuracy and timing information under:\n", + "\n", + "```\n", + "results of resnet_v1_50:\n", + " ...\n", + "```\n", + "\n", + "You can compare the accuracy metrics for the ResNet 50 models with the metrics listed at: [Pre-trained model](https://github.com/tensorflow/models/tree/master/official/resnet#pre-trained-model)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### TF-TRT Using FP32\n", + "\n", + "In this session, we use the same precision mode as in our native TensorFlow session (FP32), but this time we use the `--use_trt` flag to convert the graph to a TensorRT optimized graph." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%run image_classification --model resnet_v1_50 \\\n", + " --data_dir $DATA_DIR \\\n", + " --use_trt \\\n", + " --cache" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Before the script starts running inference, it converts the TensorFlow graph to a TensorRT optimized graph with fewer nodes. Look for the following metrics in the log:\n", + "\n", + "```\n", + "num_nodes(native_tf): ***\n", + "num_nodes(tftrt_total): ***\n", + "num_nodes(trt_only): ***\n", + "graph_size(MB)(native_tf): ***\n", + "graph_size(MB)(tft): ***\n", + "...\n", + "time(s)(trt_conversion): ***\n", + "```\n", + "\n", + "Note: For a list of supported operations that can be converted to a TensorRT graph, see [Supported Ops](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#support-ops).\n", + "\n", + "Again, note the accuracy and timing information under:\n", + "\n", + "```\n", + "results of resnet_v1_50:\n", + " ...\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### TF-TRT Using FP16\n", + "\n", + "In this session, we continue to use TF-TRT conversion, but we reduce the precision mode to FP16, allowing the use of Tensor Cores for performance improvements during inference, while preserving accuracy within the acceptable tolerance level (0.1%)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%run image_classification --model resnet_v1_50 \\\n", + " --data_dir $DATA_DIR \\\n", + " --use_trt \\\n", + " --precision fp16 \\\n", + " --cache" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Again, we see that the native TensorFlow graph gets converted to a TensorRT graph. Look again for the following in the log to confirm:\n", + "\n", + "```\n", + "num_nodes(native_tf): ***\n", + "num_nodes(tftrt_total): ***\n", + "num_nodes(trt_only): ***\n", + "graph_size(MB)(native_tf): ***\n", + "graph_size(MB)(tft): ***\n", + "...\n", + "time(s)(trt_conversion): ***\n", + "```\n", + "\n", + "Compare the results with the previous sessions:\n", + "\n", + "```\n", + "results of resnet_v1_50:\n", + " ...\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### TF-TRT Using INT8\n", + "\n", + "For this session we continue to use TF-TRT conversion, and we reduce the precision further to INT8 for faster computation. Because INT8 has significantly lower precision and dynamic range than FP32, the INT8 precision mode requires an additional calibration step before performing the type conversion. In this calibration step, inference is first run with FP32 precision on a calibration dataset to generate many INT8 quantizations of the weights and activations in the trained TensorFlow graph, from which are chosen the INT8 quantizations that minimize information loss. For more details on the calibration process, see the [8-bit Inference with TensorRT presentation](http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf).\n", + "\n", + "The calibration dataset should closely reflect the distribution of the problem dataset. In this walkthrough, we use the same ImageNet validation set training data for the calibration data, with `--calib_data_dir $DATA_DIR`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%run image_classification --model resnet_v1_50 \\\n", + " --data_dir $DATA_DIR \\\n", + " --use_trt \\\n", + " --precision int8 \\\n", + " --calib_data_dir $DATA_DIR \\\n", + " --cache" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This time, we see the script performing the calibration step:\n", + "\n", + "```\n", + "Calibrating INT8...\n", + "...\n", + "INFO:tensorflow:Evaluation [6/62]\n", + "INFO:tensorflow:Evaluation [12/62]\n", + "INFO:tensorflow:Evaluation [18/62]\n", + "...\n", + "```\n", + "\n", + "The process completes with the message:\n", + "\n", + "```\n", + "INT8 graph created.\n", + "```\n", + "\n", + "When the calibration step completes -- it may take some time -- we again see that the native TensorFlow graph gets converted to a TensorRT graph. Look again for the following in the log to confirm:\n", + "\n", + "```\n", + "num_nodes(native_tf): ***\n", + "num_nodes(tftrt_total): ***\n", + "num_nodes(trt_only): ***\n", + "graph_size(MB)(native_tf): ***\n", + "graph_size(MB)(tft): ***\n", + "...\n", + "time(s)(trt_conversion): ***\n", + "```\n", + "\n", + "Also notice the following INT8-specific timing information:\n", + "\n", + "```\n", + "time(s)(trt_calibration): ***\n", + "...\n", + "time(s)(trt_int8_conversion): ***\n", + "```\n", + "\n", + "Compare the results with the previous sessions:\n", + "\n", + "```\n", + "results of resnet_v1_50:\n", + " ...\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "Congratulations! You have run inference with an image classification model using various modes of precision and taking advantge of TensorRT inference optimization where possible." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From c4a7d5ac8aa78213cbe41b8e43f5832df426a21a Mon Sep 17 00:00:00 2001 From: Haibo Hao Date: Wed, 6 Mar 2019 01:51:32 +0800 Subject: [PATCH 29/56] Fixed bug that precision does not support lowercase (#29) --- tftrt/examples/image-classification/image_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index c8aea134b..9afc83efb 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -562,7 +562,7 @@ def get_frozen_graph( outputs=['logits', 'classes'], max_batch_size=batch_size, max_workspace_size_bytes=max_workspace_size, - precision_mode=precision, + precision_mode=precision.upper(), minimum_segment_size=minimum_segment_size, is_dynamic_op=use_dynamic_op ) From 703031eb0c96201aca42fe8f0c5fba3a617b4989 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Sun, 10 Mar 2019 12:02:19 -0700 Subject: [PATCH 30/56] Remove redundant argument `run_calibration` --- tftrt/examples/image-classification/image_classification.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index 9afc83efb..b722f6ebd 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -78,7 +78,7 @@ def after_run(self, run_context, run_values): run_context.request_stop() def run(frozen_graph, model, data_files, batch_size, - num_iterations, num_warmup_iterations, use_synthetic, display_every=100, run_calibration=False, + num_iterations, num_warmup_iterations, use_synthetic, display_every=100, mode='validation', target_duration=None): """Evaluates a frozen graph @@ -94,7 +94,6 @@ def run(frozen_graph, model, data_files, batch_size, num_warmup_iterations: int, number of iteration(batches) to exclude from benchmark measurments use_synthetic: bool, if true run using real data, otherwise synthetic display_every: int, print log every @display_every iteration - run_calibration: bool, run using calibration or not (only int8 precision) mode: validation - using estimator.evaluate with accuracy measurments, benchmark - using estimator.predict """ @@ -578,7 +577,7 @@ def get_frozen_graph( print('Calibrating INT8...') start_time = time.time() run(calib_graph, model, calib_files, batch_size, - num_calib_inputs // batch_size, 0, use_synthetic=use_synthetic, run_calibration=True) + num_calib_inputs // batch_size, 0, use_synthetic=use_synthetic) times['trt_calibration'] = time.time() - start_time start_time = time.time() From 9c53bcbef9e6afb056dde625c7c3f22beb882143 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Tue, 12 Mar 2019 12:43:46 -0700 Subject: [PATCH 31/56] Update README.md --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 3bcdf2aab..afaf1d0b1 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,7 @@ +# Documentation for TensorRT in TensorFlow (TF-TRT) + +The documentaion on how to accelerate inference in TensorFlow with TensorRT (TF-TRT) is here: https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html + # Examples for TensorRT in TensorFlow (TF-TRT) This repository contains a number of different examples From 950811e386a6b82da5609eb045ddc7260bef062e Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Thu, 14 Mar 2019 10:50:33 -0700 Subject: [PATCH 32/56] Change default value of obj detection args (#31) --- tftrt/examples/object_detection/object_detection.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tftrt/examples/object_detection/object_detection.py b/tftrt/examples/object_detection/object_detection.py index 0d3528cb6..2572d8070 100644 --- a/tftrt/examples/object_detection/object_detection.py +++ b/tftrt/examples/object_detection/object_detection.py @@ -216,8 +216,9 @@ def optimize_model(config_path, override_resizer_shape=None, max_batch_size=1, precision_mode='FP32', - minimum_segment_size=50, - max_workspace_size_bytes=1 << 25, + minimum_segment_size=2, + max_workspace_size_bytes=1 << 32, + maximum_cached_engines=100, calib_images_dir=None, num_calib_images=None, calib_image_shape=None, @@ -351,7 +352,7 @@ def optimize_model(config_path, precision_mode=precision_mode, minimum_segment_size=minimum_segment_size, is_dynamic_op=True, - maximum_cached_engines=10) + maximum_cached_engines=maximum_cached_engines) # perform calibration for int8 precision if precision_mode == 'INT8': From bb8a441900ff11f4d0203fff36c9ea784257d27a Mon Sep 17 00:00:00 2001 From: Pietro Cicotti Date: Thu, 14 Mar 2019 13:04:22 -0700 Subject: [PATCH 33/56] restrict latency timing around session.run (#37) --- tftrt/examples/image-classification/image_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index b722f6ebd..ac7d5a719 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -38,7 +38,7 @@ def __init__(self, batch_size, num_records, display_every): self.num_steps = (num_records + batch_size - 1) / batch_size self.batch_size = batch_size - def begin(self): + def before_run(self, run_context): self.start_time = time.time() def after_run(self, run_context, run_values): From ccc5f18e5811b3e41969bbd878c7fd3c8fc92fc0 Mon Sep 17 00:00:00 2001 From: mdrozdowski1996 Date: Thu, 14 Mar 2019 15:13:49 -0700 Subject: [PATCH 34/56] update log print (#38) * update log print * print log for calibration * Update object_detection.py --- .../object_detection/object_detection.py | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/tftrt/examples/object_detection/object_detection.py b/tftrt/examples/object_detection/object_detection.py index 2572d8070..00ce61706 100644 --- a/tftrt/examples/object_detection/object_detection.py +++ b/tftrt/examples/object_detection/object_detection.py @@ -20,7 +20,6 @@ import tensorflow as tf import tensorflow.contrib.tensorrt as trt -import tqdm import pdb from collections import namedtuple @@ -224,7 +223,8 @@ def optimize_model(config_path, calib_image_shape=None, tmp_dir='.optimize_model_tmp_dir', remove_tmp_dir=True, - output_path=None): + output_path=None, + display_every=100): """Optimizes an object detection model using TensorRT Optimizes an object detection model using TensorRT. This method also @@ -275,6 +275,7 @@ def optimize_model(config_path, tmp_dir or throw error. output_path: An optional string representing the path to save the optimized GraphDef to. + display_every: print log for calibration every display_every iteration Returns ------- @@ -342,6 +343,7 @@ def optimize_model(config_path, # optionally perform TensorRT optimization if use_trt: + runtimes = [] with tf.Graph().as_default() as tf_graph: with tf.Session(config=tf_config) as tf_sess: frozen_graph = trt.create_inference_graph( @@ -371,7 +373,7 @@ def optimize_model(config_path, image_paths = glob.glob(os.path.join(calib_images_dir, '*.jpg')) image_paths = image_paths[0:num_calib_images] - for image_idx in tqdm.tqdm(range(0, len(image_paths), max_batch_size)): + for image_idx in range(0, len(image_paths), max_batch_size): # read batch of images batch_images = [] @@ -379,10 +381,18 @@ def optimize_model(config_path, image = _read_image(image_path, calib_image_shape) batch_images.append(image) + t0 = time.time() # execute batch of images boxes, classes, scores, num_detections = tf_sess.run( [tf_boxes, tf_classes, tf_scores, tf_num_detections], feed_dict={tf_input: batch_images}) + t1 = time.time() + runtimes.append(float(t1 - t0)) + if len(runtimes) % display_every == 0: + print(" step %d/%d, iter_time(ms)=%.4f" % ( + len(runtimes), + (len(image_path) + max_batch_size - 1) / max_batch_size, + np.mean(runtimes) * 1000)) pdb.set_trace() frozen_graph = trt.calib_graph_to_infer_graph(frozen_graph) @@ -462,7 +472,8 @@ def benchmark_model(frozen_graph, num_images=4096, tmp_dir='.benchmark_model_tmp_dir', remove_tmp_dir=True, - output_path=None): + output_path=None, + display_every=100): """Computes accuracy and performance statistics Computes accuracy and performance statistics by executing over many images @@ -487,7 +498,7 @@ def benchmark_model(frozen_graph, a temporary directory to store intermediate files. output_path: An optional string representing a path to store the statistics in JSON format. - + display_every: int, print log every display_every iteration Returns ------- statistics: A named dictionary of accuracy and performance statistics @@ -542,7 +553,7 @@ def benchmark_model(frozen_graph, NUM_DETECTIONS_NAME + ':0') # load batches from coco dataset - for image_idx in tqdm.tqdm(range(0, len(image_ids), batch_size)): + for image_idx in range(0, len(image_ids), batch_size): batch_image_ids = image_ids[image_idx:image_idx + batch_size] batch_images = [] batch_coco_images = [] @@ -571,6 +582,11 @@ def benchmark_model(frozen_graph, # log runtime and image count runtimes.append(float(t1 - t0)) + if len(runtimes) % display_every == 0: + print(" step %d/%d, iter_time(ms)=%.4f" % ( + len(runtimes), + (len(image_ids) + batch_size - 1) / batch_size, + np.mean(runtimes) * 1000)) image_counts.append(len(batch_images)) # add coco detections for this batch to running list From b5ff3a1d864f86c985f731d70298e059cf03ab8f Mon Sep 17 00:00:00 2001 From: Pietro Cicotti Date: Wed, 27 Mar 2019 09:41:22 -0700 Subject: [PATCH 35/56] Added median and min to performance report. (#35) * Added median and min to performance report. * fixed line break in target duration message --- .../image_classification.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index ac7d5a719..04810cbfa 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -63,13 +63,16 @@ def __init__(self, target_duration=None, iteration_limit=None): def before_run(self, run_context): if not self.start_time: self.start_time = time.time() - print(" running for target duration from %d", self.start_time) + if self.target_duration: + print(" running for target duration {} seconds".format(self.target_duration), end="") + print(" from {}".format(time.asctime(time.localtime(self.start_time)))) def after_run(self, run_context, run_values): if self.target_duration: current_time = time.time() if (current_time - self.start_time) > self.target_duration: - print(" target duration %d reached at %d, requesting stop" % (self.target_duration, current_time)) + print(" target duration {}".format(self.target_duration), end="") + print(" reached at {}, requesting stop".format(time.asctime(time.localtime(current_time)))) run_context.request_stop() if self.iteration_limit: @@ -194,6 +197,8 @@ def input_fn(): results['images_per_sec'] = np.mean(batch_size / iter_times) results['99th_percentile'] = np.percentile(iter_times, q=99, interpolation='lower') * 1000 results['latency_mean'] = np.mean(iter_times) * 1000 + results['latency_median'] = np.median(iter_times) * 1000 + results['latency_min'] = np.min(iter_times) * 1000 return results class NetDef(object): @@ -679,7 +684,7 @@ def get_files(data_dir, filename_pattern): elif args.mode == "benchmark": data_files = [os.path.join(path, name) for path, _, files in os.walk(args.data_dir) for name in files] else: - raise ValueError("Mode must be either 'validation' or 'benchamark'") + raise ValueError("Mode must be either 'validation' or 'benchmark'") calib_files = get_files(args.calib_data_dir, 'train*') frozen_graph, num_nodes, times, graph_sizes = get_frozen_graph( @@ -728,6 +733,8 @@ def print_dict(input_dict, str='', scale=None): if args.mode == 'validation': print(' accuracy: %.2f' % (results['accuracy'] * 100)) print(' images/sec: %d' % results['images_per_sec']) - print(' 99th_percentile(ms): %.1f' % results['99th_percentile']) + print(' 99th_percentile(ms): %.2f' % results['99th_percentile']) print(' total_time(s): %.1f' % results['total_time']) - print(' latency_mean(ms): %.1f' % results['latency_mean']) + print(' latency_mean(ms): %.2f' % results['latency_mean']) + print(' latency_median(ms): %.2f' % results['latency_median']) + print(' latency_min(ms): %.2f' % results['latency_min']) From 6cc38460b42fa35a4f7cc74b39a2d41cb9613304 Mon Sep 17 00:00:00 2001 From: Pietro Cicotti Date: Wed, 27 Mar 2019 09:43:05 -0700 Subject: [PATCH 36/56] =?UTF-8?q?Add=20option=20to=20store=20trt=20engine.?= =?UTF-8?q?=20Save=20frozen=20graph=20before=20tf=5Ftrt=20conve=E2=80=A6?= =?UTF-8?q?=20(#36)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add option to store trt engine. Save frozen graph before tf_trt conversion. * removed graph output before conversion which can be obtained with use_trt=false --- .../image-classification/image_classification.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index 04810cbfa..5b49efd20 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -515,6 +515,7 @@ def get_frozen_graph( model, model_dir=None, use_trt=False, + engine_dir=None, use_dynamic_op=False, precision='fp32', batch_size=8, @@ -575,6 +576,16 @@ def get_frozen_graph( num_nodes['trt_only'] = len([1 for n in frozen_graph.node if str(n.op)=='TRTEngineOp']) graph_sizes['trt'] = len(frozen_graph.SerializeToString()) + if engine_dir: + segment_number = 0 + for node in frozen_graph.node: + if node.op == "TRTEngineOp": + engine = node.attr["serialized_segment"].s + engine_path = engine_dir+'/{}_{}_{}_segment{}.trtengine'.format(model, precision, batch_size, segment_number) + segment_number += 1 + with open(engine_path, "wb") as f: + f.write(engine) + if precision == 'int8': calib_graph = frozen_graph graph_sizes['calib'] = len(calib_graph.SerializeToString()) @@ -628,6 +639,9 @@ def get_frozen_graph( 'loaded from if --model_dir is not provided.') parser.add_argument('--use_trt', action='store_true', help='If set, the graph will be converted to a TensorRT graph.') + parser.add_argument('--engine_dir', type=str, default=None, + help='Directory where to write trt engines. Engines are written only if the directory ' \ + 'is provided. This option is ignored when not using tf_trt.') parser.add_argument('--use_trt_dynamic_op', action='store_true', help='If set, TRT conversion will be done using dynamic op instead of statically.') parser.add_argument('--precision', type=str, choices=['fp32', 'fp16', 'int8'], default='fp32', @@ -691,6 +705,7 @@ def get_files(data_dir, filename_pattern): model=args.model, model_dir=args.model_dir, use_trt=args.use_trt, + engine_dir=args.engine_dir, use_dynamic_op=args.use_trt_dynamic_op, precision=args.precision, batch_size=args.batch_size, From 9464cc2cba010c2764a30daae48a2f5c5f671bf0 Mon Sep 17 00:00:00 2001 From: mdrozdowski1996 Date: Wed, 27 Mar 2019 15:51:25 -0700 Subject: [PATCH 37/56] support synthetic for object detection (#26) * support synthetic for object detection * update object_detection --- .../object_detection/object_detection.py | 203 ++++++++++-------- 1 file changed, 111 insertions(+), 92 deletions(-) diff --git a/tftrt/examples/object_detection/object_detection.py b/tftrt/examples/object_detection/object_detection.py index 00ce61706..4028b7228 100644 --- a/tftrt/examples/object_detection/object_detection.py +++ b/tftrt/examples/object_detection/object_detection.py @@ -39,6 +39,9 @@ from object_detection.protos import pipeline_pb2, image_resizer_pb2 from object_detection import exporter +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval + Model = namedtuple('Model', ['name', 'url', 'extract_dir']) INPUT_NAME = 'image_tensor' @@ -473,7 +476,9 @@ def benchmark_model(frozen_graph, tmp_dir='.benchmark_model_tmp_dir', remove_tmp_dir=True, output_path=None, - display_every=100): + display_every=100, + use_synthetic=False, + num_warmup_iterations=50): """Computes accuracy and performance statistics Computes accuracy and performance statistics by executing over many images @@ -491,7 +496,8 @@ def benchmark_model(frozen_graph, batch_size: An integer representing the batch size to use when feeding images to the model. image_shape: An optional tuple of integers representing a fixed shape - to resize all images before testing. + to resize all images before testing. For synthetic data the default + image_shape is [600, 600, 3] num_images: An integer representing the number of images in the dataset to evaluate with. tmp_dir: A string representing the path where the function may create @@ -499,6 +505,8 @@ def benchmark_model(frozen_graph, output_path: An optional string representing a path to store the statistics in JSON format. display_every: int, print log every display_every iteration + num_warmup_iteration: An integer represtening number of initial iteration, + that are not cover in performance statistics Returns ------- statistics: A named dictionary of accuracy and performance statistics @@ -512,19 +520,18 @@ def benchmark_model(frozen_graph, raise RuntimeError( 'Fixed image shape must be provided for batch size > 1') - from pycocotools.coco import COCO - from pycocotools.cocoeval import COCOeval + if not use_synthetic: + coco = COCO(annotation_file=annotation_path) - coco = COCO(annotation_file=annotation_path) - # get list of image ids to use for evaluation - image_ids = coco.getImgIds() - if num_images > len(image_ids): - print( - 'Num images provided %d exceeds number in dataset %d, using %d images instead' - % (num_images, len(image_ids), len(image_ids))) - num_images = len(image_ids) - image_ids = image_ids[0:num_images] + # get list of image ids to use for evaluation + image_ids = coco.getImgIds() + if num_images > len(image_ids): + print( + 'Num images provided %d exceeds number in dataset %d, using %d images instead' + % (num_images, len(image_ids), len(image_ids))) + num_images = len(image_ids) + image_ids = image_ids[0:num_images] # load frozen graph from file if string, otherwise must be GraphDef if isinstance(frozen_graph, str): @@ -553,93 +560,105 @@ def benchmark_model(frozen_graph, NUM_DETECTIONS_NAME + ':0') # load batches from coco dataset - for image_idx in range(0, len(image_ids), batch_size): - batch_image_ids = image_ids[image_idx:image_idx + batch_size] - batch_images = [] - batch_coco_images = [] - - # read images from file - for image_id in batch_image_ids: - coco_img = coco.imgs[image_id] - batch_coco_images.append(coco_img) - image_path = os.path.join(images_dir, - coco_img['file_name']) - image = _read_image(image_path, image_shape) - batch_images.append(image) - - # run once outside of timing to initialize - if image_idx == 0: + for image_idx in range(0, num_images, batch_size): + if use_synthetic: + if image_shape is None: + batch_images = np.random.randint(256, size=(batch_size, 600, 600, 3)) + else: + batch_images = np.random(256, size=(batch_size, image_shape[0], image_shape[1], 3)) + else: + batch_image_ids = image_ids[image_idx:image_idx + batch_size] + batch_images = [] + batch_coco_images = [] + # read images from file + for image_id in batch_image_ids: + coco_img = coco.imgs[image_id] + batch_coco_images.append(coco_img) + image_path = os.path.join(images_dir, + coco_img['file_name']) + image = _read_image(image_path, image_shape) + batch_images.append(image) + + # run num_warmup_iterations outside of timing + if image_idx < num_warmup_iterations: boxes, classes, scores, num_detections = tf_sess.run( [tf_boxes, tf_classes, tf_scores, tf_num_detections], feed_dict={tf_input: batch_images}) - - # execute model and compute time difference - t0 = time.time() - boxes, classes, scores, num_detections = tf_sess.run( - [tf_boxes, tf_classes, tf_scores, tf_num_detections], - feed_dict={tf_input: batch_images}) - t1 = time.time() - - # log runtime and image count - runtimes.append(float(t1 - t0)) - if len(runtimes) % display_every == 0: - print(" step %d/%d, iter_time(ms)=%.4f" % ( - len(runtimes), - (len(image_ids) + batch_size - 1) / batch_size, - np.mean(runtimes) * 1000)) - image_counts.append(len(batch_images)) - - # add coco detections for this batch to running list - batch_coco_detections = [] - for i, image_id in enumerate(batch_image_ids): - image_width = batch_coco_images[i]['width'] - image_height = batch_coco_images[i]['height'] - - for j in range(int(num_detections[i])): - bbox = boxes[i][j] - bbox_coco_fmt = [ - bbox[1] * image_width, # x0 - bbox[0] * image_height, # x1 - (bbox[3] - bbox[1]) * image_width, # width - (bbox[2] - bbox[0]) * image_height, # height - ] - - coco_detection = { - 'image_id': image_id, - 'category_id': int(classes[i][j]), - 'bbox': bbox_coco_fmt, - 'score': float(scores[i][j]) - } - - coco_detections.append(coco_detection) - - # write coco detections to file - subprocess.call(['mkdir', '-p', tmp_dir]) - coco_detections_path = os.path.join(tmp_dir, 'coco_detections.json') - with open(coco_detections_path, 'w') as f: - json.dump(coco_detections, f) - - # compute coco metrics - cocoDt = coco.loadRes(coco_detections_path) - eval = COCOeval(coco, cocoDt, 'bbox') - eval.params.imgIds = image_ids - - eval.evaluate() - eval.accumulate() - eval.summarize() - - statistics = { - 'map': eval.stats[0], - 'avg_latency_ms': 1000.0 * np.mean(runtimes), - 'avg_throughput_fps': np.sum(image_counts) / np.sum(runtimes), - 'runtimes_ms': [1000.0 * r for r in runtimes] - } + else: + # execute model and compute time difference + t0 = time.time() + boxes, classes, scores, num_detections = tf_sess.run( + [tf_boxes, tf_classes, tf_scores, tf_num_detections], + feed_dict={tf_input: batch_images}) + t1 = time.time() + + # log runtime and image count + runtimes.append(float(t1 - t0)) + if len(runtimes) % display_every == 0: + print(" step %d/%d, iter_time(ms)=%.4f" % ( + len(runtimes), + (len(image_ids) + batch_size - 1) / batch_size, + np.mean(runtimes) * 1000)) + image_counts.append(len(batch_images)) + + if not use_synthetic: + # add coco detections for this batch to running list + batch_coco_detections = [] + for i, image_id in enumerate(batch_image_ids): + image_width = batch_coco_images[i]['width'] + image_height = batch_coco_images[i]['height'] + + for j in range(int(num_detections[i])): + bbox = boxes[i][j] + bbox_coco_fmt = [ + bbox[1] * image_width, # x0 + bbox[0] * image_height, # x1 + (bbox[3] - bbox[1]) * image_width, # width + (bbox[2] - bbox[0]) * image_height, # height + ] + + coco_detection = { + 'image_id': image_id, + 'category_id': int(classes[i][j]), + 'bbox': bbox_coco_fmt, + 'score': float(scores[i][j]) + } + + coco_detections.append(coco_detection) + + if not use_synthetic: + # write coco detections to file + subprocess.call(['mkdir', '-p', tmp_dir]) + coco_detections_path = os.path.join(tmp_dir, 'coco_detections.json') + with open(coco_detections_path, 'w') as f: + json.dump(coco_detections, f) + + # compute coco metrics + cocoDt = coco.loadRes(coco_detections_path) + eval = COCOeval(coco, cocoDt, 'bbox') + eval.params.imgIds = image_ids + + eval.evaluate() + eval.accumulate() + eval.summarize() + + statistics = { + 'map': eval.stats[0], + 'avg_latency_ms': 1000.0 * np.mean(runtimes), + 'avg_throughput_fps': np.sum(image_counts) / np.sum(runtimes), + 'runtimes_ms': [1000.0 * r for r in runtimes] + } + else: + statistics = { + 'avg_latency_ms': 1000.0 * np.mean(runtimes), + 'avg_throughput_fps': np.sum(image_counts) / np.sum(runtimes), + 'runtimes_ms': [1000.0 * r for r in runtimes] + } if output_path is not None: subprocess.call(['mkdir', '-p', os.path.dirname(output_path)]) with open(output_path, 'w') as f: json.dump(statistics, f) - subprocess.call(['rm', '-rf', tmp_dir]) return statistics From 35ae9eeebf4cb67bd1eeb4e89d5d7b40e3d66fdb Mon Sep 17 00:00:00 2001 From: mdrozdowski1996 Date: Wed, 27 Mar 2019 15:54:24 -0700 Subject: [PATCH 38/56] add conversion stats (#48) --- tftrt/examples/object_detection/object_detection.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tftrt/examples/object_detection/object_detection.py b/tftrt/examples/object_detection/object_detection.py index 4028b7228..27361999b 100644 --- a/tftrt/examples/object_detection/object_detection.py +++ b/tftrt/examples/object_detection/object_detection.py @@ -349,6 +349,9 @@ def optimize_model(config_path, runtimes = [] with tf.Graph().as_default() as tf_graph: with tf.Session(config=tf_config) as tf_sess: + graph_size = len(frozen_graph.SerializeToString()) + num_nodes = len(frozen_graph.node) + start_time = time.time() frozen_graph = trt.create_inference_graph( input_graph_def=frozen_graph, outputs=output_names, @@ -358,6 +361,14 @@ def optimize_model(config_path, minimum_segment_size=minimum_segment_size, is_dynamic_op=True, maximum_cached_engines=maximum_cached_engines) + end_time = time.time() + print("graph_size(MB)(native_tf): %.1f" % (float(graph_size)/(1<<20))) + print("graph_size(MB)(trt): %.1f" % + (float(len(frozen_graph.SerializeToString()))/(1<<20))) + print("num_nodes(native_tf): %d" % num_nodes) + print("num_nodes(tftrt_total): %d" % len(frozen_graph.node)) + print("num_nodes(trt_only): %d" % len([1 for n in frozen_graph.node if str(n.op)=='TRTEngineOp'])) + print("time(s) (trt_conversion): %.4f" % (end_time - start_time)) # perform calibration for int8 precision if precision_mode == 'INT8': From 8aa02d9ded8f23bef11e268f3cd5e6815c797184 Mon Sep 17 00:00:00 2001 From: mdrozdowski1996 Date: Wed, 27 Mar 2019 15:58:18 -0700 Subject: [PATCH 39/56] fix data files requirement for synthetic data (#47) --- .../image_classification.py | 27 ++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index 5b49efd20..94944ce62 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -167,7 +167,9 @@ def input_fn(): return features, labels # Evaluate model - if mode == 'validation': + if use_synthetic: + num_records = num_iterations * batch_size + elif mode == 'validation': num_records = get_tfrecords_count(data_files) elif mode == 'benchmark': num_records = len(data_files) @@ -626,7 +628,7 @@ def get_frozen_graph( 'resnet_v1_50', 'resnet_v2_50', 'resnet_v2_152', 'vgg_16', 'vgg_19', 'inception_v3', 'inception_v4'], help='Which model to use.') - parser.add_argument('--data_dir', type=str, required=True, + parser.add_argument('--data_dir', type=str, default=None, help='Directory containing validation set TFRecord files.') parser.add_argument('--calib_data_dir', type=str, help='Directory containing TFRecord files for calibrating int8.') @@ -683,6 +685,10 @@ def get_frozen_graph( '({} <= {})'.format(args.num_calib_inputs, args.batch_size)) if args.mode == 'validation' and args.use_synthetic: raise ValueError('Cannot use both validation mode and synthetic dataset') + if args.data_dir is None and not args.use_synthetic: + raise ValueError("--data_dir required if you are not using synthetic data") + if args.use_synthetic and args.num_iterations is None: + raise ValueError("--num_iterations is required for --use_synthetic") def get_files(data_dir, filename_pattern): if data_dir == None: @@ -693,13 +699,16 @@ def get_files(data_dir, filename_pattern): 'pattern "{}"'.format(data_dir, filename_pattern)) return files - if args.mode == "validation": - data_files = get_files(args.data_dir, 'validation*') - elif args.mode == "benchmark": - data_files = [os.path.join(path, name) for path, _, files in os.walk(args.data_dir) for name in files] - else: - raise ValueError("Mode must be either 'validation' or 'benchmark'") - calib_files = get_files(args.calib_data_dir, 'train*') + calib_files = [] + data_files = [] + if not args.use_synthetic: + if args.mode == "validation": + data_files = get_files(args.data_dir, 'validation*') + elif args.mode == "benchmark": + data_files = [os.path.join(path, name) for path, _, files in os.walk(args.data_dir) for name in files] + else: + raise ValueError("Mode must be either 'validation' or 'benchamark'") + calib_files = get_files(args.calib_data_dir, 'train*') frozen_graph, num_nodes, times, graph_sizes = get_frozen_graph( model=args.model, From e3622b14bdb2c934775bfaede7facc3fdd4503ff Mon Sep 17 00:00:00 2001 From: "Xiaodong (Vincent) Huang" Date: Fri, 29 Mar 2019 03:53:00 +0800 Subject: [PATCH 40/56] update hook position for performance measurement (#43) use before_run() instead of begin() as start time of each inference iteration. --- tftrt/examples/image-classification/image_classification.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index 94944ce62..6ad6325b1 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -44,7 +44,6 @@ def before_run(self, run_context): def after_run(self, run_context, run_values): current_time = time.time() duration = current_time - self.start_time - self.start_time = current_time self.iter_times.append(duration) current_step = len(self.iter_times) if current_step % self.display_every == 0: From 7cb3fd3d07656aa7d092b0e11d07951fb1e8f9d4 Mon Sep 17 00:00:00 2001 From: Marek Drozdowski Date: Thu, 28 Mar 2019 16:16:06 -0700 Subject: [PATCH 41/56] update readme --- tftrt/examples/image-classification/README.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tftrt/examples/image-classification/README.md b/tftrt/examples/image-classification/README.md index d84b7b8e4..60bec5baf 100644 --- a/tftrt/examples/image-classification/README.md +++ b/tftrt/examples/image-classification/README.md @@ -102,7 +102,8 @@ for more information. ### Data -The example script supports either using a dataset in TFRecord format or using +The example script supports either using a dataset (for validation +mode - TFRecord format, for benchmark mode - jpeg format) or using autogenerated synthetic data (with the `--use_synthetic` flag). If you use TFRecord files, the script assumes that the TFRecords are named according to the pattern: `validation-*-of-00128`. @@ -161,7 +162,8 @@ model as follows: python image_classification.py --model resnet_v1_50 \ --data_dir /data/imagenet/train-val-tfrecord \ --use_trt \ - --precision fp16 + --precision fp16 \ + --mode validation ``` Where: @@ -174,6 +176,8 @@ Where: `--precision`: Precision mode to use, in this case FP16. +`--mode`: Which mode to use (validation or benchmark). In validation we run inference with accuracy and performance measurments, in benchmark only performance. + Run with `--help` to see all available options. Also see [General Script Usage From 360916694fe09684e6ccb00dcd9a1e17f9204e67 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Thu, 28 Mar 2019 17:22:06 -0700 Subject: [PATCH 42/56] Remove end argument from print for py2 compatibility --- .../examples/image-classification/image_classification.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index 6ad6325b1..009e38828 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -63,15 +63,15 @@ def before_run(self, run_context): if not self.start_time: self.start_time = time.time() if self.target_duration: - print(" running for target duration {} seconds".format(self.target_duration), end="") - print(" from {}".format(time.asctime(time.localtime(self.start_time)))) + print(" running for target duration {} seconds from {}".format( + self.target_duration, time.asctime(time.localtime(self.start_time)))) def after_run(self, run_context, run_values): if self.target_duration: current_time = time.time() if (current_time - self.start_time) > self.target_duration: - print(" target duration {}".format(self.target_duration), end="") - print(" reached at {}, requesting stop".format(time.asctime(time.localtime(current_time)))) + print(" target duration {} reached at {}, requesting stop".format( + self.target_duration), time.asctime(time.localtime(current_time))) run_context.request_stop() if self.iteration_limit: From 36ef354341618271e5ca19a6986d8298ef30456e Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Thu, 28 Mar 2019 17:24:42 -0700 Subject: [PATCH 43/56] Store constant input tensor on GPU --- tftrt/examples/image-classification/image_classification.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index 009e38828..4014375d8 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -134,13 +134,14 @@ def input_fn(): loc=112, scale=70, size=(batch_size, input_height, input_width, 3)).astype(np.float32) features = np.clip(features, 0.0, 255.0) - features = tf.identity(tf.constant(features)) labels = np.random.randint( low=0, high=get_netdef(model).get_num_classes(), size=(batch_size), dtype=np.int32) - labels = tf.identity(tf.constant(labels)) + with tf.device('/device:GPU:0'): + features = tf.identity(tf.constant(features)) + labels = tf.identity(tf.constant(labels)) else: if mode == 'validation': dataset = tf.data.TFRecordDataset(data_files) From a60cd64bc0f03552545799d23074d7d88d58c38a Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Sat, 2 Mar 2019 23:02:04 -0800 Subject: [PATCH 44/56] Use uppercase letters for precision mode --- tftrt/examples/image-classification/image_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index 4014375d8..0e021471c 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -646,7 +646,7 @@ def get_frozen_graph( 'is provided. This option is ignored when not using tf_trt.') parser.add_argument('--use_trt_dynamic_op', action='store_true', help='If set, TRT conversion will be done using dynamic op instead of statically.') - parser.add_argument('--precision', type=str, choices=['fp32', 'fp16', 'int8'], default='fp32', + parser.add_argument('--precision', type=str, choices=['FP32', 'FP16', 'INT8'], default='FP32', help='Precision mode to use. FP16 and INT8 only work in conjunction with --use_trt') parser.add_argument('--batch_size', type=int, default=8, help='Number of images per batch.') From d36fea651d19f879d07cda4c9cda65ca0d27e6ce Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Mon, 4 Mar 2019 09:46:24 -0800 Subject: [PATCH 45/56] Add docs for new args --- tftrt/examples/object_detection/object_detection.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tftrt/examples/object_detection/object_detection.py b/tftrt/examples/object_detection/object_detection.py index 27361999b..c1d163015 100644 --- a/tftrt/examples/object_detection/object_detection.py +++ b/tftrt/examples/object_detection/object_detection.py @@ -264,6 +264,8 @@ def optimize_model(config_path, to use for TensorRT graph segmentation. max_workspace_size_bytes: An integer representing the max workspace size for TensorRT optimization. + maximum_cached_engines: An integer represenging the number of TRT engines + that can be stored in the cache. calib_images_dir: A string representing a directory containing images to use for int8 calibration. num_calib_images: An integer representing the number of calibration From 210771904d9d228f9d6fcddf9d4622889248eb3b Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Sun, 17 Mar 2019 20:47:19 -0700 Subject: [PATCH 46/56] Error out if max_batch_size>1 and calib_image_shape is not set --- tftrt/examples/object_detection/object_detection.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tftrt/examples/object_detection/object_detection.py b/tftrt/examples/object_detection/object_detection.py index c1d163015..28f70c892 100644 --- a/tftrt/examples/object_detection/object_detection.py +++ b/tftrt/examples/object_detection/object_detection.py @@ -286,6 +286,9 @@ def optimize_model(config_path, ------- A GraphDef representing the optimized model. """ + if max_batch_size > 1 and calib_image_shape is None: + raise RuntimeError( + 'Fixed calibration image shape must be provided for max_batch_size > 1') if os.path.exists(tmp_dir): if not remove_tmp_dir: raise RuntimeError( From 945461d9f3bc9b8d69bf6c9b80283bc584d93c59 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Wed, 10 Apr 2019 09:24:15 -0700 Subject: [PATCH 47/56] Fix links to TF-TRT docs (#51) * Update links to TF-TRT docs * Fix typo --- README.md | 6 +++--- tftrt/examples/image-classification/README.md | 14 +++++++------- .../image_classification.ipynb | 4 ++-- tftrt/examples/third_party/models | 2 +- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index afaf1d0b1..95272a1ef 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Documentation for TensorRT in TensorFlow (TF-TRT) -The documentaion on how to accelerate inference in TensorFlow with TensorRT (TF-TRT) is here: https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html +The documentaion on how to accelerate inference in TensorFlow with TensorRT (TF-TRT) is here: https://docs.nvidia.com/deeplearning/dgx/tf-trt-user-guide/index.html # Examples for TensorRT in TensorFlow (TF-TRT) @@ -12,7 +12,7 @@ that optimizes TensorFlow graphs using [TensorRT](https://developer.nvidia.com/tensorrt). We have used these examples to verify the accuracy and performance of TF-TRT. For more information see -[Verified Models](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#verified-models). +[Verified Models](https://docs.nvidia.com/deeplearning/dgx/tf-trt-user-guide/index.html#verified-models). ## Examples @@ -50,7 +50,7 @@ Installation instructions for compatibility with TensorFlow are provided on the ## Documentation -[TF-TRT documentaion](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html) +[TF-TRT documentaion](https://docs.nvidia.com/deeplearning/dgx/tf-trt-user-guide/index.html) gives an overview of the supported functionalities, provides tutorials and verified models, explains best practices with troubleshooting guides. diff --git a/tftrt/examples/image-classification/README.md b/tftrt/examples/image-classification/README.md index 60bec5baf..9c69df1ee 100644 --- a/tftrt/examples/image-classification/README.md +++ b/tftrt/examples/image-classification/README.md @@ -12,7 +12,7 @@ This causes the script to apply TensorRT inference optimization to speed up execution for portions of the model's graph where supported, and to fall back on native TensorFlow for layers and operations which are not supported. See [Accelerating Inference In TensorFlow With TensorRT User -Guide](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html) for +Guide](https://docs.nvidia.com/deeplearning/dgx/tf-trt-user-guide/index.html) for more information. When using the TF-TRT integration flag, you can use the precision option @@ -42,7 +42,7 @@ We have verified the following models. For the accuracy numbers of these models on the ImageNet validation dataset, see -[Verified Models](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#verified-models). +[Verified Models](https://docs.nvidia.com/deeplearning/dgx/tf-trt-user-guide/index.html#verified-models). ## Setup @@ -97,7 +97,7 @@ replacing `/path/to/tensorflow_models` with the path to your `tensorflow/models` repository). Also see [Setting Up The Environment -](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#image-class-envirn) +](https://docs.nvidia.com/deeplearning/dgx/tf-trt-user-guide/index.html#image-class-envirn) for more information. ### Data @@ -127,7 +127,7 @@ or the validation set. Also see [Obtaining The ImageNet Data -](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#image-class-data) +](https://docs.nvidia.com/deeplearning/dgx/tf-trt-user-guide/index.html#image-class-data) for more information. ## Running the examples as a Jupyter notebook @@ -181,7 +181,7 @@ Where: Run with `--help` to see all available options. Also see [General Script Usage -](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#image-class-usage) +](https://docs.nvidia.com/deeplearning/dgx/tf-trt-user-guide/index.html#image-class-usage) for more information. ## Output @@ -215,9 +215,9 @@ time(s)(trt_conversion): *** Note: For a list of supported operations that can be converted to a TensorRT graph, see the [Supported -Ops](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#support-ops) +Ops](https://docs.nvidia.com/deeplearning/dgx/tf-trt-user-guide/index.html#support-ops) section of the [Accelerating Inference In TensorFlow With TensorRT User -Guide](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html). +Guide](https://docs.nvidia.com/deeplearning/dgx/tf-trt-user-guide/index.html). The script then begins running inference on the ImageNet validation set, displaying run times of each iteration after the interval defined by the diff --git a/tftrt/examples/image-classification/image_classification.ipynb b/tftrt/examples/image-classification/image_classification.ipynb index 7bb4b8bf2..40202dacd 100644 --- a/tftrt/examples/image-classification/image_classification.ipynb +++ b/tftrt/examples/image-classification/image_classification.ipynb @@ -34,7 +34,7 @@ "* Inception v3\n", "* Inception v4\n", "\n", - "For the accuracy numbers of these models on the ImageNet validation dataset, see [Verified Models](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#verified-models)." + "For the accuracy numbers of these models on the ImageNet validation dataset, see [Verified Models](https://docs.nvidia.com/deeplearning/dgx/tf-trt-user-guide/index.html#verified-models)." ] }, { @@ -80,7 +80,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Also see [General Script Usage](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#image-class-usage) for more information." + "Also see [General Script Usage](https://docs.nvidia.com/deeplearning/dgx/tf-trt-user-guide/index.html#image-class-usage) for more information." ] }, { diff --git a/tftrt/examples/third_party/models b/tftrt/examples/third_party/models index 402b561b0..416bfdfc1 160000 --- a/tftrt/examples/third_party/models +++ b/tftrt/examples/third_party/models @@ -1 +1 @@ -Subproject commit 402b561b03857151f684ee00b3d997e5e6be9778 +Subproject commit 416bfdfc10307f896fa1f218b4f58800599b0cd7 From 95e5941a6aae1491ba5494db35d7b65b9d7897f1 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Wed, 10 Apr 2019 14:16:59 -0700 Subject: [PATCH 48/56] Use `-q` for wget in image_classification.py (#55) * Use `-q` for wget in image_classification.py * Use -q for wget in object_detection.py --- tftrt/examples/image-classification/image_classification.py | 2 +- tftrt/examples/object_detection/object_detection.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index 0e021471c..a519437fb 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -501,7 +501,7 @@ def copy_files(source, destination): archive_path = os.path.join(destination_path, os.path.basename(get_netdef(model).url)) if not os.path.isfile(archive_path): - subprocess.call(['wget', '--no-check-certificate', + subprocess.call(['wget', '--no-check-certificate', '-q', get_netdef(model).url, '-O', archive_path]) # Extract. subprocess.call(['tar', '-xzf', archive_path, '-C', destination_path]) diff --git a/tftrt/examples/object_detection/object_detection.py b/tftrt/examples/object_detection/object_detection.py index 28f70c892..c00a17784 100644 --- a/tftrt/examples/object_detection/object_detection.py +++ b/tftrt/examples/object_detection/object_detection.py @@ -199,7 +199,7 @@ def download_model(model_name, output_dir='.'): if os.path.exists(extract_dir): print('Using cached model found at: %s' % extract_dir) else: - subprocess.call(['wget', model.url, '-O', tar_file]) + subprocess.call(['wget', '-q', model.url, '-O', tar_file]) subprocess.call(['tar', '-xzf', tar_file, '-C', output_dir]) # hack fix to handle mobilenet_v2 config bug @@ -470,14 +470,14 @@ def download_dataset(dataset_name, output_dir='.'): print('Using cached annotation_path; %s' % (annotation_path)) else: subprocess.call( - ['wget', dataset.annotation_url, '-O', annotation_zip_file]) + ['wget', '-q', dataset.annotation_url, '-O', annotation_zip_file]) subprocess.call(['unzip', annotation_zip_file, '-d', output_dir]) # download or use cached images if os.path.exists(images_dir): print('Using cached images_dir; %s' % (images_dir)) else: - subprocess.call(['wget', dataset.images_url, '-O', images_zip_file]) + subprocess.call(['wget', '-q', dataset.images_url, '-O', images_zip_file]) subprocess.call(['unzip', images_zip_file, '-d', output_dir]) return images_dir, annotation_path From 45a91657470a882947901ccd30e2813b53310c8c Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Fri, 12 Apr 2019 15:11:44 -0700 Subject: [PATCH 49/56] Revert updating submodule tftrt/examples/third_party/models The change happened in commit 945461d9f3bc9b8d69bf6c9b80283bc584d93c59 --- tftrt/examples/third_party/models | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tftrt/examples/third_party/models b/tftrt/examples/third_party/models index 416bfdfc1..402b561b0 160000 --- a/tftrt/examples/third_party/models +++ b/tftrt/examples/third_party/models @@ -1 +1 @@ -Subproject commit 416bfdfc10307f896fa1f218b4f58800599b0cd7 +Subproject commit 402b561b03857151f684ee00b3d997e5e6be9778 From d67a79f8b3d56034945e6c5ad880ac9e831ee0ff Mon Sep 17 00:00:00 2001 From: otstrel Date: Tue, 16 Apr 2019 20:18:14 +0300 Subject: [PATCH 50/56] Fixing parentheses (#57) --- tftrt/examples/image-classification/image_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index a519437fb..e70fcad23 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -71,7 +71,7 @@ def after_run(self, run_context, run_values): current_time = time.time() if (current_time - self.start_time) > self.target_duration: print(" target duration {} reached at {}, requesting stop".format( - self.target_duration), time.asctime(time.localtime(current_time))) + self.target_duration, time.asctime(time.localtime(current_time)))) run_context.request_stop() if self.iteration_limit: From e51c6c8ac45ce8adab2ccff0d1e2045110762b49 Mon Sep 17 00:00:00 2001 From: Marek Drozdowski Date: Mon, 22 Apr 2019 09:23:32 -0700 Subject: [PATCH 51/56] add inference_script --- tftrt/examples/ncf/inference.py | 264 ++++++++++++++++++++++++++++++++ 1 file changed, 264 insertions(+) create mode 100644 tftrt/examples/ncf/inference.py diff --git a/tftrt/examples/ncf/inference.py b/tftrt/examples/ncf/inference.py new file mode 100644 index 000000000..a66d4c49e --- /dev/null +++ b/tftrt/examples/ncf/inference.py @@ -0,0 +1,264 @@ +import tensorflow as tf +import tensorflow.contrib.tensorrt as trt +import time +import random +import numpy as np +from official.dataset import movielens + +from neumf import ncf_model +from neumf import NeuMF +import os +import argparse +import csv + +class LoggerHook(tf.train.SessionRunHook): + """Logs runtime of each iteration""" + def __init__(self, batch_size, num_records, display_every): + self.iter_times = [] + self.display_every = display_every + self.num_steps = (num_records + batch_size - 1) / batch_size + self.batch_size = batch_size + + def before_run(self, run_context): + self.start_time = time.time() + + def after_run(self, run_context, run_values): + current_time = time.time() + duration = current_time - self.start_time + self.iter_times.append(duration) + current_step = len(self.iter_times) + if current_step % self.display_every == 0: + print(" step %d/%d, iter_time(ms)=%.4f, images/sec=%d" % ( + current_step, self.num_steps, duration * 1000, + self.batch_size / self.iter_times[-1])) + +class BenchmarkHook(tf.train.SessionRunHook): + """Limits run duration and number of iterations""" + def __init__(self, target_duration=None, iteration_limit=None): + self.target_duration = target_duration + self.start_time = None + self.current_iteration = 0 + self.iteration_limit = iteration_limit + def before_run(self, run_context): + if not self.start_time: + self.start_time = time.time() + if self.target_duration: + print(" running for target duration {} seconds from {}".format( + self.target_duration, time.asctime(time.localtime(self.start_time)))) + + def after_run(self, run_context, run_values): + if self.target_duration: + current_time = time.time() + if (current_time - self.start_time) > self.target_duration: + print(" target duration {} reached at {}, requesting stop".format( + self.target_duration), time.asctime(time.localtime(current_time))) + run_context.request_stop() + if self.iteration_limit: + self.current_iteration += 1 + if self.current_iteration >= self.iteration_limit: + run_context.request_stop() + + +def get_frozen_graph(model_checkpoint="/data/marek_ckpt/model.ckpt", + model_dtype=tf.float32, + mf_dim=64, + mf_reg=64, + mlp_layer_sizes=[256, 256, 128, 64], + mlp_layer_regs=[.0, .0, .0, .0], + nb_items=26744, + nb_users=138493): + tf_config = tf.ConfigProto() + with tf.Graph().as_default() as tf_graph: + with tf.Session(config=tf_config) as tf_sess: + users = tf.placeholder(shape=(None,), dtype=tf.int32, name="user_input") + items = tf.placeholder(shape=(None,), dtype=tf.int32, name="item_input") + with tf.variable_scope("neumf"): + logits = NeuMF(users, items, model_dtype, nb_users, nb_items, mf_dim, mf_reg, mlp_layer_sizes, mlp_layer_regs) + if mode == "validation": + found_positive, dcg = compute_eval_metrics(logits, dup_mask, val_batch_size, K) + hit_rate = tf.metrics.mean(found_positive, name='hit_rate') + ndcg = tf.metrics.mean(dcg, name='ndcg') + + saver = tf.train.Saver() + saver.restore(tf_sess, "/data/marek_ckpt/model.ckpt") + graph0 = tf.graph_util.convert_variables_to_constants(tf_sess, + tf_sess.graph_def, output_node_names=['neumf/dense_3/BiasAdd']) + frozen_graph = tf.graph_util.remove_training_nodes(graph0) + + for node in frozen_graph.node: + if node.op == "Assign": + node.op = "Identity" + if 'use_locking' in node.attr: del node.attr['use_locking'] + if 'validate_shape' in node.attr: del node.attr['validate_shape'] + if len(node.input) == 2: + node.input[0] = node.input[1] + del node.input[1] + return frozen_graph + + +def optimize_model(frozen_graph, + use_trt=True, + precision_mode="FP16", + batch_size=128): + if use_trt: + trt_graph = trt.create_inference_graph(frozen_graph, ['neumf/dense_3/BiasAdd:0'], max_batch_size=batch_size, precision_mode=precision_mode) + return trt_graph + +def run(frozen_graph, + data_dir='/data/cache/ml-20m', + num_iterations=None, + num_warmup_iterations=None, + use_synthetic=False, + display_every=100, + mode='validation', + target_duration=None): + + def model_fn(features, labels, mode): + logits_out = tf.import_graph_def(frozen_graph, + input_map={'input': features}, + return_elements=['logits:0'], + name='') + found_possitive, dcg = compute_eval_metrics(logits, dup_mask, val_batch_size, K) + hit_rate = tf.metrics.mean(found_positive, name='hit_rate') + ndcg = tf.metrics.mean(found_positive, name='ndcg') + if mode == tf.estimator.ModeKeys.PREDICT: + return tf.estimator.EstimatorSpec(mode=mode, + predictions={'logits': logits_out}) + if mode == tf.estimator.ModeKeys.EVAL: + return tf.estimator.EstimatorSpec( + mode=mode, + eval_metrics_ops={'found_positive': found_possitive, 'ndcg': ndcg}) + + def input_fn(): + if use_synthetic: + items = [random.randint(1, nb_items) for _ in range(batch_size)] + users = [random.randint(1, nb_users) for _ in range(batch_size)] + with tf.devices('/device:GPU:0'): + items = tf.identity(items) + users = tf.identity(users) + else: + data_path = os.path.join(data_dir, 'test_ratings.pickle') + dataset = pd.read_pickle(data_path) + users = dataset["user_id"] + items = dataset["item_id"] + + user_dataset = tf.data.Dataset_from_tensor_slices(users) + user_dataset = user_dataset.batch(batch_size) + user_dataset = user_dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) + user_dataset = user_dataset.repeat(count=1) + user_iterator = user_dataset.make_one_shot_iterator() + users = user_iterator.get_next() + + item_dataset = tf.data.Dataset_from_tensor_slices(items) + item_dataset = item_dataset.batch(batch_size) + item_dataset = item_dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) + item_dataset = item_dataset.repeat(count=1) + item_iterator = item_dataset.make_one_shot_iterator() + items = item_iterator.get_next() + return items, users + + +def input_data(use_synthetic=True, + batch_size=128, + data_dir="/data/cache/ml-20m", + num_iterations=None, + nb_items=26744, + nb_users=1388493): + + if use_synthetic and num_iterations is None: + num_iterations = 10000 + + if use_synthetic: + items = [[random.randint(1, nb_items) for _ in range(batch_size)] for _ in range(num_iterations)] + users = [[random.randint(1, nb_users) for _ in range(batch_size)] for _ in range(num_iterations)] + else: + if os.path.exists(data_dir): + print("Using cached dataset: %s" % (data_dir)) + else: + data_path = os.path.join(data_dir, 'test_ratings.pickle') + dataset = pd.read_pickle(data_path) + users = dataset["user_id"] + items = dataset["item_id"] + + return items, users + + +def run_inference(frozen_graph, + use_synthetic=True, + mode='benchmark', + batch_size=128, + data_dir=None, + num_iterations=10000, + num_warmup_iterations=2000, + nb_items=26744, + nb_users=1388493, + display_every=100): + + items, users = input_data(use_synthetic, + batch_size, + data_dir, + num_iterations, + nb_items, + nb_users) + + with tf.Graph().as_default() as tf_graph: + with tf.Session(config=tf.ConfigProto()) as tf_sess: + tf.import_graph_def(frozen_graph, name='') + output = tf_sess.graph.get_tensor_by_name('neumf/dense_3/BiasAdd:0') + runtimes = [] + res = [] + + for n in range(num_iterations): + item = items[n] + user = users[n] + + beg = time.time() + r = tf_sess.run(output, feed_dict={'item_input:0': item, 'user_input:0': user}) + end = time.time() + + res.append(r) + runtimes.append(end-beg) + if n % display_every == 0: + print(" step %d/%d, iter_time(ms)=%.4f" % ( + len(runtimes), + num_iterations, + np.mean(runtimes[(-1)*display_every]) * 1000)) + print("throghput: %.1f" % + (batch_size * num_iterations/np.sum(runtimes))) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Evaluate model') + parse.add_argument('--use_synthetic', action='store_true', + help='If set, one batch of random data is generated and used at every iteration.') + parser.add_argument('--mode', choices=['validation', 'benchmark'], + help='Which mode to use (validation or benchmark)') + parser.add_argument('--data_dir', type=str, default=None, + help='Directory containing validation set csv files.') + parser.add_argument('--model_dir', type=str, default=None, + help='Directory containing model checkpoint.') + parser.add_argument('--use_trt', action='store_true', + help='If set, the graph will be converted to a TensorRT graph.') + parser.add_argument('--precision', type=str, choices=['fp32', 'fp16', 'int8'], default='fp32', + help='Precision mode to use. FP16 and INT8 only work in conjunction with --use_trt') + parser.add_argument('--nb_items', type=int, default=26744, + help='Number of items') + parser.add_argument('--nb_users', type=int, default=1388493, + help='Number of users') + parser.add_argument('--batch_size', type=int, default=8, + help='Batch size') + parser.add_argument('--mf_dim', type=int, default=64) + parser.add_argument('--mf_reg', type=int, default=64) + parser.add_argument('--mlp_layer_sizes', default=[256, 256, 128, 64]) + parser.add_argument('--mlp_layer_regs', default=[.0, .0, .0, .0]) + + args = parser.parse_args() + if not args.use_synthetic and args.data_dir: + raise ValueError("Data_dir is not provided") + + frozen_graph = get_frozen_graph() + + frozen_graph = optimize_model(frozen_graph) + + run_inference(frozen_graph) + From 660b21d5d72af4eb1d32aa5dd23963e9347e4cf2 Mon Sep 17 00:00:00 2001 From: Marek Drozdowski Date: Mon, 22 Apr 2019 10:47:18 -0700 Subject: [PATCH 52/56] add DeepLearningExamples submodule --- .gitmodules | 3 +++ tftrt/examples/third_party/DeepLearningExamples | 1 + 2 files changed, 4 insertions(+) create mode 160000 tftrt/examples/third_party/DeepLearningExamples diff --git a/.gitmodules b/.gitmodules index 9aa63ca53..fb9b3b0fd 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "tftrt/examples/third_party/cocoapi"] path = tftrt/examples/third_party/cocoapi url = https://github.com/cocodataset/cocoapi.git +[submodule "tftrt/examples/third_party/DeepLearningExamples"] + path = tftrt/examples/third_party/DeepLearningExamples + url = https://github.com/NVIDIA/DeepLearningExamples.git diff --git a/tftrt/examples/third_party/DeepLearningExamples b/tftrt/examples/third_party/DeepLearningExamples new file mode 160000 index 000000000..531a570c5 --- /dev/null +++ b/tftrt/examples/third_party/DeepLearningExamples @@ -0,0 +1 @@ +Subproject commit 531a570c5cc1705041ca69f3841bdb437022309b From 29a6c84530614ffdb29c1105ce516fe8c6afa926 Mon Sep 17 00:00:00 2001 From: Marek Drozdowski Date: Wed, 24 Apr 2019 11:15:23 -0700 Subject: [PATCH 53/56] support estimator --- tftrt/examples/ncf/inference.py | 260 +++++++++++++++++++++++++------- 1 file changed, 208 insertions(+), 52 deletions(-) diff --git a/tftrt/examples/ncf/inference.py b/tftrt/examples/ncf/inference.py index a66d4c49e..0e44229a9 100644 --- a/tftrt/examples/ncf/inference.py +++ b/tftrt/examples/ncf/inference.py @@ -3,10 +3,11 @@ import time import random import numpy as np -from official.dataset import movielens +import pandas as pd +from official.datasets import movielens -from neumf import ncf_model -from neumf import NeuMF +from neumf import compute_eval_metrics +from neumf import neural_mf import os import argparse import csv @@ -60,24 +61,37 @@ def after_run(self, run_context, run_values): def get_frozen_graph(model_checkpoint="/data/marek_ckpt/model.ckpt", + mode="benchmark", + use_trt=True, + batch_size=1024, + use_dynamic_op=True, + precision="FP32", model_dtype=tf.float32, mf_dim=64, mf_reg=64, mlp_layer_sizes=[256, 256, 128, 64], mlp_layer_regs=[.0, .0, .0, .0], nb_items=26744, - nb_users=138493): + nb_users=138493, + dup_mask=0.1, + K=10, + minimum_segment_size=2, + calib_data_dir=None, + num_calib_inputs=None, + use_synthetic=False, + max_workspace_size=(1<<32)): + + num_nodes = {} + times = {} + graph_sizes = {} + tf_config = tf.ConfigProto() with tf.Graph().as_default() as tf_graph: with tf.Session(config=tf_config) as tf_sess: users = tf.placeholder(shape=(None,), dtype=tf.int32, name="user_input") items = tf.placeholder(shape=(None,), dtype=tf.int32, name="item_input") with tf.variable_scope("neumf"): - logits = NeuMF(users, items, model_dtype, nb_users, nb_items, mf_dim, mf_reg, mlp_layer_sizes, mlp_layer_regs) - if mode == "validation": - found_positive, dcg = compute_eval_metrics(logits, dup_mask, val_batch_size, K) - hit_rate = tf.metrics.mean(found_positive, name='hit_rate') - ndcg = tf.metrics.mean(dcg, name='ndcg') + logits = neural_mf(users, items, model_dtype, nb_users, nb_items, mf_dim, mf_reg, mlp_layer_sizes, mlp_layer_regs, 0.1) saver = tf.train.Saver() saver.restore(tf_sess, "/data/marek_ckpt/model.ckpt") @@ -93,47 +107,82 @@ def get_frozen_graph(model_checkpoint="/data/marek_ckpt/model.ckpt", if len(node.input) == 2: node.input[0] = node.input[1] del node.input[1] - return frozen_graph - -def optimize_model(frozen_graph, - use_trt=True, - precision_mode="FP16", - batch_size=128): if use_trt: - trt_graph = trt.create_inference_graph(frozen_graph, ['neumf/dense_3/BiasAdd:0'], max_batch_size=batch_size, precision_mode=precision_mode) - return trt_graph + start_time = time.time() + frozen_graph = trt.create_inference_graph( + input_graph_def=frozen_graph, + outputs=['neumf/dense_3/BiasAdd:0'], + max_batch_size=batch_size, + max_workspace_size_bytes=max_workspace_size, + precision_mode=precision_mode, + minimum_segment_size=minimum_segment_size, + is_dynamic_op=use_dynamic_op) + times['trt_conversion'] = time.time() - start_time + num_nodes['tftrt_total']=len(frozen_graph.node) + num_nodes['trt_only'] = len([1 for n in frozen_graph.node if str(n.op)=='TRTEngineOp']) + graph_sizes['trt'] = len(frozen_graph.SerializeToString()) + + if precision == 'int8': + calib_graph = frozen_graph + graph_size['calib'] = len(calib_graph.SerializeToString()) + # INT8 calibration step + print('Calibrating INT8...') + start_time = time.time() + run(calib_graph, + data_dir=calib_data_dir, + batch_size=batch_size, + num_iterations=num_calib_inputs // batch_size, + num_warmup_iterations=0, + use_synthetic=use_synthetic) + times['trt_calibration'] = time.time() - start_time + start_time = time.time() + frozen_graph = trt.calib_graph_to_infer_graph(calib_graph) + times['trt_int8_conversion'] = time.time() - start_time + graph_sizes['trt'] = len(frozen_graph.SerializeToString()) + + del calib_graph + print('INT8 graph created') + + return frozen_graph, num_nodes, times, graph_sizes + def run(frozen_graph, - data_dir='/data/cache/ml-20m', + data_dir='/data/ml-20m/', + batch_size=1024, num_iterations=None, num_warmup_iterations=None, use_synthetic=False, display_every=100, - mode='validation', - target_duration=None): + mode='benchmark', + target_duration=None, + nb_items=26744, + nb_users=138493, + dup_mask=0.1, + K=10): def model_fn(features, labels, mode): logits_out = tf.import_graph_def(frozen_graph, - input_map={'input': features}, - return_elements=['logits:0'], + input_map={'user_input:0': features["user_input"], 'item_input:0': features["item_input"]}, + return_elements=['neumf/dense_3/BiasAdd:0'], name='') - found_possitive, dcg = compute_eval_metrics(logits, dup_mask, val_batch_size, K) - hit_rate = tf.metrics.mean(found_positive, name='hit_rate') - ndcg = tf.metrics.mean(found_positive, name='ndcg') if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode=mode, - predictions={'logits': logits_out}) + predictions={'logits': logits_out[0]}) if mode == tf.estimator.ModeKeys.EVAL: + found_positive, dcg = compute_eval_metrics(logits_out[0], dup_mask, batch_size, K) + hit_rate = tf.metrics.mean(found_positive, name='hit_rate') + ndcg = tf.metrics.mean(dcg, name='ndcg') return tf.estimator.EstimatorSpec( mode=mode, - eval_metrics_ops={'found_positive': found_possitive, 'ndcg': ndcg}) + loss=dcg, + eval_metric_ops={'hit_rate': hit_rate, 'ndcg': ndcg}) def input_fn(): if use_synthetic: items = [random.randint(1, nb_items) for _ in range(batch_size)] users = [random.randint(1, nb_users) for _ in range(batch_size)] - with tf.devices('/device:GPU:0'): + with tf.device('/device:GPU:0'): items = tf.identity(items) users = tf.identity(users) else: @@ -141,26 +190,69 @@ def input_fn(): dataset = pd.read_pickle(data_path) users = dataset["user_id"] items = dataset["item_id"] - - user_dataset = tf.data.Dataset_from_tensor_slices(users) + print(type(users)) + users = users.astype('int32') + items = items.astype('int32') + user_dataset = tf.data.Dataset.from_tensor_slices(users) user_dataset = user_dataset.batch(batch_size) user_dataset = user_dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) user_dataset = user_dataset.repeat(count=1) user_iterator = user_dataset.make_one_shot_iterator() users = user_iterator.get_next() - item_dataset = tf.data.Dataset_from_tensor_slices(items) + item_dataset = tf.data.Dataset.from_tensor_slices(items) item_dataset = item_dataset.batch(batch_size) item_dataset = item_dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) item_dataset = item_dataset.repeat(count=1) item_iterator = item_dataset.make_one_shot_iterator() items = item_iterator.get_next() - return items, users + return {"user_input": users, "item_input": items}, [] + + if use_synthetic and num_iterations is None: + num_iterations=1000 + + if use_synthetic: + num_records=num_iterations*batch_size + else: + data_path = os.path.join(data_dir, 'test_ratings.pickle') + dataset = pd.read_pickle(data_path) + users = dataset["user_id"] + num_records = len(users) + + logger = LoggerHook( + display_every=display_every, + batch_size=batch_size, + num_records=num_records) + tf_config = tf.ConfigProto() + estimator = tf.estimator.Estimator( + model_fn=model_fn, + config=tf.estimator.RunConfig(session_config=tf_config), + model_dir='model_dir') + results = {} + + if mode == 'validation': + results = estimator.evaluate(input_fn, steps=num_iterations, hooks=[logger]) + elif mode == 'benchmark': + benchmark_hook = BenchmarkHook(target_duration=target_duration, iteration_limit=num_iterations) + prediction_results = [p for p in estimator.predict(input_fn, predict_keys=["logits"], hooks=[logger, benchmark_hook])] + print(prediction_results) + else: + raise ValueError("Mode must be either 'validation' or 'benchmark'") + + iter_times = np.array(logger.iter_times[num_warmup_iterations:]) + results['total_time'] = np.sum(iter_times) + results['images_per_sec'] = np.mean(batch_size / iter_times) + results['99th_percentile'] = np.percentile(iter_times, q=99, interpolation='lower') * 1000 + results['latency_mean'] = np.mean(iter_times) * 1000 + results['latency_median'] = np.median(iter_times) * 1000 + results['latency_min'] = np.min(iter_times) * 1000 + return results -def input_data(use_synthetic=True, - batch_size=128, - data_dir="/data/cache/ml-20m", + +def input_data(use_synthetic=False, + batch_size=1024, + data_dir='/data/ml-20m/', num_iterations=None, nb_items=26744, nb_users=1388493): @@ -184,10 +276,10 @@ def input_data(use_synthetic=True, def run_inference(frozen_graph, - use_synthetic=True, + use_synthetic=False, mode='benchmark', - batch_size=128, - data_dir=None, + batch_size=1024, + data_dir='/data/ml-20m/', num_iterations=10000, num_warmup_iterations=2000, nb_items=26744, @@ -204,6 +296,8 @@ def run_inference(frozen_graph, with tf.Graph().as_default() as tf_graph: with tf.Session(config=tf.ConfigProto()) as tf_sess: tf.import_graph_def(frozen_graph, name='') + for i in tf.get_default_graph().as_graph_def().node: + print(i.name) output = tf_sess.graph.get_tensor_by_name('neumf/dense_3/BiasAdd:0') runtimes = [] res = [] @@ -222,43 +316,105 @@ def run_inference(frozen_graph, print(" step %d/%d, iter_time(ms)=%.4f" % ( len(runtimes), num_iterations, - np.mean(runtimes[(-1)*display_every]) * 1000)) - print("throghput: %.1f" % + np.mean(runtimes[(-1)*display_every:]) * 1000)) + print("throughput: %.1f" % (batch_size * num_iterations/np.sum(runtimes))) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Evaluate model') - parse.add_argument('--use_synthetic', action='store_true', + parser.add_argument('--use_synthetic', action='store_true', + default=False, help='If set, one batch of random data is generated and used at every iteration.') parser.add_argument('--mode', choices=['validation', 'benchmark'], - help='Which mode to use (validation or benchmark)') - parser.add_argument('--data_dir', type=str, default=None, + default='validation', help='Which mode to use (validation or benchmark)') + parser.add_argument('--data_dir', type=str, default='/data/ml-20m/', help='Directory containing validation set csv files.') + parser.add_argument('--calib_data_dir', type=str, + help='Directory containing TFRecord files for calibrating int8.') parser.add_argument('--model_dir', type=str, default=None, help='Directory containing model checkpoint.') parser.add_argument('--use_trt', action='store_true', help='If set, the graph will be converted to a TensorRT graph.') - parser.add_argument('--precision', type=str, choices=['fp32', 'fp16', 'int8'], default='fp32', + parser.add_argument('--use_dynamic_op', action='store_true', + help='If set, TRT conversion will be done using dynamic op instead of statically.') + parser.add_argument('--precision', type=str, + choices=['fp32', 'fp16', 'int8'], default='fp32', help='Precision mode to use. FP16 and INT8 only work in conjunction with --use_trt') parser.add_argument('--nb_items', type=int, default=26744, help='Number of items') - parser.add_argument('--nb_users', type=int, default=1388493, + parser.add_argument('--nb_users', type=int, default=138493, help='Number of users') - parser.add_argument('--batch_size', type=int, default=8, + parser.add_argument('--batch_size', type=int, default=1024, help='Batch size') + parser.add_argument('--minimum_segment_size', type=int, default=2, + help='Minimum number of TF ops in a TRT engine') + parser.add_argument('--num_iterations', type=int, default=None, + help='How many iterations(batches) to evaluate. If not supplied, the whole set will be evaluated.') + parser.add_argument('--num_warmup_iterations', type=int, default=50, + help='Number of initial iterations skipped from timing') + parser.add_argument('--num_calib_inputs', type=int, default=500, + help='Number of inputs (e.g. images) used for calibration ' + '(last batch is skipped in case it is not full)') + parser.add_argument('--max_workspace_size', type=int, default=(1<<32), + help='workspace size in bytes') + parser.add_argument('--display_every', type=int, default=100, + help='Number of iterations executed between two consecutive display of metrics') + parser.add_argument('--dup_mask', type=float, default=0.1) + parser.add_argument('--K', type=int, default=10) parser.add_argument('--mf_dim', type=int, default=64) parser.add_argument('--mf_reg', type=int, default=64) parser.add_argument('--mlp_layer_sizes', default=[256, 256, 128, 64]) parser.add_argument('--mlp_layer_regs', default=[.0, .0, .0, .0]) args = parser.parse_args() - if not args.use_synthetic and args.data_dir: + if not args.use_synthetic and args.data_dir is None: raise ValueError("Data_dir is not provided") - frozen_graph = get_frozen_graph() - - frozen_graph = optimize_model(frozen_graph) - - run_inference(frozen_graph) + frozen_graph, num_nodes, times, graph_sizes = get_frozen_graph( + model_checkpoint=args.model_dir, + use_trt=args.use_trt, + use_dynamic_op=args.use_dynamic_op, + precision=args.precision, + batch_size=args.batch_size, + mf_dim=args.mf_dim, + mf_reg=args.mf_reg, + mlp_layer_sizes=args.mlp_layer_sizes, + mlp_layer_regs=args.mlp_layer_regs, + nb_items=args.nb_items, + nb_users=args.nb_users, + minimum_segment_size=args.minimum_segment_size, + calib_data_dir=args.calib_data_dir, + num_calib_inputs=args.num_calib_inputs, + use_synthetic=args.use_synthetic, + max_workspace_size=args.max_workspace_size + ) + + + def print_dict(input_dict, str='', scale=None): + for k, v in sorted(input_dict.items()): + headline = '{}({}): '.format(str, k) if str else '{}: '.format(k) + v = v * scale if scale else v + print('{}{}'.format(headline, '%.1f'%v if type(v)==float else v)) + + print_dict(num_nodes) + print_dict(graph_sizes) + print_dict(times) + + results = run(frozen_graph, + data_dir=args.data_dir, + batch_size=args.batch_size, + num_iterations=args.num_iterations, + num_warmup_iterations=args.num_warmup_iterations, + use_synthetic=args.use_synthetic, + display_every=args.display_every, + mode=args.mode, + target_duration=None, + nb_items=args.nb_items, + nb_users=args.nb_users, + dup_mask=args.dup_mask, + K=args.K) + + + print_dict(results) From b5c0e13600ea2eec2acc8c949bfa41983467689f Mon Sep 17 00:00:00 2001 From: Marek Drozdowski Date: Wed, 24 Apr 2019 11:19:20 -0700 Subject: [PATCH 54/56] add README.md file --- tftrt/examples/ncf/README.md | 61 ++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 tftrt/examples/ncf/README.md diff --git a/tftrt/examples/ncf/README.md b/tftrt/examples/ncf/README.md new file mode 100644 index 000000000..17af60da6 --- /dev/null +++ b/tftrt/examples/ncf/README.md @@ -0,0 +1,61 @@ +F examples + +The example script `inference.py` runs inference with NVIDIA NCF model implementation. +This script is included in the NVIDIA Tensorflow Docker +containers under `/workspace/nvidia-examples'. + + +## Model + +Model that we use is available here: +`https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/Recommendation/NCF` + +### Setup for running within an NVIDIA Tensorflow Docker container + + +If you are running these examples within the [NVIDIA TensorFlow docker +container](https://ngc.nvidia.com/catalog/containers/nvidia:tensorflow): + +``` +cd ../third_party/models +export PYTHONPATH="$PYTHONPATH:$PWD" +``` + + + +### Setup for runnign standalone + +If you are running these examples within your own TensorFlow environment, +perform the following steps: + +``` +# Clone this repository (tensorflow/tensorrt) if you haven't already. +git clone https://github.com/tensorflow/tensorrt.git --recurse-submodules + +# Add official models to python path +cd tensorrt/tftrt/examples/third_party/models/ +export PYTHONPATH="$PYTHONPATH:$PWD" +``` +## Usage + +The main Python script is `inference.py` + +``` +python inference.py + --data_dir /data/cache/ml-20m/ + --use_trt + --precision FP16 + ``` + + Where: + + `--data_dir`: Path to the ml-20m test dataset + + `--use_trt`: Convert the graph to a TensorRT graph. + + `--precision`: Precision mode to use, in this case FP16. + + + Run with `--help` to see all available options. + + From a2fb19aab15a338cfab8d3ac8b6ddf44ac6c8766 Mon Sep 17 00:00:00 2001 From: mdrozdowski1996 Date: Wed, 24 Apr 2019 11:21:31 -0700 Subject: [PATCH 55/56] Update README.md --- tftrt/examples/ncf/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tftrt/examples/ncf/README.md b/tftrt/examples/ncf/README.md index 17af60da6..e523f640b 100644 --- a/tftrt/examples/ncf/README.md +++ b/tftrt/examples/ncf/README.md @@ -1,4 +1,4 @@ -F examples +## NCF examples The example script `inference.py` runs inference with NVIDIA NCF model implementation. This script is included in the NVIDIA Tensorflow Docker From b6bb44ad32eaea65bc91de1bff411b902c72a7d7 Mon Sep 17 00:00:00 2001 From: Marek Drozdowski Date: Wed, 24 Apr 2019 14:00:59 -0700 Subject: [PATCH 56/56] code refactor --- .../{ncf => recommendation}/README.md | 28 +++--- .../{ncf => recommendation}/inference.py | 89 +++---------------- 2 files changed, 28 insertions(+), 89 deletions(-) rename tftrt/examples/{ncf => recommendation}/README.md (58%) rename tftrt/examples/{ncf => recommendation}/inference.py (83%) diff --git a/tftrt/examples/ncf/README.md b/tftrt/examples/recommendation/README.md similarity index 58% rename from tftrt/examples/ncf/README.md rename to tftrt/examples/recommendation/README.md index e523f640b..7a5d4b60e 100644 --- a/tftrt/examples/ncf/README.md +++ b/tftrt/examples/recommendation/README.md @@ -21,9 +21,17 @@ cd ../third_party/models export PYTHONPATH="$PYTHONPATH:$PWD" ``` +### Prepare dataset +We are using standard movielense dataset, which is available here: +`https://grouplens.org/datasets/movielens/` -### Setup for runnign standalone +To use it for our script you need to prepare it first (we require csv file). +You can do that using script, which is here: +`tensorrt/tftrt/examples/third_party/DeepLearningExamples/TensorFlow/Recommendation/NCF/prepare_dataset.sh` +You need to provide path where you download ml-20m dataset. + +### Setup for running standalone If you are running these examples within your own TensorFlow environment, perform the following steps: @@ -38,24 +46,24 @@ export PYTHONPATH="$PYTHONPATH:$PWD" ``` ## Usage -The main Python script is `inference.py` +The main Python script is `inference.py`. Here is some example of usage: ``` python inference.py --data_dir /data/cache/ml-20m/ - --use_trt - --precision FP16 - ``` + --use_trt + --precision FP16 +``` - Where: +Where: - `--data_dir`: Path to the ml-20m test dataset +`--data_dir`: Path to the ml-20m test dataset - `--use_trt`: Convert the graph to a TensorRT graph. +`--use_trt`: Convert the graph to a TensorRT graph. - `--precision`: Precision mode to use, in this case FP16. +`--precision`: Precision mode to use, in this case FP16. - Run with `--help` to see all available options. +Run with `--help` to see all available options. diff --git a/tftrt/examples/ncf/inference.py b/tftrt/examples/recommendation/inference.py similarity index 83% rename from tftrt/examples/ncf/inference.py rename to tftrt/examples/recommendation/inference.py index 0e44229a9..7ac0163e1 100644 --- a/tftrt/examples/ncf/inference.py +++ b/tftrt/examples/recommendation/inference.py @@ -60,7 +60,7 @@ def after_run(self, run_context, run_values): run_context.request_stop() -def get_frozen_graph(model_checkpoint="/data/marek_ckpt/model.ckpt", +def get_frozen_graph(model_checkpoint=None, mode="benchmark", use_trt=True, batch_size=1024, @@ -94,7 +94,7 @@ def get_frozen_graph(model_checkpoint="/data/marek_ckpt/model.ckpt", logits = neural_mf(users, items, model_dtype, nb_users, nb_items, mf_dim, mf_reg, mlp_layer_sizes, mlp_layer_regs, 0.1) saver = tf.train.Saver() - saver.restore(tf_sess, "/data/marek_ckpt/model.ckpt") + saver.restore(tf_sess, model_checkpoint) graph0 = tf.graph_util.convert_variables_to_constants(tf_sess, tf_sess.graph_def, output_node_names=['neumf/dense_3/BiasAdd']) frozen_graph = tf.graph_util.remove_training_nodes(graph0) @@ -123,7 +123,7 @@ def get_frozen_graph(model_checkpoint="/data/marek_ckpt/model.ckpt", num_nodes['trt_only'] = len([1 for n in frozen_graph.node if str(n.op)=='TRTEngineOp']) graph_sizes['trt'] = len(frozen_graph.SerializeToString()) - if precision == 'int8': + if precision == 'INT8': calib_graph = frozen_graph graph_size['calib'] = len(calib_graph.SerializeToString()) # INT8 calibration step @@ -148,7 +148,7 @@ def get_frozen_graph(model_checkpoint="/data/marek_ckpt/model.ckpt", def run(frozen_graph, - data_dir='/data/ml-20m/', + data_dir=None, batch_size=1024, num_iterations=None, num_warmup_iterations=None, @@ -190,7 +190,7 @@ def input_fn(): dataset = pd.read_pickle(data_path) users = dataset["user_id"] items = dataset["item_id"] - print(type(users)) + users = users.astype('int32') items = items.astype('int32') user_dataset = tf.data.Dataset.from_tensor_slices(users) @@ -218,6 +218,8 @@ def input_fn(): dataset = pd.read_pickle(data_path) users = dataset["user_id"] num_records = len(users) + if num_iterations is None: + num_iterations = num_records // batch_size logger = LoggerHook( display_every=display_every, @@ -247,78 +249,7 @@ def input_fn(): results['latency_median'] = np.median(iter_times) * 1000 results['latency_min'] = np.min(iter_times) * 1000 - return results - - -def input_data(use_synthetic=False, - batch_size=1024, - data_dir='/data/ml-20m/', - num_iterations=None, - nb_items=26744, - nb_users=1388493): - - if use_synthetic and num_iterations is None: - num_iterations = 10000 - - if use_synthetic: - items = [[random.randint(1, nb_items) for _ in range(batch_size)] for _ in range(num_iterations)] - users = [[random.randint(1, nb_users) for _ in range(batch_size)] for _ in range(num_iterations)] - else: - if os.path.exists(data_dir): - print("Using cached dataset: %s" % (data_dir)) - else: - data_path = os.path.join(data_dir, 'test_ratings.pickle') - dataset = pd.read_pickle(data_path) - users = dataset["user_id"] - items = dataset["item_id"] - - return items, users - - -def run_inference(frozen_graph, - use_synthetic=False, - mode='benchmark', - batch_size=1024, - data_dir='/data/ml-20m/', - num_iterations=10000, - num_warmup_iterations=2000, - nb_items=26744, - nb_users=1388493, - display_every=100): - - items, users = input_data(use_synthetic, - batch_size, - data_dir, - num_iterations, - nb_items, - nb_users) - - with tf.Graph().as_default() as tf_graph: - with tf.Session(config=tf.ConfigProto()) as tf_sess: - tf.import_graph_def(frozen_graph, name='') - for i in tf.get_default_graph().as_graph_def().node: - print(i.name) - output = tf_sess.graph.get_tensor_by_name('neumf/dense_3/BiasAdd:0') - runtimes = [] - res = [] - - for n in range(num_iterations): - item = items[n] - user = users[n] - - beg = time.time() - r = tf_sess.run(output, feed_dict={'item_input:0': item, 'user_input:0': user}) - end = time.time() - - res.append(r) - runtimes.append(end-beg) - if n % display_every == 0: - print(" step %d/%d, iter_time(ms)=%.4f" % ( - len(runtimes), - num_iterations, - np.mean(runtimes[(-1)*display_every:]) * 1000)) - print("throughput: %.1f" % - (batch_size * num_iterations/np.sum(runtimes))) + return results if __name__ == '__main__': @@ -328,7 +259,7 @@ def run_inference(frozen_graph, help='If set, one batch of random data is generated and used at every iteration.') parser.add_argument('--mode', choices=['validation', 'benchmark'], default='validation', help='Which mode to use (validation or benchmark)') - parser.add_argument('--data_dir', type=str, default='/data/ml-20m/', + parser.add_argument('--data_dir', type=str, default=None, help='Directory containing validation set csv files.') parser.add_argument('--calib_data_dir', type=str, help='Directory containing TFRecord files for calibrating int8.') @@ -339,7 +270,7 @@ def run_inference(frozen_graph, parser.add_argument('--use_dynamic_op', action='store_true', help='If set, TRT conversion will be done using dynamic op instead of statically.') parser.add_argument('--precision', type=str, - choices=['fp32', 'fp16', 'int8'], default='fp32', + choices=['FP32', 'FP16', 'INT8'], default='FP32', help='Precision mode to use. FP16 and INT8 only work in conjunction with --use_trt') parser.add_argument('--nb_items', type=int, default=26744, help='Number of items')