Skip to content

Commit 6444a5e

Browse files
[mxfp8 moe] add support for fbgemm 2d-3d mx8mx8bf16 grouped gemm
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
1 parent b663faf commit 6444a5e

File tree

6 files changed

+303
-53
lines changed

6 files changed

+303
-53
lines changed

test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -230,25 +230,26 @@ def compute_reference_forward(
230230
@pytest.mark.parametrize("num_experts", (1, 8, 16))
231231
def test_emulate_mxfp8_grouped_gemm_2d_3d(M, K, N, num_experts):
232232
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
233-
w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device="cuda")
233+
w = torch.randn(num_experts, N, K, dtype=torch.bfloat16, device="cuda")
234234
offs = generate_jagged_offs(num_experts, M)
235-
x_ref, w_t_ref, offs_ref = x.clone(), w_t.clone(), offs.clone()
235+
x_ref, w_ref, offs_ref = x.clone(), w.clone(), offs.clone()
236236

237237
# Quantize inputs to mxpf8 for emulated mxfp8 scaled grouped mm
238238
block_size = 32
239-
x_scale, x_mx = to_mx(x, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
239+
x_scale, x_fp8 = to_mx(x, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
240240

241241
# To cast B_t per-expert to mxfp8 across dim1, we transpose the experts, cast along dim -1, then untranspose.
242-
w_scale, w_mx = to_mx(
243-
w_t.transpose(-2, -1).contiguous(),
242+
w_scale, w_fp8 = to_mx(
243+
w,
244244
elem_dtype=torch.float8_e4m3fn,
245245
block_size=block_size,
246246
)
247-
w_t_scale, w_t_mx = w_scale.transpose(-2, -1), w_mx.transpose(-2, -1)
248247

249-
ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16)
248+
ref_out = torch._grouped_mm(
249+
x_ref, w_ref.transpose(-2, -1), offs=offs_ref, out_dtype=torch.bfloat16
250+
)
250251
out = _emulated_mxfp8_scaled_grouped_mm_2d_3d(
251-
x_mx, x_scale, w_t_mx, w_t_scale, offs=offs, out_dtype=torch.bfloat16
252+
x_fp8, x_scale, w_fp8, w_scale, offs=offs, out_dtype=torch.bfloat16
252253
)
253254

254255
sqnr = compute_error(ref_out, out)
@@ -305,19 +306,27 @@ def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):
305306

306307

307308
@skip_if_rocm("ROCm not supported")
308-
@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
309-
@pytest.mark.parametrize("num_experts", (1, 8, 16))
309+
@pytest.mark.parametrize("M,K,N", [(256, 512, 512)])
310+
@pytest.mark.parametrize("num_experts", (2,))
310311
def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts):
311312
from torchao.prototype.moe_training.scaled_grouped_mm import (
312313
_MXFP8GroupedMM,
313314
)
314315

315316
block_size = 32
316317
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
317-
w_t = torch.randn(
318-
num_experts, K, N, dtype=torch.bfloat16, device="cuda", requires_grad=True
318+
w = torch.randn(
319+
num_experts,
320+
N,
321+
K,
322+
dtype=torch.bfloat16,
323+
device="cuda",
319324
)
320-
offs = generate_jagged_offs(num_experts, M, multiple_of=block_size)
325+
w_t = w.transpose(-2, -1).requires_grad_(True)
326+
# TODO: use non-uniform group sizes once kernel supports it
327+
# offs = generate_jagged_offs(num_experts, M, multiple_of=block_size)
328+
group_size = M // num_experts
329+
offs = torch.arange(group_size, M + 1, group_size, device="cuda", dtype=torch.int32)
321330
x_ref, w_t_ref, offs_ref = (
322331
x.clone().detach().requires_grad_(True),
323332
w_t.clone().detach().requires_grad_(True),

test/prototype/moe_training/test_training.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
129129
)
130130

131131

132+
@pytest.mark.skip(
133+
"temporarily disable until non-uniform group sizes are supported by mxfp8 grouped gemm"
134+
)
132135
@pytest.mark.parametrize(
133136
"target_fqns",
134137
[

torchao/prototype/moe_training/kernels/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@
77
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
88
triton_fp8_per_group_rowwise_scales as triton_fp8_per_group_rowwise_scales,
99
)
10+
from torchao.prototype.moe_training.kernels.mxfp8 import (
11+
fbgemm_mxfp8_grouped_mm_2d_3d as fbgemm_mxfp8_grouped_mm_2d_3d,
12+
)
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import logging
2+
3+
import torch
4+
5+
from torchao.prototype.mx_formats.utils import (
6+
to_blocked_per_group_2d,
7+
to_blocked_per_group_3d,
8+
)
9+
10+
logger: logging.Logger = logging.getLogger(__name__)
11+
12+
try:
13+
import fbgemm_gpu.experimental.gen_ai # noqa: F401
14+
except Exception as e:
15+
logging.warning(
16+
f"fbgemm_gpu_genai package is required for this feature but import failed with exception: {e}"
17+
"Please install nightly builds of pytorch and fbgemm_gpu_genai build using this command and try again: "
18+
"pip3 install --force-reinstall --pre torch fbgemm-gpu-genai --index-url https://download.pytorch.org/whl/nightly/cu129"
19+
"If errors persist, please file a bug report."
20+
)
21+
22+
23+
@torch.library.custom_op("torchao::fbgemm_mxfp8_grouped_mm_2d_3d", mutates_args={})
24+
def fbgemm_mxfp8_grouped_mm_2d_3d(
25+
A_fp8: torch.Tensor,
26+
A_scales: torch.Tensor,
27+
B_fp8: torch.Tensor,
28+
B_scales: torch.Tensor,
29+
offs: torch.Tensor,
30+
block_size: int = 32,
31+
out_dtype: torch.dtype = torch.bfloat16,
32+
) -> torch.Tensor:
33+
assert A_fp8.ndim == 2, "A_fp8 tensor must be 2D"
34+
assert B_fp8.ndim == 3, "B_fp8 tensor must be 3D"
35+
assert block_size == 32, "Only block_size=32 is supported"
36+
assert out_dtype == torch.bfloat16, "Only out_dtype=bfloat16 is supported"
37+
assert A_fp8.shape[-1] == B_fp8.shape[-1], "A_fp8 and B_fp8 must have same last dim"
38+
39+
# Convert scales for each group to blocked format.
40+
Mg, K = A_fp8.shape
41+
A_scales_blocked, starting_row_after_padding = to_blocked_per_group_2d(
42+
A_scales, offs, Mg, K
43+
)
44+
B_scales_blocked = to_blocked_per_group_3d(B_scales)
45+
46+
# From this, we compute `group_sizes` and `starting_row_after_padding`:
47+
# group_sizes = [32, 32, 64]
48+
# starting_row_after_padding = [0, 32, 64, 128]
49+
group_sizes = torch.diff(starting_row_after_padding).to(torch.int64)
50+
51+
# TODO: remove debug logging once prototype is more mature.
52+
53+
out = torch.ops.fbgemm.mx8mx8bf16_grouped_stacked(
54+
A_fp8,
55+
B_fp8,
56+
A_scales_blocked,
57+
B_scales_blocked,
58+
group_sizes,
59+
starting_row_after_padding=starting_row_after_padding,
60+
)
61+
return out
62+
63+
64+
@fbgemm_mxfp8_grouped_mm_2d_3d.register_fake
65+
def _fbgemm_mxfp8_grouped_mm_2d_3d_fake(
66+
A_fp8: torch.Tensor,
67+
B_fp8: torch.Tensor,
68+
A_scales: torch.Tensor,
69+
B_scales: torch.Tensor,
70+
offs: torch.Tensor,
71+
) -> torch.Tensor:
72+
assert A_fp8.ndim == 2, "A_fp8 tensor must be 2D"
73+
assert B_fp8.ndim == 3, "B_fp8 tensor must be 3D"
74+
mg, k = A_fp8.shape
75+
e, k, n = B_fp8.shape
76+
n_groups = offs.numel()
77+
assert n_groups == e, (
78+
"Size of `offs` (number of groups) must match first dim of `B_fp8`"
79+
)
80+
output = torch.empty((mg, n), dtype=torch.bfloat16, device=A_fp8.device)
81+
return output
82+
83+
84+
def _log_inputs(
85+
A_fp8: torch.Tensor,
86+
B_fp8: torch.Tensor,
87+
A_scales: torch.Tensor,
88+
A_scales_blocked: torch.Tensor,
89+
B_scales: torch.Tensor,
90+
B_scales_blocked: torch.Tensor,
91+
offs: torch.Tensor,
92+
group_sizes: torch.Tensor,
93+
starting_row_after_padding: torch.Tensor,
94+
):
95+
print(f"offs: {offs}")
96+
print("A_fp8.shape", A_fp8.shape, "stride", A_fp8.stride(), "dtype", A_fp8.dtype)
97+
print(
98+
"B_fp8.shape",
99+
B_fp8.shape,
100+
"stride",
101+
B_fp8.stride(),
102+
"dtype",
103+
B_fp8.dtype,
104+
)
105+
print(
106+
"A_scales (non-blocked)",
107+
A_scales.shape,
108+
"stride",
109+
A_scales.stride(),
110+
"dtype",
111+
A_scales.dtype,
112+
)
113+
print(
114+
"A_scaless_blocked.shape",
115+
A_scales_blocked.shape,
116+
"stride",
117+
A_scales_blocked.stride(),
118+
"dtype",
119+
A_scales_blocked.dtype,
120+
)
121+
print(
122+
"B_scales (non-blocked)",
123+
B_scales.shape,
124+
"stride",
125+
B_scales.stride(),
126+
"dtype",
127+
B_scales.dtype,
128+
)
129+
print(
130+
"B_scales_blocked.shape",
131+
B_scales_blocked.shape,
132+
"stride",
133+
B_scales_blocked.stride(),
134+
"dtype",
135+
B_scales_blocked.dtype,
136+
)
137+
print(
138+
"group_sizes",
139+
group_sizes,
140+
"group_sizes.stride",
141+
group_sizes.stride(),
142+
"dtype",
143+
group_sizes.dtype,
144+
)
145+
print(
146+
"starting_row_after_padding",
147+
starting_row_after_padding,
148+
"stride",
149+
starting_row_after_padding.stride(),
150+
"dtype",
151+
starting_row_after_padding.dtype,
152+
)

0 commit comments

Comments
 (0)