Skip to content
This repository has been archived by the owner on Jul 12, 2024. It is now read-only.

fix(query): removed from get_chat_history #109

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions genai_stack/memory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class BaseMemoryConfig(StackComponentConfig):


class BaseMemory(StackComponent):

def get_user_text(self) -> str:
"""
This method returns the user query
Expand All @@ -25,21 +25,21 @@ def get_model_text(self) -> str:
This method returns the model response
"""
raise NotImplementedError()

def get_text(self) -> dict:
"""
This method returns both user query and model response
"""
raise NotImplementedError()

def add_text(self, user_text:str, model_text:str) -> None:
"""
This method stores both user query and model response
"""
raise NotImplementedError()
def get_chat_history(self, query:str) -> str:

def get_chat_history(self) -> str:
"""
This method returns the chat conversation history
"""
raise NotImplementedError()
raise NotImplementedError()
10 changes: 5 additions & 5 deletions genai_stack/memory/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ def get_user_text(self):
if len(self.memory.chat_memory.messages) == 0:
return None
return self.memory.chat_memory.messages[-2].content

def get_model_text(self):
if len(self.memory.chat_memory.messages) == 0:
return None
return self.memory.chat_memory.messages[-1].content

def get_text(self):
return {
"user_text":self.get_user_text(),
"user_text":self.get_user_text(),
"model_text":self.get_model_text()
}

def get_chat_history(self, query):
return parse_chat_conversation_history(self.memory.chat_memory.messages)
def get_chat_history(self):
return parse_chat_conversation_history(self.memory.chat_memory.messages)
20 changes: 9 additions & 11 deletions genai_stack/memory/vectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from langchain.memory import VectorStoreRetrieverMemory
from genai_stack.memory.base import BaseMemory, BaseMemoryConfig, BaseMemoryConfigModel
from genai_stack.memory.utils import (
parse_chat_conversation_history_search_result,
parse_chat_conversation_history_search_result,
format_index_name
)

Expand All @@ -25,16 +25,16 @@ class VectorDBMemory(BaseMemory):
def _post_init(self, *args, **kwargs):
config:VectorDBMemoryConfigModel = self.config.config_data

# We have to pass the index name in two places, one is to Vectordb and to
# We have to pass the index name in two places, one is to Vectordb and to
# VectorStoreRetriever.
# in case of weaviate, if we pass the index name in lowercase, the weaviate will
# internally convert it to pascal for schema/collection
# eg passed index name => chatting, weaviate converted to Chatting,
# But if VectorStoreRetriever use the lowercased index name,
# But if VectorStoreRetriever use the lowercased index name,
# it throws index error, since the weaviate changed to pascal.
# To handle this we are converting the index name to pascal before intializing
# the Vectordb and Vectorstoreretriever, and to
# maintain the consistency, we are also converting the chromadb index name to
# To handle this we are converting the index name to pascal before intializing
# the Vectordb and Vectorstoreretriever, and to
# maintain the consistency, we are also converting the chromadb index name to
# pascal, instead of conditionally doing only for weaviate.
kwarg_map, index_name = format_index_name(config=config)

Expand All @@ -53,8 +53,6 @@ def _post_init(self, *args, **kwargs):
def add_text(self, user_text: str, model_text: str):
self.memory.save_context({"input": user_text}, {"output": model_text})

def get_chat_history(self, query):
documents = self.memory.load_memory_variables({
"prompt": query
})[self.memory.memory_key]
return parse_chat_conversation_history_search_result(search_results=documents)
def get_chat_history(self):
documents = self.memory.load_memory_variables()[self.memory.memory_key]
return parse_chat_conversation_history_search_result(search_results=documents)
4 changes: 2 additions & 2 deletions genai_stack/retriever/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def get_context(self, query: str):
"""
raise NotImplementedError()

def get_chat_history(self, query:str) -> str:
def get_chat_history(self) -> str:
"""
This method returns the chat conversation history
"""
return self.mediator.get_chat_history(query=query)
return self.mediator.get_chat_history()
2 changes: 1 addition & 1 deletion genai_stack/retriever/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def retrieve(self, query: str, context: List[Document] = None):
metadata = context[0].metadata if context else None
prompt_dict["context"] = parse_search_results(context)
if "history" in prompt_template.input_variables:
prompt_dict["history"] = self.get_chat_history(query=query)
prompt_dict["history"] = self.get_chat_history()
else:
# Cache and memory cannot co-exist. Memory is given priority.
cache = self.mediator.get_cache(query=query, metadata=metadata)
Expand Down
4 changes: 2 additions & 2 deletions genai_stack/stack/mediator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ def add_text(self, user_text: str, model_text: str) -> None:
if self._is_component_available("memory"):
self._stack.memory.add_text(user_text, model_text)

def get_chat_history(self, query:str) -> str:
def get_chat_history(self) -> str:
if self._is_component_available("memory"):
return self._stack.memory.get_chat_history(query=query)
return self._stack.memory.get_chat_history()

# Vectordb
def store_to_vectordb(self, documents: List[LangDocument]):
Expand Down
26 changes: 13 additions & 13 deletions tests/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def chromadb_stack(self, index_name = "Chroma"):
self.memory = VectorDBMemory.from_kwargs(index_name = index_name)

self.chromadb_memory_stack = Stack(
model=None,
embedding=self.embedding,
vectordb=self.chromadb,
model=None,
embedding=self.embedding,
vectordb=self.chromadb,
memory=self.memory
)

Expand All @@ -83,24 +83,24 @@ def weaviatedb_stack(self, index_name = "Weaviate"):
self.memory = VectorDBMemory.from_kwargs(index_name = index_name)

self.weaviatedb_memory_stack = Stack(
model=None,
embedding=self.embedding,
vectordb=self.weaviatedb,
model=None,
embedding=self.embedding,
vectordb=self.weaviatedb,
memory=self.memory
)
)

def store_conversation_to_chromadb_memory(self, user_text:str, model_text:str):
self.chromadb_memory_stack.memory.add_text(
user_text=user_text,model_text=model_text
)

def store_conversation_to_weaviate_memory(self, user_text:str, model_text:str):
self.weaviatedb_memory_stack.memory.add_text(
user_text=user_text,model_text=model_text
)

def test_chromadb_memory(self, query:str = "what is my favourite car?"):
print(self.chromadb_memory_stack.memory.get_chat_history(query=query))
print(self.chromadb_memory_stack.memory.get_chat_history())

def test_weaviatedb_memory(self, query:str = "what is my favourite sport?"):
print(self.weaviatedb_memory_stack.memory.get_chat_history(query=query))
print(self.weaviatedb_memory_stack.memory.get_chat_history())