|
10 | 10 |
|
11 | 11 | import torch
|
12 | 12 | import torch.nn as nn
|
13 |
| -from executorch.examples.models.llama.attention import KVCache |
| 13 | +from executorch.examples.models.llama.attention import ( |
| 14 | + _create_causal_mask_for_ring_buffer, |
| 15 | + CachePositionsManager, |
| 16 | + KVCache, |
| 17 | + RingKVCache, |
| 18 | +) |
14 | 19 |
|
15 | 20 | from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
|
16 | 21 |
|
@@ -75,6 +80,7 @@ def __init__(
|
75 | 80 | self.register_buffer(
|
76 | 81 | "v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int8)
|
77 | 82 | )
|
| 83 | + self.cache_type = cache_type |
78 | 84 |
|
79 | 85 | def _quantize(self, value):
|
80 | 86 | (
|
@@ -181,6 +187,7 @@ def update(self, input_pos, k_val, v_val, indices=None):
|
181 | 187 | However the storage is [B, S, H, D] so we incur transpose in, transpose out
|
182 | 188 | This shall be removed by subsequent post-export graph pass
|
183 | 189 | """
|
| 190 | + |
184 | 191 | k_val = k_val.transpose(1, 2)
|
185 | 192 | v_val = v_val.transpose(1, 2)
|
186 | 193 |
|
@@ -346,3 +353,185 @@ def _replace_kv_cache_with_custom_kv_cache(module):
|
346 | 353 | else:
|
347 | 354 | _replace_kv_cache_with_custom_kv_cache(child)
|
348 | 355 | return module
|
| 356 | + |
| 357 | + |
| 358 | +class QuantizedRingKVCache(QuantizedKVCache): |
| 359 | + def __init__( |
| 360 | + self, |
| 361 | + max_batch_size, |
| 362 | + max_context_length, |
| 363 | + n_heads, |
| 364 | + head_dim, |
| 365 | + cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric, |
| 366 | + use_custom_update_cache_op: bool = False, |
| 367 | + ): |
| 368 | + # Look at attention.py for explanation on why max_context_length * 2 |
| 369 | + super().__init__( |
| 370 | + max_batch_size, |
| 371 | + max_context_length * 2, |
| 372 | + n_heads, |
| 373 | + head_dim, |
| 374 | + cache_type, |
| 375 | + use_custom_update_cache_op, |
| 376 | + ) |
| 377 | + self.cache_positions_manager = CachePositionsManager(self.max_context_length) |
| 378 | + self.is_ring_buffer = True |
| 379 | + self.window_size = max_context_length |
| 380 | + |
| 381 | + def create_causal_mask_for_ring_buffer(self, start_pos, seq_len): |
| 382 | + cache_positions = self.cache_positions_manager.cache_positions |
| 383 | + return _create_causal_mask_for_ring_buffer( |
| 384 | + cache_positions, self.window_size, start_pos, seq_len |
| 385 | + ) |
| 386 | + |
| 387 | + def update(self, input_pos, k_val, v_val): |
| 388 | + """ |
| 389 | + k_val, v_val: [B, H, S, D] |
| 390 | + return: [B, H, S, D] |
| 391 | + However the storage is [B, S, H, D] so we incur transpose in, transpose out |
| 392 | + This shall be removed by subsequent post-export graph pass |
| 393 | + """ |
| 394 | + # Need to transpose for two reasons |
| 395 | + # 1. kv cache is stored as [B, S, H, D] |
| 396 | + # 2. If seq_len = k_val.size(2), we wont be able be able to optimize |
| 397 | + # away transpose at the output of k, v projection |
| 398 | + seq_len = k_val.transpose(1, 2).size(1) |
| 399 | + assert seq_len <= self.k_cache.size( |
| 400 | + 1 |
| 401 | + ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" |
| 402 | + indices = self.cache_positions_manager.calculate_positions_and_update_indices( |
| 403 | + input_pos, seq_len |
| 404 | + ) |
| 405 | + indices = indices.unsqueeze(0) |
| 406 | + |
| 407 | + return super().update(input_pos, k_val, v_val, indices) |
| 408 | + |
| 409 | + @classmethod |
| 410 | + def from_quantized_kv_cache( |
| 411 | + cls, |
| 412 | + kv_cache, |
| 413 | + sliding_window_size, |
| 414 | + ): |
| 415 | + assert isinstance( |
| 416 | + kv_cache, QuantizedKVCache |
| 417 | + ), "For QuantizedRingKVCache expect QuantizedKVCache as input kv_cache" |
| 418 | + max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape |
| 419 | + return cls( |
| 420 | + max_batch_size, |
| 421 | + sliding_window_size, |
| 422 | + n_heads, |
| 423 | + head_dim, |
| 424 | + kv_cache.cache_type, |
| 425 | + kv_cache.use_custom_update_cache_op, |
| 426 | + ) |
| 427 | + |
| 428 | + |
| 429 | +class CustomRingKVCache(CustomKVCache): |
| 430 | + def __init__( |
| 431 | + self, |
| 432 | + max_batch_size, |
| 433 | + max_context_length, |
| 434 | + n_heads, |
| 435 | + head_dim, |
| 436 | + dtype=torch.float32, |
| 437 | + ): |
| 438 | + # Look at attention.py for explanation on why max_context_length * 2 |
| 439 | + super().__init__( |
| 440 | + max_batch_size, max_context_length * 2, n_heads, head_dim, dtype |
| 441 | + ) |
| 442 | + self.cache_positions_manager = CachePositionsManager(self.max_context_length) |
| 443 | + self.is_ring_buffer = True |
| 444 | + self.window_size = max_context_length |
| 445 | + |
| 446 | + def create_causal_mask_for_ring_buffer(self, start_pos, seq_len): |
| 447 | + cache_positions = self.cache_positions_manager.cache_positions |
| 448 | + return _create_causal_mask_for_ring_buffer( |
| 449 | + cache_positions, self.window_size, start_pos, seq_len |
| 450 | + ) |
| 451 | + |
| 452 | + def update(self, input_pos, k_val, v_val): |
| 453 | + """ |
| 454 | + k_val, v_val: [B, H, S, D] |
| 455 | + return: [B, H, S, D] |
| 456 | + However the storage is [B, S, H, D] so we incur transpose in, transpose out |
| 457 | + This shall be removed by subsequent post-export graph pass |
| 458 | + """ |
| 459 | + # Need to transpose for two reasons |
| 460 | + # 1. kv cache is stored as [B, S, H, D] |
| 461 | + # 2. If seq_len = k_val.size(2), we wont be able be able to optimize |
| 462 | + # away transpose at the output of k, v projection |
| 463 | + seq_len = k_val.transpose(1, 2).size(1) |
| 464 | + assert seq_len <= self.k_cache.size( |
| 465 | + 1 |
| 466 | + ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" |
| 467 | + indices = self.cache_positions_manager.calculate_positions_and_update_indices( |
| 468 | + input_pos, seq_len |
| 469 | + ) |
| 470 | + indices = indices.unsqueeze(0) |
| 471 | + |
| 472 | + return super().update(input_pos, k_val, v_val, indices) |
| 473 | + |
| 474 | + @classmethod |
| 475 | + def from_custom_kv_cache( |
| 476 | + cls, |
| 477 | + kv_cache, |
| 478 | + sliding_window_size, |
| 479 | + ): |
| 480 | + max_batch_size, n_heads, _, head_dim = kv_cache.k_cache.shape |
| 481 | + if isinstance(kv_cache, CustomKVCache): |
| 482 | + # If replacing custom kv cache, then the shape is [B, S, H, D] |
| 483 | + max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape |
| 484 | + return cls( |
| 485 | + max_batch_size, |
| 486 | + sliding_window_size, |
| 487 | + n_heads, |
| 488 | + head_dim, |
| 489 | + dtype=kv_cache.k_cache.dtype, |
| 490 | + ) |
| 491 | + |
| 492 | + |
| 493 | +def _replace_kv_cache_with_ring_kv_cache(attention, layer_size): |
| 494 | + sliding_window_size = layer_size |
| 495 | + assert ( |
| 496 | + getattr(attention, "kv_cache", None) is not None |
| 497 | + ), "Attention module must have kv_cache module" |
| 498 | + kv_cache = attention.kv_cache |
| 499 | + if isinstance(kv_cache, KVCache): |
| 500 | + attention.kv_cache = RingKVCache( |
| 501 | + kv_cache.max_batch_size, |
| 502 | + sliding_window_size, |
| 503 | + kv_cache.n_heads, |
| 504 | + kv_cache.head_dim, |
| 505 | + kv_cache.enable_dynamic_shape, |
| 506 | + kv_cache.k_cache.dtype, |
| 507 | + ) |
| 508 | + elif isinstance(kv_cache, CustomKVCache): |
| 509 | + attention.kv_cache = CustomRingKVCache.from_custom_kv_cache( |
| 510 | + kv_cache, layer_size |
| 511 | + ) |
| 512 | + elif isinstance(kv_cache, QuantizedKVCache): |
| 513 | + attention.kv_cache = QuantizedRingKVCache.from_quantized_kv_cache( |
| 514 | + kv_cache, layer_size |
| 515 | + ) |
| 516 | + |
| 517 | + |
| 518 | +def replace_kv_cache_with_ring_kv_cache(module, layer_sizes): |
| 519 | + # This is needed to ensure that custom ops are registered |
| 520 | + from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 |
| 521 | + |
| 522 | + logging.info( |
| 523 | + "Replacing kv cache with ring kv cache. This modifies the model in place." |
| 524 | + ) |
| 525 | + assert len(layer_sizes) == len( |
| 526 | + module.layers |
| 527 | + ), f"Length of layer sizes {len(layer_sizes)} must match the number of layers in the module {len(module.layers)}." |
| 528 | + for i, transformer_block in enumerate(module.layers): |
| 529 | + sliding_window_size = layer_sizes[i] |
| 530 | + if sliding_window_size == 0: |
| 531 | + continue |
| 532 | + assert ( |
| 533 | + getattr(transformer_block, "attention", None) is not None |
| 534 | + ), f"Transfomer block must have attention module. Transformer block {transformer_block}" |
| 535 | + attention = transformer_block.attention |
| 536 | + _replace_kv_cache_with_ring_kv_cache(attention, sliding_window_size) |
| 537 | + return module |
0 commit comments