Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RPC over events #4414

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/design/02 - Topics.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,6 @@ For this subscription source should map directly to agent key.
This subscription will therefore receive all events for the following well known topics:

- `{AgentType}:` - General purpose direct messages. These should be routed to the approriate message handler.
- `{AgentType}:rpc_request` - RPC request messages. These should be routed to the approriate RPC handler.
- `{AgentType}:rpc_request={RequesterAgentType}` - RPC request messages. These should be routed to the approriate RPC handler.
- `{AgentType}:rpc_response={RequestId}` - RPC response messages. These should be routed back to the response future of the caller.
- `{AgentType}:error={RequestId}` - Error message that corresponds to the given request.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
" def __init__(self) -> None:\n",
" super().__init__(\"MyAgent\")\n",
"\n",
" async def on_message(self, message: MyMessageType, ctx: MessageContext) -> None:\n",
" async def on_message_impl(self, message: MyMessageType, ctx: MessageContext) -> None:\n",
" print(f\"Received message: {message.content}\") # type: ignore"
]
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing_extensions import deprecated

from autogen_core.base._serialization import MessageSerializer, SerializationRegistry
from autogen_core.components.send_message_mixin import PublishBasedRpcMixin

from ..base import (
Agent,
Expand Down Expand Up @@ -166,7 +167,7 @@ def _warn_if_none(value: Any, handler_name: str) -> None:
)


class SingleThreadedAgentRuntime(AgentRuntime):
class SingleThreadedAgentRuntime(PublishBasedRpcMixin, AgentRuntime):
def __init__(
self,
*,
Expand Down Expand Up @@ -202,54 +203,54 @@ def _known_agent_names(self) -> Set[str]:
return set(self._agent_factories.keys())

# Returns the response of the message
async def send_message(
self,
message: Any,
recipient: AgentId,
*,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
) -> Any:
if cancellation_token is None:
cancellation_token = CancellationToken()

# event_logger.info(
# MessageEvent(
# payload=message,
# sender=sender,
# receiver=recipient,
# kind=MessageKind.DIRECT,
# delivery_stage=DeliveryStage.SEND,
# )
# )

with self._tracer_helper.trace_block(
"create",
recipient,
parent=None,
extraAttributes={"message_type": type(message).__name__},
):
future = asyncio.get_event_loop().create_future()
if recipient.type not in self._known_agent_names:
future.set_exception(Exception("Recipient not found"))

content = message.__dict__ if hasattr(message, "__dict__") else message
logger.info(f"Sending message of type {type(message).__name__} to {recipient.type}: {content}")

self._message_queue.append(
SendMessageEnvelope(
message=message,
recipient=recipient,
future=future,
cancellation_token=cancellation_token,
sender=sender,
metadata=get_telemetry_envelope_metadata(),
)
)

cancellation_token.link_future(future)

return await future
# async def send_message(
# self,
# message: Any,
# recipient: AgentId,
# *,
# sender: AgentId | None = None,
# cancellation_token: CancellationToken | None = None,
# ) -> Any:
# if cancellation_token is None:
# cancellation_token = CancellationToken()

# # event_logger.info(
# # MessageEvent(
# # payload=message,
# # sender=sender,
# # receiver=recipient,
# # kind=MessageKind.DIRECT,
# # delivery_stage=DeliveryStage.SEND,
# # )
# # )

# with self._tracer_helper.trace_block(
# "create",
# recipient,
# parent=None,
# extraAttributes={"message_type": type(message).__name__},
# ):
# future = asyncio.get_event_loop().create_future()
# if recipient.type not in self._known_agent_names:
# future.set_exception(Exception("Recipient not found"))

# content = message.__dict__ if hasattr(message, "__dict__") else message
# logger.info(f"Sending message of type {type(message).__name__} to {recipient.type}: {content}")

# self._message_queue.append(
# SendMessageEnvelope(
# message=message,
# recipient=recipient,
# future=future,
# cancellation_token=cancellation_token,
# sender=sender,
# metadata=get_telemetry_envelope_metadata(),
# )
# )

# cancellation_token.link_future(future)

# return await future

async def publish_message(
self,
Expand Down Expand Up @@ -332,7 +333,6 @@ async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
message_context = MessageContext(
sender=message_envelope.sender,
topic_id=None,
is_rpc=True,
cancellation_token=message_envelope.cancellation_token,
# Will be fixed when send API removed
message_id="NOT_DEFINED_TODO_FIX",
Expand Down Expand Up @@ -392,7 +392,6 @@ async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> No
message_context = MessageContext(
sender=message_envelope.sender,
topic_id=message_envelope.topic_id,
is_rpc=False,
cancellation_token=message_envelope.cancellation_token,
message_id=message_envelope.message_id,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,6 @@ async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None:
message_context = MessageContext(
sender=sender,
topic_id=None,
is_rpc=True,
cancellation_token=CancellationToken(),
message_id=request.request_id,
)
Expand Down
5 changes: 1 addition & 4 deletions python/packages/autogen-core/src/autogen_core/base/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,13 @@ def id(self) -> AgentId:
"""ID of the agent."""
...

async def on_message(self, message: Any, ctx: MessageContext) -> Any:
async def on_message(self, message: Any, ctx: MessageContext) -> None:
"""Message handler for the agent. This should only be called by the runtime, not by other agents.

Args:
message (Any): Received message. Type is one of the types in `subscriptions`.
ctx (MessageContext): Context of the message.

Returns:
Any: Response to the message. Can be None.

Raises:
asyncio.CancelledError: If the message was cancelled.
CantHandleException: If the agent cannot handle the message.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ def __init__(self, type: str | AgentType, key: str) -> None:
if isinstance(type, AgentType):
type = type.type

if type.isidentifier() is False:
raise ValueError(f"Invalid type: {type}")
# TODO: fixme
# if type.isidentifier() is False:
# raise ValueError(f"Invalid type: {type}")

self._type = type
self._key = key
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ async def send_message(
message: Any,
recipient: AgentId,
*,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
) -> Any:
"""Send a message to an agent and get a response.
Expand Down
51 changes: 45 additions & 6 deletions python/packages/autogen-core/src/autogen_core/base/_base_agent.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from __future__ import annotations

import inspect
import uuid
import warnings
from abc import ABC, abstractmethod
from asyncio import Future
from collections.abc import Sequence
from typing import Any, Awaitable, Callable, ClassVar, List, Mapping, Tuple, Type, TypeVar
from typing import Any, Awaitable, Callable, ClassVar, Dict, List, Mapping, Tuple, Type, TypeVar, final

from typing_extensions import Self

from autogen_core.base._rpc import format_rpc_request_topic, is_rpc_response

from ._agent import Agent
from ._agent_id import AgentId
from ._agent_instantiation import AgentInstantiationContext
Expand Down Expand Up @@ -53,7 +57,6 @@ def decorator(cls: Type[BaseAgentType]) -> Type[BaseAgentType]:

return decorator


class BaseAgent(ABC, Agent):
internal_unbound_subscriptions_list: ClassVar[List[UnboundSubscription]] = []
internal_extra_handles_types: ClassVar[List[Tuple[Type[Any], List[MessageSerializer[Any]]]]] = []
Expand All @@ -77,7 +80,17 @@ def metadata(self) -> AgentMetadata:
assert self._id is not None
return AgentMetadata(key=self._id.key, type=self._id.type, description=self._description)

def __init__(self, description: str) -> None:
def __init__(self, description: str, *, forward_unbound_rpc_responses_to_handler: bool = False) -> None:
"""Base agent that all agents should inherit from. Puts in place assumed common functionality.

Args:
description (str): Description of the agent.
forward_unbound_rpc_responses_to_handler (bool, optional): If an rpc request ID is not know to the agent, should the rpc request be forwarded to the handler. Defaults to False.

Raises:
RuntimeError: If the agent is not instantiated within the context of an AgentRuntime.
ValueError: If there is an argument type error.
"""
try:
runtime = AgentInstantiationContext.current_runtime()
id = AgentInstantiationContext.current_agent_id()
Expand All @@ -91,6 +104,8 @@ def __init__(self, description: str) -> None:
if not isinstance(description, str):
raise ValueError("Agent description must be a string")
self._description = description
self._pending_rpc_requests: Dict[str, Future[Any]] = {}
self._forward_unbound_rpc_responses_to_handler = forward_unbound_rpc_responses_to_handler

@property
def type(self) -> str:
Expand All @@ -105,7 +120,21 @@ def runtime(self) -> AgentRuntime:
return self._runtime

@abstractmethod
async def on_message(self, message: Any, ctx: MessageContext) -> Any: ...
async def on_message_impl(self, message: Any, ctx: MessageContext) -> None: ...

@final
async def on_message(self, message: Any, ctx: MessageContext) -> None:
# Intercept RPC responses
if ctx.topic_id is not None and (request_id := is_rpc_response(ctx.topic_id.type)) is not None:
if request_id in self._pending_rpc_requests:
self._pending_rpc_requests[request_id].set_result(message)
elif self._forward_unbound_rpc_responses_to_handler:
await self.on_message_impl(message, ctx)
else:
warnings.warn(f"Received RPC response for unknown request {request_id}. To forward unbound rpc responses to the handler, set forward_unbound_rpc_responses_to_handler=True", stacklevel=2)
return None

return await self.on_message_impl(message, ctx)

async def send_message(
self,
Expand All @@ -118,13 +147,23 @@ async def send_message(
if cancellation_token is None:
cancellation_token = CancellationToken()

return await self._runtime.send_message(
recipient_topic = TopicId(type=format_rpc_request_topic(rpc_recipient_agent_type=recipient.type, rpc_sender_agent_type=self.id.type), source=recipient.key)
request_id = str(uuid.uuid4())

future = Future[Any]()

await self._runtime.publish_message(
message,
sender=self.id,
recipient=recipient,
topic_id=recipient_topic,
cancellation_token=cancellation_token,
message_id=request_id,
)

self._pending_rpc_requests[request_id] = future

return future

async def publish_message(
self,
message: Any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
@dataclass
class MessageContext:
sender: AgentId | None
topic_id: TopicId | None
is_rpc: bool
topic_id: TopicId
cancellation_token: CancellationToken
message_id: str
31 changes: 31 additions & 0 deletions python/packages/autogen-core/src/autogen_core/base/_rpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

from typing import Optional




def format_rpc_request_topic(rpc_recipient_agent_type: str, rpc_sender_agent_type: str) -> str:
return f"{rpc_recipient_agent_type}:rpc_request={rpc_sender_agent_type}"

def format_rpc_response_topic(rpc_sender_agent_type: str,request_id: str) -> str:
return f"{rpc_sender_agent_type}:rpc_response={request_id}"

# If is an rpc response, return the request id
def is_rpc_response(topic_type: str) -> Optional[str]:
topic_segments = topic_type.split(":")
# Find if there is a segment starting with :rpc_response=
for segment in topic_segments:
if segment.startswith("rpc_response="):
return segment[len("rpc_response=") :]
return None


# If is an rpc response, return the requestor agent type
def is_rpc_request(topic_type: str) -> Optional[str]:
topic_segments = topic_type.split(":")
# Find if there is a segment starting with :rpc_request=
for segment in topic_segments:
if segment.startswith("rpc_request="):
return segment[len("rpc_request=") :]
return None
Loading
Loading