Skip to content

Commit 48877ff

Browse files
committed
[Executorch][llm] Enable leveraging ring kv cache via module swap
Pull Request resolved: #10611 This allows us to make some of the attention modules to use sliding window kv cache. Will help enable models like gemma3. Differential Revision: [D73891426](https://our.internmc.facebook.com/intern/diff/D73891426/) ghstack-source-id: 282013417
1 parent 7fa41a2 commit 48877ff

File tree

5 files changed

+522
-31
lines changed

5 files changed

+522
-31
lines changed

examples/models/llama/attention.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,16 @@ def forward(
150150
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
151151

152152

153+
def _create_causal_mask_for_ring_buffer(
154+
cache_positions, window_size, start_pos, seq_len
155+
):
156+
pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1)
157+
delta = pos_q - cache_positions
158+
attn_mask = (cache_positions >= 0) & (delta >= 0) & (delta < window_size)
159+
attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712
160+
return attn_mask
161+
162+
153163
class CacheUpdateStrategy(Enum):
154164
RING_BUFFER = "RingBuffer"
155165
INVALID = "Invalid"
@@ -283,12 +293,10 @@ def __init__(
283293
self.is_ring_buffer = True
284294

285295
def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
286-
pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1)
287296
cache_positions = self.cache_positions_manager.cache_positions
288-
delta = pos_q - cache_positions
289-
attn_mask = (cache_positions >= 0) & (delta >= 0) & (delta < self.window_size)
290-
attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712
291-
return attn_mask
297+
return _create_causal_mask_for_ring_buffer(
298+
cache_positions, self.window_size, start_pos, seq_len
299+
)
292300

293301
def update(
294302
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor

examples/models/llama/source_transformation/custom_kv_cache.py

+190-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010

1111
import torch
1212
import torch.nn as nn
13-
from executorch.examples.models.llama.attention import KVCache
13+
from executorch.examples.models.llama.attention import (
14+
_create_causal_mask_for_ring_buffer,
15+
CachePositionsManager,
16+
KVCache,
17+
RingKVCache,
18+
)
1419

1520
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
1621

@@ -75,6 +80,7 @@ def __init__(
7580
self.register_buffer(
7681
"v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int8)
7782
)
83+
self.cache_type = cache_type
7884

7985
def _quantize(self, value):
8086
(
@@ -181,6 +187,7 @@ def update(self, input_pos, k_val, v_val, indices=None):
181187
However the storage is [B, S, H, D] so we incur transpose in, transpose out
182188
This shall be removed by subsequent post-export graph pass
183189
"""
190+
184191
k_val = k_val.transpose(1, 2)
185192
v_val = v_val.transpose(1, 2)
186193

@@ -346,3 +353,185 @@ def _replace_kv_cache_with_custom_kv_cache(module):
346353
else:
347354
_replace_kv_cache_with_custom_kv_cache(child)
348355
return module
356+
357+
358+
class QuantizedRingKVCache(QuantizedKVCache):
359+
def __init__(
360+
self,
361+
max_batch_size,
362+
max_context_length,
363+
n_heads,
364+
head_dim,
365+
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
366+
use_custom_update_cache_op: bool = False,
367+
):
368+
# Look at attention.py for explanation on why max_context_length * 2
369+
super().__init__(
370+
max_batch_size,
371+
max_context_length * 2,
372+
n_heads,
373+
head_dim,
374+
cache_type,
375+
use_custom_update_cache_op,
376+
)
377+
self.cache_positions_manager = CachePositionsManager(self.max_context_length)
378+
self.is_ring_buffer = True
379+
self.window_size = max_context_length
380+
381+
def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
382+
cache_positions = self.cache_positions_manager.cache_positions
383+
return _create_causal_mask_for_ring_buffer(
384+
cache_positions, self.window_size, start_pos, seq_len
385+
)
386+
387+
def update(self, input_pos, k_val, v_val):
388+
"""
389+
k_val, v_val: [B, H, S, D]
390+
return: [B, H, S, D]
391+
However the storage is [B, S, H, D] so we incur transpose in, transpose out
392+
This shall be removed by subsequent post-export graph pass
393+
"""
394+
# Need to transpose for two reasons
395+
# 1. kv cache is stored as [B, S, H, D]
396+
# 2. If seq_len = k_val.size(2), we wont be able be able to optimize
397+
# away transpose at the output of k, v projection
398+
seq_len = k_val.transpose(1, 2).size(1)
399+
assert seq_len <= self.k_cache.size(
400+
1
401+
), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})"
402+
indices = self.cache_positions_manager.calculate_positions_and_update_indices(
403+
input_pos, seq_len
404+
)
405+
indices = indices.unsqueeze(0)
406+
407+
return super().update(input_pos, k_val, v_val, indices)
408+
409+
@classmethod
410+
def from_quantized_kv_cache(
411+
cls,
412+
kv_cache,
413+
sliding_window_size,
414+
):
415+
assert isinstance(
416+
kv_cache, QuantizedKVCache
417+
), "For QuantizedRingKVCache expect QuantizedKVCache as input kv_cache"
418+
max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape
419+
return cls(
420+
max_batch_size,
421+
sliding_window_size,
422+
n_heads,
423+
head_dim,
424+
kv_cache.cache_type,
425+
kv_cache.use_custom_update_cache_op,
426+
)
427+
428+
429+
class CustomRingKVCache(CustomKVCache):
430+
def __init__(
431+
self,
432+
max_batch_size,
433+
max_context_length,
434+
n_heads,
435+
head_dim,
436+
dtype=torch.float32,
437+
):
438+
# Look at attention.py for explanation on why max_context_length * 2
439+
super().__init__(
440+
max_batch_size, max_context_length * 2, n_heads, head_dim, dtype
441+
)
442+
self.cache_positions_manager = CachePositionsManager(self.max_context_length)
443+
self.is_ring_buffer = True
444+
self.window_size = max_context_length
445+
446+
def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
447+
cache_positions = self.cache_positions_manager.cache_positions
448+
return _create_causal_mask_for_ring_buffer(
449+
cache_positions, self.window_size, start_pos, seq_len
450+
)
451+
452+
def update(self, input_pos, k_val, v_val):
453+
"""
454+
k_val, v_val: [B, H, S, D]
455+
return: [B, H, S, D]
456+
However the storage is [B, S, H, D] so we incur transpose in, transpose out
457+
This shall be removed by subsequent post-export graph pass
458+
"""
459+
# Need to transpose for two reasons
460+
# 1. kv cache is stored as [B, S, H, D]
461+
# 2. If seq_len = k_val.size(2), we wont be able be able to optimize
462+
# away transpose at the output of k, v projection
463+
seq_len = k_val.transpose(1, 2).size(1)
464+
assert seq_len <= self.k_cache.size(
465+
1
466+
), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})"
467+
indices = self.cache_positions_manager.calculate_positions_and_update_indices(
468+
input_pos, seq_len
469+
)
470+
indices = indices.unsqueeze(0)
471+
472+
return super().update(input_pos, k_val, v_val, indices)
473+
474+
@classmethod
475+
def from_custom_kv_cache(
476+
cls,
477+
kv_cache,
478+
sliding_window_size,
479+
):
480+
max_batch_size, n_heads, _, head_dim = kv_cache.k_cache.shape
481+
if isinstance(kv_cache, CustomKVCache):
482+
# If replacing custom kv cache, then the shape is [B, S, H, D]
483+
max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape
484+
return cls(
485+
max_batch_size,
486+
sliding_window_size,
487+
n_heads,
488+
head_dim,
489+
dtype=kv_cache.k_cache.dtype,
490+
)
491+
492+
493+
def _replace_kv_cache_with_ring_kv_cache(attention, layer_size):
494+
sliding_window_size = layer_size
495+
assert (
496+
getattr(attention, "kv_cache", None) is not None
497+
), "Attention module must have kv_cache module"
498+
kv_cache = attention.kv_cache
499+
if isinstance(kv_cache, KVCache):
500+
attention.kv_cache = RingKVCache(
501+
kv_cache.max_batch_size,
502+
sliding_window_size,
503+
kv_cache.n_heads,
504+
kv_cache.head_dim,
505+
kv_cache.enable_dynamic_shape,
506+
kv_cache.k_cache.dtype,
507+
)
508+
elif isinstance(kv_cache, CustomKVCache):
509+
attention.kv_cache = CustomRingKVCache.from_custom_kv_cache(
510+
kv_cache, layer_size
511+
)
512+
elif isinstance(kv_cache, QuantizedKVCache):
513+
attention.kv_cache = QuantizedRingKVCache.from_quantized_kv_cache(
514+
kv_cache, layer_size
515+
)
516+
517+
518+
def replace_kv_cache_with_ring_kv_cache(module, layer_sizes):
519+
# This is needed to ensure that custom ops are registered
520+
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
521+
522+
logging.info(
523+
"Replacing kv cache with ring kv cache. This modifies the model in place."
524+
)
525+
assert len(layer_sizes) == len(
526+
module.layers
527+
), f"Length of layer sizes {len(layer_sizes)} must match the number of layers in the module {len(module.layers)}."
528+
for i, transformer_block in enumerate(module.layers):
529+
sliding_window_size = layer_sizes[i]
530+
if sliding_window_size == 0:
531+
continue
532+
assert (
533+
getattr(transformer_block, "attention", None) is not None
534+
), f"Transfomer block must have attention module. Transformer block {transformer_block}"
535+
attention = transformer_block.attention
536+
_replace_kv_cache_with_ring_kv_cache(attention, sliding_window_size)
537+
return module

examples/models/llama/tests/TARGETS

+25
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,33 @@ python_unittest(
5555
srcs = [
5656
"test_ring_attention.py",
5757
],
58+
preload_deps = [
59+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
60+
"//executorch/kernels/quantized:aot_lib",
61+
],
5862
deps = [
5963
"//caffe2:torch",
64+
"//executorch/examples/models/llama:export_library",
65+
"//executorch/examples/models/llama:llama_transformer",
66+
"//executorch/examples/models/llama:custom_kv_cache",
67+
"//executorch/examples/models/llama:sdpa",
68+
],
69+
)
70+
71+
python_unittest(
72+
name = "test_replace_kv_cache",
73+
srcs = [
74+
"test_replace_kv_cache.py",
75+
],
76+
preload_deps = [
77+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
78+
"//executorch/kernels/quantized:aot_lib",
79+
],
80+
deps = [
81+
"//caffe2:torch",
82+
"//executorch/examples/models/llama:export_library",
6083
"//executorch/examples/models/llama:llama_transformer",
84+
"//executorch/examples/models/llama:custom_kv_cache",
85+
"//executorch/examples/models/llama:sdpa",
6186
],
6287
)

0 commit comments

Comments
 (0)