Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/kb/worker #43

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pytest = "^8.2.0"
chromadb = "^0.5.0"
sumy = "^0.11.0"
fake-useragent = "^1.5.1"
youtube-transcript-api = "^0.6.2"
shreehari-aiplanet marked this conversation as resolved.
Show resolved Hide resolved

[tool.poetry.group.dev.dependencies]
ruff = "^0.1.11"
Expand Down
3 changes: 2 additions & 1 deletion src/openagi/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ def worker_task_execution(self, query: str, description: str, task_lists: TaskLi
cur_task = task_lists.get_next_unprocessed_task()
worker = self._get_worker_by_id(cur_task.worker_id)
res, task = worker.execute_task(
cur_task,
query=query,
shreehari-aiplanet marked this conversation as resolved.
Show resolved Hide resolved
task=cur_task,
context=self.get_previous_task_contexts(task_lists=task_lists),
)
self.memory.update_task(task)
Expand Down
149 changes: 149 additions & 0 deletions src/openagi/data_extractors/data_loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import json
import tempfile
from tempfile import NamedTemporaryFile
from typing import Any
from urllib.parse import urlparse

import requests
from bs4 import BeautifulSoup as Soup
from langchain.docstore.document import Document
from langchain_community.document_loaders import (
AirbyteStripeLoader,
GitLoader,
PyPDFLoader,
RecursiveUrlLoader,
TextLoader,
UnstructuredMarkdownLoader,
UnstructuredWordDocumentLoader,
WebBaseLoader,
YoutubeLoader,
)
from openagi.data_extractors.data_types import DataType
from openagi.data_extractors.data_source import DataSource

class DataLoader:
gourab-aiplanet marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, data_source: DataSource):
self.data_source = data_source

def load(self) -> Any:
loader_methods = {
DataType.TXT: self.load_txt,
DataType.PDF: self.load_pdf,
DataType.PPTX: self.load_pptx,
DataType.DOCX: self.load_docx,
DataType.GOOGLE_DOC: self.load_google_doc,
DataType.MARKDOWN: self.load_markdown,
DataType.GITHUB_REPOSITORY: self.load_github,
DataType.WEBPAGE: self.load_webpage,
DataType.YOUTUBE: self.load_youtube,
DataType.URL: self.load_url,
}

loader = loader_methods.get(self.data_source.type)
if loader:
return loader()
else:
raise ValueError(f"Loader not implemented for type: {self.data_source.type}")

def load_txt(self):
with NamedTemporaryFile(suffix=".txt", delete=True) as temp_file:
if self.data_source.url:
file_response = requests.get(self.data_source.url).text
else:
file_response = self.data_source.content
temp_file.write(file_response.encode())
temp_file.flush()
loader = TextLoader(file_path=temp_file.name)
return loader.load_and_split()

def load_pdf(self):
if self.data_source.url:
loader = PyPDFLoader(file_path=self.data_source.url)
else:
with NamedTemporaryFile(suffix=".pdf", delete=True) as temp_file:
temp_file.write(self.data_source.content)
temp_file.flush()
loader = UnstructuredWordDocumentLoader(file_path=temp_file.name)
return loader.load_and_split()
return loader.load_and_split()

def load_google_doc(self):
pass

def load_pptx(self):
from pptx import Presentation

with NamedTemporaryFile(suffix=".pptx", delete=True) as temp_file:
if self.data_source.url:
file_response = requests.get(self.data_source.url).content
else:
file_response = self.data_source.content
temp_file.write(file_response)
temp_file.flush()
presentation = Presentation(temp_file.name)
result = ""
for i, slide in enumerate(presentation.slides):
result += f"\n\nSlide #{i}: \n"
for shape in slide.shapes:
if hasattr(shape, "text"):
result += f"{shape.text}\n"
return [Document(page_content=result)]

def load_docx(self):
with NamedTemporaryFile(suffix=".docx", delete=True) as temp_file:
if self.data_source.url:
file_response = requests.get(self.data_source.url).content
else:
file_response = self.data_source.content
temp_file.write(file_response)
temp_file.flush()
loader = UnstructuredWordDocumentLoader(file_path=temp_file.name)
return loader.load_and_split()

def load_markdown(self):
with NamedTemporaryFile(suffix=".md", delete=True) as temp_file:
if self.data_source.url:
file_response = requests.get(self.data_source.url).text
else:
file_response = self.data_source.content
temp_file.write(file_response.encode())
temp_file.flush()
loader = UnstructuredMarkdownLoader(file_path=temp_file.name)
return loader.load()

def load_github(self):
parsed_url = urlparse(self.data_source.url)
path_parts = parsed_url.path.split("/")
repo_name = path_parts[2]
metadata = json.loads(self.data_source.metadata)

with tempfile.TemporaryDirectory() as temp_dir:
repo_path = f"{temp_dir}/{repo_name}/"
loader = GitLoader(
clone_url=self.data_source.url,
repo_path=repo_path,
branch=metadata["branch"],
)
return loader.load_and_split()

def load_webpage(self):
loader = RecursiveUrlLoader(
url=self.data_source.url,
max_depth=2,
extractor=lambda x: Soup(x, "html.parser").text,
)
chunks = loader.load_and_split()
for chunk in chunks:
if "language" in chunk.metadata:
del chunk.metadata["language"]
return chunks

def load_youtube(self):
video_id = self.data_source.url.split("youtube.com/watch?v=")[-1]
loader = YoutubeLoader(video_id=video_id)
return loader.load_and_split()

def load_url(self):
url_list = self.data_source.url.split(",")
loader = WebBaseLoader(url_list)
return loader.load_and_split()
7 changes: 7 additions & 0 deletions src/openagi/data_extractors/data_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from openagi.data_extractors.data_types import DataType
class DataSource:
def __init__(self, type: DataType, url: str = None, content: str = None, metadata: dict = None):
self.type = type
self.url = url
self.content = content
self.metadata = metadata
23 changes: 23 additions & 0 deletions src/openagi/data_extractors/data_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from enum import Enum

class DataType(Enum):
TXT = "TXT"
PDF = "PDF"
DOCX = "DOCX"
PPTX = "PPTX"
GOOGLE_DOC = "GOOGLE_DOC"
MARKDOWN = "MARKDOWN"
GITHUB_REPOSITORY = "GITHUB_REPOSITORY"
WEBPAGE = "WEBPAGE"
NOTION = "NOTION"
URL = "URL"
YOUTUBE = "YOUTUBE"
CSV = "CSV"
XLSX = "XLSX"

def __str__(self):
return self.value

@classmethod
def list(cls):
return list(map(lambda c: c.value, cls))
6 changes: 6 additions & 0 deletions src/openagi/prompts/worker_task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@

Context: {context}

# Knowledge available to you to take decisions and perform actions
Knowledge Base: {knowledge_base_info}

Important: Always refer to and utilize the information provided in the Knowledge Base before taking any actions or making decisions.
The Knowledge Base contains crucial information for completing the task accurately.

# Example session:
Question: What is the capital of France?
Thought: I should look up France on DuckDuckGo to find reliable information about its capital city.
Expand Down
85 changes: 83 additions & 2 deletions src/openagi/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, Dict, List, Optional, Union

from pydantic import BaseModel, Field, field_validator

from datetime import datetime
from openagi.actions.utils import run_action
from openagi.exception import OpenAGIException
from openagi.llms.base import LLMBaseModel
Expand All @@ -14,6 +14,10 @@
from openagi.tasks.task import Task
from openagi.utils.extraction import get_act_classes_from_json, get_last_json
from openagi.utils.helper import get_default_id
from openagi.storage.chroma import ChromaStorage
from openagi.data_extractors.data_loaders import DataLoader
from openagi.data_extractors.data_source import DataSource
from openagi.data_extractors.data_types import DataType


class Worker(BaseModel):
Expand Down Expand Up @@ -46,6 +50,11 @@ class Worker(BaseModel):
default=True,
description="If set to True, the output will be overwritten even if it exists.",
)
knowledge_base: Optional[ChromaStorage] = Field(
default=None,
description="Knowledge base for worker",
exclude=True
)

# Validate output_key. Should contain only alphabets and only underscore are allowed. Not alphanumeric
@field_validator("output_key")
Expand Down Expand Up @@ -104,8 +113,69 @@ def _force_output(
def save_to_memory(self, task: Task):
"""Saves the output to the memory."""
return self.memory.update_task(task)

def init_knowledge_base(self, **kwargs):
"""
Initializes the knowledge base using the provided keyword arguments.

:param kwargs: Keyword arguments to configure the ChromaStorage.
"""
self.knowledge_base = ChromaStorage.from_kwargs(**kwargs)
gourab-aiplanet marked this conversation as resolved.
Show resolved Hide resolved

def load_document(self, id: str, document: [str], metadata: dict):
"""
Loads a single document into the knowledge base.

:param id: Unique identifier for the document.
:param document: The content of the document to be saved.
:param metadata: Metadata associated with the document.
"""
if self.knowledge_base:
self.knowledge_base.save_document(id, document, metadata)

def load_knowledge_from_source(self, data_source: DataSource):
"""
Loads knowledge from a given data source into the knowledge base.

:param data_source: The source of the data to be loaded.
"""
if not self.knowledge_base:
self.init_knowledge_base(collection_name=f"worker_{self.id}_knowledge")

loader = DataLoader(data_source=data_source)
documents = loader.load()
for i, chunk in enumerate(documents):
self.load_document(
id=f"{self.id}_chunk_{i}",
document=[chunk.page_content],
metadata={
"source": chunk.metadata.get("source", ""),
"page": chunk.metadata.get("page", ""),
"chunk": i
}
)

def update_knowledge_base(self, task: Task):
"""
Updates the knowledge base with the result of a completed task.

:param task: The task whose results are to be saved.
"""
if not self.knowledge_base:
return

def execute_task(self, task: Task, context: Any = None) -> Any:
self.knowledge_base.save_document(
id=f"task_result_{task.id}",
document=str(task.result),
metadata={
"task_name": task.name,
"task_description": task.description,
"timestamp": str(datetime.now())
}
)
logging.info(f"Updated knowledge base with results from task {task.id}")

def execute_task(self, query: str, task: Task, context: Any = None) -> Any:
"""Executes the specified task."""
logging.info(
f"{'>'*20} Executing Task - {task.name}[{task.id}] with worker - {self.role}[{self.id}] {'<'*20}"
Expand All @@ -117,13 +187,24 @@ def execute_task(self, task: Task, context: Any = None) -> Any:

logging.debug("Provoking initial thought observation...")
initial_thought_provokes = self.provoke_thought_obs(None)

knowledge_base_info = "No relevant information found."
if self.knowledge_base:
query_results = self.knowledge_base.query_documents(
query_texts=[query],
n_results=3
gourab-aiplanet marked this conversation as resolved.
Show resolved Hide resolved
)
knowledge_base_info = "\n".join([doc for doc in query_results['documents'][0]])


te_vars = dict(
task_to_execute=task_to_execute,
worker_description=worker_description,
supported_actions=[action.cls_doc() for action in self.actions],
thought_provokes=initial_thought_provokes,
output_key=self.output_key,
context=context,
knowledge_base_info=knowledge_base_info,
max_iterations=self.max_iterations,
)

Expand Down