diff --git a/examples/futures.py b/examples/futures.py index c60d9d74c66..e82d881691f 100644 --- a/examples/futures.py +++ b/examples/futures.py @@ -19,10 +19,7 @@ def __init__(self, name: str, router: AgentRuntime) -> None: super().__init__(name, router) @message_handler(MessageType) - async def on_new_message( - self, message: MessageType, require_response: bool, cancellation_token: CancellationToken - ) -> MessageType: - assert require_response + async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: return MessageType(body=f"Inner: {message.body}", sender=self.name) @@ -32,11 +29,8 @@ def __init__(self, name: str, router: AgentRuntime, inner: Agent) -> None: self._inner = inner @message_handler(MessageType) - async def on_new_message( - self, message: MessageType, require_response: bool, cancellation_token: CancellationToken - ) -> MessageType: - assert require_response - inner_response = self._send_message(message, self._inner, require_response=True) + async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: + inner_response = self._send_message(message, self._inner) inner_message = await inner_response assert isinstance(inner_message, MessageType) return MessageType(body=f"Outer: {inner_message.body}", sender=self.name) diff --git a/pyproject.toml b/pyproject.toml index 9417bb00845..b9a20b913de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ select = ["E", "F", "W", "B", "Q", "I"] ignore = ["F401", "E501"] [tool.mypy] -files = ["src", "examples"] +files = ["src", "examples", "tests"] strict = true python_version = "3.10" @@ -53,7 +53,7 @@ disallow_untyped_decorators = true disallow_any_unimported = true [tool.pyright] -include = ["src", "examples"] +include = ["src", "examples", "tests"] typeCheckingMode = "strict" reportUnnecessaryIsInstance = false reportMissingTypeStubs = false diff --git a/src/agnext/agent_components/type_routed_agent.py b/src/agnext/agent_components/type_routed_agent.py index af6042d87b2..d6fa1030996 100644 --- a/src/agnext/agent_components/type_routed_agent.py +++ b/src/agnext/agent_components/type_routed_agent.py @@ -16,12 +16,12 @@ def message_handler( *target_types: Type[ReceivesT], ) -> Callable[ - [Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]]], - Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]], + [Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]]], + Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]], ]: def decorator( - func: Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]], - ) -> Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]]: + func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]], + ) -> Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]]: # Convert target_types to list and stash func._target_types = list(target_types) # type: ignore return func @@ -34,7 +34,7 @@ def __init__(self, name: str, router: AgentRuntime) -> None: super().__init__(name, router) # Self is already bound to the handlers - self._handlers: Dict[Type[Any], Callable[[Any, bool, CancellationToken], Coroutine[Any, Any, Any | None]]] = {} + self._handlers: Dict[Type[Any], Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]]] = {} router.add_agent(self) @@ -49,17 +49,13 @@ def __init__(self, name: str, router: AgentRuntime) -> None: def subscriptions(self) -> Sequence[Type[Any]]: return list(self._handlers.keys()) - async def on_message( - self, message: Any, require_response: bool, cancellation_token: CancellationToken - ) -> Any | None: + async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None: key_type: Type[Any] = type(message) # type: ignore handler = self._handlers.get(key_type) # type: ignore if handler is not None: - return await handler(message, require_response, cancellation_token) + return await handler(message, cancellation_token) else: - return await self.on_unhandled_message(message, require_response, cancellation_token) + return await self.on_unhandled_message(message, cancellation_token) - async def on_unhandled_message( - self, message: Any, require_response: bool, cancellation_token: CancellationToken - ) -> NoReturn: + async def on_unhandled_message(self, message: Any, cancellation_token: CancellationToken) -> NoReturn: raise CantHandleException(f"Unhandled message: {message}") diff --git a/src/agnext/application_components/single_threaded_agent_runtime.py b/src/agnext/application_components/single_threaded_agent_runtime.py index b21a382b87b..9d9320c1b2f 100644 --- a/src/agnext/application_components/single_threaded_agent_runtime.py +++ b/src/agnext/application_components/single_threaded_agent_runtime.py @@ -1,7 +1,7 @@ import asyncio from asyncio import Future from dataclasses import dataclass -from typing import Any, Awaitable, Dict, List, Sequence, Set, cast +from typing import Any, Awaitable, Dict, List, Set from agnext.core.cancellation_token import CancellationToken from agnext.core.exceptions import MessageDroppedException @@ -12,15 +12,13 @@ @dataclass(kw_only=True) -class BroadcastMessageEnvelope: - """A message envelope for broadcasting messages to all agents that can handle +class PublishMessageEnvelope: + """A message envelope for publishing messages to all agents that can handle the message of the type T.""" message: Any - future: Future[Sequence[Any] | None] cancellation_token: CancellationToken sender: Agent | None - require_response: bool @dataclass(kw_only=True) @@ -31,9 +29,8 @@ class SendMessageEnvelope: message: Any sender: Agent | None recipient: Agent - future: Future[Any | None] + future: Future[Any] cancellation_token: CancellationToken - require_response: bool @dataclass(kw_only=True) @@ -46,20 +43,9 @@ class ResponseMessageEnvelope: recipient: Agent | None -@dataclass(kw_only=True) -class BroadcastResponseMessageEnvelope: - """A message envelope for sending a response to a message.""" - - message: Sequence[Any] - future: Future[Sequence[Any]] - recipient: Agent | None - - class SingleThreadedAgentRuntime(AgentRuntime): def __init__(self, *, before_send: InterventionHandler | None = None) -> None: - self._message_queue: List[ - BroadcastMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope | BroadcastResponseMessageEnvelope - ] = [] + self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = [] self._per_type_subscribers: Dict[type, List[Agent]] = {} self._agents: Set[Agent] = set() self._before_send = before_send @@ -77,7 +63,6 @@ def send_message( message: Any, recipient: Agent, *, - require_response: bool = True, sender: Agent | None = None, cancellation_token: CancellationToken | None = None, ) -> Future[Any | None]: @@ -95,36 +80,31 @@ def send_message( future=future, cancellation_token=cancellation_token, sender=sender, - require_response=require_response, ) ) return future - # send message, require_response=False -> returns after delivery, gives None - # send message, require_response=True -> returns after handling, gives Response - def broadcast_message( + def publish_message( self, message: Any, *, - require_response: bool = True, sender: Agent | None = None, cancellation_token: CancellationToken | None = None, - ) -> Future[Sequence[Any] | None]: + ) -> Future[None]: if cancellation_token is None: cancellation_token = CancellationToken() - future = asyncio.get_event_loop().create_future() self._message_queue.append( - BroadcastMessageEnvelope( + PublishMessageEnvelope( message=message, - future=future, cancellation_token=cancellation_token, sender=sender, - require_response=require_response, ) ) + future = asyncio.get_event_loop().create_future() + future.set_result(None) return future async def _process_send(self, message_envelope: SendMessageEnvelope) -> None: @@ -134,64 +114,41 @@ async def _process_send(self, message_envelope: SendMessageEnvelope) -> None: try: response = await recipient.on_message( message_envelope.message, - require_response=message_envelope.require_response, cancellation_token=message_envelope.cancellation_token, ) except BaseException as e: message_envelope.future.set_exception(e) return - if not message_envelope.require_response and response is not None: - raise Exception("Recipient returned a response for a message that did not request a response") - - if message_envelope.require_response and response is None: - raise Exception("Recipient did not return a response for a message that requested a response") - - if message_envelope.require_response: - self._message_queue.append( - ResponseMessageEnvelope( - message=response, - future=message_envelope.future, - sender=message_envelope.recipient, - recipient=message_envelope.sender, - ) + self._message_queue.append( + ResponseMessageEnvelope( + message=response, + future=message_envelope.future, + sender=message_envelope.recipient, + recipient=message_envelope.sender, ) - else: - message_envelope.future.set_result(None) + ) - async def _process_broadcast(self, message_envelope: BroadcastMessageEnvelope) -> None: + async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None: responses: List[Awaitable[Any]] = [] for agent in self._per_type_subscribers.get(type(message_envelope.message), []): # type: ignore future = agent.on_message( message_envelope.message, - require_response=message_envelope.require_response, cancellation_token=message_envelope.cancellation_token, ) responses.append(future) try: - all_responses = await asyncio.gather(*responses) - except BaseException as e: - message_envelope.future.set_exception(e) + _all_responses = await asyncio.gather(*responses) + except BaseException: + # TODO log error return - if message_envelope.require_response: - self._message_queue.append( - BroadcastResponseMessageEnvelope( - message=all_responses, - future=cast(Future[Sequence[Any]], message_envelope.future), - recipient=message_envelope.sender, - ) - ) - else: - message_envelope.future.set_result(None) + # TODO if responses are given for a publish async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None: message_envelope.future.set_result(message_envelope.message) - async def _process_broadcast_response(self, message_envelope: BroadcastResponseMessageEnvelope) -> None: - message_envelope.future.set_result(message_envelope.message) - async def process_next(self) -> None: if len(self._message_queue) == 0: # Yield control to the event loop to allow other tasks to run @@ -211,20 +168,19 @@ async def process_next(self) -> None: message_envelope.message = temp_message asyncio.create_task(self._process_send(message_envelope)) - case BroadcastMessageEnvelope( + case PublishMessageEnvelope( message=message, sender=sender, - future=future, ): if self._before_send is not None: - temp_message = await self._before_send.on_broadcast(message, sender=sender) + temp_message = await self._before_send.on_publish(message, sender=sender) if temp_message is DropMessage or isinstance(temp_message, DropMessage): - future.set_exception(MessageDroppedException()) + # TODO log message dropped return message_envelope.message = temp_message - asyncio.create_task(self._process_broadcast(message_envelope)) + asyncio.create_task(self._process_publish(message_envelope)) case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future): if self._before_send is not None: temp_message = await self._before_send.on_response(message, sender=sender, recipient=recipient) @@ -236,16 +192,5 @@ async def process_next(self) -> None: asyncio.create_task(self._process_response(message_envelope)) - case BroadcastResponseMessageEnvelope(message=message, recipient=recipient, future=future): - if self._before_send is not None: - temp_message_list = await self._before_send.on_broadcast_response(message, recipient=recipient) - if temp_message_list is DropMessage or isinstance(temp_message_list, DropMessage): - future.set_exception(MessageDroppedException()) - return - - message_envelope.message = list(temp_message_list) # type: ignore - - asyncio.create_task(self._process_broadcast_response(message_envelope)) - # Yield control to the message loop to allow other tasks to run await asyncio.sleep(0) diff --git a/src/agnext/chat/agents/oai_assistant.py b/src/agnext/chat/agents/oai_assistant.py index 8ee8980313d..27c6d01d3d2 100644 --- a/src/agnext/chat/agents/oai_assistant.py +++ b/src/agnext/chat/agents/oai_assistant.py @@ -26,7 +26,7 @@ def __init__( # TODO: use require_response @message_handler(TextMessage) async def on_chat_message_with_cancellation( - self, message: TextMessage, require_response: bool, cancellation_token: CancellationToken + self, message: TextMessage, cancellation_token: CancellationToken ) -> None: print("---------------") print(f"{self.name} received message from {message.source}: {message.content}") @@ -41,22 +41,13 @@ async def on_chat_message_with_cancellation( ) self._current_session_window_length += 1 - if require_response: - # TODO ? - ... - @message_handler(Reset) - async def on_reset(self, message: Reset, require_response: bool, cancellation_token: CancellationToken) -> None: + async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None: # Reset the current session window. self._current_session_window_length = 0 @message_handler(RespondNow) - async def on_respond_now( - self, message: RespondNow, require_response: bool, cancellation_token: CancellationToken - ) -> TextMessage | None: - if not require_response: - return None - + async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage: # Create a run and wait until it finishes. run = await self._client.beta.threads.runs.create_and_poll( thread_id=self._thread_id, diff --git a/src/agnext/chat/agents/random_agent.py b/src/agnext/chat/agents/random_agent.py index f96a2d949d1..1df2e17db6e 100644 --- a/src/agnext/chat/agents/random_agent.py +++ b/src/agnext/chat/agents/random_agent.py @@ -11,7 +11,7 @@ class RandomResponseAgent(BaseChatAgent, TypeRoutedAgent): # TODO: use require_response @message_handler(RespondNow) async def on_chat_message_with_cancellation( - self, message: RespondNow, require_response: bool, cancellation_token: CancellationToken + self, message: RespondNow, cancellation_token: CancellationToken ) -> TextMessage: # Generate a random response. response_body = random.choice( diff --git a/src/agnext/chat/patterns/group_chat.py b/src/agnext/chat/patterns/group_chat.py index 0120522d9b7..08a6f762d02 100644 --- a/src/agnext/chat/patterns/group_chat.py +++ b/src/agnext/chat/patterns/group_chat.py @@ -36,9 +36,7 @@ def subscriptions(self) -> Sequence[type]: agent_sublists = [agent.subscriptions for agent in self._agents] return [Reset, RespondNow] + [item for sublist in agent_sublists for item in sublist] - async def on_message( - self, message: Any, require_response: bool, cancellation_token: CancellationToken - ) -> Any | None: + async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None: if isinstance(message, Reset): # Reset the history. self._history = [] @@ -48,10 +46,8 @@ async def on_message( # TODO reset... return self._output.get_output() - # TODO: should we do nothing here? - # Perhaps it should be saved into the message history? - if not require_response: - return None + # TODO: how should we handle the group chat receiving a message while in the middle of a conversation? + # Should this class disallow it? self._history.append(message) round = 0 @@ -67,14 +63,13 @@ async def on_message( _ = await self._send_message( self._history[-1], agent, - require_response=False, cancellation_token=cancellation_token, ) + # TODO handle if response is not None response = await self._send_message( RespondNow(), speaker, - require_response=True, cancellation_token=cancellation_token, ) @@ -88,4 +83,5 @@ async def on_message( output = self._output.get_output() self._output.reset() + self._history.clear() return output diff --git a/src/agnext/chat/patterns/orchestrator.py b/src/agnext/chat/patterns/orchestrator.py index fbf597a97f3..ca0902945c4 100644 --- a/src/agnext/chat/patterns/orchestrator.py +++ b/src/agnext/chat/patterns/orchestrator.py @@ -34,7 +34,6 @@ def __init__( async def on_chat_message( self, message: ChatMessage, - require_response: bool, cancellation_token: CancellationToken, ) -> ChatMessage | None: # A task is received. diff --git a/src/agnext/core/agent.py b/src/agnext/core/agent.py index 038ffd2a33c..2fd60ec6381 100644 --- a/src/agnext/core/agent.py +++ b/src/agnext/core/agent.py @@ -11,6 +11,4 @@ def name(self) -> str: ... @property def subscriptions(self) -> Sequence[type]: ... - async def on_message( - self, message: Any, require_response: bool, cancellation_token: CancellationToken - ) -> Any | None: ... + async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None: ... diff --git a/src/agnext/core/agent_runtime.py b/src/agnext/core/agent_runtime.py index 67f558c6f65..7bd4875fe9f 100644 --- a/src/agnext/core/agent_runtime.py +++ b/src/agnext/core/agent_runtime.py @@ -1,5 +1,5 @@ from asyncio import Future -from typing import Any, Protocol, Sequence +from typing import Any, Protocol from agnext.core.agent import Agent from agnext.core.cancellation_token import CancellationToken @@ -16,17 +16,15 @@ def send_message( message: Any, recipient: Agent, *, - require_response: bool = True, sender: Agent | None = None, cancellation_token: CancellationToken | None = None, - ) -> Future[Any | None]: ... + ) -> Future[Any]: ... - # Returns the response of all handling agents - def broadcast_message( + # No responses from publishing + def publish_message( self, message: Any, *, - require_response: bool = True, sender: Agent | None = None, cancellation_token: CancellationToken | None = None, - ) -> Future[Sequence[Any] | None]: ... + ) -> Future[None]: ... diff --git a/src/agnext/core/base_agent.py b/src/agnext/core/base_agent.py index 9560328ac2f..38c13ab879c 100644 --- a/src/agnext/core/base_agent.py +++ b/src/agnext/core/base_agent.py @@ -29,9 +29,7 @@ def subscriptions(self) -> Sequence[type]: return [] @abstractmethod - async def on_message( - self, message: Any, require_response: bool, cancellation_token: CancellationToken - ) -> Any | None: ... + async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None: ... # Returns the response of the message def _send_message( @@ -39,9 +37,8 @@ def _send_message( message: Any, recipient: Agent, *, - require_response: bool = True, cancellation_token: CancellationToken | None = None, - ) -> Future[Any | None]: + ) -> Future[Any]: if cancellation_token is None: cancellation_token = CancellationToken() @@ -49,23 +46,18 @@ def _send_message( message, sender=self, recipient=recipient, - require_response=require_response, cancellation_token=cancellation_token, ) cancellation_token.link_future(future) return future - # Returns the response of all handling agents - def _broadcast_message( + def _publish_message( self, message: Any, *, - require_response: bool = True, cancellation_token: CancellationToken | None = None, - ) -> Future[Sequence[Any] | None]: + ) -> Future[None]: if cancellation_token is None: cancellation_token = CancellationToken() - future = self._router.broadcast_message( - message, sender=self, require_response=require_response, cancellation_token=cancellation_token - ) + future = self._router.publish_message(message, sender=self, cancellation_token=cancellation_token) return future diff --git a/src/agnext/core/intervention.py b/src/agnext/core/intervention.py index 5a002c33129..4ceb016f75f 100644 --- a/src/agnext/core/intervention.py +++ b/src/agnext/core/intervention.py @@ -12,9 +12,9 @@ class DropMessage: ... class InterventionHandler(Protocol): async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]: ... - async def on_broadcast(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]: ... + async def on_publish(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]: ... async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: ... - async def on_broadcast_response( + async def on_publish_response( self, message: Sequence[Any], *, recipient: Agent | None ) -> Sequence[Any] | type[DropMessage]: ... @@ -23,13 +23,13 @@ class DefaultInterventionHandler(InterventionHandler): async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]: return message - async def on_broadcast(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]: + async def on_publish(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]: return message async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: return message - async def on_broadcast_response( + async def on_publish_response( self, message: Sequence[Any], *, recipient: Agent | None ) -> Sequence[Any] | type[DropMessage]: return message diff --git a/tests/test_cancellation.py b/tests/test_cancellation.py index de256ce40bc..5473261b52e 100644 --- a/tests/test_cancellation.py +++ b/tests/test_cancellation.py @@ -23,7 +23,7 @@ def __init__(self, name: str, router: AgentRuntime) -> None: self.cancelled = False @message_handler(MessageType) - async def on_new_message(self, message: MessageType, require_response: bool, cancellation_token: CancellationToken) -> MessageType: + async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: self.called = True sleep = asyncio.ensure_future(asyncio.sleep(100)) cancellation_token.link_future(sleep) @@ -42,10 +42,9 @@ def __init__(self, name: str, router: AgentRuntime, nested_agent: Agent) -> None self._nested_agent = nested_agent @message_handler(MessageType) - async def on_new_message(self, message: MessageType, require_response: bool, cancellation_token: CancellationToken) -> MessageType: - assert require_response == True + async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: self.called = True - response = self._send_message(message, self._nested_agent, require_response=require_response, cancellation_token=cancellation_token) + response = self._send_message(message, self._nested_agent, cancellation_token=cancellation_token) try: val = await response assert isinstance(val, MessageType) diff --git a/tests/test_intervention.py b/tests/test_intervention.py index 7750eedf749..c712db97fe6 100644 --- a/tests/test_intervention.py +++ b/tests/test_intervention.py @@ -20,7 +20,7 @@ def __init__(self, name: str, router: AgentRuntime) -> None: @message_handler(MessageType) - async def on_new_message(self, message: MessageType, require_response: bool, cancellation_token: CancellationToken) -> MessageType: + async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: self.num_calls += 1 return message @@ -28,7 +28,7 @@ async def on_new_message(self, message: MessageType, require_response: bool, can async def test_intervention_count_messages() -> None: class DebugInterventionHandler(DefaultInterventionHandler): - def __init__(self): + def __init__(self) -> None: self.num_messages = 0 async def on_send(self, message: MessageType, *, sender: Agent | None, recipient: Agent) -> MessageType: