Skip to content

Commit

Permalink
feat: change search response schemas
Browse files Browse the repository at this point in the history
  • Loading branch information
leoguillaumegouv committed Sep 19, 2024
1 parent 9de4aac commit 4fe3e0a
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 15 deletions.
7 changes: 3 additions & 4 deletions app/endpoints/search.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from fastapi import APIRouter, Security

from app.helpers import VectorStore
from app.schemas.chunks import Chunks
from app.schemas.search import SearchRequest
from app.schemas.search import SearchRequest, Searches
from app.utils.lifespan import clients
from app.utils.security import check_api_key

router = APIRouter()


@router.post("/search")
async def search(request: SearchRequest, user: str = Security(check_api_key)) -> Chunks:
async def search(request: SearchRequest, user: str = Security(check_api_key)) -> Searches:
"""
Similarity search for chunks in the vector store.
Expand All @@ -27,4 +26,4 @@ async def search(request: SearchRequest, user: str = Security(check_api_key)) ->
prompt=request.prompt, model=request.model, collection_names=request.collections, k=request.k, score_threshold=request.score_threshold
)

return Chunks(data=data)
return Searches(data=data)
10 changes: 7 additions & 3 deletions app/helpers/_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from app.schemas.chunks import Chunk
from app.schemas.collections import CollectionMetadata
from app.schemas.config import EMBEDDINGS_MODEL_TYPE, METADATA_COLLECTION, PRIVATE_COLLECTION_TYPE, PUBLIC_COLLECTION_TYPE
from app.schemas.search import Search


class VectorStore:
Expand Down Expand Up @@ -64,7 +65,7 @@ def search(
k: Optional[int] = 4,
score_threshold: Optional[float] = None,
filter: Optional[Filter] = None,
) -> List[Chunk]:
) -> List[Search]:
response = self.models[model].embeddings.create(input=[prompt], model=model)
vector = response.data[0].embedding

Expand All @@ -88,9 +89,12 @@ def search(

# sort by similarity score and get top k
chunks = sorted(chunks, key=lambda x: x.score, reverse=True)[:k]
chunks = [Chunk(id=chunk.id, content=chunk.payload["page_content"], metadata=chunk.payload["metadata"]) for chunk in chunks]
data = [
Search(score=chunk.score, chunk=Chunk(id=chunk.id, content=chunk.payload["page_content"], metadata=chunk.payload["metadata"]))
for chunk in chunks
]

return chunks
return data

def get_collection_metadata(self, collection_names: List[str] = [], type: str = "all", errors: str = "raise") -> List[CollectionMetadata]:
"""
Expand Down
14 changes: 13 additions & 1 deletion app/schemas/search.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import List, Optional
from typing import List, Literal, Optional

from pydantic import BaseModel, Field, field_validator

from app.schemas.chunks import Chunk


class SearchRequest(BaseModel):
prompt: str
Expand All @@ -15,3 +17,13 @@ def blank_string(value):
if value.strip() == "":
raise ValueError("Prompt cannot be empty")
return value


class Search(BaseModel):
score: float
chunk: Chunk


class Searches(BaseModel):
object: Literal["list"] = "list"
data: List[Search]
11 changes: 6 additions & 5 deletions app/tests/test_search.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
import logging
import os

import pytest
import wget

from app.schemas.chunks import Chunk, Chunks
from app.schemas.config import EMBEDDINGS_MODEL_TYPE
from app.schemas.search import Search, Searches


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -66,9 +67,9 @@ def test_search_response_status_code(self, args, session, setup):
response = session.post(f"{args['base_url']}/search", json=data)
assert response.status_code == 200, f"error: search request ({response.status_code} - {response.text})"

chunks = Chunks(**response.json())
assert isinstance(chunks, Chunks)
assert all(isinstance(chunk, Chunk) for chunk in chunks.data)
searches = Searches(**response.json())
assert isinstance(searches, Searches)
assert all(isinstance(search, Search) for search in searches.data)

def test_search_with_score_threshold(self, args, session, setup):
"""Test search with a score threshold."""
Expand Down
4 changes: 2 additions & 2 deletions app/tools/_baserag.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ async def get_prompt(
vectorstore = VectorStore(clients=self.clients, user=request["user"])
prompt = request["messages"][-1]["content"]

chunks = vectorstore.search(model=embeddings_model, prompt=prompt, collection_names=collections, k=k)

results = vectorstore.search(model=embeddings_model, prompt=prompt, collection_names=collections, k=k)
chunks = [result.chunk for result in results]
metadata = {"chunks": [chunk.metadata for chunk in chunks]}
documents = "\n\n".join([chunk.content for chunk in chunks])
prompt = prompt_template.format(documents=documents, prompt=prompt)
Expand Down

0 comments on commit 4fe3e0a

Please sign in to comment.