Skip to content
This repository has been archived by the owner on Jul 12, 2024. It is now read-only.

Add llm cache APIs #99

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions genai_stack/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
VECTORDB = "/vectordb"
ETL = "/etl"
PROMPT_ENGINE = "/prompt-engine"
LLM_CACHE = "/llm-cache"
MODEL = "/model"
19 changes: 19 additions & 0 deletions genai_stack/genai_server/models/cache_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pydantic import BaseModel


class BaseCacheRequestModel(BaseModel):
session_id: int
query: str
metadata: dict = None


class GetCacheRequestModel(BaseCacheRequestModel):
pass


class SetCacheRequestModel(BaseCacheRequestModel):
response: str


class CacheResponseModel(BaseCacheRequestModel):
response: str
22 changes: 22 additions & 0 deletions genai_stack/genai_server/routers/cache_routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from fastapi import APIRouter

from genai_stack.constant import API, LLM_CACHE
from genai_stack.genai_server.settings.settings import settings
from genai_stack.genai_server.models.cache_models import (
GetCacheRequestModel, SetCacheRequestModel, CacheResponseModel
)
from genai_stack.genai_server.services.cache_service import LLMCacheService

service = LLMCacheService(store=settings.STORE)

router = APIRouter(prefix=API + LLM_CACHE, tags=["llm_cache"])


@router.get("/get-cache")
def get_cache(data: GetCacheRequestModel) -> CacheResponseModel:
return service.get_cache(data=data)


@router.post("/set-cache")
def set_cache(data: SetCacheRequestModel) -> CacheResponseModel:
return service.set_cache(data=data)
2 changes: 2 additions & 0 deletions genai_stack/genai_server/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from fastapi import FastAPI

from genai_stack.genai_server.routers import (
cache_routes,
session_routes,
retriever_routes,
vectordb_routes,
Expand All @@ -26,6 +27,7 @@ def get_genai_server_app():
app.include_router(retriever_routes.router)
app.include_router(vectordb_routes.router)
app.include_router(etl_routes.router)
app.include_router(cache_routes.router)
app.include_router(model_routes.router)

return app
46 changes: 46 additions & 0 deletions genai_stack/genai_server/services/cache_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from fastapi import HTTPException
from sqlalchemy.orm import Session

from genai_stack.genai_platform.services import BaseService
from genai_stack.genai_server.models.cache_models import GetCacheRequestModel, SetCacheRequestModel, CacheResponseModel
from genai_stack.genai_server.settings.config import stack_config
from genai_stack.genai_server.utils import get_current_stack
from genai_stack.genai_store.schemas import StackSessionSchema


class LLMCacheService(BaseService):

def get_cache(self, data: GetCacheRequestModel) -> CacheResponseModel:
with Session(self.engine) as session:
stack_session = session.get(StackSessionSchema, data.session_id)
if stack_session is None:
raise HTTPException(status_code=404, detail=f"Session {data.session_id} not found")
stack = get_current_stack(config=stack_config, session=stack_session)
response = stack.llm_cache.get_cache(
query=data.query,
metadata=data.metadata
)
return CacheResponseModel(
session_id=data.session_id,
query=data.query,
metadata=data.metadata,
response=response
)

def set_cache(self, data: SetCacheRequestModel) -> CacheResponseModel:
with Session(self.engine) as session:
stack_session = session.get(StackSessionSchema, data.session_id)
if stack_session is None:
raise HTTPException(status_code=404, detail=f"Session {data.session_id} not found")
stack = get_current_stack(config=stack_config, session=stack_session)
stack.llm_cache.set_cache(
query=data.query,
response=data.response,
metadata=data.metadata
)
return CacheResponseModel(
session_id=data.session_id,
query=data.query,
metadata=data.metadata,
response=data.response
)
3 changes: 2 additions & 1 deletion genai_stack/genai_server/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_component_class(component_name: str, class_name: str):

# Creating indexes provided by user
def create_indexes(stack, stack_id: int, session_id: int) -> dict:
components = [StackComponentType.VECTOR_DB, StackComponentType.MEMORY]
components = [StackComponentType.VECTOR_DB, StackComponentType.MEMORY, StackComponentType.CACHE]

meta_data = {}
for component in components:
Expand Down Expand Up @@ -84,6 +84,7 @@ def get_current_stack(config: dict, session=None, default_session: bool = True):
if (
component_name == StackComponentType.VECTOR_DB.value
or component_name == StackComponentType.MEMORY.value
or component_name == StackComponentType.CACHE.value
):
configurations["index_name"] = session.meta_data[component_name]["index_name"]

Expand Down
99 changes: 99 additions & 0 deletions tests/api/test_genai_server/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#!/usr/bin/env python

"""Tests for `genai_server`."""
import unittest
import requests


class TestLLMCacheAPIs(unittest.TestCase):

def setUp(self) -> None:
self.base_url = "http://127.0.0.1:5000/api/llm-cache"
Akshaj000 marked this conversation as resolved.
Show resolved Hide resolved

def test_set_cache(self):
response = requests.post(
url=self.base_url + "/set-cache",
json={
"session_id": 1,
"query": "Where is sunil from ?",
"response": "Sunil is from Hyderabad.",
"metadata": {"source": "/path", "page": 1}
}
)
assert response.status_code == 200
assert response.json()
data = response.json()
assert "query" in data.keys()
assert "metadata" in data.keys()
assert "response" in data.keys()

def test_get_cache(self):
response = requests.get(
url=self.base_url + "/get-cache",
json={
"session_id": 1,
"query": "Where is sunil from ?"
}
)

assert response.status_code == 200
assert response.json()
data = response.json()
assert "query" in data.keys()
assert "metadata" in data.keys()
assert "response" in data.keys()

def test_get_and_set(self):
query = "Where is sunil from ?"
metadata = {"source": "/path", "page": 1}
output = "Sunil is from Hyderabad."
response = requests.post(
url=self.base_url + "/set-cache",
json={
"session_id": 1,
"query": query,
"response": output,
"metadata": metadata
}
)
assert response.status_code == 200
assert response.json()
data = response.json()
assert "query" in data.keys() and data.get("query") == query
assert "metadata" in data.keys() and data.get("metadata") == metadata
assert "response" in data.keys() and data.get("response") == output

response = requests.get(
url=self.base_url + "/get-cache",
json={
"session_id": 1,
"query": query
}
)

assert response.status_code == 200
assert response.json()
data = response.json()
assert "query" in data.keys() and data.get("query") == query
assert "response" in data.keys() and data.get("response") == output

response = requests.get(
url=self.base_url + "/get-cache",
json={
"session_id": 1,
"query": "Where is sunil from ?",
"metadata": {"source": "/pathdiff", "page": 1}
}
)
assert response.status_code != 200

response = requests.get(
url=self.base_url + "/get-cache",
json={
"session_id": 1,
"query": "Where is sunil from ?",
"metadata": metadata
}
)

assert response.status_code == 200