-
Notifications
You must be signed in to change notification settings - Fork 20
[FA] Unifying TMA Kernels with Warp Specialization Flag #232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This pull request was exported from Phabricator. Differential Revision: D75308966 |
Summary: Separated the TMA kernel variant handling into distinct code paths rather than using a conditional parameter. Changed from a unified approach with a dynamic `is_warp_specialized` flag to explicit separate conditions for `tma` and `tma_ws` variants. This improves code clarity by making the execution path more explicit + makes it easier for compiler to optimize. Differential Revision: D75308966
Summary: This PR consolidates redundant TMA attention kernels into a unified implementation. Previously, `_attn_fwd_tma` and `_attn_fwd_tma_ws` contained duplicate code (mainly the TMA descriptors) and didn't leverage the existing `ENABLE_WS` flag. I've merged the redundant kernels into a single `_attn_fwd_tma_unified` kernel. We now use the `ENABLE_WS` flag to toggle between regular and warp-specialized execution. Changes: * Merged both kernels into `_attn_fwd_tma_unified` kernel with handling of regular and warp-specialized paths * Utilized existing `ENABLE_WS` parameter to control warp specialization * Unified TMA descriptor creation logic Differential Revision: D75307125
Summary: Separated the TMA kernel variant handling into distinct code paths rather than using a conditional parameter. Changed from a unified approach with a dynamic `is_warp_specialized` flag to explicit separate conditions for `tma` and `tma_ws` variants. This improves code clarity by making the execution path more explicit + makes it easier for compiler to optimize. Differential Revision: D75308966
a89de27
to
fba09d9
Compare
This pull request was exported from Phabricator. Differential Revision: D75308966 |
Thanks for working on this! The patch looks good. We have been focusing on non-causal, so maybe testing
We need WITH_TMA=1 to actually enable tma. |
Thanks for letting me know! I have tested it with I exported it from fbsource. For reference, the diff in fbsource is here (this specific PR is the first two diffs in the stack). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Summary:
This PR consolidates redundant TMA attention kernels into a unified implementation. Previously,
_attn_fwd_tma
and_attn_fwd_tma_ws
contained duplicate code (mainly the TMA descriptors) and didn't leverage the existingENABLE_WS
flag.I've merged the redundant kernels into a single
_attn_fwd_tma_unified
kernel. We now use theENABLE_WS
flag to toggle between regular and warp-specialized execution.Changes:
_attn_fwd_tma_unified
kernel with handling of regular and warp-specialized pathsENABLE_WS
parameter to control warp specializationTest Plan:
Unit Tests and Benchmarking
The performance metrics before and after the code change are identical (4.39805e+12 FLOPS).
Follow Up PR for Base + Opt Kernels: #233
Differential Revision: D75307125 and D75308966 (follow-up diff)