Skip to content

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Aug 28, 2025

Stacked PRs:


[mxfp8 moe training] add triton kernel for blocked swizzled 3d weight scales

Summary

  • Add kernel to convert 3d mxfp8 weight scales to blocked swizzled format
  • This kernel is simpler because we can make a modified version of this one used for 2d linear weights, where we just have (1) different launch params to parallelize across groups, and (2) different stride calculations in the kernel, updating the base pointer to account for the group we're in.
  • I could have potentially expanded the existing dense model kernel to do this since there's no variable group sizes, but i didn't because (1) kernel needs different launch grid and additional params passed into it, so we'd have a weird if-else launching different kernels with different grids and different params. And (2) I would like to keep all the MoE training code in the prototype/moe_training to make the code base as self-contained as possible for now.
  • Note that there are no data dependent group sizes so technically torch.compile should be competitive, maybe we don't need this kernel?

Test plan

  • sanitize pytest test/prototype/moe_training/test_kernels.py -k scales_3d

Benchmarks

  • The kernel is completely unoptimized and can be improved. NCU flags uncoalesced global accesses and partial waves, which should provide a nice speedup if we resolve.
    • As an aside, I've noticed with triton loading 2d blocks of data, uncoalesced loads from GMEM seem to always happen with row major data but not column major data. This surprises me, because with a 2d block of data in either row major or col major, we should always be able to get coalesced loads by assigning threads in each warp to load either a row or column of data, depending on which direction is contiguous. Not sure if the compiler isn't able to figure this out, or if it's me that is misunderstanding something.
input_shape        torch_time_us    triton_time_us    torch_mem_bw_gbps    triton_mem_bw_gbps  triton_speedup
---------------  ---------------  ----------------  -------------------  --------------------  ----------------
(1, 8192, 160)            19.424            10.112              134.959               259.241  1.92x
(2, 8192, 160)            16.384            10.432              320                   502.577  1.57x
(4, 8192, 160)            28.736            13.216              364.9                 793.414  2.17x
(8, 8192, 160)           133.184            20.512              157.463              1022.4    6.49x
(16, 8192, 160)          104.096            35.648              402.927              1176.59   2.92x
(1, 5120, 256)            26.56             10.08                98.699               260.063  2.63x
(2, 5120, 256)            25.248            10.688              207.655               490.539  2.36x
(4, 5120, 256)           104.512            14.464              100.331               724.956  7.23x
(8, 5120, 256)           544.8              22.56                38.494               929.589  24.15x
(16, 5120, 256)          134.144            38.912              312.672              1077.89   3.45x

… scales

stack-info: PR: #2894, branch: danielvegamyhre/stack/63
Copy link

pytorch-bot bot commented Aug 28, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2894

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

danielvegamyhre added a commit that referenced this pull request Aug 28, 2025
… scales

stack-info: PR: #2894, branch: danielvegamyhre/stack/63
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/63 branch from 0558c1c to 29aa1f2 Compare August 28, 2025 01:55
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 28, 2025
@danielvegamyhre danielvegamyhre added mx topic: not user facing Use this tag if you don't want this PR to show up in release notes labels Aug 28, 2025
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/62 to main August 28, 2025 04:06
danielvegamyhre added a commit that referenced this pull request Aug 28, 2025
… scales

stack-info: PR: #2894, branch: danielvegamyhre/stack/63
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/63 branch from 29aa1f2 to 9c628c0 Compare August 28, 2025 04:06
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/62 August 28, 2025 04:06
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/62 to main August 28, 2025 15:08
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/62 August 28, 2025 15:09
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/62 to main August 28, 2025 15:13
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/63 branch from 9c628c0 to f451be2 Compare August 28, 2025 15:13
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/62 August 28, 2025 15:13
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/62 branch from 327db2b to 8cfccae Compare August 28, 2025 15:15
danielvegamyhre added a commit that referenced this pull request Aug 28, 2025
… scales

stack-info: PR: #2894, branch: danielvegamyhre/stack/63
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/63 branch from f451be2 to 95c3e51 Compare August 28, 2025 15:15
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/62 to main August 28, 2025 15:30
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/63 branch from 95c3e51 to 2b2af6a Compare August 28, 2025 15:30
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/62 August 28, 2025 15:30
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/62 to main August 28, 2025 16:49
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/62 August 28, 2025 16:49
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/62 to main August 28, 2025 17:18
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/62 August 28, 2025 17:18
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/62 to main August 28, 2025 17:26
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/62 August 28, 2025 17:26
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/62 to main August 28, 2025 19:34
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/63 branch from 2b2af6a to 6d41af1 Compare August 28, 2025 19:34
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/62 August 28, 2025 19:34
pid_col = tl.program_id(2)

# Update base pointers based on this group id
input_ptr += pid_group * input_stride_dim0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this one we can just call the other triton kernl impl though right?

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Aug 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could potentially modify the other impl but the result would be rather ugly IMO, see below (from PR description, sorry it was a bit long):

"I could have potentially expanded the existing dense model kernel to do this since there's no variable group sizes, but i didn't because (1) kernel needs different launch grid and additional params passed into it, so we'd have a weird if-else launching different kernels with different grids and different params. And (2) I would like to keep all the MoE training code in the prototype/moe_training to make the code base as self-contained as possible for now."

What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe... you could also just factor out the main kernel into a series of triton jit functions and then the kernel builds them up

non blocking

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, that could work. Sounds good

@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/62 to main August 28, 2025 21:38
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/63 branch from 6d41af1 to c388a0e Compare August 28, 2025 21:38
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/62 August 28, 2025 21:38
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/63 branch from c388a0e to 0f0598a Compare August 28, 2025 21:42
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/62 to main August 28, 2025 21:42
@danielvegamyhre danielvegamyhre merged commit 83a20c7 into main Aug 29, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. mx topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants