From b5002242e04634bca7e75cac9df0cdc6c0bf407a Mon Sep 17 00:00:00 2001 From: amumu96 <128140880+amumu96@users.noreply.github.com> Date: Fri, 23 Aug 2024 18:14:53 +0800 Subject: [PATCH] FEAT: Support LMDeploy for internvl2 and fix finish reasion miss at internvl stream (#2145) Co-authored-by: wuzhaoxin <15667065080@162.com> --- xinference/core/model.py | 5 +- xinference/deploy/docker/Dockerfile | 2 + xinference/model/llm/__init__.py | 4 + xinference/model/llm/llm_family.json | 18 +- xinference/model/llm/llm_family.py | 2 + .../model/llm/llm_family_modelscope.json | 20 +- xinference/model/llm/lmdeploy/__init__.py | 0 xinference/model/llm/lmdeploy/core.py | 557 ++++++++++++++++++ .../model/llm/lmdeploy/tests/__init__.py | 13 + .../model/llm/transformers/intern_vl.py | 18 +- xinference/model/llm/utils.py | 11 +- xinference/model/llm/vllm/core.py | 2 +- 12 files changed, 629 insertions(+), 23 deletions(-) create mode 100644 xinference/model/llm/lmdeploy/__init__.py create mode 100644 xinference/model/llm/lmdeploy/core.py create mode 100644 xinference/model/llm/lmdeploy/tests/__init__.py diff --git a/xinference/core/model.py b/xinference/core/model.py index 2650266efe..602f712514 100644 --- a/xinference/core/model.py +++ b/xinference/core/model.py @@ -177,6 +177,7 @@ def __init__( request_limits: Optional[int] = None, ): super().__init__() + from ..model.llm.lmdeploy.core import LMDeployModel from ..model.llm.sglang.core import SGLANGModel from ..model.llm.transformers.core import PytorchModel from ..model.llm.vllm.core import VLLMModel @@ -192,7 +193,9 @@ def __init__( self._current_generator = lambda: None self._lock = ( None - if isinstance(self._model, (PytorchModel, VLLMModel, SGLANGModel)) + if isinstance( + self._model, (PytorchModel, VLLMModel, SGLANGModel, LMDeployModel) + ) else asyncio.locks.Lock() ) self._worker_ref = None diff --git a/xinference/deploy/docker/Dockerfile b/xinference/deploy/docker/Dockerfile index 7c5f583af0..1975adb5eb 100644 --- a/xinference/deploy/docker/Dockerfile +++ b/xinference/deploy/docker/Dockerfile @@ -30,6 +30,8 @@ RUN pip install --upgrade -i "$PIP_INDEX" pip && \ pip install "llama-cpp-python>=0.2.82" -i https://abetlen.github.io/llama-cpp-python/whl/cu124 && \ pip install -i "$PIP_INDEX" --upgrade-strategy only-if-needed -r /opt/inference/xinference/deploy/docker/requirements.txt && \ pip install -i "$PIP_INDEX" --no-deps sglang && \ + pip uninstall flashinfer -y && \ + pip install flashinfer -i https://flashinfer.ai/whl/cu124/torch2.4 && \ cd /opt/inference && \ python3 setup.py build_web && \ git restore . && \ diff --git a/xinference/model/llm/__init__.py b/xinference/model/llm/__init__.py index 8e6f0d7b5b..9909addebb 100644 --- a/xinference/model/llm/__init__.py +++ b/xinference/model/llm/__init__.py @@ -34,6 +34,7 @@ BUILTIN_MODELSCOPE_LLM_FAMILIES, LLAMA_CLASSES, LLM_ENGINES, + LMDEPLOY_CLASSES, MLX_CLASSES, SGLANG_CLASSES, SUPPORTED_ENGINES, @@ -113,6 +114,7 @@ def generate_engine_config_by_model_family(model_family): def _install(): from .llama_cpp.core import LlamaCppChatModel, LlamaCppModel + from .lmdeploy.core import LMDeployChatModel, LMDeployModel from .mlx.core import MLXChatModel, MLXModel from .sglang.core import SGLANGChatModel, SGLANGModel from .transformers.chatglm import ChatglmPytorchChatModel @@ -148,6 +150,7 @@ def _install(): SGLANG_CLASSES.extend([SGLANGModel, SGLANGChatModel]) VLLM_CLASSES.extend([VLLMModel, VLLMChatModel, VLLMVisionModel]) MLX_CLASSES.extend([MLXModel, MLXChatModel]) + LMDEPLOY_CLASSES.extend([LMDeployModel, LMDeployChatModel]) TRANSFORMERS_CLASSES.extend( [ ChatglmPytorchChatModel, @@ -176,6 +179,7 @@ def _install(): SUPPORTED_ENGINES["Transformers"] = TRANSFORMERS_CLASSES SUPPORTED_ENGINES["llama.cpp"] = LLAMA_CLASSES SUPPORTED_ENGINES["MLX"] = MLX_CLASSES + SUPPORTED_ENGINES["LMDEPLOY"] = LMDEPLOY_CLASSES json_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "llm_family.json" diff --git a/xinference/model/llm/llm_family.json b/xinference/model/llm/llm_family.json index 89c725ef4d..26f1d599a8 100644 --- a/xinference/model/llm/llm_family.json +++ b/xinference/model/llm/llm_family.json @@ -7189,15 +7189,6 @@ "model_id": "OpenGVLab/InternVL2-4B", "model_revision": "b50544dafada6c41e80bfde2f57cc9b0140fc21c" }, - { - "model_format": "awq", - "model_size_in_billions": 4, - "quantizations": [ - "Int4" - ], - "model_id": "OpenGVLab/InternVL2-8B-AWQ", - "model_revision": "9f1a4756b7ae18eb26d8a22b618dfc283e8193b3" - }, { "model_format": "pytorch", "model_size_in_billions": 8, @@ -7209,6 +7200,15 @@ "model_id": "OpenGVLab/InternVL2-8B", "model_revision": "3bfd3664dea4f3da628785f5125d30f889701253" }, + { + "model_format": "awq", + "model_size_in_billions": 8, + "quantizations": [ + "Int4" + ], + "model_id": "OpenGVLab/InternVL2-8B-AWQ", + "model_revision": "9f1a4756b7ae18eb26d8a22b618dfc283e8193b3" + }, { "model_format": "pytorch", "model_size_in_billions": 26, diff --git a/xinference/model/llm/llm_family.py b/xinference/model/llm/llm_family.py index 8618155f28..c2ea4d7b98 100644 --- a/xinference/model/llm/llm_family.py +++ b/xinference/model/llm/llm_family.py @@ -271,6 +271,8 @@ def parse_raw( MLX_CLASSES: List[Type[LLM]] = [] +LMDEPLOY_CLASSES: List[Type[LLM]] = [] + LLM_ENGINES: Dict[str, Dict[str, List[Dict[str, Any]]]] = {} SUPPORTED_ENGINES: Dict[str, List[Type[LLM]]] = {} diff --git a/xinference/model/llm/llm_family_modelscope.json b/xinference/model/llm/llm_family_modelscope.json index 1cbcbfbfa2..44ac3e7794 100644 --- a/xinference/model/llm/llm_family_modelscope.json +++ b/xinference/model/llm/llm_family_modelscope.json @@ -4778,10 +4778,10 @@ "model_revision": "master" }, { - "model_format": "pytorch", + "model_format": "awq", "model_size_in_billions": 2, "quantizations": [ - "none" + "Int4" ], "model_hub": "modelscope", "model_id": "OpenGVLab/InternVL2-2B-AWQ", @@ -4812,10 +4812,10 @@ "model_revision": "master" }, { - "model_format": "pytorch", + "model_format": "awq", "model_size_in_billions": 8, "quantizations": [ - "none" + "Int4" ], "model_hub": "modelscope", "model_id": "OpenGVLab/InternVL2-8B-AWQ", @@ -4834,10 +4834,10 @@ "model_revision": "master" }, { - "model_format": "pytorch", + "model_format": "awq", "model_size_in_billions": 26, "quantizations": [ - "none" + "Int4" ], "model_hub": "modelscope", "model_id": "OpenGVLab/InternVL2-26B-AWQ", @@ -4856,10 +4856,10 @@ "model_revision": "master" }, { - "model_format": "pytorch", + "model_format": "awq", "model_size_in_billions": 40, "quantizations": [ - "none" + "Int4" ], "model_hub": "modelscope", "model_id": "OpenGVLab/InternVL2-40B-AWQ", @@ -4878,10 +4878,10 @@ "model_revision": "master" }, { - "model_format": "pytorch", + "model_format": "awq", "model_size_in_billions": 76, "quantizations": [ - "none" + "Int4" ], "model_hub": "modelscope", "model_id": "OpenGVLab/InternVL2-Llama3-76B-AWQ", diff --git a/xinference/model/llm/lmdeploy/__init__.py b/xinference/model/llm/lmdeploy/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/xinference/model/llm/lmdeploy/core.py b/xinference/model/llm/lmdeploy/core.py new file mode 100644 index 0000000000..22fbd53e72 --- /dev/null +++ b/xinference/model/llm/lmdeploy/core.py @@ -0,0 +1,557 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import time +import uuid +from typing import AsyncGenerator, Dict, Iterator, List, Optional, TypedDict, Union + +import torch + +from ....types import ( + ChatCompletion, + ChatCompletionChunk, + ChatCompletionChunkChoice, + ChatCompletionMessage, + Completion, + CompletionChoice, + CompletionUsage, + LoRA, +) +from ..core import LLM +from ..llm_family import LLMFamilyV1, LLMSpecV1 +from ..utils import ChatModelMixin + +logger = logging.getLogger(__name__) + +try: + import lmdeploy # noqa: F401 + + LMDEPLOY_INSTALLED = True +except ImportError: + LMDEPLOY_INSTALLED = False + +LMDEPLOY_SUPPORTED_CHAT_MODELS = ["internvl2"] +LMDEPLOY_MODEL_CHAT_TEMPLATE_NAME = { + "internvl2": "internvl-internlm2", +} + + +class LMDeployModelConfig(TypedDict, total=False): + model_format: Optional[str] + tp: Optional[int] + session_len: Optional[int] + max_batch_size: Optional[int] + cache_max_entry_count: Optional[float] + cache_block_seq_len: Optional[int] + enable_prefix_caching: Optional[bool] + quant_policy: Optional[int] + rope_scaling_factor: Optional[float] + use_logn_attn: Optional[bool] + download_dir: Optional[str] + revision: Optional[str] + max_prefill_token_num: Optional[int] + num_tokens_per_iter: Optional[int] + max_prefill_iters: Optional[int] + + +class LMDeployGenerateConfig(TypedDict, total=False): + n: Optional[int] + max_new_tokens: Optional[int] + top_p: Optional[float] + top_k: Optional[int] + temperature: Optional[float] + repetition_penalty: Optional[float] + ignore_eos: Optional[bool] + random_seed: Optional[int] + stop_words: Optional[List[str]] + bad_words: Optional[List[str]] + min_new_tokens: Optional[int] + skip_special_tokens: Optional[bool] + logprobs: Optional[int] + + +class LMDeployModel(LLM): + def __init__( + self, + model_uid: str, + model_family: "LLMFamilyV1", + model_spec: "LLMSpecV1", + quantization: str, + model_path: str, + model_config: Optional[LMDeployModelConfig] = None, + peft_model: Optional[List[LoRA]] = None, + ): + super().__init__(model_uid, model_family, model_spec, quantization, model_path) + self._model_config: LMDeployModelConfig = self._sanitize_model_config( + model_config + ) + if peft_model is not None: + raise ValueError("LMDEPLOY engine has not supported lora yet.") + + def _sanitize_model_config( + self, model_config: Optional[LMDeployModelConfig] + ) -> LMDeployModelConfig: + if model_config is None: + model_config = LMDeployModelConfig() + model_config.setdefault("session_len", 8192) + if self.model_spec.model_format == "awq": + model_config.setdefault("model_format", "awq") + return model_config + + def load(self): + try: + import lmdeploy # noqa: F401, F811 + except ImportError: + error_message = "Failed to import module 'lmdeploy'" + installation_guide = [ + "Please make sure 'lmdeploy' is installed. ", + "You can install it by `pip install lmdeploy`\n", + ] + + raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}") + raise ValueError("LMDEPLOY engine has not supported generate yet.") + + @classmethod + def match( + cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str + ) -> bool: + return False + + def generate( + self, + prompt: str, + generate_config: Optional[Dict] = None, + ) -> Union[Completion, Iterator[ChatCompletionChunk]]: + raise NotImplementedError("LMDeploy generate ablility does not support now.") + + +class LMDeployChatModel(LMDeployModel, ChatModelMixin): + def load(self): + try: + from lmdeploy import ( + ChatTemplateConfig, + TurbomindEngineConfig, + VisionConfig, + pipeline, + ) + except ImportError: + error_message = "Failed to import module 'lmdeploy'" + installation_guide = [ + "Please make sure 'lmdeploy' is installed. ", + "You can install it by `pip install lmdeploy`\n", + ] + + raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}") + + chat_temp_name = "" + family = self.model_family.model_family or self.model_family.model_name + for key in LMDEPLOY_MODEL_CHAT_TEMPLATE_NAME.keys(): + if family in key: + chat_temp_name = LMDEPLOY_MODEL_CHAT_TEMPLATE_NAME[key] + break + if chat_temp_name == "": + raise ValueError(f"Can not find correct chat template.") + + chat_template_config = ChatTemplateConfig(chat_temp_name) + chat_template_config.meta_instruction = ( + self.model_family.prompt_style.system_prompt + ) + count = torch.cuda.device_count() + if count > 1: + self._model_config.setdefault("tp", torch.cuda.device_count()) + + self._model = pipeline( + self.model_path, + chat_template_config=chat_template_config, + backend_config=TurbomindEngineConfig(**self._model_config), + vision_config=VisionConfig(thread_safe=True), + ) + + @classmethod + def match( + cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str + ) -> bool: + if llm_spec.model_format == "awq": + # Currently, only 4-bit weight quantization is supported for AWQ, but got 8 bits. + if "4" not in quantization: + return False + if llm_family.model_name not in LMDEPLOY_SUPPORTED_CHAT_MODELS: + return False + return LMDEPLOY_INSTALLED + + async def async_chat( + self, + prompt: Union[str, List[Dict]], + system_prompt: Optional[str] = None, + chat_history: Optional[List[ChatCompletionMessage]] = None, + generate_config: Optional[Dict] = None, + ) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]: + stream = ( + generate_config.get("stream", False) + if isinstance(generate_config, dict) + else False + ) + stream_options = ( + generate_config.get("stream_options", None) + if isinstance(generate_config, dict) + else False + ) + include_usage = ( + stream_options["include_usage"] + if isinstance(stream_options, dict) + else False + ) + + chat_history = chat_history or [] + + if stream: + chunk = self._chat_stream(prompt, chat_history, include_usage) + return self._async_to_chat_completion_chunks(chunk) + else: + chunk = await self._chat(prompt, chat_history) + return self._to_chat_completion(chunk) + + async def _chat_stream(self, prompt, chat_history, include_usage): + from lmdeploy.messages import Response + + prompt_tokens, completion_tokens, total_tokens = 0, 0, 0 + completion_id = str(uuid.uuid1()) + async for output in self._generate( + prompt, + chat_history, + session_id=-1, + stream_response=True, + ): + new_text = output.text if isinstance(output, Response) else output.response + + completion_choice = ChatCompletionChunkChoice( + text=new_text, + index=0, + logprobs=None, + finish_reason=output.finish_reason, + ) + chunk = ChatCompletionChunk( + id=completion_id, + object="chat.completion", + created=int(time.time()), + model=self.model_uid, + choices=[completion_choice], + ) + prompt_tokens = output.input_token_len + completion_tokens = output.generate_token_len + total_tokens = prompt_tokens + completion_tokens + completion_usage = CompletionUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + chunk["usage"] = completion_usage + print(chunk) + yield chunk + if include_usage: + chunk = ChatCompletionChunk( + id=completion_id, + object="chat.completion", + created=int(time.time()), + model=self.model_uid, + choices=[], + ) + chunk["usage"] = CompletionUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + yield chunk + + async def _chat(self, prompt, chat_history): + from lmdeploy.messages import Response + + response, finish_reason = "", "" + prompt_tokens, completion_tokens, total_tokens = 0, 0, 0 + async for output in self._generate( + prompt, + chat_history, + session_id=-1, + stream_response=False, + ): + response += output.text if isinstance(output, Response) else output.response + prompt_tokens = output.input_token_len + completion_tokens = output.generate_token_len + total_tokens = output.input_token_len + output.generate_token_len + finish_reason = output.finish_reason + + chunk = ChatCompletion( + id=str(uuid.uuid1()), + object="chat.completion", + created=int(time.time()), + model=self.model_uid, + choices=[ + CompletionChoice( + index=0, text=response, finish_reason=finish_reason, logprobs=None + ) + ], + usage=CompletionUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ), + ) + return chunk + + # copy from lmdeploy + # Reference: lmdeploy.serve.async_engine.py + async def _generate( + self, + prompt, + chat_history, + session_id: int, + generate_config: Optional[Dict] = None, + tools: Optional[List[object]] = None, + stream_response: bool = True, + sequence_start: bool = True, + sequence_end: bool = True, # no interactive mode by default + step: int = 0, + do_preprocess: bool = False, + adapter_name: Optional[str] = None, + **kwargs, + ): + import random + + from lmdeploy.messages import EngineGenerationConfig, GenerationConfig + from lmdeploy.serve.async_engine import GenOut + from lmdeploy.tokenizer import DetokenizeState + + session_id = -1 + + if str(session_id) not in self._model.id2step: + self._model.id2step[str(session_id)] = 0 + if generate_config is None: + generate_config = GenerationConfig() + if type(generate_config) is GenerationConfig: + generate_config = EngineGenerationConfig.From( + generate_config, self._model.tokenizer + ) + if generate_config.stop_words is None: # type: ignore + generate_config.stop_words = self._model.stop_words # type: ignore + if generate_config.random_seed is None and sequence_start: # type: ignore + generate_config.random_seed = random.getrandbits(64) # type: ignore + if generate_config.n > 1: # type: ignore + logger.warning( + f"n({generate_config.n}) > 1 hasn't been supported yet. " # type: ignore + f"Fallback to 1" + ) + generate_config.n = 1 # type: ignore + + prompt_input = await self._get_prompt_input(prompt, chat_history) + prompt = prompt_input["prompt"] + input_ids = prompt_input["input_ids"] + finish_reason = None + logger.info( + f"prompt={prompt!r}, " + f"gen_config={generate_config}, " + f"prompt_token_id={input_ids}, " + f"adapter_name={adapter_name}." + ) + logger.info( + f"session_id={session_id}, " # type: ignore + f"history_tokens={self._model.id2step[str(session_id)]}, " + f"input_tokens={len(input_ids)}, " + f"max_new_tokens={generate_config.max_new_tokens}, " + f"seq_start={sequence_start}, seq_end={sequence_end}, " + f"step={step}, prep={do_preprocess}" + ) + + if generate_config.max_new_tokens is None: # type: ignore + # for interactive endpoint, will try maximum possible token num + generate_config.max_new_tokens = max( # type: ignore + 128, + self._model.session_len + - self._model.id2step[str(session_id)] + - len(input_ids), + ) + elif ( + self._model.id2step[str(session_id)] + + len(input_ids) + + generate_config.max_new_tokens # type: ignore + > self._model.session_len + ): + generate_config.max_new_tokens = max( # type: ignore + self._model.session_len + - self._model.id2step[str(session_id)] + - len(input_ids), + 128, + ) + logger.error(f"Truncate max_new_tokens to {generate_config.max_new_tokens}") # type: ignore + + if ( + self._model.id2step[str(session_id)] + + len(input_ids) + + generate_config.max_new_tokens # type: ignore + > self._model.session_len + ): + logger.error(f"run out of tokens. session_id={session_id}.") + yield GenOut( + "", self._model.id2step[str(session_id)], len(input_ids), 0, "length" + ) + if sequence_end is True and sequence_start is False: + await self._model.end_session(session_id) + else: + generator = await self._model.get_generator(False, session_id) + async with self._model.safe_run(session_id): + state = DetokenizeState(len(input_ids)) + start_ids_offset = state.ids_offset + response = "" + async for outputs in generator.async_stream_infer( + session_id=session_id, + **prompt_input, + gen_config=generate_config, + adapter_name=adapter_name, + stream_output=stream_response, + sequence_start=sequence_start, + sequence_end=sequence_end, + step=self._model.id2step[str(session_id)], + ): + # decode res + res, tokens = ( + input_ids + outputs.token_ids, + outputs.num_token, + ) # noqa + if len(res) <= state.ids_offset: + continue + + ids_offset = state.ids_offset + response, state = self._model.tokenizer.detokenize_incrementally( + res, + state, + skip_special_tokens=generate_config.skip_special_tokens, # type: ignore + ) + + res = res[ids_offset:] + logprobs = None + if outputs.logprobs: + log_offset = ids_offset - start_ids_offset + logprobs = outputs.logprobs[log_offset:] + + # response, history token len, + # input token len, gen token len + yield GenOut( + response, + self._model.id2step[str(session_id)], + len(input_ids), + tokens, + finish_reason, + res, + logprobs, + ) + + finish_reason = ( + "length" if tokens >= generate_config.max_new_tokens else "stop" # type: ignore + ) + # utf-8 char at the end means it's a potential unfinished + # byte sequence + if not response.endswith("�"): + response = "" # avaid returning the last response twice + yield GenOut( + response, + self._model.id2step[str(session_id)], + len(input_ids), + tokens, + finish_reason, + ) + # update step + self._model.id2step[str(session_id)] += len(input_ids) + tokens + if sequence_end: + self._model.id2step[str(session_id)] = 0 + # manually end pytorch session + # TODO modify pytorch or turbomind api + if self._model.backend == "pytorch" and sequence_end: + await self._model.end_session(session_id) + + # copy from lmdeploy + # Reference: lmdeploy.serve.vl_async_engine.py + async def _get_prompt_input( + self, + prompt: Union[str, List[Dict]], + chat_history: Optional[List[ChatCompletionMessage]] = None, + sequence_start: bool = True, + tools: Optional[List[object]] = None, + **kwargs, + ): + """get input_ids, embeddings and offsets.""" + IMAGE_TOKEN = "" + IMAGE_DUMMY_TOKEN_INDEX = 0 + import numpy as np + + assert self.model_family.prompt_style is not None + prompt_style = self.model_family.prompt_style.copy() + chat_history = chat_history or [] + + decorated, _ = self.get_prompt(prompt, chat_history, prompt_style) # type: ignore + chat_history.append(ChatCompletionMessage(role="user", content=prompt)) # type: ignore + prompt = chat_history # type: ignore + + decorated = decorated.replace("", "") + + segs = decorated.split(IMAGE_TOKEN) + + results = {} + input_ids = [] # type: ignore + if len(segs) > 1: + images = await self._model.vl_prompt_template.async_collect_pil_images( + prompt + ) + + features = await self._model.vl_encoder.async_infer(images) + + from lmdeploy.vl.templates import MiniCPMVTempateWrapper + + if isinstance(self._model.vl_prompt_template, MiniCPMVTempateWrapper): + ( + decorated, + features, + ) = self._model.vl_prompt_template.update_image_token( # noqa: E501 + decorated, features + ) + segs = decorated.split(IMAGE_TOKEN) + + features = [x.cpu().numpy() for x in features] + input_ids = [] + begins = [] + ends = [] + if len(segs) != len(features) + 1: + logger.error( + f"the number of {IMAGE_TOKEN} is not equal " + f"to input images, {len(segs) - 1} vs {len(features)}" + ) + features = features[: len(segs) - 1] + for i, seg in enumerate(segs): + if i > 0 and i <= len(features): + image_dim = features[i - 1].shape[0] + begins.append(len(input_ids)) + ends.append(begins[-1] + image_dim) + input_ids.extend([IMAGE_DUMMY_TOKEN_INDEX] * image_dim) + seg_ids = self._model.tokenizer.encode( + seg, add_bos=((i == 0) and sequence_start) + ) + input_ids.extend(seg_ids) + ranges = np.stack([begins, ends], axis=1).tolist() + results["input_embeddings"] = features + results["input_embedding_ranges"] = ranges + else: + input_ids = self._model.tokenizer.encode(decorated, add_bos=sequence_start) + + results["input_ids"] = input_ids + results["prompt"] = decorated + + return results diff --git a/xinference/model/llm/lmdeploy/tests/__init__.py b/xinference/model/llm/lmdeploy/tests/__init__.py new file mode 100644 index 0000000000..37f6558d95 --- /dev/null +++ b/xinference/model/llm/lmdeploy/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/xinference/model/llm/transformers/intern_vl.py b/xinference/model/llm/transformers/intern_vl.py index dc90cd0516..02632e2af8 100644 --- a/xinference/model/llm/transformers/intern_vl.py +++ b/xinference/model/llm/transformers/intern_vl.py @@ -507,7 +507,23 @@ def _generate_stream(self, generate_kwargs, input_ids, include_usage): ) chunk["usage"] = completion_usage yield chunk - + completion_choice = CompletionChoice( + text="", index=0, logprobs=None, finish_reason="stop" + ) + chunk = CompletionChunk( + id=completion_id, + object="text_completion", + created=int(time.time()), + model=self.model_uid, + choices=[completion_choice], + ) + completion_usage = CompletionUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + chunk["usage"] = completion_usage + yield chunk if include_usage: chunk = CompletionChunk( id=completion_id, diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py index 2caaa83c71..7f203d0c21 100644 --- a/xinference/model/llm/utils.py +++ b/xinference/model/llm/utils.py @@ -459,7 +459,16 @@ def get_role(role_name: str): role = get_role(message["role"]) content = message["content"] if isinstance(content, str): - ret += role + "\n" + content + prompt_style.intra_message_sep + "\n" + if content: + ret += ( + role + + "\n" + + content + + prompt_style.intra_message_sep + + "\n" + ) + else: + ret += role + "\n" elif isinstance(content, list): text = "" image_urls = [] diff --git a/xinference/model/llm/vllm/core.py b/xinference/model/llm/vllm/core.py index 49a324e5d9..4b009aa646 100644 --- a/xinference/model/llm/vllm/core.py +++ b/xinference/model/llm/vllm/core.py @@ -721,7 +721,7 @@ async def async_chat( prompt_style = self.model_family.prompt_style.copy() chat_history = chat_history or [] prompt, images = self.get_prompt(prompt, chat_history, prompt_style) - logger.info(f"messages:{prompt}") + if len(images) == 0: inputs = { "prompt": prompt,