From 4fe3e0a439e19f57e123e132b0d9bb8cd6dc43b8 Mon Sep 17 00:00:00 2001 From: leoguillaume Date: Thu, 19 Sep 2024 18:27:29 +0200 Subject: [PATCH] feat: change search response schemas --- app/endpoints/search.py | 7 +++---- app/helpers/_vectorstore.py | 10 +++++++--- app/schemas/search.py | 14 +++++++++++++- app/tests/test_search.py | 11 ++++++----- app/tools/_baserag.py | 4 ++-- 5 files changed, 31 insertions(+), 15 deletions(-) diff --git a/app/endpoints/search.py b/app/endpoints/search.py index ac4fa71c..a5417601 100644 --- a/app/endpoints/search.py +++ b/app/endpoints/search.py @@ -1,8 +1,7 @@ from fastapi import APIRouter, Security from app.helpers import VectorStore -from app.schemas.chunks import Chunks -from app.schemas.search import SearchRequest +from app.schemas.search import SearchRequest, Searches from app.utils.lifespan import clients from app.utils.security import check_api_key @@ -10,7 +9,7 @@ @router.post("/search") -async def search(request: SearchRequest, user: str = Security(check_api_key)) -> Chunks: +async def search(request: SearchRequest, user: str = Security(check_api_key)) -> Searches: """ Similarity search for chunks in the vector store. @@ -27,4 +26,4 @@ async def search(request: SearchRequest, user: str = Security(check_api_key)) -> prompt=request.prompt, model=request.model, collection_names=request.collections, k=request.k, score_threshold=request.score_threshold ) - return Chunks(data=data) + return Searches(data=data) diff --git a/app/helpers/_vectorstore.py b/app/helpers/_vectorstore.py index 83f3ff0e..5cd72b5e 100644 --- a/app/helpers/_vectorstore.py +++ b/app/helpers/_vectorstore.py @@ -9,6 +9,7 @@ from app.schemas.chunks import Chunk from app.schemas.collections import CollectionMetadata from app.schemas.config import EMBEDDINGS_MODEL_TYPE, METADATA_COLLECTION, PRIVATE_COLLECTION_TYPE, PUBLIC_COLLECTION_TYPE +from app.schemas.search import Search class VectorStore: @@ -64,7 +65,7 @@ def search( k: Optional[int] = 4, score_threshold: Optional[float] = None, filter: Optional[Filter] = None, - ) -> List[Chunk]: + ) -> List[Search]: response = self.models[model].embeddings.create(input=[prompt], model=model) vector = response.data[0].embedding @@ -88,9 +89,12 @@ def search( # sort by similarity score and get top k chunks = sorted(chunks, key=lambda x: x.score, reverse=True)[:k] - chunks = [Chunk(id=chunk.id, content=chunk.payload["page_content"], metadata=chunk.payload["metadata"]) for chunk in chunks] + data = [ + Search(score=chunk.score, chunk=Chunk(id=chunk.id, content=chunk.payload["page_content"], metadata=chunk.payload["metadata"])) + for chunk in chunks + ] - return chunks + return data def get_collection_metadata(self, collection_names: List[str] = [], type: str = "all", errors: str = "raise") -> List[CollectionMetadata]: """ diff --git a/app/schemas/search.py b/app/schemas/search.py index d44e1fb7..4f20bd5a 100644 --- a/app/schemas/search.py +++ b/app/schemas/search.py @@ -1,7 +1,9 @@ -from typing import List, Optional +from typing import List, Literal, Optional from pydantic import BaseModel, Field, field_validator +from app.schemas.chunks import Chunk + class SearchRequest(BaseModel): prompt: str @@ -15,3 +17,13 @@ def blank_string(value): if value.strip() == "": raise ValueError("Prompt cannot be empty") return value + + +class Search(BaseModel): + score: float + chunk: Chunk + + +class Searches(BaseModel): + object: Literal["list"] = "list" + data: List[Search] diff --git a/app/tests/test_search.py b/app/tests/test_search.py index 439f1d45..e1ec2a4f 100644 --- a/app/tests/test_search.py +++ b/app/tests/test_search.py @@ -1,10 +1,11 @@ -import os import logging +import os + import pytest import wget -from app.schemas.chunks import Chunk, Chunks from app.schemas.config import EMBEDDINGS_MODEL_TYPE +from app.schemas.search import Search, Searches @pytest.fixture(scope="function") @@ -66,9 +67,9 @@ def test_search_response_status_code(self, args, session, setup): response = session.post(f"{args['base_url']}/search", json=data) assert response.status_code == 200, f"error: search request ({response.status_code} - {response.text})" - chunks = Chunks(**response.json()) - assert isinstance(chunks, Chunks) - assert all(isinstance(chunk, Chunk) for chunk in chunks.data) + searches = Searches(**response.json()) + assert isinstance(searches, Searches) + assert all(isinstance(search, Search) for search in searches.data) def test_search_with_score_threshold(self, args, session, setup): """Test search with a score threshold.""" diff --git a/app/tools/_baserag.py b/app/tools/_baserag.py index a2052a3e..3fcf1f21 100644 --- a/app/tools/_baserag.py +++ b/app/tools/_baserag.py @@ -39,8 +39,8 @@ async def get_prompt( vectorstore = VectorStore(clients=self.clients, user=request["user"]) prompt = request["messages"][-1]["content"] - chunks = vectorstore.search(model=embeddings_model, prompt=prompt, collection_names=collections, k=k) - + results = vectorstore.search(model=embeddings_model, prompt=prompt, collection_names=collections, k=k) + chunks = [result.chunk for result in results] metadata = {"chunks": [chunk.metadata for chunk in chunks]} documents = "\n\n".join([chunk.content for chunk in chunks]) prompt = prompt_template.format(documents=documents, prompt=prompt)