Skip to content

Commit a08719c

Browse files
[mxfp8 moe training] add grouped gemm benchmark script
stack-info: PR: #2882, branch: danielvegamyhre/stack/61
1 parent fb1628c commit a08719c

File tree

3 files changed

+243
-149
lines changed

3 files changed

+243
-149
lines changed

benchmarks/float8/bench_grouped_mm.py

Lines changed: 0 additions & 149 deletions
This file was deleted.
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
import argparse
8+
import itertools
9+
from dataclasses import dataclass
10+
from typing import List
11+
12+
import torch
13+
from tabulate import tabulate
14+
from tqdm import tqdm
15+
from utils import benchmark_cuda_function_in_microseconds
16+
17+
from torchao.float8.config import ScalingGranularity
18+
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
19+
from torchao.prototype.moe_training.utils import generate_jagged_offs
20+
from torchao.prototype.mx_formats.mx_tensor import to_mx
21+
from torchao.prototype.mx_formats.utils import (
22+
to_blocked_per_group_2d,
23+
to_blocked_per_group_3d,
24+
)
25+
26+
device = torch.device("cuda")
27+
28+
29+
@dataclass(frozen=True)
30+
class ExperimentConfig:
31+
e: int
32+
m: int
33+
n: int
34+
k: int
35+
36+
37+
@dataclass(frozen=True)
38+
class ExperimentResult:
39+
bf16_us: float
40+
fp8_rowwise_us: float
41+
mxfp8_us: float
42+
43+
44+
@dataclass(frozen=True)
45+
class Experiment:
46+
config: ExperimentConfig
47+
result: ExperimentResult
48+
49+
50+
def get_configs() -> List[ExperimentConfig]:
51+
# Llama4 shapes
52+
M = [16640]
53+
K = [5120]
54+
N = [8192]
55+
E = [16]
56+
configs = []
57+
for e, m, n, k in itertools.product(
58+
E,
59+
M,
60+
N,
61+
K,
62+
):
63+
configs.append(
64+
ExperimentConfig(
65+
e=e,
66+
m=m,
67+
n=n,
68+
k=k,
69+
)
70+
)
71+
return configs
72+
73+
74+
def run_experiment(
75+
config: ExperimentConfig, args: argparse.Namespace
76+
) -> ExperimentResult:
77+
e, m, n, k = config.e, config.m, config.n, config.k
78+
79+
# define test inputs
80+
A = torch.randn(
81+
(m, k),
82+
dtype=torch.bfloat16,
83+
device=device,
84+
)
85+
B_t = torch.randn(
86+
(e, n, k),
87+
dtype=torch.bfloat16,
88+
device=device,
89+
requires_grad=True,
90+
).transpose(-2, -1)
91+
92+
# Configure groups
93+
n_groups = e
94+
Mg = A.shape[0]
95+
alignment_size = 16
96+
offs = generate_jagged_offs(n_groups, Mg, multiple_of=alignment_size)
97+
98+
# benchmark bf16 grouped mm
99+
bf16_us = benchmark_cuda_function_in_microseconds(
100+
torch._grouped_mm,
101+
A,
102+
B_t,
103+
offs,
104+
out_dtype=torch.bfloat16,
105+
)
106+
107+
# bench fp8 rowwise grouped mm
108+
fp8_rowwise_us = bench_fp8_rowwise_grouped_mm(A, B_t, offs)
109+
110+
# benchmark mxfp8 grouped mm
111+
mxfp8_us = bench_mxfp8_grouped_mm(A, B_t, offs)
112+
113+
return ExperimentResult(
114+
bf16_us=round(bf16_us, 3),
115+
fp8_rowwise_us=round(fp8_rowwise_us, 3),
116+
mxfp8_us=round(mxfp8_us, 3),
117+
)
118+
119+
120+
def print_results(experiments: List[Experiment]):
121+
headers = [
122+
"E",
123+
"M",
124+
"N",
125+
"K",
126+
"bf16_time_us",
127+
"fp8_rowwise_time_us",
128+
"mxfp8_time_us",
129+
]
130+
rows = []
131+
for experiment in experiments:
132+
rows.append(
133+
[
134+
experiment.config.e,
135+
experiment.config.m,
136+
experiment.config.n,
137+
experiment.config.k,
138+
experiment.result.bf16_us,
139+
experiment.result.fp8_rowwise_us,
140+
experiment.result.mxfp8_us,
141+
]
142+
)
143+
print(tabulate(rows, headers=headers))
144+
145+
146+
# benchmark fp8 grouped mm
147+
def bench_fp8_rowwise_grouped_mm(A, B_t, offs) -> float:
148+
# Convert A to float8, row-major for left operand of grouped GEMM.
149+
A_scales = tensor_to_scale(
150+
A,
151+
torch.float8_e4m3fn,
152+
scaling_granularity=ScalingGranularity.AXISWISE,
153+
axiswise_dim=-1,
154+
round_scales_to_power_of_2=True,
155+
)
156+
A_scaled = A.to(torch.float32) * A_scales
157+
A_fp8_row_major = to_fp8_saturated(A_scaled, torch.float8_e4m3fn)
158+
159+
# Convert B_t to float8, column-major for right operand of grouped GEMM.
160+
B_t_scales = tensor_to_scale(
161+
B_t,
162+
torch.float8_e4m3fn,
163+
scaling_granularity=ScalingGranularity.AXISWISE,
164+
axiswise_dim=-2,
165+
round_scales_to_power_of_2=True,
166+
)
167+
B_t_scaled = B_t.to(torch.float32) * B_t_scales
168+
B_t_fp8_col_major = to_fp8_saturated(B_t_scaled, torch.float8_e4m3fn)
169+
170+
# Bench the gemm
171+
fp8_us = benchmark_cuda_function_in_microseconds(
172+
torch._scaled_grouped_mm,
173+
A_fp8_row_major,
174+
B_t_fp8_col_major,
175+
A_scales.squeeze(1).reciprocal(),
176+
B_t_scales.squeeze(1).reciprocal(),
177+
offs,
178+
out_dtype=torch.bfloat16,
179+
use_fast_accum=True,
180+
)
181+
return fp8_us
182+
183+
184+
def bench_mxfp8_grouped_mm(A, B_t, offs, block_size=32) -> float:
185+
# A_mx shape: (M, K)
186+
# A_scale shape: (M, K//block_size)
187+
A_scales, A_fp8 = to_mx(A, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
188+
189+
# B_mx shape: (E, N, K)
190+
# B_scale shape: (E, N, K//block_size)
191+
B_scales, B_fp8 = to_mx(
192+
B_t.transpose(-2, -1),
193+
elem_dtype=torch.float8_e4m3fn,
194+
block_size=block_size,
195+
)
196+
197+
# Convert scales for each group to blocked format.
198+
Mg, K = A_fp8.shape
199+
A_scales_blocked, starting_row_after_padding = to_blocked_per_group_2d(
200+
A_scales, offs, Mg, K
201+
)
202+
B_scales_blocked = to_blocked_per_group_3d(B_scales)
203+
204+
# From this, we compute `group_sizes` and `starting_row_after_padding`:
205+
# group_sizes = [32, 32, 64]
206+
# starting_row_after_padding = [0, 32, 64, 128]
207+
zero = torch.tensor([0], dtype=offs.dtype, device=offs.device)
208+
group_sizes = torch.diff(offs, prepend=zero).to(torch.int64)
209+
210+
# Run the grouped mm
211+
mxfp8_us = benchmark_cuda_function_in_microseconds(
212+
torch.ops.fbgemm.mx8mx8bf16_grouped_stacked,
213+
A_fp8,
214+
B_fp8,
215+
A_scales_blocked,
216+
B_scales_blocked,
217+
group_sizes,
218+
starting_row_after_padding=starting_row_after_padding,
219+
)
220+
return mxfp8_us
221+
222+
223+
def main(args: argparse.Namespace):
224+
torch.random.manual_seed(123)
225+
configs = get_configs()
226+
results = []
227+
for config in tqdm(configs):
228+
result = run_experiment(config, args)
229+
results.append(Experiment(config=config, result=result))
230+
231+
# Use Tabulate to print results
232+
print_results(results)
233+
234+
235+
if __name__ == "__main__":
236+
arg_parser = argparse.ArgumentParser()
237+
args = arg_parser.parse_args()
238+
main(args)

0 commit comments

Comments
 (0)