Skip to content

[KVCache] Per Layer Sliding Window #17928

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions python/tvm/relax/frontend/nn/llm/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# pylint: disable=too-many-statements,too-many-lines,too-many-arguments,invalid-name
import enum
import math
from typing import Any, Dict, List, Literal, Optional, Tuple
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

import tvm
from tvm import relax as rx
Expand Down Expand Up @@ -86,6 +86,7 @@ class AttnKind(enum.IntEnum):

MHA = 0
MLA = 1
MHA_SLIDING = 3


class RopeMode(enum.IntEnum):
Expand Down Expand Up @@ -301,7 +302,7 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-me

def __init__( # pylint: disable=too-many-locals
self,
attn_kind: Literal["mha", "mla"],
attn_kind: Union[Literal["mha", "mla"], List[Literal["mha", "mla","mha_sliding"]]],
max_batch_size: tir.Var,
max_total_seq_len: tir.Var,
prefill_chunk_size: tir.Var,
Expand Down Expand Up @@ -377,8 +378,8 @@ def __init__( # pylint: disable=too-many-locals
dtype_q=dtype,
dtype_kv=dtype,
dtype_o=dtype,
qk_head_dim=qk_head_dim if attn_kind == "mha" else mla_original_qk_head_dim,
v_head_dim=v_head_dim if attn_kind == "mha" else mla_original_v_head_dim,
qk_head_dim=qk_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_qk_head_dim,
v_head_dim=v_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_v_head_dim,
target=target,
enable_inline_rope=rope_mode == RopeMode.INLINE,
)
Expand All @@ -391,7 +392,7 @@ def __init__( # pylint: disable=too-many-locals
v_head_dim=v_head_dim,
target=target,
)
if attn_kind == "mha"
if (attn_kind == "mha" or isinstance(attn_kind, List))
else []
)
flashinfer_mla_mods = (
Expand Down Expand Up @@ -420,7 +421,7 @@ def __init__( # pylint: disable=too-many-locals
rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]),
rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]),
]
if attn_kind == "mha"
if (attn_kind == "mha" or isinstance(attn_kind, List))
else [rx.Tuple([]) for _ in range(6)]
)
mla_function = rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_mla_paged_attention_run"), rx.ExternFunc("batch_mla_paged_attention_plan")] if attn_kind == "mla" else [])
Expand All @@ -430,6 +431,11 @@ def __init__( # pylint: disable=too-many-locals
if attn_kind == "mla":
attn_merge_functions.append(bb.add_func(_merge_state_inplace(num_attention_heads, mla_original_v_head_dim, dtype, target, "tir_attention_merge_state_mla"), "tir_attention_merge_state_mla"))


if isinstance(attn_kind, List):
attn_kind = [int(getattr(AttnKind, layer_kind.upper())) for layer_kind in attn_kind]
else:
attn_kind = [int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)]
args = [
rx.ShapeExpr(
[
Expand Down Expand Up @@ -482,7 +488,7 @@ class TIRPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-methods

def __init__( # pylint: disable=too-many-locals
self,
attn_kind: Literal["mha", "mla"],
attn_kind: Union[Literal["mha", "mla"], List[Literal["mha", "mla","mha_sliding"]]],
max_batch_size: tir.Var,
max_total_seq_len: tir.Var,
prefill_chunk_size: tir.Var,
Expand Down Expand Up @@ -553,7 +559,10 @@ def __init__( # pylint: disable=too-many-locals
target : Target
The target to build the model to.
"""

if isinstance(attn_kind, List):
attn_kind = [int(getattr(AttnKind, layer_kind.upper())) for layer_kind in attn_kind]
else:
attn_kind = [int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)]
bb = rx.BlockBuilder.current()
args = [
rx.ShapeExpr(
Expand All @@ -570,9 +579,7 @@ def __init__( # pylint: disable=too-many-locals
rx.PrimValue(num_key_value_heads),
rx.PrimValue(qk_head_dim),
rx.PrimValue(v_head_dim),
rx.ShapeExpr(
[int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)]
),
rx.ShapeExpr(attn_kind),
rx.PrimValue(enable_disaggregation),
rx.PrimValue(rope_mode),
rx.PrimValue(rope_scale),
Expand Down Expand Up @@ -614,9 +621,9 @@ def __init__( # pylint: disable=too-many-locals
else:
# pylint: disable=line-too-long
# fmt: off
ragged_qk_head_dim = qk_head_dim if attn_kind == "mha" else mla_original_qk_head_dim
ragged_v_head_dim = v_head_dim if attn_kind == "mha" else mla_original_v_head_dim
args.append(rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_ragged(num_key_value_heads if attn_kind == "mha" else num_attention_heads, num_attention_heads, ragged_qk_head_dim, ragged_v_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged")]))
ragged_qk_head_dim = qk_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_qk_head_dim
ragged_v_head_dim = v_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_v_head_dim
args.append(rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_ragged(num_key_value_heads if (attn_kind == "mha" or isinstance(attn_kind, List)) else num_attention_heads, num_attention_heads, ragged_qk_head_dim, ragged_v_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged")]))
mha_functions = (
[
rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, False, rope_scaling, target), "tir_attention_prefill")]),
Expand All @@ -626,7 +633,7 @@ def __init__( # pylint: disable=too-many-locals
rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]),
rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]),
]
if attn_kind == "mha"
if (attn_kind == "mha" or isinstance(attn_kind, List))
else [rx.Tuple([]) for _ in range(6)]
)
mla_function = rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_mla(num_attention_heads, v_head_dim, qk_head_dim - v_head_dim, dtype, False, target), "tir_attention_prefill_mla")] if attn_kind == "mla" else [])
Expand All @@ -641,7 +648,7 @@ def __init__( # pylint: disable=too-many-locals
[
rx.Tuple(attn_merge_functions),
bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, qk_head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"),
bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if attn_kind == "mha" else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"),
bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if (attn_kind == "mha" or isinstance(attn_kind, List)) else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"),
bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, qk_head_dim, dtype), "kv_cache_debug_get_kv"),
bb.add_func(_compact_kv_copy(num_key_value_heads, qk_head_dim, dtype, target), "kv_cache_compact_kv_copy"),
]
Expand Down
3 changes: 2 additions & 1 deletion src/runtime/relax_vm/attn_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,14 @@ enum class AttnKind : int {
kMHA = 0,
kMLA = 1,
kLinearAttn = 2,
kMHASliding = 3,
};

/*! \brief Given the attention kind and other metadata, return the one-layer KV cache shape. */
inline ShapeTuple GetKVCacheShape(AttnKind attn_kind, int64_t num_total_pages, int num_sequence,
int64_t num_kv_heads, int64_t page_size, int64_t qk_head_dim,
int64_t v_head_dim) {
if (attn_kind == AttnKind::kMHA) {
if (attn_kind == AttnKind::kMHA || attn_kind == AttnKind::kMHASliding) {
// Ignore v_head_dim since multi-head attention requires K/V to have the same head dim.
return {num_total_pages, 2, num_kv_heads, page_size, qk_head_dim};
} else if (attn_kind == AttnKind::kMLA) {
Expand Down
Loading