Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit 27ef7b8

Browse files
author
DEKHTIARJonathan
committed
Metric Export to CSV added
1 parent 73b6db5 commit 27ef7b8

File tree

3 files changed

+73
-21
lines changed

3 files changed

+73
-21
lines changed

tftrt/examples/benchmark_args.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,14 @@ def __init__(self):
256256
"to the set location in JSON format for further processing."
257257
)
258258

259+
self._parser.add_argument(
260+
"--export_metrics_csv_path",
261+
type=str,
262+
default=None,
263+
help="If set, the script will export runtime metrics and arguments "
264+
"to the set location in CSV format for further processing."
265+
)
266+
259267
self._parser.add_argument(
260268
"--tf_profile_export_path",
261269
type=str,

tftrt/examples/benchmark_runner.py

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import abc
88
import contextlib
99
import copy
10+
import csv
1011
import json
1112
import logging
1213
import sys
@@ -125,28 +126,70 @@ def _debug_print(self, msg):
125126

126127
def _export_runtime_metrics_to_json(self, metric_dict):
127128

128-
metric_dict = {
129-
# Creating a copy to avoid modifying the original
130-
"results": copy.deepcopy(metric_dict),
131-
"runtime_arguments": vars(self._args)
132-
}
129+
try:
133130

134-
json_path = self._args.export_metrics_json_path
135-
if json_path is not None:
136-
try:
137-
with open(json_path, 'w') as json_f:
138-
json_string = json.dumps(
139-
metric_dict,
140-
default=lambda o: o.__dict__,
141-
sort_keys=True,
142-
indent=4
143-
)
144-
print(json_string, file=json_f)
145-
except Exception as e:
146-
print(
147-
"[ERROR] Impossible to save JSON File at path: "
148-
f"{json_path}.\nError: {str(e)}"
131+
file_path = self._args.export_metrics_json_path
132+
if file_path is None:
133+
return
134+
135+
metric_dict = {
136+
# Creating a copy to avoid modifying the original
137+
"results": copy.deepcopy(metric_dict),
138+
"runtime_arguments": vars(self._args)
139+
}
140+
141+
with open(file_path, 'w') as json_f:
142+
json_string = json.dumps(
143+
metric_dict,
144+
default=lambda o: o.__dict__,
145+
sort_keys=True,
146+
indent=4
149147
)
148+
print(json_string, file=json_f)
149+
150+
except Exception as e:
151+
print(f"An exception occured during export to JSON: {e}")
152+
153+
def _export_runtime_metrics_to_csv(self, metric_dict):
154+
155+
try:
156+
157+
file_path = self._args.export_metrics_csv_path
158+
if file_path is None:
159+
return
160+
161+
data = {f"metric_{k}": v for k, v in metric_dict.items()}
162+
163+
args_to_save = [
164+
"batch_size",
165+
"input_saved_model_dir",
166+
"minimum_segment_size",
167+
"no_tf32",
168+
"precision",
169+
"use_dynamic_shape",
170+
"use_synthetic_data",
171+
"use_tftrt",
172+
"use_xla",
173+
"use_xla_auto_jit"
174+
]
175+
-
176+
runtime_arguments = vars(self._args)
177+
for key in args_to_save:
178+
data[f"arg_{key}"] = str(runtime_arguments[key]).split("/")[-1]
179+
180+
fieldnames = sorted(data.keys())
181+
182+
if not os.path.isfile(file_path):
183+
with open(file_path, 'w') as outcsv:
184+
writer = csv.DictWriter(outcsv, fieldnames=fieldnames, delimiter=',')
185+
writer.writeheader()
186+
187+
with open(file_path, 'a') as outcsv:
188+
writer = csv.DictWriter(outcsv, fieldnames=fieldnames, delimiter=',')
189+
writer.writerow(data)
190+
191+
except Exception as e:
192+
print(f"An exception occured during export to CSV: {e}")
150193

151194
def _get_graph_func(self):
152195
"""Retreives a frozen SavedModel and applies TF-TRT
@@ -524,6 +567,7 @@ def timing_metrics(time_arr, log_prefix):
524567
metrics.update(timing_metrics(memcopy_times, "Data MemCopyHtoD Time"))
525568

526569
self._export_runtime_metrics_to_json(metrics)
570+
self._export_runtime_metrics_to_csv(metrics)
527571

528572
def log_value(key, val):
529573
if isinstance(val, int):

tftrt/examples/benchmark_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def timed_section(msg, activate=True, start_end_mode=True):
4949
total_time = time.time() - start_time
5050

5151
if start_end_mode:
52-
print(f"[END] `{msg}` - Duration: {total_time:.1f}s")
52+
print(f"[END] {msg} - Duration: {total_time:.1f}s")
5353
print("=" * 80, "\n")
5454
else:
5555
print(f"{msg:18s}: {total_time:.4f}s")

0 commit comments

Comments
 (0)