-
Notifications
You must be signed in to change notification settings - Fork 322
[mxfp8 moe] add support for fbgemm 2d-3d mx8mx8bf16 grouped gemm #2848
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2848
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ You can merge normally! (1 Unrelated Failure)As of commit f70cc90 with merge base 8722c0c ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
a5a29db
to
e8759f2
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
e8759f2
to
2638c34
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
2638c34
to
b249e0c
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
b249e0c
to
57d96f2
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
57d96f2
to
24ac553
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
24ac553
to
2772a69
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
2772a69
to
6444a5e
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
6444a5e
to
b975201
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
b975201
to
2c8371d
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
2c8371d
to
7320c14
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
7320c14
to
e3a6297
Compare
This PR needs to be approved by an authorized maintainer before merge. |
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
e3a6297
to
3cd9a50
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
3cd9a50
to
ad54e2b
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
ad54e2b
to
cabb470
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
cabb470
to
0dd17b5
Compare
@@ -402,12 +400,30 @@ def backward(ctx, grad_out: torch.Tensor): | |||
def _emulated_mxfp8_scaled_grouped_mm_2d_3d( | |||
A_mx: torch.Tensor, | |||
A_scale: torch.Tensor, | |||
B_t_mx: torch.Tensor, | |||
B_t_scale: torch.Tensor, | |||
B_mx: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: _mx
should be for a combination of raw data and scale, if B_mx
is just the data then better to call it something else
blocked_scales: Tensor | ||
start_row_after_padding: Tensor of shape (num_groups,) which contains the start row after padding for each group. | ||
""" | ||
from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import _to_blocked |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this the same function as the one we have in torchao?
ao/torchao/prototype/mx_formats/utils.py
Line 18 in 6f035e8
def to_blocked(input_matrix, use_triton_kernel: bool = False) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will test if they're the same and replace if we can - I was having trouble getting the kernel working without CUDA errors so was trying to minimize differences between fbgemm unit test code and this torchao code path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be the exact same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stamping since this is prototype
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
0dd17b5
to
f70cc90
Compare
Stacked PRs:
[mxfp8 moe] add support for fbgemm 2d-3d mx8mx8bf16 grouped gemm
Summary
output = input @ weight^T
grad_input = grad_output @ weight
to_blocked_per_group_2d
(for input scales) andto_blocked_per_group_3d
(for weight scales). These are pytorch reference implementations that are not performant. We can implement equivalent triton kernels for them later.x
tensorx_scales
tensor and must be sizelen(group_sizes) + 1
where the first starting row is always 0, and each value corresponds to the starting row of group[i] in thex_scales
tensor AFTER padding_emulated_mxfp8_scaled_grouped_mm_2d_3d
to have same function signature and input constraints as the fbgemm APITest plan
pytest test/prototype/moe_training/test_scaled_grouped_mm.py -k mx
pytest test/prototype/moe_training/test_training.py -k mx