diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index 1d06bf2f3595..314c1b8a21f6 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -20,7 +20,7 @@ # pylint: disable=too-many-statements,too-many-lines,too-many-arguments,invalid-name import enum import math -from typing import Any, Dict, List, Literal, Optional, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple, Union import tvm from tvm import relax as rx @@ -86,6 +86,7 @@ class AttnKind(enum.IntEnum): MHA = 0 MLA = 1 + MHA_SLIDING = 3 class RopeMode(enum.IntEnum): @@ -301,7 +302,7 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-me def __init__( # pylint: disable=too-many-locals self, - attn_kind: Literal["mha", "mla"], + attn_kind: Union[Literal["mha", "mla"], List[Literal["mha", "mla","mha_sliding"]]], max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, @@ -377,8 +378,8 @@ def __init__( # pylint: disable=too-many-locals dtype_q=dtype, dtype_kv=dtype, dtype_o=dtype, - qk_head_dim=qk_head_dim if attn_kind == "mha" else mla_original_qk_head_dim, - v_head_dim=v_head_dim if attn_kind == "mha" else mla_original_v_head_dim, + qk_head_dim=qk_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_qk_head_dim, + v_head_dim=v_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_v_head_dim, target=target, enable_inline_rope=rope_mode == RopeMode.INLINE, ) @@ -391,7 +392,7 @@ def __init__( # pylint: disable=too-many-locals v_head_dim=v_head_dim, target=target, ) - if attn_kind == "mha" + if (attn_kind == "mha" or isinstance(attn_kind, List)) else [] ) flashinfer_mla_mods = ( @@ -420,7 +421,7 @@ def __init__( # pylint: disable=too-many-locals rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]), rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]), ] - if attn_kind == "mha" + if (attn_kind == "mha" or isinstance(attn_kind, List)) else [rx.Tuple([]) for _ in range(6)] ) mla_function = rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_mla_paged_attention_run"), rx.ExternFunc("batch_mla_paged_attention_plan")] if attn_kind == "mla" else []) @@ -430,6 +431,11 @@ def __init__( # pylint: disable=too-many-locals if attn_kind == "mla": attn_merge_functions.append(bb.add_func(_merge_state_inplace(num_attention_heads, mla_original_v_head_dim, dtype, target, "tir_attention_merge_state_mla"), "tir_attention_merge_state_mla")) + + if isinstance(attn_kind, List): + attn_kind = [int(getattr(AttnKind, layer_kind.upper())) for layer_kind in attn_kind] + else: + attn_kind = [int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)] args = [ rx.ShapeExpr( [ @@ -482,7 +488,7 @@ class TIRPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-methods def __init__( # pylint: disable=too-many-locals self, - attn_kind: Literal["mha", "mla"], + attn_kind: Union[Literal["mha", "mla"], List[Literal["mha", "mla","mha_sliding"]]], max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, @@ -553,7 +559,10 @@ def __init__( # pylint: disable=too-many-locals target : Target The target to build the model to. """ - + if isinstance(attn_kind, List): + attn_kind = [int(getattr(AttnKind, layer_kind.upper())) for layer_kind in attn_kind] + else: + attn_kind = [int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)] bb = rx.BlockBuilder.current() args = [ rx.ShapeExpr( @@ -570,9 +579,7 @@ def __init__( # pylint: disable=too-many-locals rx.PrimValue(num_key_value_heads), rx.PrimValue(qk_head_dim), rx.PrimValue(v_head_dim), - rx.ShapeExpr( - [int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)] - ), + rx.ShapeExpr(attn_kind), rx.PrimValue(enable_disaggregation), rx.PrimValue(rope_mode), rx.PrimValue(rope_scale), @@ -614,9 +621,9 @@ def __init__( # pylint: disable=too-many-locals else: # pylint: disable=line-too-long # fmt: off - ragged_qk_head_dim = qk_head_dim if attn_kind == "mha" else mla_original_qk_head_dim - ragged_v_head_dim = v_head_dim if attn_kind == "mha" else mla_original_v_head_dim - args.append(rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_ragged(num_key_value_heads if attn_kind == "mha" else num_attention_heads, num_attention_heads, ragged_qk_head_dim, ragged_v_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged")])) + ragged_qk_head_dim = qk_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_qk_head_dim + ragged_v_head_dim = v_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_v_head_dim + args.append(rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_ragged(num_key_value_heads if (attn_kind == "mha" or isinstance(attn_kind, List)) else num_attention_heads, num_attention_heads, ragged_qk_head_dim, ragged_v_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged")])) mha_functions = ( [ rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, False, rope_scaling, target), "tir_attention_prefill")]), @@ -626,7 +633,7 @@ def __init__( # pylint: disable=too-many-locals rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]), rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]), ] - if attn_kind == "mha" + if (attn_kind == "mha" or isinstance(attn_kind, List)) else [rx.Tuple([]) for _ in range(6)] ) mla_function = rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_mla(num_attention_heads, v_head_dim, qk_head_dim - v_head_dim, dtype, False, target), "tir_attention_prefill_mla")] if attn_kind == "mla" else []) @@ -641,7 +648,7 @@ def __init__( # pylint: disable=too-many-locals [ rx.Tuple(attn_merge_functions), bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, qk_head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"), - bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if attn_kind == "mha" else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"), + bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if (attn_kind == "mha" or isinstance(attn_kind, List)) else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"), bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, qk_head_dim, dtype), "kv_cache_debug_get_kv"), bb.add_func(_compact_kv_copy(num_key_value_heads, qk_head_dim, dtype, target), "kv_cache_compact_kv_copy"), ] diff --git a/src/runtime/relax_vm/attn_utils.h b/src/runtime/relax_vm/attn_utils.h index 8138aa7bbdf6..f2a63aeb9044 100644 --- a/src/runtime/relax_vm/attn_utils.h +++ b/src/runtime/relax_vm/attn_utils.h @@ -62,13 +62,14 @@ enum class AttnKind : int { kMHA = 0, kMLA = 1, kLinearAttn = 2, + kMHASliding = 3, }; /*! \brief Given the attention kind and other metadata, return the one-layer KV cache shape. */ inline ShapeTuple GetKVCacheShape(AttnKind attn_kind, int64_t num_total_pages, int num_sequence, int64_t num_kv_heads, int64_t page_size, int64_t qk_head_dim, int64_t v_head_dim) { - if (attn_kind == AttnKind::kMHA) { + if (attn_kind == AttnKind::kMHA || attn_kind == AttnKind::kMHASliding) { // Ignore v_head_dim since multi-head attention requires K/V to have the same head dim. return {num_total_pages, 2, num_kv_heads, page_size, qk_head_dim}; } else if (attn_kind == AttnKind::kMLA) { diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 496891f2172b..02bb13473449 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -98,6 +98,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { const int64_t prefill_chunk_size_; /*! \brief A boolean flag indicating if the KV cache supports sliding window. */ const bool support_sliding_window_; + /*! \brief A boolean flag indicating if the KV cache has per layer sliding window. */ + const bool support_layer_sliding_window_; /*! \brief The attention kinds for each layer. */ const std::vector attn_kinds_; @@ -195,10 +197,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::vector qo_indptr_on_depths_host_; std::vector page_indptr_on_depths_host_; std::vector page_indices_on_depths_host_; + std::vector page_indptr_sliding_window_on_depths_host_; + std::vector page_indices_sliding_window_on_depths_host_; std::vector last_page_len_on_depths_host_; std::vector sliding_window_offset_on_depths_host_; std::vector sink_size_on_depths_host_; std::vector k_rope_pos_offset_on_depths_host_; + std::vector k_rope_pos_offset_sliding_window_on_depths_host_; HostMemoryVector k_ragged_rope_pos_offset_host_; HostMemoryVector q_rope_position_map_host_; HostMemoryVector append_position_map_host_; @@ -236,8 +241,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::vector qo_indptr_on_depths_view_; std::vector page_indptr_on_depths_view_; std::vector page_indices_on_depths_view_; + std::vector page_indptr_sliding_window_on_depths_view_; + std::vector page_indices_sliding_window_on_depths_view_; std::vector length_info_on_depths_view_; + std::vector layer_sliding_window_length_info_on_depths_view_; std::vector k_rope_pos_offset_view_; + std::vector k_rope_pos_offset_sliding_window_view_; std::vector tree_attn_mask_view_; std::vector tree_attn_mn_indptr_view_; @@ -297,7 +306,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { v_head_dim_(v_head_dim), num_total_pages_(num_total_pages), prefill_chunk_size_(prefill_chunk_size), - support_sliding_window_(support_sliding_window), + support_sliding_window_(std::find(attn_kinds.begin(), attn_kinds.end(), AttnKind::kMHASliding) != attn_kinds.end() ? false : support_sliding_window), + support_layer_sliding_window_(std::find(attn_kinds.begin(), attn_kinds.end(), AttnKind::kMHASliding) != attn_kinds.end()), attn_kinds_(std::move(attn_kinds)), rope_mode_(support_sliding_window && rope_mode != RoPEMode::kNone ? RoPEMode::kInline : rope_mode), @@ -373,6 +383,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device)); page_indices_on_depths_host_.push_back( HostMemoryVector(num_total_pages, dtype_aux_, preferred_host_device)); + page_indptr_sliding_window_on_depths_host_.push_back( + HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device)); + page_indices_sliding_window_on_depths_host_.push_back( + HostMemoryVector(num_total_pages, dtype_aux_, preferred_host_device)); last_page_len_on_depths_host_.push_back( HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device)); sliding_window_offset_on_depths_host_.push_back( @@ -381,6 +395,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device)); k_rope_pos_offset_on_depths_host_.push_back( HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device)); + k_rope_pos_offset_sliding_window_on_depths_host_.push_back( + HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device)); tree_attn_mask_host_.push_back(HostMemoryVector(kTreeAttnMaxTreeSize * 2 * reserved_num_seqs, dtype_aux_, preferred_host_device)); tree_attn_mn_indptr_host_.push_back( @@ -423,8 +439,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { qo_indptr_on_depths_view_.push_back(NDArray()); page_indptr_on_depths_view_.push_back(NDArray()); page_indices_on_depths_view_.push_back(NDArray()); + page_indptr_sliding_window_on_depths_view_.push_back(NDArray()); + page_indices_sliding_window_on_depths_view_.push_back(NDArray()); length_info_on_depths_view_.push_back(NDArray()); + layer_sliding_window_length_info_on_depths_view_.push_back(NDArray()); k_rope_pos_offset_view_.push_back(NDArray()); + k_rope_pos_offset_sliding_window_view_.push_back(NDArray()); tree_attn_mask_view_.push_back(NDArray()); tree_attn_mn_indptr_view_.push_back(NDArray()); is_chain_on_depths_.push_back(true); @@ -711,7 +731,30 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { void EnableSlidingWindowForSeq(int64_t seq_id, int32_t sliding_window_size, int32_t attn_sink_size) final { - CHECK(support_sliding_window_) << "The KV cache does not support sliding window."; + // If per layer sliding window exists, enable sliding window for sequence + CHECK(support_sliding_window_ || support_layer_sliding_window_) << "The KV cache does not support sliding window."; + // for (AttnKind attn_kind : attn_kinds_) { + // if (attn_kind == AttnKind::kMHASliding) { + // LOG(INFO) << "Found sliding"; + // } else if (attn_kind == AttnKind::kMHA) { + // LOG(INFO) << "Found non-sliding"; + // } else { + // LOG(INFO) << "Found other"; + // } + // } + // if (support_sliding_window_) { + // LOG(INFO) << "Sldiing window supported"; + // } else { + // LOG(INFO) << "Sldiing window not supported"; + // } + // if (std::find(attn_kinds_.begin(), attn_kinds_.end(), AttnKind::kMHASliding) != attn_kinds_.end()) { + // LOG(INFO) << "Sliding layer found"; + // } else { + // LOG(INFO) << "Sliding layer not found"; + // } + // CHECK(!support_sliding_window_) << "The KV cache does not support sliding window."; + // LOG(INFO) << "Enabling sliding window"; + // LOG(INFO) << sliding_window_size; auto it = seq_map_.find(seq_id); CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; CHECK_GE(attn_sink_size, 0) @@ -933,28 +976,37 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { HostMemoryVector& qo_indptr_h = qo_indptr_on_depths_host_[d]; HostMemoryVector& page_indptr_h = page_indptr_on_depths_host_[d]; HostMemoryVector& page_indices_h = page_indices_on_depths_host_[d]; + HostMemoryVector& page_indptr_sliding_window_h = page_indptr_sliding_window_on_depths_host_[d]; + HostMemoryVector& page_indices_sliding_window_h = page_indices_sliding_window_on_depths_host_[d]; HostMemoryVector& last_page_len_h = last_page_len_on_depths_host_[d]; HostMemoryVector& sliding_window_offset_h = sliding_window_offset_on_depths_host_[d]; HostMemoryVector& sink_size_h = sink_size_on_depths_host_[d]; HostMemoryVector& k_rope_pos_offset_h = k_rope_pos_offset_on_depths_host_[d]; + HostMemoryVector& k_rope_pos_offset_sliding_window_h = k_rope_pos_offset_sliding_window_on_depths_host_[d]; qo_indptr_h.clear(); page_indptr_h.clear(); page_indices_h.clear(); + page_indptr_sliding_window_h.clear(); + page_indices_sliding_window_h.clear(); last_page_len_h.clear(); sliding_window_offset_h.clear(); sink_size_h.clear(); k_rope_pos_offset_h.clear(); + k_rope_pos_offset_sliding_window_h.clear(); qo_indptr_h.push_back(0); page_indptr_h.push_back(0); + page_indptr_sliding_window_h.push_back(0); for (int i = 0; i < static_cast(chunked_block_ids_arr[d].size()); ++i) { const auto& [block_id, chunk_append_length] = chunked_block_ids_arr[d][i]; qo_indptr_h.push_back(qo_indptr_h.back() + chunk_append_length); if (block_id == -1) { page_indptr_h.push_back(page_indptr_h.back()); + page_indptr_sliding_window_h.push_back(page_indptr_sliding_window_h.back()); last_page_len_h.push_back(0); sliding_window_offset_h.push_back(0); sink_size_h.push_back(0); k_rope_pos_offset_h.push_back(0); + k_rope_pos_offset_sliding_window_h.push_back(0); } else { if (d < kPagedKVCacheMaxBlockDepth - 1) { // Blocks not at maximum depth @@ -962,16 +1014,44 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { page_indptr_h.push_back(page_indptr_h.back() + block.page_ids.size()); for (int32_t page_id : block.page_ids) { page_indices_h.push_back(page_id); + // Do the same for page_indices_sliding_window } + + // For sliding window, the first page and last page will both be partially used + page_indptr_sliding_window_h.push_back( + page_indptr_sliding_window_h.back() + std::min(static_cast(block.page_ids.size()), + sequences[d]->sliding_window_size / page_size_ + (block.seq_length % sequences[d]->sliding_window_size ? 1 : 0) + )); + for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back(); i < static_cast(page_indices_h.size()); i++) { + page_indices_sliding_window_h.push_back(page_indices_h[i]); + } + // set up the page indices properly by choosing the last (sliding_window_size / + // page_size_) pages (at most) last_page_len_h.push_back( block.seq_length == 0 ? 0 : (block.seq_length - block.sink_length + block.sliding_window_offset - 1) % page_size_ + 1); - sliding_window_offset_h.push_back(block.sliding_window_offset); + if (support_layer_sliding_window_) { + if (block.seq_length < sequences[d]->sliding_window_size) { + sliding_window_offset_h.push_back(0); + } else { + sliding_window_offset_h.push_back(block.seq_length % page_size_); + } + } else { + sliding_window_offset_h.push_back(block.sliding_window_offset); + } sink_size_h.push_back(block.sink_length); k_rope_pos_offset_h.push_back(block.start_pos); + + // If sliding window, we need to calculate the positional offset + if (support_layer_sliding_window_) { + k_rope_pos_offset_sliding_window_h.push_back(std::max(0, block.start_pos + block.seq_length - sequences[d]->sliding_window_size)); + } + // if (page_indices_sliding_window_h.size() > 0) { + // LOG(INFO) << "WOffset: "<< sliding_window_offset_h.back() << " SWIdx: "<< page_indptr_sliding_window_h.back() << " LastPgIdx: " << page_indices_sliding_window_h.back() << " LastPgLen: " << last_page_len_h.back() << " RopeOffset: " << k_rope_pos_offset_h.back() << " RopeOffsetSlide: " << k_rope_pos_offset_sliding_window_h.back(); + // } } else { // Blocks at maximum depth const Block& block = global_block_pool_[block_id]; @@ -991,7 +1071,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { total_seq_length += block.seq_length; last_block_id = id; } + // Also add sliding window here? page_indptr_h.push_back(page_indptr_h.back() + num_pages); + page_indptr_sliding_window_h.push_back( + page_indptr_sliding_window_h.back() + std::min(static_cast(block.page_ids.size()), + sequences[d]->sliding_window_size / page_size_ + (block.seq_length % sequences[d]->sliding_window_size ? 1 : 0) + )); + for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back(); i < static_cast(page_indices_h.size()); i++) { + page_indices_sliding_window_h.push_back(page_indices_h[i]); + } const Block& last_block = global_block_pool_[last_block_id]; last_page_len_h.push_back(total_seq_length == 0 ? 0 @@ -999,9 +1087,20 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { last_block.sliding_window_offset - 1) % page_size_ + 1); - sliding_window_offset_h.push_back(last_block.sliding_window_offset); + if (support_layer_sliding_window_) { + if (last_block.seq_length < sequences[d]->sliding_window_size) { + sliding_window_offset_h.push_back(0); + } else { + sliding_window_offset_h.push_back(last_block.seq_length % page_size_); + } + } else { + sliding_window_offset_h.push_back(last_block.sliding_window_offset); + } sink_size_h.push_back(last_block.sink_length); k_rope_pos_offset_h.push_back(block.start_pos); + if (support_layer_sliding_window_) { + k_rope_pos_offset_sliding_window_h.push_back(std::max(0, block.start_pos + block.seq_length - sequences[d]->sliding_window_size)); + } } } } @@ -1187,7 +1286,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { NDArray pages = pages_[local_layer_id]; CHECK(qkv_data.DataType() == pages.DataType()); CHECK(o_data.DataType() == pages.DataType()); - CHECK(attn_kinds_[layer_id] == AttnKind::kMHA); + CHECK(attn_kinds_[layer_id] == AttnKind::kMHA || attn_kinds_[layer_id] == AttnKind::kMHASliding); // qkv_data: (num_total_length, num_qo_heads + 2 * num_kv_heads, qk_head_dim) // o_data: (num_total_length, num_qo_heads, qk_head_dim) @@ -1763,7 +1862,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { */ void SlideWindowForSequence(Sequence* seq) { // - No action when the sequence is not enabled for sliding window. - if (seq->sliding_window_size == -1) { + if (seq->sliding_window_size == -1 || !support_sliding_window_) { return; } // - No action when the sequence length does not exceed the window size. @@ -1805,7 +1904,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // - The first sliding page after sliding is either the last sink page, // or the page next to the last sink page. ICHECK(page_idx_after_sliding == num_sink_pages - 1 || - page_idx_after_sliding == num_sink_pages); + page_idx_after_sliding == num_sink_pages); // - Update the length of the sequence and the block. seq->seq_length = seq->sliding_window_size; @@ -1815,9 +1914,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { ICHECK_GE(block.seq_length, block.sink_length); ICHECK_GE(block.sliding_window_offset, block.sink_length); ICHECK_EQ( - (block.sliding_window_offset + (block.seq_length - block.sink_length) + page_size_ - 1) / - page_size_, - block.page_ids.size()); + (block.sliding_window_offset + (block.seq_length - block.sink_length) + page_size_ - 1) / + page_size_, + block.page_ids.size()); } /*! @@ -1849,7 +1948,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { for (int64_t page_idx = cur_npage; page_idx < tgt_npage; ++page_idx) { // When sliding window is enabled for the seq, we can "borrow temporary pages (-1)", // since the pages need to be slidden out might not have been released. - if (free_page_ids_.empty() && seq->sliding_window_size != -1) { + if (free_page_ids_.empty() && seq->sliding_window_size != -1 && support_sliding_window_) { block.page_ids.push_back(kPagedKVCacheTempPageId); } else { block.page_ids.push_back(GetFreePage()); @@ -1860,10 +1959,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // ==================== Slide ==================== // Slide the sequences so that the pages exceed the sliding window are released. SlideWindowForSequence(seq); - for (int i = 0; i < static_cast(block.page_ids.size()); ++i) { - if (block.page_ids[i] == kPagedKVCacheTempPageId) { - // Re-allocate the temporary pages after sliding window release. - block.page_ids[i] = GetFreePage(); + if (support_sliding_window_) { + for (int i = 0; i < static_cast(block.page_ids.size()); ++i) { + if (block.page_ids[i] == kPagedKVCacheTempPageId) { + // Re-allocate the temporary pages after sliding window release. + block.page_ids[i] = GetFreePage(); + } } } @@ -1921,7 +2022,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (page_indices_on_depths_view_[d]->shape[0] == 0) { continue; } - CHECK(!support_sliding_window_) << "Kernel BeginForward doesn't support sliding window."; + CHECK(!support_sliding_window_ || !support_layer_sliding_window_) << "Kernel BeginForward doesn't support sliding window."; if (use_decode_kernel_[d]) { if (f_attention_decode_ != nullptr && f_attention_decode_->backend_kind == AttnBackendKind::kFlashInfer) { @@ -2039,9 +2140,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { bool MHACrossAttnInternal(int64_t local_layer_id, NDArray q_data, NDArray o_data, NDArray lse_data, double sm_scale, bool is_first_kernel) { std::unique_ptr& f_prefill = - !support_sliding_window_ ? f_attention_prefill_ : f_attention_prefill_sliding_window_; + (!support_sliding_window_ && attn_kinds_[local_layer_id + layer_id_begin_offset_] != AttnKind::kMHASliding) ? f_attention_prefill_ : f_attention_prefill_sliding_window_; std::unique_ptr& f_decode = - !support_sliding_window_ ? f_attention_decode_ : f_attention_decode_sliding_window_; + (!support_sliding_window_ && attn_kinds_[local_layer_id + layer_id_begin_offset_] != AttnKind::kMHASliding) ? f_attention_decode_ : f_attention_decode_sliding_window_; CHECK_GE(num_depths_, 1) << "The number of effective depths must be greater or equal to 1."; bool cross_attn_computed = false; @@ -2058,29 +2159,46 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { attn_output = temp_attn_output_view_; attn_lse = temp_attn_lse_view_; } + // If layer is sliding window, use sliding window index pointer/indices + NDArray page_indptr; + NDArray page_indices; + NDArray length_info; + NDArray k_rope_pos; + if (attn_kinds_[local_layer_id + layer_id_begin_offset_] == AttnKind::kMHASliding) { + page_indptr = page_indptr_sliding_window_on_depths_view_[d]; + page_indices = page_indices_sliding_window_on_depths_view_[d]; + length_info = layer_sliding_window_length_info_on_depths_view_[d]; + k_rope_pos = k_rope_pos_offset_sliding_window_view_[d]; + } else { + page_indptr = page_indptr_on_depths_view_[d]; + page_indices = page_indices_on_depths_view_[d]; + length_info = length_info_on_depths_view_[d]; + k_rope_pos = k_rope_pos_offset_view_[d]; + } + if (append_before_attn_ && !is_chain_on_depths_[d]) { ICHECK_NOTNULL(f_attention_prefill_with_tree_mask_paged_kv_); f_attention_prefill_with_tree_mask_paged_kv_->MHA( q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], - page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], - length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], q_rope_position_map_view_, + page_indptr, page_indices, + length_info, k_rope_pos, q_rope_position_map_view_, tree_attn_mn_indptr_view_[d], tree_attn_mask_view_[d], rope_mode_, rotary_scale_, rotary_theta_, sm_scale, attn_output, attn_lse, compute_stream_); } else if (use_decode_kernel_[d]) { // Use decode kernel for depth d ICHECK_NOTNULL(f_decode); - f_decode->MHA(d, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[d], - page_indices_on_depths_view_[d], length_info_on_depths_view_[d], - k_rope_pos_offset_view_[d], q_rope_position_map_view_, rope_mode_, + f_decode->MHA(d, q_data, pages_[local_layer_id], page_indptr, + page_indices, length_info, + k_rope_pos, q_rope_position_map_view_, rope_mode_, rotary_scale_, rotary_theta_, sm_scale, attn_output, attn_lse, compute_stream_); } else { // Use prefill kernel for depth d ICHECK_NOTNULL(f_prefill); f_prefill->MHA(d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], - page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], - length_info_on_depths_view_[d], q_rope_position_map_view_, - k_rope_pos_offset_view_[d], /*causal=*/false, + page_indptr, page_indices, + length_info, q_rope_position_map_view_, + k_rope_pos, /*causal=*/false, /*rotary_mode=*/rope_mode_, rotary_scale_, rotary_theta_, sm_scale, attn_output, attn_lse, compute_stream_); } @@ -2193,7 +2311,23 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { page_indices_on_depths_view_[d] = aux_data_manager_->CopyPageIndicesOnDepthAsync(&page_indices_on_depths_host_[d], d); } - // 5. length_info_on_depths + + // If per layer sliding window exists, must copy additional vectors + if (support_layer_sliding_window_) { + // 5. page_indptr_sliding_window_on_depths + for (int d = 0; d < num_depths_; ++d) { + ICHECK_EQ(page_indptr_sliding_window_on_depths_host_[d].size(), qo_indptr_on_depths_host_[d].size()); + page_indptr_sliding_window_on_depths_view_[d] = + aux_data_manager_->CopyPageIndptrOnDepthAsync(&page_indptr_sliding_window_on_depths_host_[d], d); + } + // 6. page_indices_sliding_window_on_depths + for (int d = 0; d < num_depths_; ++d) { + ICHECK_EQ(page_indices_sliding_window_on_depths_host_[d].size(), page_indptr_sliding_window_on_depths_host_[d].back()); + page_indices_sliding_window_on_depths_view_[d] = + aux_data_manager_->CopyPageIndicesOnDepthAsync(&page_indices_sliding_window_on_depths_host_[d], d); + } + } + // 7. length_info_on_depths // last_page_len_on_depths_host_; // sliding_window_offset_on_depths_host_; // sink_size_on_depths_host_; @@ -2212,6 +2346,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { &last_page_len_on_depths_host_[d], &sliding_window_offset_on_depths_host_[d], &sink_size_on_depths_host_[d], d); } + + if (support_layer_sliding_window_) { + layer_sliding_window_length_info_on_depths_view_[d] = aux_data_manager_->CopyLengthInfoOnDepthAsync( + &last_page_len_on_depths_host_[d], &sliding_window_offset_on_depths_host_[d], + &sink_size_on_depths_host_[d], d); + } + } // 6. k_rope_pos_offset_on_depths for (int d = 0; d < num_depths_; ++d) { @@ -2219,6 +2360,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { qo_indptr_on_depths_host_[d].size()); k_rope_pos_offset_view_[d] = aux_data_manager_->CopyKRoPEPosOffsetOnDepthAsync( &k_rope_pos_offset_on_depths_host_[d], d); + if (support_layer_sliding_window_) { + ICHECK_EQ(k_rope_pos_offset_sliding_window_on_depths_host_[d].size() + 1, + qo_indptr_on_depths_host_[d].size()); + k_rope_pos_offset_sliding_window_view_[d] = aux_data_manager_->CopyKRoPEPosOffsetOnDepthAsync( + &k_rope_pos_offset_sliding_window_on_depths_host_[d], d); + } } // 7. cur_append_lengths_indptr cur_append_length_indptr_view_ =