-
Notifications
You must be signed in to change notification settings - Fork 322
[mxfp8 moe training] add triton kernel for blocked swizzled 3d weight scales #2894
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… scales stack-info: PR: #2894, branch: danielvegamyhre/stack/63
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2894
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
… scales stack-info: PR: #2894, branch: danielvegamyhre/stack/63
0558c1c
to
29aa1f2
Compare
… scales stack-info: PR: #2894, branch: danielvegamyhre/stack/63
29aa1f2
to
9c628c0
Compare
9c628c0
to
f451be2
Compare
327db2b
to
8cfccae
Compare
… scales stack-info: PR: #2894, branch: danielvegamyhre/stack/63
f451be2
to
95c3e51
Compare
95c3e51
to
2b2af6a
Compare
2b2af6a
to
6d41af1
Compare
pid_col = tl.program_id(2) | ||
|
||
# Update base pointers based on this group id | ||
input_ptr += pid_group * input_stride_dim0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this one we can just call the other triton kernl impl though right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could potentially modify the other impl but the result would be rather ugly IMO, see below (from PR description, sorry it was a bit long):
"I could have potentially expanded the existing dense model kernel to do this since there's no variable group sizes, but i didn't because (1) kernel needs different launch grid and additional params passed into it, so we'd have a weird if-else launching different kernels with different grids and different params. And (2) I would like to keep all the MoE training code in the prototype/moe_training to make the code base as self-contained as possible for now."
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe... you could also just factor out the main kernel into a series of triton jit functions and then the kernel builds them up
non blocking
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, that could work. Sounds good
6d41af1
to
c388a0e
Compare
c388a0e
to
0f0598a
Compare
Stacked PRs:
[mxfp8 moe training] add triton kernel for blocked swizzled 3d weight scales
Summary
prototype/moe_training
to make the code base as self-contained as possible for now.Test plan
sanitize pytest test/prototype/moe_training/test_kernels.py -k scales_3d
Benchmarks