Skip to content

Commit

Permalink
Remove require_response, rename broadcast to publish, remove publish …
Browse files Browse the repository at this point in the history
…responses (autogenhub#25)

* rename broadcast to publish

* remove require response, remove responses from publishing
  • Loading branch information
jackgerrits authored May 26, 2024
1 parent b6dd861 commit cb55e00
Show file tree
Hide file tree
Showing 14 changed files with 69 additions and 161 deletions.
12 changes: 3 additions & 9 deletions examples/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ def __init__(self, name: str, router: AgentRuntime) -> None:
super().__init__(name, router)

@message_handler(MessageType)
async def on_new_message(
self, message: MessageType, require_response: bool, cancellation_token: CancellationToken
) -> MessageType:
assert require_response
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
return MessageType(body=f"Inner: {message.body}", sender=self.name)


Expand All @@ -32,11 +29,8 @@ def __init__(self, name: str, router: AgentRuntime, inner: Agent) -> None:
self._inner = inner

@message_handler(MessageType)
async def on_new_message(
self, message: MessageType, require_response: bool, cancellation_token: CancellationToken
) -> MessageType:
assert require_response
inner_response = self._send_message(message, self._inner, require_response=True)
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
inner_response = self._send_message(message, self._inner)
inner_message = await inner_response
assert isinstance(inner_message, MessageType)
return MessageType(body=f"Outer: {inner_message.body}", sender=self.name)
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ select = ["E", "F", "W", "B", "Q", "I"]
ignore = ["F401", "E501"]

[tool.mypy]
files = ["src", "examples"]
files = ["src", "examples", "tests"]

strict = true
python_version = "3.10"
Expand All @@ -53,7 +53,7 @@ disallow_untyped_decorators = true
disallow_any_unimported = true

[tool.pyright]
include = ["src", "examples"]
include = ["src", "examples", "tests"]
typeCheckingMode = "strict"
reportUnnecessaryIsInstance = false
reportMissingTypeStubs = false
Expand Down
22 changes: 9 additions & 13 deletions src/agnext/agent_components/type_routed_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
def message_handler(
*target_types: Type[ReceivesT],
) -> Callable[
[Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]]],
Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
[Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]]],
Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
]:
def decorator(
func: Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
) -> Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]]:
func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
) -> Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]]:
# Convert target_types to list and stash
func._target_types = list(target_types) # type: ignore
return func
Expand All @@ -34,7 +34,7 @@ def __init__(self, name: str, router: AgentRuntime) -> None:
super().__init__(name, router)

# Self is already bound to the handlers
self._handlers: Dict[Type[Any], Callable[[Any, bool, CancellationToken], Coroutine[Any, Any, Any | None]]] = {}
self._handlers: Dict[Type[Any], Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]]] = {}

router.add_agent(self)

Expand All @@ -49,17 +49,13 @@ def __init__(self, name: str, router: AgentRuntime) -> None:
def subscriptions(self) -> Sequence[Type[Any]]:
return list(self._handlers.keys())

async def on_message(
self, message: Any, require_response: bool, cancellation_token: CancellationToken
) -> Any | None:
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None:
key_type: Type[Any] = type(message) # type: ignore
handler = self._handlers.get(key_type) # type: ignore
if handler is not None:
return await handler(message, require_response, cancellation_token)
return await handler(message, cancellation_token)
else:
return await self.on_unhandled_message(message, require_response, cancellation_token)
return await self.on_unhandled_message(message, cancellation_token)

async def on_unhandled_message(
self, message: Any, require_response: bool, cancellation_token: CancellationToken
) -> NoReturn:
async def on_unhandled_message(self, message: Any, cancellation_token: CancellationToken) -> NoReturn:
raise CantHandleException(f"Unhandled message: {message}")
107 changes: 26 additions & 81 deletions src/agnext/application_components/single_threaded_agent_runtime.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from asyncio import Future
from dataclasses import dataclass
from typing import Any, Awaitable, Dict, List, Sequence, Set, cast
from typing import Any, Awaitable, Dict, List, Set

from agnext.core.cancellation_token import CancellationToken
from agnext.core.exceptions import MessageDroppedException
Expand All @@ -12,15 +12,13 @@


@dataclass(kw_only=True)
class BroadcastMessageEnvelope:
"""A message envelope for broadcasting messages to all agents that can handle
class PublishMessageEnvelope:
"""A message envelope for publishing messages to all agents that can handle
the message of the type T."""

message: Any
future: Future[Sequence[Any] | None]
cancellation_token: CancellationToken
sender: Agent | None
require_response: bool


@dataclass(kw_only=True)
Expand All @@ -31,9 +29,8 @@ class SendMessageEnvelope:
message: Any
sender: Agent | None
recipient: Agent
future: Future[Any | None]
future: Future[Any]
cancellation_token: CancellationToken
require_response: bool


@dataclass(kw_only=True)
Expand All @@ -46,20 +43,9 @@ class ResponseMessageEnvelope:
recipient: Agent | None


@dataclass(kw_only=True)
class BroadcastResponseMessageEnvelope:
"""A message envelope for sending a response to a message."""

message: Sequence[Any]
future: Future[Sequence[Any]]
recipient: Agent | None


class SingleThreadedAgentRuntime(AgentRuntime):
def __init__(self, *, before_send: InterventionHandler | None = None) -> None:
self._message_queue: List[
BroadcastMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope | BroadcastResponseMessageEnvelope
] = []
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
self._per_type_subscribers: Dict[type, List[Agent]] = {}
self._agents: Set[Agent] = set()
self._before_send = before_send
Expand All @@ -77,7 +63,6 @@ def send_message(
message: Any,
recipient: Agent,
*,
require_response: bool = True,
sender: Agent | None = None,
cancellation_token: CancellationToken | None = None,
) -> Future[Any | None]:
Expand All @@ -95,36 +80,31 @@ def send_message(
future=future,
cancellation_token=cancellation_token,
sender=sender,
require_response=require_response,
)
)

return future

# send message, require_response=False -> returns after delivery, gives None
# send message, require_response=True -> returns after handling, gives Response
def broadcast_message(
def publish_message(
self,
message: Any,
*,
require_response: bool = True,
sender: Agent | None = None,
cancellation_token: CancellationToken | None = None,
) -> Future[Sequence[Any] | None]:
) -> Future[None]:
if cancellation_token is None:
cancellation_token = CancellationToken()

future = asyncio.get_event_loop().create_future()
self._message_queue.append(
BroadcastMessageEnvelope(
PublishMessageEnvelope(
message=message,
future=future,
cancellation_token=cancellation_token,
sender=sender,
require_response=require_response,
)
)

future = asyncio.get_event_loop().create_future()
future.set_result(None)
return future

async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
Expand All @@ -134,64 +114,41 @@ async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
try:
response = await recipient.on_message(
message_envelope.message,
require_response=message_envelope.require_response,
cancellation_token=message_envelope.cancellation_token,
)
except BaseException as e:
message_envelope.future.set_exception(e)
return

if not message_envelope.require_response and response is not None:
raise Exception("Recipient returned a response for a message that did not request a response")

if message_envelope.require_response and response is None:
raise Exception("Recipient did not return a response for a message that requested a response")

if message_envelope.require_response:
self._message_queue.append(
ResponseMessageEnvelope(
message=response,
future=message_envelope.future,
sender=message_envelope.recipient,
recipient=message_envelope.sender,
)
self._message_queue.append(
ResponseMessageEnvelope(
message=response,
future=message_envelope.future,
sender=message_envelope.recipient,
recipient=message_envelope.sender,
)
else:
message_envelope.future.set_result(None)
)

async def _process_broadcast(self, message_envelope: BroadcastMessageEnvelope) -> None:
async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None:
responses: List[Awaitable[Any]] = []
for agent in self._per_type_subscribers.get(type(message_envelope.message), []): # type: ignore
future = agent.on_message(
message_envelope.message,
require_response=message_envelope.require_response,
cancellation_token=message_envelope.cancellation_token,
)
responses.append(future)

try:
all_responses = await asyncio.gather(*responses)
except BaseException as e:
message_envelope.future.set_exception(e)
_all_responses = await asyncio.gather(*responses)
except BaseException:
# TODO log error
return

if message_envelope.require_response:
self._message_queue.append(
BroadcastResponseMessageEnvelope(
message=all_responses,
future=cast(Future[Sequence[Any]], message_envelope.future),
recipient=message_envelope.sender,
)
)
else:
message_envelope.future.set_result(None)
# TODO if responses are given for a publish

async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None:
message_envelope.future.set_result(message_envelope.message)

async def _process_broadcast_response(self, message_envelope: BroadcastResponseMessageEnvelope) -> None:
message_envelope.future.set_result(message_envelope.message)

async def process_next(self) -> None:
if len(self._message_queue) == 0:
# Yield control to the event loop to allow other tasks to run
Expand All @@ -211,20 +168,19 @@ async def process_next(self) -> None:
message_envelope.message = temp_message

asyncio.create_task(self._process_send(message_envelope))
case BroadcastMessageEnvelope(
case PublishMessageEnvelope(
message=message,
sender=sender,
future=future,
):
if self._before_send is not None:
temp_message = await self._before_send.on_broadcast(message, sender=sender)
temp_message = await self._before_send.on_publish(message, sender=sender)
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
future.set_exception(MessageDroppedException())
# TODO log message dropped
return

message_envelope.message = temp_message

asyncio.create_task(self._process_broadcast(message_envelope))
asyncio.create_task(self._process_publish(message_envelope))
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
if self._before_send is not None:
temp_message = await self._before_send.on_response(message, sender=sender, recipient=recipient)
Expand All @@ -236,16 +192,5 @@ async def process_next(self) -> None:

asyncio.create_task(self._process_response(message_envelope))

case BroadcastResponseMessageEnvelope(message=message, recipient=recipient, future=future):
if self._before_send is not None:
temp_message_list = await self._before_send.on_broadcast_response(message, recipient=recipient)
if temp_message_list is DropMessage or isinstance(temp_message_list, DropMessage):
future.set_exception(MessageDroppedException())
return

message_envelope.message = list(temp_message_list) # type: ignore

asyncio.create_task(self._process_broadcast_response(message_envelope))

# Yield control to the message loop to allow other tasks to run
await asyncio.sleep(0)
15 changes: 3 additions & 12 deletions src/agnext/chat/agents/oai_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
# TODO: use require_response
@message_handler(TextMessage)
async def on_chat_message_with_cancellation(
self, message: TextMessage, require_response: bool, cancellation_token: CancellationToken
self, message: TextMessage, cancellation_token: CancellationToken
) -> None:
print("---------------")
print(f"{self.name} received message from {message.source}: {message.content}")
Expand All @@ -41,22 +41,13 @@ async def on_chat_message_with_cancellation(
)
self._current_session_window_length += 1

if require_response:
# TODO ?
...

@message_handler(Reset)
async def on_reset(self, message: Reset, require_response: bool, cancellation_token: CancellationToken) -> None:
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
# Reset the current session window.
self._current_session_window_length = 0

@message_handler(RespondNow)
async def on_respond_now(
self, message: RespondNow, require_response: bool, cancellation_token: CancellationToken
) -> TextMessage | None:
if not require_response:
return None

async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
# Create a run and wait until it finishes.
run = await self._client.beta.threads.runs.create_and_poll(
thread_id=self._thread_id,
Expand Down
2 changes: 1 addition & 1 deletion src/agnext/chat/agents/random_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class RandomResponseAgent(BaseChatAgent, TypeRoutedAgent):
# TODO: use require_response
@message_handler(RespondNow)
async def on_chat_message_with_cancellation(
self, message: RespondNow, require_response: bool, cancellation_token: CancellationToken
self, message: RespondNow, cancellation_token: CancellationToken
) -> TextMessage:
# Generate a random response.
response_body = random.choice(
Expand Down
Loading

0 comments on commit cb55e00

Please sign in to comment.