Skip to content

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Aug 22, 2025

Stacked PRs:


[mxfp8 moe] add support for fbgemm 2d-3d mx8mx8bf16 grouped gemm

Summary

  • fbgemm recently added a 2d-3d mxfp8 grouped gemm in Enable MXFP8 grouped GEMM FBGEMM#4710
  • This PR integrates the new gemm into the MoE training code base for the following 2d-3d grouped gemms:
    • output = input @ weight^T
    • grad_input = grad_output @ weight
  • Add new mxfp8 utils to_blocked_per_group_2d (for input scales) and to_blocked_per_group_3d (for weight scales). These are pytorch reference implementations that are not performant. We can implement equivalent triton kernels for them later.
  • Notes on fbgemm API and pytorch grouped mm API:
    • x must be shape (Mg, K) and row major / contiguous
    • x scales must have been preprocessed to have per-group blocked layout and be contiguous
    • weights must be shape (E, N, K) and row major / contiguous
    • weight scales must have been pre-processed to have per-group blocked layout and be contiguous
    • group sizes is a vector containing the size of each token group in the x tensor
    • starting_row_after_padding corresponds to x_scales tensor and must be size len(group_sizes) + 1 where the first starting row is always 0, and each value corresponds to the starting row of group[i] in the x_scales tensor AFTER padding
  • Refactor _emulated_mxfp8_scaled_grouped_mm_2d_3d to have same function signature and input constraints as the fbgemm API

Test plan

  • pytest test/prototype/moe_training/test_scaled_grouped_mm.py -k mx
  • pytest test/prototype/moe_training/test_training.py -k mx

Copy link

pytorch-bot bot commented Aug 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2848

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ You can merge normally! (1 Unrelated Failure)

As of commit f70cc90 with merge base 8722c0c (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

danielvegamyhre added a commit that referenced this pull request Aug 22, 2025
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/55 branch from a5a29db to e8759f2 Compare August 22, 2025 14:42
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 22, 2025
danielvegamyhre added a commit that referenced this pull request Aug 22, 2025
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/55 branch from e8759f2 to 2638c34 Compare August 22, 2025 14:45
danielvegamyhre added a commit that referenced this pull request Aug 22, 2025
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/55 branch from 2638c34 to b249e0c Compare August 22, 2025 14:45
danielvegamyhre added a commit that referenced this pull request Aug 23, 2025
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/55 branch from b249e0c to 57d96f2 Compare August 23, 2025 00:08
danielvegamyhre added a commit that referenced this pull request Aug 23, 2025
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/55 branch from 57d96f2 to 24ac553 Compare August 23, 2025 01:50
danielvegamyhre added a commit that referenced this pull request Aug 23, 2025
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/55 branch from 24ac553 to 2772a69 Compare August 23, 2025 16:14
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Aug 23, 2025
danielvegamyhre added a commit that referenced this pull request Aug 23, 2025
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/55 branch from 2772a69 to 6444a5e Compare August 23, 2025 17:03
@danielvegamyhre danielvegamyhre changed the title [mxfp8 moe] add support for fbgemm 2d-3d mx8mx8bf16 grouped gemm [mxfp8 moe] add support for fbgemm 2d-3d mx8mx8bf16 grouped gemm, using uniform group sizes Aug 23, 2025
danielvegamyhre added a commit that referenced this pull request Aug 23, 2025
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/55 branch from 6444a5e to b975201 Compare August 23, 2025 17:18
@danielvegamyhre danielvegamyhre changed the title [mxfp8 moe] add support for fbgemm 2d-3d mx8mx8bf16 grouped gemm, using uniform group sizes [mxfp8 moe] add support for fbgemm 2d-3d mx8mx8bf16 grouped gemm Aug 23, 2025
danielvegamyhre added a commit that referenced this pull request Aug 23, 2025
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/55 branch from b975201 to 2c8371d Compare August 23, 2025 17:20
danielvegamyhre added a commit that referenced this pull request Aug 23, 2025
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/55 branch from 2c8371d to 7320c14 Compare August 23, 2025 17:24
danielvegamyhre added a commit that referenced this pull request Aug 25, 2025
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/55 branch from 7320c14 to e3a6297 Compare August 25, 2025 20:24
Copy link

pytorch-bot bot commented Aug 25, 2025

This PR needs to be approved by an authorized maintainer before merge.

danielvegamyhre added a commit that referenced this pull request Aug 25, 2025
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/55 branch from e3a6297 to 3cd9a50 Compare August 25, 2025 22:51
danielvegamyhre added a commit that referenced this pull request Aug 26, 2025
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/55 branch from 3cd9a50 to ad54e2b Compare August 26, 2025 00:21
danielvegamyhre added a commit that referenced this pull request Aug 26, 2025
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/55 branch from ad54e2b to cabb470 Compare August 26, 2025 00:59
@danielvegamyhre
Copy link
Contributor Author

@drisspg @vkuzo this is working for all tested cases now and is ready for review/land

@@ -402,12 +400,30 @@ def backward(ctx, grad_out: torch.Tensor):
def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
A_mx: torch.Tensor,
A_scale: torch.Tensor,
B_t_mx: torch.Tensor,
B_t_scale: torch.Tensor,
B_mx: torch.Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: _mx should be for a combination of raw data and scale, if B_mx is just the data then better to call it something else

blocked_scales: Tensor
start_row_after_padding: Tensor of shape (num_groups,) which contains the start row after padding for each group.
"""
from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import _to_blocked
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the same function as the one we have in torchao?

def to_blocked(input_matrix, use_triton_kernel: bool = False) -> Tensor:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will test if they're the same and replace if we can - I was having trouble getting the kernel working without CUDA errors so was trying to minimize differences between fbgemm unit test code and this torchao code path.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be the exact same

Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamping since this is prototype

stack-info: PR: #2848, branch: danielvegamyhre/stack/55
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/55 branch from 0dd17b5 to f70cc90 Compare August 27, 2025 16:23
@danielvegamyhre danielvegamyhre merged commit 15a6de6 into main Aug 27, 2025
19 of 20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/rocm ciflow/4xh100 CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. mx topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants