5
5
import os
6
6
7
7
import abc
8
+ import contextlib
8
9
import copy
9
10
import json
10
11
import logging
@@ -390,7 +391,12 @@ def log_step(step_idx, display_every, iter_time, memcpyHtoD_time, dequeue_time):
390
391
)
391
392
392
393
@force_gpu_resync
393
- @tf .function ()
394
+ @tf .function (jit_compile = self ._args .use_xla )
395
+ def dequeue_batch (ds_iter ):
396
+ return next (ds_iter )
397
+
398
+ @force_gpu_resync
399
+ @tf .function (jit_compile = self ._args .use_xla )
394
400
def force_data_on_gpu (data , device = "/gpu:0" ):
395
401
with tf .device (device ):
396
402
if isinstance (data , (list , tuple )):
@@ -403,58 +409,77 @@ def force_data_on_gpu(data, device="/gpu:0"):
403
409
output_data [k ] = tf .identity (v )
404
410
else :
405
411
output_data = tf .identity (data )
412
+
406
413
return output_data
407
414
415
+ if self ._args .tf_profile_export_path :
416
+ profiling_ctx = tf .profiler .experimental .Profile (
417
+ self ._args .tf_profile_export_path
418
+ )
419
+ tracing_ctx = tf .profiler .experimental .Trace
420
+ else :
421
+ profiling_ctx = contextlib .nullcontext ()
422
+ tracing_ctx = lambda * a , ** kw : contextlib .nullcontext ()
423
+
408
424
step_idx = 0
409
425
ds_iter = iter (dataset )
410
426
411
- while True :
427
+ with profiling_ctx :
412
428
413
- try :
414
- start_time = time .time ()
415
- data_batch = next (ds_iter )
416
- dequeue_times .append (time .time () - start_time )
417
- except :
418
- break
419
-
420
- start_time = time .time ()
421
- data_batch = force_data_on_gpu (data_batch )
422
- memcopy_times .append (time .time () - start_time )
423
-
424
- x , y = self .preprocess_model_inputs (data_batch )
425
-
426
- start_time = time .time ()
427
- y_pred = infer_batch (x )
428
- iter_times .append (time .time () - start_time )
429
-
430
- if not self ._args .debug_performance :
431
- log_step (
432
- step_idx + 1 ,
433
- display_every = self ._args .display_every ,
434
- iter_time = np .mean (iter_times [- self ._args .display_every :]) * 1000 ,
435
- memcpyHtoD_time = np .mean (memcopy_times [- self ._args .display_every :]) * 1000 ,
436
- dequeue_time = np .mean (dequeue_times [- self ._args .display_every :]) * 1000
437
- )
438
- else :
439
- print (f"{ 'GPU Iteration Time' :18s} : { iter_times [- 1 ]:08.4f} s" )
440
- print (f"{ 'Data MemCopyHtoD Time' :18s} : { memcpyHtoD_time [- 1 ]:08.4f} s" )
441
- print (f"{ 'Data Dequeue Time' :18s} : { dequeue_times [- 1 ]:08.4f} s" )
429
+ while True :
442
430
443
- if not self ._args .use_synthetic_data :
444
- data_aggregator .aggregate_data (y_pred , y )
431
+ step_idx += 1
445
432
446
- if (self ._args .num_iterations is not None and
447
- step_idx + 1 >= self ._args .num_iterations ):
448
- break
433
+ if (self ._args .num_iterations is not None and
434
+ step_idx >= self ._args .num_iterations ):
435
+ break
436
+
437
+ with tracing_ctx ('Inference Step' , step_num = step_idx , _r = 1 ):
438
+
439
+ with tracing_ctx ('Input Dequeueing' , step_num = step_idx , _r = 1 ):
440
+ try :
441
+ start_time = time .time ()
442
+ data_batch = dequeue_batch (ds_iter )
443
+ dequeue_times .append (time .time () - start_time )
444
+ except :
445
+ print ("[Exiting] Reached end of dataset ..." )
446
+ break
447
+
448
+ with tracing_ctx ('Inputs MemcpyHtoD' , step_num = step_idx , _r = 1 ):
449
+ start_time = time .time ()
450
+ data_batch = force_data_on_gpu (data_batch )
451
+ memcopy_times .append (time .time () - start_time )
452
+
453
+ with tracing_ctx ('Inputs Preprocessing' , step_num = step_idx , _r = 1 ):
454
+ x , y = self .preprocess_model_inputs (data_batch )
455
+
456
+ with tracing_ctx ('GPU Inference' , step_num = step_idx , _r = 1 ):
457
+ start_time = time .time ()
458
+ y_pred = infer_batch (x )
459
+ iter_times .append (time .time () - start_time )
460
+
461
+ if not self ._args .debug_performance :
462
+ log_step (
463
+ step_idx ,
464
+ display_every = self ._args .display_every ,
465
+ iter_time = np .mean (iter_times [- self ._args .display_every :]) * 1000 ,
466
+ memcpyHtoD_time = np .mean (memcopy_times [- self ._args .display_every :]) * 1000 ,
467
+ dequeue_time = np .mean (dequeue_times [- self ._args .display_every :]) * 1000
468
+ )
469
+ else :
470
+ print (f"{ 'GPU Iteration Time' :18s} : { iter_times [- 1 ]:08.4f} s" )
471
+ print (f"{ 'Data MemCopyHtoD Time' :18s} : { memcpyHtoD_time [- 1 ]:08.4f} s" )
472
+ print (f"{ 'Data Dequeue Time' :18s} : { dequeue_times [- 1 ]:08.4f} s" )
449
473
450
- step_idx += 1
474
+ if not self ._args .use_synthetic_data :
475
+ data_aggregator .aggregate_data (y_pred , y )
451
476
452
477
if (
453
478
not self ._args .debug_performance and
454
479
step_idx % self ._args .display_every != 0
455
480
): # avoids double printing
456
481
log_step (
457
- step_idx + 1 ,
482
+ step_idx ,
458
483
display_every = 1 , # force print
459
484
iter_time = np .mean (iter_times [- self ._args .display_every :]) * 1000 ,
460
485
memcpyHtoD_time = np .mean (memcopy_times [- self ._args .display_every :]) * 1000 ,
0 commit comments