diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 90754b45c..c2ee67f89 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -46,7 +46,8 @@ void set_params_fprop(Flash_fwd_params ¶ms, int window_size_right, const float softcap, bool seqlenq_ngroups_swapped=false, - const bool unpadded_lse=false) { + const bool unpadded_lse=false, + const bool is_kvc=false) { // Reset the parameters params = {}; @@ -59,11 +60,19 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.v_ptr = v.data_ptr(); // All stride are in elements, not bytes. params.q_row_stride = q.stride(-3); - params.k_row_stride = k.stride(-3); - params.v_row_stride = v.stride(-3); params.q_head_stride = q.stride(-2); - params.k_head_stride = k.stride(-2); - params.v_head_stride = v.stride(-2); + params.is_kvc_cache = is_kvc; + if (!is_kvc) { + params.k_row_stride = k.stride(-3); + params.v_row_stride = v.stride(-3); + params.k_head_stride = k.stride(-2); + params.v_head_stride = v.stride(-2); + } else { + params.k_row_stride = k.stride(1); + params.v_row_stride = v.stride(1); + // head stride not used + } + params.o_ptr = out.data_ptr(); params.o_row_stride = out.stride(-3); params.o_head_stride = out.stride(-2); @@ -159,6 +168,7 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split HEADDIM_SWITCH(params.d, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 + assert(false); run_mha_fwd_(params, stream); } else { run_mha_fwd_splitkv_dispatch(params, stream); @@ -502,6 +512,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); } + const bool is_KVC = paged_KV && (block_table.dim() > 2); + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); @@ -514,11 +526,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s const int batch_size = cu_seqlens_q.numel() - 1; int num_heads = sizes[1]; const int head_size_og = sizes[2]; - const int num_heads_k = paged_KV ? k.size(2) : k.size(1); + const int num_heads_k = paged_KV ? (!is_KVC ? k.size(2): block_table.size(1)) : k.size(1); if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } - const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); + const int max_num_blocks_per_seq = !paged_KV ? 0 : (!is_KVC ? block_table.size(1) : block_table.size(2)); const int num_blocks = !paged_KV ? 0 : k.size(0); const int page_block_size = !paged_KV ? 1 : k.size(1); TORCH_CHECK(!paged_KV || page_block_size % 16 == 0, "Paged KV cache block size must be divisible by 16"); @@ -554,13 +566,29 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); } else { - CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og); - CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og); - CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + if (!is_KVC) { + CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og); + CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og); + CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + } else { + CHECK_SHAPE(k, num_blocks, page_block_size, head_size_og); + CHECK_SHAPE(v, num_blocks, page_block_size, head_size_og); + // [ batch_size, kv_heads, blocks ] + // printf("batch_size=%d, num_heads_k=%d, max_num_blocks_per_seq=%d", + // batch_size, num_heads_k, max_num_blocks_per_seq); + // std::cout << "block_tables shape\n" << block_table.sizes() << std::endl; + CHECK_SHAPE(block_table, batch_size, num_heads_k, max_num_blocks_per_seq); + } } + bool seqlen_by_head = false; CHECK_SHAPE(cu_seqlens_q, batch_size + 1); - CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + if (!is_KVC) { + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + } else { + seqlen_by_head = cu_seqlens_k.size(0) > batch_size + 1; + // CHECK_SHAPE(cu_seqlens_k, batch_size + 1 + batch_size * num_heads_k + 1); + } if (seqused_k.has_value()){ auto seqused_k_ = seqused_k.value(); TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32"); @@ -639,12 +667,22 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s window_size_right, softcap, seqlenq_ngroups_swapped, - /*unpadded_lse*/true); + /*unpadded_lse*/true, + /*is_kvc*/is_KVC); params.total_q = total_q; + params.seqlen_by_head = seqlen_by_head; if (paged_KV) { params.block_table = block_table.data_ptr(); - params.block_table_batch_stride = block_table.stride(0); + if (!is_KVC) { + params.block_table_batch_stride = block_table.stride(0); + } else { + params.kseqlen_batch_stride = num_heads_k; + params.block_table_batch_stride = block_table.stride(0); + params.block_table_head_stride = block_table.stride(1); + } + // std::cout << "\n" << k_padded.strides() << std::endl; + // std::cout << k_padded.sizes() << std::endl; params.k_batch_stride = k_padded.stride(0); params.v_batch_stride = v_padded.stride(0); } @@ -759,6 +797,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); } + const bool is_KVC = paged_KV && (block_table.dim() > 2); const auto sizes = q.sizes(); @@ -769,12 +808,12 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he const int num_heads_og = num_heads; const int head_size_og = sizes[3]; - const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); + const int max_num_blocks_per_seq = !paged_KV ? 0 : (!is_KVC ? block_table.size(1) : block_table.size(2)); const int num_blocks = !paged_KV ? 0 : kcache.size(0); const int page_block_size = !paged_KV ? 1 : kcache.size(1); TORCH_CHECK(!paged_KV || page_block_size % 16 == 0, "Paged KV cache block size must be divisible by 16"); const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size; - const int num_heads_k = kcache.size(2); + const int num_heads_k = !is_KVC ? kcache.size(2) : block_table.size(1); const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size; TORCH_CHECK(batch_size > 0, "batch size must be postive"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); @@ -802,9 +841,16 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); } else { - CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og); - CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og); - CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + if (!is_KVC) { + CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og); + CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og); + CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + } else { + CHECK_SHAPE(kcache, num_blocks, page_block_size, head_size_og); + CHECK_SHAPE(vcache, num_blocks, page_block_size, head_size_og); + // [ batch_size, kv_heads, blocks ] + CHECK_SHAPE(block_table, batch_size, num_heads_k, max_num_blocks_per_seq); + } } at::Tensor q_padded, kcache_padded, vcache_padded; @@ -865,8 +911,10 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he softmax_scale, window_size_left, window_size_right, - softcap - ); + softcap, + /*seqlenq_ngroups_swapped=*/false, + /*unpadded_lse=*/false, + /*is_kvc=*/is_KVC); at::Tensor k, v, k_padded, v_padded; if (k_.has_value()) { @@ -907,8 +955,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); CHECK_DEVICE(seqlens_k); CHECK_CONTIGUOUS(seqlens_k); - CHECK_SHAPE(seqlens_k, batch_size); + if (!is_KVC) { + CHECK_SHAPE(seqlens_k, batch_size); + } else { + CHECK_SHAPE(seqlens_k, batch_size * num_heads_k); + } params.cu_seqlens_k = static_cast(seqlens_k.data_ptr()); + params.seqlen_by_head = is_KVC; } params.is_seqlens_k_cumulative = !(seqlens_k_.has_value()); @@ -954,7 +1007,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he if (paged_KV) { params.block_table = block_table.data_ptr(); - params.block_table_batch_stride = block_table.stride(0); + if (!is_KVC) { + params.block_table_batch_stride = block_table.stride(0); + } else { + params.kseqlen_batch_stride = num_heads_k; + params.block_table_batch_stride = block_table.stride(0); + params.block_table_head_stride = block_table.stride(1); + } } params.page_block_size = page_block_size; diff --git a/csrc/flash_attn/src/block_info.h b/csrc/flash_attn/src/block_info.h index 3a23a1e1f..99c1b8115 100644 --- a/csrc/flash_attn/src/block_info.h +++ b/csrc/flash_attn/src/block_info.h @@ -12,13 +12,26 @@ template struct BlockInfo { template - __device__ BlockInfo(const Params ¶ms, const int bidb) + __device__ BlockInfo(const Params ¶ms, const int bidb, const int bidkh) : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) - , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]) + , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : ( + ((bidkh < 0) || !params.seqlen_by_head) ? (params.cu_seqlens_k[bidb]) : + (params.cu_seqlens_k[bidb * params.kseqlen_batch_stride + bidkh]) + )) , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. - , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) + , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : ( + params.is_seqlens_k_cumulative ? + ( + ((bidkh < 0) || !params.seqlen_by_head) ? (params.cu_seqlens_k[bidb + 1] - sum_s_k) : + (params.cu_seqlens_k[bidb * params.kseqlen_batch_stride + bidkh + 1] - sum_s_k) + ) : + ( + ((bidkh < 0) || !params.seqlen_by_head) ? params.cu_seqlens_k[bidb] : + params.cu_seqlens_k[bidb * params.kseqlen_batch_stride + bidkh] + ) + )) , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) { } diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index 8e2352b66..f7299268c 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -103,7 +103,11 @@ struct Flash_fwd_params : public Qkv_params { // Paged KV cache int * __restrict__ block_table; + bool is_kvc_cache; + bool seqlen_by_head; + index_t kseqlen_batch_stride; index_t block_table_batch_stride; + index_t block_table_head_stride; int page_block_size; // The dropout probability (probability of keeping an activation). diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 73ad023b9..b02960343 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -82,7 +82,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi params.rng_state[1] = std::get<1>(seed_offset); } - const BlockInfo binfo(params, bidb); + const BlockInfo binfo(params, bidb, -1); if (m_block * kBlockM >= binfo.actual_seqlen_q) return; const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); @@ -526,7 +526,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons >; using ElementO = std::conditional_t; - const BlockInfo binfo(params, bidb); + const BlockInfo binfo(params, bidb, bidh / params.h_h_k_ratio); + // const BlockInfo binfo1(params, bidb, 0); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); } // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } if (m_block * kBlockM >= binfo.actual_seqlen_q) return; @@ -587,18 +589,21 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // We move K and V to the last block. const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; - const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride; + const int *block_table = params.block_table == nullptr ? nullptr : (!params.is_kvc_cache ? + params.block_table + bidb * params.block_table_batch_stride : + params.block_table + bidb * params.block_table_batch_stride + + (bidh / params.h_h_k_ratio) * params.block_table_head_stride); const index_t row_offset_k = block_table == nullptr ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride - : (bidh / params.h_h_k_ratio) * params.k_head_stride; // block addresses are later resolved per-thread + : (!params.is_kvc_cache ? (bidh / params.h_h_k_ratio) * params.k_head_stride : 0); // block addresses are later resolved per-thread const index_t row_offset_v = block_table == nullptr ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride - : (bidh / params.h_h_k_ratio) * params.v_head_stride; + : (!params.is_kvc_cache ? (bidh / params.h_h_k_ratio) * params.v_head_stride : 0); + - Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), make_shape(binfo.actual_seqlen_q, params.h, params.d), @@ -638,10 +643,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor tVsV = make_tensor(tVsV_.data(), reshape_thread_tile(tVsV_.layout())); if (block_table != nullptr) { - tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, + const int64_t tKgK_offset = flash::resolve_thread_kv_page_slice_offset( + tidx, n_block_max, params.page_block_size, block_table, params.k_batch_stride, params.k_row_stride); - tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, + const int64_t tVgV_offset = flash::resolve_thread_kv_page_slice_offset( + tidx, n_block_max, params.page_block_size, block_table, params.v_batch_stride, params.v_row_stride); + tKgK.data() = gK.data() + tKgK_offset; + tVgV.data() = gV.data() + tVgV_offset; } typename Kernel_traits::TiledMma tiled_mma; @@ -704,7 +713,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopyRotcossinContPaged gmem_tiled_copy_rotary_cont; auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); - + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. @@ -721,7 +730,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); - + Tensor tRgCos_ = gmem_thr_copy_rotary.partition_S(gCos); Tensor tRgSin_ = gmem_thr_copy_rotary.partition_S(gSin); Tensor tRgCosCont_ = gmem_thr_copy_rotary_cont.partition_S(gCosCont); @@ -798,9 +807,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { if (n_block > n_block_copy_min) { - tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, params.v_batch_stride, params.v_row_stride); - tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, params.k_batch_stride, params.k_row_stride); } } @@ -930,7 +939,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { - tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, params.k_batch_stride, params.k_row_stride); } flash::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); @@ -970,7 +979,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); } else { - tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block + 1, params.page_block_size, + tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block + 1, params.page_block_size, block_table, params.v_batch_stride, params.v_row_stride); } flash::copy(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); @@ -991,8 +1000,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { - tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, - block_table, params.k_batch_stride, params.k_row_stride); + tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + block_table, params.k_batch_stride, params.k_row_stride); } flash::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization diff --git a/tests/test_vllm_flash_attn.py b/tests/test_vllm_flash_attn.py index 758c77b68..6cc149621 100644 --- a/tests/test_vllm_flash_attn.py +++ b/tests/test_vllm_flash_attn.py @@ -7,6 +7,7 @@ import pytest import torch +import itertools import flash_attn_wrapper # noqa: F401 @@ -19,6 +20,71 @@ NUM_BLOCKS = [32768, 2048] +def ref_paged_attn_kvc( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + query_lens: List[int], + kv_lens: List[int], + block_tables: torch.Tensor, + scale: float, + sliding_window: Optional[int] = None, + soft_cap: Optional[float] = None, +) -> torch.Tensor: + num_seqs = len(query_lens) + block_tables = block_tables.cpu().numpy() + num_kv_heads = block_tables.shape[1] + num_heads = query.shape[1] + _, block_size, head_size = key_cache.shape + + outputs: List[torch.Tensor] = [] + start_idx = 0 + for i in range(num_seqs): + query_len = query_lens[i] + q = query[start_idx:start_idx + query_len] + q *= scale + + outs = [] + for j in range(num_kv_heads): + queries_per_key = num_heads // num_kv_heads + q_ = q[:,j*queries_per_key:(j+1)*queries_per_key] + + kv_len = kv_lens[i * num_kv_heads + j] + num_kv_blocks = (kv_len + block_size - 1) // block_size + block_indices = block_tables[i, j, :num_kv_blocks] + + k = key_cache[block_indices] + k = k.view(-1, 1, head_size) + k = k[:kv_len] + v = value_cache[block_indices] + v = v.view(-1, 1, head_size) + v = v[:kv_len] + + k = torch.repeat_interleave(k, queries_per_key, dim=1) + v = torch.repeat_interleave(v, queries_per_key, dim=1) + attn = torch.einsum("qhd,khd->hqk", q_, k).float() + empty_mask = torch.ones(query_len, kv_len) + mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() + if sliding_window is not None: + sliding_window_mask = torch.triu(empty_mask, + diagonal=kv_len - + (query_len + sliding_window) + + 1).bool().logical_not() + mask |= sliding_window_mask + if soft_cap is not None: + attn = soft_cap * torch.tanh(attn / soft_cap) + attn.masked_fill_(mask[None], float("-inf")) + attn = torch.softmax(attn, dim=-1).to(v.dtype) + out = torch.einsum("hqk,khd->qhd", attn, v) + outs.append(out) + + outputs.append(torch.cat(outs, dim=1)) + + start_idx += query_len + + return torch.cat(outputs, dim=0) + + def ref_paged_attn( query: torch.Tensor, key_cache: torch.Tensor, @@ -43,11 +109,13 @@ def ref_paged_attn( q *= scale num_kv_blocks = (kv_len + block_size - 1) // block_size - block_indices = block_tables[i, :num_kv_blocks] - k = key_cache[block_indices].view(-1, num_kv_heads, head_size) + block_indices = block_tables[i, :num_kv_blocks] + k = key_cache[block_indices] + k = k.view(-1, num_kv_heads, head_size) k = k[:kv_len] - v = value_cache[block_indices].view(-1, num_kv_heads, head_size) + v = value_cache[block_indices] + v = v.view(-1, num_kv_heads, head_size) v = v[:kv_len] if q.shape[1] != k.shape[1]: @@ -81,6 +149,7 @@ def ref_paged_attn( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("kvc", [True, False]) @torch.inference_mode() def test_flash_attn_with_paged_kv( kv_lens: List[int], @@ -90,6 +159,7 @@ def test_flash_attn_with_paged_kv( block_size: int, soft_cap: Optional[float], num_blocks: int, + kvc: bool, ) -> None: torch.set_default_device("cuda") torch.cuda.manual_seed_all(0) @@ -107,6 +177,16 @@ def test_flash_attn_with_paged_kv( head_size, dtype=dtype) value_cache = torch.randn_like(key_cache) + if kvc: + key_cache = (key_cache.transpose(1, 2) + .transpose(0, 1) + .reshape(-1, block_size, head_size)).contiguous() + value_cache = (value_cache.transpose(1, 2) + .transpose(0, 1) + .reshape(-1, block_size, head_size)).contiguous() + kv_lens = list(itertools.chain.from_iterable( + [l // 2] + [l] * (num_kv_heads - 2) + [l // 3] + for l in kv_lens)) kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size @@ -114,6 +194,10 @@ def test_flash_attn_with_paged_kv( num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) + if kvc: + block_tables = (block_tables[:,None] + + torch.arange(num_kv_heads).type(torch.int)[None,:,None] + * num_blocks) output = torch.ops.vllm.flash_attn_with_kvcache( decode_query=query.unsqueeze(1), @@ -145,7 +229,8 @@ def test_flash_attn_with_paged_kv( ), test_utils=test_utils) - ref_output = ref_paged_attn( + ref_func = ref_paged_attn_kvc if kvc else ref_paged_attn + ref_output = ref_func( query=query, key_cache=key_cache, value_cache=value_cache, @@ -167,6 +252,7 @@ def test_flash_attn_with_paged_kv( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("kvc", [True, False]) @torch.inference_mode() def test_varlen_with_paged_kv( seq_lens: List[Tuple[int, int]], @@ -177,14 +263,19 @@ def test_varlen_with_paged_kv( block_size: int, soft_cap: Optional[float], num_blocks: int, + kvc: bool, ) -> None: torch.set_default_device("cuda") torch.cuda.manual_seed_all(0) + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] kv_lens = [x[1] for x in seq_lens] - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] + if kvc: + kv_lens = list(itertools.chain.from_iterable( + [l // 2] + [l] * (num_kv_heads - 2) + [l // 3] + for l in kv_lens)) assert num_query_heads % num_kv_heads == 0 max_query_len = max(query_lens) max_kv_len = max(kv_lens) @@ -203,18 +294,30 @@ def test_varlen_with_paged_kv( head_size, dtype=dtype) value_cache = torch.randn_like(key_cache) + if kvc: + key_cache = (key_cache.transpose(1, 2) + .transpose(0, 1) + .reshape(-1, block_size, head_size)).contiguous() + value_cache = (value_cache.transpose(1, 2) + .transpose(0, 1) + .reshape(-1, block_size, head_size)).contiguous() cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32) - cu_kv_lens = torch.tensor([0] + kv_lens, + kv_lens_ = kv_lens + cu_kv_lens = torch.tensor([0] + kv_lens_, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + num_blocks, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + if kvc: + block_tables = (block_tables[:,None] + + torch.arange(num_kv_heads).type(torch.int)[None,:,None] + * num_blocks) output = torch.ops.vllm.flash_attn_varlen_func( q=query, @@ -254,16 +357,18 @@ def test_varlen_with_paged_kv( ), test_utils=test_utils) - ref_output = ref_paged_attn( + ref_func = ref_paged_attn_kvc if kvc else ref_paged_attn + ref_output = ref_func( query=query, key_cache=key_cache, value_cache=value_cache, query_lens=query_lens, - kv_lens=kv_lens, + kv_lens=kv_lens_, block_tables=block_tables, scale=scale, sliding_window=sliding_window, soft_cap=soft_cap, ) + torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}"