Skip to content

Commit 0558c1c

Browse files
[mxfp8 moe training] add triton kernel for blocked swizzled 3d weight scales
1 parent a174a57 commit 0558c1c

File tree

3 files changed

+259
-98
lines changed

3 files changed

+259
-98
lines changed
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
from dataclasses import dataclass
9+
from typing import List
10+
11+
import torch
12+
from tabulate import tabulate
13+
from tqdm import tqdm
14+
from utils import benchmark_cuda_function_in_microseconds
15+
16+
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
17+
torch_to_blocked_per_group_3d,
18+
triton_mx_block_rearrange_per_group_3d,
19+
)
20+
21+
device = torch.device("cuda")
22+
23+
# Needed since changing args to function causes recompiles
24+
torch._dynamo.config.cache_size_limit = 1000
25+
26+
27+
@dataclass(frozen=True)
28+
class ExperimentConfig:
29+
input_shape: tuple[int]
30+
31+
32+
@dataclass(frozen=True)
33+
class ExperimentResult:
34+
torch_time_us: float
35+
triton_time_us: float
36+
torch_mem_bw_gbps: float
37+
triton_mem_bw_gbps: float
38+
39+
40+
@dataclass(frozen=True)
41+
class Experiment:
42+
config: ExperimentConfig
43+
result: ExperimentResult
44+
45+
46+
def get_configs() -> List[ExperimentConfig]:
47+
# Llama4 shapes. Input activations are scaled along K dim.
48+
block_size = 32
49+
input_shapes = [
50+
# w1, w3
51+
(1, 8192, 5120 // block_size),
52+
(2, 8192, 5120 // block_size),
53+
(8, 8192, 5120 // block_size),
54+
(16, 8192, 5120 // block_size),
55+
# w2
56+
(1, 5120, 8192 // block_size),
57+
(2, 5120, 8192 // block_size),
58+
(8, 5120, 8192 // block_size),
59+
(16, 5120, 8192 // block_size),
60+
]
61+
configs = []
62+
for shape in input_shapes:
63+
configs.append(
64+
ExperimentConfig(
65+
input_shape=shape,
66+
)
67+
)
68+
return configs
69+
70+
71+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
72+
input_tensor = torch.randint(
73+
low=0,
74+
high=256,
75+
size=config.input_shape,
76+
dtype=torch.uint8,
77+
device=device,
78+
)
79+
80+
E, N, K = config.input_shape
81+
82+
# bench torch
83+
compiled_run_torch = torch.compile(torch_to_blocked_per_group_3d)
84+
_ = compiled_run_torch(input_tensor)
85+
torch_time_us = benchmark_cuda_function_in_microseconds(
86+
compiled_run_torch,
87+
input_tensor,
88+
)
89+
90+
# bench triton
91+
triton_out_scales = triton_mx_block_rearrange_per_group_3d(input_tensor)
92+
triton_time_us = benchmark_cuda_function_in_microseconds(
93+
triton_mx_block_rearrange_per_group_3d,
94+
input_tensor,
95+
)
96+
97+
# mem bw calculations
98+
bytes_per_input_el = torch.finfo(torch.float8_e8m0fnu).bits / 8
99+
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8
100+
101+
read_bytes = input_tensor.numel() * bytes_per_input_el
102+
write_bytes = triton_out_scales.numel() * bytes_per_output_el
103+
104+
torch_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6)
105+
triton_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6)
106+
107+
return ExperimentResult(
108+
torch_time_us=torch_time_us,
109+
triton_time_us=triton_time_us,
110+
torch_mem_bw_gbps=torch_mem_bw_gbps,
111+
triton_mem_bw_gbps=triton_mem_bw_gbps,
112+
)
113+
114+
115+
def print_results(experiments: List[Experiment]):
116+
headers = [
117+
"input_shape",
118+
"torch_time_us",
119+
"triton_time_us",
120+
"torch_mem_bw_gbps",
121+
"triton_mem_bw_gbps",
122+
"triton_speedup",
123+
]
124+
rows = []
125+
for experiment in experiments:
126+
input_shape = f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1]}, {experiment.config.input_shape[2]})"
127+
rows.append(
128+
[
129+
input_shape,
130+
experiment.result.torch_time_us,
131+
experiment.result.triton_time_us,
132+
round(experiment.result.torch_mem_bw_gbps, 3),
133+
round(experiment.result.triton_mem_bw_gbps, 3),
134+
f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x",
135+
]
136+
)
137+
print(tabulate(rows, headers=headers))
138+
139+
140+
def main():
141+
torch.random.manual_seed(123)
142+
configs = get_configs()
143+
results = []
144+
for config in tqdm(configs):
145+
result = run_experiment(config)
146+
results.append(Experiment(config=config, result=result))
147+
148+
# Use Tabulate to print results
149+
print_results(results)
150+
151+
152+
if __name__ == "__main__":
153+
main()

test/prototype/moe_training/test_kernels.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
2424
compute_per_group_blocked_scale_offsets,
2525
torch_to_blocked_per_group_2d,
26+
torch_to_blocked_per_group_3d,
2627
triton_mx_block_rearrange_per_group_2d,
28+
triton_mx_block_rearrange_per_group_3d,
2729
)
2830
from torchao.prototype.moe_training.utils import (
2931
_is_column_major,
@@ -204,3 +206,27 @@ def test_mxfp8_per_group_blocked_scales_2d(
204206
assert torch.allclose(ref_out_scales, triton_out_scales, atol=0, rtol=0), (
205207
"blocked scales not equal"
206208
)
209+
210+
211+
@skip_if_rocm("ROCm enablement in progress")
212+
@pytest.mark.parametrize("e,n,k", [(1, 8192, 5120), (2, 8192, 5120), (8, 5120, 8192)])
213+
def test_mxfp8_per_group_blocked_scales_3d(
214+
e: int,
215+
n: int,
216+
k: int,
217+
):
218+
device = "cuda"
219+
block_size = 32
220+
weights = torch.randn(e, n, k // block_size, device=device)
221+
weight_scales, _ = to_mx(
222+
weights, elem_dtype=torch.float8_e4m3fn, block_size=block_size
223+
)
224+
225+
# torch reference
226+
ref_out_scales = torch_to_blocked_per_group_3d(weight_scales)
227+
228+
# triton kernel
229+
triton_out_scales = triton_mx_block_rearrange_per_group_3d(weight_scales)
230+
assert torch.allclose(ref_out_scales, triton_out_scales, atol=0, rtol=0), (
231+
"blocked scales not equal"
232+
)

0 commit comments

Comments
 (0)