From c31514780f6c5c1a0a146234097cb0408c91d60d Mon Sep 17 00:00:00 2001 From: akshaj000 Date: Fri, 6 Oct 2023 17:37:38 +0530 Subject: [PATCH] Refactor get prompt service --- .../services/prompt_engine_service.py | 32 +------------------ 1 file changed, 1 insertion(+), 31 deletions(-) diff --git a/genai_stack/genai_server/services/prompt_engine_service.py b/genai_stack/genai_server/services/prompt_engine_service.py index 47364323..9231bd30 100644 --- a/genai_stack/genai_server/services/prompt_engine_service.py +++ b/genai_stack/genai_server/services/prompt_engine_service.py @@ -21,37 +21,7 @@ def get_prompt(self, data: PromptEngineGetRequestModel) -> PromptEngineGetRespon stack_session = session.get(StackSessionSchema, data.session_id) if stack_session is None: raise HTTPException(status_code=404, detail=f"Session {data.session_id} not found") - prompt_session = ( - session.query(PromptSchema) - .filter_by(stack_session=data.session_id, type=data.type.value) - .first() - ) - if prompt_session is not None: - template = prompt_session.template - prompt_type_map = { - PromptTypeEnum.SIMPLE_CHAT_PROMPT.value: "simple_chat_prompt_template", - PromptTypeEnum.CONTEXTUAL_CHAT_PROMPT.value: "contextual_chat_prompt_template", - PromptTypeEnum.CONTEXTUAL_QA_PROMPT.value: "contextual_qa_prompt_template", - } - input_variables = ["context", "history", "query"] - if data.type == PromptTypeEnum.SIMPLE_CHAT_PROMPT: - input_variables.remove("context") - elif data.type == PromptTypeEnum.CONTEXTUAL_QA_PROMPT: - input_variables.remove("history") - prompt = PromptTemplate(template=template, input_variables=input_variables) - stack = get_current_stack( - engine=session, - config=stack_config, - session=stack_session, - overide_config={ - "prompt_engine": { - "should_validate": data.should_validate, - prompt_type_map[data.type.value]: prompt - } - } - ) - else: - stack = get_current_stack(config=stack_config, engine=session, session=stack_session) + stack = get_current_stack(config=stack_config, engine=session, session=stack_session) prompt = stack.prompt_engine.get_prompt_template(promptType=data.type, query=data.query) return PromptEngineGetResponseModel( template=prompt.template,