Skip to content

Commit

Permalink
BUG: Use Cache class instead of raw tuple for transformers contin…
Browse files Browse the repository at this point in the history
…uous batching, compatible with latest `transformers` (#2820)
  • Loading branch information
ChengjieLi28 authored Feb 8, 2025
1 parent 0129b84 commit ac97a13
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 18 deletions.
11 changes: 4 additions & 7 deletions xinference/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,16 +269,13 @@ def get_generate_configs(
)


def _get_valid_batch_kv_cache(data, skipped_indexes: Set[int]):
from transformers.cache_utils import DynamicCache

cache = DynamicCache.from_legacy_cache(data)
def _get_valid_batch_kv_cache(cache, skipped_indexes: Set[int]):
batch_size = cache.key_cache[0].shape[0]
batch_slices = [num for num in range(batch_size) if num not in skipped_indexes]
for idx in range(len(cache)):
cache.key_cache[idx] = cache.key_cache[idx][batch_slices, ::]
cache.value_cache[idx] = cache.value_cache[idx][batch_slices, ::]
return cache.to_legacy_cache()
cache.key_cache[idx] = cache.key_cache[idx][batch_slices, ::].contiguous()
cache.value_cache[idx] = cache.value_cache[idx][batch_slices, ::].contiguous()
return cache


class SchedulerActor(xo.StatelessActor):
Expand Down
33 changes: 22 additions & 11 deletions xinference/model/llm/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,16 +193,14 @@ def _get_pad_param(seq_len_idx: int, pad_len: int) -> Tuple:

def _merge_kv_cache(
xinf_model_obj: "PytorchModel",
past_kv: Tuple[Tuple[torch.Tensor]],
new_kv: Tuple[Tuple[torch.Tensor]],
):
past_cache: DynamicCache,
new_cache: DynamicCache,
) -> DynamicCache:
from torch.nn.functional import pad

_, seq_len_idx = xinf_model_obj.get_batch_size_and_seq_len_indexes_from_kv()
past_cache = DynamicCache.from_legacy_cache(past_kv)
new_cache = DynamicCache.from_legacy_cache(new_kv)
past_seq_len = past_kv[0][0].shape[seq_len_idx]
new_seq_len = new_kv[0][0].shape[seq_len_idx]
past_seq_len = past_cache[0][0].shape[seq_len_idx]
new_seq_len = new_cache[0][0].shape[seq_len_idx]
if past_seq_len != new_seq_len:
padding_target = new_cache if past_seq_len > new_seq_len else past_cache
padding_len = abs(past_seq_len - new_seq_len)
Expand All @@ -219,15 +217,28 @@ def _merge_kv_cache(
for idx in range(len(past_cache)):
k1, k2 = new_cache.key_cache[idx], past_cache.key_cache[idx]
v1, v2 = new_cache.value_cache[idx], past_cache.value_cache[idx]
ret_kv.update(torch.cat((k1, k2), 0), torch.cat((v1, v2), 0), idx)
return ret_kv.to_legacy_cache()
ret_kv.update(
torch.cat((k1, k2), 0).contiguous(),
torch.cat((v1, v2), 0).contiguous(),
idx,
)
return ret_kv


def get_batch_size_and_seq_len_from_kv_cache(kv, xinf_model_obj: "PytorchModel"):
bs_idx, seq_len_idx = xinf_model_obj.get_batch_size_and_seq_len_indexes_from_kv()
return kv[0][0].shape[bs_idx], kv[0][0].shape[seq_len_idx] + 1


def convert_to_cache_cls(cache) -> DynamicCache:
"""
Compatible with some old models
"""
if isinstance(cache, tuple):
return DynamicCache.from_legacy_cache(cache)
return cache


@torch.inference_mode()
def _batch_inference_one_step_internal(
xinf_model_obj: "PytorchModel",
Expand Down Expand Up @@ -269,7 +280,7 @@ def _batch_inference_one_step_internal(
out = model(**prefill_kws, use_cache=True)

logits = out.logits
past_key_values = out.past_key_values
past_key_values = convert_to_cache_cls(out.past_key_values)

for i, r in enumerate(prefill_reqs):
(
Expand Down Expand Up @@ -317,7 +328,7 @@ def _batch_inference_one_step_internal(
)
out = model(**inf_kws, use_cache=True, past_key_values=past_key_values)
logits = out.logits
past_key_values = out.past_key_values
past_key_values = convert_to_cache_cls(out.past_key_values)

for i, r in enumerate(valid_req_list):
(
Expand Down

0 comments on commit ac97a13

Please sign in to comment.