Skip to content

Commit

Permalink
Refactor chat (autogenhub#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
ekzhu authored Jun 9, 2024
1 parent e99ad51 commit 7ef502a
Show file tree
Hide file tree
Showing 18 changed files with 99 additions and 103 deletions.
5 changes: 2 additions & 3 deletions examples/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import aiofiles
import openai
from agnext.application import SingleThreadedAgentRuntime
from agnext.chat.agents.base import BaseChatAgent
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
from agnext.chat.patterns.group_chat import GroupChatOutput
from agnext.chat.patterns.two_agent_chat import TwoAgentChat
Expand Down Expand Up @@ -38,7 +37,7 @@ def reset(self) -> None:
sep = "-" * 50


class UserProxyAgent(BaseChatAgent, TypeRoutedAgent): # type: ignore
class UserProxyAgent(TypeRoutedAgent): # type: ignore
def __init__(
self,
name: str,
Expand All @@ -52,7 +51,7 @@ def __init__(
name=name,
description="A human user",
runtime=runtime,
)
) # type: ignore
self._client = client
self._assistant_id = assistant_id
self._thread_id = thread_id
Expand Down
20 changes: 10 additions & 10 deletions examples/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,22 @@ class MessageType:
sender: str


class Inner(TypeRoutedAgent):
def __init__(self, name: str, router: AgentRuntime) -> None:
super().__init__(name, router)
class Inner(TypeRoutedAgent): # type: ignore
def __init__(self, name: str, router: AgentRuntime) -> None: # type: ignore
super().__init__(name, "The inner agent", router)

@message_handler()
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
@message_handler() # type: ignore
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore
return MessageType(body=f"Inner: {message.body}", sender=self.name)


class Outer(TypeRoutedAgent):
def __init__(self, name: str, router: AgentRuntime, inner: Agent) -> None:
super().__init__(name, router)
class Outer(TypeRoutedAgent): # type: ignore
def __init__(self, name: str, router: AgentRuntime, inner: Agent) -> None: # type: ignore
super().__init__(name, "The outter agent", router)
self._inner = inner

@message_handler()
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
@message_handler() # type: ignore
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore
inner_response = self._send_message(message, self._inner)
inner_message = await inner_response
assert isinstance(inner_message, MessageType)
Expand Down
4 changes: 4 additions & 0 deletions src/agnext/chat/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .chat_completion_agent import ChatCompletionAgent
from .oai_assistant import OpenAIAssistantAgent

__all__ = ["ChatCompletionAgent", "OpenAIAssistantAgent"]
14 changes: 0 additions & 14 deletions src/agnext/chat/agents/base.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/agnext/chat/agents/chat_completion_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@
TextMessage,
)
from ..utils import convert_messages_to_llm_messages
from .base import BaseChatAgent


class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent):
class ChatCompletionAgent(TypeRoutedAgent):
def __init__(
self,
name: str,
Expand All @@ -40,6 +39,7 @@ def __init__(
tools: Sequence[Tool] = [],
) -> None:
super().__init__(name, description, runtime)
self._description = description
self._system_messages = system_messages
self._client = model_client
self._memory = memory
Expand Down
9 changes: 4 additions & 5 deletions src/agnext/chat/agents/oai_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
from openai import AsyncAssistantEventHandler
from openai.types.beta import AssistantResponseFormatParam

from agnext.chat.agents.base import BaseChatAgent
from agnext.chat.types import Reset, RespondNow, ResponseFormat, TextMessage
from agnext.components import TypeRoutedAgent, message_handler
from agnext.core import AgentRuntime, CancellationToken
from ...components import TypeRoutedAgent, message_handler
from ...core import AgentRuntime, CancellationToken
from ..types import Reset, RespondNow, ResponseFormat, TextMessage


class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
class OpenAIAssistantAgent(TypeRoutedAgent):
def __init__(
self,
name: str,
Expand Down
16 changes: 8 additions & 8 deletions src/agnext/chat/patterns/group_chat.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Any, List, Protocol, Sequence

from ...components import TypeRoutedAgent, message_handler
from ...core import AgentRuntime, CancellationToken
from ..agents.base import BaseChatAgent
from ...core import Agent, AgentRuntime, CancellationToken
from ..types import Reset, RespondNow, TextMessage


Expand All @@ -14,25 +13,26 @@ def get_output(self) -> Any: ...
def reset(self) -> None: ...


class GroupChat(BaseChatAgent, TypeRoutedAgent):
class GroupChat(TypeRoutedAgent):
def __init__(
self,
name: str,
description: str,
runtime: AgentRuntime,
agents: Sequence[BaseChatAgent],
participants: Sequence[Agent],
num_rounds: int,
output: GroupChatOutput,
) -> None:
self._agents = agents
self._description = description
self._participants = participants
self._num_rounds = num_rounds
self._history: List[Any] = []
self._output = output
super().__init__(name, description, runtime)

@property
def subscriptions(self) -> Sequence[type]:
agent_sublists = [agent.subscriptions for agent in self._agents]
agent_sublists = [agent.subscriptions for agent in self._participants]
return [Reset, RespondNow] + [item for sublist in agent_sublists for item in sublist]

@message_handler()
Expand All @@ -55,10 +55,10 @@ async def on_text_message(self, message: TextMessage, cancellation_token: Cancel
while round < self._num_rounds:
# TODO: add support for advanced speaker selection.
# Select speaker (round-robin for now).
speaker = self._agents[round % len(self._agents)]
speaker = self._participants[round % len(self._participants)]

# Send the last message to all agents except the previous speaker.
for agent in [agent for agent in self._agents if agent is not prev_speaker]:
for agent in [agent for agent in self._participants if agent is not prev_speaker]:
# TODO gather and await
_ = await self._send_message(
self._history[-1],
Expand Down
11 changes: 5 additions & 6 deletions src/agnext/chat/patterns/orchestrator_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,21 @@
from typing import Any, Sequence, Tuple

from ...components import TypeRoutedAgent, message_handler
from ...core import AgentRuntime, CancellationToken
from ..agents.base import BaseChatAgent
from ...core import Agent, AgentRuntime, CancellationToken
from ..types import Reset, RespondNow, ResponseFormat, TextMessage

__all__ = ["OrchestratorChat"]


class OrchestratorChat(BaseChatAgent, TypeRoutedAgent):
class OrchestratorChat(TypeRoutedAgent):
def __init__(
self,
name: str,
description: str,
runtime: AgentRuntime,
orchestrator: BaseChatAgent,
planner: BaseChatAgent,
specialists: Sequence[BaseChatAgent],
orchestrator: Agent,
planner: Agent,
specialists: Sequence[Agent],
max_turns: int = 30,
max_stalled_turns_before_retry: int = 2,
max_retry_attempts: int = 1,
Expand Down
10 changes: 4 additions & 6 deletions src/agnext/chat/patterns/two_agent_chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput

from ...core import AgentRuntime
from ..agents.base import BaseChatAgent
from ...core import Agent, AgentRuntime
from .group_chat import GroupChat, GroupChatOutput


# TODO: rewrite this with a new message type calling for add to message
Expand All @@ -12,8 +10,8 @@ def __init__(
name: str,
description: str,
runtime: AgentRuntime,
first_speaker: BaseChatAgent,
second_speaker: BaseChatAgent,
first_speaker: Agent,
second_speaker: Agent,
num_rounds: int,
output: GroupChatOutput,
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions src/agnext/chat/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from enum import Enum
from typing import List, Union

from agnext.components import FunctionCall, Image
from agnext.components.models import FunctionExecutionResultMessage
from ..components import FunctionCall, Image
from ..components.models import FunctionExecutionResultMessage


@dataclass(kw_only=True)
Expand Down
14 changes: 7 additions & 7 deletions src/agnext/chat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@

from typing_extensions import Literal

from agnext.chat.types import (
FunctionCallMessage,
Message,
MultiModalMessage,
TextMessage,
)
from agnext.components.models import (
from ..components.models import (
AssistantMessage,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
UserMessage,
)
from .types import (
FunctionCallMessage,
Message,
MultiModalMessage,
TextMessage,
)


def convert_content_message_to_assistant_message(
Expand Down
8 changes: 4 additions & 4 deletions src/agnext/components/_type_routed_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
runtime_checkable,
)

from agnext.core import AgentRuntime, BaseAgent, CancellationToken
from agnext.core.exceptions import CantHandleException
from ..core import AgentRuntime, BaseAgent, CancellationToken
from ..core.exceptions import CantHandleException

logger = logging.getLogger("agnext")

Expand Down Expand Up @@ -132,7 +132,7 @@ async def wrapper(self: Any, message: ReceivesT, cancellation_token: Cancellatio


class TypeRoutedAgent(BaseAgent):
def __init__(self, name: str, router: AgentRuntime) -> None:
def __init__(self, name: str, description: str, runtime: AgentRuntime) -> None:
# Self is already bound to the handlers
self._handlers: Dict[
Type[Any],
Expand All @@ -147,7 +147,7 @@ def __init__(self, name: str, router: AgentRuntime) -> None:
for target_type in message_handler.target_types:
self._handlers[target_type] = message_handler

super().__init__(name, router)
super().__init__(name, description, runtime)

@property
def subscriptions(self) -> Sequence[Type[Any]]:
Expand Down
9 changes: 8 additions & 1 deletion src/agnext/core/_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Mapping, Protocol, Sequence, runtime_checkable

from agnext.core._cancellation_token import CancellationToken
from ._cancellation_token import CancellationToken


@runtime_checkable
Expand All @@ -14,6 +14,13 @@ def name(self) -> str:
"""
...

@property
def description(self) -> str:
"""Description of the agent.
A human-readable description of the agent."""
...

@property
def subscriptions(self) -> Sequence[type]:
"""Types of messages that this agent can receive."""
Expand Down
12 changes: 8 additions & 4 deletions src/agnext/core/_base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from asyncio import Future
from typing import Any, Mapping, Sequence, TypeVar

from agnext.core._agent_runtime import AgentRuntime
from agnext.core._cancellation_token import CancellationToken

from ._agent import Agent
from ._agent_runtime import AgentRuntime
from ._cancellation_token import CancellationToken

ConsumesT = TypeVar("ConsumesT")
ProducesT = TypeVar("ProducesT", covariant=True)
Expand All @@ -16,15 +15,20 @@


class BaseAgent(ABC, Agent):
def __init__(self, name: str, router: AgentRuntime) -> None:
def __init__(self, name: str, description: str, router: AgentRuntime) -> None:
self._name = name
self._description = description
self._router = router
router.add_agent(self)

@property
def name(self) -> str:
return self._name

@property
def description(self) -> str:
return self._description

@property
@abstractmethod
def subscriptions(self) -> Sequence[type]:
Expand Down
20 changes: 10 additions & 10 deletions tests/test_cancellation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ class MessageType:
# To do cancellation, only the token should be interacted with as a user
# If you cancel a future, it may not work as you expect.

class LongRunningAgent(TypeRoutedAgent):
def __init__(self, name: str, router: AgentRuntime) -> None:
super().__init__(name, router)
class LongRunningAgent(TypeRoutedAgent): # type: ignore
def __init__(self, name: str, router: AgentRuntime) -> None: # type: ignore
super().__init__(name, "A long running agent", router)
self.called = False
self.cancelled = False

@message_handler()
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
@message_handler() # type: ignore
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore
self.called = True
sleep = asyncio.ensure_future(asyncio.sleep(100))
cancellation_token.link_future(sleep)
Expand All @@ -33,15 +33,15 @@ async def on_new_message(self, message: MessageType, cancellation_token: Cancell
self.cancelled = True
raise

class NestingLongRunningAgent(TypeRoutedAgent):
def __init__(self, name: str, router: AgentRuntime, nested_agent: Agent) -> None:
super().__init__(name, router)
class NestingLongRunningAgent(TypeRoutedAgent): # type: ignore
def __init__(self, name: str, router: AgentRuntime, nested_agent: Agent) -> None: # type: ignore
super().__init__(name, "A nesting long running agent", router)
self.called = False
self.cancelled = False
self._nested_agent = nested_agent

@message_handler()
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
@message_handler() # type: ignore
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore
self.called = True
response = self._send_message(message, self._nested_agent, cancellation_token=cancellation_token)
try:
Expand Down
Loading

0 comments on commit 7ef502a

Please sign in to comment.