Skip to content

[FA] Unifying TMA Kernels with Warp Specialization Flag #232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

codingwithsurya
Copy link

@codingwithsurya codingwithsurya commented May 27, 2025

Summary:
This PR consolidates redundant TMA attention kernels into a unified implementation. Previously, _attn_fwd_tma and _attn_fwd_tma_ws contained duplicate code (mainly the TMA descriptors) and didn't leverage the existing ENABLE_WS flag.

I've merged the redundant kernels into a single _attn_fwd_tma_unified kernel. We now use the ENABLE_WS flag to toggle between regular and warp-specialized execution.

Changes:

  • Merged both kernels into _attn_fwd_tma_unified kernel with handling of regular and warp-specialized paths
  • Utilized existing ENABLE_WS parameter to control warp specialization
  • Unified TMA descriptor creation logic

Test Plan:
Unit Tests and Benchmarking

WITH_TMA=1  python run.py --op flash_attention --only triton_tutorial_flash_v2_tma_ws --num-inputs 1 --seq-len 8192 --metrics tflops --batch 8 --n-heads 16 --d-head 128
WITH_TMA=1  python run.py --op flash_attention --only triton_tutorial_flash_v2_tma --num-inputs 1 --seq-len 8192 --metrics tflops --batch 8 --n-heads 16 --d-head 128
python -m unittest test/test_gpu/main.py -k test_gpu_tritonbench_flash_attention

The performance metrics before and after the code change are identical (4.39805e+12 FLOPS).

Follow Up PR for Base + Opt Kernels: #233

Differential Revision: D75307125 and D75308966 (follow-up diff)

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75308966

@codingwithsurya codingwithsurya changed the title Refactor TMA kernel variant handling for improved readability (#231) [FA] Unifying TMA Kernels with Warp Specialization Flag May 27, 2025
facebook-github-bot pushed a commit that referenced this pull request May 27, 2025
Summary:


Separated the TMA kernel variant handling into distinct code paths rather than using a conditional parameter. 

Changed from a unified approach with a dynamic `is_warp_specialized` flag to explicit separate conditions for `tma` and `tma_ws` variants. This improves code clarity by making the execution path more explicit + makes it easier for compiler to optimize.

Differential Revision: D75308966
Summary:

This PR consolidates redundant TMA attention kernels into a unified implementation. Previously, `_attn_fwd_tma` and `_attn_fwd_tma_ws` contained duplicate code (mainly the TMA descriptors) and didn't leverage the existing `ENABLE_WS` flag. 

I've merged the redundant kernels into a single `_attn_fwd_tma_unified` kernel. We now use the `ENABLE_WS` flag to toggle between regular and warp-specialized execution.

Changes:

*   Merged both kernels into `_attn_fwd_tma_unified` kernel with handling of regular and warp-specialized paths
*   Utilized existing `ENABLE_WS` parameter to control warp specialization
*   Unified TMA descriptor creation logic

Differential Revision: D75307125
Summary:


Separated the TMA kernel variant handling into distinct code paths rather than using a conditional parameter. 

Changed from a unified approach with a dynamic `is_warp_specialized` flag to explicit separate conditions for `tma` and `tma_ws` variants. This improves code clarity by making the execution path more explicit + makes it easier for compiler to optimize.

Differential Revision: D75308966
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75308966

@manman-ren
Copy link
Contributor

Thanks for working on this! The patch looks good. We have been focusing on non-causal, so maybe testing

WITH_TMA=1 CUDA_VISIBLE_DEVICES=5 python run.py --op flash_attention --only triton_tutorial_flash_v2_tma_ws --num-inputs 1 --seq-len 8192 --metrics tflops --batch 8 --n-heads 16 --d-head 128
WITH_TMA=1 CUDA_VISIBLE_DEVICES=5 python run.py --op flash_attention --only triton_tutorial_flash_v2_tma --num-inputs 1 --seq-len 8192 --metrics tflops --batch 8 --n-heads 16 --d-head 128

We need WITH_TMA=1 to actually enable tma.
If you haven't imported the diff to fbsource yet, please do it, it will trigger tests with our internal Triton.

@codingwithsurya
Copy link
Author

codingwithsurya commented May 28, 2025

Thanks for working on this! The patch looks good. We have been focusing on non-causal, so maybe testing

WITH_TMA=1 CUDA_VISIBLE_DEVICES=5 python run.py --op flash_attention --only triton_tutorial_flash_v2_tma_ws --num-inputs 1 --seq-len 8192 --metrics tflops --batch 8 --n-heads 16 --d-head 128
WITH_TMA=1 CUDA_VISIBLE_DEVICES=5 python run.py --op flash_attention --only triton_tutorial_flash_v2_tma --num-inputs 1 --seq-len 8192 --metrics tflops --batch 8 --n-heads 16 --d-head 128

We need WITH_TMA=1 to actually enable tma. If you haven't imported the diff to fbsource yet, please do it, it will trigger tests with our internal Triton.

Thanks for letting me know! I have tested it with WITH_TMA=1 and it works. I have updated the test plan.

I exported it from fbsource. For reference, the diff in fbsource is here (this specific PR is the first two diffs in the stack).

@codingwithsurya codingwithsurya self-assigned this May 28, 2025
Copy link
Contributor

@mandroid6 mandroid6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants