@@ -46,7 +46,8 @@ void set_params_fprop(Flash_fwd_params ¶ms,
46
46
int window_size_right,
47
47
const float softcap,
48
48
bool seqlenq_ngroups_swapped=false ,
49
- const bool unpadded_lse=false ) {
49
+ const bool unpadded_lse=false ,
50
+ const bool is_kvc=false ) {
50
51
51
52
// Reset the parameters
52
53
params = {};
@@ -59,11 +60,19 @@ void set_params_fprop(Flash_fwd_params ¶ms,
59
60
params.v_ptr = v.data_ptr ();
60
61
// All stride are in elements, not bytes.
61
62
params.q_row_stride = q.stride (-3 );
62
- params.k_row_stride = k.stride (-3 );
63
- params.v_row_stride = v.stride (-3 );
64
63
params.q_head_stride = q.stride (-2 );
65
- params.k_head_stride = k.stride (-2 );
66
- params.v_head_stride = v.stride (-2 );
64
+ params.is_kvc_cache = is_kvc;
65
+ if (!is_kvc) {
66
+ params.k_row_stride = k.stride (-3 );
67
+ params.v_row_stride = v.stride (-3 );
68
+ params.k_head_stride = k.stride (-2 );
69
+ params.v_head_stride = v.stride (-2 );
70
+ } else {
71
+ params.k_row_stride = k.stride (1 );
72
+ params.v_row_stride = v.stride (1 );
73
+ // head stride not used
74
+ }
75
+
67
76
params.o_ptr = out.data_ptr ();
68
77
params.o_row_stride = out.stride (-3 );
69
78
params.o_head_stride = out.stride (-2 );
@@ -502,6 +511,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
502
511
TORCH_CHECK (block_table.dtype () == torch::kInt32 , " block_table must have dtype torch.int32" );
503
512
TORCH_CHECK (block_table.stride (-1 ) == 1 , " block_table must have contiguous last dimension" );
504
513
}
514
+ const bool is_KVC = paged_KV && (block_table.dim () > 2 );
515
+
505
516
506
517
TORCH_CHECK (q.stride (-1 ) == 1 , " Input tensor must have contiguous last dimension" );
507
518
TORCH_CHECK (k.stride (-1 ) == 1 , " Input tensor must have contiguous last dimension" );
@@ -514,11 +525,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
514
525
const int batch_size = cu_seqlens_q.numel () - 1 ;
515
526
int num_heads = sizes[1 ];
516
527
const int head_size_og = sizes[2 ];
517
- const int num_heads_k = paged_KV ? k.size (2 ) : k.size (1 );
528
+ const int num_heads_k = paged_KV ? (!is_KVC ? k.size (2 ): block_table. size ( 1 ) ) : k.size (1 );
518
529
519
530
if (softcap > 0 .f ) { TORCH_CHECK (p_dropout == 0 .f , " Softcapping does not support dropout for now" ); }
520
531
521
- const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size (1 );
532
+ const int max_num_blocks_per_seq = !paged_KV ? 0 : (!is_KVC ? block_table.size (1 ) : block_table. size ( 2 ) );
522
533
const int num_blocks = !paged_KV ? 0 : k.size (0 );
523
534
const int page_block_size = !paged_KV ? 1 : k.size (1 );
524
535
TORCH_CHECK (!paged_KV || page_block_size % 16 == 0 , " Paged KV cache block size must be divisible by 16" );
@@ -554,13 +565,29 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
554
565
CHECK_SHAPE (k, total_k, num_heads_k, head_size_og);
555
566
CHECK_SHAPE (v, total_k, num_heads_k, head_size_og);
556
567
} else {
557
- CHECK_SHAPE (k, num_blocks, page_block_size, num_heads_k, head_size_og);
558
- CHECK_SHAPE (v, num_blocks, page_block_size, num_heads_k, head_size_og);
559
- CHECK_SHAPE (block_table, batch_size, max_num_blocks_per_seq);
568
+ if (!is_KVC) {
569
+ CHECK_SHAPE (k, num_blocks, page_block_size, num_heads_k, head_size_og);
570
+ CHECK_SHAPE (v, num_blocks, page_block_size, num_heads_k, head_size_og);
571
+ CHECK_SHAPE (block_table, batch_size, max_num_blocks_per_seq);
572
+ } else {
573
+ CHECK_SHAPE (k, num_blocks, page_block_size, head_size_og);
574
+ CHECK_SHAPE (v, num_blocks, page_block_size, head_size_og);
575
+ // [ batch_size, kv_heads, blocks ]
576
+ // printf("batch_size=%d, num_heads_k=%d, max_num_blocks_per_seq=%d",
577
+ // batch_size, num_heads_k, max_num_blocks_per_seq);
578
+ // std::cout << "block_tables shape\n" << block_table.sizes() << std::endl;
579
+ CHECK_SHAPE (block_table, batch_size, num_heads_k, max_num_blocks_per_seq);
580
+ }
560
581
}
561
582
583
+ bool seqlen_by_head = false ;
562
584
CHECK_SHAPE (cu_seqlens_q, batch_size + 1 );
563
- CHECK_SHAPE (cu_seqlens_k, batch_size + 1 );
585
+ if (!is_KVC) {
586
+ CHECK_SHAPE (cu_seqlens_k, batch_size + 1 );
587
+ } else {
588
+ seqlen_by_head = cu_seqlens_k.size (0 ) > batch_size + 1 ;
589
+ // CHECK_SHAPE(cu_seqlens_k, batch_size + 1 + batch_size * num_heads_k + 1);
590
+ }
564
591
if (seqused_k.has_value ()){
565
592
auto seqused_k_ = seqused_k.value ();
566
593
TORCH_CHECK (seqused_k_.dtype () == torch::kInt32 , " seqused_k must have dtype int32" );
@@ -639,12 +666,22 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
639
666
window_size_right,
640
667
softcap,
641
668
seqlenq_ngroups_swapped,
642
- /* unpadded_lse*/ true );
669
+ /* unpadded_lse*/ true ,
670
+ /* is_kvc*/ is_KVC);
643
671
params.total_q = total_q;
672
+ params.seqlen_by_head = seqlen_by_head;
644
673
645
674
if (paged_KV) {
646
675
params.block_table = block_table.data_ptr <int >();
647
- params.block_table_batch_stride = block_table.stride (0 );
676
+ if (!is_KVC) {
677
+ params.block_table_batch_stride = block_table.stride (0 );
678
+ } else {
679
+ params.kseqlen_batch_stride = num_heads_k;
680
+ params.block_table_batch_stride = block_table.stride (0 );
681
+ params.block_table_head_stride = block_table.stride (1 );
682
+ }
683
+ // std::cout << "\n" << k_padded.strides() << std::endl;
684
+ // std::cout << k_padded.sizes() << std::endl;
648
685
params.k_batch_stride = k_padded.stride (0 );
649
686
params.v_batch_stride = v_padded.stride (0 );
650
687
}
@@ -759,6 +796,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
759
796
TORCH_CHECK (block_table.dtype () == torch::kInt32 , " block_table must have dtype torch.int32" );
760
797
TORCH_CHECK (block_table.stride (-1 ) == 1 , " block_table must have contiguous last dimension" );
761
798
}
799
+ const bool is_KVC = paged_KV && (block_table.dim () > 2 );
762
800
763
801
const auto sizes = q.sizes ();
764
802
@@ -769,12 +807,12 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
769
807
const int num_heads_og = num_heads;
770
808
const int head_size_og = sizes[3 ];
771
809
772
- const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size (1 );
810
+ const int max_num_blocks_per_seq = !paged_KV ? 0 : (!is_KVC ? block_table.size (1 ) : block_table. size ( 2 ) );
773
811
const int num_blocks = !paged_KV ? 0 : kcache.size (0 );
774
812
const int page_block_size = !paged_KV ? 1 : kcache.size (1 );
775
813
TORCH_CHECK (!paged_KV || page_block_size % 16 == 0 , " Paged KV cache block size must be divisible by 16" );
776
814
const int seqlen_k = !paged_KV ? kcache.size (1 ) : max_num_blocks_per_seq * page_block_size;
777
- const int num_heads_k = kcache.size (2 );
815
+ const int num_heads_k = !is_KVC ? kcache.size (2 ) : block_table. size ( 1 );
778
816
const int batch_size_c = !paged_KV ? kcache.size (0 ) : batch_size;
779
817
TORCH_CHECK (batch_size > 0 , " batch size must be postive" );
780
818
TORCH_CHECK (head_size_og <= 256 , " FlashAttention forward only supports head dimension at most 256" );
@@ -802,9 +840,16 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
802
840
CHECK_SHAPE (kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
803
841
CHECK_SHAPE (vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
804
842
} else {
805
- CHECK_SHAPE (kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
806
- CHECK_SHAPE (vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
807
- CHECK_SHAPE (block_table, batch_size, max_num_blocks_per_seq);
843
+ if (!is_KVC) {
844
+ CHECK_SHAPE (kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
845
+ CHECK_SHAPE (vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
846
+ CHECK_SHAPE (block_table, batch_size, max_num_blocks_per_seq);
847
+ } else {
848
+ CHECK_SHAPE (kcache, num_blocks, page_block_size, head_size_og);
849
+ CHECK_SHAPE (vcache, num_blocks, page_block_size, head_size_og);
850
+ // [ batch_size, kv_heads, blocks ]
851
+ CHECK_SHAPE (block_table, batch_size, num_heads_k, max_num_blocks_per_seq);
852
+ }
808
853
}
809
854
810
855
at::Tensor q_padded, kcache_padded, vcache_padded;
@@ -865,8 +910,10 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
865
910
softmax_scale,
866
911
window_size_left,
867
912
window_size_right,
868
- softcap
869
- );
913
+ softcap,
914
+ /* seqlenq_ngroups_swapped=*/ false ,
915
+ /* unpadded_lse=*/ false ,
916
+ /* is_kvc=*/ is_KVC);
870
917
871
918
at::Tensor k, v, k_padded, v_padded;
872
919
if (k_.has_value ()) {
@@ -907,8 +954,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
907
954
TORCH_CHECK (seqlens_k.dtype () == torch::kInt32 , " seqlens_k must have dtype int32" );
908
955
CHECK_DEVICE (seqlens_k);
909
956
CHECK_CONTIGUOUS (seqlens_k);
910
- CHECK_SHAPE (seqlens_k, batch_size);
957
+ if (!is_KVC) {
958
+ CHECK_SHAPE (seqlens_k, batch_size);
959
+ } else {
960
+ CHECK_SHAPE (seqlens_k, batch_size * num_heads_k);
961
+ }
911
962
params.cu_seqlens_k = static_cast <int *>(seqlens_k.data_ptr ());
963
+ params.seqlen_by_head = is_KVC;
912
964
}
913
965
params.is_seqlens_k_cumulative = !(seqlens_k_.has_value ());
914
966
@@ -954,7 +1006,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
954
1006
955
1007
if (paged_KV) {
956
1008
params.block_table = block_table.data_ptr <int >();
957
- params.block_table_batch_stride = block_table.stride (0 );
1009
+ if (!is_KVC) {
1010
+ params.block_table_batch_stride = block_table.stride (0 );
1011
+ } else {
1012
+ params.kseqlen_batch_stride = num_heads_k;
1013
+ params.block_table_batch_stride = block_table.stride (0 );
1014
+ params.block_table_head_stride = block_table.stride (1 );
1015
+ }
958
1016
}
959
1017
params.page_block_size = page_block_size;
960
1018
0 commit comments