From d36abe936888e7aed9b13e20237a7498fefe53b7 Mon Sep 17 00:00:00 2001 From: Sri Laasya Nutheti Date: Fri, 21 Feb 2025 15:43:54 -0800 Subject: [PATCH 1/2] Add chromadb with openai for RAG --- agentstack/_tools/chroma_vectordb/__init__.py | 137 ++++++++++++++++++ agentstack/_tools/chroma_vectordb/config.json | 21 +++ 2 files changed, 158 insertions(+) create mode 100644 agentstack/_tools/chroma_vectordb/__init__.py create mode 100644 agentstack/_tools/chroma_vectordb/config.json diff --git a/agentstack/_tools/chroma_vectordb/__init__.py b/agentstack/_tools/chroma_vectordb/__init__.py new file mode 100644 index 00000000..af16e843 --- /dev/null +++ b/agentstack/_tools/chroma_vectordb/__init__.py @@ -0,0 +1,137 @@ +import os +import chromadb +from chromadb.config import Settings +from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction +from typing import List, Dict, Any, Optional +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + +def create_collection( + collection_name: str = "default_collection", + persist_directory: str = "chroma_db" +) -> str: + """ + Creates a new Chroma collection with OpenAI embeddings. + + Args: + collection_name: Name for the collection + persist_directory: Directory to store the database + + Returns: + str: Success message with collection details + """ + openai_api_key = os.getenv("OPENAI_API_KEY") + if not openai_api_key: + raise ValueError("OPENAI_API_KEY environment variable is not set!") + + client = chromadb.Client(Settings( + persist_directory=persist_directory + )) + + embedding_function = OpenAIEmbeddingFunction( + model_name="text-embedding-ada-002", + api_key=openai_api_key + ) + + collection = client.get_or_create_collection( + name=collection_name, + embedding_function=embedding_function + ) + + return f"Created collection '{collection_name}' in {persist_directory}" + +def add_documents( + collection_name: str, + documents: List[Dict[str, str]], + persist_directory: str = "chroma_db" +) -> str: + """ + Adds documents to a Chroma collection. + + Args: + collection_name: Name of the collection to add documents to + documents: List of documents, each with "content" and "url" keys + persist_directory: Directory where the database is stored + + Returns: + str: Success message with number of documents added + """ + openai_api_key = os.getenv("OPENAI_API_KEY") + if not openai_api_key: + raise ValueError("OPENAI_API_KEY environment variable is not set!") + + client = chromadb.Client(Settings( + persist_directory=persist_directory + )) + + embedding_function = OpenAIEmbeddingFunction( + model_name="text-embedding-ada-002", + api_key=openai_api_key + ) + + collection = client.get_collection( + name=collection_name, + embedding_function=embedding_function + ) + + docs = [] + metadatas = [] + ids = [] + + for i, doc in enumerate(documents): + docs.append(doc.get("content", "")) + metadatas.append({"url": doc.get("url", "")}) + ids.append(f"doc_{i}") + + collection.add( + documents=docs, + metadatas=metadatas, + ids=ids + ) + + return f"Added {len(documents)} documents to collection '{collection_name}'" + +def query_collection( + collection_name: str, + query_text: str, + n_results: int = 3, + persist_directory: str = "chroma_db" +) -> str: + """ + Query a Chroma collection using natural language. + + Args: + collection_name: Name of the collection to query + query_text: The search query + n_results: Number of results to return + persist_directory: Directory where the database is stored + + Returns: + str: Query results including document content and metadata + """ + openai_api_key = os.getenv("OPENAI_API_KEY") + if not openai_api_key: + raise ValueError("OPENAI_API_KEY environment variable is not set!") + + client = chromadb.Client(Settings( + persist_directory=persist_directory + )) + + embedding_function = OpenAIEmbeddingFunction( + model_name="text-embedding-ada-002", + api_key=openai_api_key + ) + + collection = client.get_collection( + name=collection_name, + embedding_function=embedding_function + ) + + results = collection.query( + query_texts=[query_text], + n_results=n_results + ) + + return str(results) diff --git a/agentstack/_tools/chroma_vectordb/config.json b/agentstack/_tools/chroma_vectordb/config.json new file mode 100644 index 00000000..5302f281 --- /dev/null +++ b/agentstack/_tools/chroma_vectordb/config.json @@ -0,0 +1,21 @@ +{ + "name": "chroma_vectordb", + "url": "https://www.trychroma.com/", + "category": "vector-store", + "env": { + "OPENAI_API_KEY": null + }, + "dependencies": [ + "chromadb>=0.4.0", + "openai>=1.0.0", + "python-dotenv>=1.0.0", + "pytest>=7.0.0", + "pytest-mock>=3.10.0" + ], + "tools": [ + "create_collection", + "add_documents", + "query_collection" + ], + "cta": "Make sure you have an OpenAI API key set in your environment variables." +} \ No newline at end of file From 7b87860f5fd8849d6f322e6e17add89929c2128b Mon Sep 17 00:00:00 2001 From: Sri Laasya Nutheti Date: Fri, 21 Feb 2025 16:23:01 -0800 Subject: [PATCH 2/2] Resolved mypy attr error --- agentstack/_tools/chroma_vectordb/__init__.py | 55 ++++++++++++++----- 1 file changed, 40 insertions(+), 15 deletions(-) diff --git a/agentstack/_tools/chroma_vectordb/__init__.py b/agentstack/_tools/chroma_vectordb/__init__.py index af16e843..bc444187 100644 --- a/agentstack/_tools/chroma_vectordb/__init__.py +++ b/agentstack/_tools/chroma_vectordb/__init__.py @@ -1,9 +1,11 @@ import os import chromadb from chromadb.config import Settings -from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction -from typing import List, Dict, Any, Optional +from chromadb.utils.embedding_functions import EmbeddingFunction +from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction +from typing import List, Dict, Any, Optional, Sequence, Mapping, Union from dotenv import load_dotenv +from typing import cast # Load environment variables load_dotenv() @@ -30,9 +32,12 @@ def create_collection( persist_directory=persist_directory )) - embedding_function = OpenAIEmbeddingFunction( - model_name="text-embedding-ada-002", - api_key=openai_api_key + embedding_function = cast( + EmbeddingFunction, + OpenAIEmbeddingFunction( + model_name="text-embedding-ada-002", + api_key=openai_api_key + ) ) collection = client.get_or_create_collection( @@ -66,9 +71,12 @@ def add_documents( persist_directory=persist_directory )) - embedding_function = OpenAIEmbeddingFunction( - model_name="text-embedding-ada-002", - api_key=openai_api_key + embedding_function = cast( + EmbeddingFunction, + OpenAIEmbeddingFunction( + model_name="text-embedding-ada-002", + api_key=openai_api_key + ) ) collection = client.get_collection( @@ -76,9 +84,9 @@ def add_documents( embedding_function=embedding_function ) - docs = [] - metadatas = [] - ids = [] + docs: List[str] = [] + metadatas: List[Mapping[str, Union[str, int, float, bool]]] = [] + ids: List[str] = [] for i, doc in enumerate(documents): docs.append(doc.get("content", "")) @@ -119,9 +127,12 @@ def query_collection( persist_directory=persist_directory )) - embedding_function = OpenAIEmbeddingFunction( - model_name="text-embedding-ada-002", - api_key=openai_api_key + embedding_function = cast( + EmbeddingFunction, + OpenAIEmbeddingFunction( + model_name="text-embedding-ada-002", + api_key=openai_api_key + ) ) collection = client.get_collection( @@ -134,4 +145,18 @@ def query_collection( n_results=n_results ) - return str(results) + # Format results nicely + formatted_results = [] + for i, (doc, metadata, distance) in enumerate(zip( + results['documents'][0], # type: ignore + results['metadatas'][0], # type: ignore + results['distances'][0] # type: ignore + )): + formatted_results.append( + f"Result {i+1}:\n" + f"Content: {doc}\n" + f"URL: {metadata['url']}\n" + f"Relevance Score: {1 - float(distance):.2f}\n" + ) + + return "\n".join(formatted_results)