@@ -46,7 +46,8 @@ void set_params_fprop(Flash_fwd_params ¶ms,
46
46
int window_size_right,
47
47
const float softcap,
48
48
bool seqlenq_ngroups_swapped=false ,
49
- const bool unpadded_lse=false ) {
49
+ const bool unpadded_lse=false ,
50
+ const bool is_kvc=false ) {
50
51
51
52
// Reset the parameters
52
53
params = {};
@@ -59,11 +60,19 @@ void set_params_fprop(Flash_fwd_params ¶ms,
59
60
params.v_ptr = v.data_ptr ();
60
61
// All stride are in elements, not bytes.
61
62
params.q_row_stride = q.stride (-3 );
62
- params.k_row_stride = k.stride (-3 );
63
- params.v_row_stride = v.stride (-3 );
64
63
params.q_head_stride = q.stride (-2 );
65
- params.k_head_stride = k.stride (-2 );
66
- params.v_head_stride = v.stride (-2 );
64
+ params.is_kvc_cache = is_kvc;
65
+ if (!is_kvc) {
66
+ params.k_row_stride = k.stride (-3 );
67
+ params.v_row_stride = v.stride (-3 );
68
+ params.k_head_stride = k.stride (-2 );
69
+ params.v_head_stride = v.stride (-2 );
70
+ } else {
71
+ params.k_row_stride = k.stride (1 );
72
+ params.v_row_stride = v.stride (1 );
73
+ // head stride not used
74
+ }
75
+
67
76
params.o_ptr = out.data_ptr ();
68
77
params.o_row_stride = out.stride (-3 );
69
78
params.o_head_stride = out.stride (-2 );
@@ -159,6 +168,7 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split
159
168
HEADDIM_SWITCH (params.d , [&] {
160
169
BOOL_SWITCH (params.is_causal , Is_causal, [&] {
161
170
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
171
+ assert (false );
162
172
run_mha_fwd_<elem_type, kHeadDim , Is_causal>(params, stream);
163
173
} else {
164
174
run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim , Is_causal>(params, stream);
@@ -502,6 +512,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
502
512
TORCH_CHECK (block_table.dtype () == torch::kInt32 , " block_table must have dtype torch.int32" );
503
513
TORCH_CHECK (block_table.stride (-1 ) == 1 , " block_table must have contiguous last dimension" );
504
514
}
515
+ const bool is_KVC = paged_KV && (block_table.dim () > 2 );
516
+
505
517
506
518
TORCH_CHECK (q.stride (-1 ) == 1 , " Input tensor must have contiguous last dimension" );
507
519
TORCH_CHECK (k.stride (-1 ) == 1 , " Input tensor must have contiguous last dimension" );
@@ -514,11 +526,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
514
526
const int batch_size = cu_seqlens_q.numel () - 1 ;
515
527
int num_heads = sizes[1 ];
516
528
const int head_size_og = sizes[2 ];
517
- const int num_heads_k = paged_KV ? k.size (2 ) : k.size (1 );
529
+ const int num_heads_k = paged_KV ? (!is_KVC ? k.size (2 ): block_table. size ( 1 ) ) : k.size (1 );
518
530
519
531
if (softcap > 0 .f ) { TORCH_CHECK (p_dropout == 0 .f , " Softcapping does not support dropout for now" ); }
520
532
521
- const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size (1 );
533
+ const int max_num_blocks_per_seq = !paged_KV ? 0 : (!is_KVC ? block_table.size (1 ) : block_table. size ( 2 ) );
522
534
const int num_blocks = !paged_KV ? 0 : k.size (0 );
523
535
const int page_block_size = !paged_KV ? 1 : k.size (1 );
524
536
TORCH_CHECK (!paged_KV || page_block_size % 16 == 0 , " Paged KV cache block size must be divisible by 16" );
@@ -554,13 +566,29 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
554
566
CHECK_SHAPE (k, total_k, num_heads_k, head_size_og);
555
567
CHECK_SHAPE (v, total_k, num_heads_k, head_size_og);
556
568
} else {
557
- CHECK_SHAPE (k, num_blocks, page_block_size, num_heads_k, head_size_og);
558
- CHECK_SHAPE (v, num_blocks, page_block_size, num_heads_k, head_size_og);
559
- CHECK_SHAPE (block_table, batch_size, max_num_blocks_per_seq);
569
+ if (!is_KVC) {
570
+ CHECK_SHAPE (k, num_blocks, page_block_size, num_heads_k, head_size_og);
571
+ CHECK_SHAPE (v, num_blocks, page_block_size, num_heads_k, head_size_og);
572
+ CHECK_SHAPE (block_table, batch_size, max_num_blocks_per_seq);
573
+ } else {
574
+ CHECK_SHAPE (k, num_blocks, page_block_size, head_size_og);
575
+ CHECK_SHAPE (v, num_blocks, page_block_size, head_size_og);
576
+ // [ batch_size, kv_heads, blocks ]
577
+ // printf("batch_size=%d, num_heads_k=%d, max_num_blocks_per_seq=%d",
578
+ // batch_size, num_heads_k, max_num_blocks_per_seq);
579
+ // std::cout << "block_tables shape\n" << block_table.sizes() << std::endl;
580
+ CHECK_SHAPE (block_table, batch_size, num_heads_k, max_num_blocks_per_seq);
581
+ }
560
582
}
561
583
584
+ bool seqlen_by_head = false ;
562
585
CHECK_SHAPE (cu_seqlens_q, batch_size + 1 );
563
- CHECK_SHAPE (cu_seqlens_k, batch_size + 1 );
586
+ if (!is_KVC) {
587
+ CHECK_SHAPE (cu_seqlens_k, batch_size + 1 );
588
+ } else {
589
+ seqlen_by_head = cu_seqlens_k.size (0 ) > batch_size + 1 ;
590
+ // CHECK_SHAPE(cu_seqlens_k, batch_size + 1 + batch_size * num_heads_k + 1);
591
+ }
564
592
if (seqused_k.has_value ()){
565
593
auto seqused_k_ = seqused_k.value ();
566
594
TORCH_CHECK (seqused_k_.dtype () == torch::kInt32 , " seqused_k must have dtype int32" );
@@ -639,12 +667,22 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
639
667
window_size_right,
640
668
softcap,
641
669
seqlenq_ngroups_swapped,
642
- /* unpadded_lse*/ true );
670
+ /* unpadded_lse*/ true ,
671
+ /* is_kvc*/ is_KVC);
643
672
params.total_q = total_q;
673
+ params.seqlen_by_head = seqlen_by_head;
644
674
645
675
if (paged_KV) {
646
676
params.block_table = block_table.data_ptr <int >();
647
- params.block_table_batch_stride = block_table.stride (0 );
677
+ if (!is_KVC) {
678
+ params.block_table_batch_stride = block_table.stride (0 );
679
+ } else {
680
+ params.kseqlen_batch_stride = num_heads_k;
681
+ params.block_table_batch_stride = block_table.stride (0 );
682
+ params.block_table_head_stride = block_table.stride (1 );
683
+ }
684
+ // std::cout << "\n" << k_padded.strides() << std::endl;
685
+ // std::cout << k_padded.sizes() << std::endl;
648
686
params.k_batch_stride = k_padded.stride (0 );
649
687
params.v_batch_stride = v_padded.stride (0 );
650
688
}
@@ -759,6 +797,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
759
797
TORCH_CHECK (block_table.dtype () == torch::kInt32 , " block_table must have dtype torch.int32" );
760
798
TORCH_CHECK (block_table.stride (-1 ) == 1 , " block_table must have contiguous last dimension" );
761
799
}
800
+ const bool is_KVC = paged_KV && (block_table.dim () > 2 );
762
801
763
802
const auto sizes = q.sizes ();
764
803
@@ -769,12 +808,12 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
769
808
const int num_heads_og = num_heads;
770
809
const int head_size_og = sizes[3 ];
771
810
772
- const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size (1 );
811
+ const int max_num_blocks_per_seq = !paged_KV ? 0 : (!is_KVC ? block_table.size (1 ) : block_table. size ( 2 ) );
773
812
const int num_blocks = !paged_KV ? 0 : kcache.size (0 );
774
813
const int page_block_size = !paged_KV ? 1 : kcache.size (1 );
775
814
TORCH_CHECK (!paged_KV || page_block_size % 16 == 0 , " Paged KV cache block size must be divisible by 16" );
776
815
const int seqlen_k = !paged_KV ? kcache.size (1 ) : max_num_blocks_per_seq * page_block_size;
777
- const int num_heads_k = kcache.size (2 );
816
+ const int num_heads_k = !is_KVC ? kcache.size (2 ) : block_table. size ( 1 );
778
817
const int batch_size_c = !paged_KV ? kcache.size (0 ) : batch_size;
779
818
TORCH_CHECK (batch_size > 0 , " batch size must be postive" );
780
819
TORCH_CHECK (head_size_og <= 256 , " FlashAttention forward only supports head dimension at most 256" );
@@ -802,9 +841,16 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
802
841
CHECK_SHAPE (kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
803
842
CHECK_SHAPE (vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
804
843
} else {
805
- CHECK_SHAPE (kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
806
- CHECK_SHAPE (vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
807
- CHECK_SHAPE (block_table, batch_size, max_num_blocks_per_seq);
844
+ if (!is_KVC) {
845
+ CHECK_SHAPE (kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
846
+ CHECK_SHAPE (vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
847
+ CHECK_SHAPE (block_table, batch_size, max_num_blocks_per_seq);
848
+ } else {
849
+ CHECK_SHAPE (kcache, num_blocks, page_block_size, head_size_og);
850
+ CHECK_SHAPE (vcache, num_blocks, page_block_size, head_size_og);
851
+ // [ batch_size, kv_heads, blocks ]
852
+ CHECK_SHAPE (block_table, batch_size, num_heads_k, max_num_blocks_per_seq);
853
+ }
808
854
}
809
855
810
856
at::Tensor q_padded, kcache_padded, vcache_padded;
@@ -865,8 +911,10 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
865
911
softmax_scale,
866
912
window_size_left,
867
913
window_size_right,
868
- softcap
869
- );
914
+ softcap,
915
+ /* seqlenq_ngroups_swapped=*/ false ,
916
+ /* unpadded_lse=*/ false ,
917
+ /* is_kvc=*/ is_KVC);
870
918
871
919
at::Tensor k, v, k_padded, v_padded;
872
920
if (k_.has_value ()) {
@@ -907,8 +955,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
907
955
TORCH_CHECK (seqlens_k.dtype () == torch::kInt32 , " seqlens_k must have dtype int32" );
908
956
CHECK_DEVICE (seqlens_k);
909
957
CHECK_CONTIGUOUS (seqlens_k);
910
- CHECK_SHAPE (seqlens_k, batch_size);
958
+ if (!is_KVC) {
959
+ CHECK_SHAPE (seqlens_k, batch_size);
960
+ } else {
961
+ CHECK_SHAPE (seqlens_k, batch_size * num_heads_k);
962
+ }
911
963
params.cu_seqlens_k = static_cast <int *>(seqlens_k.data_ptr ());
964
+ params.seqlen_by_head = is_KVC;
912
965
}
913
966
params.is_seqlens_k_cumulative = !(seqlens_k_.has_value ());
914
967
@@ -954,7 +1007,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
954
1007
955
1008
if (paged_KV) {
956
1009
params.block_table = block_table.data_ptr <int >();
957
- params.block_table_batch_stride = block_table.stride (0 );
1010
+ if (!is_KVC) {
1011
+ params.block_table_batch_stride = block_table.stride (0 );
1012
+ } else {
1013
+ params.kseqlen_batch_stride = num_heads_k;
1014
+ params.block_table_batch_stride = block_table.stride (0 );
1015
+ params.block_table_head_stride = block_table.stride (1 );
1016
+ }
958
1017
}
959
1018
params.page_block_size = page_block_size;
960
1019
0 commit comments