Skip to content

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Aug 23, 2025

Stacked PRs:


[moe fp8 training] test and bench new faster method for per group rowwise scaling

Summary

  • Per group rowwise scaling kernel used in the backward pass for MoE fp8 rowwise training is slow, due to uncoalesced global accesses arising from the row major layout of the tensor conflicting with the memory access pattern in the kernel.
  • To get around this, we can compute per group colwise scales on the transpose of the input, then transpose the result, to effectively compute per group rowwise scales.

Changes

  • Add tests and benchmarks for this transpose based method
  • Make separate kernel benchmark scripts for rowwise vs colwise per group scaling, rather than having 1 combined one.
  • Update mem bw calcs

Benchmarks

  • Benchmarking shows the new method is ~3x faster across all shapes tested for rowwise
Mg,N             n_groups    torch_loop_time_us    triton_time_us    triton_transpose_us    torch_mem_bw_gbps    triton_mem_bw_gbps    triton_transpose_mem_bw_gbps  triton_speedup    triton_transpose_speedup
-------------  ----------  --------------------  ----------------  ---------------------  -------------------  --------------------  ------------------------------  ----------------  --------------------------
(16640, 8192)           1               1705.38           2422.82                768.416              399.681               281.328                         887.029  0.70x             2.22x
(16640, 8192)          16               3078.21           2310.27                686.912              221.59                295.246                         992.993  1.33x             4.48x
(16640, 8192)          64              10639.2            2177.7                 784.288               64.26                313.943                         871.71   4.89x             13.57x

Copy link

pytorch-bot bot commented Aug 23, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2863

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a1c0745 with merge base 253d65a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 23, 2025
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/57 branch from afd9cb6 to 7af9f68 Compare August 23, 2025 23:42
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/57 branch from 7af9f68 to f5f64e0 Compare August 24, 2025 00:13
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Aug 24, 2025
…wise scaling

stack-info: PR: #2863, branch: danielvegamyhre/stack/57
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/57 branch from f5f64e0 to a1c0745 Compare August 24, 2025 01:08
@danielvegamyhre danielvegamyhre merged commit 8722c0c into main Aug 27, 2025
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants