diff --git a/torchrec/distributed/benchmark/benchmark_train_pipeline.py b/torchrec/distributed/benchmark/benchmark_train_pipeline.py index eabc1f0d9..9b3751661 100644 --- a/torchrec/distributed/benchmark/benchmark_train_pipeline.py +++ b/torchrec/distributed/benchmark/benchmark_train_pipeline.py @@ -40,7 +40,11 @@ TestTowerCollectionSparseNNConfig, TestTowerSparseNNConfig, ) -from torchrec.distributed.benchmark.benchmark_utils import benchmark_func, cmd_conf +from torchrec.distributed.benchmark.benchmark_utils import ( + benchmark_func, + benchmark_operators, + cmd_conf, +) from torchrec.distributed.comm import get_local_size from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.planner import Topology @@ -110,6 +114,9 @@ class RunOptions: sparse_lr: float = 0.1 sparse_momentum: Optional[float] = None sparse_weight_decay: Optional[float] = None + benchmark_operators: bool = False + target_operators: Optional[List[str]] = None + limit_operator_results: int = 10 @dataclass @@ -379,10 +386,11 @@ def _func_to_benchmark( if jit_suffix else type(pipeline).__name__ ) + result = benchmark_func( name=name, - bench_inputs=bench_inputs, # pyre-ignore - prof_inputs=bench_inputs, # pyre-ignore + bench_inputs=bench_inputs, # pyre-ignore[6] + prof_inputs=bench_inputs, # pyre-ignore[6] num_benchmarks=5, num_profiles=2, profile_dir=run_option.profile, @@ -393,6 +401,19 @@ def _func_to_benchmark( ) results.append(result) + if run_option.benchmark_operators: + op_results = benchmark_operators( + func_to_benchmark=pipeline, + bench_inputs=bench_inputs, + num_benchmarks=5, + device_type="cuda", + target_operators=run_option.target_operators, + is_pipeline=True, + rank=rank, + limit_results=run_option.limit_operator_results, + ) + results.extend(op_results) + if rank == 0: for result in results: print(result) diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index 00da22230..72672c6b9 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -905,6 +905,83 @@ def trace_handler(prof) -> None: ) +def benchmark_operators( + func_to_benchmark: Any, # pyre-ignore[2] + bench_inputs: List[Any], # pyre-ignore[2] + num_benchmarks: int, + device_type: str = "cuda", + target_operators: Optional[List[str]] = None, + is_pipeline: bool = False, + rank: int = -1, + limit_results: int = 10, +) -> List[BenchmarkResult]: + activities = [torch.profiler.ProfilerActivity.CPU] + if device_type == "cuda": + activities.append(torch.profiler.ProfilerActivity.CUDA) + + results = [] + elapsed_times = {} + peak_memory_usage = {} + + for _ in range(num_benchmarks): + with torch.profiler.profile( + activities=activities, + record_shapes=True, + profile_memory=True, + with_stack=True, + with_flops=True, + with_modules=True, + ) as prof: + if is_pipeline: + dataloader = iter(bench_inputs) + while True: + try: + func_to_benchmark.progress(dataloader) + except StopIteration: + break + else: + for bench_input in bench_inputs: + func_to_benchmark(bench_input) + + for evt in prof.key_averages(): + if evt.key not in elapsed_times: + elapsed_times[evt.key] = [] + peak_memory_usage[evt.key] = 0 + + elapsed_times[evt.key].append(evt.self_device_time_total / 1e3) + peak_memory_usage[evt.key] = max( + peak_memory_usage[evt.key], evt.self_device_memory_usage + ) + + for op in elapsed_times: + if target_operators is not None and op not in target_operators: + continue + + mem_stats = [ + MemoryStats( + rank=rank, + malloc_retries=-1, # Not supported in profiler + max_mem_allocated_mbs=peak_memory_usage[op] / 1024 / 1024, + max_mem_reserved_mbs=-1, # Not supported in profiler + ) + ] + + results.append( + BenchmarkResult( + short_name=f"operator_{op}", + elapsed_time=torch.tensor(elapsed_times[op], dtype=torch.float), + mem_stats=mem_stats, + rank=rank, + ) + ) + + sorted_results = sorted( + results, key=lambda x: x.runtime_percentile(90), reverse=True + ) + + return sorted_results[:limit_results] + + def benchmark_type_name(compile_mode: CompileMode, sharding_type: ShardingType) -> str: if sharding_type == ShardingType.TABLE_WISE: name = "tw-sharded"