Skip to content

Commit

Permalink
refactor(querybot)🔄: Enhance QueryBot to support additional query inp…
Browse files Browse the repository at this point in the history
…ut types

- Updated the QueryBot's __call__ method to accept Union types for query inputs.
- Added logic to handle BaseMessage and HumanMessage types for query content extraction.
- Refactored the retrieval process to accommodate the new input types.
  • Loading branch information
ericmjl committed Jan 19, 2025
1 parent b28f30f commit 314f78a
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions llamabot/bot/querybot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import contextvars
from pathlib import Path
from typing import Optional
from typing import Optional, Union
from dotenv import load_dotenv

from llamabot.config import default_language_model

from llamabot.bot.simplebot import SimpleBot
from llamabot.components.messages import AIMessage, HumanMessage
from llamabot.components.messages import AIMessage, BaseMessage, HumanMessage
from llamabot.components.docstore import LanceDBDocStore
from llamabot.components.chatui import ChatUIMixin
from llamabot.components.messages import (
Expand Down Expand Up @@ -76,7 +76,9 @@ def __init__(

ChatUIMixin.__init__(self, initial_message)

def __call__(self, query: str, n_results: int = 20) -> AIMessage:
def __call__(
self, query: Union[str, HumanMessage, BaseMessage], n_results: int = 20
) -> AIMessage:
"""Query documents within QueryBot's document store.
We use RAG to query out documents.
Expand All @@ -86,12 +88,16 @@ def __call__(self, query: str, n_results: int = 20) -> AIMessage:
messages = [self.system_prompt]

retreived_messages = set()

q = query
if isinstance(query, (BaseMessage, HumanMessage)):
q = query.content
retrieved_messages = retreived_messages.union(
self.docstore.retrieve(query, n_results)
self.docstore.retrieve(q, n_results)
)
retrieved = [RetrievedMessage(content=chunk) for chunk in retrieved_messages]
messages.extend(retrieved)
messages.append(HumanMessage(content=query))
messages.append(HumanMessage(content=q))
if self.stream_target == "stdout":
response: AIMessage = self.stream_stdout(messages)
return response
Expand Down

0 comments on commit 314f78a

Please sign in to comment.