Skip to content

Commit 28f6637

Browse files
committed
support KV-Compress paged KV cache
Signed-off-by: Isaac Rehg <[email protected]>
1 parent 5259c58 commit 28f6637

File tree

5 files changed

+242
-53
lines changed

5 files changed

+242
-53
lines changed

csrc/flash_attn/flash_api.cpp

Lines changed: 80 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ void set_params_fprop(Flash_fwd_params &params,
4646
int window_size_right,
4747
const float softcap,
4848
bool seqlenq_ngroups_swapped=false,
49-
const bool unpadded_lse=false) {
49+
const bool unpadded_lse=false,
50+
const bool is_kvc=false) {
5051

5152
// Reset the parameters
5253
params = {};
@@ -59,11 +60,19 @@ void set_params_fprop(Flash_fwd_params &params,
5960
params.v_ptr = v.data_ptr();
6061
// All stride are in elements, not bytes.
6162
params.q_row_stride = q.stride(-3);
62-
params.k_row_stride = k.stride(-3);
63-
params.v_row_stride = v.stride(-3);
6463
params.q_head_stride = q.stride(-2);
65-
params.k_head_stride = k.stride(-2);
66-
params.v_head_stride = v.stride(-2);
64+
params.is_kvc_cache = is_kvc;
65+
if (!is_kvc) {
66+
params.k_row_stride = k.stride(-3);
67+
params.v_row_stride = v.stride(-3);
68+
params.k_head_stride = k.stride(-2);
69+
params.v_head_stride = v.stride(-2);
70+
} else {
71+
params.k_row_stride = k.stride(1);
72+
params.v_row_stride = v.stride(1);
73+
// head stride not used
74+
}
75+
6776
params.o_ptr = out.data_ptr();
6877
params.o_row_stride = out.stride(-3);
6978
params.o_head_stride = out.stride(-2);
@@ -502,6 +511,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
502511
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
503512
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
504513
}
514+
const bool is_KVC = paged_KV && (block_table.dim() > 2);
515+
505516

506517
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
507518
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
@@ -514,11 +525,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
514525
const int batch_size = cu_seqlens_q.numel() - 1;
515526
int num_heads = sizes[1];
516527
const int head_size_og = sizes[2];
517-
const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
528+
const int num_heads_k = paged_KV ? (!is_KVC ? k.size(2): block_table.size(1)) : k.size(1);
518529

519530
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
520531

521-
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
532+
const int max_num_blocks_per_seq = !paged_KV ? 0 : (!is_KVC ? block_table.size(1) : block_table.size(2));
522533
const int num_blocks = !paged_KV ? 0 : k.size(0);
523534
const int page_block_size = !paged_KV ? 1 : k.size(1);
524535
TORCH_CHECK(!paged_KV || page_block_size % 16 == 0, "Paged KV cache block size must be divisible by 16");
@@ -554,13 +565,29 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
554565
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
555566
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
556567
} else {
557-
CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og);
558-
CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og);
559-
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
568+
if (!is_KVC) {
569+
CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og);
570+
CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og);
571+
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
572+
} else {
573+
CHECK_SHAPE(k, num_blocks, page_block_size, head_size_og);
574+
CHECK_SHAPE(v, num_blocks, page_block_size, head_size_og);
575+
// [ batch_size, kv_heads, blocks ]
576+
// printf("batch_size=%d, num_heads_k=%d, max_num_blocks_per_seq=%d",
577+
// batch_size, num_heads_k, max_num_blocks_per_seq);
578+
// std::cout << "block_tables shape\n" << block_table.sizes() << std::endl;
579+
CHECK_SHAPE(block_table, batch_size, num_heads_k, max_num_blocks_per_seq);
580+
}
560581
}
561582

583+
bool seqlen_by_head = false;
562584
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
563-
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
585+
if (!is_KVC) {
586+
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
587+
} else {
588+
seqlen_by_head = cu_seqlens_k.size(0) > batch_size + 1;
589+
// CHECK_SHAPE(cu_seqlens_k, batch_size + 1 + batch_size * num_heads_k + 1);
590+
}
564591
if (seqused_k.has_value()){
565592
auto seqused_k_ = seqused_k.value();
566593
TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
@@ -639,12 +666,22 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
639666
window_size_right,
640667
softcap,
641668
seqlenq_ngroups_swapped,
642-
/*unpadded_lse*/true);
669+
/*unpadded_lse*/true,
670+
/*is_kvc*/is_KVC);
643671
params.total_q = total_q;
672+
params.seqlen_by_head = seqlen_by_head;
644673

645674
if (paged_KV) {
646675
params.block_table = block_table.data_ptr<int>();
647-
params.block_table_batch_stride = block_table.stride(0);
676+
if (!is_KVC) {
677+
params.block_table_batch_stride = block_table.stride(0);
678+
} else {
679+
params.kseqlen_batch_stride = num_heads_k;
680+
params.block_table_batch_stride = block_table.stride(0);
681+
params.block_table_head_stride = block_table.stride(1);
682+
}
683+
// std::cout << "\n" << k_padded.strides() << std::endl;
684+
// std::cout << k_padded.sizes() << std::endl;
648685
params.k_batch_stride = k_padded.stride(0);
649686
params.v_batch_stride = v_padded.stride(0);
650687
}
@@ -759,6 +796,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
759796
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
760797
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
761798
}
799+
const bool is_KVC = paged_KV && (block_table.dim() > 2);
762800

763801
const auto sizes = q.sizes();
764802

@@ -769,12 +807,12 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
769807
const int num_heads_og = num_heads;
770808
const int head_size_og = sizes[3];
771809

772-
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
810+
const int max_num_blocks_per_seq = !paged_KV ? 0 : (!is_KVC ? block_table.size(1) : block_table.size(2));
773811
const int num_blocks = !paged_KV ? 0 : kcache.size(0);
774812
const int page_block_size = !paged_KV ? 1 : kcache.size(1);
775813
TORCH_CHECK(!paged_KV || page_block_size % 16 == 0, "Paged KV cache block size must be divisible by 16");
776814
const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
777-
const int num_heads_k = kcache.size(2);
815+
const int num_heads_k = !is_KVC ? kcache.size(2) : block_table.size(1);
778816
const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
779817
TORCH_CHECK(batch_size > 0, "batch size must be postive");
780818
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
@@ -802,9 +840,16 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
802840
CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
803841
CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
804842
} else {
805-
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
806-
CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
807-
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
843+
if (!is_KVC) {
844+
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
845+
CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
846+
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
847+
} else {
848+
CHECK_SHAPE(kcache, num_blocks, page_block_size, head_size_og);
849+
CHECK_SHAPE(vcache, num_blocks, page_block_size, head_size_og);
850+
// [ batch_size, kv_heads, blocks ]
851+
CHECK_SHAPE(block_table, batch_size, num_heads_k, max_num_blocks_per_seq);
852+
}
808853
}
809854

810855
at::Tensor q_padded, kcache_padded, vcache_padded;
@@ -865,8 +910,10 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
865910
softmax_scale,
866911
window_size_left,
867912
window_size_right,
868-
softcap
869-
);
913+
softcap,
914+
/*seqlenq_ngroups_swapped=*/false,
915+
/*unpadded_lse=*/false,
916+
/*is_kvc=*/is_KVC);
870917

871918
at::Tensor k, v, k_padded, v_padded;
872919
if (k_.has_value()) {
@@ -907,8 +954,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
907954
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
908955
CHECK_DEVICE(seqlens_k);
909956
CHECK_CONTIGUOUS(seqlens_k);
910-
CHECK_SHAPE(seqlens_k, batch_size);
957+
if (!is_KVC) {
958+
CHECK_SHAPE(seqlens_k, batch_size);
959+
} else {
960+
CHECK_SHAPE(seqlens_k, batch_size * num_heads_k);
961+
}
911962
params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr());
963+
params.seqlen_by_head = is_KVC;
912964
}
913965
params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());
914966

@@ -954,7 +1006,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
9541006

9551007
if (paged_KV) {
9561008
params.block_table = block_table.data_ptr<int>();
957-
params.block_table_batch_stride = block_table.stride(0);
1009+
if (!is_KVC) {
1010+
params.block_table_batch_stride = block_table.stride(0);
1011+
} else {
1012+
params.kseqlen_batch_stride = num_heads_k;
1013+
params.block_table_batch_stride = block_table.stride(0);
1014+
params.block_table_head_stride = block_table.stride(1);
1015+
}
9581016
}
9591017
params.page_block_size = page_block_size;
9601018

csrc/flash_attn/src/block_info.h

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,26 @@ template<bool Varlen=true>
1212
struct BlockInfo {
1313

1414
template<typename Params>
15-
__device__ BlockInfo(const Params &params, const int bidb)
15+
__device__ BlockInfo(const Params &params, const int bidb, const int bidkh)
1616
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
17-
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
17+
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : (
18+
((bidkh < 0) || !params.seqlen_by_head) ? (params.cu_seqlens_k[bidb]) :
19+
(params.cu_seqlens_k[bidb * params.kseqlen_batch_stride + bidkh])
20+
))
1821
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
1922
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
2023
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
21-
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
24+
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (
25+
params.is_seqlens_k_cumulative ?
26+
(
27+
((bidkh < 0) || !params.seqlen_by_head) ? (params.cu_seqlens_k[bidb + 1] - sum_s_k) :
28+
(params.cu_seqlens_k[bidb * params.kseqlen_batch_stride + bidkh + 1] - sum_s_k)
29+
) :
30+
(
31+
((bidkh < 0) || !params.seqlen_by_head) ? params.cu_seqlens_k[bidb] :
32+
params.cu_seqlens_k[bidb * params.kseqlen_batch_stride + bidkh]
33+
)
34+
))
2235
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
2336
{
2437
}

csrc/flash_attn/src/flash.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,11 @@ struct Flash_fwd_params : public Qkv_params {
103103

104104
// Paged KV cache
105105
int * __restrict__ block_table;
106+
bool is_kvc_cache;
107+
bool seqlen_by_head;
108+
index_t kseqlen_batch_stride;
106109
index_t block_table_batch_stride;
110+
index_t block_table_head_stride;
107111
int page_block_size;
108112

109113
// The dropout probability (probability of keeping an activation).

0 commit comments

Comments
 (0)