diff --git a/docs/design/02 - Topics.md b/docs/design/02 - Topics.md index 7d7149c37c9..bf3ed8d9dca 100644 --- a/docs/design/02 - Topics.md +++ b/docs/design/02 - Topics.md @@ -51,3 +51,16 @@ Agents are able to handle certain types of messages. This is an internal detail > [!NOTE] > This might be revisited based on scaling and performance considerations. + +## Well known topic types + +Agents should subscribe via a prefix subscription to the `{AgentType}:` topic as a direct message channel for the agent type. + +For this subscription source should map directly to agent key. + +This subscription will therefore receive all events for the following well known topics: + +- `{AgentType}:` - General purpose direct messages. These should be routed to the approriate message handler. +- `{AgentType}:rpc_request` - RPC request messages. These should be routed to the approriate RPC handler. +- `{AgentType}:rpc_response={RequestId}` - RPC response messages. These should be routed back to the response future of the caller. +- `{AgentType}:error={RequestId}` - Error message that corresponds to the given request. diff --git a/python/packages/autogen-agentchat/pyproject.toml b/python/packages/autogen-agentchat/pyproject.toml index c2336a6eeba..705a125528a 100644 --- a/python/packages/autogen-agentchat/pyproject.toml +++ b/python/packages/autogen-agentchat/pyproject.toml @@ -29,6 +29,7 @@ include = ["src/**", "tests/*.py"] [tool.pyright] extends = "../../pyproject.toml" include = ["src", "tests"] +reportDeprecated = true [tool.pytest.ini_options] minversion = "6.0" diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index fbca2644920..b038d857305 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -14,6 +14,7 @@ MessageContext, ) from autogen_core.components import ClosureAgent, TypeSubscription +from autogen_core.components._closure_agent import ClosureContext from ... import EVENT_LOGGER_NAME from ...base import ChatAgent, TaskResult, Team, TerminationCondition @@ -139,8 +140,7 @@ async def _init(self, runtime: AgentRuntime) -> None: ) async def collect_output_messages( - _runtime: AgentRuntime, - id: AgentId, + _runtime: ClosureContext, message: GroupChatStart | GroupChatMessage | GroupChatTermination, ctx: MessageContext, ) -> None: @@ -150,7 +150,7 @@ async def collect_output_messages( return await self._output_message_queue.put(message.message) - await ClosureAgent.register( + await ClosureAgent.register_closure( runtime, type=self._collector_agent_type, closure=collect_output_messages, diff --git a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/extracting-results-with-an-agent.ipynb b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/extracting-results-with-an-agent.ipynb index c386699f0d7..e86b807d783 100644 --- a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/extracting-results-with-an-agent.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/extracting-results-with-an-agent.ipynb @@ -18,7 +18,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -26,8 +26,8 @@ "from dataclasses import dataclass\n", "\n", "from autogen_core.application import SingleThreadedAgentRuntime\n", - "from autogen_core.base import AgentId, AgentRuntime, MessageContext\n", - "from autogen_core.components import ClosureAgent, DefaultSubscription, DefaultTopicId" + "from autogen_core.base import MessageContext\n", + "from autogen_core.components import ClosureAgent, ClosureContext, DefaultSubscription, DefaultTopicId" ] }, { @@ -77,11 +77,11 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "async def output_result(_runtime: AgentRuntime, id: AgentId, message: FinalResult, ctx: MessageContext) -> None:\n", + "async def output_result(_agent: ClosureContext, message: FinalResult, ctx: MessageContext) -> None:\n", " await queue.put(message)" ] }, @@ -94,7 +94,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -110,7 +110,9 @@ ], "source": [ "runtime = SingleThreadedAgentRuntime()\n", - "await ClosureAgent.register(runtime, \"output_result\", output_result, subscriptions=lambda: [DefaultSubscription()])" + "await ClosureAgent.register_closure(\n", + " runtime, \"output_result\", output_result, subscriptions=lambda: [DefaultSubscription()]\n", + ")" ] }, { diff --git a/python/packages/autogen-core/pyproject.toml b/python/packages/autogen-core/pyproject.toml index 8727f5ee0c2..8f70a64ebeb 100644 --- a/python/packages/autogen-core/pyproject.toml +++ b/python/packages/autogen-core/pyproject.toml @@ -92,7 +92,7 @@ include = ["src/**", "samples/*.py", "docs/**/*.ipynb", "tests/**"] extends = "../../pyproject.toml" include = ["src", "tests", "samples"] exclude = ["src/autogen_core/application/protos", "tests/protos"] -reportDeprecated = false +reportDeprecated = true [tool.pytest.ini_options] minversion = "6.0" diff --git a/python/packages/autogen-core/samples/semantic_router/run_semantic_router.py b/python/packages/autogen-core/samples/semantic_router/run_semantic_router.py index 48f57eece71..35d057fb994 100644 --- a/python/packages/autogen-core/samples/semantic_router/run_semantic_router.py +++ b/python/packages/autogen-core/samples/semantic_router/run_semantic_router.py @@ -32,8 +32,8 @@ WorkerAgentMessage, ) from autogen_core.application import WorkerAgentRuntime -from autogen_core.base import AgentId, AgentRuntime, MessageContext -from autogen_core.components import ClosureAgent, DefaultSubscription, DefaultTopicId +from autogen_core.base import MessageContext +from autogen_core.components import ClosureAgent, ClosureContext, DefaultSubscription, DefaultTopicId class MockIntentClassifier(IntentClassifierBase): @@ -60,12 +60,12 @@ async def get_agent(self, intent: str) -> str: async def output_result( - _runtime: AgentRuntime, id: AgentId, message: WorkerAgentMessage | FinalResult, ctx: MessageContext + closure_ctx: ClosureContext, message: WorkerAgentMessage | FinalResult, ctx: MessageContext ) -> None: if isinstance(message, WorkerAgentMessage): print(f"{message.source} Agent: {message.content}") new_message = input("User response: ") - await _runtime.publish_message( + await closure_ctx.publish_message( UserProxyMessage(content=new_message, source="user"), topic_id=DefaultTopicId(type=message.source, source="user"), ) @@ -73,7 +73,7 @@ async def output_result( print(f"{message.source} Agent: {message.content}") print("Conversation ended") new_message = input("Enter a new conversation start: ") - await _runtime.publish_message( + await closure_ctx.publish_message( UserProxyMessage(content=new_message, source="user"), topic_id=DefaultTopicId(type="default", source="user") ) @@ -95,7 +95,7 @@ async def run_workers(): await agent_runtime.add_subscription(DefaultSubscription(topic_type="user_proxy", agent_type="user_proxy")) # A closure agent surfaces the final result to external systems (e.g. an API) so that the system can interact with the user - await ClosureAgent.register( + await ClosureAgent.register_closure( agent_runtime, "closure_agent", output_result, diff --git a/python/packages/autogen-core/samples/slow_human_in_loop.py b/python/packages/autogen-core/samples/slow_human_in_loop.py index cc8012f601c..348b156493b 100644 --- a/python/packages/autogen-core/samples/slow_human_in_loop.py +++ b/python/packages/autogen-core/samples/slow_human_in_loop.py @@ -33,7 +33,13 @@ from autogen_core.application import SingleThreadedAgentRuntime from autogen_core.base import AgentId, CancellationToken, MessageContext from autogen_core.base.intervention import DefaultInterventionHandler -from autogen_core.components import DefaultSubscription, DefaultTopicId, FunctionCall, RoutedAgent, message_handler +from autogen_core.components import ( + DefaultTopicId, + FunctionCall, + RoutedAgent, + message_handler, + type_subscription, +) from autogen_core.components.model_context import BufferedChatCompletionContext from autogen_core.components.models import ( AssistantMessage, @@ -81,6 +87,7 @@ def save_content(self, content: Mapping[str, Any]) -> None: state_persister = MockPersistence() +@type_subscription("scheduling_assistant_conversation") class SlowUserProxyAgent(RoutedAgent): def __init__( self, @@ -132,6 +139,7 @@ async def run(self, args: ScheduleMeetingInput, cancellation_token: Cancellation return ScheduleMeetingOutput() +@type_subscription("scheduling_assistant_conversation") class SchedulingAssistantAgent(RoutedAgent): def __init__( self, @@ -256,16 +264,13 @@ async def main(latest_user_input: Optional[str] = None) -> None | str: needs_user_input_handler = NeedsUserInputHandler() runtime = SingleThreadedAgentRuntime(intervention_handlers=[needs_user_input_handler, termination_handler]) - await runtime.register( - "User", - lambda: SlowUserProxyAgent("User", "I am a user"), - subscriptions=lambda: [DefaultSubscription("scheduling_assistant_conversation")], - ) + await SlowUserProxyAgent.register(runtime, "User", lambda: SlowUserProxyAgent("User", "I am a user")) initial_schedule_assistant_message = AssistantTextMessage( content="Hi! How can I help you? I can help schedule meetings", source="User" ) - await runtime.register( + await SchedulingAssistantAgent.register( + runtime, "SchedulingAssistant", lambda: SchedulingAssistantAgent( "SchedulingAssistant", @@ -273,7 +278,6 @@ async def main(latest_user_input: Optional[str] = None) -> None | str: model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"), initial_message=initial_schedule_assistant_message, ), - subscriptions=lambda: [DefaultSubscription("scheduling_assistant_conversation")], ) if latest_user_input is not None: diff --git a/python/packages/autogen-core/src/autogen_core/components/__init__.py b/python/packages/autogen-core/src/autogen_core/components/__init__.py index 75bb5eabcbe..37d1ad48a06 100644 --- a/python/packages/autogen-core/src/autogen_core/components/__init__.py +++ b/python/packages/autogen-core/src/autogen_core/components/__init__.py @@ -3,7 +3,7 @@ """ from ..base._type_prefix_subscription import TypePrefixSubscription -from ._closure_agent import ClosureAgent +from ._closure_agent import ClosureAgent, ClosureContext from ._default_subscription import DefaultSubscription, default_subscription, type_subscription from ._default_topic import DefaultTopicId from ._image import Image @@ -16,6 +16,7 @@ "RoutedAgent", "TypeRoutedAgent", "ClosureAgent", + "ClosureContext", "message_handler", "event", "rpc", diff --git a/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py b/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py index 1123c3ee40e..12e5faae6bf 100644 --- a/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py +++ b/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py @@ -1,32 +1,38 @@ +from __future__ import annotations + import inspect -from typing import Any, Awaitable, Callable, List, Mapping, Sequence, TypeVar, get_type_hints +from typing import Any, Awaitable, Callable, List, Mapping, Protocol, Sequence, TypeVar, get_type_hints + +from autogen_core.base._serialization import try_get_known_serializers_for_type +from autogen_core.base._subscription_context import SubscriptionInstantiationContext from ..base import ( - Agent, AgentId, AgentInstantiationContext, AgentMetadata, AgentRuntime, AgentType, + BaseAgent, + CancellationToken, MessageContext, Subscription, - SubscriptionInstantiationContext, - try_get_known_serializers_for_type, + TopicId, ) from ..base._type_helpers import get_types from ..base.exceptions import CantHandleException T = TypeVar("T") +ClosureAgentType = TypeVar("ClosureAgentType", bound="ClosureAgent") def get_handled_types_from_closure( - closure: Callable[[AgentRuntime, AgentId, T, MessageContext], Awaitable[Any]], + closure: Callable[[ClosureAgent, T, MessageContext], Awaitable[Any]], ) -> Sequence[type]: args = inspect.getfullargspec(closure)[0] - if len(args) != 4: + if len(args) != 3: raise AssertionError("Closure must have 4 arguments") - message_arg_name = args[2] + message_arg_name = args[1] type_hints = get_type_hints(closure) @@ -47,9 +53,30 @@ def get_handled_types_from_closure( return target_types -class ClosureAgent(Agent): +class ClosureContext(Protocol): + @property + def id(self) -> AgentId: ... + + async def send_message( + self, + message: Any, + recipient: AgentId, + *, + cancellation_token: CancellationToken | None = None, + ) -> Any: ... + + async def publish_message( + self, + message: Any, + topic_id: TopicId, + *, + cancellation_token: CancellationToken | None = None, + ) -> None: ... + + +class ClosureAgent(BaseAgent, ClosureContext): def __init__( - self, description: str, closure: Callable[[AgentRuntime, AgentId, T, MessageContext], Awaitable[Any]] + self, description: str, closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]] ) -> None: try: runtime = AgentInstantiationContext.current_runtime() @@ -65,6 +92,7 @@ def __init__( handled_types = get_handled_types_from_closure(closure) self._expected_types = handled_types self._closure = closure + super().__init__(description) @property def metadata(self) -> AgentMetadata: @@ -88,7 +116,7 @@ async def on_message(self, message: Any, ctx: MessageContext) -> Any: raise CantHandleException( f"Message type {type(message)} not in target types {self._expected_types} of {self.id}" ) - return await self._closure(self._runtime, self._id, message, ctx) + return await self._closure(self, message, ctx) async def save_state(self) -> Mapping[str, Any]: raise ValueError("save_state not implemented for ClosureAgent") @@ -97,16 +125,28 @@ async def load_state(self, state: Mapping[str, Any]) -> None: raise ValueError("load_state not implemented for ClosureAgent") @classmethod - async def register( + async def register_closure( cls, runtime: AgentRuntime, type: str, - closure: Callable[[AgentRuntime, AgentId, T, MessageContext], Awaitable[Any]], + closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]], *, + skip_class_subscriptions: bool = False, + skip_direct_message_subscription: bool = False, description: str = "", subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None = None, ) -> AgentType: - agent_type = AgentType(type) + def factory() -> ClosureAgent: + return ClosureAgent(description=description, closure=closure) + + agent_type = await cls.register( + runtime=runtime, + type=type, + factory=factory, # type: ignore + skip_class_subscriptions=skip_class_subscriptions, + skip_direct_message_subscription=skip_direct_message_subscription, + ) + subscriptions_list: List[Subscription] = [] if subscriptions is not None: with SubscriptionInstantiationContext.populate_context(agent_type): @@ -117,11 +157,6 @@ async def register( # just ignore mypy here subscriptions_list.extend(subscriptions_list_result) # type: ignore - agent_type = await runtime.register_factory( - type=agent_type, - agent_factory=lambda: ClosureAgent(description=description, closure=closure), - expected_class=cls, - ) for subscription in subscriptions_list: await runtime.add_subscription(subscription) diff --git a/python/packages/autogen-core/tests/test_cancellation.py b/python/packages/autogen-core/tests/test_cancellation.py index d971ef50fc3..67852636f6b 100644 --- a/python/packages/autogen-core/tests/test_cancellation.py +++ b/python/packages/autogen-core/tests/test_cancellation.py @@ -59,7 +59,7 @@ async def on_new_message(self, message: MessageType, ctx: MessageContext) -> Mes async def test_cancellation_with_token() -> None: runtime = SingleThreadedAgentRuntime() - await runtime.register("long_running", LongRunningAgent) + await LongRunningAgent.register(runtime, "long_running", LongRunningAgent) agent_id = AgentId("long_running", key="default") token = CancellationToken() response = asyncio.create_task(runtime.send_message(MessageType(), recipient=agent_id, cancellation_token=token)) @@ -85,8 +85,9 @@ async def test_cancellation_with_token() -> None: async def test_nested_cancellation_only_outer_called() -> None: runtime = SingleThreadedAgentRuntime() - await runtime.register("long_running", LongRunningAgent) - await runtime.register( + await LongRunningAgent.register(runtime, "long_running", LongRunningAgent) + await NestingLongRunningAgent.register( + runtime, "nested", lambda: NestingLongRunningAgent(AgentId("long_running", key=AgentInstantiationContext.current_agent_id().key)), ) @@ -119,8 +120,9 @@ async def test_nested_cancellation_only_outer_called() -> None: async def test_nested_cancellation_inner_called() -> None: runtime = SingleThreadedAgentRuntime() - await runtime.register("long_running", LongRunningAgent) - await runtime.register( + await LongRunningAgent.register(runtime, "long_running", LongRunningAgent) + await NestingLongRunningAgent.register( + runtime, "nested", lambda: NestingLongRunningAgent(AgentId("long_running", key=AgentInstantiationContext.current_agent_id().key)), ) diff --git a/python/packages/autogen-core/tests/test_closure_agent.py b/python/packages/autogen-core/tests/test_closure_agent.py index a8731bca3b8..328fe237434 100644 --- a/python/packages/autogen-core/tests/test_closure_agent.py +++ b/python/packages/autogen-core/tests/test_closure_agent.py @@ -3,9 +3,8 @@ import pytest from autogen_core.application import SingleThreadedAgentRuntime -from autogen_core.base import AgentId, AgentRuntime, MessageContext -from autogen_core.components import ClosureAgent, DefaultSubscription -from autogen_core.components._default_topic import DefaultTopicId +from autogen_core.base import MessageContext +from autogen_core.components import ClosureAgent, ClosureContext, DefaultSubscription, DefaultTopicId @dataclass @@ -19,11 +18,11 @@ async def test_register_receives_publish() -> None: queue = asyncio.Queue[tuple[str, str]]() - async def log_message(_runtime: AgentRuntime, id: AgentId, message: Message, ctx: MessageContext) -> None: - key = id.key + async def log_message(closure_ctx: ClosureContext, message: Message, ctx: MessageContext) -> None: + key = closure_ctx.id.key await queue.put((key, message.content)) - await ClosureAgent.register(runtime, "name", log_message, subscriptions=lambda: [DefaultSubscription()]) + await ClosureAgent.register_closure(runtime, "name", log_message, subscriptions=lambda: [DefaultSubscription()]) runtime.start() await runtime.publish_message(Message("first message"), topic_id=DefaultTopicId()) diff --git a/python/packages/autogen-core/tests/test_intervention.py b/python/packages/autogen-core/tests/test_intervention.py index 6b3d18c7e0a..105df32988b 100644 --- a/python/packages/autogen-core/tests/test_intervention.py +++ b/python/packages/autogen-core/tests/test_intervention.py @@ -18,7 +18,7 @@ async def on_send(self, message: MessageType, *, sender: AgentId | None, recipie handler = DebugInterventionHandler() runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler]) - await runtime.register("name", LoopbackAgent) + await LoopbackAgent.register(runtime, "name", LoopbackAgent) loopback = AgentId("name", key="default") runtime.start() @@ -42,7 +42,7 @@ async def on_send( handler = DropSendInterventionHandler() runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler]) - await runtime.register("name", LoopbackAgent) + await LoopbackAgent.register(runtime, "name", LoopbackAgent) loopback = AgentId("name", key="default") runtime.start() @@ -66,7 +66,7 @@ async def on_response( handler = DropResponseInterventionHandler() runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler]) - await runtime.register("name", LoopbackAgent) + await LoopbackAgent.register(runtime, "name", LoopbackAgent) loopback = AgentId("name", key="default") runtime.start() @@ -90,7 +90,7 @@ async def on_send( handler = ExceptionInterventionHandler() runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler]) - await runtime.register("name", LoopbackAgent) + await LoopbackAgent.register(runtime, "name", LoopbackAgent) loopback = AgentId("name", key="default") runtime.start() @@ -117,7 +117,7 @@ async def on_response( handler = ExceptionInterventionHandler() runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler]) - await runtime.register("name", LoopbackAgent) + await LoopbackAgent.register(runtime, "name", LoopbackAgent) loopback = AgentId("name", key="default") runtime.start() with pytest.raises(InterventionException): diff --git a/python/packages/autogen-core/tests/test_routed_agent.py b/python/packages/autogen-core/tests/test_routed_agent.py index 8da5bf6073f..cab1b1d467f 100644 --- a/python/packages/autogen-core/tests/test_routed_agent.py +++ b/python/packages/autogen-core/tests/test_routed_agent.py @@ -37,7 +37,8 @@ async def on_broadcast_message(self, message: MessageType, ctx: MessageContext) async def test_routed_agent(caplog: pytest.LogCaptureFixture) -> None: runtime = SingleThreadedAgentRuntime() with caplog.at_level(logging.INFO): - await runtime.register("loopback", LoopbackAgent, lambda: [TypeSubscription("default", "loopback")]) + await LoopbackAgent.register(runtime, "loopback", LoopbackAgent) + await runtime.add_subscription(TypeSubscription("default", "loopback")) runtime.start() await runtime.publish_message(UnhandledMessageType(), topic_id=TopicId("default", "default")) await runtime.stop_when_idle() @@ -47,7 +48,8 @@ async def test_routed_agent(caplog: pytest.LogCaptureFixture) -> None: @pytest.mark.asyncio async def test_message_handler_router() -> None: runtime = SingleThreadedAgentRuntime() - await runtime.register("counter", CounterAgent, lambda: [TypeSubscription("default", "counter")]) + await CounterAgent.register(runtime, "counter", CounterAgent) + await runtime.add_subscription(TypeSubscription("default", "counter")) agent_id = AgentId(type="counter", key="default") # Send a broadcast message. @@ -94,7 +96,7 @@ async def handler_two(self, message: TestMessage, ctx: MessageContext) -> None: @pytest.mark.asyncio async def test_routed_agent_message_matching() -> None: runtime = SingleThreadedAgentRuntime() - await runtime.register("message_match", RoutedAgentMessageCustomMatch) + await RoutedAgentMessageCustomMatch.register(runtime, "message_match", RoutedAgentMessageCustomMatch) agent_id = AgentId(type="message_match", key="default") agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RoutedAgentMessageCustomMatch) @@ -134,7 +136,8 @@ async def on_event_two(self, message: TestMessage, ctx: MessageContext) -> None: @pytest.mark.asyncio async def test_event() -> None: runtime = SingleThreadedAgentRuntime() - await runtime.register("counter", EventAgent, lambda: [TypeSubscription("default", "counter")]) + await EventAgent.register(runtime, "counter", EventAgent) + await runtime.add_subscription(TypeSubscription("default", "counter")) agent_id = AgentId(type="counter", key="default") # Send a broadcast message. @@ -181,7 +184,8 @@ async def on_rpc_two(self, message: TestMessage, ctx: MessageContext) -> TestMes @pytest.mark.asyncio async def test_rpc() -> None: runtime = SingleThreadedAgentRuntime() - await runtime.register("counter", RPCAgent, lambda: [TypeSubscription("default", "counter")]) + await RPCAgent.register(runtime, "counter", RPCAgent) + await runtime.add_subscription(TypeSubscription("default", "counter")) agent_id = AgentId(type="counter", key="default") # Send an RPC message. diff --git a/python/packages/autogen-core/tests/test_runtime.py b/python/packages/autogen-core/tests/test_runtime.py index 0f56e42a7e0..b327be1461e 100644 --- a/python/packages/autogen-core/tests/test_runtime.py +++ b/python/packages/autogen-core/tests/test_runtime.py @@ -1,4 +1,3 @@ -import asyncio import logging import pytest @@ -7,16 +6,10 @@ AgentId, AgentInstantiationContext, AgentType, - Subscription, - SubscriptionInstantiationContext, TopicId, try_get_known_serializers_for_type, ) -from autogen_core.components import ( - DefaultTopicId, - TypeSubscription, - type_subscription, -) +from autogen_core.components import DefaultTopicId, TypeSubscription, type_subscription from opentelemetry.sdk.trace import TracerProvider from test_utils import ( CascadingAgent, @@ -146,82 +139,9 @@ async def test_register_receives_publish_cascade() -> None: async def test_register_factory_explicit_name() -> None: runtime = SingleThreadedAgentRuntime() - await runtime.register("name", LoopbackAgent, lambda: [TypeSubscription("default", "name")]) - runtime.start() - agent_id = AgentId("name", key="default") - topic_id = TopicId("default", "default") - await runtime.publish_message(MessageType(), topic_id=topic_id) - - await runtime.stop_when_idle() - - # Agent in default namespace should have received the message - long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent) - assert long_running_agent.num_calls == 1 - - # Agent in other namespace should not have received the message - other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance( - AgentId("name", key="other"), type=LoopbackAgent - ) - assert other_long_running_agent.num_calls == 0 - - -@pytest.mark.asyncio -async def test_register_factory_context_var_name() -> None: - runtime = SingleThreadedAgentRuntime() - - await runtime.register( - "name", LoopbackAgent, lambda: [TypeSubscription("default", SubscriptionInstantiationContext.agent_type().type)] - ) - runtime.start() - agent_id = AgentId("name", key="default") - topic_id = TopicId("default", "default") - await runtime.publish_message(MessageType(), topic_id=topic_id) - - await runtime.stop_when_idle() - - # Agent in default namespace should have received the message - long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent) - assert long_running_agent.num_calls == 1 - - # Agent in other namespace should not have received the message - other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance( - AgentId("name", key="other"), type=LoopbackAgent - ) - assert other_long_running_agent.num_calls == 0 - - -@pytest.mark.asyncio -async def test_register_factory_async() -> None: - runtime = SingleThreadedAgentRuntime() - - async def sub_factory() -> list[Subscription]: - await asyncio.sleep(0.1) - return [TypeSubscription("default", SubscriptionInstantiationContext.agent_type().type)] - - await runtime.register("name", LoopbackAgent, sub_factory) - runtime.start() - agent_id = AgentId("name", key="default") - topic_id = TopicId("default", "default") - await runtime.publish_message(MessageType(), topic_id=topic_id) - - await runtime.stop_when_idle() - - # Agent in default namespace should have received the message - long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent) - assert long_running_agent.num_calls == 1 - - # Agent in other namespace should not have received the message - other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance( - AgentId("name", key="other"), type=LoopbackAgent - ) - assert other_long_running_agent.num_calls == 0 - - -@pytest.mark.asyncio -async def test_register_factory_direct_list() -> None: - runtime = SingleThreadedAgentRuntime() + await LoopbackAgent.register(runtime, "name", LoopbackAgent) + await runtime.add_subscription(TypeSubscription("default", "name")) - await runtime.register("name", LoopbackAgent, [TypeSubscription("default", "name")]) runtime.start() agent_id = AgentId("name", key="default") topic_id = TopicId("default", "default") diff --git a/python/packages/autogen-core/tests/test_state.py b/python/packages/autogen-core/tests/test_state.py index 5d68447365f..7120a9baab4 100644 --- a/python/packages/autogen-core/tests/test_state.py +++ b/python/packages/autogen-core/tests/test_state.py @@ -24,7 +24,7 @@ async def load_state(self, state: Mapping[str, Any]) -> None: async def test_agent_can_save_state() -> None: runtime = SingleThreadedAgentRuntime() - await runtime.register("name1", StatefulAgent) + await StatefulAgent.register(runtime, "name1", StatefulAgent) agent1_id = AgentId("name1", key="default") agent1: StatefulAgent = await runtime.try_get_underlying_agent_instance(agent1_id, type=StatefulAgent) assert agent1.state == 0 @@ -44,7 +44,7 @@ async def test_agent_can_save_state() -> None: async def test_runtime_can_save_state() -> None: runtime = SingleThreadedAgentRuntime() - await runtime.register("name1", StatefulAgent) + await StatefulAgent.register(runtime, "name1", StatefulAgent) agent1_id = AgentId("name1", key="default") agent1: StatefulAgent = await runtime.try_get_underlying_agent_instance(agent1_id, type=StatefulAgent) assert agent1.state == 0 @@ -54,7 +54,7 @@ async def test_runtime_can_save_state() -> None: runtime_state = await runtime.save_state() runtime2 = SingleThreadedAgentRuntime() - await runtime2.register("name1", StatefulAgent) + await StatefulAgent.register(runtime2, "name1", StatefulAgent) agent2_id = AgentId("name1", key="default") agent2: StatefulAgent = await runtime2.try_get_underlying_agent_instance(agent2_id, type=StatefulAgent) diff --git a/python/packages/autogen-core/tests/test_subscription.py b/python/packages/autogen-core/tests/test_subscription.py index c339d549be5..91223acbbc8 100644 --- a/python/packages/autogen-core/tests/test_subscription.py +++ b/python/packages/autogen-core/tests/test_subscription.py @@ -27,7 +27,7 @@ def test_type_subscription_map() -> None: async def test_non_default_default_subscription() -> None: runtime = SingleThreadedAgentRuntime() - await runtime.register("MyAgent", LoopbackAgent) + await LoopbackAgent.register(runtime, "MyAgent", LoopbackAgent, skip_class_subscriptions=True) runtime.start() await runtime.publish_message(MessageType(), topic_id=DefaultTopicId()) await runtime.stop_when_idle() diff --git a/python/packages/autogen-core/tests/test_tool_agent.py b/python/packages/autogen-core/tests/test_tool_agent.py index 6184e9c78c8..bdbd3b96b72 100644 --- a/python/packages/autogen-core/tests/test_tool_agent.py +++ b/python/packages/autogen-core/tests/test_tool_agent.py @@ -43,7 +43,8 @@ async def _async_sleep_function(input: str) -> str: @pytest.mark.asyncio async def test_tool_agent() -> None: runtime = SingleThreadedAgentRuntime() - await runtime.register( + await ToolAgent.register( + runtime, "tool_agent", lambda: ToolAgent( description="Tool agent", @@ -143,7 +144,8 @@ def capabilities(self) -> ModelCapabilities: client = MockChatCompletionClient() tools: List[Tool] = [FunctionTool(_pass_function, name="pass", description="Pass function")] runtime = SingleThreadedAgentRuntime() - await runtime.register( + await ToolAgent.register( + runtime, "tool_agent", lambda: ToolAgent( description="Tool agent",