From 81101518e139fa83c92c7ebe38c284fc895f4835 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20L=C3=B6wenfels?= <282+dfl@users.noreply.github.com> Date: Mon, 3 Jun 2024 14:45:27 +0100 Subject: [PATCH 1/6] WIP; adding mac support --- .gitignore | 3 ++ gradio_app.py | 71 ++++++++++++++++++++++++---------- lib_omost/memory_management.py | 14 +++++-- requirements.txt | 3 +- 4 files changed, 66 insertions(+), 25 deletions(-) diff --git a/.gitignore b/.gitignore index 594e5e6..37e8560 100644 --- a/.gitignore +++ b/.gitignore @@ -162,3 +162,6 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ + +# Mac finder directory settings +.DS_Store \ No newline at end of file diff --git a/gradio_app.py b/gradio_app.py index d0eecd8..b8079e8 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -1,4 +1,6 @@ import os +import platform +is_mac = platform.system() == 'Darwin' os.environ['HF_HOME'] = os.path.join(os.path.dirname(__file__), 'hf_download') HF_TOKEN = None @@ -10,6 +12,8 @@ import numpy as np import gradio as gr import tempfile +if is_mac: + import mlx_lm gradio_temp_dir = os.path.join(tempfile.gettempdir(), 'gradio') os.makedirs(gradio_temp_dir, exist_ok=True) @@ -32,6 +36,9 @@ import lib_omost.canvas as omost_canvas +# https://medium.com/@natsunoyuki/using-civitai-models-with-diffusers-package-45e0c475a67e +# https://huggingface.co/docs/diffusers/en/api/loaders/single_file +# https://github.com/huggingface/diffusers/blob/v0.28.0/scripts/convert_original_stable_diffusion_to_diffusers.py # SDXL @@ -67,25 +74,27 @@ memory_management.unload_all_models([text_encoder, text_encoder_2, vae, unet]) # LLM +if is_mac: + # llm_name = "mlx-community/Phi-3-mini-128k-instruct-8bit" + llm_name = "mlx-community/Meta-Llama-3-8B-4bit" + # llm_name = "mlx-community/dolphin-2.9.1-llama-3-8b-4bit" + llm_model, llm_tokenizer = mlx_lm.load(llm_name) +else: + # llm_name = 'lllyasviel/omost-phi-3-mini-128k-8bits' + llm_name = 'lllyasviel/omost-llama-3-8b-4bits' + # llm_name = 'lllyasviel/omost-dolphin-2.9-llama3-8b-4bits' + + llm_model = AutoModelForCausalLM.from_pretrained( + llm_name, + torch_dtype=torch.bfloat16, # This is computation type, not load/memory type. The loading quant type is baked in config. + token=HF_TOKEN + ) + llm_tokenizer = AutoTokenizer.from_pretrained( + llm_name, + token=HF_TOKEN + ) -# llm_name = 'lllyasviel/omost-phi-3-mini-128k-8bits' -llm_name = 'lllyasviel/omost-llama-3-8b-4bits' -# llm_name = 'lllyasviel/omost-dolphin-2.9-llama3-8b-4bits' - -llm_model = AutoModelForCausalLM.from_pretrained( - llm_name, - torch_dtype=torch.bfloat16, # This is computation type, not load/memory type. The loading quant type is baked in config. - token=HF_TOKEN, - device_map="auto" # This will load model to gpu with an offload system -) - -llm_tokenizer = AutoTokenizer.from_pretrained( - llm_name, - token=HF_TOKEN -) - -memory_management.unload_all_models(llm_model) - + memory_management.unload_all_models(llm_model) @torch.inference_mode() def pytorch2numpy(imgs): @@ -111,6 +120,26 @@ def resize_without_crop(image, target_width, target_height): return np.array(resized_image) +def llm_generate(kwargs): + if not is_mac: return llm_model.generate + return (lambda kwargs: mlx_lm.generate(llm_model, llm_tokenizer, prompt, temp, max_tokens, verbose, formatter, repetition_penalty, repetition_context_size, top_p, logit_bias)) +# generate(model: mlx.nn.layers.base.Module, tokenizer: Union[transformers.tokenization_utils.PreTrainedTokenizer, mlx_lm.tokenizer_utils.TokenizerWrapper], prompt: str, temp: float = 0.0, max_tokens: int = 100, verbose: bool = False, formatter: Optional[Callable] = None, repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = None, top_p: float = 1.0, logit_bias: Optional[Dict[int, float]] = None) -> str +# Generate text from the model. + +# Args: +# model (nn.Module): The language model. +# tokenizer (PreTrainedTokenizer): The tokenizer. +# prompt (str): The string prompt. +# temp (float): The temperature for sampling (default 0). +# max_tokens (int): The maximum number of tokens (default 100). +# verbose (bool): If ``True``, print tokens and timing information +# (default ``False``). +# formatter (Optional[Callable]): A function which takes a token and a +# probability and displays it. +# repetition_penalty (float, optional): The penalty factor for repeating tokens. +# repetition_context_size (int, optional): The number of tokens to consider for repetition penalty. + + @torch.inference_mode() def chat_fn(message: str, history: list, seed:int, temperature: float, top_p: float, max_new_tokens: int) -> str: np.random.seed(int(seed)) @@ -126,9 +155,9 @@ def chat_fn(message: str, history: list, seed:int, temperature: float, top_p: fl conversation.append({"role": "user", "content": message}) memory_management.load_models_to_gpu(llm_model) - + input_ids = llm_tokenizer.apply_chat_template( - conversation, return_tensors="pt", add_generation_prompt=True).to(llm_model.device) + conversation, return_tensors="pt", add_generation_prompt=True).to( getattr(llm_model, 'device', memory_management.gpu) ) streamer = TextIteratorStreamer(llm_tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) @@ -158,7 +187,7 @@ def interrupter(): if temperature == 0: generate_kwargs['do_sample'] = False - Thread(target=llm_model.generate, kwargs=generate_kwargs).start() + Thread(target=llm_generate, kwargs=generate_kwargs).start() outputs = [] for text in streamer: diff --git a/lib_omost/memory_management.py b/lib_omost/memory_management.py index 05ed6d0..241734a 100644 --- a/lib_omost/memory_management.py +++ b/lib_omost/memory_management.py @@ -1,13 +1,19 @@ import torch from contextlib import contextmanager +import platform +is_mac = platform.system() == 'Darwin' high_vram = False -gpu = torch.device('cuda') +if is_mac: + gpu = torch.device('mps') +else: + gpu = torch.device('cuda') cpu = torch.device('cpu') torch.zeros((1, 1)).to(gpu, torch.float32) -torch.cuda.empty_cache() + +torch.cuda.empty_cache() if not is_mac else torch.mps.empty_cache() models_in_gpu = [] @@ -27,6 +33,8 @@ def movable_bnb_model(m): def load_models_to_gpu(models): + if is_mac: return + global models_in_gpu if not isinstance(models, (tuple, list)): @@ -49,7 +57,7 @@ def load_models_to_gpu(models): print('Load to GPU:', m.__class__.__name__) models_in_gpu = list(set(models_in_gpu + models)) - torch.cuda.empty_cache() + torch.cuda.empty_cache() if not is_mac else torch.mps.empty_cache() return diff --git a/requirements.txt b/requirements.txt index 96681b5..a6c9891 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,8 @@ diffusers==0.28.0 transformers==4.41.1 gradio==4.31.5 -bitsandbytes==0.43.1 +mlx-lm==0.14.1; sys_platform == 'darwin' +bitsandbytes==0.43.1; sys_platform != 'darwin' accelerate==0.30.1 protobuf==3.20 opencv-python From 351b7d0bf171c308996c1cffb5555b1068b8b0a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20L=C3=B6wenfels?= <282+dfl@users.noreply.github.com> Date: Mon, 3 Jun 2024 17:26:57 +0100 Subject: [PATCH 2/6] update mlx-lm --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a6c9891..32d7b29 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ diffusers==0.28.0 transformers==4.41.1 gradio==4.31.5 -mlx-lm==0.14.1; sys_platform == 'darwin' +mlx-lm==0.14.2; sys_platform == 'darwin' bitsandbytes==0.43.1; sys_platform != 'darwin' accelerate==0.30.1 protobuf==3.20 From 898880f8eb764293f41e810423379d5a2b6cac48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20L=C3=B6wenfels?= <282+dfl@users.noreply.github.com> Date: Tue, 4 Jun 2024 09:17:20 +0100 Subject: [PATCH 3/6] WIP --- gradio_app.py | 41 ++++++++++++---------------------- lib_omost/memory_management.py | 2 ++ requirements.txt | 2 +- 3 files changed, 17 insertions(+), 28 deletions(-) diff --git a/gradio_app.py b/gradio_app.py index b8079e8..55f6f9e 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -14,6 +14,8 @@ import tempfile if is_mac: import mlx_lm + # from mlx_lm import load, stream_generate + gradio_temp_dir = os.path.join(tempfile.gettempdir(), 'gradio') os.makedirs(gradio_temp_dir, exist_ok=True) @@ -94,7 +96,7 @@ token=HF_TOKEN ) - memory_management.unload_all_models(llm_model) +memory_management.unload_all_models(llm_model) @torch.inference_mode() def pytorch2numpy(imgs): @@ -119,27 +121,6 @@ def resize_without_crop(image, target_width, target_height): resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS) return np.array(resized_image) - -def llm_generate(kwargs): - if not is_mac: return llm_model.generate - return (lambda kwargs: mlx_lm.generate(llm_model, llm_tokenizer, prompt, temp, max_tokens, verbose, formatter, repetition_penalty, repetition_context_size, top_p, logit_bias)) -# generate(model: mlx.nn.layers.base.Module, tokenizer: Union[transformers.tokenization_utils.PreTrainedTokenizer, mlx_lm.tokenizer_utils.TokenizerWrapper], prompt: str, temp: float = 0.0, max_tokens: int = 100, verbose: bool = False, formatter: Optional[Callable] = None, repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = None, top_p: float = 1.0, logit_bias: Optional[Dict[int, float]] = None) -> str -# Generate text from the model. - -# Args: -# model (nn.Module): The language model. -# tokenizer (PreTrainedTokenizer): The tokenizer. -# prompt (str): The string prompt. -# temp (float): The temperature for sampling (default 0). -# max_tokens (int): The maximum number of tokens (default 100). -# verbose (bool): If ``True``, print tokens and timing information -# (default ``False``). -# formatter (Optional[Callable]): A function which takes a token and a -# probability and displays it. -# repetition_penalty (float, optional): The penalty factor for repeating tokens. -# repetition_context_size (int, optional): The number of tokens to consider for repetition penalty. - - @torch.inference_mode() def chat_fn(message: str, history: list, seed:int, temperature: float, top_p: float, max_new_tokens: int) -> str: np.random.seed(int(seed)) @@ -187,13 +168,19 @@ def interrupter(): if temperature == 0: generate_kwargs['do_sample'] = False - Thread(target=llm_generate, kwargs=generate_kwargs).start() outputs = [] - for text in streamer: - outputs.append(text) - # print(outputs) - yield "".join(outputs), interrupter + if is_mac: + for text in mlx_lm.stream_generate(llm_model, llm_tokenizer, input_ids.cpu().numpy(), temp=temperature, top_p=top_p): + outputs.append(text) + # print(outputs) + yield "".join(outputs), interrupter + else: + Thread(target=llm_model.generate, kwargs=generate_kwargs).start() + for text in streamer: + outputs.append(text) + # print(outputs) + yield "".join(outputs), interrupter return diff --git a/lib_omost/memory_management.py b/lib_omost/memory_management.py index 241734a..6d4c1ae 100644 --- a/lib_omost/memory_management.py +++ b/lib_omost/memory_management.py @@ -62,6 +62,8 @@ def load_models_to_gpu(models): def unload_all_models(extra_models=None): + if is_mac: return + global models_in_gpu if extra_models is None: diff --git a/requirements.txt b/requirements.txt index 32d7b29..65bbfec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ diffusers==0.28.0 transformers==4.41.1 gradio==4.31.5 -mlx-lm==0.14.2; sys_platform == 'darwin' +mlx-lm==0.14.3; sys_platform == 'darwin' bitsandbytes==0.43.1; sys_platform != 'darwin' accelerate==0.30.1 protobuf==3.20 From 3cf1cafa61f338ac3ff33c0d0a5403f3773c2052 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20L=C3=B6wenfels?= <282+dfl@users.noreply.github.com> Date: Thu, 6 Jun 2024 17:04:01 +0100 Subject: [PATCH 4/6] WIP --- gradio_app.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/gradio_app.py b/gradio_app.py index 55f6f9e..1a56506 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -171,10 +171,31 @@ def interrupter(): outputs = [] if is_mac: - for text in mlx_lm.stream_generate(llm_model, llm_tokenizer, input_ids.cpu().numpy(), temp=temperature, top_p=top_p): - outputs.append(text) - # print(outputs) - yield "".join(outputs), interrupter + # for text in mlx_lm.stream_generate(llm_model, llm_tokenizer, input_ids.cpu().numpy(), temp=temperature, top_p=top_p): + # outputs.append(text) + # # print(outputs) + # yield "".join(outputs), interrupter + + max_tokens = 100 + detokenizer = tokenizer.detokenizer + detokenizer.reset() + prompt_tokens = input_ids.cpu().numpy() + for (token, prob), n in zip( + generate_step(prompt_tokens, llm_model, **generate_kwargs), + range(max_tokens), + ): + if token == llm_tokenizer.eos_token_id: + break + detokenizer.add_token(token) + + # Yield the last segment if streaming + # yield detokenizer.last_segment + outputs.append(detokenizer.last_segment) + + detokenizer.finalize() + # yield detokenizer.last_segment + outputs.append(detokenizer.last_segment) + else: Thread(target=llm_model.generate, kwargs=generate_kwargs).start() for text in streamer: From 566103b98690e3379808ca05da283b3eb898f45c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20L=C3=B6wenfels?= <282+dfl@users.noreply.github.com> Date: Thu, 6 Jun 2024 21:50:23 +0100 Subject: [PATCH 5/6] WIP --- gradio_app.py | 41 ++++------------------ mlx_lm_wrapper.py | 89 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 35 deletions(-) create mode 100644 mlx_lm_wrapper.py diff --git a/gradio_app.py b/gradio_app.py index 1a56506..97abd1c 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -13,8 +13,7 @@ import gradio as gr import tempfile if is_mac: - import mlx_lm - # from mlx_lm import load, stream_generate + from mlx_lm_wrapper import load_mlx_lm gradio_temp_dir = os.path.join(tempfile.gettempdir(), 'gradio') @@ -80,7 +79,7 @@ # llm_name = "mlx-community/Phi-3-mini-128k-instruct-8bit" llm_name = "mlx-community/Meta-Llama-3-8B-4bit" # llm_name = "mlx-community/dolphin-2.9.1-llama-3-8b-4bit" - llm_model, llm_tokenizer = mlx_lm.load(llm_name) + llm_model, llm_tokenizer = load_mlx_lm(llm_name) else: # llm_name = 'lllyasviel/omost-phi-3-mini-128k-8bits' llm_name = 'lllyasviel/omost-llama-3-8b-4bits' @@ -168,40 +167,12 @@ def interrupter(): if temperature == 0: generate_kwargs['do_sample'] = False + Thread(target=llm_model.generate, kwargs=generate_kwargs).start() outputs = [] - if is_mac: - # for text in mlx_lm.stream_generate(llm_model, llm_tokenizer, input_ids.cpu().numpy(), temp=temperature, top_p=top_p): - # outputs.append(text) - # # print(outputs) - # yield "".join(outputs), interrupter - - max_tokens = 100 - detokenizer = tokenizer.detokenizer - detokenizer.reset() - prompt_tokens = input_ids.cpu().numpy() - for (token, prob), n in zip( - generate_step(prompt_tokens, llm_model, **generate_kwargs), - range(max_tokens), - ): - if token == llm_tokenizer.eos_token_id: - break - detokenizer.add_token(token) - - # Yield the last segment if streaming - # yield detokenizer.last_segment - outputs.append(detokenizer.last_segment) - - detokenizer.finalize() - # yield detokenizer.last_segment - outputs.append(detokenizer.last_segment) - - else: - Thread(target=llm_model.generate, kwargs=generate_kwargs).start() - for text in streamer: - outputs.append(text) - # print(outputs) - yield "".join(outputs), interrupter + for text in streamer: + outputs.append(text) + yield "".join(outputs), interrupter return diff --git a/mlx_lm_wrapper.py b/mlx_lm_wrapper.py new file mode 100644 index 0000000..b7fa86b --- /dev/null +++ b/mlx_lm_wrapper.py @@ -0,0 +1,89 @@ +from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union +# from mlx_lm import load, PreTrainedTokenizer, TokenizerWrapper +import mlx +import mlx_lm +import transformers as tf +# from transformers import AutoTokenizer, TextIteratorStreamer +from transformers.generation.stopping_criteria import StoppingCriteriaList +from transformers.generation.utils import GenerateOutput +import numpy as np +import torch + +def load_mlx_lm(llm_name: str) -> Tuple[mlx.nn.Module, tf.PreTrainedTokenizer]: + llm_model, llm_tokenizer = mlx_lm.load(llm_name) + return MLX_LLM_TransformersWrapper(llm_model, llm_tokenizer), llm_tokenizer + +class MLX_LLM_TransformersWrapper(mlx.nn.Module): + def __init__(self, model: mlx.nn.Module, tokenizer: tf.PreTrainedTokenizer): + self.model = model + self.tokenizer = tokenizer + + def generate(self, + input_ids: np.ndarray, + streamer: tf.TextIteratorStreamer, #Optional["BaseStreamer"] = None, + # inputs: Optional[torch.Tensor] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_new_tokens: int = 100, + do_sample: bool = True, + temperature: float = 1.0, + top_p: float = 1.0, + **kwargs + ) -> Union[GenerateOutput, torch.LongTensor]: + + if streamer is not None: + streamer.put(input_ids.cpu()) + + # has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + # return self.__stream_generate(self.model, self.tokenizer, input_ids, max_new_tokens, **kwargs) + + + def __stream_generate(self, + model: torch.nn.Module, + tokenizer: tf.PreTrainedTokenizer, + prompt: Union[str, np.ndarray], + max_tokens: int = 100, + **kwargs, + ) -> Union[str, Generator[str, None, None]]: + """ + A generator producing text based on the given prompt from the model. + + Args: + prompt (mx.array): The input prompt. + model (nn.Module): The model to use for generation. + max_tokens (int): The ma + kwargs: The remaining options get passed to :func:`generate_step`. + See :func:`generate_step` for more details. + + Yields: + Generator[Tuple[mx.array, mx.array]]: A generator producing text. + """ + # if not isinstance(tokenizer, TokenizerWrapper): + # tokenizer = TokenizerWrapper(tokenizer) + + if isinstance(prompt, str): + prompt_tokens = mx.array(tokenizer.encode(prompt)) + else: + prompt_tokens = mx.array(prompt) + + detokenizer = tokenizer.detokenizer + detokenizer.reset() + print("generating...") + for (token, prob), n in zip( + generate_step( + prompt=prompt_tokens, + model=model, + temp=kwargs.pop("temperature", 1.0), + **kwargs), + range(max_tokens), + ): + print(f"n: {n}") + if token == tokenizer.eos_token_id: + print("EOS") + break + detokenizer.add_token(token) + print(f"Token: {token}") + # Yield the last segment if streaming + yield detokenizer.last_segment + + detokenizer.finalize() + yield detokenizer.last_segment From 4b6330b35cb8525effa134cd9735400bcccbf09c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20L=C3=B6wenfels?= <282+dfl@users.noreply.github.com> Date: Fri, 7 Jun 2024 14:59:42 +0100 Subject: [PATCH 6/6] going for a different approach; using mlm_lm.server with openai API format --- gradio_app.py | 123 +++++++++++++++++++++++------------------------ requirements.txt | 1 + 2 files changed, 62 insertions(+), 62 deletions(-) diff --git a/gradio_app.py b/gradio_app.py index 97abd1c..61e26cd 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -1,6 +1,7 @@ import os import platform is_mac = platform.system() == 'Darwin' +from huggingface_hub import snapshot_download os.environ['HF_HOME'] = os.path.join(os.path.dirname(__file__), 'hf_download') HF_TOKEN = None @@ -12,15 +13,13 @@ import numpy as np import gradio as gr import tempfile -if is_mac: - from mlx_lm_wrapper import load_mlx_lm +from openai import OpenAI +import subprocess gradio_temp_dir = os.path.join(tempfile.gettempdir(), 'gradio') os.makedirs(gradio_temp_dir, exist_ok=True) -from threading import Thread - # Phi3 Hijack from transformers.models.phi3.modeling_phi3 import Phi3PreTrainedModel @@ -74,28 +73,50 @@ memory_management.unload_all_models([text_encoder, text_encoder_2, vae, unet]) +openai_api_base = "http://127.0.0.1:8080/v1" +client = OpenAI(api_key="EMPTY", base_url=openai_api_base) + # LLM -if is_mac: - # llm_name = "mlx-community/Phi-3-mini-128k-instruct-8bit" - llm_name = "mlx-community/Meta-Llama-3-8B-4bit" - # llm_name = "mlx-community/dolphin-2.9.1-llama-3-8b-4bit" - llm_model, llm_tokenizer = load_mlx_lm(llm_name) -else: - # llm_name = 'lllyasviel/omost-phi-3-mini-128k-8bits' - llm_name = 'lllyasviel/omost-llama-3-8b-4bits' - # llm_name = 'lllyasviel/omost-dolphin-2.9-llama3-8b-4bits' - - llm_model = AutoModelForCausalLM.from_pretrained( - llm_name, - torch_dtype=torch.bfloat16, # This is computation type, not load/memory type. The loading quant type is baked in config. - token=HF_TOKEN - ) - llm_tokenizer = AutoTokenizer.from_pretrained( - llm_name, - token=HF_TOKEN +# llm_name = "mlx-community/Phi-3-mini-128k-instruct-8bit" +llm_name = "mlx-community/Meta-Llama-3-8B-4bit" +# llm_name = "mlx-community/dolphin-2.9.1-llama-3-8b-4bit" + +def load_model(model_name): + global process + + local_model_dir = os.path.join( + os.environ['HF_HOME'], llm_name.split("/")[1] ) -memory_management.unload_all_models(llm_model) + if not os.path.exists(local_model_dir): + snapshot_download(repo_id=llm_name, local_dir=local_model_dir) + + command = ["python3", "-m", "mlx_lm.server", "--model", local_model_dir] + + try: + process = subprocess.Popen( + command, stdin=subprocess.PIPE, stderr=subprocess.PIPE, text=True + ) + process.stdin.write("y\n") + process.stdin.flush() + print("Model Loaded") + return True #{model_status: "Model Loaded"} + except Exception as e: + print(f"Exception occurred: {str(e)}") + return False #{model_status: f"Exception occurred: {str(e)}"} + +load_model(llm_name) + +def kill_process(): + global process + process.terminate() + time.sleep(2) + if process.poll() is None: # Check if the process has indeed terminated + process.kill() # Force kill if still running + + print("Model Killed") + return {model_status: "Model Unloaded"} + @torch.inference_mode() def pytorch2numpy(imgs): @@ -134,48 +155,26 @@ def chat_fn(message: str, history: list, seed:int, temperature: float, top_p: fl conversation.append({"role": "user", "content": message}) - memory_management.load_models_to_gpu(llm_model) - - input_ids = llm_tokenizer.apply_chat_template( - conversation, return_tensors="pt", add_generation_prompt=True).to( getattr(llm_model, 'device', memory_management.gpu) ) - - streamer = TextIteratorStreamer(llm_tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) - - def interactive_stopping_criteria(*args, **kwargs) -> bool: - if getattr(streamer, 'user_interrupted', False): - print('User stopped generation') - return True - else: - return False - - stopping_criteria = StoppingCriteriaList([interactive_stopping_criteria]) - - def interrupter(): - streamer.user_interrupted = True - return - - generate_kwargs = dict( - input_ids=input_ids, - streamer=streamer, - stopping_criteria=stopping_criteria, - max_new_tokens=max_new_tokens, - do_sample=True, + response = client.chat.completions.create( + model="gpt", + messages=conversation, temperature=temperature, top_p=top_p, + # frequency_penalty=freq_penalty, + max_tokens=max_new_tokens, + stream=True, ) - - if temperature == 0: - generate_kwargs['do_sample'] = False - - Thread(target=llm_model.generate, kwargs=generate_kwargs).start() - - outputs = [] - for text in streamer: - outputs.append(text) - yield "".join(outputs), interrupter - - return - + stop = ["<|im_end|>", "<|endoftext|>"] + partial_message = "" + for chunk in response: + if len(chunk.choices) != 0: + if chunk.choices[0].delta.content not in stop: + partial_message = partial_message + chunk.choices[0].delta.content + else: + partial_message = partial_message + "" + yield partial_message + + return partial_message @torch.inference_mode() def post_chat(history): diff --git a/requirements.txt b/requirements.txt index 65bbfec..88b566b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ pillow einops torch peft +openai \ No newline at end of file