diff --git a/app/endpoints/search.py b/app/endpoints/search.py new file mode 100644 index 00000000..275d4b0b --- /dev/null +++ b/app/endpoints/search.py @@ -0,0 +1,28 @@ +from fastapi import APIRouter, Security + +from app.helpers import VectorStore +from app.schemas.chunks import Chunks +from app.schemas.search import SearchRequest +from app.utils.lifespan import clients +from app.utils.security import check_api_key + +router = APIRouter() + + +@router.post("/search") +async def search(request: SearchRequest, user: str = Security(check_api_key)) -> Chunks: + """ + Similarity search for chunks in the vector store. + + Parameters: + request (SearchRequest): The search request. + user (str): The user. + + Returns: + Chunks: The chunks. + """ + + vectorstore = VectorStore(clients=clients, user=user) + data = vectorstore.search(prompt=request.prompt, collection_names=request.collections, k=request.k, score_threshold=request.score_threshold) + + return Chunks(data=data) diff --git a/app/helpers/_vectorstore.py b/app/helpers/_vectorstore.py index 107eb545..9b15287d 100644 --- a/app/helpers/_vectorstore.py +++ b/app/helpers/_vectorstore.py @@ -6,7 +6,8 @@ from qdrant_client.http.models import Distance, FieldCondition, Filter, MatchAny, PointIdsList, PointStruct, VectorParams from app.schemas.chunks import Chunk -from app.schemas.collections import CollectionMetadata, Document +from langchain.docstore.document import Document +from app.schemas.collections import CollectionMetadata from app.schemas.config import EMBEDDINGS_MODEL_TYPE, METADATA_COLLECTION, PRIVATE_COLLECTION_TYPE, PUBLIC_COLLECTION_TYPE @@ -23,7 +24,7 @@ def from_documents(self, documents: List[Document], model: str, collection_name: Add documents to a collection. Parameters: - documents (List[Document]): A list of Document objects to add to the collection. + documents (List[Document]): A list of Langchain Document objects to add to the collection. model (str): The model to use for embeddings. collection_name (str): The name of the collection to add the documents to. """ @@ -63,11 +64,11 @@ def search( k: Optional[int] = 4, score_threshold: Optional[float] = None, filter: Optional[Filter] = None, - ) -> List[Document]: + ) -> List[Chunk]: response = self.models[model].embeddings.create(input=[prompt], model=model) vector = response.data[0].embedding - documents = [] + chunks = [] collections = self.get_collection_metadata(collection_names=collection_names) for collection in collections: if collection.model != model: @@ -81,21 +82,20 @@ def search( with_payload=True, query_filter=filter, ) + for i, result in enumerate(results): + results[i] = result.model_dump() + results[i]["collection"] = collection.name - documents.extend(results) + chunks.extend(results) # sort by similarity score and get top k - documents = sorted(documents, key=lambda x: x.score, reverse=True)[:k] - documents = [ - Document( - id=document.id, - page_content=document.payload["page_content"], - metadata=document.payload["metadata"], - ) - for document in documents + chunks = sorted(chunks, key=lambda x: x["score"], reverse=True)[:k] + chunks = [ + Chunk(id=chunk["id"], collection=chunk["collection"], content=chunk["payload"]["page_content"], metadata=chunk["payload"]["metadata"]) + for chunk in chunks ] - return documents + return chunks def get_collection_metadata(self, collection_names: List[str] = [], type: str = "all", errors: str = "raise") -> List[CollectionMetadata]: """ @@ -258,15 +258,9 @@ def get_chunks(self, collection_name: str, filter: Optional[Filter] = None) -> L scroll_filter=filter, limit=100, # @TODO: add pagination )[0] - data = list() - for chunk in chunks: - data.append( - Chunk( - collection=collection_name, - id=chunk.id, - metadata=chunk.payload["metadata"], - content=chunk.payload["page_content"], - ) - ) + chunks = [ + Chunk(collection=collection_name, id=chunk.id, metadata=chunk.payload["metadata"], content=chunk.payload["page_content"]) + for chunk in chunks + ] - return data + return chunks diff --git a/app/main.py b/app/main.py index dd1df39b..46de908c 100644 --- a/app/main.py +++ b/app/main.py @@ -1,6 +1,6 @@ from fastapi import FastAPI, Response, Security -from app.endpoints import chat, chunks, collections, completions, embeddings, files, models, tools +from app.endpoints import chat, chunks, collections, completions, embeddings, files, models, search, tools from app.utils.config import APP_CONTACT_EMAIL, APP_CONTACT_URL, APP_DESCRIPTION, APP_VERSION from app.utils.lifespan import lifespan from app.utils.security import check_api_key @@ -31,4 +31,5 @@ def health(user: str = Security(check_api_key)): app.include_router(collections.router, tags=["Collections"], prefix="/v1") app.include_router(chunks.router, tags=["Chunks"], prefix="/v1") app.include_router(files.router, tags=["Files"], prefix="/v1") +app.include_router(search.router, tags=["Search"], prefix="/v1") app.include_router(tools.router, tags=["Tools"], prefix="/v1") diff --git a/app/schemas/collections.py b/app/schemas/collections.py index aa47d2ab..f1e10166 100644 --- a/app/schemas/collections.py +++ b/app/schemas/collections.py @@ -1,9 +1,8 @@ -from typing import Literal, List, Optional, Dict, Any -from uuid import UUID +from typing import List, Literal, Optional from pydantic import BaseModel -from app.schemas.config import PUBLIC_COLLECTION_TYPE, PRIVATE_COLLECTION_TYPE +from app.schemas.config import PRIVATE_COLLECTION_TYPE, PUBLIC_COLLECTION_TYPE class Collection(BaseModel): @@ -22,12 +21,6 @@ class Collections(BaseModel): data: List[Collection] -class Document(BaseModel): - id: UUID - page_content: str - metadata: Dict[str, Any] - - class CollectionMetadata(BaseModel): id: str name: Optional[str] = None diff --git a/app/schemas/search.py b/app/schemas/search.py new file mode 100644 index 00000000..45b90d69 --- /dev/null +++ b/app/schemas/search.py @@ -0,0 +1,10 @@ +from typing import List, Optional + +from pydantic import BaseModel + + +class SearchRequest(BaseModel): + prompt: str + collections: List[str] + k: int + score_threshold: Optional[float] = None diff --git a/app/tools/_baserag.py b/app/tools/_baserag.py index af635645..a2052a3e 100644 --- a/app/tools/_baserag.py +++ b/app/tools/_baserag.py @@ -39,10 +39,10 @@ async def get_prompt( vectorstore = VectorStore(clients=self.clients, user=request["user"]) prompt = request["messages"][-1]["content"] - documents = vectorstore.search(model=embeddings_model, prompt=prompt, collection_names=collections, k=k) + chunks = vectorstore.search(model=embeddings_model, prompt=prompt, collection_names=collections, k=k) - metadata = {"chunks": [document.metadata for document in documents]} - documents = "\n\n".join([document.page_content for document in documents]) + metadata = {"chunks": [chunk.metadata for chunk in chunks]} + documents = "\n\n".join([chunk.content for chunk in chunks]) prompt = prompt_template.format(documents=documents, prompt=prompt) return ToolOutput(prompt=prompt, metadata=metadata)