From 0c9666f5b7a940011574329ddffea33251fd6534 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Wed, 27 Nov 2024 16:32:04 -0500 Subject: [PATCH 1/7] WIP - rpc over events --- docs/design/02 - Topics.md | 2 +- .../framework/agent-and-agent-runtime.ipynb | 2 +- .../_single_threaded_agent_runtime.py | 101 +++++++++--------- .../application/_worker_runtime.py | 1 - .../src/autogen_core/base/_agent.py | 5 +- .../src/autogen_core/base/_agent_id.py | 5 +- .../src/autogen_core/base/_agent_runtime.py | 1 - .../src/autogen_core/base/_base_agent.py | 51 +++++++-- .../src/autogen_core/base/_message_context.py | 3 +- .../src/autogen_core/base/_rpc.py | 31 ++++++ .../autogen_core/components/_closure_agent.py | 16 +-- .../autogen_core/components/_routed_agent.py | 59 ++++++++-- .../components/send_message_mixin.py | 61 +++++++++++ python/packages/autogen-core/test.py | 35 ++++++ .../packages/autogen-core/tests/test_state.py | 2 +- .../packages/autogen-core/tests/test_types.py | 6 +- .../autogen-core/tests/test_utils/__init__.py | 2 +- 17 files changed, 292 insertions(+), 91 deletions(-) create mode 100644 python/packages/autogen-core/src/autogen_core/base/_rpc.py create mode 100644 python/packages/autogen-core/src/autogen_core/components/send_message_mixin.py create mode 100644 python/packages/autogen-core/test.py diff --git a/docs/design/02 - Topics.md b/docs/design/02 - Topics.md index bf3ed8d9dca..c64ac50a274 100644 --- a/docs/design/02 - Topics.md +++ b/docs/design/02 - Topics.md @@ -61,6 +61,6 @@ For this subscription source should map directly to agent key. This subscription will therefore receive all events for the following well known topics: - `{AgentType}:` - General purpose direct messages. These should be routed to the approriate message handler. -- `{AgentType}:rpc_request` - RPC request messages. These should be routed to the approriate RPC handler. +- `{AgentType}:rpc_request={RequesterAgentType}` - RPC request messages. These should be routed to the approriate RPC handler. - `{AgentType}:rpc_response={RequestId}` - RPC response messages. These should be routed back to the response future of the caller. - `{AgentType}:error={RequestId}` - Error message that corresponds to the given request. diff --git a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/agent-and-agent-runtime.ipynb b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/agent-and-agent-runtime.ipynb index fdd7aed5644..5884967418d 100644 --- a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/agent-and-agent-runtime.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/agent-and-agent-runtime.ipynb @@ -67,7 +67,7 @@ " def __init__(self) -> None:\n", " super().__init__(\"MyAgent\")\n", "\n", - " async def on_message(self, message: MyMessageType, ctx: MessageContext) -> None:\n", + " async def on_message_impl(self, message: MyMessageType, ctx: MessageContext) -> None:\n", " print(f\"Received message: {message.content}\") # type: ignore" ] }, diff --git a/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py index 3d81f15eb33..9b402c11c4a 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py @@ -16,6 +16,7 @@ from typing_extensions import deprecated from autogen_core.base._serialization import MessageSerializer, SerializationRegistry +from autogen_core.components.send_message_mixin import PublishBasedRpcMixin from ..base import ( Agent, @@ -166,7 +167,7 @@ def _warn_if_none(value: Any, handler_name: str) -> None: ) -class SingleThreadedAgentRuntime(AgentRuntime): +class SingleThreadedAgentRuntime(PublishBasedRpcMixin, AgentRuntime): def __init__( self, *, @@ -202,54 +203,54 @@ def _known_agent_names(self) -> Set[str]: return set(self._agent_factories.keys()) # Returns the response of the message - async def send_message( - self, - message: Any, - recipient: AgentId, - *, - sender: AgentId | None = None, - cancellation_token: CancellationToken | None = None, - ) -> Any: - if cancellation_token is None: - cancellation_token = CancellationToken() - - # event_logger.info( - # MessageEvent( - # payload=message, - # sender=sender, - # receiver=recipient, - # kind=MessageKind.DIRECT, - # delivery_stage=DeliveryStage.SEND, - # ) - # ) - - with self._tracer_helper.trace_block( - "create", - recipient, - parent=None, - extraAttributes={"message_type": type(message).__name__}, - ): - future = asyncio.get_event_loop().create_future() - if recipient.type not in self._known_agent_names: - future.set_exception(Exception("Recipient not found")) - - content = message.__dict__ if hasattr(message, "__dict__") else message - logger.info(f"Sending message of type {type(message).__name__} to {recipient.type}: {content}") - - self._message_queue.append( - SendMessageEnvelope( - message=message, - recipient=recipient, - future=future, - cancellation_token=cancellation_token, - sender=sender, - metadata=get_telemetry_envelope_metadata(), - ) - ) - - cancellation_token.link_future(future) - - return await future + # async def send_message( + # self, + # message: Any, + # recipient: AgentId, + # *, + # sender: AgentId | None = None, + # cancellation_token: CancellationToken | None = None, + # ) -> Any: + # if cancellation_token is None: + # cancellation_token = CancellationToken() + + # # event_logger.info( + # # MessageEvent( + # # payload=message, + # # sender=sender, + # # receiver=recipient, + # # kind=MessageKind.DIRECT, + # # delivery_stage=DeliveryStage.SEND, + # # ) + # # ) + + # with self._tracer_helper.trace_block( + # "create", + # recipient, + # parent=None, + # extraAttributes={"message_type": type(message).__name__}, + # ): + # future = asyncio.get_event_loop().create_future() + # if recipient.type not in self._known_agent_names: + # future.set_exception(Exception("Recipient not found")) + + # content = message.__dict__ if hasattr(message, "__dict__") else message + # logger.info(f"Sending message of type {type(message).__name__} to {recipient.type}: {content}") + + # self._message_queue.append( + # SendMessageEnvelope( + # message=message, + # recipient=recipient, + # future=future, + # cancellation_token=cancellation_token, + # sender=sender, + # metadata=get_telemetry_envelope_metadata(), + # ) + # ) + + # cancellation_token.link_future(future) + + # return await future async def publish_message( self, @@ -332,7 +333,6 @@ async def _process_send(self, message_envelope: SendMessageEnvelope) -> None: message_context = MessageContext( sender=message_envelope.sender, topic_id=None, - is_rpc=True, cancellation_token=message_envelope.cancellation_token, # Will be fixed when send API removed message_id="NOT_DEFINED_TODO_FIX", @@ -392,7 +392,6 @@ async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> No message_context = MessageContext( sender=message_envelope.sender, topic_id=message_envelope.topic_id, - is_rpc=False, cancellation_token=message_envelope.cancellation_token, message_id=message_envelope.message_id, ) diff --git a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py index 24007fadfc7..24714f29155 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py @@ -497,7 +497,6 @@ async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None: message_context = MessageContext( sender=sender, topic_id=None, - is_rpc=True, cancellation_token=CancellationToken(), message_id=request.request_id, ) diff --git a/python/packages/autogen-core/src/autogen_core/base/_agent.py b/python/packages/autogen-core/src/autogen_core/base/_agent.py index edb5e59b1ce..0202522d08a 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_agent.py +++ b/python/packages/autogen-core/src/autogen_core/base/_agent.py @@ -17,16 +17,13 @@ def id(self) -> AgentId: """ID of the agent.""" ... - async def on_message(self, message: Any, ctx: MessageContext) -> Any: + async def on_message(self, message: Any, ctx: MessageContext) -> None: """Message handler for the agent. This should only be called by the runtime, not by other agents. Args: message (Any): Received message. Type is one of the types in `subscriptions`. ctx (MessageContext): Context of the message. - Returns: - Any: Response to the message. Can be None. - Raises: asyncio.CancelledError: If the message was cancelled. CantHandleException: If the agent cannot handle the message. diff --git a/python/packages/autogen-core/src/autogen_core/base/_agent_id.py b/python/packages/autogen-core/src/autogen_core/base/_agent_id.py index 06f163ed9c3..b3a129ac417 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_agent_id.py +++ b/python/packages/autogen-core/src/autogen_core/base/_agent_id.py @@ -8,8 +8,9 @@ def __init__(self, type: str | AgentType, key: str) -> None: if isinstance(type, AgentType): type = type.type - if type.isidentifier() is False: - raise ValueError(f"Invalid type: {type}") + # TODO: fixme + # if type.isidentifier() is False: + # raise ValueError(f"Invalid type: {type}") self._type = type self._key = key diff --git a/python/packages/autogen-core/src/autogen_core/base/_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/base/_agent_runtime.py index 27c37ad9f34..0cd65ef383b 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/base/_agent_runtime.py @@ -26,7 +26,6 @@ async def send_message( message: Any, recipient: AgentId, *, - sender: AgentId | None = None, cancellation_token: CancellationToken | None = None, ) -> Any: """Send a message to an agent and get a response. diff --git a/python/packages/autogen-core/src/autogen_core/base/_base_agent.py b/python/packages/autogen-core/src/autogen_core/base/_base_agent.py index 70481705ca6..25ce6179b43 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/base/_base_agent.py @@ -1,13 +1,17 @@ from __future__ import annotations import inspect +import uuid import warnings from abc import ABC, abstractmethod +from asyncio import Future from collections.abc import Sequence -from typing import Any, Awaitable, Callable, ClassVar, List, Mapping, Tuple, Type, TypeVar +from typing import Any, Awaitable, Callable, ClassVar, Dict, List, Mapping, Tuple, Type, TypeVar, final from typing_extensions import Self +from autogen_core.base._rpc import format_rpc_request_topic, is_rpc_response + from ._agent import Agent from ._agent_id import AgentId from ._agent_instantiation import AgentInstantiationContext @@ -53,7 +57,6 @@ def decorator(cls: Type[BaseAgentType]) -> Type[BaseAgentType]: return decorator - class BaseAgent(ABC, Agent): internal_unbound_subscriptions_list: ClassVar[List[UnboundSubscription]] = [] internal_extra_handles_types: ClassVar[List[Tuple[Type[Any], List[MessageSerializer[Any]]]]] = [] @@ -77,7 +80,17 @@ def metadata(self) -> AgentMetadata: assert self._id is not None return AgentMetadata(key=self._id.key, type=self._id.type, description=self._description) - def __init__(self, description: str) -> None: + def __init__(self, description: str, *, forward_unbound_rpc_responses_to_handler: bool = False) -> None: + """Base agent that all agents should inherit from. Puts in place assumed common functionality. + + Args: + description (str): Description of the agent. + forward_unbound_rpc_responses_to_handler (bool, optional): If an rpc request ID is not know to the agent, should the rpc request be forwarded to the handler. Defaults to False. + + Raises: + RuntimeError: If the agent is not instantiated within the context of an AgentRuntime. + ValueError: If there is an argument type error. + """ try: runtime = AgentInstantiationContext.current_runtime() id = AgentInstantiationContext.current_agent_id() @@ -91,6 +104,8 @@ def __init__(self, description: str) -> None: if not isinstance(description, str): raise ValueError("Agent description must be a string") self._description = description + self._pending_rpc_requests: Dict[str, Future[Any]] = {} + self._forward_unbound_rpc_responses_to_handler = forward_unbound_rpc_responses_to_handler @property def type(self) -> str: @@ -105,7 +120,21 @@ def runtime(self) -> AgentRuntime: return self._runtime @abstractmethod - async def on_message(self, message: Any, ctx: MessageContext) -> Any: ... + async def on_message_impl(self, message: Any, ctx: MessageContext) -> None: ... + + @final + async def on_message(self, message: Any, ctx: MessageContext) -> None: + # Intercept RPC responses + if ctx.topic_id is not None and (request_id := is_rpc_response(ctx.topic_id.type)) is not None: + if request_id in self._pending_rpc_requests: + self._pending_rpc_requests[request_id].set_result(message) + elif self._forward_unbound_rpc_responses_to_handler: + await self.on_message_impl(message, ctx) + else: + warnings.warn(f"Received RPC response for unknown request {request_id}. To forward unbound rpc responses to the handler, set forward_unbound_rpc_responses_to_handler=True", stacklevel=2) + return None + + return await self.on_message_impl(message, ctx) async def send_message( self, @@ -118,13 +147,23 @@ async def send_message( if cancellation_token is None: cancellation_token = CancellationToken() - return await self._runtime.send_message( + recipient_topic = TopicId(type=format_rpc_request_topic(rpc_recipient_agent_type=recipient.type, rpc_sender_agent_type=self.id.type), source=recipient.key) + request_id = str(uuid.uuid4()) + + future = Future[Any]() + + await self._runtime.publish_message( message, sender=self.id, - recipient=recipient, + topic_id=recipient_topic, cancellation_token=cancellation_token, + message_id=request_id, ) + self._pending_rpc_requests[request_id] = future + + return future + async def publish_message( self, message: Any, diff --git a/python/packages/autogen-core/src/autogen_core/base/_message_context.py b/python/packages/autogen-core/src/autogen_core/base/_message_context.py index c5c00559ed0..65cbbb64d4b 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_message_context.py +++ b/python/packages/autogen-core/src/autogen_core/base/_message_context.py @@ -8,7 +8,6 @@ @dataclass class MessageContext: sender: AgentId | None - topic_id: TopicId | None - is_rpc: bool + topic_id: TopicId cancellation_token: CancellationToken message_id: str diff --git a/python/packages/autogen-core/src/autogen_core/base/_rpc.py b/python/packages/autogen-core/src/autogen_core/base/_rpc.py new file mode 100644 index 00000000000..d6b857afd0e --- /dev/null +++ b/python/packages/autogen-core/src/autogen_core/base/_rpc.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import Optional + + + + +def format_rpc_request_topic(rpc_recipient_agent_type: str, rpc_sender_agent_type: str) -> str: + return f"{rpc_recipient_agent_type}:rpc_request={rpc_sender_agent_type}" + +def format_rpc_response_topic(rpc_sender_agent_type: str,request_id: str) -> str: + return f"{rpc_sender_agent_type}:rpc_response={request_id}" + +# If is an rpc response, return the request id +def is_rpc_response(topic_type: str) -> Optional[str]: + topic_segments = topic_type.split(":") + # Find if there is a segment starting with :rpc_response= + for segment in topic_segments: + if segment.startswith("rpc_response="): + return segment[len("rpc_response=") :] + return None + + +# If is an rpc response, return the requestor agent type +def is_rpc_request(topic_type: str) -> Optional[str]: + topic_segments = topic_type.split(":") + # Find if there is a segment starting with :rpc_request= + for segment in topic_segments: + if segment.startswith("rpc_request="): + return segment[len("rpc_request=") :] + return None diff --git a/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py b/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py index 12e5faae6bf..d6b100547bd 100644 --- a/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py +++ b/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py @@ -18,7 +18,7 @@ Subscription, TopicId, ) -from ..base._type_helpers import get_types +from ..base._type_helpers import AnyType, get_types from ..base.exceptions import CantHandleException T = TypeVar("T") @@ -76,7 +76,7 @@ async def publish_message( class ClosureAgent(BaseAgent, ClosureContext): def __init__( - self, description: str, closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]] + self, description: str, closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]], *, forward_unbound_rpc_responses_to_handler: bool = False ) -> None: try: runtime = AgentInstantiationContext.current_runtime() @@ -92,7 +92,7 @@ def __init__( handled_types = get_handled_types_from_closure(closure) self._expected_types = handled_types self._closure = closure - super().__init__(description) + super().__init__(description, forward_unbound_rpc_responses_to_handler=forward_unbound_rpc_responses_to_handler) @property def metadata(self) -> AgentMetadata: @@ -111,8 +111,8 @@ def id(self) -> AgentId: def runtime(self) -> AgentRuntime: return self._runtime - async def on_message(self, message: Any, ctx: MessageContext) -> Any: - if type(message) not in self._expected_types: + async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any: + if AnyType not in self._expected_types and type(message) not in self._expected_types: raise CantHandleException( f"Message type {type(message)} not in target types {self._expected_types} of {self.id}" ) @@ -131,19 +131,19 @@ async def register_closure( type: str, closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]], *, - skip_class_subscriptions: bool = False, skip_direct_message_subscription: bool = False, + forward_unbound_rpc_responses_to_handler: bool = False, description: str = "", subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None = None, ) -> AgentType: def factory() -> ClosureAgent: - return ClosureAgent(description=description, closure=closure) + return ClosureAgent(description=description, closure=closure, forward_unbound_rpc_responses_to_handler=forward_unbound_rpc_responses_to_handler) agent_type = await cls.register( runtime=runtime, type=type, factory=factory, # type: ignore - skip_class_subscriptions=skip_class_subscriptions, + skip_class_subscriptions=True, skip_direct_message_subscription=skip_direct_message_subscription, ) diff --git a/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py b/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py index e7f266bf49d..50ef43aa67f 100644 --- a/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py +++ b/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py @@ -19,13 +19,16 @@ runtime_checkable, ) +from autogen_core.base._rpc import format_rpc_response_topic, is_rpc_request +from autogen_core.base._topic import TopicId + from ..base import BaseAgent, MessageContext, MessageSerializer, try_get_known_serializers_for_type from ..base._type_helpers import AnyType, get_types from ..base.exceptions import CantHandleException logger = logging.getLogger("autogen_core") -AgentT = TypeVar("AgentT") +AgentT = TypeVar("AgentT", bound=BaseAgent) ReceivesT = TypeVar("ReceivesT") ProducesT = TypeVar("ProducesT", covariant=True) @@ -138,7 +141,7 @@ def decorator( # Convert target_types to list and stash @wraps(func) - async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT: + async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> None: if type(message) not in target_types: if strict: raise CantHandleException(f"Message type {type(message)} not in target types {target_types}") @@ -153,7 +156,26 @@ async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> Prod else: logger.warning(f"Return type {type(return_value)} not in return types {return_types}") - return return_value + # Dont return, but publish it if you need to... + # Any return is treated as a response to the RPC request and is published accordingly + + if return_value is not None: + if (requestor_type := is_rpc_request(ctx.topic_id.type)) is not None: + response_topic_id = TopicId( + type=format_rpc_response_topic(rpc_sender_agent_type=requestor_type, request_id=ctx.message_id), + source=self.id.key, + ) + + await self.publish_message( + message=return_value, + topic_id=response_topic_id, + cancellation_token=ctx.cancellation_token, + ) + else: + warnings.warn( + "Returning a value from a message handler that is not an RPC request. This value will be ignored.", + stacklevel=2, + ) wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, ProducesT], wrapper) wrapper_handler.target_types = list(target_types) @@ -278,8 +300,8 @@ async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> None wrapper_handler.target_types = list(target_types) wrapper_handler.produces_types = list(return_types) wrapper_handler.is_message_handler = True - # Wrap the match function with a check on the is_rpc flag. - wrapper_handler.router = lambda _message, _ctx: (not _ctx.is_rpc) and (match(_message, _ctx) if match else True) + # Wrap the match function with a check on the topic for rpc + wrapper_handler.router = lambda _message, _ctx: (is_rpc_request(_ctx.topic_id.type) is None) and (match(_message, _ctx) if match else True) return wrapper_handler @@ -378,7 +400,7 @@ def decorator( # Convert target_types to list and stash @wraps(func) - async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT: + async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> None: if type(message) not in target_types: if strict: raise CantHandleException(f"Message type {type(message)} not in target types {target_types}") @@ -393,13 +415,32 @@ async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> Prod else: logger.warning(f"Return type {type(return_value)} not in return types {return_types}") - return return_value + # Dont return, but publish it if you need to... + # Any return is treated as a response to the RPC request and is published accordingly + + if return_value is not None: + if (requestor_type := is_rpc_request(ctx.topic_id.type)) is not None: + response_topic_id = TopicId( + type=format_rpc_response_topic(rpc_sender_agent_type=requestor_type, request_id=ctx.message_id), + source=self.id.key, + ) + + await self.publish_message( + message=return_value, + topic_id=response_topic_id, + cancellation_token=ctx.cancellation_token, + ) + else: + warnings.warn( + "Returning a value from a message handler that is not an RPC request. This value will be ignored.", + stacklevel=2, + ) wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, ProducesT], wrapper) wrapper_handler.target_types = list(target_types) wrapper_handler.produces_types = list(return_types) wrapper_handler.is_message_handler = True - wrapper_handler.router = lambda _message, _ctx: (_ctx.is_rpc) and (match(_message, _ctx) if match else True) + wrapper_handler.router = lambda _message, _ctx: (is_rpc_request(_ctx.topic_id.type) is not None) and (match(_message, _ctx) if match else True) return wrapper_handler @@ -470,7 +511,7 @@ def __init__(self, description: str) -> None: super().__init__(description) - async def on_message(self, message: Any, ctx: MessageContext) -> Any | None: + async def on_message_impl(self, message: Any, ctx: MessageContext): """Handle a message by routing it to the appropriate message handler. Do not override this method in subclasses. Instead, add message handlers as methods decorated with either the :func:`event` or :func:`rpc` decorator.""" diff --git a/python/packages/autogen-core/src/autogen_core/components/send_message_mixin.py b/python/packages/autogen-core/src/autogen_core/components/send_message_mixin.py new file mode 100644 index 00000000000..2ae0e09eeba --- /dev/null +++ b/python/packages/autogen-core/src/autogen_core/components/send_message_mixin.py @@ -0,0 +1,61 @@ + +from typing import Any +import uuid +import warnings + +from autogen_core.base._rpc import format_rpc_request_topic, format_rpc_response_topic +from autogen_core.base._topic import TopicId + +from ..base._message_context import MessageContext + +from ..base._agent_id import AgentId +from ..base._cancellation_token import CancellationToken +from ._closure_agent import ClosureAgent, ClosureContext +from ..base._agent_runtime import AgentRuntime + + +import asyncio + +class PublishBasedRpcMixin(AgentRuntime): + async def send_message( + self: AgentRuntime, + message: Any, + recipient: AgentId, + *, + cancellation_token: CancellationToken | None = None, + ) -> Any: + + rpc_request_id = str(uuid.uuid4()) + # TODO add "-" to topic and agent type allowed characters in spec + closure_agent_type = f"rpc_receiver_{recipient.type}_{rpc_request_id}" + + future: asyncio.Future[Any] = asyncio.Future() + expected_response_topic_type = format_rpc_response_topic(rpc_sender_agent_type=closure_agent_type, request_id=rpc_request_id) + async def set_result(closure_context:ClosureContext, message: Any, ctx: MessageContext) -> None: + assert ctx.topic_id is not None + if ctx.topic_id.type == expected_response_topic_type: + future.set_result(message) + else: + warnings.warn(f"{closure_agent_type} received an unexpected message on topic type {ctx.topic_id.type}. Expected {expected_response_topic_type}", stacklevel=2) + + # TODO: remove agent after response is received + + await ClosureAgent.register_closure( + runtime=self, + type=closure_agent_type, + closure=set_result, + forward_unbound_rpc_responses_to_handler=True, + ) + + rpc_request_topic_id = format_rpc_request_topic(rpc_recipient_agent_type=recipient.type, rpc_sender_agent_type=closure_agent_type) + await self.publish_message( + message=message, + topic_id=TopicId(type=rpc_request_topic_id, source=recipient.key), + cancellation_token=cancellation_token, + message_id=rpc_request_id, + ) + + return await future + + # register a closure agent... + diff --git a/python/packages/autogen-core/test.py b/python/packages/autogen-core/test.py new file mode 100644 index 00000000000..d8c5672d1fe --- /dev/null +++ b/python/packages/autogen-core/test.py @@ -0,0 +1,35 @@ +from dataclasses import dataclass + +from autogen_core.base import MessageContext +from autogen_core.base._agent_id import AgentId +from autogen_core.components import RoutedAgent +from autogen_core.components._routed_agent import rpc + +from autogen_core.application import SingleThreadedAgentRuntime +import asyncio + +@dataclass +class Message: + content: str + +class MyAgent(RoutedAgent): + def __init__(self) -> None: + super().__init__("My agent") + + @rpc + async def handle_message(self, message: Message, ctx: MessageContext) -> Message: + print(f"Received message: {message.content}") + return Message(content=f"I got: {message.content}") + +async def main(): + runtime = SingleThreadedAgentRuntime() + + await MyAgent.register(runtime, "my_agent", MyAgent) + + runtime.start() + print(await runtime.send_message( + Message("I'm sending you this"), recipient=AgentId("my_agent", "default") + )) + await runtime.stop_when_idle() + +asyncio.run(main()) diff --git a/python/packages/autogen-core/tests/test_state.py b/python/packages/autogen-core/tests/test_state.py index 7120a9baab4..ba4fe86cf13 100644 --- a/python/packages/autogen-core/tests/test_state.py +++ b/python/packages/autogen-core/tests/test_state.py @@ -10,7 +10,7 @@ def __init__(self) -> None: super().__init__("A stateful agent") self.state = 0 - async def on_message(self, message: Any, ctx: MessageContext) -> None: + async def on_message_impl(self, message: Any, ctx: MessageContext) -> None: raise NotImplementedError async def save_state(self) -> Mapping[str, Any]: diff --git a/python/packages/autogen-core/tests/test_types.py b/python/packages/autogen-core/tests/test_types.py index 1dbc02c4fa9..3959456b35b 100644 --- a/python/packages/autogen-core/tests/test_types.py +++ b/python/packages/autogen-core/tests/test_types.py @@ -5,7 +5,7 @@ from autogen_core.base import MessageContext from autogen_core.base._serialization import has_nested_base_model from autogen_core.base._type_helpers import AnyType, get_types -from autogen_core.components._routed_agent import message_handler +from autogen_core.components._routed_agent import RoutedAgent, message_handler from pydantic import BaseModel @@ -21,7 +21,7 @@ def test_get_types() -> None: def test_handler() -> None: - class HandlerClass: + class HandlerClass(RoutedAgent): @message_handler() async def handler(self, message: int, ctx: MessageContext) -> Any: return None @@ -37,7 +37,7 @@ async def handler2(self, message: str | bool, ctx: MessageContext) -> None: assert HandlerClass.handler2.produces_types == [NoneType] -class HandlerClass: +class HandlerClass(RoutedAgent): @message_handler() async def handler(self, message: int, ctx: MessageContext) -> Any: return None diff --git a/python/packages/autogen-core/tests/test_utils/__init__.py b/python/packages/autogen-core/tests/test_utils/__init__.py index 5de7519fc49..3b1ac1101fc 100644 --- a/python/packages/autogen-core/tests/test_utils/__init__.py +++ b/python/packages/autogen-core/tests/test_utils/__init__.py @@ -57,5 +57,5 @@ class NoopAgent(BaseAgent): def __init__(self) -> None: super().__init__("A no op agent") - async def on_message(self, message: Any, ctx: MessageContext) -> Any: + async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any: raise NotImplementedError From f7a6d481c7489867717d1e73952dd5aae9a1ed56 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Sun, 1 Dec 2024 14:43:35 -0500 Subject: [PATCH 2/7] remove handled rpc --- .../packages/autogen-core/src/autogen_core/base/_base_agent.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/packages/autogen-core/src/autogen_core/base/_base_agent.py b/python/packages/autogen-core/src/autogen_core/base/_base_agent.py index 25ce6179b43..f2baa6d1a07 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/base/_base_agent.py @@ -128,13 +128,14 @@ async def on_message(self, message: Any, ctx: MessageContext) -> None: if ctx.topic_id is not None and (request_id := is_rpc_response(ctx.topic_id.type)) is not None: if request_id in self._pending_rpc_requests: self._pending_rpc_requests[request_id].set_result(message) + del self._pending_rpc_requests[request_id] elif self._forward_unbound_rpc_responses_to_handler: await self.on_message_impl(message, ctx) else: warnings.warn(f"Received RPC response for unknown request {request_id}. To forward unbound rpc responses to the handler, set forward_unbound_rpc_responses_to_handler=True", stacklevel=2) return None - return await self.on_message_impl(message, ctx) + await self.on_message_impl(message, ctx) async def send_message( self, From 4170415e27e07459202f247a1088a969c5972413 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Sun, 1 Dec 2024 14:54:40 -0500 Subject: [PATCH 3/7] move module --- .../application/_single_threaded_agent_runtime.py | 4 ++-- .../{send_message_mixin.py => _publish_based_rpc.py} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename python/packages/autogen-core/src/autogen_core/components/{send_message_mixin.py => _publish_based_rpc.py} (100%) diff --git a/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py index 9b402c11c4a..f56618da512 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py @@ -15,8 +15,8 @@ from opentelemetry.trace import TracerProvider from typing_extensions import deprecated -from autogen_core.base._serialization import MessageSerializer, SerializationRegistry -from autogen_core.components.send_message_mixin import PublishBasedRpcMixin +from ..base._serialization import MessageSerializer, SerializationRegistry +from ..components._publish_based_rpc import PublishBasedRpcMixin from ..base import ( Agent, diff --git a/python/packages/autogen-core/src/autogen_core/components/send_message_mixin.py b/python/packages/autogen-core/src/autogen_core/components/_publish_based_rpc.py similarity index 100% rename from python/packages/autogen-core/src/autogen_core/components/send_message_mixin.py rename to python/packages/autogen-core/src/autogen_core/components/_publish_based_rpc.py From 7e62aad51ab16dcf5c2baf5ca0ad86927b9b21db Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Sun, 1 Dec 2024 14:59:49 -0500 Subject: [PATCH 4/7] remove rpc from single threaded runtime --- .../_single_threaded_agent_runtime.py | 249 ++---------------- 1 file changed, 28 insertions(+), 221 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py index f56618da512..763dada6ab5 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py @@ -6,7 +6,7 @@ import threading import uuid import warnings -from asyncio import CancelledError, Future, Task +from asyncio import CancelledError, Task from collections.abc import Sequence from dataclasses import dataclass from enum import Enum @@ -32,7 +32,7 @@ SubscriptionInstantiationContext, TopicId, ) -from ..base.exceptions import MessageDroppedException + from ..base.intervention import DropMessage, InterventionHandler from ._helpers import SubscriptionManager, get_impl from .telemetry import EnvelopeMetadata, MessageRuntimeTracingConfig, TraceHelper, get_telemetry_envelope_metadata @@ -58,30 +58,6 @@ class PublishMessageEnvelope: message_id: str -@dataclass(kw_only=True) -class SendMessageEnvelope: - """A message envelope for sending a message to a specific agent that can handle - the message of the type T.""" - - message: Any - sender: AgentId | None - recipient: AgentId - future: Future[Any] - cancellation_token: CancellationToken - metadata: EnvelopeMetadata | None = None - - -@dataclass(kw_only=True) -class ResponseMessageEnvelope: - """A message envelope for sending a response to a message.""" - - message: Any - future: Future[Any] - sender: AgentId - recipient: AgentId | None - metadata: EnvelopeMetadata | None = None - - P = ParamSpec("P") T = TypeVar("T", bound=Agent) @@ -175,7 +151,7 @@ def __init__( tracer_provider: TracerProvider | None = None, ) -> None: self._tracer_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("SingleThreadedAgentRuntime")) - self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = [] + self._message_queue: List[PublishMessageEnvelope] = [] # (namespace, type) -> List[AgentId] self._agent_factories: Dict[ str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]] @@ -191,7 +167,7 @@ def __init__( @property def unprocessed_messages( self, - ) -> Sequence[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope]: + ) -> Sequence[PublishMessageEnvelope]: return self._message_queue @property @@ -202,56 +178,6 @@ def outstanding_tasks(self) -> int: def _known_agent_names(self) -> Set[str]: return set(self._agent_factories.keys()) - # Returns the response of the message - # async def send_message( - # self, - # message: Any, - # recipient: AgentId, - # *, - # sender: AgentId | None = None, - # cancellation_token: CancellationToken | None = None, - # ) -> Any: - # if cancellation_token is None: - # cancellation_token = CancellationToken() - - # # event_logger.info( - # # MessageEvent( - # # payload=message, - # # sender=sender, - # # receiver=recipient, - # # kind=MessageKind.DIRECT, - # # delivery_stage=DeliveryStage.SEND, - # # ) - # # ) - - # with self._tracer_helper.trace_block( - # "create", - # recipient, - # parent=None, - # extraAttributes={"message_type": type(message).__name__}, - # ): - # future = asyncio.get_event_loop().create_future() - # if recipient.type not in self._known_agent_names: - # future.set_exception(Exception("Recipient not found")) - - # content = message.__dict__ if hasattr(message, "__dict__") else message - # logger.info(f"Sending message of type {type(message).__name__} to {recipient.type}: {content}") - - # self._message_queue.append( - # SendMessageEnvelope( - # message=message, - # recipient=recipient, - # future=future, - # cancellation_token=cancellation_token, - # sender=sender, - # metadata=get_telemetry_envelope_metadata(), - # ) - # ) - - # cancellation_token.link_future(future) - - # return await future - async def publish_message( self, message: Any, @@ -308,61 +234,6 @@ async def load_state(self, state: Mapping[str, Any]) -> None: if agent_id.type in self._known_agent_names: await (await self._get_agent(agent_id)).load_state(state[str(agent_id)]) - async def _process_send(self, message_envelope: SendMessageEnvelope) -> None: - with self._tracer_helper.trace_block("send", message_envelope.recipient, parent=message_envelope.metadata): - recipient = message_envelope.recipient - # todo: check if recipient is in the known namespaces - # assert recipient in self._agents - - try: - # TODO use id - sender_name = message_envelope.sender.type if message_envelope.sender is not None else "Unknown" - logger.info( - f"Calling message handler for {recipient} with message type {type(message_envelope.message).__name__} sent by {sender_name}" - ) - # event_logger.info( - # MessageEvent( - # payload=message_envelope.message, - # sender=message_envelope.sender, - # receiver=recipient, - # kind=MessageKind.DIRECT, - # delivery_stage=DeliveryStage.DELIVER, - # ) - # ) - recipient_agent = await self._get_agent(recipient) - message_context = MessageContext( - sender=message_envelope.sender, - topic_id=None, - cancellation_token=message_envelope.cancellation_token, - # Will be fixed when send API removed - message_id="NOT_DEFINED_TODO_FIX", - ) - with MessageHandlerContext.populate_context(recipient_agent.id): - response = await recipient_agent.on_message( - message_envelope.message, - ctx=message_context, - ) - except CancelledError as e: - if not message_envelope.future.cancelled(): - message_envelope.future.set_exception(e) - self._outstanding_tasks.decrement() - return - except BaseException as e: - message_envelope.future.set_exception(e) - self._outstanding_tasks.decrement() - return - - self._message_queue.append( - ResponseMessageEnvelope( - message=response, - future=message_envelope.future, - sender=message_envelope.recipient, - recipient=message_envelope.sender, - metadata=get_telemetry_envelope_metadata(), - ) - ) - self._outstanding_tasks.decrement() - async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None: with self._tracer_helper.trace_block("publish", message_envelope.topic_id, parent=message_envelope.metadata): try: @@ -418,29 +289,6 @@ async def _on_message(agent: Agent, message_context: MessageContext) -> Any: self._outstanding_tasks.decrement() # TODO if responses are given for a publish - async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None: - with self._tracer_helper.trace_block("ack", message_envelope.recipient, parent=message_envelope.metadata): - content = ( - message_envelope.message.__dict__ - if hasattr(message_envelope.message, "__dict__") - else message_envelope.message - ) - logger.info( - f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {message_envelope.recipient} from {message_envelope.sender.type}: {content}" - ) - # event_logger.info( - # MessageEvent( - # payload=message_envelope.message, - # sender=message_envelope.sender, - # receiver=message_envelope.recipient, - # kind=MessageKind.RESPOND, - # delivery_stage=DeliveryStage.DELIVER, - # ) - # ) - self._outstanding_tasks.decrement() - if not message_envelope.future.cancelled(): - message_envelope.future.set_result(message_envelope.message) - async def process_next(self) -> None: """Process the next message in the queue.""" @@ -450,71 +298,30 @@ async def process_next(self) -> None: return message_envelope = self._message_queue.pop(0) - match message_envelope: - case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future): - if self._intervention_handlers is not None: - for handler in self._intervention_handlers: - with self._tracer_helper.trace_block( - "intercept", handler.__class__.__name__, parent=message_envelope.metadata - ): - try: - temp_message = await handler.on_send(message, sender=sender, recipient=recipient) - _warn_if_none(temp_message, "on_send") - except BaseException as e: - future.set_exception(e) - return - if temp_message is DropMessage or isinstance(temp_message, DropMessage): - future.set_exception(MessageDroppedException()) - return - - message_envelope.message = temp_message - self._outstanding_tasks.increment() - task = asyncio.create_task(self._process_send(message_envelope)) - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) - case PublishMessageEnvelope( - message=message, - sender=sender, - ): - if self._intervention_handlers is not None: - for handler in self._intervention_handlers: - with self._tracer_helper.trace_block( - "intercept", handler.__class__.__name__, parent=message_envelope.metadata - ): - try: - temp_message = await handler.on_publish(message, sender=sender) - _warn_if_none(temp_message, "on_publish") - except BaseException as e: - # TODO: we should raise the intervention exception to the publisher. - logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True) - return - if temp_message is DropMessage or isinstance(temp_message, DropMessage): - # TODO log message dropped - return - - message_envelope.message = temp_message - self._outstanding_tasks.increment() - task = asyncio.create_task(self._process_publish(message_envelope)) - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) - case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future): - if self._intervention_handlers is not None: - for handler in self._intervention_handlers: - try: - temp_message = await handler.on_response(message, sender=sender, recipient=recipient) - _warn_if_none(temp_message, "on_response") - except BaseException as e: - # TODO: should we raise the exception to sender of the response instead? - future.set_exception(e) - return - if temp_message is DropMessage or isinstance(temp_message, DropMessage): - future.set_exception(MessageDroppedException()) - return - message_envelope.message = temp_message - self._outstanding_tasks.increment() - task = asyncio.create_task(self._process_response(message_envelope)) - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) + message = message_envelope.message + sender = message_envelope.sender + + if self._intervention_handlers is not None: + for handler in self._intervention_handlers: + with self._tracer_helper.trace_block( + "intercept", handler.__class__.__name__, parent=message_envelope.metadata + ): + try: + temp_message = await handler.on_publish(message, sender=sender) + _warn_if_none(temp_message, "on_publish") + except BaseException as e: + # TODO: we should raise the intervention exception to the publisher. + logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True) + return + if temp_message is DropMessage or isinstance(temp_message, DropMessage): + # TODO log message dropped + return + + message_envelope.message = temp_message + self._outstanding_tasks.increment() + task = asyncio.create_task(self._process_publish(message_envelope)) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) # Yield control to the message loop to allow other tasks to run await asyncio.sleep(0) From 13cc05ac0946063824dee47feba65fbdb43b4100 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Sun, 1 Dec 2024 15:01:58 -0500 Subject: [PATCH 5/7] remove rpc from worker runtime --- protos/agent_worker.proto | 47 +---- .../application/_worker_runtime.py | 162 +------------- .../application/protos/agent_worker_pb2.py | 76 +++---- .../application/protos/agent_worker_pb2.pyi | 198 +----------------- 4 files changed, 38 insertions(+), 445 deletions(-) diff --git a/protos/agent_worker.proto b/protos/agent_worker.proto index 4d346dfecd6..8260a1d77f9 100644 --- a/protos/agent_worker.proto +++ b/protos/agent_worker.proto @@ -7,46 +7,11 @@ option csharp_namespace = "Microsoft.AutoGen.Abstractions"; import "cloudevent.proto"; import "google/protobuf/any.proto"; -message TopicId { - string type = 1; - string source = 2; -} - message AgentId { string type = 1; string key = 2; } -message Payload { - string data_type = 1; - string data_content_type = 2; - bytes data = 3; -} - -message RpcRequest { - string request_id = 1; - optional AgentId source = 2; - AgentId target = 3; - string method = 4; - Payload payload = 5; - map metadata = 6; -} - -message RpcResponse { - string request_id = 1; - Payload payload = 2; - string error = 3; - map metadata = 4; -} - -message Event { - string topic_type = 1; - string topic_source = 2; - optional AgentId source = 3; - Payload payload = 4; - map metadata = 5; -} - message RegisterAgentTypeRequest { string request_id = 1; string type = 2; @@ -115,13 +80,11 @@ message SaveStateResponse { message Message { oneof message { - RpcRequest request = 1; - RpcResponse response = 2; - cloudevent.CloudEvent cloudEvent = 3; - RegisterAgentTypeRequest registerAgentTypeRequest = 4; - RegisterAgentTypeResponse registerAgentTypeResponse = 5; - AddSubscriptionRequest addSubscriptionRequest = 6; - AddSubscriptionResponse addSubscriptionResponse = 7; + cloudevent.CloudEvent cloudEvent = 1; + RegisterAgentTypeRequest registerAgentTypeRequest = 2; + RegisterAgentTypeResponse registerAgentTypeResponse = 3; + AddSubscriptionRequest addSubscriptionRequest = 4; + AddSubscriptionResponse addSubscriptionResponse = 5; } } diff --git a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py index 24714f29155..d72d1555486 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py @@ -28,6 +28,7 @@ cast, ) + from google.protobuf import any_pb2 from opentelemetry.trace import TracerProvider from typing_extensions import Self, deprecated @@ -53,6 +54,7 @@ from ..base._serialization import MessageSerializer, SerializationRegistry from ..base._type_helpers import ChannelArgumentType from ..components import TypePrefixSubscription, TypeSubscription +from ..components._publish_based_rpc import PublishBasedRpcMixin from . import _constants from ._constants import GRPC_IMPORT_ERROR_STR from ._helpers import SubscriptionManager, get_impl @@ -177,7 +179,7 @@ async def recv(self) -> agent_worker_pb2.Message: return await self._recv_queue.get() -class WorkerAgentRuntime(AgentRuntime): +class WorkerAgentRuntime(PublishBasedRpcMixin, AgentRuntime): def __init__( self, host_address: str, @@ -237,16 +239,6 @@ async def _run_read_loop(self) -> None: match oneofcase: case "registerAgentTypeRequest" | "addSubscriptionRequest": logger.warning(f"Cant handle {oneofcase}, skipping.") - case "request": - task = asyncio.create_task(self._process_request(message.request)) - self._background_tasks.add(task) - task.add_done_callback(self._raise_on_exception) - task.add_done_callback(self._background_tasks.discard) - case "response": - task = asyncio.create_task(self._process_response(message.response)) - self._background_tasks.add(task) - task.add_done_callback(self._raise_on_exception) - task.add_done_callback(self._background_tasks.discard) case "cloudEvent": # The proto typing doesnt resolve this one cloud_event = cast(cloudevent_pb2.CloudEvent, message.cloudEvent) # type: ignore @@ -331,51 +323,6 @@ async def _send_message( with self._trace_helper.trace_block(send_type, recipient, parent=telemetry_metadata): await self._host_connection.send(runtime_message) - async def send_message( - self, - message: Any, - recipient: AgentId, - *, - sender: AgentId | None = None, - cancellation_token: CancellationToken | None = None, - ) -> Any: - if not self._running: - raise ValueError("Runtime must be running when sending message.") - if self._host_connection is None: - raise RuntimeError("Host connection is not set.") - data_type = self._serialization_registry.type_name(message) - with self._trace_helper.trace_block( - "create", recipient, parent=None, extraAttributes={"message_type": data_type} - ): - # create a new future for the result - future = asyncio.get_event_loop().create_future() - request_id = await self._get_new_request_id() - self._pending_requests[request_id] = future - serialized_message = self._serialization_registry.serialize( - message, type_name=data_type, data_content_type=JSON_DATA_CONTENT_TYPE - ) - telemetry_metadata = get_telemetry_grpc_metadata() - runtime_message = agent_worker_pb2.Message( - request=agent_worker_pb2.RpcRequest( - request_id=request_id, - target=agent_worker_pb2.AgentId(type=recipient.type, key=recipient.key), - source=agent_worker_pb2.AgentId(type=sender.type, key=sender.key) if sender is not None else None, - metadata=telemetry_metadata, - payload=agent_worker_pb2.Payload( - data_type=data_type, - data=serialized_message, - data_content_type=JSON_DATA_CONTENT_TYPE, - ), - ) - ) - - # TODO: Find a way to handle timeouts/errors - task = asyncio.create_task(self._send_message(runtime_message, "send", recipient, telemetry_metadata)) - self._background_tasks.add(task) - task.add_done_callback(self._raise_on_exception) - task.add_done_callback(self._background_tasks.discard) - return await future - async def publish_message( self, message: Any, @@ -475,98 +422,6 @@ async def _get_new_request_id(self) -> str: self._next_request_id += 1 return str(self._next_request_id) - async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None: - assert self._host_connection is not None - recipient = AgentId(request.target.type, request.target.key) - sender: AgentId | None = None - if request.HasField("source"): - sender = AgentId(request.source.type, request.source.key) - logging.info(f"Processing request from {sender} to {recipient}") - else: - logging.info(f"Processing request from unknown source to {recipient}") - - # Deserialize the message. - message = self._serialization_registry.deserialize( - request.payload.data, - type_name=request.payload.data_type, - data_content_type=request.payload.data_content_type, - ) - - # Get the receiving agent and prepare the message context. - rec_agent = await self._get_agent(recipient) - message_context = MessageContext( - sender=sender, - topic_id=None, - cancellation_token=CancellationToken(), - message_id=request.request_id, - ) - - # Call the receiving agent. - try: - with MessageHandlerContext.populate_context(rec_agent.id): - with self._trace_helper.trace_block( - "process", - rec_agent.id, - parent=request.metadata, - attributes={"request_id": request.request_id}, - extraAttributes={"message_type": request.payload.data_type}, - ): - result = await rec_agent.on_message(message, ctx=message_context) - except BaseException as e: - response_message = agent_worker_pb2.Message( - response=agent_worker_pb2.RpcResponse( - request_id=request.request_id, - error=str(e), - metadata=get_telemetry_grpc_metadata(), - ), - ) - # Send the error response. - await self._host_connection.send(response_message) - return - - # Serialize the result. - result_type = self._serialization_registry.type_name(result) - serialized_result = self._serialization_registry.serialize( - result, type_name=result_type, data_content_type=JSON_DATA_CONTENT_TYPE - ) - - # Create the response message. - response_message = agent_worker_pb2.Message( - response=agent_worker_pb2.RpcResponse( - request_id=request.request_id, - payload=agent_worker_pb2.Payload( - data_type=result_type, - data=serialized_result, - data_content_type=JSON_DATA_CONTENT_TYPE, - ), - metadata=get_telemetry_grpc_metadata(), - ) - ) - - # Send the response. - await self._host_connection.send(response_message) - - async def _process_response(self, response: agent_worker_pb2.RpcResponse) -> None: - with self._trace_helper.trace_block( - "ack", - None, - parent=response.metadata, - attributes={"request_id": response.request_id}, - extraAttributes={"message_type": response.payload.data_type}, - ): - # Deserialize the result. - result = self._serialization_registry.deserialize( - response.payload.data, - type_name=response.payload.data_type, - data_content_type=response.payload.data_content_type, - ) - # Get the future and set the result. - future = self._pending_requests.pop(response.request_id) - if len(response.error) > 0: - future.set_exception(Exception(response.error)) - else: - future.set_result(result) - async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None: event_attributes = event.attributes sender: AgentId | None = None @@ -598,16 +453,6 @@ async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None: else: raise ValueError(f"Unsupported message content type: {message_content_type}") - # TODO: dont read these values in the runtime - topic_type_suffix = topic_id.type.split(":", maxsplit=1)[1] if ":" in topic_id.type else "" - is_rpc = topic_type_suffix == _constants.MESSAGE_KIND_VALUE_RPC_REQUEST - is_marked_rpc_type = ( - _constants.MESSAGE_KIND_ATTR in event_attributes - and event_attributes[_constants.MESSAGE_KIND_ATTR].ce_string == _constants.MESSAGE_KIND_VALUE_RPC_REQUEST - ) - if is_rpc and not is_marked_rpc_type: - warnings.warn("Received RPC request with topic type suffix but not marked as RPC request.", stacklevel=2) - # Send the message to each recipient. responses: List[Awaitable[Any]] = [] for agent_id in recipients: @@ -616,7 +461,6 @@ async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None: message_context = MessageContext( sender=sender, topic_id=topic_id, - is_rpc=is_rpc, cancellation_token=CancellationToken(), message_id=event.id, ) diff --git a/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.py b/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.py index 319ee2c6365..ca08dcb1db8 100644 --- a/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.py +++ b/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.py @@ -16,7 +16,7 @@ from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\x1a\x10\x63loudevent.proto\x1a\x19google/protobuf/any.proto\"\'\n\x07TopicId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0e\n\x06source\x18\x02 \x01(\t\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\x89\x02\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12$\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xe4\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12$\n\x06source\x18\x03 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12 \n\x07payload\x18\x04 \x01(\x0b\x32\x0f.agents.Payload\x12-\n\x08metadata\x18\x05 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"<\n\x18RegisterAgentTypeRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\"^\n\x19RegisterAgentTypeResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"G\n\x16TypePrefixSubscription\x12\x19\n\x11topic_type_prefix\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"\x96\x01\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x12@\n\x16typePrefixSubscription\x18\x02 \x01(\x0b\x32\x1e.agents.TypePrefixSubscriptionH\x00\x42\x0e\n\x0csubscription\"X\n\x16\x41\x64\x64SubscriptionRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12*\n\x0csubscription\x18\x02 \x01(\x0b\x32\x14.agents.Subscription\"\\\n\x17\x41\x64\x64SubscriptionResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\x9d\x01\n\nAgentState\x12!\n\x08\x61gent_id\x18\x01 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0c\n\x04\x65Tag\x18\x02 \x01(\t\x12\x15\n\x0b\x62inary_data\x18\x03 \x01(\x0cH\x00\x12\x13\n\ttext_data\x18\x04 \x01(\tH\x00\x12*\n\nproto_data\x18\x05 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00\x42\x06\n\x04\x64\x61ta\"j\n\x10GetStateResponse\x12\'\n\x0b\x61gent_state\x18\x01 \x01(\x0b\x32\x12.agents.AgentState\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"B\n\x11SaveStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\xa6\x03\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12,\n\ncloudEvent\x18\x03 \x01(\x0b\x32\x16.cloudevent.CloudEventH\x00\x12\x44\n\x18registerAgentTypeRequest\x18\x04 \x01(\x0b\x32 .agents.RegisterAgentTypeRequestH\x00\x12\x46\n\x19registerAgentTypeResponse\x18\x05 \x01(\x0b\x32!.agents.RegisterAgentTypeResponseH\x00\x12@\n\x16\x61\x64\x64SubscriptionRequest\x18\x06 \x01(\x0b\x32\x1e.agents.AddSubscriptionRequestH\x00\x12\x42\n\x17\x61\x64\x64SubscriptionResponse\x18\x07 \x01(\x0b\x32\x1f.agents.AddSubscriptionResponseH\x00\x42\t\n\x07message2\xb2\x01\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x12\x35\n\x08GetState\x12\x0f.agents.AgentId\x1a\x18.agents.GetStateResponse\x12:\n\tSaveState\x12\x12.agents.AgentState\x1a\x19.agents.SaveStateResponseB!\xaa\x02\x1eMicrosoft.AutoGen.Abstractionsb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\x1a\x10\x63loudevent.proto\x1a\x19google/protobuf/any.proto\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"<\n\x18RegisterAgentTypeRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\"^\n\x19RegisterAgentTypeResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"G\n\x16TypePrefixSubscription\x12\x19\n\x11topic_type_prefix\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"\x96\x01\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x12@\n\x16typePrefixSubscription\x18\x02 \x01(\x0b\x32\x1e.agents.TypePrefixSubscriptionH\x00\x42\x0e\n\x0csubscription\"X\n\x16\x41\x64\x64SubscriptionRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12*\n\x0csubscription\x18\x02 \x01(\x0b\x32\x14.agents.Subscription\"\\\n\x17\x41\x64\x64SubscriptionResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\x9d\x01\n\nAgentState\x12!\n\x08\x61gent_id\x18\x01 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0c\n\x04\x65Tag\x18\x02 \x01(\t\x12\x15\n\x0b\x62inary_data\x18\x03 \x01(\x0cH\x00\x12\x13\n\ttext_data\x18\x04 \x01(\tH\x00\x12*\n\nproto_data\x18\x05 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00\x42\x06\n\x04\x64\x61ta\"j\n\x10GetStateResponse\x12\'\n\x0b\x61gent_state\x18\x01 \x01(\x0b\x32\x12.agents.AgentState\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"B\n\x11SaveStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\xd6\x02\n\x07Message\x12,\n\ncloudEvent\x18\x01 \x01(\x0b\x32\x16.cloudevent.CloudEventH\x00\x12\x44\n\x18registerAgentTypeRequest\x18\x02 \x01(\x0b\x32 .agents.RegisterAgentTypeRequestH\x00\x12\x46\n\x19registerAgentTypeResponse\x18\x03 \x01(\x0b\x32!.agents.RegisterAgentTypeResponseH\x00\x12@\n\x16\x61\x64\x64SubscriptionRequest\x18\x04 \x01(\x0b\x32\x1e.agents.AddSubscriptionRequestH\x00\x12\x42\n\x17\x61\x64\x64SubscriptionResponse\x18\x05 \x01(\x0b\x32\x1f.agents.AddSubscriptionResponseH\x00\x42\t\n\x07message2\xb2\x01\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x12\x35\n\x08GetState\x12\x0f.agents.AgentId\x1a\x18.agents.GetStateResponse\x12:\n\tSaveState\x12\x12.agents.AgentState\x1a\x19.agents.SaveStateResponseB!\xaa\x02\x1eMicrosoft.AutoGen.Abstractionsb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -24,52 +24,30 @@ if _descriptor._USE_C_DESCRIPTORS == False: _globals['DESCRIPTOR']._options = None _globals['DESCRIPTOR']._serialized_options = b'\252\002\036Microsoft.AutoGen.Abstractions' - _globals['_RPCREQUEST_METADATAENTRY']._options = None - _globals['_RPCREQUEST_METADATAENTRY']._serialized_options = b'8\001' - _globals['_RPCRESPONSE_METADATAENTRY']._options = None - _globals['_RPCRESPONSE_METADATAENTRY']._serialized_options = b'8\001' - _globals['_EVENT_METADATAENTRY']._options = None - _globals['_EVENT_METADATAENTRY']._serialized_options = b'8\001' - _globals['_TOPICID']._serialized_start=75 - _globals['_TOPICID']._serialized_end=114 - _globals['_AGENTID']._serialized_start=116 - _globals['_AGENTID']._serialized_end=152 - _globals['_PAYLOAD']._serialized_start=154 - _globals['_PAYLOAD']._serialized_end=223 - _globals['_RPCREQUEST']._serialized_start=226 - _globals['_RPCREQUEST']._serialized_end=491 - _globals['_RPCREQUEST_METADATAENTRY']._serialized_start=433 - _globals['_RPCREQUEST_METADATAENTRY']._serialized_end=480 - _globals['_RPCRESPONSE']._serialized_start=494 - _globals['_RPCRESPONSE']._serialized_end=678 - _globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=433 - _globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=480 - _globals['_EVENT']._serialized_start=681 - _globals['_EVENT']._serialized_end=909 - _globals['_EVENT_METADATAENTRY']._serialized_start=433 - _globals['_EVENT_METADATAENTRY']._serialized_end=480 - _globals['_REGISTERAGENTTYPEREQUEST']._serialized_start=911 - _globals['_REGISTERAGENTTYPEREQUEST']._serialized_end=971 - _globals['_REGISTERAGENTTYPERESPONSE']._serialized_start=973 - _globals['_REGISTERAGENTTYPERESPONSE']._serialized_end=1067 - _globals['_TYPESUBSCRIPTION']._serialized_start=1069 - _globals['_TYPESUBSCRIPTION']._serialized_end=1127 - _globals['_TYPEPREFIXSUBSCRIPTION']._serialized_start=1129 - _globals['_TYPEPREFIXSUBSCRIPTION']._serialized_end=1200 - _globals['_SUBSCRIPTION']._serialized_start=1203 - _globals['_SUBSCRIPTION']._serialized_end=1353 - _globals['_ADDSUBSCRIPTIONREQUEST']._serialized_start=1355 - _globals['_ADDSUBSCRIPTIONREQUEST']._serialized_end=1443 - _globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_start=1445 - _globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_end=1537 - _globals['_AGENTSTATE']._serialized_start=1540 - _globals['_AGENTSTATE']._serialized_end=1697 - _globals['_GETSTATERESPONSE']._serialized_start=1699 - _globals['_GETSTATERESPONSE']._serialized_end=1805 - _globals['_SAVESTATERESPONSE']._serialized_start=1807 - _globals['_SAVESTATERESPONSE']._serialized_end=1873 - _globals['_MESSAGE']._serialized_start=1876 - _globals['_MESSAGE']._serialized_end=2298 - _globals['_AGENTRPC']._serialized_start=2301 - _globals['_AGENTRPC']._serialized_end=2479 + _globals['_AGENTID']._serialized_start=75 + _globals['_AGENTID']._serialized_end=111 + _globals['_REGISTERAGENTTYPEREQUEST']._serialized_start=113 + _globals['_REGISTERAGENTTYPEREQUEST']._serialized_end=173 + _globals['_REGISTERAGENTTYPERESPONSE']._serialized_start=175 + _globals['_REGISTERAGENTTYPERESPONSE']._serialized_end=269 + _globals['_TYPESUBSCRIPTION']._serialized_start=271 + _globals['_TYPESUBSCRIPTION']._serialized_end=329 + _globals['_TYPEPREFIXSUBSCRIPTION']._serialized_start=331 + _globals['_TYPEPREFIXSUBSCRIPTION']._serialized_end=402 + _globals['_SUBSCRIPTION']._serialized_start=405 + _globals['_SUBSCRIPTION']._serialized_end=555 + _globals['_ADDSUBSCRIPTIONREQUEST']._serialized_start=557 + _globals['_ADDSUBSCRIPTIONREQUEST']._serialized_end=645 + _globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_start=647 + _globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_end=739 + _globals['_AGENTSTATE']._serialized_start=742 + _globals['_AGENTSTATE']._serialized_end=899 + _globals['_GETSTATERESPONSE']._serialized_start=901 + _globals['_GETSTATERESPONSE']._serialized_end=1007 + _globals['_SAVESTATERESPONSE']._serialized_start=1009 + _globals['_SAVESTATERESPONSE']._serialized_end=1075 + _globals['_MESSAGE']._serialized_start=1078 + _globals['_MESSAGE']._serialized_end=1420 + _globals['_AGENTRPC']._serialized_start=1423 + _globals['_AGENTRPC']._serialized_end=1601 # @@protoc_insertion_point(module_scope) diff --git a/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.pyi b/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.pyi index 79e384ab948..7c9baa5e9ca 100644 --- a/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.pyi +++ b/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.pyi @@ -5,33 +5,13 @@ isort:skip_file import builtins import cloudevent_pb2 -import collections.abc import google.protobuf.any_pb2 import google.protobuf.descriptor -import google.protobuf.internal.containers import google.protobuf.message import typing DESCRIPTOR: google.protobuf.descriptor.FileDescriptor -@typing.final -class TopicId(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - TYPE_FIELD_NUMBER: builtins.int - SOURCE_FIELD_NUMBER: builtins.int - type: builtins.str - source: builtins.str - def __init__( - self, - *, - type: builtins.str = ..., - source: builtins.str = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["source", b"source", "type", b"type"]) -> None: ... - -global___TopicId = TopicId - @typing.final class AgentId(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -50,170 +30,6 @@ class AgentId(google.protobuf.message.Message): global___AgentId = AgentId -@typing.final -class Payload(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - DATA_TYPE_FIELD_NUMBER: builtins.int - DATA_CONTENT_TYPE_FIELD_NUMBER: builtins.int - DATA_FIELD_NUMBER: builtins.int - data_type: builtins.str - data_content_type: builtins.str - data: builtins.bytes - def __init__( - self, - *, - data_type: builtins.str = ..., - data_content_type: builtins.str = ..., - data: builtins.bytes = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["data", b"data", "data_content_type", b"data_content_type", "data_type", b"data_type"]) -> None: ... - -global___Payload = Payload - -@typing.final -class RpcRequest(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - @typing.final - class MetadataEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - key: builtins.str - value: builtins.str - def __init__( - self, - *, - key: builtins.str = ..., - value: builtins.str = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... - - REQUEST_ID_FIELD_NUMBER: builtins.int - SOURCE_FIELD_NUMBER: builtins.int - TARGET_FIELD_NUMBER: builtins.int - METHOD_FIELD_NUMBER: builtins.int - PAYLOAD_FIELD_NUMBER: builtins.int - METADATA_FIELD_NUMBER: builtins.int - request_id: builtins.str - method: builtins.str - @property - def source(self) -> global___AgentId: ... - @property - def target(self) -> global___AgentId: ... - @property - def payload(self) -> global___Payload: ... - @property - def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ... - def __init__( - self, - *, - request_id: builtins.str = ..., - source: global___AgentId | None = ..., - target: global___AgentId | None = ..., - method: builtins.str = ..., - payload: global___Payload | None = ..., - metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["_source", b"_source", "payload", b"payload", "source", b"source", "target", b"target"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["_source", b"_source", "metadata", b"metadata", "method", b"method", "payload", b"payload", "request_id", b"request_id", "source", b"source", "target", b"target"]) -> None: ... - def WhichOneof(self, oneof_group: typing.Literal["_source", b"_source"]) -> typing.Literal["source"] | None: ... - -global___RpcRequest = RpcRequest - -@typing.final -class RpcResponse(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - @typing.final - class MetadataEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - key: builtins.str - value: builtins.str - def __init__( - self, - *, - key: builtins.str = ..., - value: builtins.str = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... - - REQUEST_ID_FIELD_NUMBER: builtins.int - PAYLOAD_FIELD_NUMBER: builtins.int - ERROR_FIELD_NUMBER: builtins.int - METADATA_FIELD_NUMBER: builtins.int - request_id: builtins.str - error: builtins.str - @property - def payload(self) -> global___Payload: ... - @property - def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ... - def __init__( - self, - *, - request_id: builtins.str = ..., - payload: global___Payload | None = ..., - error: builtins.str = ..., - metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["payload", b"payload"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["error", b"error", "metadata", b"metadata", "payload", b"payload", "request_id", b"request_id"]) -> None: ... - -global___RpcResponse = RpcResponse - -@typing.final -class Event(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - @typing.final - class MetadataEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - key: builtins.str - value: builtins.str - def __init__( - self, - *, - key: builtins.str = ..., - value: builtins.str = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... - - TOPIC_TYPE_FIELD_NUMBER: builtins.int - TOPIC_SOURCE_FIELD_NUMBER: builtins.int - SOURCE_FIELD_NUMBER: builtins.int - PAYLOAD_FIELD_NUMBER: builtins.int - METADATA_FIELD_NUMBER: builtins.int - topic_type: builtins.str - topic_source: builtins.str - @property - def source(self) -> global___AgentId: ... - @property - def payload(self) -> global___Payload: ... - @property - def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ... - def __init__( - self, - *, - topic_type: builtins.str = ..., - topic_source: builtins.str = ..., - source: global___AgentId | None = ..., - payload: global___Payload | None = ..., - metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["_source", b"_source", "payload", b"payload", "source", b"source"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["_source", b"_source", "metadata", b"metadata", "payload", b"payload", "source", b"source", "topic_source", b"topic_source", "topic_type", b"topic_type"]) -> None: ... - def WhichOneof(self, oneof_group: typing.Literal["_source", b"_source"]) -> typing.Literal["source"] | None: ... - -global___Event = Event - @typing.final class RegisterAgentTypeRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -435,18 +251,12 @@ global___SaveStateResponse = SaveStateResponse class Message(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - REQUEST_FIELD_NUMBER: builtins.int - RESPONSE_FIELD_NUMBER: builtins.int CLOUDEVENT_FIELD_NUMBER: builtins.int REGISTERAGENTTYPEREQUEST_FIELD_NUMBER: builtins.int REGISTERAGENTTYPERESPONSE_FIELD_NUMBER: builtins.int ADDSUBSCRIPTIONREQUEST_FIELD_NUMBER: builtins.int ADDSUBSCRIPTIONRESPONSE_FIELD_NUMBER: builtins.int @property - def request(self) -> global___RpcRequest: ... - @property - def response(self) -> global___RpcResponse: ... - @property def cloudEvent(self) -> cloudevent_pb2.CloudEvent: ... @property def registerAgentTypeRequest(self) -> global___RegisterAgentTypeRequest: ... @@ -459,16 +269,14 @@ class Message(google.protobuf.message.Message): def __init__( self, *, - request: global___RpcRequest | None = ..., - response: global___RpcResponse | None = ..., cloudEvent: cloudevent_pb2.CloudEvent | None = ..., registerAgentTypeRequest: global___RegisterAgentTypeRequest | None = ..., registerAgentTypeResponse: global___RegisterAgentTypeResponse | None = ..., addSubscriptionRequest: global___AddSubscriptionRequest | None = ..., addSubscriptionResponse: global___AddSubscriptionResponse | None = ..., ) -> None: ... - def HasField(self, field_name: typing.Literal["addSubscriptionRequest", b"addSubscriptionRequest", "addSubscriptionResponse", b"addSubscriptionResponse", "cloudEvent", b"cloudEvent", "message", b"message", "registerAgentTypeRequest", b"registerAgentTypeRequest", "registerAgentTypeResponse", b"registerAgentTypeResponse", "request", b"request", "response", b"response"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["addSubscriptionRequest", b"addSubscriptionRequest", "addSubscriptionResponse", b"addSubscriptionResponse", "cloudEvent", b"cloudEvent", "message", b"message", "registerAgentTypeRequest", b"registerAgentTypeRequest", "registerAgentTypeResponse", b"registerAgentTypeResponse", "request", b"request", "response", b"response"]) -> None: ... - def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["request", "response", "cloudEvent", "registerAgentTypeRequest", "registerAgentTypeResponse", "addSubscriptionRequest", "addSubscriptionResponse"] | None: ... + def HasField(self, field_name: typing.Literal["addSubscriptionRequest", b"addSubscriptionRequest", "addSubscriptionResponse", b"addSubscriptionResponse", "cloudEvent", b"cloudEvent", "message", b"message", "registerAgentTypeRequest", b"registerAgentTypeRequest", "registerAgentTypeResponse", b"registerAgentTypeResponse"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["addSubscriptionRequest", b"addSubscriptionRequest", "addSubscriptionResponse", b"addSubscriptionResponse", "cloudEvent", b"cloudEvent", "message", b"message", "registerAgentTypeRequest", b"registerAgentTypeRequest", "registerAgentTypeResponse", b"registerAgentTypeResponse"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["cloudEvent", "registerAgentTypeRequest", "registerAgentTypeResponse", "addSubscriptionRequest", "addSubscriptionResponse"] | None: ... global___Message = Message From 88113d3aa75ddf756a8fb032593a95aae0f9f551 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Sun, 1 Dec 2024 15:02:14 -0500 Subject: [PATCH 6/7] fmt --- .../_single_threaded_agent_runtime.py | 6 +-- .../application/_worker_runtime.py | 1 - .../src/autogen_core/base/_base_agent.py | 11 ++++- .../src/autogen_core/base/_rpc.py | 6 +-- .../autogen_core/components/_closure_agent.py | 12 +++++- .../components/_publish_based_rpc.py | 41 ++++++++++--------- .../autogen_core/components/_routed_agent.py | 8 +++- 7 files changed, 52 insertions(+), 33 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py index 763dada6ab5..06eabddf4f9 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py @@ -15,9 +15,6 @@ from opentelemetry.trace import TracerProvider from typing_extensions import deprecated -from ..base._serialization import MessageSerializer, SerializationRegistry -from ..components._publish_based_rpc import PublishBasedRpcMixin - from ..base import ( Agent, AgentId, @@ -32,8 +29,9 @@ SubscriptionInstantiationContext, TopicId, ) - +from ..base._serialization import MessageSerializer, SerializationRegistry from ..base.intervention import DropMessage, InterventionHandler +from ..components._publish_based_rpc import PublishBasedRpcMixin from ._helpers import SubscriptionManager, get_impl from .telemetry import EnvelopeMetadata, MessageRuntimeTracingConfig, TraceHelper, get_telemetry_envelope_metadata diff --git a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py index d72d1555486..a9d2a3c3570 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py @@ -28,7 +28,6 @@ cast, ) - from google.protobuf import any_pb2 from opentelemetry.trace import TracerProvider from typing_extensions import Self, deprecated diff --git a/python/packages/autogen-core/src/autogen_core/base/_base_agent.py b/python/packages/autogen-core/src/autogen_core/base/_base_agent.py index f2baa6d1a07..318178913f0 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/base/_base_agent.py @@ -57,6 +57,7 @@ def decorator(cls: Type[BaseAgentType]) -> Type[BaseAgentType]: return decorator + class BaseAgent(ABC, Agent): internal_unbound_subscriptions_list: ClassVar[List[UnboundSubscription]] = [] internal_extra_handles_types: ClassVar[List[Tuple[Type[Any], List[MessageSerializer[Any]]]]] = [] @@ -132,7 +133,10 @@ async def on_message(self, message: Any, ctx: MessageContext) -> None: elif self._forward_unbound_rpc_responses_to_handler: await self.on_message_impl(message, ctx) else: - warnings.warn(f"Received RPC response for unknown request {request_id}. To forward unbound rpc responses to the handler, set forward_unbound_rpc_responses_to_handler=True", stacklevel=2) + warnings.warn( + f"Received RPC response for unknown request {request_id}. To forward unbound rpc responses to the handler, set forward_unbound_rpc_responses_to_handler=True", + stacklevel=2, + ) return None await self.on_message_impl(message, ctx) @@ -148,7 +152,10 @@ async def send_message( if cancellation_token is None: cancellation_token = CancellationToken() - recipient_topic = TopicId(type=format_rpc_request_topic(rpc_recipient_agent_type=recipient.type, rpc_sender_agent_type=self.id.type), source=recipient.key) + recipient_topic = TopicId( + type=format_rpc_request_topic(rpc_recipient_agent_type=recipient.type, rpc_sender_agent_type=self.id.type), + source=recipient.key, + ) request_id = str(uuid.uuid4()) future = Future[Any]() diff --git a/python/packages/autogen-core/src/autogen_core/base/_rpc.py b/python/packages/autogen-core/src/autogen_core/base/_rpc.py index d6b857afd0e..d6554e71844 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_rpc.py +++ b/python/packages/autogen-core/src/autogen_core/base/_rpc.py @@ -3,14 +3,14 @@ from typing import Optional - - def format_rpc_request_topic(rpc_recipient_agent_type: str, rpc_sender_agent_type: str) -> str: return f"{rpc_recipient_agent_type}:rpc_request={rpc_sender_agent_type}" -def format_rpc_response_topic(rpc_sender_agent_type: str,request_id: str) -> str: + +def format_rpc_response_topic(rpc_sender_agent_type: str, request_id: str) -> str: return f"{rpc_sender_agent_type}:rpc_response={request_id}" + # If is an rpc response, return the request id def is_rpc_response(topic_type: str) -> Optional[str]: topic_segments = topic_type.split(":") diff --git a/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py b/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py index d6b100547bd..36566f4ecbc 100644 --- a/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py +++ b/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py @@ -76,7 +76,11 @@ async def publish_message( class ClosureAgent(BaseAgent, ClosureContext): def __init__( - self, description: str, closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]], *, forward_unbound_rpc_responses_to_handler: bool = False + self, + description: str, + closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]], + *, + forward_unbound_rpc_responses_to_handler: bool = False, ) -> None: try: runtime = AgentInstantiationContext.current_runtime() @@ -137,7 +141,11 @@ async def register_closure( subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None = None, ) -> AgentType: def factory() -> ClosureAgent: - return ClosureAgent(description=description, closure=closure, forward_unbound_rpc_responses_to_handler=forward_unbound_rpc_responses_to_handler) + return ClosureAgent( + description=description, + closure=closure, + forward_unbound_rpc_responses_to_handler=forward_unbound_rpc_responses_to_handler, + ) agent_type = await cls.register( runtime=runtime, diff --git a/python/packages/autogen-core/src/autogen_core/components/_publish_based_rpc.py b/python/packages/autogen-core/src/autogen_core/components/_publish_based_rpc.py index 2ae0e09eeba..a091fe61e78 100644 --- a/python/packages/autogen-core/src/autogen_core/components/_publish_based_rpc.py +++ b/python/packages/autogen-core/src/autogen_core/components/_publish_based_rpc.py @@ -1,42 +1,44 @@ - -from typing import Any +import asyncio import uuid import warnings +from typing import Any from autogen_core.base._rpc import format_rpc_request_topic, format_rpc_response_topic from autogen_core.base._topic import TopicId -from ..base._message_context import MessageContext - from ..base._agent_id import AgentId +from ..base._agent_runtime import AgentRuntime from ..base._cancellation_token import CancellationToken +from ..base._message_context import MessageContext from ._closure_agent import ClosureAgent, ClosureContext -from ..base._agent_runtime import AgentRuntime - -import asyncio class PublishBasedRpcMixin(AgentRuntime): async def send_message( - self: AgentRuntime, - message: Any, - recipient: AgentId, - *, - cancellation_token: CancellationToken | None = None, - ) -> Any: - + self: AgentRuntime, + message: Any, + recipient: AgentId, + *, + cancellation_token: CancellationToken | None = None, + ) -> Any: rpc_request_id = str(uuid.uuid4()) # TODO add "-" to topic and agent type allowed characters in spec closure_agent_type = f"rpc_receiver_{recipient.type}_{rpc_request_id}" future: asyncio.Future[Any] = asyncio.Future() - expected_response_topic_type = format_rpc_response_topic(rpc_sender_agent_type=closure_agent_type, request_id=rpc_request_id) - async def set_result(closure_context:ClosureContext, message: Any, ctx: MessageContext) -> None: + expected_response_topic_type = format_rpc_response_topic( + rpc_sender_agent_type=closure_agent_type, request_id=rpc_request_id + ) + + async def set_result(closure_context: ClosureContext, message: Any, ctx: MessageContext) -> None: assert ctx.topic_id is not None if ctx.topic_id.type == expected_response_topic_type: future.set_result(message) else: - warnings.warn(f"{closure_agent_type} received an unexpected message on topic type {ctx.topic_id.type}. Expected {expected_response_topic_type}", stacklevel=2) + warnings.warn( + f"{closure_agent_type} received an unexpected message on topic type {ctx.topic_id.type}. Expected {expected_response_topic_type}", + stacklevel=2, + ) # TODO: remove agent after response is received @@ -47,7 +49,9 @@ async def set_result(closure_context:ClosureContext, message: Any, ctx: MessageC forward_unbound_rpc_responses_to_handler=True, ) - rpc_request_topic_id = format_rpc_request_topic(rpc_recipient_agent_type=recipient.type, rpc_sender_agent_type=closure_agent_type) + rpc_request_topic_id = format_rpc_request_topic( + rpc_recipient_agent_type=recipient.type, rpc_sender_agent_type=closure_agent_type + ) await self.publish_message( message=message, topic_id=TopicId(type=rpc_request_topic_id, source=recipient.key), @@ -58,4 +62,3 @@ async def set_result(closure_context:ClosureContext, message: Any, ctx: MessageC return await future # register a closure agent... - diff --git a/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py b/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py index 50ef43aa67f..c8e6da3b2c9 100644 --- a/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py +++ b/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py @@ -301,7 +301,9 @@ async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> None wrapper_handler.produces_types = list(return_types) wrapper_handler.is_message_handler = True # Wrap the match function with a check on the topic for rpc - wrapper_handler.router = lambda _message, _ctx: (is_rpc_request(_ctx.topic_id.type) is None) and (match(_message, _ctx) if match else True) + wrapper_handler.router = lambda _message, _ctx: (is_rpc_request(_ctx.topic_id.type) is None) and ( + match(_message, _ctx) if match else True + ) return wrapper_handler @@ -440,7 +442,9 @@ async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> None wrapper_handler.target_types = list(target_types) wrapper_handler.produces_types = list(return_types) wrapper_handler.is_message_handler = True - wrapper_handler.router = lambda _message, _ctx: (is_rpc_request(_ctx.topic_id.type) is not None) and (match(_message, _ctx) if match else True) + wrapper_handler.router = lambda _message, _ctx: (is_rpc_request(_ctx.topic_id.type) is not None) and ( + match(_message, _ctx) if match else True + ) return wrapper_handler From 11dea884fbf9d18f466d5bbce91043ad4c3e7440 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Sun, 1 Dec 2024 15:21:32 -0500 Subject: [PATCH 7/7] lint, type, fmt fixes --- .../_group_chat/_sequential_routed_agent.py | 4 +- .../cookbook/local-llms-ollama-litellm.ipynb | 3 +- .../_worker_runtime_host_servicer.py | 49 ------------------- .../src/autogen_core/base/_agent_proxy.py | 2 - .../src/autogen_core/base/_base_agent.py | 2 +- .../autogen_core/components/_routed_agent.py | 6 +-- .../autogen-core/tests/test_routed_agent.py | 5 +- .../headless_web_surfer/test_web_surfer.py | 5 -- 8 files changed, 10 insertions(+), 66 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_sequential_routed_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_sequential_routed_agent.py index fe80c9e9392..6b92b21e883 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_sequential_routed_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_sequential_routed_agent.py @@ -43,9 +43,9 @@ def __init__(self, description: str) -> None: super().__init__(description=description) self._fifo_lock = FIFOLock() - async def on_message(self, message: Any, ctx: MessageContext) -> Any | None: + async def on_message_impl(self, message: Any, ctx: MessageContext) -> None: await self._fifo_lock.acquire() try: - return await super().on_message(message, ctx) + await super().on_message_impl(message, ctx) finally: self._fifo_lock.release() diff --git a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/local-llms-ollama-litellm.ipynb b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/local-llms-ollama-litellm.ipynb index 80fde2b7101..e9274ae0a52 100644 --- a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/local-llms-ollama-litellm.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/local-llms-ollama-litellm.ipynb @@ -174,7 +174,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -211,7 +211,6 @@ "await runtime.send_message(\n", " Message(\"Joe, tell me a joke.\"),\n", " recipient=AgentId(joe, \"default\"),\n", - " sender=AgentId(cathy, \"default\"),\n", ")\n", "await runtime.stop_when_idle()" ] diff --git a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host_servicer.py b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host_servicer.py index e24a7db3f30..4dfd52b9949 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host_servicer.py +++ b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host_servicer.py @@ -102,18 +102,6 @@ async def _receive_messages( logger.info(f"Received message from client {client_id}: {message}") oneofcase = message.WhichOneof("message") match oneofcase: - case "request": - request: agent_worker_pb2.RpcRequest = message.request - task = asyncio.create_task(self._process_request(request, client_id)) - self._background_tasks.add(task) - task.add_done_callback(self._raise_on_exception) - task.add_done_callback(self._background_tasks.discard) - case "response": - response: agent_worker_pb2.RpcResponse = message.response - task = asyncio.create_task(self._process_response(response, client_id)) - self._background_tasks.add(task) - task.add_done_callback(self._raise_on_exception) - task.add_done_callback(self._background_tasks.discard) case "cloudEvent": # The proto typing doesnt resolve this one event = cast(cloudevent_pb2.CloudEvent, message.cloudEvent) # type: ignore @@ -140,43 +128,6 @@ async def _receive_messages( case None: logger.warning("Received empty message") - async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: int) -> None: - # Deliver the message to a client given the target agent type. - async with self._agent_type_to_client_id_lock: - target_client_id = self._agent_type_to_client_id.get(request.target.type) - if target_client_id is None: - logger.error(f"Agent {request.target.type} not found, failed to deliver message.") - return - target_send_queue = self._send_queues.get(target_client_id) - if target_send_queue is None: - logger.error(f"Client {target_client_id} not found, failed to deliver message.") - return - await target_send_queue.put(agent_worker_pb2.Message(request=request)) - - # Create a future to wait for the response from the target. - future = asyncio.get_event_loop().create_future() - self._pending_responses.setdefault(target_client_id, {})[request.request_id] = future - - # Create a task to wait for the response and send it back to the client. - send_response_task = asyncio.create_task(self._wait_and_send_response(future, client_id)) - self._background_tasks.add(send_response_task) - send_response_task.add_done_callback(self._raise_on_exception) - send_response_task.add_done_callback(self._background_tasks.discard) - - async def _wait_and_send_response(self, future: Future[agent_worker_pb2.RpcResponse], client_id: int) -> None: - response = await future - message = agent_worker_pb2.Message(response=response) - send_queue = self._send_queues.get(client_id) - if send_queue is None: - logger.error(f"Client {client_id} not found, failed to send response message.") - return - await send_queue.put(message) - - async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: int) -> None: - # Setting the result of the future will send the response back to the original sender. - future = self._pending_responses[client_id].pop(response.request_id) - future.set_result(response) - async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None: topic_id = TopicId(type=event.type, source=event.source) recipients = await self._subscription_manager.get_subscribed_recipients(topic_id) diff --git a/python/packages/autogen-core/src/autogen_core/base/_agent_proxy.py b/python/packages/autogen-core/src/autogen_core/base/_agent_proxy.py index f3eb70f2827..5b8ddc314a0 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_agent_proxy.py +++ b/python/packages/autogen-core/src/autogen_core/base/_agent_proxy.py @@ -29,13 +29,11 @@ async def send_message( self, message: Any, *, - sender: AgentId, cancellation_token: CancellationToken | None = None, ) -> Any: return await self._runtime.send_message( message, recipient=self._agent, - sender=sender, cancellation_token=cancellation_token, ) diff --git a/python/packages/autogen-core/src/autogen_core/base/_base_agent.py b/python/packages/autogen-core/src/autogen_core/base/_base_agent.py index 318178913f0..0df1cc25caa 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/base/_base_agent.py @@ -126,7 +126,7 @@ async def on_message_impl(self, message: Any, ctx: MessageContext) -> None: ... @final async def on_message(self, message: Any, ctx: MessageContext) -> None: # Intercept RPC responses - if ctx.topic_id is not None and (request_id := is_rpc_response(ctx.topic_id.type)) is not None: + if (request_id := is_rpc_response(ctx.topic_id.type)) is not None: if request_id in self._pending_rpc_requests: self._pending_rpc_requests[request_id].set_result(message) del self._pending_rpc_requests[request_id] diff --git a/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py b/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py index c8e6da3b2c9..b7a169f184c 100644 --- a/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py +++ b/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py @@ -515,7 +515,7 @@ def __init__(self, description: str) -> None: super().__init__(description) - async def on_message_impl(self, message: Any, ctx: MessageContext): + async def on_message_impl(self, message: Any, ctx: MessageContext) -> None: """Handle a message by routing it to the appropriate message handler. Do not override this method in subclasses. Instead, add message handlers as methods decorated with either the :func:`event` or :func:`rpc` decorator.""" @@ -526,8 +526,8 @@ async def on_message_impl(self, message: Any, ctx: MessageContext): # Call the first handler whose router returns True and then return the result. for h in handlers: if h.router(message, ctx): - return await h(self, message, ctx) - return await self.on_unhandled_message(message, ctx) # type: ignore + await h(self, message, ctx) + await self.on_unhandled_message(message, ctx) async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None: """Called when a message is received that does not have a matching message handler. diff --git a/python/packages/autogen-core/tests/test_routed_agent.py b/python/packages/autogen-core/tests/test_routed_agent.py index cab1b1d467f..55de2432f4f 100644 --- a/python/packages/autogen-core/tests/test_routed_agent.py +++ b/python/packages/autogen-core/tests/test_routed_agent.py @@ -5,6 +5,7 @@ import pytest from autogen_core.application import SingleThreadedAgentRuntime from autogen_core.base import AgentId, MessageContext, TopicId +from autogen_core.base._rpc import is_rpc_request from autogen_core.components import RoutedAgent, TypeSubscription, event, message_handler, rpc from test_utils import LoopbackAgent @@ -23,12 +24,12 @@ def __init__(self) -> None: self.num_calls_rpc = 0 self.num_calls_broadcast = 0 - @message_handler(match=lambda _, ctx: ctx.is_rpc) + @message_handler(match=lambda _, ctx: is_rpc_request(ctx.topic_id.type) is not None) async def on_rpc_message(self, message: MessageType, ctx: MessageContext) -> MessageType: self.num_calls_rpc += 1 return message - @message_handler(match=lambda _, ctx: not ctx.is_rpc) + @message_handler(match=lambda _, ctx: is_rpc_request(ctx.topic_id.type) is None) async def on_broadcast_message(self, message: MessageType, ctx: MessageContext) -> None: self.num_calls_broadcast += 1 diff --git a/python/packages/autogen-magentic-one/tests/headless_web_surfer/test_web_surfer.py b/python/packages/autogen-magentic-one/tests/headless_web_surfer/test_web_surfer.py index 769ac5080e8..6106a9f219a 100644 --- a/python/packages/autogen-magentic-one/tests/headless_web_surfer/test_web_surfer.py +++ b/python/packages/autogen-magentic-one/tests/headless_web_surfer/test_web_surfer.py @@ -218,27 +218,22 @@ async def test_web_surfer_oai() -> None: ) ), recipient=web_surfer.id, - sender=user_proxy.id, ) await runtime.send_message( BroadcastMessage(content=UserMessage(content="Please scroll down.", source="user")), recipient=web_surfer.id, - sender=user_proxy.id, ) await runtime.send_message( BroadcastMessage(content=UserMessage(content="Please scroll up.", source="user")), recipient=web_surfer.id, - sender=user_proxy.id, ) await runtime.send_message( BroadcastMessage(content=UserMessage(content="When was it founded?", source="user")), recipient=web_surfer.id, - sender=user_proxy.id, ) await runtime.send_message( BroadcastMessage(content=UserMessage(content="What's this page about?", source="user")), recipient=web_surfer.id, - sender=user_proxy.id, ) await runtime.stop_when_idle()