Skip to content
This repository has been archived by the owner on Jul 12, 2024. It is now read-only.

Add prompt engine api #85

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Add prompts table

Revision ID: 7dd5490b4f14
Revises: f5cbe001454d
Create Date: 2023-10-05 08:58:16.091864

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = '7dd5490b4f14'
down_revision: Union[str, None] = 'f5cbe001454d'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('prompts',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('stack_session', sa.Integer(), nullable=False),
sa.Column('type', sa.Enum('SIMPLE_CHAT_PROMPT', 'CONTEXTUAL_CHAT_PROMPT', 'CONTEXTUAL_QA_PROMPT', name='prompttypeenum'), nullable=True),
sa.Column('template', sa.String(), nullable=True),
sa.Column('meta_data', sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('modified_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['stack_session'], ['stack_sessions.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('prompts')
# ### end Alembic commands ###
26 changes: 26 additions & 0 deletions genai_stack/genai_server/models/prompt_engine_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from langchain.prompts import PromptTemplate
from pydantic import BaseModel

from genai_stack.prompt_engine.utils import PromptTypeEnum


class PromptEngineBaseModel(BaseModel):
session_id: int
type: PromptTypeEnum


class PromptEngineSetRequestModel(PromptEngineBaseModel):
template: str


class PromptEngineSetResponseModel(PromptEngineBaseModel):
template: str


class PromptEngineGetRequestModel(PromptEngineBaseModel):
query: str = None
should_validate: bool = False


class PromptEngineGetResponseModel(PromptEngineBaseModel):
template: str
23 changes: 23 additions & 0 deletions genai_stack/genai_server/routers/prompt_engine_routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from fastapi import APIRouter

from genai_stack.constant import API, PROMPT_ENGINE
from genai_stack.genai_server.settings.settings import settings
from genai_stack.genai_server.models.prompt_engine_models import (
PromptEngineGetRequestModel, PromptEngineGetResponseModel,
PromptEngineSetRequestModel, PromptEngineSetResponseModel
)
from genai_stack.genai_server.services.prompt_engine_service import PromptEngineService

service = PromptEngineService(store=settings.STORE)

router = APIRouter(prefix=API + PROMPT_ENGINE, tags=["prompt-engine"])


@router.get("/prompt")
def get_prompt(data: PromptEngineGetRequestModel) -> PromptEngineGetResponseModel:
return service.get_prompt(data=data)


@router.post("/prompt")
def set_prompt(data: PromptEngineSetRequestModel) -> PromptEngineSetResponseModel:
return service.set_prompt(data=data)
2 changes: 2 additions & 0 deletions genai_stack/genai_server/schemas/components/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .etl import ETLJob, ETLJobStatus
from .prompt_engine import PromptSchema

20 changes: 20 additions & 0 deletions genai_stack/genai_server/schemas/components/prompt_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from sqlalchemy import Column, Integer, JSON, ForeignKey, Enum, String

from genai_stack.genai_server.schemas.base_schemas import TimeStampedSchema
from genai_stack.prompt_engine.utils import PromptTypeEnum


class PromptSchema(TimeStampedSchema):
"""
Schema for the Prompt model
"""

__tablename__ = "prompts"

id = Column(Integer, primary_key=True, autoincrement=True)
stack_session = Column(
Integer, ForeignKey("stack_sessions.id", ondelete="CASCADE"), nullable=False
)
type = Column(Enum(PromptTypeEnum), default=PromptTypeEnum.SIMPLE_CHAT_PROMPT.value)
template = Column(String)
meta_data = Column(JSON, nullable=True)
2 changes: 2 additions & 0 deletions genai_stack/genai_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
retriever_routes,
vectordb_routes,
etl_routes,
prompt_engine_routes,
model_routes,
)

Expand All @@ -27,6 +28,7 @@ def get_genai_server_app():
app.include_router(retriever_routes.router)
app.include_router(vectordb_routes.router)
app.include_router(etl_routes.router)
app.include_router(prompt_engine_routes.router)
app.include_router(model_routes.router)

return app
2 changes: 1 addition & 1 deletion genai_stack/genai_server/services/etl_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def submit_job(self, data: Any, stack_session_id: Optional[int] = None) -> ETLJo

data = ETLUtil(data).save_request(etl_job.id)

stack = get_current_stack(config=stack_config, session=stack_session)
stack = get_current_stack(config=stack_config, engine=session, session=stack_session)
get_etl_platform(stack=stack).handle_job(**data)

etl_job.data = data
Expand Down
69 changes: 69 additions & 0 deletions genai_stack/genai_server/services/prompt_engine_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from fastapi import HTTPException
from langchain.prompts import PromptTemplate
from sqlalchemy.orm import Session

from genai_stack.genai_platform.services import BaseService
from genai_stack.genai_server.models.prompt_engine_models import (
PromptEngineGetRequestModel, PromptEngineGetResponseModel,
PromptEngineSetRequestModel, PromptEngineSetResponseModel
)
from genai_stack.genai_server.schemas import StackSessionSchema
from genai_stack.genai_server.schemas.components.prompt_engine import PromptSchema
from genai_stack.genai_server.settings.config import stack_config
from genai_stack.genai_server.utils import get_current_stack
from genai_stack.prompt_engine.utils import PromptTypeEnum


class PromptEngineService(BaseService):

def get_prompt(self, data: PromptEngineGetRequestModel) -> PromptEngineGetResponseModel:
with Session(self.engine) as session:
stack_session = session.get(StackSessionSchema, data.session_id)
if stack_session is None:
raise HTTPException(status_code=404, detail=f"Session {data.session_id} not found")
stack = get_current_stack(config=stack_config, engine=session, session=stack_session)
prompt = stack.prompt_engine.get_prompt_template(promptType=data.type, query=data.query)
return PromptEngineGetResponseModel(
template=prompt.template,
session_id=data.session_id,
type=data.type.value
)

def set_prompt(self, data: PromptEngineSetRequestModel) -> PromptEngineSetResponseModel:
with Session(self.engine) as session:
stack_session = session.get(StackSessionSchema, data.session_id)
if stack_session is None:
raise HTTPException(status_code=404, detail=f"Session {data.session_id} not found")
input_variables = ["context", "history", "query"]
if data.type.value == PromptTypeEnum.SIMPLE_CHAT_PROMPT.value:
input_variables.remove("context")
elif data.type.value == PromptTypeEnum.CONTEXTUAL_QA_PROMPT.value:
input_variables.remove("history")
for variable in input_variables:
if variable not in data.template:
raise HTTPException(status_code=400, detail=f"Input variable {variable} not found in template")
for variable in data.template.split("{"):
if "}" in variable and variable.split("}")[0] not in input_variables:
raise HTTPException(status_code=400, detail=f"Unknown input variable {variable.split('}')[0]}")
prompt_session = (
session.query(PromptSchema)
.filter_by(stack_session=data.session_id, type=data.type.value)
.first()
)
if prompt_session is not None:
prompt_session.template = data.template
session.commit()
else:
prompt_session = PromptSchema(
stack_session=data.session_id,
type=data.type.value,
template=data.template,
meta_data={}
)
session.add(prompt_session)
session.commit()
return PromptEngineSetResponseModel(
template=prompt_session.template,
session_id=data.session_id,
type=data.type.value
)
2 changes: 1 addition & 1 deletion genai_stack/genai_server/services/retriever_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def retrieve(self, data: RetrieverRequestModel) -> RetrieverResponseModel:
stack_session = session.get(StackSessionSchema, data.session_id)
if stack_session is None:
raise HTTPException(status_code=404, detail=f"Session {data.session_id} not found")
stack = get_current_stack(config=stack_config, session=stack_session)
stack = get_current_stack(config=stack_config, engine=session, session=stack_session)
response = stack.retriever.retrieve(data.query)
return RetrieverResponseModel(
output=response['output'],
Expand Down
2 changes: 1 addition & 1 deletion genai_stack/genai_server/services/session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def create_session(self) -> StackSessionResponseModel:
created_at : datetime
modified_at : None
"""
stack = get_current_stack(stack_config, default_session=False)
stack = get_current_stack(config=stack_config, default_session=False)

with Session(self.engine) as session:
stack_session = StackSessionSchema(stack_id=1, meta_data={})
Expand Down
4 changes: 2 additions & 2 deletions genai_stack/genai_server/services/vectordb_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def add_documents(self, data: RetrieverAddDocumentsRequestModel) -> RetrieverAdd
stack_session = session.get(StackSessionSchema, data.session_id)
if stack_session is None:
raise HTTPException(status_code=404, detail=f"Session {data.session_id} not found")
stack = get_current_stack(config=stack_config, session=stack_session)
stack = get_current_stack(config=stack_config, engine=session, session=stack_session)
stack.vectordb.add_documents(data.documents)
return RetrieverAddDocumentsResponseModel(
documents=[
Expand All @@ -34,7 +34,7 @@ def search(self, data: RetrieverSearchRequestModel) -> RetrieverSearchResponseMo

with Session(self.engine) as session:
stack_session = session.get(StackSessionSchema, data.session_id)
stack = get_current_stack(config=stack_config, session=stack_session)
stack = get_current_stack(config=stack_config, engine=session, session=stack_session)
if stack_session is None:
raise HTTPException(status_code=404, detail=f"Session {data.session_id} not found")
documents = stack.vectordb.search(data.query)
Expand Down
59 changes: 57 additions & 2 deletions genai_stack/genai_server/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import string
import random

from langchain.prompts import PromptTemplate

from genai_stack.genai_server.schemas import PromptSchema
from genai_stack.prompt_engine.utils import PromptTypeEnum
from genai_stack.utils import import_class
from genai_stack.enums import StackComponentType
from genai_stack.genai_server.models.session_models import StackSessionResponseModel
Expand Down Expand Up @@ -68,7 +72,56 @@ def create_indexes(stack, stack_id: int, session_id: int) -> dict:
return meta_data


def get_current_stack(config: dict, session=None, default_session: bool = True):
def get_prompt_from_db(session, session_id, stack_config):
prompt_sessions = (
session.query(PromptSchema)
.filter_by(stack_session=session_id)
)
prompt_type_map = {
PromptTypeEnum.SIMPLE_CHAT_PROMPT.value: {
"field": "simple_chat_prompt_template",
"input_variables": ["history", "query"]
},
PromptTypeEnum.CONTEXTUAL_CHAT_PROMPT.value: {
"field": "contextual_chat_prompt_template",
"input_variables": ["context", "history", "query"]
},
PromptTypeEnum.CONTEXTUAL_QA_PROMPT.value: {
"field": "contextual_qa_prompt_template",
"input_variables": ["context", "query"]
}
}
for prompt_session in prompt_sessions:
if "prompt_engine" not in stack_config["components"]:
stack_config["components"]["prompt_engine"] = {
"name": "PromptEngine",
"config": {}
}
if "config" not in stack_config["components"]["prompt_engine"]:
stack_config["components"]["prompt_engine"]["config"] = {}
stack_config["components"]["prompt_engine"]["config"] = {
**stack_config["components"]["prompt_engine"]["config"],
prompt_type_map[prompt_session.type.value]["field"]: PromptTemplate(
template=prompt_session.template,
input_variables=prompt_type_map[prompt_session.type.value]["input_variables"]
)
}
return stack_config


def get_current_stack(
config: dict,
engine=None,
session=None,
default_session: bool = True,
overide_config: dict = None
):
if engine is not None:
config = get_prompt_from_db(
session=engine,
session_id=session.id,
stack_config=config
)
components = {}
if session is None and default_session:
from genai_stack.genai_server.settings.settings import settings
Expand All @@ -86,7 +139,9 @@ def get_current_stack(config: dict, session=None, default_session: bool = True):
or component_name == StackComponentType.MEMORY.value
):
configurations["index_name"] = session.meta_data[component_name]["index_name"]

if overide_config:
if component_name in overide_config:
configurations.update(overide_config[component_name])
components[component_name] = cls.from_kwargs(**configurations)
# To avoid circular import error
from genai_stack.stack.stack import Stack
Expand Down
1 change: 1 addition & 0 deletions genai_stack/prompt_engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class ValidationResponseDict(TypedDict):
reason: str
response: str


class PromptTypeEnum(enum.Enum):
SIMPLE_CHAT_PROMPT = "SIMPLE_CHAT_PROMPT"
CONTEXTUAL_CHAT_PROMPT = "CONTEXTUAL_CHAT_PROMPT"
Expand Down
Loading