Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support KV-Compress paged KV cache #27

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

IsaacRe
Copy link

@IsaacRe IsaacRe commented Nov 27, 2024

This PR adds support for paged KV cache following the structure of KV-Compress, where cache blocks are paged out on a per-head basis.

In this case the cache shape becomes [num_blocks, block_size, d] (rather than [num_blocks, block_size, num_heads, d]) and the block table shape becomes [num_seqs, num_heads, max_blocks_per_seq] (rather than [num_seqs, max_blocks_per_seq]). Sequence lengths also need to be specified per-head, so that a (num_seqs * num_heads) or (num_seqs * num_heads + 1)-sized tensor is provided (depending on whether the sequence length or cumulative offsets are used).

I configured it to use the dimensionality of the block tables tensor to detect whether a KV-Compress cache is being used, assuming KV-Compress when dim > 2 and following the existing logic otherwise.

@@ -159,6 +168,7 @@ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split
HEADDIM_SWITCH(params.d, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
assert(false);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is going on here? Perhaps you mean to also remove the call that follows?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. That was just to make sure it was launching the right kernel when I was testing. It should be removed

@WoosukKwon
Copy link

@IsaacRe Amazing! Excited to see this work 🚀 BTW, are you in the vLLM slack workspace?

@IsaacRe
Copy link
Author

IsaacRe commented Nov 28, 2024

@IsaacRe Amazing! Excited to see this work 🚀 BTW, are you in the vLLM slack workspace?

Thanks! Yes I am. I'm wrapping up chunked-prefill compat and will update in the channel when benchmarks are done

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants