diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index e222c052788..c886a062c39 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from enum import Enum from typing import Any, Dict, Optional, Tuple, Type, TypedDict import torch @@ -160,6 +161,112 @@ def forward( return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) +class CacheUpdateStrategy(Enum): + RING_BUFFER = "RingBuffer" + INVALID = "Invalid" + + +class CachePositionsManager(nn.Module): + def __init__( + self, + max_context_length: int, + cache_update_strategy: CacheUpdateStrategy = CacheUpdateStrategy.RING_BUFFER, + ): + super().__init__() + assert ( + cache_update_strategy == CacheUpdateStrategy.RING_BUFFER + ), "Only RingBuffer is supported" + self.max_context_length = max_context_length + self.register_buffer( + "cache_positions", + torch.zeros((self.max_context_length), dtype=torch.long, device="cpu"), + ) + + def calculate_positions_and_update_indices(self, input_pos: torch.Tensor, seq_len): + """ + Calculate indices, into k_cache, v_cache, where to put k_val tensor. + Given the input_pos and length of k_val at sequence dim, the input pos may + have to wrap around if it is smaller than the cache capacity. + If it is larger than the cache capacity then just pick the last + self.max_context_length entries. + + Additionally: + Update the cache positions buffer with the new indices. + Given the cache positions in sequence dim, indicated by indices, + we can just update cache_positions buffer using orig_indices. + For example + Given cache capacity of 4 and update of length 3 with start_pos = 2 + will have following values + indices = [2, 3, 0] + orig_indices = [2, 3, 4] + So cache_positions after the update will be [4, 1, 2, 3] + Note cache_positions[1] = 1 that is from previous write to the cache. + The corner case here is cache positions before cache rolls over. + For example when start_pos = 0 and update is of length 2, then we have + filled positions 0 and 1 in the buffer, while the rest are invalid. In this case + we have + indices = [0, 1] + orig_indices = [0, 1] + But if we have cache_positins = [0, 1, 0, 0] that is not valid. Hence we have + to make sure that invalid positions have a sentinel value of - 1. + """ + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos + indices = orig_indices % self.max_context_length + + full_t = torch.full((self.max_context_length,), -1, dtype=torch.long) + arange_tensor = torch.arange(self.max_context_length, dtype=torch.long) + cache_positions = torch.where( + arange_tensor < start_pos, self.cache_positions, full_t + ) + self.cache_positions.copy_(cache_positions) + self.cache_positions.index_copy_(0, indices, orig_indices) + + return indices + + +class RingKVCache(KVCache): + def __init__( + self, + max_batch_size: int, + max_context_length: int, + n_heads: int, + head_dim: int, + enable_dynamic_shape: bool, + dtype=torch.float32, + ): + super().__init__( + max_batch_size, + max_context_length, + n_heads, + head_dim, + enable_dynamic_shape, + dtype, + ) + self.cache_positions_manager = CachePositionsManager(max_context_length) + + def update( + self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # input_pos: [S], k_val: [B, H, S, D] + seq_len = k_val.size(2) + indices = self.cache_positions_manager.calculate_positions_and_update_indices( + input_pos, seq_len + ) + if self.enable_dynamic_shape: + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + + self.k_cache.index_copy_(2, indices, k_val) + self.v_cache.index_copy_(2, indices, v_val) + else: + self.k_cache[:, :, indices] = k_val + self.v_cache[:, :, indices] = v_val + + return self.k_cache, self.v_cache + + @register_attention("mha") class AttentionMHA(Attention): def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): diff --git a/examples/models/llama/tests/TARGETS b/examples/models/llama/tests/TARGETS index 0efaa9635c4..09ca02868ed 100644 --- a/examples/models/llama/tests/TARGETS +++ b/examples/models/llama/tests/TARGETS @@ -38,3 +38,14 @@ python_unittest( "//executorch/examples/models/llama:static_attention", ], ) + +python_unittest( + name = "test_ring_kv_cache", + srcs = [ + "test_ring_kv_cache.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/examples/models/llama:llama_transformer", + ], +) diff --git a/examples/models/llama/tests/test_ring_kv_cache.py b/examples/models/llama/tests/test_ring_kv_cache.py new file mode 100644 index 00000000000..dd9971fa010 --- /dev/null +++ b/examples/models/llama/tests/test_ring_kv_cache.py @@ -0,0 +1,477 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.examples.models.llama.attention import RingKVCache + + +class TestRingKVCache(unittest.TestCase): + def setUp(self): + # Common test parameters + self.max_batch_size = 2 + self.max_context_length = 8 + self.n_heads = 4 + self.head_dim = 16 + self.enable_dynamic_shape = True + self.dtype = torch.float32 + + def test_basic_update(self): + """Test basic update functionality of RingKVCache.""" + cache = RingKVCache( + self.max_batch_size, + self.max_context_length, + self.n_heads, + self.head_dim, + self.enable_dynamic_shape, + self.dtype, + ) + + # Create input tensors + input_pos = torch.tensor([0], dtype=torch.long) + seq_len = 3 + k_val = torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + v_val = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + * 2 + ) + + # Update the cache + k_out, v_out = cache.update(input_pos, k_val, v_val) + + # Check that the cache was updated correctly + for i in range(seq_len): + self.assertTrue(torch.all(k_out[:, :, i] == 1.0)) + self.assertTrue(torch.all(v_out[:, :, i] == 2.0)) + + # Check that the rest of the cache is still zeros + for i in range(seq_len, self.max_context_length): + self.assertTrue(torch.all(k_out[:, :, i] == 0.0)) + self.assertTrue(torch.all(v_out[:, :, i] == 0.0)) + + # Check that cache_positions was updated correctly + expected_positions = torch.tensor( + [0, 1, 2, -1, -1, -1, -1, -1], dtype=torch.long + ) + self.assertTrue( + torch.all( + cache.cache_positions_manager.cache_positions == expected_positions + ) + ) + + def test_ring_buffer_wrapping(self): + """Test that the ring buffer wraps around correctly.""" + cache = RingKVCache( + self.max_batch_size, + self.max_context_length, + self.n_heads, + self.head_dim, + self.enable_dynamic_shape, + self.dtype, + ) + + # Create input tensors for first update + input_pos = torch.tensor([6], dtype=torch.long) + seq_len = 4 # This will wrap around from position 6 to positions 6, 7, 0, 1 + k_val = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + * 3 + ) + v_val = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + * 4 + ) + + # Update the cache + k_out, v_out = cache.update(input_pos, k_val, v_val) + + # Check that the cache was updated correctly with wrapping + # Positions 6, 7 should be updated + for i in range(6, 8): + self.assertTrue(torch.all(k_out[:, :, i] == 3.0)) + self.assertTrue(torch.all(v_out[:, :, i] == 4.0)) + + # Positions 0, 1 should also be updated due to wrapping + for i in range(0, 2): + self.assertTrue(torch.all(k_out[:, :, i] == 3.0)) + self.assertTrue(torch.all(v_out[:, :, i] == 4.0)) + + # The rest should still be zeros + for i in range(2, 6): + self.assertTrue(torch.all(k_out[:, :, i] == 0.0)) + self.assertTrue(torch.all(v_out[:, :, i] == 0.0)) + + # Check that cache_positions was updated correctly + # Note that positions 2, 3, 4, 5 are 0 instead of -1 because in actual ring + # updates those positions would have been updated. + # But CachePositionsManager thinks they are updated because start_pos > (2, 3, 4, 5) + # As a result it does not fill them with -1 and instead uses original values + # which is 0, the value cache_position buffer is initialized with. + expected_positions = torch.tensor([8, 9, 0, 0, 0, 0, 6, 7], dtype=torch.long) + self.assertTrue( + torch.all( + cache.cache_positions_manager.cache_positions == expected_positions + ) + ) + + def test_multiple_updates(self): + """Test multiple updates to the cache.""" + cache = RingKVCache( + self.max_batch_size, + self.max_context_length, + self.n_heads, + self.head_dim, + self.enable_dynamic_shape, + self.dtype, + ) + + # First update + input_pos1 = torch.tensor([0], dtype=torch.long) + seq_len1 = 2 + k_val1 = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len1, self.head_dim), + dtype=self.dtype, + ) + * 5 + ) + v_val1 = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len1, self.head_dim), + dtype=self.dtype, + ) + * 6 + ) + + _, _ = cache.update(input_pos1, k_val1, v_val1) + + # Second update + input_pos2 = torch.tensor([2], dtype=torch.long) + seq_len2 = 3 + k_val2 = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len2, self.head_dim), + dtype=self.dtype, + ) + * 7 + ) + v_val2 = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len2, self.head_dim), + dtype=self.dtype, + ) + * 8 + ) + + k_out2, v_out2 = cache.update(input_pos2, k_val2, v_val2) + + # Check that the cache was updated correctly after both updates + # First update (positions 0, 1) + for i in range(0, 2): + self.assertTrue(torch.all(k_out2[:, :, i] == 5.0)) + self.assertTrue(torch.all(v_out2[:, :, i] == 6.0)) + + # Second update (positions 2, 3, 4) + for i in range(2, 5): + self.assertTrue(torch.all(k_out2[:, :, i] == 7.0)) + self.assertTrue(torch.all(v_out2[:, :, i] == 8.0)) + + # The rest should still be zeros + for i in range(5, 8): + self.assertTrue(torch.all(k_out2[:, :, i] == 0.0)) + self.assertTrue(torch.all(v_out2[:, :, i] == 0.0)) + + # Check that cache_positions was updated correctly + expected_positions = torch.tensor([0, 1, 2, 3, 4, -1, -1, -1], dtype=torch.long) + self.assertTrue( + torch.all( + cache.cache_positions_manager.cache_positions == expected_positions + ) + ) + + # Third update with wrapping + input_pos3 = torch.tensor([6], dtype=torch.long) + seq_len3 = 4 + k_val3 = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len3, self.head_dim), + dtype=self.dtype, + ) + * 9 + ) + v_val3 = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len3, self.head_dim), + dtype=self.dtype, + ) + * 10 + ) + + k_out3, v_out3 = cache.update(input_pos3, k_val3, v_val3) + + # Check final state after third update with wrapping + # Positions 0, 1 should now have values from the third update (due to wrapping) + for i in range(0, 2): + self.assertTrue(torch.all(k_out3[:, :, i] == 9.0)) + self.assertTrue(torch.all(v_out3[:, :, i] == 10.0)) + + # Positions 2, 3, 4 should still have values from the second update + for i in range(2, 5): + self.assertTrue(torch.all(k_out3[:, :, i] == 7.0)) + self.assertTrue(torch.all(v_out3[:, :, i] == 8.0)) + + # Position 5 should still be zero + self.assertTrue(torch.all(k_out3[:, :, 5] == 0.0)) + self.assertTrue(torch.all(v_out3[:, :, 5] == 0.0)) + + # Positions 6, 7 should have values from the third update + for i in range(6, 8): + self.assertTrue(torch.all(k_out3[:, :, i] == 9.0)) + self.assertTrue(torch.all(v_out3[:, :, i] == 10.0)) + + # Check that cache_positions was updated correctly + expected_positions = torch.tensor([8, 9, 2, 3, 4, -1, 6, 7], dtype=torch.long) + self.assertTrue( + torch.all( + cache.cache_positions_manager.cache_positions == expected_positions + ) + ) + + def test_edge_case_input_pos_zero(self): + """Test the edge case where input_pos is 0.""" + cache = RingKVCache( + self.max_batch_size, + self.max_context_length, + self.n_heads, + self.head_dim, + self.enable_dynamic_shape, + self.dtype, + ) + + # Create input tensors + input_pos = torch.tensor([0], dtype=torch.long) + seq_len = 1 + k_val = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + * 11 + ) + v_val = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + * 12 + ) + + # Update the cache + k_out, v_out = cache.update(input_pos, k_val, v_val) + + # Check that position 0 was updated + self.assertTrue(torch.all(k_out[:, :, 0] == 11.0)) + self.assertTrue(torch.all(v_out[:, :, 0] == 12.0)) + + # Check that the rest of the cache is still zeros + for i in range(1, self.max_context_length): + self.assertTrue(torch.all(k_out[:, :, i] == 0.0)) + self.assertTrue(torch.all(v_out[:, :, i] == 0.0)) + + # Check that cache_positions was updated correctly + expected_positions = torch.tensor( + [0, -1, -1, -1, -1, -1, -1, -1], dtype=torch.long + ) + self.assertTrue( + torch.all( + cache.cache_positions_manager.cache_positions == expected_positions + ) + ) + + def test_edge_case_exceeding_context_length(self): + """Test the edge case where input_pos + seq_len > max_context_length.""" + cache = RingKVCache( + self.max_batch_size, + self.max_context_length, + self.n_heads, + self.head_dim, + self.enable_dynamic_shape, + self.dtype, + ) + + # Create input tensors + input_pos = torch.tensor([5], dtype=torch.long) + seq_len = 5 # This will wrap around from position 5 to positions 5, 6, 7, 0, 1 + k_val = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + * 13 + ) + v_val = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + * 14 + ) + + # Update the cache + k_out, v_out = cache.update(input_pos, k_val, v_val) + + # Check that positions 5, 6, 7 were updated + for i in range(5, 8): + self.assertTrue(torch.all(k_out[:, :, i] == 13.0)) + self.assertTrue(torch.all(v_out[:, :, i] == 14.0)) + + # Check that positions 0, 1 were also updated due to wrapping + for i in range(0, 2): + self.assertTrue(torch.all(k_out[:, :, i] == 13.0)) + self.assertTrue(torch.all(v_out[:, :, i] == 14.0)) + + # Check that positions 2, 3, 4 are still zeros + for i in range(2, 5): + self.assertTrue(torch.all(k_out[:, :, i] == 0.0)) + self.assertTrue(torch.all(v_out[:, :, i] == 0.0)) + + # Check that cache_positions was updated correctly + # Note that positions 2, 3, 4 are 0 instead of -1 because in actual ring + # updates those positions would have been updated. + # But CachePositionsManager thinks they are updated because start_pos > (2, 3, 4) + # As a result it does not fill them with -1 and instead uses original values + # which is 0, the value cache_position buffer is initialized with. + expected_positions = torch.tensor([8, 9, 0, 0, 0, 5, 6, 7], dtype=torch.long) + self.assertTrue( + torch.all( + cache.cache_positions_manager.cache_positions == expected_positions + ) + ) + + def test_original_indices_tracking(self): + """Test that the original indices are tracked correctly in cache_positions.""" + cache = RingKVCache( + self.max_batch_size, + self.max_context_length, + self.n_heads, + self.head_dim, + self.enable_dynamic_shape, + self.dtype, + ) + + # First update at position 10 (will be mapped to position 2 in the ring buffer) + input_pos = torch.tensor([10], dtype=torch.long) + seq_len = 4 + k_val = torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + v_val = torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + + # Update the cache + cache.update(input_pos, k_val, v_val) + + # Check that cache_positions correctly tracks the original indices + # For input_pos=10 and seq_len=4, the original indices should be 10, 11, 12, 13 + # These map to positions 2, 3, 4, 5 in the ring buffer (since max_context_length=8) + # Note that positions 0, 1, 6 and 7 are 0 instead of -1 because in actual ring + # updates those positions would have been updated for start_pos = 0. + # So CachePositionsManager thinks they are updated because start_pos > (0, 1, 6, 7) + # As a result it does not fill them with -1 and instead uses original values + # which is 0, the value cache_position buffer is initialized with. + expected_positions = torch.tensor( + [0, 0, 10, 11, 12, 13, 0, 0], dtype=torch.long + ) + self.assertTrue( + torch.all( + cache.cache_positions_manager.cache_positions == expected_positions + ) + ) + + # Second update at position 14 (will be mapped to position 6 in the ring buffer) + input_pos = torch.tensor([14], dtype=torch.long) + seq_len = 3 + k_val = torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + v_val = torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + + # Update the cache + cache.update(input_pos, k_val, v_val) + + # Check that cache_positions correctly tracks the original indices + # For input_pos=14 and seq_len=3, the original indices should be 14, 15, 16 + # These map to positions 6, 7, 0 in the ring buffer + expected_positions = torch.tensor( + [16, 0, 10, 11, 12, 13, 14, 15], dtype=torch.long + ) + self.assertTrue( + torch.all( + cache.cache_positions_manager.cache_positions == expected_positions + ) + ) + + def test_non_dynamic_shape(self): + """Test RingKVCache with enable_dynamic_shape=False.""" + cache = RingKVCache( + self.max_batch_size, + self.max_context_length, + self.n_heads, + self.head_dim, + enable_dynamic_shape=False, + dtype=self.dtype, + ) + + # Create input tensors + input_pos = torch.tensor([0], dtype=torch.long) + seq_len = 3 + k_val = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + * 15 + ) + v_val = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + * 16 + ) + + # Update the cache + k_out, v_out = cache.update(input_pos, k_val, v_val) + + # Check that the cache was updated correctly + for i in range(seq_len): + self.assertTrue(torch.all(k_out[:, :, i] == 15.0)) + self.assertTrue(torch.all(v_out[:, :, i] == 16.0)) + + # Check that the rest of the cache is still zeros + for i in range(seq_len, self.max_context_length): + self.assertTrue(torch.all(k_out[:, :, i] == 0.0)) + self.assertTrue(torch.all(v_out[:, :, i] == 0.0))