Skip to content

Commit 15a6de6

Browse files
[mxfp8 moe] add support for fbgemm 2d-3d mx8mx8bf16 grouped gemm (#2848)
1 parent 8669213 commit 15a6de6

File tree

6 files changed

+324
-168
lines changed

6 files changed

+324
-168
lines changed

test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -230,25 +230,26 @@ def compute_reference_forward(
230230
@pytest.mark.parametrize("num_experts", (1, 8, 16))
231231
def test_emulate_mxfp8_grouped_gemm_2d_3d(M, K, N, num_experts):
232232
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
233-
w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device="cuda")
233+
w = torch.randn(num_experts, N, K, dtype=torch.bfloat16, device="cuda")
234234
offs = generate_jagged_offs(num_experts, M)
235-
x_ref, w_t_ref, offs_ref = x.clone(), w_t.clone(), offs.clone()
235+
x_ref, w_ref, offs_ref = x.clone(), w.clone(), offs.clone()
236236

237237
# Quantize inputs to mxpf8 for emulated mxfp8 scaled grouped mm
238238
block_size = 32
239-
x_scale, x_mx = to_mx(x, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
239+
x_scale, x_fp8 = to_mx(x, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
240240

241241
# To cast B_t per-expert to mxfp8 across dim1, we transpose the experts, cast along dim -1, then untranspose.
242-
w_scale, w_mx = to_mx(
243-
w_t.transpose(-2, -1).contiguous(),
242+
w_scale, w_fp8 = to_mx(
243+
w,
244244
elem_dtype=torch.float8_e4m3fn,
245245
block_size=block_size,
246246
)
247-
w_t_scale, w_t_mx = w_scale.transpose(-2, -1), w_mx.transpose(-2, -1)
248247

249-
ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16)
248+
ref_out = torch._grouped_mm(
249+
x_ref, w_ref.transpose(-2, -1), offs=offs_ref, out_dtype=torch.bfloat16
250+
)
250251
out = _emulated_mxfp8_scaled_grouped_mm_2d_3d(
251-
x_mx, x_scale, w_t_mx, w_t_scale, offs=offs, out_dtype=torch.bfloat16
252+
x_fp8, x_scale, w_fp8, w_scale, offs=offs, out_dtype=torch.bfloat16
252253
)
253254

254255
sqnr = compute_error(ref_out, out)
@@ -305,18 +306,25 @@ def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):
305306

306307

307308
@skip_if_rocm("ROCm not supported")
308-
@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
309-
@pytest.mark.parametrize("num_experts", (1, 8, 16))
309+
@pytest.mark.parametrize(
310+
"M,K,N", [(1024, 5120, 8192), (2048, 5120, 8192), (16640, 5120, 8192)]
311+
)
312+
@pytest.mark.parametrize("num_experts", (2, 4, 8, 16))
310313
def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts):
311314
from torchao.prototype.moe_training.scaled_grouped_mm import (
312315
_MXFP8GroupedMM,
313316
)
314317

315318
block_size = 32
316319
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
317-
w_t = torch.randn(
318-
num_experts, K, N, dtype=torch.bfloat16, device="cuda", requires_grad=True
320+
w = torch.randn(
321+
num_experts,
322+
N,
323+
K,
324+
dtype=torch.bfloat16,
325+
device="cuda",
319326
)
327+
w_t = w.transpose(-2, -1).requires_grad_(True)
320328
offs = generate_jagged_offs(num_experts, M, multiple_of=block_size)
321329
x_ref, w_t_ref, offs_ref = (
322330
x.clone().detach().requires_grad_(True),

test/prototype/moe_training/test_training.py

Lines changed: 29 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -40,109 +40,38 @@
4040
],
4141
)
4242
@pytest.mark.parametrize("compile", [False, True])
43-
def test_moe_float8_training(target_fqns: list[str], compile: bool):
44-
# Set token group alignment size to 16. This is required so that
45-
# each logically distinct gemm in the grouped gemm `grad_weight = grad_output_t @ input`
46-
# has the contraction dim be divisible by 16. 16 byte alignment is required
47-
# for the slowest moving dim (stride 1), so 16 bytes / 1 byte per element in fp8 = 16 elements.
48-
set_token_group_alignment_size_m(16)
49-
model_args = MoEArgs(
50-
num_experts=8,
51-
)
52-
init_std = 0.02
53-
device = torch.device("cuda")
54-
55-
# reference bf16 MoE
56-
dim, hidden_dim = 5120, 8192
57-
ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda()
58-
torch.manual_seed(42)
59-
ref_model.init_weights(init_std, device)
60-
61-
# target MoE for testing conversion
62-
model = copy.deepcopy(ref_model)
63-
64-
# assert starting params are identical for both models
65-
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
66-
assert torch.equal(param1, param2)
67-
68-
# convert MoE to float8 training
69-
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
70-
for target_fqn in target_fqns:
71-
if target_fqn in cur_fqn:
72-
return True
73-
return False
74-
75-
# quantize test model
76-
config = MoETrainingConfig()
77-
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
78-
79-
# validate that only the experts were converted
80-
_validate_model_conversion(
81-
model,
82-
target_fqns=target_fqns,
83-
)
84-
if compile:
85-
# TODO: compile with fullgraph=True when torchtitan llama4 moe supports it
86-
model = torch.compile(model, fullgraph=False)
87-
ref_model = torch.compile(ref_model, fullgraph=False)
88-
89-
# inputs
90-
batch, seq = 8, 2048
91-
ref_x = torch.randn(
92-
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
93-
)
94-
x = ref_x.detach().clone().requires_grad_(True)
95-
96-
# forward pass
97-
ref_out = ref_model(ref_x)
98-
out = model(x)
99-
100-
# validate output
101-
out_sqnr = compute_error(out, ref_out)
102-
min_out_sqnr = 29.0
103-
assert out_sqnr.item() >= min_out_sqnr, (
104-
f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}."
105-
)
106-
107-
# compute loss
108-
labels = torch.ones_like(ref_out)
109-
ref_loss = F.mse_loss(ref_out, labels)
110-
out_loss = F.mse_loss(out, labels)
111-
112-
# backward pass
113-
ref_loss.backward()
114-
out_loss.backward()
115-
116-
# validate input gradient
117-
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
118-
min_input_grad_sqnr = 29.0
119-
assert input_grad_sqnr.item() >= min_input_grad_sqnr, (
120-
f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}."
121-
)
122-
123-
# validate param gradients
124-
min_param_grad_sqnr = 23.0
125-
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
126-
param_grad_sqnr = compute_error(param1.grad, param2.grad)
127-
assert param_grad_sqnr.item() >= min_param_grad_sqnr, (
128-
f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}."
129-
)
130-
131-
13243
@pytest.mark.parametrize(
133-
"target_fqns",
44+
"recipe_config",
13445
[
135-
["experts"],
136-
["does.not.exist"],
46+
# {"recipe": MoEScalingType.FP8_ROWWISE, "group_alignment_size": 16, "min_out_sqnr": 29.0, "min_input_grad_sqnr": 29.0, "min_param_grad_sqnr": 23.0},
47+
{
48+
"recipe": MoEScalingType.MXFP8,
49+
"group_alignment_size": 32,
50+
"min_out_sqnr": 28.0,
51+
"min_input_grad_sqnr": 29.0,
52+
"min_param_grad_sqnr": 21.0,
53+
},
13754
],
13855
)
139-
@pytest.mark.parametrize("compile", [False, True])
140-
def test_moe_mxfp8_training(target_fqns: list[str], compile: bool):
141-
block_size = 32
142-
143-
# Token groups must be divisible by 32 for mxfp8
144-
set_token_group_alignment_size_m(block_size)
145-
56+
def test_moe_training(target_fqns: list[str], compile: bool, recipe_config: dict):
57+
(
58+
recipe,
59+
group_alignment_size,
60+
min_out_sqnr,
61+
min_input_grad_sqnr,
62+
min_param_grad_sqnr,
63+
) = (
64+
recipe_config["recipe"],
65+
recipe_config["group_alignment_size"],
66+
recipe_config["min_out_sqnr"],
67+
recipe_config["min_input_grad_sqnr"],
68+
recipe_config["min_param_grad_sqnr"],
69+
)
70+
# Set token group alignment size. This is required so that
71+
# each logically distinct gemm in the grouped gemm `grad_weight = grad_output_t @ input`
72+
# has the contraction dim be divisible by 16. 16 byte alignment is required
73+
# for the slowest moving dim (stride 1).
74+
set_token_group_alignment_size_m(group_alignment_size)
14675
model_args = MoEArgs(
14776
num_experts=8,
14877
)
@@ -170,15 +99,14 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
17099
return False
171100

172101
# quantize test model
173-
config = MoETrainingConfig(scaling_type=MoEScalingType.MXFP8)
102+
config = MoETrainingConfig(scaling_type=recipe)
174103
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
175104

176105
# validate that only the experts were converted
177106
_validate_model_conversion(
178107
model,
179108
target_fqns=target_fqns,
180109
)
181-
182110
if compile:
183111
# TODO: compile with fullgraph=True when torchtitan llama4 moe supports it
184112
model = torch.compile(model, fullgraph=False)
@@ -197,7 +125,6 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
197125

198126
# validate output
199127
out_sqnr = compute_error(out, ref_out)
200-
min_out_sqnr = 28.0
201128
assert out_sqnr.item() >= min_out_sqnr, (
202129
f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}."
203130
)
@@ -213,13 +140,11 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
213140

214141
# validate input gradient
215142
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
216-
min_input_grad_sqnr = 30.0
217143
assert input_grad_sqnr.item() >= min_input_grad_sqnr, (
218144
f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}."
219145
)
220146

221147
# validate param gradients
222-
min_param_grad_sqnr = 21.0
223148
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
224149
param_grad_sqnr = compute_error(param1.grad, param2.grad)
225150
assert param_grad_sqnr.item() >= min_param_grad_sqnr, (

torchao/prototype/moe_training/kernels/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@
77
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
88
triton_fp8_per_group_rowwise_scales as triton_fp8_per_group_rowwise_scales,
99
)
10+
from torchao.prototype.moe_training.kernels.mxfp8 import (
11+
fbgemm_mxfp8_grouped_mm_2d_3d as fbgemm_mxfp8_grouped_mm_2d_3d,
12+
)
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import logging
2+
3+
import torch
4+
5+
from torchao.prototype.mx_formats.utils import (
6+
to_blocked_per_group_2d,
7+
to_blocked_per_group_3d,
8+
)
9+
10+
logger: logging.Logger = logging.getLogger(__name__)
11+
12+
try:
13+
import fbgemm_gpu.experimental.gen_ai # noqa: F401
14+
except Exception as e:
15+
logging.warning(
16+
f"fbgemm_gpu_genai package is required for this feature but import failed with exception: {e}"
17+
"Please install nightly builds of pytorch and fbgemm_gpu_genai build using this command and try again: "
18+
"pip3 install --force-reinstall --pre torch fbgemm-gpu-genai --index-url https://download.pytorch.org/whl/nightly/cu129"
19+
"If errors persist, please file a bug report."
20+
)
21+
22+
23+
@torch.library.custom_op("torchao::fbgemm_mxfp8_grouped_mm_2d_3d", mutates_args={})
24+
def fbgemm_mxfp8_grouped_mm_2d_3d(
25+
A_fp8: torch.Tensor,
26+
A_scales: torch.Tensor,
27+
B_fp8: torch.Tensor,
28+
B_scales: torch.Tensor,
29+
offs: torch.Tensor,
30+
block_size: int = 32,
31+
out_dtype: torch.dtype = torch.bfloat16,
32+
) -> torch.Tensor:
33+
assert A_fp8.ndim == 2, "A_fp8 tensor must be 2D"
34+
assert B_fp8.ndim == 3, "B_fp8 tensor must be 3D"
35+
assert block_size == 32, "Only block_size=32 is supported"
36+
assert out_dtype == torch.bfloat16, "Only out_dtype=bfloat16 is supported"
37+
assert A_fp8.shape[-1] == B_fp8.shape[-1], "A_fp8 and B_fp8 must have same last dim"
38+
39+
# Convert scales for each group to blocked format.
40+
Mg, K = A_fp8.shape
41+
A_scales_blocked, starting_row_after_padding = to_blocked_per_group_2d(
42+
A_scales, offs, Mg, K
43+
)
44+
B_scales_blocked = to_blocked_per_group_3d(B_scales)
45+
46+
# From this, we compute `group_sizes` and `starting_row_after_padding`:
47+
# group_sizes = [32, 32, 64]
48+
# starting_row_after_padding = [0, 32, 64, 128]
49+
zero = torch.tensor([0], dtype=offs.dtype, device=offs.device)
50+
group_sizes = torch.diff(offs, prepend=zero).to(torch.int64)
51+
52+
# TODO: remove debug logging once prototype is more mature.
53+
_log_inputs(
54+
A_fp8,
55+
B_fp8,
56+
A_scales,
57+
A_scales_blocked,
58+
B_scales,
59+
B_scales_blocked,
60+
offs,
61+
group_sizes,
62+
starting_row_after_padding,
63+
)
64+
65+
out = torch.ops.fbgemm.mx8mx8bf16_grouped_stacked(
66+
A_fp8,
67+
B_fp8,
68+
A_scales_blocked,
69+
B_scales_blocked,
70+
group_sizes,
71+
starting_row_after_padding=starting_row_after_padding,
72+
)
73+
return out
74+
75+
76+
@fbgemm_mxfp8_grouped_mm_2d_3d.register_fake
77+
def _fbgemm_mxfp8_grouped_mm_2d_3d_fake(
78+
A_fp8: torch.Tensor,
79+
A_scales: torch.Tensor,
80+
B_fp8: torch.Tensor,
81+
B_scales: torch.Tensor,
82+
offs: torch.Tensor,
83+
block_size: int = 32,
84+
out_dtype: torch.dtype = torch.bfloat16,
85+
) -> torch.Tensor:
86+
assert A_fp8.ndim == 2, "A_fp8 tensor must be 2D"
87+
assert B_fp8.ndim == 3, "B_fp8 tensor must be 3D"
88+
assert out_dtype == torch.bfloat16, "Only out_dtype=bfloat16 is supported"
89+
assert A_fp8.shape[-1] == B_fp8.shape[-1], "A_fp8 and B_fp8 must have same last dim"
90+
mg, k = A_fp8.shape
91+
e, n, k = B_fp8.shape
92+
n_groups = offs.numel()
93+
assert n_groups == e, (
94+
"Size of `offs` (number of groups) must match first dim of `B_fp8`"
95+
)
96+
output = torch.empty((mg, n), dtype=torch.bfloat16, device=A_fp8.device)
97+
return output
98+
99+
100+
def _log_inputs(
101+
A_fp8: torch.Tensor,
102+
B_fp8: torch.Tensor,
103+
A_scales: torch.Tensor,
104+
A_scales_blocked: torch.Tensor,
105+
B_scales: torch.Tensor,
106+
B_scales_blocked: torch.Tensor,
107+
offs: torch.Tensor,
108+
group_sizes: torch.Tensor,
109+
starting_row_after_padding: torch.Tensor,
110+
):
111+
logger.info(f"offs: {offs}, dtype: {offs.dtype}")
112+
logger.info(
113+
f"A_fp8.shape: {A_fp8.shape}, stride: {A_fp8.stride()}, dtype: {A_fp8.dtype}"
114+
)
115+
logger.info(
116+
f"B_fp8.shape: {B_fp8.shape}, stride: {B_fp8.stride()}, dtype: {B_fp8.dtype}"
117+
)
118+
logger.info(
119+
f"A_scales (non-blocked) shape: {A_scales.shape}, stride: {A_scales.stride()}, dtype: {A_scales.dtype}"
120+
)
121+
logger.info(
122+
f"A_scales_blocked.shape: {A_scales_blocked.shape}, stride: {A_scales_blocked.stride()}, dtype: {A_scales_blocked.dtype}"
123+
)
124+
logger.info(
125+
f"B_scales (non-blocked) shape: {B_scales.shape}, stride: {B_scales.stride()}, dtype: {B_scales.dtype}"
126+
)
127+
logger.info(
128+
f"B_scales_blocked.shape: {B_scales_blocked.shape}, stride: {B_scales_blocked.stride()}, dtype: {B_scales_blocked.dtype}"
129+
)
130+
logger.info(
131+
f"group_sizes: {group_sizes}, stride: {group_sizes.stride()}, dtype: {group_sizes.dtype}"
132+
)
133+
logger.info(
134+
f"starting_row_after_padding: {starting_row_after_padding}, stride: {starting_row_after_padding.stride()}, dtype: {starting_row_after_padding.dtype}"
135+
)

0 commit comments

Comments
 (0)