From 3740cbed53d1b66c11834e792136317fdb34885d Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 26 Aug 2025 14:54:17 -0700 Subject: [PATCH] [mxfp8 moe training] add grouped gemm benchmark script stack-info: PR: https://github.com/pytorch/ao/pull/2882, branch: danielvegamyhre/stack/61 --- benchmarks/float8/bench_grouped_mm.py | 149 ----------- .../moe_training/bench_2d-3d_grouped_gemm.py | 238 ++++++++++++++++++ 2 files changed, 238 insertions(+), 149 deletions(-) delete mode 100644 benchmarks/float8/bench_grouped_mm.py create mode 100644 benchmarks/prototype/moe_training/bench_2d-3d_grouped_gemm.py diff --git a/benchmarks/float8/bench_grouped_mm.py b/benchmarks/float8/bench_grouped_mm.py deleted file mode 100644 index 1bded14c44..0000000000 --- a/benchmarks/float8/bench_grouped_mm.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. -from typing import Optional - -import fire -import pandas as pd -import torch -from utils import do_benchmarks, get_name_to_moe_shapes_iter - -from torchao.prototype.moe_training.utils import generate_jagged_offs -from torchao.testing.training.roofline_utils import get_specs - - -@torch.inference_mode() -def run( - n_limit: Optional[int] = None, - out_filename: Optional[str] = None, - M: Optional[int] = None, - K: Optional[int] = None, - N: Optional[int] = None, - E: Optional[int] = None, # dim 0 of B tensor (num experts) - use_gpu_kernel_time: bool = True, - shape_gen_name="llama4_17bx16e", - recipe: str = "rowwise", -): - device = "cuda" - - assert recipe in ("rowwise",), "unsupported" - - specs = get_specs() - bf16_peak_tops = specs["bf16_peak_tops"] - fp8_peak_tops = specs["fp8_peak_tops"] - print(f"gpu_name: {torch.cuda.get_device_name(0)}") - print(f"peak tops: bf16 {bf16_peak_tops:.2e}, fp8 {fp8_peak_tops:.2e}") - headers = ( - "name", - "recipe", - "M", - "K", - "N", - "E", - "time_s", - "speedup", - "fp8_speedup", - ) - results = [] - - dtype = torch.bfloat16 - name_to_shapes = get_name_to_moe_shapes_iter(shape_gen_name, M, K, N, E) - - for idx, (name, (M, K, N, E)) in enumerate( - name_to_shapes, - ): - if n_limit is not None and idx >= n_limit: - break - assert M % E == 0, ( - "tokens (M) must be evenly divisible by num experts (E) for this benchmark" - ) - tops = 2 * M * N * K * E - print("M, K, N, E:", M, K, N, E, f"tops: {tops:.2E}") - - # Run bf16 torch._grouped_mm baseline. - A = torch.randn(M, K, device=device, dtype=dtype) - B = torch.randn(E, N, K, device=device, dtype=dtype) - offs = generate_jagged_offs(E, M) - print(f"offs: {offs}") - ref_time_sec, ref_tops_sec, ref_pct_top_peak = do_benchmarks( - tops, - bf16_peak_tops, - use_gpu_kernel_time, - torch._grouped_mm, - A, - B.transpose(-2, -1), - offs, - ) - print( - f"{dtype} time_sec {ref_time_sec:.2E}, tops/sec {ref_tops_sec:.2E}, pct_peak {ref_pct_top_peak:.3f}" - ) - del A - del B - - # Run scaled_grouped_mm. - A_hp = torch.randn(M, K, device=device) - B_hp_t = torch.randn(E, N, K, device=device).transpose(-2, -1) - - if recipe == "rowwise": - # TODO: add e5m2 - A = A_hp.to(torch.float8_e4m3fn) - B = B_hp_t.to(torch.float8_e4m3fn) - peak_tops = fp8_peak_tops - scale_a = torch.ones(M, device=device) - scale_b = torch.ones(E, N, device=device) - else: - assert False, f"unknown recipe {recipe}" - - def do_scaled_grouped_mm(A, B): - nonlocal scale_a - nonlocal scale_b - nonlocal offs - return torch._scaled_grouped_mm(A, B, scale_a, scale_b, offs=offs) - - if recipe == "rowwise": - do_matmul = do_scaled_grouped_mm - else: - raise ValueError(f"unknown recipe {recipe}") - - time_sec, tops_sec, pct_top_peak = do_benchmarks( - tops, peak_tops, use_gpu_kernel_time, do_matmul, A, B - ) - print( - f"time_sec {time_sec:.2E}, tops/sec {tops_sec:.2E}, pct_peak {pct_top_peak:.3f}" - ) - - del A, B - if scale_a is not None: - del scale_a - if scale_b is not None: - del scale_b - - results.append( - [ - name, - recipe, - M, - K, - N, - E, - ref_time_sec, - time_sec, - ref_time_sec / time_sec, - ] - ) - - data_df = pd.DataFrame(results, columns=headers) - print(data_df) - - if out_filename is not None: - data_df.to_csv(out_filename) - - -def main() -> None: - fire.Fire(run) - - -if __name__ == "__main__": - main() # pragma: no cover diff --git a/benchmarks/prototype/moe_training/bench_2d-3d_grouped_gemm.py b/benchmarks/prototype/moe_training/bench_2d-3d_grouped_gemm.py new file mode 100644 index 0000000000..87b54d124e --- /dev/null +++ b/benchmarks/prototype/moe_training/bench_2d-3d_grouped_gemm.py @@ -0,0 +1,238 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py +import argparse +import itertools +from dataclasses import dataclass +from typing import List + +import torch +from tabulate import tabulate +from tqdm import tqdm +from utils import benchmark_cuda_function_in_microseconds + +from torchao.float8.config import ScalingGranularity +from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated +from torchao.prototype.moe_training.utils import generate_jagged_offs +from torchao.prototype.mx_formats.mx_tensor import to_mx +from torchao.prototype.mx_formats.utils import ( + to_blocked_per_group_2d, + to_blocked_per_group_3d, +) + +device = torch.device("cuda") + + +@dataclass(frozen=True) +class ExperimentConfig: + e: int + m: int + n: int + k: int + + +@dataclass(frozen=True) +class ExperimentResult: + bf16_us: float + fp8_rowwise_us: float + mxfp8_us: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + # Llama4 shapes + M = [16640] + K = [5120] + N = [8192] + E = [16] + configs = [] + for e, m, n, k in itertools.product( + E, + M, + N, + K, + ): + configs.append( + ExperimentConfig( + e=e, + m=m, + n=n, + k=k, + ) + ) + return configs + + +def run_experiment( + config: ExperimentConfig, args: argparse.Namespace +) -> ExperimentResult: + e, m, n, k = config.e, config.m, config.n, config.k + + # define test inputs + A = torch.randn( + (m, k), + dtype=torch.bfloat16, + device=device, + ) + B_t = torch.randn( + (e, n, k), + dtype=torch.bfloat16, + device=device, + requires_grad=True, + ).transpose(-2, -1) + + # Configure groups + n_groups = e + Mg = A.shape[0] + alignment_size = 16 + offs = generate_jagged_offs(n_groups, Mg, multiple_of=alignment_size) + + # benchmark bf16 grouped mm + bf16_us = benchmark_cuda_function_in_microseconds( + torch._grouped_mm, + A, + B_t, + offs, + out_dtype=torch.bfloat16, + ) + + # bench fp8 rowwise grouped mm + fp8_rowwise_us = bench_fp8_rowwise_grouped_mm(A, B_t, offs) + + # benchmark mxfp8 grouped mm + mxfp8_us = bench_mxfp8_grouped_mm(A, B_t, offs) + + return ExperimentResult( + bf16_us=round(bf16_us, 3), + fp8_rowwise_us=round(fp8_rowwise_us, 3), + mxfp8_us=round(mxfp8_us, 3), + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "E", + "M", + "N", + "K", + "bf16_time_us", + "fp8_rowwise_time_us", + "mxfp8_time_us", + ] + rows = [] + for experiment in experiments: + rows.append( + [ + experiment.config.e, + experiment.config.m, + experiment.config.n, + experiment.config.k, + experiment.result.bf16_us, + experiment.result.fp8_rowwise_us, + experiment.result.mxfp8_us, + ] + ) + print(tabulate(rows, headers=headers)) + + +# benchmark fp8 grouped mm +def bench_fp8_rowwise_grouped_mm(A, B_t, offs) -> float: + # Convert A to float8, row-major for left operand of grouped GEMM. + A_scales = tensor_to_scale( + A, + torch.float8_e4m3fn, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-1, + round_scales_to_power_of_2=True, + ) + A_scaled = A.to(torch.float32) * A_scales + A_fp8_row_major = to_fp8_saturated(A_scaled, torch.float8_e4m3fn) + + # Convert B_t to float8, column-major for right operand of grouped GEMM. + B_t_scales = tensor_to_scale( + B_t, + torch.float8_e4m3fn, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-2, + round_scales_to_power_of_2=True, + ) + B_t_scaled = B_t.to(torch.float32) * B_t_scales + B_t_fp8_col_major = to_fp8_saturated(B_t_scaled, torch.float8_e4m3fn) + + # Bench the gemm + fp8_us = benchmark_cuda_function_in_microseconds( + torch._scaled_grouped_mm, + A_fp8_row_major, + B_t_fp8_col_major, + A_scales.squeeze(1).reciprocal(), + B_t_scales.squeeze(1).reciprocal(), + offs, + out_dtype=torch.bfloat16, + use_fast_accum=True, + ) + return fp8_us + + +def bench_mxfp8_grouped_mm(A, B_t, offs, block_size=32) -> float: + # A_mx shape: (M, K) + # A_scale shape: (M, K//block_size) + A_scales, A_fp8 = to_mx(A, elem_dtype=torch.float8_e4m3fn, block_size=block_size) + + # B_mx shape: (E, N, K) + # B_scale shape: (E, N, K//block_size) + B_scales, B_fp8 = to_mx( + B_t.transpose(-2, -1), + elem_dtype=torch.float8_e4m3fn, + block_size=block_size, + ) + + # Convert scales for each group to blocked format. + Mg, K = A_fp8.shape + A_scales_blocked, starting_row_after_padding = to_blocked_per_group_2d( + A_scales, offs, Mg, K + ) + B_scales_blocked = to_blocked_per_group_3d(B_scales) + + # From this, we compute `group_sizes` and `starting_row_after_padding`: + # group_sizes = [32, 32, 64] + # starting_row_after_padding = [0, 32, 64, 128] + zero = torch.tensor([0], dtype=offs.dtype, device=offs.device) + group_sizes = torch.diff(offs, prepend=zero).to(torch.int64) + + # Run the grouped mm + mxfp8_us = benchmark_cuda_function_in_microseconds( + torch.ops.fbgemm.mx8mx8bf16_grouped_stacked, + A_fp8, + B_fp8, + A_scales_blocked, + B_scales_blocked, + group_sizes, + starting_row_after_padding=starting_row_after_padding, + ) + return mxfp8_us + + +def main(args: argparse.Namespace): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + for config in tqdm(configs): + result = run_experiment(config, args) + results.append(Experiment(config=config, result=result)) + + # Use Tabulate to print results + print_results(results) + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser() + args = arg_parser.parse_args() + main(args)