Skip to content

Commit

Permalink
use 4o-mini in more places
Browse files Browse the repository at this point in the history
  • Loading branch information
bassner committed Dec 20, 2024
1 parent fad216c commit ddd77de
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 25 deletions.
33 changes: 20 additions & 13 deletions app/pipeline/chat/course_chat_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def get_mastery(progress, confidence):
class CourseChatPipeline(Pipeline):
"""Course chat pipeline that answers course related questions from students."""

llm: IrisLangchainChatModel
llm_big: IrisLangchainChatModel
llm_small: IrisLangchainChatModel
pipeline: Runnable
lecture_pipeline: LectureChatPipeline
suggestion_pipeline: InteractionSuggestionPipeline
Expand All @@ -93,14 +94,20 @@ def __init__(
self.event = event

# Set the langchain chat model
request_handler = CapabilityRequestHandler(
requirements=RequirementList(
gpt_version_equivalent=4.5,
)
)
completion_args = CompletionArguments(temperature=0.5, max_tokens=2000)
self.llm = IrisLangchainChatModel(
request_handler=request_handler, completion_args=completion_args
self.llm_big = IrisLangchainChatModel(
request_handler=CapabilityRequestHandler(
requirements=RequirementList(
gpt_version_equivalent=4.5,
)
), completion_args=completion_args
)
self.llm_small = IrisLangchainChatModel(
request_handler=CapabilityRequestHandler(
requirements=RequirementList(
gpt_version_equivalent=4.25,
)
), completion_args=completion_args
)
self.callback = callback

Expand All @@ -110,14 +117,14 @@ def __init__(
self.citation_pipeline = CitationPipeline()

# Create the pipeline
self.pipeline = self.llm | JsonOutputParser()
self.pipeline = self.llm_big | JsonOutputParser()
self.tokens = []

def __repr__(self):
return f"{self.__class__.__name__}(llm={self.llm})"
return f"{self.__class__.__name__}(llm_big={self.llm_big}, llm_small={self.llm_small})"

def __str__(self):
return f"{self.__class__.__name__}(llm={self.llm})"
return f"{self.__class__.__name__}(llm_big={self.llm_big}, llm_small={self.llm_small})"

@traceable(name="Course Chat Pipeline")
def __call__(self, dto: CourseChatPipelineExecutionDTO, **kwargs):
Expand Down Expand Up @@ -395,7 +402,7 @@ def lecture_content_retrieval() -> str:
# No idea why we need this extra contrary to exercise chat agent in this case, but solves the issue.
params.update({"tools": tools})
agent = create_tool_calling_agent(
llm=self.llm, tools=tools, prompt=self.prompt
llm=self.llm_big, tools=tools, prompt=self.prompt
)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=False)

Expand All @@ -404,7 +411,7 @@ def lecture_content_retrieval() -> str:
for step in agent_executor.iter(params):
print("STEP:", step)
self._append_tokens(
self.llm.tokens, PipelineEnum.IRIS_CHAT_COURSE_MESSAGE
self.llm_big.tokens, PipelineEnum.IRIS_CHAT_COURSE_MESSAGE
)
if step.get("output", None):
out = step["output"]
Expand Down
28 changes: 18 additions & 10 deletions app/pipeline/chat/exercise_chat_agent_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def map_message_role(role: IrisMessageRole) -> str:
class ExerciseChatAgentPipeline(Pipeline):
"""Exercise chat agent pipeline that answers exercises related questions from students."""

llm: IrisLangchainChatModel
llm_big: IrisLangchainChatModel
llm_small: IrisLangchainChatModel
pipeline: Runnable
callback: ExerciseChatStatusCallback
suggestion_pipeline: InteractionSuggestionPipeline
Expand All @@ -112,14 +113,22 @@ def __init__(
super().__init__(implementation_id="exercise_chat_pipeline")
# Set the langchain chat model
completion_args = CompletionArguments(temperature=0.5, max_tokens=2000)
self.llm = IrisLangchainChatModel(
self.llm_big = IrisLangchainChatModel(
request_handler=CapabilityRequestHandler(
requirements=RequirementList(
gpt_version_equivalent=4.5,
),
),
completion_args=completion_args,
)
self.llm_small = IrisLangchainChatModel(
request_handler=CapabilityRequestHandler(
requirements=RequirementList(
gpt_version_equivalent=4.25,
),
),
completion_args=completion_args,
)
self.variant = variant
self.event = event
self.callback = callback
Expand All @@ -130,15 +139,15 @@ def __init__(
self.retriever = LectureRetrieval(self.db.client)
self.reranker_pipeline = RerankerPipeline()
self.code_feedback_pipeline = CodeFeedbackPipeline()
self.pipeline = self.llm | JsonOutputParser()
self.pipeline = self.llm_big | JsonOutputParser()
self.citation_pipeline = CitationPipeline()
self.tokens = []

def __repr__(self):
return f"{self.__class__.__name__}(llm={self.llm})"
return f"{self.__class__.__name__}(llm_big={self.llm_big}, llm_small={self.llm_small})"

def __str__(self):
return f"{self.__class__.__name__}(llm={self.llm})"
return f"{self.__class__.__name__}(llm_big={self.llm_big}, llm_small={self.llm_small})"

@traceable(name="Exercise Chat Agent Pipeline")
def __call__(self, dto: ExerciseChatPipelineExecutionDTO):
Expand Down Expand Up @@ -504,15 +513,14 @@ def lecture_content_retrieval() -> str:
tool_list.append(lecture_content_retrieval)
tools = generate_structured_tools_from_functions(tool_list)
agent = create_tool_calling_agent(
llm=self.llm, tools=tools, prompt=self.prompt
llm=self.llm_big, tools=tools, prompt=self.prompt
)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=False)
self.callback.in_progress()
out = None
for step in agent_executor.iter(params):
print("STEP:", step)
self._append_tokens(
self.llm.tokens, PipelineEnum.IRIS_CHAT_EXERCISE_AGENT_MESSAGE
self.llm_big.tokens, PipelineEnum.IRIS_CHAT_EXERCISE_AGENT_MESSAGE
)
if step.get("output", None):
out = step["output"]
Expand All @@ -525,13 +533,13 @@ def lecture_content_retrieval() -> str:
]
)

guide_response = (self.prompt | self.llm | StrOutputParser()).invoke(
guide_response = (self.prompt | self.llm_small | StrOutputParser()).invoke(
{
"response": out,
}
)
self._append_tokens(
self.llm.tokens, PipelineEnum.IRIS_CHAT_EXERCISE_AGENT_MESSAGE
self.llm_small.tokens, PipelineEnum.IRIS_CHAT_EXERCISE_AGENT_MESSAGE
)
if "!ok!" in guide_response:
print("Response is ok and not rewritten!!!")
Expand Down
2 changes: 1 addition & 1 deletion app/pipeline/chat/interaction_suggestion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, variant: str = "default"):
# Set the langchain chat model
request_handler = CapabilityRequestHandler(
requirements=RequirementList(
gpt_version_equivalent=4.5,
gpt_version_equivalent=4.25,
json_mode=True,
)
)
Expand Down
2 changes: 1 addition & 1 deletion app/pipeline/prompts/citation_prompt.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ Answer without citations:
Paragraphs with their Lecture Names, Unit Names, Links and Page Numbers:
{Paragraphs}

Answer with citations (ensure empty line between the message and the citations):
If the answer actually does not contain any information from the paragraphs, please do not include any citations and return '!NONE!'.
Original answer with citations (ensure two empty lines between the message and the citations):

0 comments on commit ddd77de

Please sign in to comment.