Skip to content

Commit b52d7d1

Browse files
[mxfp8 moe training] add per group blocked scale kernels
stack-info: PR: #2886, branch: danielvegamyhre/stack/62
1 parent a08719c commit b52d7d1

File tree

7 files changed

+519
-88
lines changed

7 files changed

+519
-88
lines changed
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
import itertools
9+
from dataclasses import dataclass
10+
from typing import List
11+
12+
import torch
13+
from tabulate import tabulate
14+
from tqdm import tqdm
15+
from utils import benchmark_cuda_function_in_microseconds
16+
17+
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
18+
compute_per_group_blocked_scale_offsets,
19+
torch_to_blocked_per_group_2d,
20+
triton_mx_block_rearrange_per_group,
21+
)
22+
from torchao.prototype.moe_training.utils import generate_jagged_offs
23+
24+
device = torch.device("cuda")
25+
26+
# Needed since changing args to function causes recompiles
27+
torch._dynamo.config.cache_size_limit = 1000
28+
29+
30+
@dataclass(frozen=True)
31+
class ExperimentConfig:
32+
input_shape: tuple[int]
33+
num_groups: int
34+
35+
36+
@dataclass(frozen=True)
37+
class ExperimentResult:
38+
torch_time_us: float
39+
triton_time_us: float
40+
torch_mem_bw_gbps: float
41+
triton_mem_bw_gbps: float
42+
43+
44+
@dataclass(frozen=True)
45+
class Experiment:
46+
config: ExperimentConfig
47+
result: ExperimentResult
48+
49+
50+
def get_configs() -> List[ExperimentConfig]:
51+
# Llama4 shapes. Input activations are scaled along K dim.
52+
block_size = 32
53+
input_shapes = [
54+
(16640, 5120 // block_size),
55+
]
56+
num_groups = [16]
57+
configs = []
58+
for shape, groups in itertools.product(
59+
input_shapes,
60+
num_groups,
61+
):
62+
configs.append(
63+
ExperimentConfig(
64+
input_shape=shape,
65+
num_groups=groups,
66+
)
67+
)
68+
return configs
69+
70+
71+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
72+
input_shape, num_groups = config.input_shape, config.num_groups
73+
input_tensor = torch.randint(
74+
low=0,
75+
high=256,
76+
size=input_shape,
77+
dtype=torch.uint8,
78+
device=device,
79+
)
80+
81+
Mg, K = input_shape
82+
input_group_offsets = generate_jagged_offs(num_groups, Mg, multiple_of=32)
83+
84+
# bench torch
85+
compiled_run_torch = torch.compile(torch_to_blocked_per_group_2d)
86+
torch_out_scales, torch_group_offs = compiled_run_torch(
87+
input_tensor, input_group_offsets, Mg, K
88+
)
89+
torch_time_us = benchmark_cuda_function_in_microseconds(
90+
compiled_run_torch,
91+
input_tensor,
92+
input_group_offsets,
93+
Mg,
94+
K,
95+
)
96+
97+
# bench triton
98+
_, output_group_offsets = compute_per_group_blocked_scale_offsets(
99+
input_group_offsets
100+
)
101+
triton_out_scales = triton_mx_block_rearrange_per_group(
102+
input_tensor,
103+
input_group_offsets,
104+
output_group_offsets,
105+
)
106+
triton_time_us = benchmark_cuda_function_in_microseconds(
107+
triton_mx_block_rearrange_per_group,
108+
input_tensor,
109+
input_group_offsets,
110+
output_group_offsets,
111+
)
112+
113+
# mem bw calculations
114+
bytes_per_input_el = torch.finfo(torch.float8_e8m0fnu).bits / 8
115+
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8
116+
117+
read_bytes = input_tensor.numel() * bytes_per_input_el
118+
write_bytes = triton_out_scales.numel() * bytes_per_output_el
119+
120+
torch_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6)
121+
triton_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6)
122+
123+
return ExperimentResult(
124+
torch_time_us=torch_time_us,
125+
triton_time_us=triton_time_us,
126+
torch_mem_bw_gbps=torch_mem_bw_gbps,
127+
triton_mem_bw_gbps=triton_mem_bw_gbps,
128+
)
129+
130+
131+
def print_results(experiments: List[Experiment]):
132+
headers = [
133+
"input_shape",
134+
"torch_time_us",
135+
"triton_time_us",
136+
"torch_mem_bw_gbps",
137+
"triton_mem_bw_gbps",
138+
"triton_speedup",
139+
]
140+
rows = []
141+
for experiment in experiments:
142+
input_shape = (
143+
f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1]})"
144+
)
145+
rows.append(
146+
[
147+
input_shape,
148+
experiment.result.torch_time_us,
149+
experiment.result.triton_time_us,
150+
round(experiment.result.torch_mem_bw_gbps, 3),
151+
round(experiment.result.triton_mem_bw_gbps, 3),
152+
f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x",
153+
]
154+
)
155+
print(tabulate(rows, headers=headers))
156+
157+
158+
def main():
159+
torch.random.manual_seed(123)
160+
configs = get_configs()
161+
results = []
162+
for config in tqdm(configs):
163+
result = run_experiment(config)
164+
results.append(Experiment(config=config, result=result))
165+
166+
# Use Tabulate to print results
167+
print_results(results)
168+
169+
170+
if __name__ == "__main__":
171+
main()

test/prototype/moe_training/test_kernels.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,19 @@
2020
triton_fp8_per_group_colwise_scales,
2121
triton_fp8_per_group_rowwise_scales,
2222
)
23+
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
24+
compute_per_group_blocked_scale_offsets,
25+
torch_to_blocked_per_group_2d,
26+
triton_mx_block_rearrange_per_group,
27+
)
2328
from torchao.prototype.moe_training.utils import (
2429
_is_column_major,
30+
generate_jagged_offs,
2531
torch_to_3d_rowwise_float8_transpose_rhs,
2632
torch_to_float8_per_group_colwise,
2733
torch_to_float8_per_group_rowwise,
2834
)
35+
from torchao.prototype.mx_formats.mx_tensor import to_mx
2936
from torchao.testing.utils import skip_if_rocm
3037

3138

@@ -159,3 +166,41 @@ def test_fp8_rowwise_3d_transpose_rhs(round_scales_to_power_of_2: bool):
159166
assert ref_fp8.shape == triton_fp8.shape, "output shapes not equal"
160167
assert ref_fp8.stride() == triton_fp8.stride(), "output strides not equal"
161168
assert torch.allclose(ref_fp8, triton_fp8, rtol=0, atol=0), "fp8 data not equal"
169+
170+
171+
@skip_if_rocm("ROCm enablement in progress")
172+
@pytest.mark.parametrize(
173+
"m,k,n_groups", [(256, 256, 4), (16640, 5120, 16), (16640, 8192, 16)]
174+
)
175+
def test_mxfp8_per_group_blocked_scales_2d(
176+
m: int,
177+
k: int,
178+
n_groups: int,
179+
):
180+
device = "cuda"
181+
block_size = 32
182+
input_data = torch.randn(m, k, device=device)
183+
e8m0_scales, _ = to_mx(
184+
input_data, elem_dtype=torch.float8_e4m3fn, block_size=block_size
185+
)
186+
input_group_offsets = generate_jagged_offs(
187+
n_groups, m, multiple_of=block_size, device=device
188+
)
189+
190+
# torch reference
191+
ref_out_scales, _ = torch_to_blocked_per_group_2d(
192+
e8m0_scales, input_group_offsets, m, k, block_size=block_size
193+
)
194+
195+
# triton kernel
196+
_, output_group_offsets = compute_per_group_blocked_scale_offsets(
197+
input_group_offsets
198+
)
199+
triton_out_scales = triton_mx_block_rearrange_per_group(
200+
e8m0_scales,
201+
input_group_offsets,
202+
output_group_offsets,
203+
)
204+
assert torch.allclose(ref_out_scales, triton_out_scales, atol=0, rtol=0), (
205+
"blocked scales not equal"
206+
)

torchao/prototype/moe_training/kernels/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
88
triton_fp8_per_group_rowwise_scales as triton_fp8_per_group_rowwise_scales,
99
)
10-
from torchao.prototype.moe_training.kernels.mxfp8 import (
10+
from torchao.prototype.moe_training.kernels.mxfp8_gemms import (
1111
fbgemm_mxfp8_grouped_mm_2d_3d as fbgemm_mxfp8_grouped_mm_2d_3d,
1212
)

0 commit comments

Comments
 (0)