Skip to content

WIP stream k scheduling #67

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 104 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
104 commits
Select commit Hold shift + click to select a range
6b1d059
Support ROCM builds from source distribution, and improve error handl…
mgorny Jan 18, 2025
cd393e0
[Build] Update version of setuptools used to generate core package (#…
tmm1 Jan 29, 2025
bb135af
Don't compile for CUDA 11, compile for official pytorch 2.6.0
tridao Jan 29, 2025
979702c
Bump to v2.7.4
tridao Jan 29, 2025
5231d95
Drop Pytorch 2.1
tridao Jan 29, 2025
454ce31
[FA3] Compile with nvcc 12.8 instead of 12.3
tridao Jan 29, 2025
803f609
Fix comment in assert
tridao Jan 30, 2025
02541ac
[CE] Assert logit_scale > 0
tridao Jan 30, 2025
2a20412
Implement HeadDim_V != HeadDim_QK, support hdimQK=192, hdimV=128
tridao Feb 3, 2025
6d199aa
Fix shape_O in epilogue params when kHeadDimV != kHeadDim
tridao Feb 4, 2025
86bcd05
Remove old combine.h
tridao Feb 4, 2025
e3b2400
Fix loading paged V when kHeadDimV != kHeadDim
tridao Feb 4, 2025
9e07d6d
Fix shape_V for storing new KV when kHeadDimV != kHeadDim
tridao Feb 4, 2025
f0f2523
Implement the case of LargeHeadDimV
tridao Feb 4, 2025
4c8819d
Rename Mma0->MmaQK, Mma1->MmaPV, use Cluster only if hdimV >= 192
tridao Feb 7, 2025
dd87691
Pass _1 or _0 to cute::aligned_struct
tridao Feb 8, 2025
ed53b5f
Fix compilation for FP8 when kHeadDimV != kHeadDim
tridao Feb 8, 2025
4e8496a
Support Qv
tridao Feb 8, 2025
893a22a
Test varlen_q=True by default for kvcache
tridao Feb 8, 2025
5fab938
Fix num_splits heuristic being called before get_pack_gqa
tridao Feb 8, 2025
5fc5ebf
Fix num_splits heuristic again when PackGQA
tridao Feb 8, 2025
5378bc3
Tile fwd_combine kernel along headdim, don't need kBlockM > 128
tridao Feb 8, 2025
db8ca79
Use bf16 instead of fp16 in benchmark_gemm.py
tridao Feb 9, 2025
982c480
Update Cutlass to 3.7
tridao Feb 9, 2025
58ebfa5
Use nvcc 12.6 but ptxas 12.8
tridao Feb 9, 2025
ed435c6
cicc uses the same version as ptxas
tridao Feb 9, 2025
8668823
Split hdimdiff into a separate translation unit
tridao Feb 9, 2025
b2fc79d
Update benchmark script
tridao Feb 9, 2025
c091545
Update Cutlass to 3.8
tridao Feb 9, 2025
5e39b10
Adjust tile size for hdim 64
tridao Feb 9, 2025
1a7f4df
Adjust ninja build file
tridao Feb 10, 2025
15cf7ee
Rename collective_mainloop -> mainloop, move tile_scheduler variable
tridao Feb 11, 2025
9f313c7
Move functions getting number of m/n blocks to a separate file
tridao Feb 12, 2025
eafd53c
Update cutlass 3.8 to fix error w cudaGetDriverEntryPointByVersion
tridao Feb 12, 2025
fa445ff
Fix FP8 test
tridao Feb 12, 2025
a09abcd
make seqused optional on top level interface (#1497)
vasqu Feb 16, 2025
40cbd52
Temporarily change package name of FA3 to allow FA2 & FA3 install
tridao Feb 18, 2025
91917b4
Update benchmark_split_kv.py to work w new API
tridao Feb 18, 2025
ea3ecea
Add tp_degree to benchmark_split_kv
tridao Feb 18, 2025
74dfa43
Fix divide by 0 in causal tile_scheduler for large seqlen
tridao Feb 19, 2025
b36ad4e
Use split for super long sequences that don't fit into L2
tridao Feb 19, 2025
ecdb528
Make rotary test optional in FA3
tridao Feb 22, 2025
06e34f6
Enable MLA flag in FA3 (rope=64, latent=512) (#1504)
tzadouri Feb 23, 2025
6aed835
Add simple script to benchmark MLA decode
tridao Feb 24, 2025
6752d62
Add dynamic splits
tridao Feb 24, 2025
cdda5bf
Update to Cutlass 3.8.0 tag
tridao Feb 24, 2025
9505c74
Adjust seqlen_q in MLA decode benchmark script
tridao Feb 24, 2025
3b5047d
Fix loop in prepare_scheduler.cu (h/t Jay Shah)
tridao Feb 25, 2025
dec83a1
fix: add "typename" prior to dependent type name (#1517)
zhiweij1 Feb 28, 2025
08f4c80
Add FLOPS to MLA decode benchmark
tridao Feb 28, 2025
085ce58
Change margin in prepare_scheduler.cu from 20% to 10%
tridao Feb 28, 2025
39e7197
Fix cuda 12.1 build (#1511)
LucasWilkinson Mar 1, 2025
20b84d6
Don't use IntraWGOverlap for hdim 64,512
tridao Mar 2, 2025
5458c78
Remove sink token
tridao Mar 2, 2025
6865e60
fix: prompt index to type longlong to avoid numerical overflow (#1500)
xin-w8023 Mar 2, 2025
45c48af
Add option for WG1 to use RS MMA but WG2 using SS MMA
tridao Mar 4, 2025
3edf7e0
Add kwargs to _write_ninja_file for compatibility with new torch
tridao Mar 4, 2025
4f0640d
Move writing P to smem as separate function
tridao Mar 5, 2025
d82bbf2
Fix causal scheduler not considering hdim_v != hdim
tridao Mar 5, 2025
9c036e4
Always split fwd_combine_kernel on batch
tridao Mar 7, 2025
81643fa
For each batch, if num_splits=1, write to O instead of O_partial
tridao Mar 8, 2025
1d30bb4
Enable TMA when page size is a multiple of kBlockN
tridao Mar 9, 2025
a3a9cc5
Update ptxas to 12.8.93 (i.e. 12.8.1)
tridao Mar 9, 2025
322bec9
Use tile size 192 x 128 for hdim 64 causal
tridao Mar 9, 2025
5639b9d
Update benchmark_mla_decode.py
tridao Mar 9, 2025
48b3acb
Benchmark MHA, GQA, MQA, MLA in the same script
tridao Mar 11, 2025
d904855
Benchmark FlashMLA if it's available
tridao Mar 11, 2025
cdaf2de
Run all 4 attn variants in benchmark
tridao Mar 12, 2025
cf1b809
Move scheduler.get_next_work to before the epilogue
tridao Mar 12, 2025
3cf8998
Enable Cluster for hdim128 back
tridao Mar 12, 2025
6063dc5
Move tOrO init in mainloop
tridao Mar 12, 2025
430954a
Adjust heuristic for get_pagedkv_tma
tridao Mar 12, 2025
000090d
Enable PDL
tridao Mar 13, 2025
46e1d4a
Simplify prepare_varlen_num_blocks_kernel, restrict to batch <= 992
tridao Mar 13, 2025
897c845
Fix: num_splits_dynamic_ptr needs to be set before get_num_splits
tridao Mar 14, 2025
90f27a2
Loop on num_splits instead of parameterizing it in kvcache test
tridao Mar 15, 2025
fa60e7c
Add option to precompute scheduler metadata
tridao Mar 15, 2025
6c87fac
Update MLA decode benchmark to use get_scheduler_metadata
tridao Mar 15, 2025
4b5eeab
Fix FP8 test to quantize KV cache for reference impl as well
tridao Mar 15, 2025
27f501d
Dynamic autotune configs for devices with warp size != 32 (#1534)
schung-amd Mar 15, 2025
7ae5f8c
Add option for rotary_seqlens
tridao Mar 21, 2025
fef4fcf
Use StreamkBarrier0/1 barriers instead of TileCountSmemEmpty/Full
tridao Mar 22, 2025
b1951a4
Update Cutlass to 3.9
tridao Mar 22, 2025
df11fca
Support hdim 64,256
tridao Mar 22, 2025
f6a294a
Update benchmark with GLA
tridao Mar 22, 2025
29ef580
Adjust warp scheduler sync for HasQv case
tridao Mar 22, 2025
2f9ef08
num_head -> args.num_head (#1552)
yeqcharlotte Mar 25, 2025
1a58058
Fix zeroing out the scheduler semaphore when reusing metadata
tridao Mar 29, 2025
2dd8078
fix deprecation warning for newer torch versions (#1565)
vasqu Apr 1, 2025
7ff1b62
Don't use FusedDense anymore to simplify code
tridao Apr 7, 2025
aa04de6
Fix FA3 qkvpacked interface
tridao Apr 7, 2025
2afa43c
Launch more thread blocks in layer_norm_bwd
tridao Apr 8, 2025
9f2d2ae
check valid tile before storing num_splits in split_idx (#1578)
jayhshah Apr 9, 2025
d836a6b
Tune rotary kernel to use 2 warps if rotary_dim <= 64
tridao Apr 9, 2025
0ffdc94
Merge remote-tracking branch 'upstream/main' into lwilkinson/upstream…
LucasWilkinson Apr 9, 2025
8046156
update api
LucasWilkinson Apr 9, 2025
70cd625
single wg for decode
LucasWilkinson Apr 11, 2025
65c54ad
disable masking for pure decode
LucasWilkinson Apr 11, 2025
62c987b
Seperate out `get_n_block_min_max`
LucasWilkinson Apr 14, 2025
57635db
add `TileSchedulerCommon`
LucasWilkinson Apr 14, 2025
5e026f6
build
LucasWilkinson Apr 15, 2025
626a871
building
LucasWilkinson Apr 18, 2025
8a9354c
Merge remote-tracking branch 'origin/main' into lwilkinson/stream-k-s…
LucasWilkinson Apr 26, 2025
3552c40
streamk working
LucasWilkinson May 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
SOURCES
hopper/flash_fwd_combine.cu
hopper/flash_prepare_scheduler.cu
hopper/streamk.cu
hopper/flash_api.cpp
hopper/flash_api_torch_lib.cpp
${FA3_GEN_SRCS}
Expand All @@ -244,11 +245,15 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
# FLASHATTENTION_DISABLE_ALIBI
# FLASHATTENTION_DISABLE_SOFTCAP
FLASHATTENTION_DISABLE_SOFTCAP
FLASHATTENTION_DISABLE_UNEVEN_K
# FLASHATTENTION_DISABLE_LOCAL
FLASHATTENTION_DISABLE_PYBIND
FLASHATTENTION_VARLEN_ONLY # Custom flag to save on binary size
FLASHATTENTION_DISABLE_HDIM64
FLASHATTENTION_DISABLE_HDIM96
FLASHATTENTION_DISABLE_HDIM192
FLASHATTENTION_DISABLE_HDIM256
)
elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 12.0)
message(STATUS "FA3 is disabled because CUDA version is not 12.0 or later.")
Expand Down
45 changes: 21 additions & 24 deletions hopper/epilogue_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,18 +217,16 @@ struct CollectiveEpilogueFwd {
SharedStorage& shared_storage,
TiledMma tiled_mma,
int thread_idx,
cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord
BlockCoord<false> const& block_coord
) {

auto [m_block, bidh, bidb, split_idx] = block_coord;
int num_splits = get<4>(params.shape_O_packed);
if constexpr (Split && Varlen) {
uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx
}
bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1);
int const m_block = block_coord.m_block;
int const bidh = block_coord.bidh;
int const bidb = block_coord.bidb;
int const peer_id = block_coord.peer_id;
int const num_peers = block_coord.num_peers;

bool const is_split = !Split ? false : (!Varlen ? true : num_peers > 1);

Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{});
// Tensor sO_pi = cute::as_position_independent_swizzle_tensor(sO);
Expand Down Expand Up @@ -292,7 +290,7 @@ struct CollectiveEpilogueFwd {

Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)),
params.shape_LSE_packed,
!is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx);
!is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : peer_id);
// if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); }
if (!LargeHeadDimV || warp_group_idx == 0) {
if constexpr (!PackGQA) {
Expand All @@ -308,7 +306,7 @@ struct CollectiveEpilogueFwd {

// Step 3: Write O from smem -> gmem
if constexpr (Use_TMA_O) {
Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx);
Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, peer_id);
Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
auto block_tma_O = params.tma_store_O.get_slice(_0{});
Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K)
Expand Down Expand Up @@ -361,7 +359,7 @@ struct CollectiveEpilogueFwd {
PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
}
} else {
Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset_o * get<0>(params.stride_O_partial)), params.shape_O_packed, params.stride_O_partial_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx);
Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset_o * get<0>(params.stride_O_partial)), params.shape_O_packed, params.stride_O_partial_packed)(_, _, bidh, !is_varlen ? bidb : 0, peer_id);
Tensor gOpartial = local_tile(mOpartial, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
// We already arrived on barrier_O earlier if !Use_smem
if constexpr (Use_smem) {
Expand Down Expand Up @@ -410,18 +408,17 @@ struct CollectiveEpilogueFwd {
store_zero(
Params const& params,
int thread_idx,
cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord
BlockCoord<false> const& block_coord
) {
int const m_block = block_coord.m_block;
int const bidh = block_coord.bidh;
int const bidb = block_coord.bidb;
int const peer_id = block_coord.peer_id;
int const num_peers = block_coord.num_peers;

static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
auto [m_block, bidh, bidb, split_idx] = block_coord;
int num_splits = get<4>(params.shape_O_packed);
if constexpr (Split && Varlen) {
uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx
}
bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1);

bool const is_split = !Split ? false : (!Varlen ? true : num_peers > 1);

flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused};
bool const is_varlen = Varlen && params.cu_seqlens;
Expand All @@ -430,7 +427,7 @@ struct CollectiveEpilogueFwd {
int qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor;
Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)),
params.shape_LSE_packed,
!is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx);
!is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : peer_id);
Tensor gLSE = local_tile(mLSE, Shape<Int<kBlockM>>{}, make_coord(m_block));

static_assert(kBlockM <= NumEpilogueThreads);
Expand Down
9 changes: 9 additions & 0 deletions hopper/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <cuda.h>
#include <vector>

#include "streamk.h"

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Qkv_params {
Expand Down Expand Up @@ -149,6 +151,7 @@ struct Flash_fwd_params : public Qkv_params {

int num_splits; // For split-KV version
bool pack_gqa;
bool use_one_mma_wg;

int * __restrict__ tile_count_semaphore;
// int * __restrict__ num_m_blocks_ptr;
Expand All @@ -158,6 +161,12 @@ struct Flash_fwd_params : public Qkv_params {

int arch;
int num_sm;

// Streamk stuff
int * __restrict__ sm_work_tile_ind_ptr = nullptr;
StreamKWorkTile * __restrict__ work_tiles_ptr = nullptr;
StreamKCombineTile * __restrict__ combine_tiles_ptr = nullptr;
StreamKSchedulerDescisions const* host_scheduler_metadata_ptr = nullptr;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
124 changes: 48 additions & 76 deletions hopper/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "tile_size.h"
#include "heuristics.h"
#include "cuda_check.h"
#include "streamk.h"

// Copied from https://github.com/pytorch/pytorch/commit/7931eee5c5ebcdf468bff4d308510b03355cd909
// This is so that we can pass in torch.dtype as a parameter to the function.
Expand Down Expand Up @@ -407,53 +408,6 @@ inline bool get_pagedkv_tma(Flash_fwd_params const& params) {
return params.page_size % kBlockN == 0 && params.seqlen_q * (params.h / params.h_k) > kBlockM;
}

inline bool get_pack_gqa(Flash_fwd_params const& params) {
// Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation and binary size.
// Has little effect on speed.
if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; }
#ifdef FLASHATTENTION_DISABLE_PACKGQA
return false;
#else
// params.page_table must already be set
if (params.h == params.h_k) { return false; }
// This needs to match the kernel configs
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);
int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90);
return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM);
#endif
}

inline int get_num_splits(Flash_fwd_params const& params) {
#ifdef FLASHATTENTION_DISABLE_SPLIT
return 1;
#else
// Always enable PackGQA for Split
// params.page_table must already be set
// This needs to match the kernel configs
bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k;
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, use_one_mma_wg(params));
// Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits
// has not been set here. It's OK though because we might just underestimate kBlockN a bit
auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr);
int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);
int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k);
// If is_local, we're not going to load all of seqlen_k
int const seqlen_k_loaded = !params.is_local
? params.seqlen_k
: std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM));
int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN;
int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM;
int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2);
// Always enable PackGQA for Split
// If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits.
// We assume the case where there's 1 long sequence and the rest are short, i.e. pretending
// that batch = 1.
int total_mblocks = (params.num_splits_dynamic_ptr ? 1 : params.b) * params.h_k * num_m_blocks;
return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128);
#endif
}

inline int get_max_headdim() {
#ifndef FLASHATTENTION_DISABLE_HDIM256
return 256;
Expand Down Expand Up @@ -502,7 +456,7 @@ inline int round_up_headdimv(int head_size) {
}

// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available
at::Tensor
std::tuple<at::Tensor, at::Tensor>
mha_fwd_get_scheduler_metadata(
int batch_size,
int max_seqlen_q,
Expand Down Expand Up @@ -531,7 +485,8 @@ mha_fwd_get_scheduler_metadata(

TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn,
"FlashAttention only supports fp16, bf16, and fp8_e4m3 data type");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
TORCH_CHECK(num_heads % num_heads_k == 0 && num_heads_k > 1,
"Number of heads in key/value must divide number of heads in query");

// Reset the parameters
Flash_fwd_params params{};
Expand Down Expand Up @@ -581,15 +536,11 @@ mha_fwd_get_scheduler_metadata(
params.page_size = page_size.has_value() ? page_size.value() : 1;
params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast<int*>(1);

bool const use_dynamic_split = params.b <= 992;
params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast<int*>(1);
bool const use_stream_k = params.b <= 992;
params.num_splits_dynamic_ptr = !use_stream_k ? nullptr : reinterpret_cast<int*>(1);

params.pagedkv_tma = get_pagedkv_tma(params);
// Determine if we should pack GQA before num_splits since it impacts use_one_mma_wg (in get_num_splits)
params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
// Always enable PackGQA for Split
params.pack_gqa = params.num_splits > 1;
determine_pack_gqa_splits_and_mma_wgs(params, num_splits, pack_gqa_, use_stream_k);

bool is_varlen = true;

Expand All @@ -599,17 +550,21 @@ mha_fwd_get_scheduler_metadata(

auto opts = seqused_k.options();
// This needs to be set after get_num_splits
at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic
bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1;
if (scheduler_needs_semaphore || use_dynamic_split) {
tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b}, opts.dtype(torch::kInt32));
if (scheduler_needs_semaphore) {
if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing
params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
} else {
params.tile_count_semaphore = nullptr;
}
params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr<int>() + 1 : nullptr;
at::Tensor device_metadata; // Contains the semaphore and optionally num_splits_dynamic
at::Tensor host_metadata;

bool const scheduler_needs_semaphore = (params.arch >= 90 || params.num_splits > 1) && !use_stream_k;

if (scheduler_needs_semaphore) {
device_metadata = torch::empty({1}, opts.dtype(torch::kInt32));
} else {
std::tie(device_metadata, host_metadata) = streamk_schedule(
params.arch, params.num_sm, params.b, cu_seqlens_q_, seqused_k, params.seqlen_q, params.seqlen_k,
params.h, params.h_k, params.d, params.dv, params.is_causal, params.is_local,
params.is_e4m3 ? 1 : 2, false /*v_colmajor*/, true /*pagedkv*/, params.pagedkv_tma,
params.softcap, params.seqlen_knew > 0
);
return {host_metadata, device_metadata};
}

if (params.num_splits_dynamic_ptr) {
Expand All @@ -621,7 +576,7 @@ mha_fwd_get_scheduler_metadata(
prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/);
CHECK_CUDA_KERNEL_LAUNCH();
}
return tile_count_semaphore;
return {host_metadata, device_metadata};
}

// b: batch_size
Expand Down Expand Up @@ -664,6 +619,8 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
float const softcap,
bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
std::optional<at::Tensor> &scheduler_metadata_, // (b + 1)
std::optional<const at::Tensor> &device_scheduler_metadata_,
std::optional<const at::Tensor> &host_scheduler_metadata_,
int num_splits,
std::optional<bool> pack_gqa_,
int const sm_margin
Expand Down Expand Up @@ -935,16 +892,31 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
}

// 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks kernel
bool const use_dynamic_split = is_varlen && params.b <= 992;
bool const use_stream_k = device_scheduler_metadata_ && host_scheduler_metadata_;
bool const use_dynamic_split = is_varlen && params.b <= 992 && !use_stream_k;
// Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it
params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast<int*>(1);

params.pagedkv_tma = get_pagedkv_tma(params);
// Determine if we should pack GQA before num_splits since it impacts use_one_mma_wg (in get_num_splits)
params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
// Always enable PackGQA for Split
params.pack_gqa = params.num_splits > 1;
determine_pack_gqa_splits_and_mma_wgs(params, num_splits, pack_gqa_, use_stream_k);

// Assume streamk scheduling
if (use_stream_k) {
auto host_metadata_ptr = reinterpret_cast<StreamKSchedulerDescisions const*>(host_scheduler_metadata_->data_ptr<uint8_t>());
params.host_scheduler_metadata_ptr = host_metadata_ptr;
params.num_splits = host_metadata_ptr->max_num_peers;
params.pack_gqa = host_metadata_ptr->pack_gqa;
params.use_one_mma_wg = host_metadata_ptr->use_one_mma_wg;

int num_work_tiles = host_metadata_ptr->num_work_tiles;
assert(device_scheduler_metadata_.has_value());
auto device_metadata_ptr = device_scheduler_metadata_->data_ptr<uint8_t>();

auto [offsets, total_size] = get_device_metadata_offsets_and_size(
params.num_sm, num_work_tiles, host_metadata_ptr->num_combine_blocks);

params.work_tiles_ptr = reinterpret_cast<StreamKWorkTile*>(device_metadata_ptr + offsets.work_tiles_offset);
params.sm_work_tile_ind_ptr = reinterpret_cast<int*>(device_metadata_ptr + offsets.work_tiles_ind_ptr_offset);
params.combine_tiles_ptr = reinterpret_cast<StreamKCombineTile*>(device_metadata_ptr + offsets.combine_tiles_offset);
}

// This needs to be set after get_num_splits
at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic
Expand Down
Loading