Skip to content

Commit

Permalink
Add a static cache that offloads to the CPU or other device (#32161)
Browse files Browse the repository at this point in the history
* Add a static cache that offloads to the CPU or other device

* Fix PR comments, add unit-tests
  • Loading branch information
gerbenvv authored Aug 29, 2024
1 parent 92a75ff commit 5129671
Show file tree
Hide file tree
Showing 7 changed files with 350 additions and 19 deletions.
5 changes: 5 additions & 0 deletions docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,11 @@ A [`Constraint`] can be used to force the generation to include specific tokens
- get_seq_length
- reset

[[autodoc]] OffloadedStaticCache
- update
- get_seq_length
- reset

[[autodoc]] HybridCache
- update
- get_seq_length
Expand Down
45 changes: 34 additions & 11 deletions docs/source/en/kv_cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,15 @@ with the [`~DynamicCache`] class being the default cache for most models. It all

Refer to the table below to see the difference between cache types and choose the one that suits best for your use-case.

| Cache Type | Memory Efficient | Supports torch.compile() | Initialization Recommended | Latency | Long Context Generation |
|---------------------|------------------|--------------------------|----------------------------|----------|--------------------------|
| Dynamic Cache | No | No | No | Mid | No |
| Static Cache | No | Yes | Yes | High | No |
| Quantized Cache | Yes | No | No | Low | Yes |
| Offloaded Cache | Yes | No | No | Low | No |
| Sliding Window Cache| No | Yes | Yes | High | No |
| Sink Cache | Yes | No | Yes | Mid | Yes |
| Cache Type | Memory Efficient | Supports torch.compile() | Initialization Recommended | Latency | Long Context Generation |
|------------------------|------------------|--------------------------|----------------------------|---------|-------------------------|
| Dynamic Cache | No | No | No | Mid | No |
| Static Cache | No | Yes | Yes | High | No |
| Offloaded Cache | Yes | No | No | Low | Yes |
| Offloaded Static Cache | No | Yes | Yes | High | Yes |
| Quantized Cache | Yes | No | No | Low | Yes |
| Sliding Window Cache | No | Yes | Yes | High | No |
| Sink Cache | Yes | No | Yes | Mid | Yes |


These cache classes can be set with a `cache_implementation` argument when generating. To learn about the available options for the cache_implementation flag, please refer to the [API Documentation](./main_classes/text_generation.md#transformers.GenerationConfig). Now, let's explore each cache type in detail and see how to use them. Note that the below examples are for decoder-only Tranformer-based models. We also support ["Model-Specific Cache"] classes for models such as Mamba or Jamba, keep reading for more details.
Expand Down Expand Up @@ -142,7 +143,7 @@ I like rock music because it's loud and energetic. It's a great way to express m
I like rock music because it's loud and energetic. I like to listen to it when I'm feeling
```

## OffloadedCache
## Offloaded Cache

Similarly to KV cache quantization, [`~OffloadedCache`] strategy aims to reduce GPU VRAM usage.
It does so by moving the KV cache for most layers to the CPU.
Expand All @@ -154,7 +155,8 @@ Thus, it can serve as a drop-in replacement or a fallback for it.
Depending on your model and the characteristics of your generation task (size of context, number of generated tokens, number of beams, etc.)
you may notice a small degradation in generation throughput compared to the default KV cache implementation.

To enable KV cache offloading, pass `cache_implementation="offloaded"` in the `generation_config` or directky to the `generate()` call.
To enable KV cache offloading, pass `cache_implementation="offloaded"` in the `generation_config` or directly to the `generate()` call.
Use `cache_implementation="offloaded_static"` for an offloaded static cache (see also [Offloaded Static Cache](#offloaded-static-cache) below).

```python
>>> import torch
Expand Down Expand Up @@ -216,7 +218,6 @@ retrying with cache_implementation='offloaded'
before successfully generating 40 beams.



### Static Cache

Since the "DynamicCache" dynamically grows with each generation step, it prevents you from taking advantage of JIT optimizations. The [`~StaticCache`] pre-allocates
Expand All @@ -238,6 +239,28 @@ For more examples with Static Cache and JIT compilation, take a look at [StaticC
"Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of"
```


## Offloaded Static Cache

Like [`~OffloadedCache`] exists for offloading a "DynamicCache", there is also an offloaded static cache. It fully supports
JIT optimizations. Just pass `cache_implementation="offloaded_static"` in the `generation_config` or directly to the `generate()` call.
This will use the [`~OffloadedStaticCache`] implementation instead.

```python
>>> import torch
>>> from transformers import AutoTokenizer, AutoModelForCausalLM

>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto")
>>> inputs = tokenizer("Hello, my name is", return_tensors="pt").to(model.device)

>>> # simply pass the cache implementation="static"
>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="offloaded_static")
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
"Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of"
```


### Sliding Window Cache

As the name suggests, this cache type implements a sliding window over previous keys and values, retaining only the last `sliding_window` tokens. It should be used with models like Mistral that support sliding window attention. Additionally, similar to Static Cache, this one is JIT-friendly and can be used with the same compile tecniques as Static Cache.
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,6 +1246,7 @@
"HybridCache",
"MambaCache",
"OffloadedCache",
"OffloadedStaticCache",
"QuantizedCache",
"QuantizedCacheConfig",
"QuantoQuantizedCache",
Expand Down Expand Up @@ -6052,6 +6053,7 @@
HybridCache,
MambaCache,
OffloadedCache,
OffloadedStaticCache,
QuantizedCache,
QuantizedCacheConfig,
QuantoQuantizedCache,
Expand Down
272 changes: 272 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1708,3 +1708,275 @@ def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
def reset(self):
self.conv_states.zero_()
self.ssm_states.zero_()


class OffloadedStaticCache(StaticCache):
"""
Static cache class to be used with `torch.compile(model)` that offloads to the CPU or
another device.
Args:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize
the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`Union[str, torch.device]`):
The device on which the cache should be initialized. Should be the same as the
layer device.
dtype (`torch.dtype`, *optional*):
The default `dtype` to use when initializing the cache.
offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`):
The device to offload to. Defaults to CPU.
Attributes:
key_cache (`List[torch.Tensor]`):
Off-loaded key cache tensors. First one will be on device, where-as the others are
off-loaded.
value_cache (`List[torch.Tensor]`):
Off-loaded value cache tensors. First one will be on device, where-as the others are
off-loaded.
max_batch_size (`int`):
The maximum batch size with which this cache can be used.
max_cache_len (`int`):
The maximum sequence length with which this cache can be used.
device (`torch.device`):
The device on which the cache is used.
offload_device (`torch.device`):
The device used to offload to.
dtype (`torch.dtype`):
The `dtype` used to initializing the cache.
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = OffloadedStaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
```
"""

def __init__(
self,
config: PretrainedConfig,
max_batch_size: int,
max_cache_len: Optional[int],
device: Union[str, torch.device],
dtype: Optional[torch.dtype] = None,
offload_device: Union[str, torch.device] = torch.device("cpu"),
) -> None:
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.device = torch.device(device)
self.offload_device = torch.device(offload_device)
self.dtype = dtype if dtype is not None else torch.float32

# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads

num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)

cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim)

# Create offloaded CPU tensors.
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []

for i in range(config.num_hidden_layers):
# First layer is always on-device.
device = self.device if i == 0 else self.offload_device

key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, device)

self.key_cache.append(key_cache)
self.value_cache.append(value_cache)

# Create device tensors.
self._device_key_cache: List[torch.Tensor] = []
self._device_value_cache: List[torch.Tensor] = []

for i in range(2):
key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, self.device)

self._device_key_cache.append(key_cache)
self._device_value_cache.append(value_cache)

# For backwards compatibility.
# TODO(gante): Remove this.
self._seen_tokens = 0

# Create new CUDA stream for parallel prefetching.
self._prefetch_stream = torch.cuda.Stream() if self.device.type == "cuda" else None

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, *optional*):
Additional arguments for the cache subclass. The `OffloadedStaticCache` needs the
`cache_position` input to know how where to write in the cache.
Return:
A tuple containing the updated key and value states.
"""

if layer_idx == 0:
# Update seen tokens.
# TODO(gante): Remove this.
self._seen_tokens += key_states.shape[-2]

# Always there.
k_out = self.key_cache[0]
v_out = self.value_cache[0]
else:
# Wait for prefetch stream.
if self._prefetch_stream is not None:
torch.cuda.default_stream(self.device).wait_stream(self._prefetch_stream)

k_out = self._device_key_cache[layer_idx & 1]
v_out = self._device_value_cache[layer_idx & 1]

self._prefetch_layer(layer_idx + 1)

cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
if cache_position is None:
k_out.copy_(key_states)
v_out.copy_(value_states)

# Copy the values to the offloaded device as well.
if layer_idx == 0:
self.key_cache[layer_idx].copy_(key_states.to(self.offload_device))
self.value_cache[layer_idx].copy_(value_states.to(self.offload_device))
else:
# Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
# `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does
# explicitly an in-place operation, that avoids copies and uses less memory.
try:
k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# The operator 'aten::index_copy.out' is not currently implemented for the MPS
# device.
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states

# Copy the values to the offloaded device as well.
if layer_idx != 0:
cache_position = cache_position.to(self.offload_device)
key_states = key_states.to(self.offload_device)
value_states = value_states.to(self.offload_device)

try:
self.key_cache[layer_idx].index_copy_(2, cache_position, key_states)
self.value_cache[layer_idx].index_copy_(2, cache_position, value_states)
except NotImplementedError:
# The operator 'aten::index_copy.out' is not currently implemented for the MPS
# device.
self.key_cache[layer_idx][:, :, cache_position] = key_states
self.value_cache[layer_idx][:, :, cache_position] = value_states

return k_out, v_out

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model."""

# TODO(gante): Remove this.
return self._seen_tokens

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states."""

return self.max_cache_len

def reset(self) -> None:
"""Resets the cache values while preserving the objects."""

# For backwards compatibility.
# TODO(gante): Remove this.
self._seen_tokens = 0

# Zero out cache.
for layer_idx in range(len(self.key_cache)):
# In-place ops prevent breaking the static address.
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()

@property
def seen_tokens(self) -> int:
# For backwards compatibility.
# TODO(gante): Remove this.
return self._seen_tokens

def _create_key_value_cache_tensors(
self, shape: Tuple[int, ...], device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Creates K/V cache tensors on a device. Pins memory for CPU tensors. Marks them as static
addresses for non-CPU tensors.
Args:
shape (`Tuple[int, ...]`): Shape.
device (`torch.device`): Device.
Returns:
Key and value cache tensors as a tuple.
"""

is_cpu_device = device == torch.device("cpu")

key_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device)
value_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device)

# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
# preventing compiled graph breaks when updating the cache.
torch._dynamo.mark_static_address(key_cache)
torch._dynamo.mark_static_address(value_cache)

return key_cache, value_cache

def _prefetch_layer(self, layer_idx: int) -> None:
"""Prefetch a layer to the device. Needs to be called in order of layer indices."""

# Don't fetch layers that do not exist.
if layer_idx >= len(self.key_cache):
return

# Alternate between two on-device caches.
if self._prefetch_stream is not None:
with torch.cuda.stream(self._prefetch_stream):
self._prefetch_layer_in_context(layer_idx)
else:
self._prefetch_layer_in_context(layer_idx)

def _prefetch_layer_in_context(self, layer_idx: int) -> None:
"""Performs the actual copy of the layer to device cache."""

self._device_key_cache[layer_idx & 1].copy_(self.key_cache[layer_idx], non_blocking=True)
self._device_value_cache[layer_idx & 1].copy_(self.value_cache[layer_idx], non_blocking=True)
2 changes: 2 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
HybridCache,
MambaCache,
OffloadedCache,
OffloadedStaticCache,
QuantizedCacheConfig,
QuantoQuantizedCache,
SlidingWindowCache,
Expand Down Expand Up @@ -119,6 +120,7 @@

NEED_SETUP_CACHE_CLASSES_MAPPING = {
"static": StaticCache,
"offloaded_static": OffloadedStaticCache,
"sliding_window": SlidingWindowCache,
"hybrid": HybridCache,
"mamba": MambaCache,
Expand Down
Loading

0 comments on commit 5129671

Please sign in to comment.