-
Notifications
You must be signed in to change notification settings - Fork 322
[mxfp8 moe training] add per group blocked scale kernels #2886
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2886
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
stack-info: PR: #2886, branch: danielvegamyhre/stack/62
fe29946
to
c66c5c0
Compare
stack-info: PR: #2886, branch: danielvegamyhre/stack/62
c66c5c0
to
75ae9d6
Compare
75ae9d6
to
3cf3f8d
Compare
stack-info: PR: #2886, branch: danielvegamyhre/stack/62
stack-info: PR: #2886, branch: danielvegamyhre/stack/62
3cf3f8d
to
d0b4a1e
Compare
# We track how many row blocks we have iterated through. | ||
block_row_id = 0 | ||
current_start_row = input_group_start_row | ||
while current_start_row < input_group_end_row: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note for reviewer: I think we can probably do this without a loop, and just parallelize across row blocks as well (like in the original impl for dense models). Need to think about it some more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets add as a follow up / todo
stack-info: PR: #2886, branch: danielvegamyhre/stack/62
stack-info: PR: #2886, branch: danielvegamyhre/stack/62
b52d7d1
to
a174a57
Compare
a08719c
to
b755921
Compare
a174a57
to
327db2b
Compare
stack-info: PR: #2886, branch: danielvegamyhre/stack/62
327db2b
to
8cfccae
Compare
8cfccae
to
402a30a
Compare
402a30a
to
0312c7e
Compare
stack-info: PR: #2886, branch: danielvegamyhre/stack/62
0312c7e
to
35a9c69
Compare
Stacked PRs:
[mxfp8 moe training] add per group blocked scale kernels for 2d input activations
Summary
compute_per_group_blocked_scale_offsets
), but this is just a couple standard torch ops and shouldn't cause a d2h sync. There is probably still room for optimization here by doing this in the kernel somehow, but we'll take things one step at a time.Test plan
pytest test/prototype/moe_training/test_kernels.py -k blocked
Performance