Skip to content

Commit 3aa31f1

Browse files
committed
support KV-Compress paged KV cache
1 parent 5259c58 commit 3aa31f1

File tree

5 files changed

+243
-53
lines changed

5 files changed

+243
-53
lines changed

csrc/flash_attn/flash_api.cpp

Lines changed: 81 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ void set_params_fprop(Flash_fwd_params &params,
4646
int window_size_right,
4747
const float softcap,
4848
bool seqlenq_ngroups_swapped=false,
49-
const bool unpadded_lse=false) {
49+
const bool unpadded_lse=false,
50+
const bool is_kvc=false) {
5051

5152
// Reset the parameters
5253
params = {};
@@ -59,11 +60,19 @@ void set_params_fprop(Flash_fwd_params &params,
5960
params.v_ptr = v.data_ptr();
6061
// All stride are in elements, not bytes.
6162
params.q_row_stride = q.stride(-3);
62-
params.k_row_stride = k.stride(-3);
63-
params.v_row_stride = v.stride(-3);
6463
params.q_head_stride = q.stride(-2);
65-
params.k_head_stride = k.stride(-2);
66-
params.v_head_stride = v.stride(-2);
64+
params.is_kvc_cache = is_kvc;
65+
if (!is_kvc) {
66+
params.k_row_stride = k.stride(-3);
67+
params.v_row_stride = v.stride(-3);
68+
params.k_head_stride = k.stride(-2);
69+
params.v_head_stride = v.stride(-2);
70+
} else {
71+
params.k_row_stride = k.stride(1);
72+
params.v_row_stride = v.stride(1);
73+
// head stride not used
74+
}
75+
6776
params.o_ptr = out.data_ptr();
6877
params.o_row_stride = out.stride(-3);
6978
params.o_head_stride = out.stride(-2);
@@ -159,6 +168,7 @@ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split
159168
HEADDIM_SWITCH(params.d, [&] {
160169
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
161170
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
171+
assert(false);
162172
run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
163173
} else {
164174
run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>(params, stream);
@@ -502,6 +512,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
502512
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
503513
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
504514
}
515+
const bool is_KVC = paged_KV && (block_table.dim() > 2);
516+
505517

506518
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
507519
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
@@ -514,11 +526,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
514526
const int batch_size = cu_seqlens_q.numel() - 1;
515527
int num_heads = sizes[1];
516528
const int head_size_og = sizes[2];
517-
const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
529+
const int num_heads_k = paged_KV ? (!is_KVC ? k.size(2): block_table.size(1)) : k.size(1);
518530

519531
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
520532

521-
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
533+
const int max_num_blocks_per_seq = !paged_KV ? 0 : (!is_KVC ? block_table.size(1) : block_table.size(2));
522534
const int num_blocks = !paged_KV ? 0 : k.size(0);
523535
const int page_block_size = !paged_KV ? 1 : k.size(1);
524536
TORCH_CHECK(!paged_KV || page_block_size % 16 == 0, "Paged KV cache block size must be divisible by 16");
@@ -554,13 +566,29 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
554566
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
555567
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
556568
} else {
557-
CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og);
558-
CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og);
559-
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
569+
if (!is_KVC) {
570+
CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og);
571+
CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og);
572+
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
573+
} else {
574+
CHECK_SHAPE(k, num_blocks, page_block_size, head_size_og);
575+
CHECK_SHAPE(v, num_blocks, page_block_size, head_size_og);
576+
// [ batch_size, kv_heads, blocks ]
577+
// printf("batch_size=%d, num_heads_k=%d, max_num_blocks_per_seq=%d",
578+
// batch_size, num_heads_k, max_num_blocks_per_seq);
579+
// std::cout << "block_tables shape\n" << block_table.sizes() << std::endl;
580+
CHECK_SHAPE(block_table, batch_size, num_heads_k, max_num_blocks_per_seq);
581+
}
560582
}
561583

584+
bool seqlen_by_head = false;
562585
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
563-
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
586+
if (!is_KVC) {
587+
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
588+
} else {
589+
seqlen_by_head = cu_seqlens_k.size(0) > batch_size + 1;
590+
// CHECK_SHAPE(cu_seqlens_k, batch_size + 1 + batch_size * num_heads_k + 1);
591+
}
564592
if (seqused_k.has_value()){
565593
auto seqused_k_ = seqused_k.value();
566594
TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
@@ -639,12 +667,22 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
639667
window_size_right,
640668
softcap,
641669
seqlenq_ngroups_swapped,
642-
/*unpadded_lse*/true);
670+
/*unpadded_lse*/true,
671+
/*is_kvc*/is_KVC);
643672
params.total_q = total_q;
673+
params.seqlen_by_head = seqlen_by_head;
644674

645675
if (paged_KV) {
646676
params.block_table = block_table.data_ptr<int>();
647-
params.block_table_batch_stride = block_table.stride(0);
677+
if (!is_KVC) {
678+
params.block_table_batch_stride = block_table.stride(0);
679+
} else {
680+
params.kseqlen_batch_stride = num_heads_k;
681+
params.block_table_batch_stride = block_table.stride(0);
682+
params.block_table_head_stride = block_table.stride(1);
683+
}
684+
// std::cout << "\n" << k_padded.strides() << std::endl;
685+
// std::cout << k_padded.sizes() << std::endl;
648686
params.k_batch_stride = k_padded.stride(0);
649687
params.v_batch_stride = v_padded.stride(0);
650688
}
@@ -759,6 +797,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
759797
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
760798
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
761799
}
800+
const bool is_KVC = paged_KV && (block_table.dim() > 2);
762801

763802
const auto sizes = q.sizes();
764803

@@ -769,12 +808,12 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
769808
const int num_heads_og = num_heads;
770809
const int head_size_og = sizes[3];
771810

772-
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
811+
const int max_num_blocks_per_seq = !paged_KV ? 0 : (!is_KVC ? block_table.size(1) : block_table.size(2));
773812
const int num_blocks = !paged_KV ? 0 : kcache.size(0);
774813
const int page_block_size = !paged_KV ? 1 : kcache.size(1);
775814
TORCH_CHECK(!paged_KV || page_block_size % 16 == 0, "Paged KV cache block size must be divisible by 16");
776815
const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
777-
const int num_heads_k = kcache.size(2);
816+
const int num_heads_k = !is_KVC ? kcache.size(2) : block_table.size(1);
778817
const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
779818
TORCH_CHECK(batch_size > 0, "batch size must be postive");
780819
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
@@ -802,9 +841,16 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
802841
CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
803842
CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
804843
} else {
805-
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
806-
CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
807-
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
844+
if (!is_KVC) {
845+
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
846+
CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
847+
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
848+
} else {
849+
CHECK_SHAPE(kcache, num_blocks, page_block_size, head_size_og);
850+
CHECK_SHAPE(vcache, num_blocks, page_block_size, head_size_og);
851+
// [ batch_size, kv_heads, blocks ]
852+
CHECK_SHAPE(block_table, batch_size, num_heads_k, max_num_blocks_per_seq);
853+
}
808854
}
809855

810856
at::Tensor q_padded, kcache_padded, vcache_padded;
@@ -865,8 +911,10 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
865911
softmax_scale,
866912
window_size_left,
867913
window_size_right,
868-
softcap
869-
);
914+
softcap,
915+
/*seqlenq_ngroups_swapped=*/false,
916+
/*unpadded_lse=*/false,
917+
/*is_kvc=*/is_KVC);
870918

871919
at::Tensor k, v, k_padded, v_padded;
872920
if (k_.has_value()) {
@@ -907,8 +955,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
907955
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
908956
CHECK_DEVICE(seqlens_k);
909957
CHECK_CONTIGUOUS(seqlens_k);
910-
CHECK_SHAPE(seqlens_k, batch_size);
958+
if (!is_KVC) {
959+
CHECK_SHAPE(seqlens_k, batch_size);
960+
} else {
961+
CHECK_SHAPE(seqlens_k, batch_size * num_heads_k);
962+
}
911963
params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr());
964+
params.seqlen_by_head = is_KVC;
912965
}
913966
params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());
914967

@@ -954,7 +1007,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
9541007

9551008
if (paged_KV) {
9561009
params.block_table = block_table.data_ptr<int>();
957-
params.block_table_batch_stride = block_table.stride(0);
1010+
if (!is_KVC) {
1011+
params.block_table_batch_stride = block_table.stride(0);
1012+
} else {
1013+
params.kseqlen_batch_stride = num_heads_k;
1014+
params.block_table_batch_stride = block_table.stride(0);
1015+
params.block_table_head_stride = block_table.stride(1);
1016+
}
9581017
}
9591018
params.page_block_size = page_block_size;
9601019

csrc/flash_attn/src/block_info.h

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,26 @@ template<bool Varlen=true>
1212
struct BlockInfo {
1313

1414
template<typename Params>
15-
__device__ BlockInfo(const Params &params, const int bidb)
15+
__device__ BlockInfo(const Params &params, const int bidb, const int bidkh)
1616
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
17-
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
17+
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : (
18+
((bidkh < 0) || !params.seqlen_by_head) ? (params.cu_seqlens_k[bidb]) :
19+
(params.cu_seqlens_k[bidb * params.kseqlen_batch_stride + bidkh])
20+
))
1821
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
1922
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
2023
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
21-
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
24+
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (
25+
params.is_seqlens_k_cumulative ?
26+
(
27+
((bidkh < 0) || !params.seqlen_by_head) ? (params.cu_seqlens_k[bidb + 1] - sum_s_k) :
28+
(params.cu_seqlens_k[bidb * params.kseqlen_batch_stride + bidkh + 1] - sum_s_k)
29+
) :
30+
(
31+
((bidkh < 0) || !params.seqlen_by_head) ? params.cu_seqlens_k[bidb] :
32+
params.cu_seqlens_k[bidb * params.kseqlen_batch_stride + bidkh]
33+
)
34+
))
2235
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
2336
{
2437
}

csrc/flash_attn/src/flash.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,11 @@ struct Flash_fwd_params : public Qkv_params {
103103

104104
// Paged KV cache
105105
int * __restrict__ block_table;
106+
bool is_kvc_cache;
107+
bool seqlen_by_head;
108+
index_t kseqlen_batch_stride;
106109
index_t block_table_batch_stride;
110+
index_t block_table_head_stride;
107111
int page_block_size;
108112

109113
// The dropout probability (probability of keeping an activation).

0 commit comments

Comments
 (0)