From ac97a13a831de6debda52e6fdb8c1bf9366be57c Mon Sep 17 00:00:00 2001 From: Chengjie Li <109656400+ChengjieLi28@users.noreply.github.com> Date: Sat, 8 Feb 2025 17:06:47 +0800 Subject: [PATCH] BUG: Use `Cache` class instead of raw `tuple` for transformers continuous batching, compatible with latest `transformers` (#2820) --- xinference/core/scheduler.py | 11 +++----- xinference/model/llm/transformers/utils.py | 33 ++++++++++++++-------- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/xinference/core/scheduler.py b/xinference/core/scheduler.py index 8b91855daa..703bf0a693 100644 --- a/xinference/core/scheduler.py +++ b/xinference/core/scheduler.py @@ -269,16 +269,13 @@ def get_generate_configs( ) -def _get_valid_batch_kv_cache(data, skipped_indexes: Set[int]): - from transformers.cache_utils import DynamicCache - - cache = DynamicCache.from_legacy_cache(data) +def _get_valid_batch_kv_cache(cache, skipped_indexes: Set[int]): batch_size = cache.key_cache[0].shape[0] batch_slices = [num for num in range(batch_size) if num not in skipped_indexes] for idx in range(len(cache)): - cache.key_cache[idx] = cache.key_cache[idx][batch_slices, ::] - cache.value_cache[idx] = cache.value_cache[idx][batch_slices, ::] - return cache.to_legacy_cache() + cache.key_cache[idx] = cache.key_cache[idx][batch_slices, ::].contiguous() + cache.value_cache[idx] = cache.value_cache[idx][batch_slices, ::].contiguous() + return cache class SchedulerActor(xo.StatelessActor): diff --git a/xinference/model/llm/transformers/utils.py b/xinference/model/llm/transformers/utils.py index 614df6256d..ea81770307 100644 --- a/xinference/model/llm/transformers/utils.py +++ b/xinference/model/llm/transformers/utils.py @@ -193,16 +193,14 @@ def _get_pad_param(seq_len_idx: int, pad_len: int) -> Tuple: def _merge_kv_cache( xinf_model_obj: "PytorchModel", - past_kv: Tuple[Tuple[torch.Tensor]], - new_kv: Tuple[Tuple[torch.Tensor]], -): + past_cache: DynamicCache, + new_cache: DynamicCache, +) -> DynamicCache: from torch.nn.functional import pad _, seq_len_idx = xinf_model_obj.get_batch_size_and_seq_len_indexes_from_kv() - past_cache = DynamicCache.from_legacy_cache(past_kv) - new_cache = DynamicCache.from_legacy_cache(new_kv) - past_seq_len = past_kv[0][0].shape[seq_len_idx] - new_seq_len = new_kv[0][0].shape[seq_len_idx] + past_seq_len = past_cache[0][0].shape[seq_len_idx] + new_seq_len = new_cache[0][0].shape[seq_len_idx] if past_seq_len != new_seq_len: padding_target = new_cache if past_seq_len > new_seq_len else past_cache padding_len = abs(past_seq_len - new_seq_len) @@ -219,8 +217,12 @@ def _merge_kv_cache( for idx in range(len(past_cache)): k1, k2 = new_cache.key_cache[idx], past_cache.key_cache[idx] v1, v2 = new_cache.value_cache[idx], past_cache.value_cache[idx] - ret_kv.update(torch.cat((k1, k2), 0), torch.cat((v1, v2), 0), idx) - return ret_kv.to_legacy_cache() + ret_kv.update( + torch.cat((k1, k2), 0).contiguous(), + torch.cat((v1, v2), 0).contiguous(), + idx, + ) + return ret_kv def get_batch_size_and_seq_len_from_kv_cache(kv, xinf_model_obj: "PytorchModel"): @@ -228,6 +230,15 @@ def get_batch_size_and_seq_len_from_kv_cache(kv, xinf_model_obj: "PytorchModel") return kv[0][0].shape[bs_idx], kv[0][0].shape[seq_len_idx] + 1 +def convert_to_cache_cls(cache) -> DynamicCache: + """ + Compatible with some old models + """ + if isinstance(cache, tuple): + return DynamicCache.from_legacy_cache(cache) + return cache + + @torch.inference_mode() def _batch_inference_one_step_internal( xinf_model_obj: "PytorchModel", @@ -269,7 +280,7 @@ def _batch_inference_one_step_internal( out = model(**prefill_kws, use_cache=True) logits = out.logits - past_key_values = out.past_key_values + past_key_values = convert_to_cache_cls(out.past_key_values) for i, r in enumerate(prefill_reqs): ( @@ -317,7 +328,7 @@ def _batch_inference_one_step_internal( ) out = model(**inf_kws, use_cache=True, past_key_values=past_key_values) logits = out.logits - past_key_values = out.past_key_values + past_key_values = convert_to_cache_cls(out.past_key_values) for i, r in enumerate(valid_req_list): (