Skip to content

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Aug 27, 2025

Stacked PRs:


[mxfp8 moe training] add per group blocked scale kernels for 2d input activations

Summary

  • We currently use a pytorch loop based impl for converting mxfp8 scales to block swizzled format on a per-group basis
  • This of course is not optimal for perf and results in d2h sync
  • This PR implements a Triton kernel which does this conversion, without doing a d2h sync then looping on the host
  • Note: to simplify the kernel, we pre-compute the start row of each group in the block padded output scales tensor (see compute_per_group_blocked_scale_offsets), but this is just a couple standard torch ops and shouldn't cause a d2h sync. There is probably still room for optimization here by doing this in the kernel somehow, but we'll take things one step at a time.

Test plan

  • pytest test/prototype/moe_training/test_kernels.py -k blocked

Performance

  • Low memory bandwidth utilization but 14x faster than existing torch implementation
input_shape      torch_time_us    triton_time_us    torch_mem_bw_gbps    triton_mem_bw_gbps  triton_speedup
-------------  ---------------  ----------------  -------------------  --------------------  ----------------
(16640, 160)           866.848            60.416                6.261                89.831  14.35x

Copy link

pytorch-bot bot commented Aug 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2886

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

danielvegamyhre added a commit that referenced this pull request Aug 27, 2025
stack-info: PR: #2886, branch: danielvegamyhre/stack/62
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 27, 2025
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/62 branch from fe29946 to c66c5c0 Compare August 27, 2025 02:57
@danielvegamyhre danielvegamyhre marked this pull request as draft August 27, 2025 03:02
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/61 to main August 27, 2025 15:42
danielvegamyhre added a commit that referenced this pull request Aug 27, 2025
stack-info: PR: #2886, branch: danielvegamyhre/stack/62
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/62 branch from c66c5c0 to 75ae9d6 Compare August 27, 2025 15:43
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/61 August 27, 2025 15:43
@danielvegamyhre danielvegamyhre marked this pull request as ready for review August 27, 2025 15:43
@danielvegamyhre danielvegamyhre added mx topic: not user facing Use this tag if you don't want this PR to show up in release notes labels Aug 27, 2025
@danielvegamyhre danielvegamyhre changed the title [mxfp8 moe training] add per group blocked scale kernels [mxfp8 moe training] add per group blocked scale kernels for 2d input activations Aug 27, 2025
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/61 to main August 27, 2025 15:53
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/62 branch from 75ae9d6 to 3cf3f8d Compare August 27, 2025 15:53
danielvegamyhre added a commit that referenced this pull request Aug 27, 2025
stack-info: PR: #2886, branch: danielvegamyhre/stack/62
@danielvegamyhre danielvegamyhre changed the title [mxfp8 moe training] add per group blocked scale kernels for 2d input activations [mxfp8 moe training] add per group blocked scale kernels Aug 27, 2025
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/61 August 27, 2025 15:53
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/61 to main August 27, 2025 16:08
danielvegamyhre added a commit that referenced this pull request Aug 27, 2025
stack-info: PR: #2886, branch: danielvegamyhre/stack/62
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/62 branch from 3cf3f8d to d0b4a1e Compare August 27, 2025 16:08
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/61 August 27, 2025 16:08
# We track how many row blocks we have iterated through.
block_row_id = 0
current_start_row = input_group_start_row
while current_start_row < input_group_end_row:
Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Aug 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for reviewer: I think we can probably do this without a loop, and just parallelize across row blocks as well (like in the original impl for dense models). Need to think about it some more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets add as a follow up / todo

@danielvegamyhre danielvegamyhre requested a review from vkuzo August 27, 2025 16:21
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/61 to main August 27, 2025 16:23
danielvegamyhre added a commit that referenced this pull request Aug 27, 2025
stack-info: PR: #2886, branch: danielvegamyhre/stack/62
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/61 August 28, 2025 01:19
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/61 to main August 28, 2025 01:23
danielvegamyhre added a commit that referenced this pull request Aug 28, 2025
stack-info: PR: #2886, branch: danielvegamyhre/stack/62
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/62 branch from b52d7d1 to a174a57 Compare August 28, 2025 01:23
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/61 August 28, 2025 01:23
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/61 to main August 28, 2025 01:55
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/61 August 28, 2025 01:55
@danielvegamyhre danielvegamyhre changed the title [mxfp8 moe training] add per group blocked scale kernels [mxfp8 moe training] add per group blocked swizzle scale kernels for 2d input scales with group offsets Aug 28, 2025
@danielvegamyhre danielvegamyhre changed the title [mxfp8 moe training] add per group blocked swizzle scale kernels for 2d input scales with group offsets [mxfp8 moe training] add triton kernel for blocked swizzled 2d input scales with group offsets Aug 28, 2025
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/61 to main August 28, 2025 04:06
@danielvegamyhre danielvegamyhre changed the title [mxfp8 moe training] add triton kernel for blocked swizzled 2d input scales with group offsets [mxfp8 moe training] add per group blocked scale kernels Aug 28, 2025
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/61 August 28, 2025 04:06
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/61 to main August 28, 2025 15:08
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/61 August 28, 2025 15:09
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/61 branch from a08719c to b755921 Compare August 28, 2025 15:09
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/61 to main August 28, 2025 15:13
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/62 branch from a174a57 to 327db2b Compare August 28, 2025 15:13
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/61 August 28, 2025 15:13
danielvegamyhre added a commit that referenced this pull request Aug 28, 2025
stack-info: PR: #2886, branch: danielvegamyhre/stack/62
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/62 branch from 327db2b to 8cfccae Compare August 28, 2025 15:15
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/61 to main August 28, 2025 15:15
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/62 branch from 8cfccae to 402a30a Compare August 28, 2025 15:30
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/62 branch from 402a30a to 0312c7e Compare August 28, 2025 19:34
stack-info: PR: #2886, branch: danielvegamyhre/stack/62
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/62 branch from 0312c7e to 35a9c69 Compare August 28, 2025 21:38
@danielvegamyhre danielvegamyhre merged commit 4ecc89e into main Aug 28, 2025
13 of 17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. mx topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants