diff --git a/genai_stack/memory/base.py b/genai_stack/memory/base.py index ce5eedad..67ca07bd 100644 --- a/genai_stack/memory/base.py +++ b/genai_stack/memory/base.py @@ -13,7 +13,7 @@ class BaseMemoryConfig(StackComponentConfig): class BaseMemory(StackComponent): - + def get_user_text(self) -> str: """ This method returns the user query @@ -25,21 +25,21 @@ def get_model_text(self) -> str: This method returns the model response """ raise NotImplementedError() - + def get_text(self) -> dict: """ This method returns both user query and model response """ raise NotImplementedError() - + def add_text(self, user_text:str, model_text:str) -> None: """ This method stores both user query and model response """ raise NotImplementedError() - - def get_chat_history(self, query:str) -> str: + + def get_chat_history(self) -> str: """ This method returns the chat conversation history """ - raise NotImplementedError() \ No newline at end of file + raise NotImplementedError() diff --git a/genai_stack/memory/langchain.py b/genai_stack/memory/langchain.py index d2e247f7..d2385c71 100644 --- a/genai_stack/memory/langchain.py +++ b/genai_stack/memory/langchain.py @@ -28,17 +28,17 @@ def get_user_text(self): if len(self.memory.chat_memory.messages) == 0: return None return self.memory.chat_memory.messages[-2].content - + def get_model_text(self): if len(self.memory.chat_memory.messages) == 0: return None return self.memory.chat_memory.messages[-1].content - + def get_text(self): return { - "user_text":self.get_user_text(), + "user_text":self.get_user_text(), "model_text":self.get_model_text() } - def get_chat_history(self, query): - return parse_chat_conversation_history(self.memory.chat_memory.messages) \ No newline at end of file + def get_chat_history(self): + return parse_chat_conversation_history(self.memory.chat_memory.messages) diff --git a/genai_stack/memory/vectordb.py b/genai_stack/memory/vectordb.py index e7202c6f..2c2f88f5 100644 --- a/genai_stack/memory/vectordb.py +++ b/genai_stack/memory/vectordb.py @@ -2,7 +2,7 @@ from langchain.memory import VectorStoreRetrieverMemory from genai_stack.memory.base import BaseMemory, BaseMemoryConfig, BaseMemoryConfigModel from genai_stack.memory.utils import ( - parse_chat_conversation_history_search_result, + parse_chat_conversation_history_search_result, format_index_name ) @@ -25,16 +25,16 @@ class VectorDBMemory(BaseMemory): def _post_init(self, *args, **kwargs): config:VectorDBMemoryConfigModel = self.config.config_data - # We have to pass the index name in two places, one is to Vectordb and to + # We have to pass the index name in two places, one is to Vectordb and to # VectorStoreRetriever. # in case of weaviate, if we pass the index name in lowercase, the weaviate will # internally convert it to pascal for schema/collection # eg passed index name => chatting, weaviate converted to Chatting, - # But if VectorStoreRetriever use the lowercased index name, + # But if VectorStoreRetriever use the lowercased index name, # it throws index error, since the weaviate changed to pascal. - # To handle this we are converting the index name to pascal before intializing - # the Vectordb and Vectorstoreretriever, and to - # maintain the consistency, we are also converting the chromadb index name to + # To handle this we are converting the index name to pascal before intializing + # the Vectordb and Vectorstoreretriever, and to + # maintain the consistency, we are also converting the chromadb index name to # pascal, instead of conditionally doing only for weaviate. kwarg_map, index_name = format_index_name(config=config) @@ -53,8 +53,6 @@ def _post_init(self, *args, **kwargs): def add_text(self, user_text: str, model_text: str): self.memory.save_context({"input": user_text}, {"output": model_text}) - def get_chat_history(self, query): - documents = self.memory.load_memory_variables({ - "prompt": query - })[self.memory.memory_key] - return parse_chat_conversation_history_search_result(search_results=documents) \ No newline at end of file + def get_chat_history(self): + documents = self.memory.load_memory_variables()[self.memory.memory_key] + return parse_chat_conversation_history_search_result(search_results=documents) diff --git a/genai_stack/retriever/base.py b/genai_stack/retriever/base.py index 7c05e7b4..3a8b544a 100644 --- a/genai_stack/retriever/base.py +++ b/genai_stack/retriever/base.py @@ -35,8 +35,8 @@ def get_context(self, query: str): """ raise NotImplementedError() - def get_chat_history(self, query:str) -> str: + def get_chat_history(self) -> str: """ This method returns the chat conversation history """ - return self.mediator.get_chat_history(query=query) + return self.mediator.get_chat_history() diff --git a/genai_stack/retriever/langchain.py b/genai_stack/retriever/langchain.py index afd00701..6e322a15 100644 --- a/genai_stack/retriever/langchain.py +++ b/genai_stack/retriever/langchain.py @@ -32,7 +32,7 @@ def retrieve(self, query: str, context: List[Document] = None): metadata = context[0].metadata if context else None prompt_dict["context"] = parse_search_results(context) if "history" in prompt_template.input_variables: - prompt_dict["history"] = self.get_chat_history(query=query) + prompt_dict["history"] = self.get_chat_history() else: # Cache and memory cannot co-exist. Memory is given priority. cache = self.mediator.get_cache(query=query, metadata=metadata) diff --git a/genai_stack/stack/mediator.py b/genai_stack/stack/mediator.py index 800496d7..9bd50afa 100644 --- a/genai_stack/stack/mediator.py +++ b/genai_stack/stack/mediator.py @@ -67,9 +67,9 @@ def add_text(self, user_text: str, model_text: str) -> None: if self._is_component_available("memory"): self._stack.memory.add_text(user_text, model_text) - def get_chat_history(self, query:str) -> str: + def get_chat_history(self) -> str: if self._is_component_available("memory"): - return self._stack.memory.get_chat_history(query=query) + return self._stack.memory.get_chat_history() # Vectordb def store_to_vectordb(self, documents: List[LangDocument]): diff --git a/tests/test_memory.py b/tests/test_memory.py index 0f1d08a9..ad1c9305 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -57,9 +57,9 @@ def chromadb_stack(self, index_name = "Chroma"): self.memory = VectorDBMemory.from_kwargs(index_name = index_name) self.chromadb_memory_stack = Stack( - model=None, - embedding=self.embedding, - vectordb=self.chromadb, + model=None, + embedding=self.embedding, + vectordb=self.chromadb, memory=self.memory ) @@ -83,24 +83,24 @@ def weaviatedb_stack(self, index_name = "Weaviate"): self.memory = VectorDBMemory.from_kwargs(index_name = index_name) self.weaviatedb_memory_stack = Stack( - model=None, - embedding=self.embedding, - vectordb=self.weaviatedb, + model=None, + embedding=self.embedding, + vectordb=self.weaviatedb, memory=self.memory - ) - + ) + def store_conversation_to_chromadb_memory(self, user_text:str, model_text:str): self.chromadb_memory_stack.memory.add_text( user_text=user_text,model_text=model_text ) - + def store_conversation_to_weaviate_memory(self, user_text:str, model_text:str): self.weaviatedb_memory_stack.memory.add_text( user_text=user_text,model_text=model_text ) - + def test_chromadb_memory(self, query:str = "what is my favourite car?"): - print(self.chromadb_memory_stack.memory.get_chat_history(query=query)) - + print(self.chromadb_memory_stack.memory.get_chat_history()) + def test_weaviatedb_memory(self, query:str = "what is my favourite sport?"): - print(self.weaviatedb_memory_stack.memory.get_chat_history(query=query)) + print(self.weaviatedb_memory_stack.memory.get_chat_history())