Skip to content

Commit

Permalink
support KV-Compress paged KV cache
Browse files Browse the repository at this point in the history
  • Loading branch information
IsaacRe committed Nov 27, 2024
1 parent 5259c58 commit 3aa31f1
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 53 deletions.
103 changes: 81 additions & 22 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ void set_params_fprop(Flash_fwd_params &params,
int window_size_right,
const float softcap,
bool seqlenq_ngroups_swapped=false,
const bool unpadded_lse=false) {
const bool unpadded_lse=false,
const bool is_kvc=false) {

// Reset the parameters
params = {};
Expand All @@ -59,11 +60,19 @@ void set_params_fprop(Flash_fwd_params &params,
params.v_ptr = v.data_ptr();
// All stride are in elements, not bytes.
params.q_row_stride = q.stride(-3);
params.k_row_stride = k.stride(-3);
params.v_row_stride = v.stride(-3);
params.q_head_stride = q.stride(-2);
params.k_head_stride = k.stride(-2);
params.v_head_stride = v.stride(-2);
params.is_kvc_cache = is_kvc;
if (!is_kvc) {
params.k_row_stride = k.stride(-3);
params.v_row_stride = v.stride(-3);
params.k_head_stride = k.stride(-2);
params.v_head_stride = v.stride(-2);
} else {
params.k_row_stride = k.stride(1);
params.v_row_stride = v.stride(1);
// head stride not used
}

params.o_ptr = out.data_ptr();
params.o_row_stride = out.stride(-3);
params.o_head_stride = out.stride(-2);
Expand Down Expand Up @@ -159,6 +168,7 @@ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split
HEADDIM_SWITCH(params.d, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
assert(false);
run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
} else {
run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>(params, stream);
Expand Down Expand Up @@ -502,6 +512,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
}
const bool is_KVC = paged_KV && (block_table.dim() > 2);


TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
Expand All @@ -514,11 +526,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
const int batch_size = cu_seqlens_q.numel() - 1;
int num_heads = sizes[1];
const int head_size_og = sizes[2];
const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
const int num_heads_k = paged_KV ? (!is_KVC ? k.size(2): block_table.size(1)) : k.size(1);

if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }

const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
const int max_num_blocks_per_seq = !paged_KV ? 0 : (!is_KVC ? block_table.size(1) : block_table.size(2));
const int num_blocks = !paged_KV ? 0 : k.size(0);
const int page_block_size = !paged_KV ? 1 : k.size(1);
TORCH_CHECK(!paged_KV || page_block_size % 16 == 0, "Paged KV cache block size must be divisible by 16");
Expand Down Expand Up @@ -554,13 +566,29 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
} else {
CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
if (!is_KVC) {
CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
} else {
CHECK_SHAPE(k, num_blocks, page_block_size, head_size_og);
CHECK_SHAPE(v, num_blocks, page_block_size, head_size_og);
// [ batch_size, kv_heads, blocks ]
// printf("batch_size=%d, num_heads_k=%d, max_num_blocks_per_seq=%d",
// batch_size, num_heads_k, max_num_blocks_per_seq);
// std::cout << "block_tables shape\n" << block_table.sizes() << std::endl;
CHECK_SHAPE(block_table, batch_size, num_heads_k, max_num_blocks_per_seq);
}
}

bool seqlen_by_head = false;
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
if (!is_KVC) {
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
} else {
seqlen_by_head = cu_seqlens_k.size(0) > batch_size + 1;
// CHECK_SHAPE(cu_seqlens_k, batch_size + 1 + batch_size * num_heads_k + 1);
}
if (seqused_k.has_value()){
auto seqused_k_ = seqused_k.value();
TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
Expand Down Expand Up @@ -639,12 +667,22 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
window_size_right,
softcap,
seqlenq_ngroups_swapped,
/*unpadded_lse*/true);
/*unpadded_lse*/true,
/*is_kvc*/is_KVC);
params.total_q = total_q;
params.seqlen_by_head = seqlen_by_head;

if (paged_KV) {
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
if (!is_KVC) {
params.block_table_batch_stride = block_table.stride(0);
} else {
params.kseqlen_batch_stride = num_heads_k;
params.block_table_batch_stride = block_table.stride(0);
params.block_table_head_stride = block_table.stride(1);
}
// std::cout << "\n" << k_padded.strides() << std::endl;
// std::cout << k_padded.sizes() << std::endl;
params.k_batch_stride = k_padded.stride(0);
params.v_batch_stride = v_padded.stride(0);
}
Expand Down Expand Up @@ -759,6 +797,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
}
const bool is_KVC = paged_KV && (block_table.dim() > 2);

const auto sizes = q.sizes();

Expand All @@ -769,12 +808,12 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
const int num_heads_og = num_heads;
const int head_size_og = sizes[3];

const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
const int max_num_blocks_per_seq = !paged_KV ? 0 : (!is_KVC ? block_table.size(1) : block_table.size(2));
const int num_blocks = !paged_KV ? 0 : kcache.size(0);
const int page_block_size = !paged_KV ? 1 : kcache.size(1);
TORCH_CHECK(!paged_KV || page_block_size % 16 == 0, "Paged KV cache block size must be divisible by 16");
const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
const int num_heads_k = kcache.size(2);
const int num_heads_k = !is_KVC ? kcache.size(2) : block_table.size(1);
const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
Expand Down Expand Up @@ -802,9 +841,16 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
} else {
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
if (!is_KVC) {
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
} else {
CHECK_SHAPE(kcache, num_blocks, page_block_size, head_size_og);
CHECK_SHAPE(vcache, num_blocks, page_block_size, head_size_og);
// [ batch_size, kv_heads, blocks ]
CHECK_SHAPE(block_table, batch_size, num_heads_k, max_num_blocks_per_seq);
}
}

at::Tensor q_padded, kcache_padded, vcache_padded;
Expand Down Expand Up @@ -865,8 +911,10 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
softmax_scale,
window_size_left,
window_size_right,
softcap
);
softcap,
/*seqlenq_ngroups_swapped=*/false,
/*unpadded_lse=*/false,
/*is_kvc=*/is_KVC);

at::Tensor k, v, k_padded, v_padded;
if (k_.has_value()) {
Expand Down Expand Up @@ -907,8 +955,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
CHECK_DEVICE(seqlens_k);
CHECK_CONTIGUOUS(seqlens_k);
CHECK_SHAPE(seqlens_k, batch_size);
if (!is_KVC) {
CHECK_SHAPE(seqlens_k, batch_size);
} else {
CHECK_SHAPE(seqlens_k, batch_size * num_heads_k);
}
params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr());
params.seqlen_by_head = is_KVC;
}
params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());

Expand Down Expand Up @@ -954,7 +1007,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he

if (paged_KV) {
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
if (!is_KVC) {
params.block_table_batch_stride = block_table.stride(0);
} else {
params.kseqlen_batch_stride = num_heads_k;
params.block_table_batch_stride = block_table.stride(0);
params.block_table_head_stride = block_table.stride(1);
}
}
params.page_block_size = page_block_size;

Expand Down
19 changes: 16 additions & 3 deletions csrc/flash_attn/src/block_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,26 @@ template<bool Varlen=true>
struct BlockInfo {

template<typename Params>
__device__ BlockInfo(const Params &params, const int bidb)
__device__ BlockInfo(const Params &params, const int bidb, const int bidkh)
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : (
((bidkh < 0) || !params.seqlen_by_head) ? (params.cu_seqlens_k[bidb]) :
(params.cu_seqlens_k[bidb * params.kseqlen_batch_stride + bidkh])
))
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (
params.is_seqlens_k_cumulative ?
(
((bidkh < 0) || !params.seqlen_by_head) ? (params.cu_seqlens_k[bidb + 1] - sum_s_k) :
(params.cu_seqlens_k[bidb * params.kseqlen_batch_stride + bidkh + 1] - sum_s_k)
) :
(
((bidkh < 0) || !params.seqlen_by_head) ? params.cu_seqlens_k[bidb] :
params.cu_seqlens_k[bidb * params.kseqlen_batch_stride + bidkh]
)
))
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
{
}
Expand Down
4 changes: 4 additions & 0 deletions csrc/flash_attn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ struct Flash_fwd_params : public Qkv_params {

// Paged KV cache
int * __restrict__ block_table;
bool is_kvc_cache;
bool seqlen_by_head;
index_t kseqlen_batch_stride;
index_t block_table_batch_stride;
index_t block_table_head_stride;
int page_block_size;

// The dropout probability (probability of keeping an activation).
Expand Down
Loading

0 comments on commit 3aa31f1

Please sign in to comment.