Skip to content

[FA] Unify Base + Opt FWD kernels #233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Conversation

codingwithsurya
Copy link

@codingwithsurya codingwithsurya commented May 27, 2025

Summary:

I merged the "base" (_attn_fwd) and "opt" (_attn_fwd_opt) variants of the forward pass of the attention kernel into a single _attn_fwd_unified kernel to streamline the codebase and avoid redundancy. Both the base and opt kernels were great candidates to be merged since they were nearly identical code paths.

This PR also includes diffs from #232 due to Github Export from Phabricator. #232 should be merged before this.

Test Plan:
Unit Tests and Benchmarking

python -m unittest test/test_gpu/main.py -k test_gpu_tritonbench_flash_attention

python run.py --op flash_attention --only triton_tutorial_flash_v2 --batch 4 --seq-len 16384 --n-heads 32 --d-head 64 --precision fp16 --causal --metrics flops

python run.py --op flash_attention --only triton_tutorial_flash_v2_opt --batch 4 --seq-len 16384 --n-heads 32 --d-head 64 --precision fp16 --causal --metrics flops

Differential Revision: D75388323

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75388323

Summary:

This PR consolidates redundant TMA attention kernels into a unified implementation. Previously, `_attn_fwd_tma` and `_attn_fwd_tma_ws` contained duplicate code (mainly the TMA descriptors) and didn't leverage the existing `ENABLE_WS` flag. 

I've merged the redundant kernels into a single `_attn_fwd_tma_unified` kernel. We now use the `ENABLE_WS` flag to toggle between regular and warp-specialized execution.

Changes:

*   Merged both kernels into `_attn_fwd_tma_unified` kernel with handling of regular and warp-specialized paths
*   Utilized existing `ENABLE_WS` parameter to control warp specialization
*   Unified TMA descriptor creation logic

Differential Revision: D75307125
Summary:


Separated the TMA kernel variant handling into distinct code paths rather than using a conditional parameter. 

Changed from a unified approach with a dynamic `is_warp_specialized` flag to explicit separate conditions for `tma` and `tma_ws` variants. This improves code clarity by making the execution path more explicit + makes it easier for compiler to optimize.

Differential Revision: D75308966
Summary:

I merged the "base" (`_attn_fwd`) and "opt" (`_attn_fwd_opt`) variants of the forward pass of the attention kernel into a single `_attn_fwd_unified` kernel to streamline the codebase and avoid redundancy. Both the base and opt kernels were great candidates to be merged since they were nearly identical code paths.

I still need to integrate the warp spec path (`attn_fwd_ws`) after some bug-fixing. For now it's still available as its own autotuned kernel.

Differential Revision: D75388323
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75388323

@codingwithsurya codingwithsurya changed the title Unify Base and Opt Attention Kernels [FA] Unify TMA attention kernels with Warp Spec flag and consolidate base/opt kernels May 27, 2025
@codingwithsurya codingwithsurya changed the title [FA] Unify TMA attention kernels with Warp Spec flag and consolidate base/opt kernels [FA] Unify TMA attention kernels with Warp Spec flag and Consolidate Base/Opt kernels May 27, 2025
@codingwithsurya codingwithsurya changed the title [FA] Unify TMA attention kernels with Warp Spec flag and Consolidate Base/Opt kernels [FA] Unify Base + Opt FWD kernels May 27, 2025
@codingwithsurya codingwithsurya self-assigned this May 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants