From 3058bafcf25dddd1c0a5258816b130a68ec6ccb9 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Wed, 27 Nov 2024 10:40:34 -0800 Subject: [PATCH 1/2] Propagate team cancellation token in agentchat (#4400) * Propagate team cancellation token in agentchat * Docs --------- Co-authored-by: Ryan Sweet --- .../teams/_group_chat/_base_group_chat.py | 107 +++++++++++++++++- .../_group_chat/_base_group_chat_manager.py | 28 ++++- .../_group_chat/_chat_agent_container.py | 1 + .../_magentic_one/_magentic_one_group_chat.py | 1 - .../_magentic_one_orchestrator.py | 45 ++++---- .../tests/test_group_chat.py | 38 ++++++- .../tests/test_magentic_one_group_chat.py | 80 +++++++++++++ 7 files changed, 271 insertions(+), 29 deletions(-) create mode 100644 python/packages/autogen-agentchat/tests/test_magentic_one_group_chat.py diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index b038d857305..0990ddc6652 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -170,6 +170,13 @@ async def run( :meth:`run_stream` to run the team and then returns the final result. Once the team is stopped, the termination condition is reset. + Args: + task (str | ChatMessage | None): The task to run the team with. + cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately. + Setting the cancellation token potentially put the team in an inconsistent state, + and it may not reset the termination condition. + To gracefully stop the team, use :class:`~autogen_agentchat.task.ExternalTermination` instead. + Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team: @@ -198,6 +205,47 @@ async def main() -> None: print(result) + asyncio.run(main()) + + + Example using the :class:`~autogen_core.base.CancellationToken` to cancel the task: + + .. code-block:: python + + import asyncio + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.task import MaxMessageTermination + from autogen_agentchat.teams import RoundRobinGroupChat + from autogen_core.base import CancellationToken + from autogen_ext.models import OpenAIChatCompletionClient + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + agent1 = AssistantAgent("Assistant1", model_client=model_client) + agent2 = AssistantAgent("Assistant2", model_client=model_client) + termination = MaxMessageTermination(3) + team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination) + + cancellation_token = CancellationToken() + + # Create a task to run the team in the background. + run_task = asyncio.create_task( + team.run( + task="Count from 1 to 10, respond one at a time.", + cancellation_token=cancellation_token, + ) + ) + + # Wait for 1 second and then cancel the task. + await asyncio.sleep(1) + cancellation_token.cancel() + + # This will raise a cancellation error. + await run_task + + asyncio.run(main()) """ result: TaskResult | None = None @@ -221,6 +269,13 @@ async def run_stream( of the type :class:`TaskResult` as the last item in the stream. Once the team is stopped, the termination condition is reset. + Args: + task (str | ChatMessage | None): The task to run the team with. + cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately. + Setting the cancellation token potentially put the team in an inconsistent state, + and it may not reset the termination condition. + To gracefully stop the team, use :class:`~autogen_agentchat.task.ExternalTermination` instead. + Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team: .. code-block:: python @@ -251,7 +306,52 @@ async def main() -> None: asyncio.run(main()) + + + Example using the :class:`~autogen_core.base.CancellationToken` to cancel the task: + + .. code-block:: python + + import asyncio + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.task import MaxMessageTermination, Console + from autogen_agentchat.teams import RoundRobinGroupChat + from autogen_core.base import CancellationToken + from autogen_ext.models import OpenAIChatCompletionClient + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + agent1 = AssistantAgent("Assistant1", model_client=model_client) + agent2 = AssistantAgent("Assistant2", model_client=model_client) + termination = MaxMessageTermination(3) + team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination) + + cancellation_token = CancellationToken() + + # Create a task to run the team in the background. + run_task = asyncio.create_task( + Console( + team.run_stream( + task="Count from 1 to 10, respond one at a time.", + cancellation_token=cancellation_token, + ) + ) + ) + + # Wait for 1 second and then cancel the task. + await asyncio.sleep(1) + cancellation_token.cancel() + + # This will raise a cancellation error. + await run_task + + + asyncio.run(main()) + """ + # Create the first chat message if the task is a string or a chat message. first_chat_message: ChatMessage | None = None if task is None: @@ -288,12 +388,17 @@ async def stop_runtime() -> None: await self._runtime.send_message( GroupChatStart(message=first_chat_message), recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id), + cancellation_token=cancellation_token, ) # Collect the output messages in order. output_messages: List[AgentMessage] = [] # Yield the messsages until the queue is empty. while True: - message = await self._output_message_queue.get() + message_future = asyncio.ensure_future(self._output_message_queue.get()) + if cancellation_token is not None: + cancellation_token.link_future(message_future) + # Wait for the next message, this will raise an exception if the task is cancelled. + message = await message_future if message is None: break yield message diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py index d2a2b917690..201db26bfba 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py @@ -1,3 +1,4 @@ +import asyncio from abc import ABC, abstractmethod from typing import Any, List @@ -78,7 +79,9 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type)) # Relay the start message to the participants. - await self.publish_message(message, topic_id=DefaultTopicId(type=self._group_topic_type)) + await self.publish_message( + message, topic_id=DefaultTopicId(type=self._group_topic_type), cancellation_token=ctx.cancellation_token + ) # Append the user message to the message thread. self._message_thread.append(message.message) @@ -95,8 +98,16 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No await self._termination_condition.reset() return - speaker_topic_type = await self.select_speaker(self._message_thread) - await self.publish_message(GroupChatRequestPublish(), topic_id=DefaultTopicId(type=speaker_topic_type)) + # Select a speaker to start the conversation. + speaker_topic_type_future = asyncio.ensure_future(self.select_speaker(self._message_thread)) + # Link the select speaker future to the cancellation token. + ctx.cancellation_token.link_future(speaker_topic_type_future) + speaker_topic_type = await speaker_topic_type_future + await self.publish_message( + GroupChatRequestPublish(), + topic_id=DefaultTopicId(type=speaker_topic_type), + cancellation_token=ctx.cancellation_token, + ) @event async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None: @@ -140,8 +151,15 @@ async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: Mess return # Select a speaker to continue the conversation. - speaker_topic_type = await self.select_speaker(self._message_thread) - await self.publish_message(GroupChatRequestPublish(), topic_id=DefaultTopicId(type=speaker_topic_type)) + speaker_topic_type_future = asyncio.ensure_future(self.select_speaker(self._message_thread)) + # Link the select speaker future to the cancellation token. + ctx.cancellation_token.link_future(speaker_topic_type_future) + speaker_topic_type = await speaker_topic_type_future + await self.publish_message( + GroupChatRequestPublish(), + topic_id=DefaultTopicId(type=speaker_topic_type), + cancellation_token=ctx.cancellation_token, + ) @rpc async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None: diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py index 31570803286..17c9830086b 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py @@ -71,6 +71,7 @@ async def handle_request(self, message: GroupChatRequestPublish, ctx: MessageCon await self.publish_message( GroupChatAgentResponse(agent_response=response), topic_id=DefaultTopicId(type=self._parent_topic_type), + cancellation_token=ctx.cancellation_token, ) async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None: diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_group_chat.py index cd67ced11e5..d199cbfd712 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_group_chat.py @@ -47,7 +47,6 @@ def _create_group_chat_manager_factory( return lambda: MagenticOneOrchestrator( group_topic_type, output_topic_type, - self._team_id, participant_topic_types, participant_descriptions, max_turns, diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py index f6963016234..e1f80a09805 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py @@ -1,7 +1,7 @@ import json from typing import Any, List -from autogen_core.base import MessageContext, AgentId +from autogen_core.base import AgentId, CancellationToken, MessageContext from autogen_core.components import DefaultTopicId, Image, event, rpc from autogen_core.components.models import ( AssistantMessage, @@ -42,7 +42,6 @@ def __init__( self, group_topic_type: str, output_topic_type: str, - team_id: str, participant_topic_types: List[str], participant_descriptions: List[str], max_turns: int | None, @@ -52,7 +51,6 @@ def __init__( super().__init__(description="Group chat manager") self._group_topic_type = group_topic_type self._output_topic_type = output_topic_type - self._team_id = team_id if len(participant_topic_types) != len(participant_descriptions): raise ValueError("The number of participant topic types, agent types, and descriptions must be the same.") if len(set(participant_topic_types)) != len(participant_topic_types): @@ -122,7 +120,7 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No planning_conversation.append( UserMessage(content=self._get_task_ledger_facts_prompt(self._task), source=self._name) ) - response = await self._model_client.create(planning_conversation) + response = await self._model_client.create(planning_conversation, cancellation_token=ctx.cancellation_token) assert isinstance(response.content, str) self._facts = response.content @@ -133,19 +131,19 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No planning_conversation.append( UserMessage(content=self._get_task_ledger_plan_prompt(self._team_description), source=self._name) ) - response = await self._model_client.create(planning_conversation) + response = await self._model_client.create(planning_conversation, cancellation_token=ctx.cancellation_token) assert isinstance(response.content, str) self._plan = response.content # Kick things off self._n_stalls = 0 - await self._reenter_inner_loop() + await self._reenter_inner_loop(ctx.cancellation_token) @event async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None: self._message_thread.append(message.agent_response.chat_message) - await self._orchestrate_step() + await self._orchestrate_step(ctx.cancellation_token) @rpc async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None: @@ -164,12 +162,13 @@ async def reset(self) -> None: async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None: raise ValueError(f"Unhandled message in group chat manager: {type(message)}") - async def _reenter_inner_loop(self) -> None: + async def _reenter_inner_loop(self, cancellation_token: CancellationToken) -> None: # Reset the agents for participant_topic_type in self._participant_topic_types: await self._runtime.send_message( GroupChatReset(), - recipient=AgentId(type=participant_topic_type, key=self._team_id), + recipient=AgentId(type=participant_topic_type, key=self.id.key), + cancellation_token=cancellation_token, ) # Reset the group chat manager await self.reset() @@ -197,12 +196,12 @@ async def _reenter_inner_loop(self) -> None: ) # Restart the inner loop - await self._orchestrate_step() + await self._orchestrate_step(cancellation_token=cancellation_token) - async def _orchestrate_step(self) -> None: + async def _orchestrate_step(self, cancellation_token: CancellationToken) -> None: # Check if we reached the maximum number of rounds if self._max_turns is not None and self._n_rounds > self._max_turns: - await self._prepare_final_answer("Max rounds reached.") + await self._prepare_final_answer("Max rounds reached.", cancellation_token) return self._n_rounds += 1 @@ -221,7 +220,7 @@ async def _orchestrate_step(self) -> None: # Check for task completion if progress_ledger["is_request_satisfied"]["answer"]: - await self._prepare_final_answer(progress_ledger["is_request_satisfied"]["reason"]) + await self._prepare_final_answer(progress_ledger["is_request_satisfied"]["reason"], cancellation_token) return # Check for stalling @@ -234,8 +233,8 @@ async def _orchestrate_step(self) -> None: # Too much stalling if self._n_stalls >= self._max_stalls: - await self._update_task_ledger() - await self._reenter_inner_loop() + await self._update_task_ledger(cancellation_token) + await self._reenter_inner_loop(cancellation_token) return # Broadcst the next step @@ -252,20 +251,23 @@ async def _orchestrate_step(self) -> None: await self.publish_message( # Broadcast GroupChatAgentResponse(agent_response=Response(chat_message=message)), topic_id=DefaultTopicId(type=self._group_topic_type), + cancellation_token=cancellation_token, ) # Request that the step be completed next_speaker = progress_ledger["next_speaker"]["answer"] - await self.publish_message(GroupChatRequestPublish(), topic_id=DefaultTopicId(type=next_speaker)) + await self.publish_message( + GroupChatRequestPublish(), topic_id=DefaultTopicId(type=next_speaker), cancellation_token=cancellation_token + ) - async def _update_task_ledger(self) -> None: + async def _update_task_ledger(self, cancellation_token: CancellationToken) -> None: context = self._thread_to_context() # Update the facts update_facts_prompt = self._get_task_ledger_facts_update_prompt(self._task, self._facts) context.append(UserMessage(content=update_facts_prompt, source=self._name)) - response = await self._model_client.create(context) + response = await self._model_client.create(context, cancellation_token=cancellation_token) assert isinstance(response.content, str) self._facts = response.content @@ -275,19 +277,19 @@ async def _update_task_ledger(self) -> None: update_plan_prompt = self._get_task_ledger_plan_update_prompt(self._team_description) context.append(UserMessage(content=update_plan_prompt, source=self._name)) - response = await self._model_client.create(context) + response = await self._model_client.create(context, cancellation_token=cancellation_token) assert isinstance(response.content, str) self._plan = response.content - async def _prepare_final_answer(self, reason: str) -> None: + async def _prepare_final_answer(self, reason: str, cancellation_token: CancellationToken) -> None: context = self._thread_to_context() # Get the final answer final_answer_prompt = self._get_final_answer_prompt(self._task) context.append(UserMessage(content=final_answer_prompt, source=self._name)) - response = await self._model_client.create(context) + response = await self._model_client.create(context, cancellation_token=cancellation_token) assert isinstance(response.content, str) message = TextMessage(content=response.content, source=self._name) @@ -303,6 +305,7 @@ async def _prepare_final_answer(self, reason: str) -> None: await self.publish_message( GroupChatAgentResponse(agent_response=Response(chat_message=message)), topic_id=DefaultTopicId(type=self._group_topic_type), + cancellation_token=cancellation_token, ) # Signal termination diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 7df27abcbcd..00ac5ee90fe 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -69,18 +69,25 @@ class _EchoAgent(BaseChatAgent): def __init__(self, name: str, description: str) -> None: super().__init__(name, description) self._last_message: str | None = None + self._total_messages = 0 @property def produced_message_types(self) -> List[type[ChatMessage]]: return [TextMessage] + @property + def total_messages(self) -> int: + return self._total_messages + async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: if len(messages) > 0: assert isinstance(messages[0], TextMessage) self._last_message = messages[0].content + self._total_messages += 1 return Response(chat_message=TextMessage(content=messages[0].content, source=self.name)) else: assert self._last_message is not None + self._total_messages += 1 return Response(chat_message=TextMessage(content=self._last_message, source=self.name)) async def on_reset(self, cancellation_token: CancellationToken) -> None: @@ -358,7 +365,7 @@ async def test_round_robin_group_chat_with_resume_and_reset() -> None: @pytest.mark.asyncio -async def test_round_group_chat_max_turn() -> None: +async def test_round_robin_group_chat_max_turn() -> None: agent_1 = _EchoAgent("agent_1", description="echo agent 1") agent_2 = _EchoAgent("agent_2", description="echo agent 2") agent_3 = _EchoAgent("agent_3", description="echo agent 3") @@ -391,6 +398,35 @@ async def test_round_group_chat_max_turn() -> None: assert result.stop_reason is not None +@pytest.mark.asyncio +async def test_round_robin_group_chat_cancellation() -> None: + agent_1 = _EchoAgent("agent_1", description="echo agent 1") + agent_2 = _EchoAgent("agent_2", description="echo agent 2") + agent_3 = _EchoAgent("agent_3", description="echo agent 3") + agent_4 = _EchoAgent("agent_4", description="echo agent 4") + # Set max_turns to a large number to avoid stopping due to max_turns before cancellation. + team = RoundRobinGroupChat(participants=[agent_1, agent_2, agent_3, agent_4], max_turns=1000) + cancellation_token = CancellationToken() + run_task = asyncio.create_task( + team.run( + task="Write a program that prints 'Hello, world!'", + cancellation_token=cancellation_token, + ) + ) + await asyncio.sleep(0.1) + # Cancel the task. + cancellation_token.cancel() + with pytest.raises(asyncio.CancelledError): + await run_task + + # Total messages produced so far. + total_messages = agent_1.total_messages + agent_2.total_messages + agent_3.total_messages + agent_4.total_messages + + # Still can run again and finish the task. + result = await team.run() + assert len(result.messages) + total_messages == 1000 + + @pytest.mark.asyncio async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None: model = "gpt-4o-2024-05-13" diff --git a/python/packages/autogen-agentchat/tests/test_magentic_one_group_chat.py b/python/packages/autogen-agentchat/tests/test_magentic_one_group_chat.py new file mode 100644 index 00000000000..2a8931a93d6 --- /dev/null +++ b/python/packages/autogen-agentchat/tests/test_magentic_one_group_chat.py @@ -0,0 +1,80 @@ +import asyncio +import json +import logging +from typing import List, Sequence + +import pytest +from autogen_agentchat import EVENT_LOGGER_NAME +from autogen_agentchat.agents import ( + BaseChatAgent, +) +from autogen_agentchat.base import Response +from autogen_agentchat.logging import FileLogHandler +from autogen_agentchat.messages import ( + ChatMessage, + TextMessage, +) +from autogen_agentchat.teams import ( + MagenticOneGroupChat, +) +from autogen_core.base import CancellationToken +from autogen_ext.models import ReplayChatCompletionClient + +logger = logging.getLogger(EVENT_LOGGER_NAME) +logger.setLevel(logging.DEBUG) +logger.addHandler(FileLogHandler("test_magentic_one_group_chat.log")) + + +class _EchoAgent(BaseChatAgent): + def __init__(self, name: str, description: str) -> None: + super().__init__(name, description) + self._last_message: str | None = None + self._total_messages = 0 + + @property + def produced_message_types(self) -> List[type[ChatMessage]]: + return [TextMessage] + + @property + def total_messages(self) -> int: + return self._total_messages + + async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: + if len(messages) > 0: + assert isinstance(messages[0], TextMessage) + self._last_message = messages[0].content + self._total_messages += 1 + return Response(chat_message=TextMessage(content=messages[0].content, source=self.name)) + else: + assert self._last_message is not None + self._total_messages += 1 + return Response(chat_message=TextMessage(content=self._last_message, source=self.name)) + + async def on_reset(self, cancellation_token: CancellationToken) -> None: + self._last_message = None + + +@pytest.mark.asyncio +async def test_magentic_one_group_chat_cancellation() -> None: + agent_1 = _EchoAgent("agent_1", description="echo agent 1") + agent_2 = _EchoAgent("agent_2", description="echo agent 2") + agent_3 = _EchoAgent("agent_3", description="echo agent 3") + agent_4 = _EchoAgent("agent_4", description="echo agent 4") + + model_client = ReplayChatCompletionClient( + chat_completions=["test", "test", json.dumps({"is_request_satisfied": {"answer": True, "reason": "test"}})], + ) + + # Set max_turns to a large number to avoid stopping due to max_turns before cancellation. + team = MagenticOneGroupChat(participants=[agent_1, agent_2, agent_3, agent_4], model_client=model_client) + cancellation_token = CancellationToken() + run_task = asyncio.create_task( + team.run( + task="Write a program that prints 'Hello, world!'", + cancellation_token=cancellation_token, + ) + ) + # Cancel the task. + cancellation_token.cancel() + with pytest.raises(asyncio.CancelledError): + await run_task From 52790a8de74bedf41e7b5279c02ffdc1c30770ac Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Wed, 27 Nov 2024 10:45:51 -0800 Subject: [PATCH 2/2] o1 support for agent chat, and validate model capabilities (#4397) --- .../agents/_assistant_agent.py | 58 +++++++++++++++++-- .../tests/test_assistant_agent.py | 24 ++++++++ .../autogen_ext/models/_openai/_model_info.py | 14 +++++ 3 files changed, 92 insertions(+), 4 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 0870a6c2f3b..1edf86f0061 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -23,6 +23,7 @@ AgentMessage, ChatMessage, HandoffMessage, + MultiModalMessage, TextMessage, ToolCallMessage, ToolCallResultMessage, @@ -113,7 +114,10 @@ class AssistantAgent(BaseChatAgent): async def main() -> None: - model_client = OpenAIChatCompletionClient(model="gpt-4o") + model_client = OpenAIChatCompletionClient( + model="gpt-4o", + # api_key = "your_openai_api_key" + ) agent = AssistantAgent(name="assistant", model_client=model_client) response = await agent.on_messages( @@ -144,7 +148,10 @@ async def get_current_time() -> str: async def main() -> None: - model_client = OpenAIChatCompletionClient(model="gpt-4o") + model_client = OpenAIChatCompletionClient( + model="gpt-4o", + # api_key = "your_openai_api_key" + ) agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time]) await Console( @@ -156,6 +163,39 @@ async def main() -> None: asyncio.run(main()) + + The following example shows how to use `o1-mini` model with the assistant agent. + + .. code-block:: python + + import asyncio + from autogen_core.base import CancellationToken + from autogen_ext.models import OpenAIChatCompletionClient + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.messages import TextMessage + + + async def main() -> None: + model_client = OpenAIChatCompletionClient( + model="o1-mini", + # api_key = "your_openai_api_key" + ) + # The system message is not supported by the o1 series model. + agent = AssistantAgent(name="assistant", model_client=model_client, system_message=None) + + response = await agent.on_messages( + [TextMessage(content="What is the capital of France?", source="user")], CancellationToken() + ) + print(response) + + + asyncio.run(main()) + + .. note:: + + The `o1-preview` and `o1-mini` models do not support system message and function calling. + So the `system_message` should be set to `None` and the `tools` and `handoffs` should not be set. + See `o1 beta limitations `_ for more details. """ def __init__( @@ -166,13 +206,19 @@ def __init__( tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None, handoffs: List[Handoff | str] | None = None, description: str = "An agent that provides assistance with ability to use tools.", - system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.", + system_message: str + | None = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.", ): super().__init__(name=name, description=description) self._model_client = model_client - self._system_messages = [SystemMessage(content=system_message)] + if system_message is None: + self._system_messages = [] + else: + self._system_messages = [SystemMessage(content=system_message)] self._tools: List[Tool] = [] if tools is not None: + if model_client.capabilities["function_calling"] is False: + raise ValueError("The model does not support function calling.") for tool in tools: if isinstance(tool, Tool): self._tools.append(tool) @@ -192,6 +238,8 @@ def __init__( self._handoff_tools: List[Tool] = [] self._handoffs: Dict[str, Handoff] = {} if handoffs is not None: + if model_client.capabilities["function_calling"] is False: + raise ValueError("The model does not support function calling, which is needed for handoffs.") for handoff in handoffs: if isinstance(handoff, str): handoff = Handoff(target=handoff) @@ -229,6 +277,8 @@ async def on_messages_stream( ) -> AsyncGenerator[AgentMessage | Response, None]: # Add messages to the model context. for msg in messages: + if isinstance(msg, MultiModalMessage) and self._model_client.capabilities["vision"] is False: + raise ValueError("The model does not support vision.") self._model_context.append(UserMessage(content=msg.content, source=msg.source)) # Inner messages. diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 98ee8c3990d..086ea62ae42 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -233,3 +233,27 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None: img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC" result = await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)])) assert len(result.messages) == 2 + + +@pytest.mark.asyncio +async def test_invalid_model_capabilities() -> None: + model = "random-model" + model_client = OpenAIChatCompletionClient( + model=model, api_key="", model_capabilities={"vision": False, "function_calling": False, "json_output": False} + ) + + with pytest.raises(ValueError): + agent = AssistantAgent( + name="assistant", + model_client=model_client, + tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")], + ) + + with pytest.raises(ValueError): + agent = AssistantAgent(name="assistant", model_client=model_client, handoffs=["agent2"]) + + with pytest.raises(ValueError): + agent = AssistantAgent(name="assistant", model_client=model_client) + # Generate a random base64 image. + img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC" + await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)])) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/_openai/_model_info.py b/python/packages/autogen-ext/src/autogen_ext/models/_openai/_model_info.py index aea2bfb5d1c..3a837915f42 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/_openai/_model_info.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/_openai/_model_info.py @@ -5,6 +5,8 @@ # Based on: https://platform.openai.com/docs/models/continuous-model-upgrades # This is a moving target, so correctness is checked by the model value returned by openai against expected values at runtime`` _MODEL_POINTERS = { + "o1-preview": "o1-preview-2024-09-12", + "o1-mini": "o1-mini-2024-09-12", "gpt-4o": "gpt-4o-2024-08-06", "gpt-4o-mini": "gpt-4o-mini-2024-07-18", "gpt-4-turbo": "gpt-4-turbo-2024-04-09", @@ -16,6 +18,16 @@ } _MODEL_CAPABILITIES: Dict[str, ModelCapabilities] = { + "o1-preview-2024-09-12": { + "vision": False, + "function_calling": False, + "json_output": False, + }, + "o1-mini-2024-09-12": { + "vision": False, + "function_calling": False, + "json_output": False, + }, "gpt-4o-2024-08-06": { "vision": True, "function_calling": True, @@ -89,6 +101,8 @@ } _MODEL_TOKEN_LIMITS: Dict[str, int] = { + "o1-preview-2024-09-12": 128000, + "o1-mini-2024-09-12": 128000, "gpt-4o-2024-08-06": 128000, "gpt-4o-2024-05-13": 128000, "gpt-4o-mini-2024-07-18": 128000,