diff --git a/genai_stack/constant.py b/genai_stack/constant.py index 0e7b405b..df420131 100644 --- a/genai_stack/constant.py +++ b/genai_stack/constant.py @@ -7,4 +7,5 @@ VECTORDB = "/vectordb" ETL = "/etl" PROMPT_ENGINE = "/prompt-engine" +LLM_CACHE = "/llm-cache" MODEL = "/model" diff --git a/genai_stack/genai_server/models/cache_models.py b/genai_stack/genai_server/models/cache_models.py new file mode 100644 index 00000000..5fe55e6d --- /dev/null +++ b/genai_stack/genai_server/models/cache_models.py @@ -0,0 +1,19 @@ +from pydantic import BaseModel + + +class BaseCacheRequestModel(BaseModel): + session_id: int + query: str + metadata: dict = None + + +class GetCacheRequestModel(BaseCacheRequestModel): + pass + + +class SetCacheRequestModel(BaseCacheRequestModel): + response: str + + +class CacheResponseModel(BaseCacheRequestModel): + response: str diff --git a/genai_stack/genai_server/routers/cache_routes.py b/genai_stack/genai_server/routers/cache_routes.py new file mode 100644 index 00000000..aea02a0b --- /dev/null +++ b/genai_stack/genai_server/routers/cache_routes.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter + +from genai_stack.constant import API, LLM_CACHE +from genai_stack.genai_server.settings.settings import settings +from genai_stack.genai_server.models.cache_models import ( + GetCacheRequestModel, SetCacheRequestModel, CacheResponseModel +) +from genai_stack.genai_server.services.cache_service import LLMCacheService + +service = LLMCacheService(store=settings.STORE) + +router = APIRouter(prefix=API + LLM_CACHE, tags=["llm_cache"]) + + +@router.get("/get-cache") +def get_cache(data: GetCacheRequestModel) -> CacheResponseModel: + return service.get_cache(data=data) + + +@router.post("/set-cache") +def set_cache(data: SetCacheRequestModel) -> CacheResponseModel: + return service.set_cache(data=data) diff --git a/genai_stack/genai_server/server.py b/genai_stack/genai_server/server.py index d519736c..33ae6eb7 100644 --- a/genai_stack/genai_server/server.py +++ b/genai_stack/genai_server/server.py @@ -1,6 +1,7 @@ from fastapi import FastAPI from genai_stack.genai_server.routers import ( + cache_routes, session_routes, retriever_routes, vectordb_routes, @@ -26,6 +27,7 @@ def get_genai_server_app(): app.include_router(retriever_routes.router) app.include_router(vectordb_routes.router) app.include_router(etl_routes.router) + app.include_router(cache_routes.router) app.include_router(model_routes.router) return app diff --git a/genai_stack/genai_server/services/cache_service.py b/genai_stack/genai_server/services/cache_service.py new file mode 100644 index 00000000..864b29ce --- /dev/null +++ b/genai_stack/genai_server/services/cache_service.py @@ -0,0 +1,46 @@ +from fastapi import HTTPException +from sqlalchemy.orm import Session + +from genai_stack.genai_platform.services import BaseService +from genai_stack.genai_server.models.cache_models import GetCacheRequestModel, SetCacheRequestModel, CacheResponseModel +from genai_stack.genai_server.settings.config import stack_config +from genai_stack.genai_server.utils import get_current_stack +from genai_stack.genai_store.schemas import StackSessionSchema + + +class LLMCacheService(BaseService): + + def get_cache(self, data: GetCacheRequestModel) -> CacheResponseModel: + with Session(self.engine) as session: + stack_session = session.get(StackSessionSchema, data.session_id) + if stack_session is None: + raise HTTPException(status_code=404, detail=f"Session {data.session_id} not found") + stack = get_current_stack(config=stack_config, session=stack_session) + response = stack.llm_cache.get_cache( + query=data.query, + metadata=data.metadata + ) + return CacheResponseModel( + session_id=data.session_id, + query=data.query, + metadata=data.metadata, + response=response + ) + + def set_cache(self, data: SetCacheRequestModel) -> CacheResponseModel: + with Session(self.engine) as session: + stack_session = session.get(StackSessionSchema, data.session_id) + if stack_session is None: + raise HTTPException(status_code=404, detail=f"Session {data.session_id} not found") + stack = get_current_stack(config=stack_config, session=stack_session) + stack.llm_cache.set_cache( + query=data.query, + response=data.response, + metadata=data.metadata + ) + return CacheResponseModel( + session_id=data.session_id, + query=data.query, + metadata=data.metadata, + response=data.response + ) diff --git a/genai_stack/genai_server/utils/utils.py b/genai_stack/genai_server/utils/utils.py index 678ccd9e..e5be4c44 100644 --- a/genai_stack/genai_server/utils/utils.py +++ b/genai_stack/genai_server/utils/utils.py @@ -56,7 +56,7 @@ def get_component_class(component_name: str, class_name: str): # Creating indexes provided by user def create_indexes(stack, stack_id: int, session_id: int) -> dict: - components = [StackComponentType.VECTOR_DB, StackComponentType.MEMORY] + components = [StackComponentType.VECTOR_DB, StackComponentType.MEMORY, StackComponentType.CACHE] meta_data = {} for component in components: @@ -84,6 +84,7 @@ def get_current_stack(config: dict, session=None, default_session: bool = True): if ( component_name == StackComponentType.VECTOR_DB.value or component_name == StackComponentType.MEMORY.value + or component_name == StackComponentType.CACHE.value ): configurations["index_name"] = session.meta_data[component_name]["index_name"] diff --git a/tests/api/test_genai_server/test_cache.py b/tests/api/test_genai_server/test_cache.py new file mode 100644 index 00000000..25f47aa9 --- /dev/null +++ b/tests/api/test_genai_server/test_cache.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python + +"""Tests for `genai_server`.""" +import unittest +import requests + + +class TestLLMCacheAPIs(unittest.TestCase): + + def setUp(self) -> None: + self.base_url = "http://127.0.0.1:5000/api/llm-cache" + + def test_set_cache(self): + response = requests.post( + url=self.base_url + "/set-cache", + json={ + "session_id": 1, + "query": "Where is sunil from ?", + "response": "Sunil is from Hyderabad.", + "metadata": {"source": "/path", "page": 1} + } + ) + assert response.status_code == 200 + assert response.json() + data = response.json() + assert "query" in data.keys() + assert "metadata" in data.keys() + assert "response" in data.keys() + + def test_get_cache(self): + response = requests.get( + url=self.base_url + "/get-cache", + json={ + "session_id": 1, + "query": "Where is sunil from ?" + } + ) + + assert response.status_code == 200 + assert response.json() + data = response.json() + assert "query" in data.keys() + assert "metadata" in data.keys() + assert "response" in data.keys() + + def test_get_and_set(self): + query = "Where is sunil from ?" + metadata = {"source": "/path", "page": 1} + output = "Sunil is from Hyderabad." + response = requests.post( + url=self.base_url + "/set-cache", + json={ + "session_id": 1, + "query": query, + "response": output, + "metadata": metadata + } + ) + assert response.status_code == 200 + assert response.json() + data = response.json() + assert "query" in data.keys() and data.get("query") == query + assert "metadata" in data.keys() and data.get("metadata") == metadata + assert "response" in data.keys() and data.get("response") == output + + response = requests.get( + url=self.base_url + "/get-cache", + json={ + "session_id": 1, + "query": query + } + ) + + assert response.status_code == 200 + assert response.json() + data = response.json() + assert "query" in data.keys() and data.get("query") == query + assert "response" in data.keys() and data.get("response") == output + + response = requests.get( + url=self.base_url + "/get-cache", + json={ + "session_id": 1, + "query": "Where is sunil from ?", + "metadata": {"source": "/pathdiff", "page": 1} + } + ) + assert response.status_code != 200 + + response = requests.get( + url=self.base_url + "/get-cache", + json={ + "session_id": 1, + "query": "Where is sunil from ?", + "metadata": metadata + } + ) + + assert response.status_code == 200