Skip to content

[Executorch][llm] Add support for ring kv cache and ring attention #10608

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: gh/kimishpatel/185/base
Choose a base branch
from
107 changes: 107 additions & 0 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, Optional, Tuple, Type, TypedDict

import torch
Expand Down Expand Up @@ -160,6 +161,112 @@ def forward(
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)


class CacheUpdateStrategy(Enum):
RING_BUFFER = "RingBuffer"
INVALID = "Invalid"


class CachePositionsManager(nn.Module):
def __init__(
self,
max_context_length: int,
cache_update_strategy: CacheUpdateStrategy = CacheUpdateStrategy.RING_BUFFER,
):
super().__init__()
assert (
cache_update_strategy == CacheUpdateStrategy.RING_BUFFER
), "Only RingBuffer is supported"
self.max_context_length = max_context_length
self.register_buffer(
"cache_positions",
torch.zeros((self.max_context_length), dtype=torch.long, device="cpu"),
)

def calculate_positions_and_update_indices(self, input_pos: torch.Tensor, seq_len):
"""
Calculate indices, into k_cache, v_cache, where to put k_val tensor.
Given the input_pos and length of k_val at sequence dim, the input pos may
have to wrap around if it is smaller than the cache capacity.
If it is larger than the cache capacity then just pick the last
self.max_context_length entries.

Additionally:
Update the cache positions buffer with the new indices.
Given the cache positions in sequence dim, indicated by indices,
we can just update cache_positions buffer using orig_indices.
For example
Given cache capacity of 4 and update of length 3 with start_pos = 2
will have following values
indices = [2, 3, 0]
orig_indices = [2, 3, 4]
So cache_positions after the update will be [4, 1, 2, 3]
Note cache_positions[1] = 1 that is from previous write to the cache.
The corner case here is cache positions before cache rolls over.
For example when start_pos = 0 and update is of length 2, then we have
filled positions 0 and 1 in the buffer, while the rest are invalid. In this case
we have
indices = [0, 1]
orig_indices = [0, 1]
But if we have cache_positins = [0, 1, 0, 0] that is not valid. Hence we have
to make sure that invalid positions have a sentinel value of - 1.
"""
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos
indices = orig_indices % self.max_context_length

full_t = torch.full((self.max_context_length,), -1, dtype=torch.long)
arange_tensor = torch.arange(self.max_context_length, dtype=torch.long)
cache_positions = torch.where(
arange_tensor < start_pos, self.cache_positions, full_t
)
self.cache_positions.copy_(cache_positions)
self.cache_positions.index_copy_(0, indices, orig_indices)

return indices


class RingKVCache(KVCache):
def __init__(
self,
max_batch_size: int,
max_context_length: int,
n_heads: int,
head_dim: int,
enable_dynamic_shape: bool,
dtype=torch.float32,
):
super().__init__(
max_batch_size,
max_context_length,
n_heads,
head_dim,
enable_dynamic_shape,
dtype,
)
self.cache_positions_manager = CachePositionsManager(max_context_length)

def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [S], k_val: [B, H, S, D]
seq_len = k_val.size(2)
indices = self.cache_positions_manager.calculate_positions_and_update_indices(
input_pos, seq_len
)
if self.enable_dynamic_shape:
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)

self.k_cache.index_copy_(2, indices, k_val)
self.v_cache.index_copy_(2, indices, v_val)
else:
self.k_cache[:, :, indices] = k_val
self.v_cache[:, :, indices] = v_val

return self.k_cache, self.v_cache


@register_attention("mha")
class AttentionMHA(Attention):
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
Expand Down
11 changes: 11 additions & 0 deletions examples/models/llama/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,14 @@ python_unittest(
"//executorch/examples/models/llama:static_attention",
],
)

python_unittest(
name = "test_ring_kv_cache",
srcs = [
"test_ring_kv_cache.py",
],
deps = [
"//caffe2:torch",
"//executorch/examples/models/llama:llama_transformer",
],
)
Loading
Loading