Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit eb97b19

Browse files
author
DEKHTIARJonathan
committed
TF Profiler Instrumentation
1 parent 263043b commit eb97b19

File tree

2 files changed

+71
-38
lines changed

2 files changed

+71
-38
lines changed

tftrt/examples/benchmark_args.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,14 @@ def __init__(self):
245245
"to the set location in JSON format for further processing."
246246
)
247247

248+
self._parser.add_argument(
249+
"--tf_profile_export_path",
250+
type=str,
251+
default=None,
252+
help="If set, the script will export tf.profile files for further "
253+
"performance analysis."
254+
)
255+
248256
self._add_bool_argument(
249257
name="debug",
250258
default=False,

tftrt/examples/benchmark_runner.py

Lines changed: 63 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66

77
import abc
8+
import contextlib
89
import copy
910
import json
1011
import logging
@@ -390,7 +391,12 @@ def log_step(step_idx, display_every, iter_time, memcpyHtoD_time, dequeue_time):
390391
)
391392

392393
@force_gpu_resync
393-
@tf.function()
394+
@tf.function(jit_compile=self._args.use_xla)
395+
def dequeue_batch(ds_iter):
396+
return next(ds_iter)
397+
398+
@force_gpu_resync
399+
@tf.function(jit_compile=self._args.use_xla)
394400
def force_data_on_gpu(data, device="/gpu:0"):
395401
with tf.device(device):
396402
if isinstance(data, (list, tuple)):
@@ -403,58 +409,77 @@ def force_data_on_gpu(data, device="/gpu:0"):
403409
output_data[k] = tf.identity(v)
404410
else:
405411
output_data = tf.identity(data)
412+
406413
return output_data
407414

415+
if self._args.tf_profile_export_path:
416+
profiling_ctx = tf.profiler.experimental.Profile(
417+
self._args.tf_profile_export_path
418+
)
419+
tracing_ctx = tf.profiler.experimental.Trace
420+
else:
421+
profiling_ctx = contextlib.nullcontext()
422+
tracing_ctx = lambda *a, **kw: contextlib.nullcontext()
423+
408424
step_idx = 0
409425
ds_iter = iter(dataset)
410426

411-
while True:
427+
with profiling_ctx:
412428

413-
try:
414-
start_time = time.time()
415-
data_batch = next(ds_iter)
416-
dequeue_times.append(time.time() - start_time)
417-
except:
418-
break
419-
420-
start_time = time.time()
421-
data_batch = force_data_on_gpu(data_batch)
422-
memcopy_times.append(time.time() - start_time)
423-
424-
x, y = self.preprocess_model_inputs(data_batch)
425-
426-
start_time = time.time()
427-
y_pred = infer_batch(x)
428-
iter_times.append(time.time() - start_time)
429-
430-
if not self._args.debug_performance:
431-
log_step(
432-
step_idx + 1,
433-
display_every=self._args.display_every,
434-
iter_time=np.mean(iter_times[-self._args.display_every:]) * 1000,
435-
memcpyHtoD_time=np.mean(memcopy_times[-self._args.display_every:]) * 1000,
436-
dequeue_time=np.mean(dequeue_times[-self._args.display_every:]) * 1000
437-
)
438-
else:
439-
print(f"{'GPU Iteration Time':18s}: {iter_times[-1]:08.4f}s")
440-
print(f"{'Data MemCopyHtoD Time':18s}: {memcpyHtoD_time[-1]:08.4f}s")
441-
print(f"{'Data Dequeue Time':18s}: {dequeue_times[-1]:08.4f}s")
429+
while True:
442430

443-
if not self._args.use_synthetic_data:
444-
data_aggregator.aggregate_data(y_pred, y)
431+
step_idx += 1
445432

446-
if (self._args.num_iterations is not None and
447-
step_idx + 1 >= self._args.num_iterations):
448-
break
433+
if (self._args.num_iterations is not None and
434+
step_idx >= self._args.num_iterations):
435+
break
436+
437+
with tracing_ctx('Inference Step', step_num=step_idx, _r=1):
438+
439+
with tracing_ctx('Input Dequeueing', step_num=step_idx, _r=1):
440+
try:
441+
start_time = time.time()
442+
data_batch = dequeue_batch(ds_iter)
443+
dequeue_times.append(time.time() - start_time)
444+
except:
445+
print("[Exiting] Reached end of dataset ...")
446+
break
447+
448+
with tracing_ctx('Inputs MemcpyHtoD', step_num=step_idx, _r=1):
449+
start_time = time.time()
450+
data_batch = force_data_on_gpu(data_batch)
451+
memcopy_times.append(time.time() - start_time)
452+
453+
with tracing_ctx('Inputs Preprocessing', step_num=step_idx, _r=1):
454+
x, y = self.preprocess_model_inputs(data_batch)
455+
456+
with tracing_ctx('GPU Inference', step_num=step_idx, _r=1):
457+
start_time = time.time()
458+
y_pred = infer_batch(x)
459+
iter_times.append(time.time() - start_time)
460+
461+
if not self._args.debug_performance:
462+
log_step(
463+
step_idx,
464+
display_every=self._args.display_every,
465+
iter_time=np.mean(iter_times[-self._args.display_every:]) * 1000,
466+
memcpyHtoD_time=np.mean(memcopy_times[-self._args.display_every:]) * 1000,
467+
dequeue_time=np.mean(dequeue_times[-self._args.display_every:]) * 1000
468+
)
469+
else:
470+
print(f"{'GPU Iteration Time':18s}: {iter_times[-1]:08.4f}s")
471+
print(f"{'Data MemCopyHtoD Time':18s}: {memcpyHtoD_time[-1]:08.4f}s")
472+
print(f"{'Data Dequeue Time':18s}: {dequeue_times[-1]:08.4f}s")
449473

450-
step_idx += 1
474+
if not self._args.use_synthetic_data:
475+
data_aggregator.aggregate_data(y_pred, y)
451476

452477
if (
453478
not self._args.debug_performance and
454479
step_idx % self._args.display_every != 0
455480
): # avoids double printing
456481
log_step(
457-
step_idx + 1,
482+
step_idx,
458483
display_every=1, # force print
459484
iter_time=np.mean(iter_times[-self._args.display_every:]) * 1000,
460485
memcpyHtoD_time=np.mean(memcopy_times[-self._args.display_every:]) * 1000,

0 commit comments

Comments
 (0)