From 51296712900f28023cecea6ff7333af4b3c81d30 Mon Sep 17 00:00:00 2001 From: Gerben van V Date: Thu, 29 Aug 2024 11:51:09 +0200 Subject: [PATCH] Add a static cache that offloads to the CPU or other device (#32161) * Add a static cache that offloads to the CPU or other device * Fix PR comments, add unit-tests --- docs/source/en/internal/generation_utils.md | 5 + docs/source/en/kv_cache.md | 45 +++- src/transformers/__init__.py | 2 + src/transformers/cache_utils.py | 272 ++++++++++++++++++++ src/transformers/generation/utils.py | 2 + src/transformers/utils/dummy_pt_objects.py | 7 + tests/utils/test_cache_utils.py | 36 ++- 7 files changed, 350 insertions(+), 19 deletions(-) diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index 936e4bfb95da..a81d202c6634 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -390,6 +390,11 @@ A [`Constraint`] can be used to force the generation to include specific tokens - get_seq_length - reset +[[autodoc]] OffloadedStaticCache + - update + - get_seq_length + - reset + [[autodoc]] HybridCache - update - get_seq_length diff --git a/docs/source/en/kv_cache.md b/docs/source/en/kv_cache.md index c0ccc49d41e6..1ae97497d2ff 100644 --- a/docs/source/en/kv_cache.md +++ b/docs/source/en/kv_cache.md @@ -96,14 +96,15 @@ with the [`~DynamicCache`] class being the default cache for most models. It all Refer to the table below to see the difference between cache types and choose the one that suits best for your use-case. -| Cache Type | Memory Efficient | Supports torch.compile() | Initialization Recommended | Latency | Long Context Generation | -|---------------------|------------------|--------------------------|----------------------------|----------|--------------------------| -| Dynamic Cache | No | No | No | Mid | No | -| Static Cache | No | Yes | Yes | High | No | -| Quantized Cache | Yes | No | No | Low | Yes | -| Offloaded Cache | Yes | No | No | Low | No | -| Sliding Window Cache| No | Yes | Yes | High | No | -| Sink Cache | Yes | No | Yes | Mid | Yes | +| Cache Type | Memory Efficient | Supports torch.compile() | Initialization Recommended | Latency | Long Context Generation | +|------------------------|------------------|--------------------------|----------------------------|---------|-------------------------| +| Dynamic Cache | No | No | No | Mid | No | +| Static Cache | No | Yes | Yes | High | No | +| Offloaded Cache | Yes | No | No | Low | Yes | +| Offloaded Static Cache | No | Yes | Yes | High | Yes | +| Quantized Cache | Yes | No | No | Low | Yes | +| Sliding Window Cache | No | Yes | Yes | High | No | +| Sink Cache | Yes | No | Yes | Mid | Yes | These cache classes can be set with a `cache_implementation` argument when generating. To learn about the available options for the cache_implementation flag, please refer to the [API Documentation](./main_classes/text_generation.md#transformers.GenerationConfig). Now, let's explore each cache type in detail and see how to use them. Note that the below examples are for decoder-only Tranformer-based models. We also support ["Model-Specific Cache"] classes for models such as Mamba or Jamba, keep reading for more details. @@ -142,7 +143,7 @@ I like rock music because it's loud and energetic. It's a great way to express m I like rock music because it's loud and energetic. I like to listen to it when I'm feeling ``` -## OffloadedCache +## Offloaded Cache Similarly to KV cache quantization, [`~OffloadedCache`] strategy aims to reduce GPU VRAM usage. It does so by moving the KV cache for most layers to the CPU. @@ -154,7 +155,8 @@ Thus, it can serve as a drop-in replacement or a fallback for it. Depending on your model and the characteristics of your generation task (size of context, number of generated tokens, number of beams, etc.) you may notice a small degradation in generation throughput compared to the default KV cache implementation. -To enable KV cache offloading, pass `cache_implementation="offloaded"` in the `generation_config` or directky to the `generate()` call. +To enable KV cache offloading, pass `cache_implementation="offloaded"` in the `generation_config` or directly to the `generate()` call. +Use `cache_implementation="offloaded_static"` for an offloaded static cache (see also [Offloaded Static Cache](#offloaded-static-cache) below). ```python >>> import torch @@ -216,7 +218,6 @@ retrying with cache_implementation='offloaded' before successfully generating 40 beams. - ### Static Cache Since the "DynamicCache" dynamically grows with each generation step, it prevents you from taking advantage of JIT optimizations. The [`~StaticCache`] pre-allocates @@ -238,6 +239,28 @@ For more examples with Static Cache and JIT compilation, take a look at [StaticC "Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of" ``` + +## Offloaded Static Cache + +Like [`~OffloadedCache`] exists for offloading a "DynamicCache", there is also an offloaded static cache. It fully supports +JIT optimizations. Just pass `cache_implementation="offloaded_static"` in the `generation_config` or directly to the `generate()` call. +This will use the [`~OffloadedStaticCache`] implementation instead. + +```python +>>> import torch +>>> from transformers import AutoTokenizer, AutoModelForCausalLM + +>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") +>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto") +>>> inputs = tokenizer("Hello, my name is", return_tensors="pt").to(model.device) + +>>> # simply pass the cache implementation="static" +>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="offloaded_static") +>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0] +"Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of" +``` + + ### Sliding Window Cache As the name suggests, this cache type implements a sliding window over previous keys and values, retaining only the last `sliding_window` tokens. It should be used with models like Mistral that support sliding window attention. Additionally, similar to Static Cache, this one is JIT-friendly and can be used with the same compile tecniques as Static Cache. diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 41d932c0ff96..74870501f798 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1246,6 +1246,7 @@ "HybridCache", "MambaCache", "OffloadedCache", + "OffloadedStaticCache", "QuantizedCache", "QuantizedCacheConfig", "QuantoQuantizedCache", @@ -6052,6 +6053,7 @@ HybridCache, MambaCache, OffloadedCache, + OffloadedStaticCache, QuantizedCache, QuantizedCacheConfig, QuantoQuantizedCache, diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 33bcbcda64b1..80c36b9f68ee 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1708,3 +1708,275 @@ def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): def reset(self): self.conv_states.zero_() self.ssm_states.zero_() + + +class OffloadedStaticCache(StaticCache): + """ + Static cache class to be used with `torch.compile(model)` that offloads to the CPU or + another device. + + Args: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize + the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. + max_cache_len (`int`): + The maximum sequence length with which the model will be used. + device (`Union[str, torch.device]`): + The device on which the cache should be initialized. Should be the same as the + layer device. + dtype (`torch.dtype`, *optional*): + The default `dtype` to use when initializing the cache. + offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`): + The device to offload to. Defaults to CPU. + + Attributes: + key_cache (`List[torch.Tensor]`): + Off-loaded key cache tensors. First one will be on device, where-as the others are + off-loaded. + value_cache (`List[torch.Tensor]`): + Off-loaded value cache tensors. First one will be on device, where-as the others are + off-loaded. + max_batch_size (`int`): + The maximum batch size with which this cache can be used. + max_cache_len (`int`): + The maximum sequence length with which this cache can be used. + device (`torch.device`): + The device on which the cache is used. + offload_device (`torch.device`): + The device used to offload to. + dtype (`torch.dtype`): + The `dtype` used to initializing the cache. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache + + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = OffloadedStaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` + """ + + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + max_cache_len: Optional[int], + device: Union[str, torch.device], + dtype: Optional[torch.dtype] = None, + offload_device: Union[str, torch.device] = torch.device("cpu"), + ) -> None: + self.max_batch_size = max_batch_size + self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len + self.device = torch.device(device) + self.offload_device = torch.device(offload_device) + self.dtype = dtype if dtype is not None else torch.float32 + + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads + head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + + num_key_value_heads = ( + config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads + ) + + cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim) + + # Create offloaded CPU tensors. + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + + for i in range(config.num_hidden_layers): + # First layer is always on-device. + device = self.device if i == 0 else self.offload_device + + key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, device) + + self.key_cache.append(key_cache) + self.value_cache.append(value_cache) + + # Create device tensors. + self._device_key_cache: List[torch.Tensor] = [] + self._device_value_cache: List[torch.Tensor] = [] + + for i in range(2): + key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, self.device) + + self._device_key_cache.append(key_cache) + self._device_value_cache.append(value_cache) + + # For backwards compatibility. + # TODO(gante): Remove this. + self._seen_tokens = 0 + + # Create new CUDA stream for parallel prefetching. + self._prefetch_stream = torch.cuda.Stream() if self.device.type == "cuda" else None + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, *optional*): + Additional arguments for the cache subclass. The `OffloadedStaticCache` needs the + `cache_position` input to know how where to write in the cache. + + Return: + A tuple containing the updated key and value states. + """ + + if layer_idx == 0: + # Update seen tokens. + # TODO(gante): Remove this. + self._seen_tokens += key_states.shape[-2] + + # Always there. + k_out = self.key_cache[0] + v_out = self.value_cache[0] + else: + # Wait for prefetch stream. + if self._prefetch_stream is not None: + torch.cuda.default_stream(self.device).wait_stream(self._prefetch_stream) + + k_out = self._device_key_cache[layer_idx & 1] + v_out = self._device_value_cache[layer_idx & 1] + + self._prefetch_layer(layer_idx + 1) + + cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None + if cache_position is None: + k_out.copy_(key_states) + v_out.copy_(value_states) + + # Copy the values to the offloaded device as well. + if layer_idx == 0: + self.key_cache[layer_idx].copy_(key_states.to(self.offload_device)) + self.value_cache[layer_idx].copy_(value_states.to(self.offload_device)) + else: + # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to + # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does + # explicitly an in-place operation, that avoids copies and uses less memory. + try: + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS + # device. + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + # Copy the values to the offloaded device as well. + if layer_idx != 0: + cache_position = cache_position.to(self.offload_device) + key_states = key_states.to(self.offload_device) + value_states = value_states.to(self.offload_device) + + try: + self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) + self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS + # device. + self.key_cache[layer_idx][:, :, cache_position] = key_states + self.value_cache[layer_idx][:, :, cache_position] = value_states + + return k_out, v_out + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states that were seen by the model.""" + + # TODO(gante): Remove this. + return self._seen_tokens + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states.""" + + return self.max_cache_len + + def reset(self) -> None: + """Resets the cache values while preserving the objects.""" + + # For backwards compatibility. + # TODO(gante): Remove this. + self._seen_tokens = 0 + + # Zero out cache. + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address. + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + @property + def seen_tokens(self) -> int: + # For backwards compatibility. + # TODO(gante): Remove this. + return self._seen_tokens + + def _create_key_value_cache_tensors( + self, shape: Tuple[int, ...], device: torch.device + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Creates K/V cache tensors on a device. Pins memory for CPU tensors. Marks them as static + addresses for non-CPU tensors. + + Args: + shape (`Tuple[int, ...]`): Shape. + device (`torch.device`): Device. + + Returns: + Key and value cache tensors as a tuple. + """ + + is_cpu_device = device == torch.device("cpu") + + key_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device) + value_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device) + + # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, + # preventing compiled graph breaks when updating the cache. + torch._dynamo.mark_static_address(key_cache) + torch._dynamo.mark_static_address(value_cache) + + return key_cache, value_cache + + def _prefetch_layer(self, layer_idx: int) -> None: + """Prefetch a layer to the device. Needs to be called in order of layer indices.""" + + # Don't fetch layers that do not exist. + if layer_idx >= len(self.key_cache): + return + + # Alternate between two on-device caches. + if self._prefetch_stream is not None: + with torch.cuda.stream(self._prefetch_stream): + self._prefetch_layer_in_context(layer_idx) + else: + self._prefetch_layer_in_context(layer_idx) + + def _prefetch_layer_in_context(self, layer_idx: int) -> None: + """Performs the actual copy of the layer to device cache.""" + + self._device_key_cache[layer_idx & 1].copy_(self.key_cache[layer_idx], non_blocking=True) + self._device_value_cache[layer_idx & 1].copy_(self.value_cache[layer_idx], non_blocking=True) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0d2baea6d85f..c0fe3acb9eb3 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -33,6 +33,7 @@ HybridCache, MambaCache, OffloadedCache, + OffloadedStaticCache, QuantizedCacheConfig, QuantoQuantizedCache, SlidingWindowCache, @@ -119,6 +120,7 @@ NEED_SETUP_CACHE_CLASSES_MAPPING = { "static": StaticCache, + "offloaded_static": OffloadedStaticCache, "sliding_window": SlidingWindowCache, "hybrid": HybridCache, "mamba": MambaCache, diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index c298fdb9697a..26f6c8a4b56b 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -72,6 +72,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class OffloadedStaticCache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class QuantizedCache(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 4a9acf4a271f..0bb604c96f8c 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -380,8 +380,15 @@ def test_sink_cache_iterative_prompts(self): self.assertTrue(decoded[0].endswith(last_output)) @require_torch_gpu - @parameterized.expand(["eager", "sdpa"]) - def test_static_cache_greedy_decoding_pad_left(self, attn_implementation): + @parameterized.expand( + [ + ("eager", "static"), + ("sdpa", "static"), + ("eager", "offloaded-static"), + ("sdpa", "offloaded-static"), + ] + ) + def test_static_cache_greedy_decoding_pad_left(self, attn_implementation, cache_implementation): EXPECTED_GENERATION = [ "The best color is the one that complements the skin tone of the", "We should not undermind the issues at hand.\nWe should not undermind the issues", @@ -406,7 +413,7 @@ def test_static_cache_greedy_decoding_pad_left(self, attn_implementation): self.assertListEqual(decoded, EXPECTED_GENERATION) set_seed(0) - model.generation_config.cache_implementation = "static" + model.generation_config.cache_implementation = cache_implementation gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) with self.subTest(f"{attn_implementation}, static, eager"): @@ -420,8 +427,15 @@ def test_static_cache_greedy_decoding_pad_left(self, attn_implementation): self.assertListEqual(decoded, EXPECTED_GENERATION) @require_torch_gpu - @parameterized.expand(["eager", "sdpa"]) - def test_static_cache_greedy_decoding_pad_right(self, attn_implementation): + @parameterized.expand( + [ + ("eager", "static"), + ("sdpa", "static"), + ("eager", "offloaded-static"), + ("sdpa", "offloaded-static"), + ] + ) + def test_static_cache_greedy_decoding_pad_right(self, attn_implementation, cache_implementation): EXPECTED_GENERATION = [ "The best color isЋ the one that complements the skin tone of", "We should not undermind the issues at hand.\nWe should not undermind the issues", @@ -446,7 +460,7 @@ def test_static_cache_greedy_decoding_pad_right(self, attn_implementation): self.assertListEqual(decoded, EXPECTED_GENERATION) set_seed(0) - model.generation_config.cache_implementation = "static" + model.generation_config.cache_implementation = cache_implementation gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) with self.subTest(f"{attn_implementation}, static, eager"): @@ -506,7 +520,13 @@ def test_dynamic_cache_extra_left_padding(self): decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) self.assertListEqual(decoded, EXPECTED_GENERATION) - def test_static_cache_extra_left_padding(self): + @parameterized.expand( + [ + "static", + "offloaded-static", + ] + ) + def test_static_cache_extra_left_padding(self, cache_implementation): """Tests that adding extra left-padding does not affect the generation with the static cache""" EXPECTED_GENERATION = [ "The best color is the one that complements the skin tone of the", @@ -524,7 +544,7 @@ def test_static_cache_extra_left_padding(self): ["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt" ).to(model.device) - model.generation_config.cache_implementation = "static" + model.generation_config.cache_implementation = cache_implementation gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)