From 7e7245c221e81ccdd290d27f82fbbeab4ffa8adc Mon Sep 17 00:00:00 2001 From: Mika Date: Thu, 21 Nov 2024 19:42:11 +0200 Subject: [PATCH] llm cache --- test_uralicnlp.py | 16 ++++++++--- uralicNLP/llm.py | 67 ++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 69 insertions(+), 14 deletions(-) diff --git a/test_uralicnlp.py b/test_uralicnlp.py index 9db35de..e412815 100644 --- a/test_uralicnlp.py +++ b/test_uralicnlp.py @@ -159,8 +159,16 @@ #print(result) #print(llm_output) #llm = get_llm("mistral", open_read(os.path.expanduser("~/.mistralapikey")).read().strip(), model="mistral-embed") -#llm = get_llm("microsoft/Phi-3.5-mini-instruct") -#print(llm.prompt("What is Livonian?")) +llm = get_llm("roneneldan/TinyStories-33M") +#llm.load_cache("cache.bin") +#"microsoft/Phi-3.5-mini-instruct") +prompts = ["What is Livonian?", "Look at this dog", "WTF are you talking about", "Yeah, right"] +for prompt in prompts: + print(llm.prompt(prompt)) + llm.embed(prompt) + +#llm.save_cache("cache.bin") + #llm = get_llm("claude", open_read(os.path.expanduser("~/.claudeapikey")).read().strip()) #print(llm.prompt("What is Tundra Nenets?")) @@ -169,11 +177,11 @@ #print(llm.embed("Super great text to embed")) #print(llm.embed_endangered("Näʹde täävtõõđi âʹtte peeʹlljid pärnnses täävtõõđi.", "sms", "fin")) -llm = get_llm("google-bert/bert-base-uncased") +#llm = get_llm("google-bert/bert-base-uncased") texts = ["dogs are funny", "cats play around", "cars go fast", "planes fly around", "parrots like to eat", "eagles soar in the skies", "moon is big", "saturn is a planet"] endangered_texts = ["Ёртозь ёртовсь кудостонть.", "Теке сялгонзояк те касовксонть арасть.", "Истяяк арсеват.", "Атякштне, кунсолан, сыргойсть омбоцеде.", "Вальмаванть неявить ульцява ардыцят.", "Морат эрзянь моро?"] #print(semantics.cluster(texts, llm, return_ids=True)) #print(semantics.cluster(texts, llm)) #print(semantics.cluster(texts, llm, hierarchical_clustering=True)) #print(semantics.cluster_endangered(endangered_texts, llm, "myv", "fin")) -print(semantics.cluster_endangered(endangered_texts, llm, "myv", "fin", hierarchical_clustering=True, method="hdbscan")) +#print(semantics.cluster_endangered(endangered_texts, llm, "myv", "fin", hierarchical_clustering=True, method="hdbscan")) diff --git a/uralicNLP/llm.py b/uralicNLP/llm.py index d88152e..e38a446 100644 --- a/uralicNLP/llm.py +++ b/uralicNLP/llm.py @@ -38,6 +38,8 @@ import json +from mikatools import pickle_dump, pickle_load + class ModuleNotInstalled(Exception): pass @@ -72,14 +74,52 @@ def get_llm(llm_name, *args, **kwargs): class LLM(object): """docstring for LLM""" def __init__(self): + self.cache = False + self._embed_cache_dict = {} + self._prompt_cache_dict = {} super(LLM, self).__init__() + + def _embed_cache(func): + def inner(*args, **kwargs): + self = args[0] + if self.cache and "_".join(args[1:]) in self._embed_cache_dict: + return self._embed_cache_dict["_".join(args[1:])] + else: + r = func(*args, **kwargs) + if self.cache: + self._embed_cache_dict["_".join(args[1:])] = r + return r + return inner + + def _prompt_cache(func): + def inner(*args, **kwargs): + self = args[0] + if self.cache and "_".join(args[1:]) in self._prompt_cache_dict: + return self._prompt_cache_dict["_".join(args[1:])] + else: + r = func(*args, **kwargs) + if self.cache: + self._prompt_cache_dict["_".join(args[1:])] = r + return r + return inner + + + @_prompt_cache def prompt(self, text): + return self._prompt(text) + + def _prompt(self, text): raise NotImplementedException("LLM does not support prompting") + @_embed_cache def embed(self, text): + return self._embed(text) + + def _embed(self, text): raise NotImplementedException("LLM does not support embeddings") + @_embed_cache def embed_endangered(self, text, lang, dict_lang,backend=TinyDictionary): r = [] for word in tokenize_words(text): @@ -96,6 +136,13 @@ def embed_endangered(self, text, lang, dict_lang,backend=TinyDictionary): text = " ".join(r) return self.embed(text) + def save_cache(self, file, *args, **kwargs): + pickle_dump([self._embed_cache_dict, self._prompt_cache_dict], file, *args, **kwargs ) + + def load_cache(self, file, *args, **kwargs): + self.cache = True + self._embed_cache_dict, self._prompt_cache_dict = pickle_load(file, *args, **kwargs) + class ChatGPT(LLM): """docstring for ChatGPT""" @@ -107,7 +154,7 @@ def __init__(self, api_key, model="gpt-4o"): raise ModuleNotInstalled("OpenAI Python library is not installed. Run pip install openai. If you do have the library installed, check your API key.") self.model = model - def prompt(self, prompt, temperature=1): + def _prompt(self, prompt, temperature=1): chat_completion = self.client.chat.completions.create( messages=[ { @@ -120,7 +167,7 @@ def prompt(self, prompt, temperature=1): ) return chat_completion.choices[0].message.content - def embed(self, text): + def _embed(self, text): response = self.client.embeddings.create(input=text, model=self.model) return response.data[0].embedding @@ -136,11 +183,11 @@ def __init__(self, api_key, model="gemini-1.5-flash", task_type="retrieval_docum self.model_name = model self.task_type = task_type - def prompt(self, prompt): + def _prompt(self, prompt): response = self.model.generate_content(prompt) return response.text - def embed(self, text): + def _embed(self, text): result = genai.embed_content(model=self.model_name, content=text, task_type=self.task_type) return result['embedding'] @@ -153,13 +200,13 @@ def __init__(self, model, max_length=1000, device=-1): self.embedder = None self.device = device - def prompt(self, prompt): + def _prompt(self, prompt): if self.model is None: self.model = pipeline('text-generation', model = self.model_name, device = self.device) r = self.model(prompt, max_length=self.max_length, truncation=True) return " ".join([x['generated_text'] for x in r]) - def embed(self, text): + def _embed(self, text): if self.embedder is None: self.embedder = pipeline('feature-extraction', model=self.model_name,device = self.device) r = self.embedder(text, return_tensors="pt")[0].numpy().mean(axis=0) @@ -174,11 +221,11 @@ def __init__(self, api_key, model="mistral-small-latest"): raise ModuleNotInstalled("Mistral library is not installed. Run pip install mistralai. If you do have the library installed, check your API key.") self.model = model - def prompt(self, prompt): + def _prompt(self, prompt): r = self.s.chat.complete(model=self.model, messages=[{"content": prompt,"role": "user",}]) return r.choices[0].message.content - def embed(self, text): + def _embed(self, text): embeddings_batch_response = self.s.embeddings.create(model=self.model, inputs=[text]) return embeddings_batch_response.data[0].embedding @@ -194,7 +241,7 @@ def __init__(self, api_key, model="claude-3-5-sonnet-latest", max_length=1024): self.model = model self.max_length = max_length - def prompt(self, prompt, temperature=1): + def _prompt(self, prompt, temperature=1): chat_completion = self.client.messages.create(model=self.model,messages=[{"role": "user", "content": prompt}], max_tokens=self.max_length) return " ".join([x.text for x in chat_completion.content]) @@ -209,7 +256,7 @@ def __init__(self, api_key, model="voyage-3"): raise ModuleNotInstalled("Voyage Python library is not installed. Run pip install voyageai. If you do have the library installed, check your API key.") self.model = model - def embed(self, text): + def _embed(self, text): result = self.vo.embed([text], model=self.model, input_type="document") return result.embeddings[0]