Skip to content
This repository has been archived by the owner on Jul 12, 2024. It is now read-only.

fix(query): removed from get_chat_history #109

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions genai_stack/memory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class BaseMemoryConfig(StackComponentConfig):


class BaseMemory(StackComponent):

def get_user_text(self) -> str:
"""
This method returns the user query
Expand All @@ -25,21 +25,21 @@ def get_model_text(self) -> str:
This method returns the model response
"""
raise NotImplementedError()

def get_text(self) -> dict:
"""
This method returns both user query and model response
"""
raise NotImplementedError()

def add_text(self, user_text:str, model_text:str) -> None:
"""
This method stores both user query and model response
"""
raise NotImplementedError()
def get_chat_history(self, query:str) -> str:

def get_chat_history(self) -> str:
"""
This method returns the chat conversation history
"""
raise NotImplementedError()
raise NotImplementedError()
10 changes: 5 additions & 5 deletions genai_stack/memory/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ def get_user_text(self):
if len(self.memory.chat_memory.messages) == 0:
return None
return self.memory.chat_memory.messages[-2].content

def get_model_text(self):
if len(self.memory.chat_memory.messages) == 0:
return None
return self.memory.chat_memory.messages[-1].content

def get_text(self):
return {
"user_text":self.get_user_text(),
"user_text":self.get_user_text(),
"model_text":self.get_model_text()
}

def get_chat_history(self, query):
return parse_chat_conversation_history(self.memory.chat_memory.messages)
def get_chat_history(self):
return parse_chat_conversation_history(self.memory.chat_memory.messages)
4 changes: 2 additions & 2 deletions genai_stack/retriever/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def get_context(self, query: str):
"""
raise NotImplementedError()

def get_chat_history(self, query:str) -> str:
def get_chat_history(self) -> str:
"""
This method returns the chat conversation history
"""
return self.mediator.get_chat_history(query=query)
return self.mediator.get_chat_history()
2 changes: 1 addition & 1 deletion genai_stack/retriever/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def retrieve(self, query: str, context: List[Document] = None):
metadata = context[0].metadata if context else None
prompt_dict["context"] = parse_search_results(context)
if "history" in prompt_template.input_variables:
prompt_dict["history"] = self.get_chat_history(query=query)
prompt_dict["history"] = self.get_chat_history()
else:
# Cache and memory cannot co-exist. Memory is given priority.
cache = self.mediator.get_cache(query=query, metadata=metadata)
Expand Down
4 changes: 2 additions & 2 deletions genai_stack/stack/mediator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ def add_text(self, user_text: str, model_text: str) -> None:
if self._is_component_available("memory"):
self._stack.memory.add_text(user_text, model_text)

def get_chat_history(self, query:str) -> str:
def get_chat_history(self) -> str:
if self._is_component_available("memory"):
return self._stack.memory.get_chat_history(query=query)
return self._stack.memory.get_chat_history()

# Vectordb
def store_to_vectordb(self, documents: List[LangDocument]):
Expand Down
26 changes: 13 additions & 13 deletions tests/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def chromadb_stack(self, index_name = "Chroma"):
self.memory = VectorDBMemory.from_kwargs(index_name = index_name)

self.chromadb_memory_stack = Stack(
model=None,
embedding=self.embedding,
vectordb=self.chromadb,
model=None,
embedding=self.embedding,
vectordb=self.chromadb,
memory=self.memory
)

Expand All @@ -83,24 +83,24 @@ def weaviatedb_stack(self, index_name = "Weaviate"):
self.memory = VectorDBMemory.from_kwargs(index_name = index_name)

self.weaviatedb_memory_stack = Stack(
model=None,
embedding=self.embedding,
vectordb=self.weaviatedb,
model=None,
embedding=self.embedding,
vectordb=self.weaviatedb,
memory=self.memory
)
)

def store_conversation_to_chromadb_memory(self, user_text:str, model_text:str):
self.chromadb_memory_stack.memory.add_text(
user_text=user_text,model_text=model_text
)

def store_conversation_to_weaviate_memory(self, user_text:str, model_text:str):
self.weaviatedb_memory_stack.memory.add_text(
user_text=user_text,model_text=model_text
)

def test_chromadb_memory(self, query:str = "what is my favourite car?"):
print(self.chromadb_memory_stack.memory.get_chat_history(query=query))
print(self.chromadb_memory_stack.memory.get_chat_history())

def test_weaviatedb_memory(self, query:str = "what is my favourite sport?"):
print(self.weaviatedb_memory_stack.memory.get_chat_history(query=query))
print(self.weaviatedb_memory_stack.memory.get_chat_history())