Skip to content

Commit

Permalink
Merge branch 'main' into u/#4354
Browse files Browse the repository at this point in the history
  • Loading branch information
MohMaz authored Nov 27, 2024
2 parents f2ca194 + a6ccb6f commit 4e53aca
Show file tree
Hide file tree
Showing 17 changed files with 137 additions and 154 deletions.
13 changes: 13 additions & 0 deletions docs/design/02 - Topics.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,16 @@ Agents are able to handle certain types of messages. This is an internal detail

> [!NOTE]
> This might be revisited based on scaling and performance considerations.
## Well known topic types

Agents should subscribe via a prefix subscription to the `{AgentType}:` topic as a direct message channel for the agent type.

For this subscription source should map directly to agent key.

This subscription will therefore receive all events for the following well known topics:

- `{AgentType}:` - General purpose direct messages. These should be routed to the approriate message handler.
- `{AgentType}:rpc_request` - RPC request messages. These should be routed to the approriate RPC handler.
- `{AgentType}:rpc_response={RequestId}` - RPC response messages. These should be routed back to the response future of the caller.
- `{AgentType}:error={RequestId}` - Error message that corresponds to the given request.
1 change: 1 addition & 0 deletions python/packages/autogen-agentchat/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ include = ["src/**", "tests/*.py"]
[tool.pyright]
extends = "../../pyproject.toml"
include = ["src", "tests"]
reportDeprecated = true

[tool.pytest.ini_options]
minversion = "6.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
MessageContext,
)
from autogen_core.components import ClosureAgent, TypeSubscription
from autogen_core.components._closure_agent import ClosureContext

from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
Expand Down Expand Up @@ -139,8 +140,7 @@ async def _init(self, runtime: AgentRuntime) -> None:
)

async def collect_output_messages(
_runtime: AgentRuntime,
id: AgentId,
_runtime: ClosureContext,
message: GroupChatStart | GroupChatMessage | GroupChatTermination,
ctx: MessageContext,
) -> None:
Expand All @@ -150,7 +150,7 @@ async def collect_output_messages(
return
await self._output_message_queue.put(message.message)

await ClosureAgent.register(
await ClosureAgent.register_closure(
runtime,
type=self._collector_agent_type,
closure=collect_output_messages,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import asyncio\n",
"from dataclasses import dataclass\n",
"\n",
"from autogen_core.application import SingleThreadedAgentRuntime\n",
"from autogen_core.base import AgentId, AgentRuntime, MessageContext\n",
"from autogen_core.components import ClosureAgent, DefaultSubscription, DefaultTopicId"
"from autogen_core.base import MessageContext\n",
"from autogen_core.components import ClosureAgent, ClosureContext, DefaultSubscription, DefaultTopicId"
]
},
{
Expand Down Expand Up @@ -77,11 +77,11 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"async def output_result(_runtime: AgentRuntime, id: AgentId, message: FinalResult, ctx: MessageContext) -> None:\n",
"async def output_result(_agent: ClosureContext, message: FinalResult, ctx: MessageContext) -> None:\n",
" await queue.put(message)"
]
},
Expand All @@ -94,7 +94,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -110,7 +110,9 @@
],
"source": [
"runtime = SingleThreadedAgentRuntime()\n",
"await ClosureAgent.register(runtime, \"output_result\", output_result, subscriptions=lambda: [DefaultSubscription()])"
"await ClosureAgent.register_closure(\n",
" runtime, \"output_result\", output_result, subscriptions=lambda: [DefaultSubscription()]\n",
")"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion python/packages/autogen-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ include = ["src/**", "samples/*.py", "docs/**/*.ipynb", "tests/**"]
extends = "../../pyproject.toml"
include = ["src", "tests", "samples"]
exclude = ["src/autogen_core/application/protos", "tests/protos"]
reportDeprecated = false
reportDeprecated = true

[tool.pytest.ini_options]
minversion = "6.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
WorkerAgentMessage,
)
from autogen_core.application import WorkerAgentRuntime
from autogen_core.base import AgentId, AgentRuntime, MessageContext
from autogen_core.components import ClosureAgent, DefaultSubscription, DefaultTopicId
from autogen_core.base import MessageContext
from autogen_core.components import ClosureAgent, ClosureContext, DefaultSubscription, DefaultTopicId


class MockIntentClassifier(IntentClassifierBase):
Expand All @@ -60,20 +60,20 @@ async def get_agent(self, intent: str) -> str:


async def output_result(
_runtime: AgentRuntime, id: AgentId, message: WorkerAgentMessage | FinalResult, ctx: MessageContext
closure_ctx: ClosureContext, message: WorkerAgentMessage | FinalResult, ctx: MessageContext
) -> None:
if isinstance(message, WorkerAgentMessage):
print(f"{message.source} Agent: {message.content}")
new_message = input("User response: ")
await _runtime.publish_message(
await closure_ctx.publish_message(
UserProxyMessage(content=new_message, source="user"),
topic_id=DefaultTopicId(type=message.source, source="user"),
)
else:
print(f"{message.source} Agent: {message.content}")
print("Conversation ended")
new_message = input("Enter a new conversation start: ")
await _runtime.publish_message(
await closure_ctx.publish_message(
UserProxyMessage(content=new_message, source="user"), topic_id=DefaultTopicId(type="default", source="user")
)

Expand All @@ -95,7 +95,7 @@ async def run_workers():
await agent_runtime.add_subscription(DefaultSubscription(topic_type="user_proxy", agent_type="user_proxy"))

# A closure agent surfaces the final result to external systems (e.g. an API) so that the system can interact with the user
await ClosureAgent.register(
await ClosureAgent.register_closure(
agent_runtime,
"closure_agent",
output_result,
Expand Down
20 changes: 12 additions & 8 deletions python/packages/autogen-core/samples/slow_human_in_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import AgentId, CancellationToken, MessageContext
from autogen_core.base.intervention import DefaultInterventionHandler
from autogen_core.components import DefaultSubscription, DefaultTopicId, FunctionCall, RoutedAgent, message_handler
from autogen_core.components import (
DefaultTopicId,
FunctionCall,
RoutedAgent,
message_handler,
type_subscription,
)
from autogen_core.components.model_context import BufferedChatCompletionContext
from autogen_core.components.models import (
AssistantMessage,
Expand Down Expand Up @@ -81,6 +87,7 @@ def save_content(self, content: Mapping[str, Any]) -> None:
state_persister = MockPersistence()


@type_subscription("scheduling_assistant_conversation")
class SlowUserProxyAgent(RoutedAgent):
def __init__(
self,
Expand Down Expand Up @@ -132,6 +139,7 @@ async def run(self, args: ScheduleMeetingInput, cancellation_token: Cancellation
return ScheduleMeetingOutput()


@type_subscription("scheduling_assistant_conversation")
class SchedulingAssistantAgent(RoutedAgent):
def __init__(
self,
Expand Down Expand Up @@ -256,24 +264,20 @@ async def main(latest_user_input: Optional[str] = None) -> None | str:
needs_user_input_handler = NeedsUserInputHandler()
runtime = SingleThreadedAgentRuntime(intervention_handlers=[needs_user_input_handler, termination_handler])

await runtime.register(
"User",
lambda: SlowUserProxyAgent("User", "I am a user"),
subscriptions=lambda: [DefaultSubscription("scheduling_assistant_conversation")],
)
await SlowUserProxyAgent.register(runtime, "User", lambda: SlowUserProxyAgent("User", "I am a user"))

initial_schedule_assistant_message = AssistantTextMessage(
content="Hi! How can I help you? I can help schedule meetings", source="User"
)
await runtime.register(
await SchedulingAssistantAgent.register(
runtime,
"SchedulingAssistant",
lambda: SchedulingAssistantAgent(
"SchedulingAssistant",
description="AI that helps you schedule meetings",
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
initial_message=initial_schedule_assistant_message,
),
subscriptions=lambda: [DefaultSubscription("scheduling_assistant_conversation")],
)

if latest_user_input is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from ..base._type_prefix_subscription import TypePrefixSubscription
from ._closure_agent import ClosureAgent
from ._closure_agent import ClosureAgent, ClosureContext
from ._default_subscription import DefaultSubscription, default_subscription, type_subscription
from ._default_topic import DefaultTopicId
from ._image import Image
Expand All @@ -16,6 +16,7 @@
"RoutedAgent",
"TypeRoutedAgent",
"ClosureAgent",
"ClosureContext",
"message_handler",
"event",
"rpc",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,32 +1,38 @@
from __future__ import annotations

import inspect
from typing import Any, Awaitable, Callable, List, Mapping, Sequence, TypeVar, get_type_hints
from typing import Any, Awaitable, Callable, List, Mapping, Protocol, Sequence, TypeVar, get_type_hints

from autogen_core.base._serialization import try_get_known_serializers_for_type
from autogen_core.base._subscription_context import SubscriptionInstantiationContext

from ..base import (
Agent,
AgentId,
AgentInstantiationContext,
AgentMetadata,
AgentRuntime,
AgentType,
BaseAgent,
CancellationToken,
MessageContext,
Subscription,
SubscriptionInstantiationContext,
try_get_known_serializers_for_type,
TopicId,
)
from ..base._type_helpers import get_types
from ..base.exceptions import CantHandleException

T = TypeVar("T")
ClosureAgentType = TypeVar("ClosureAgentType", bound="ClosureAgent")


def get_handled_types_from_closure(
closure: Callable[[AgentRuntime, AgentId, T, MessageContext], Awaitable[Any]],
closure: Callable[[ClosureAgent, T, MessageContext], Awaitable[Any]],
) -> Sequence[type]:
args = inspect.getfullargspec(closure)[0]
if len(args) != 4:
if len(args) != 3:
raise AssertionError("Closure must have 4 arguments")

message_arg_name = args[2]
message_arg_name = args[1]

type_hints = get_type_hints(closure)

Expand All @@ -47,9 +53,30 @@ def get_handled_types_from_closure(
return target_types


class ClosureAgent(Agent):
class ClosureContext(Protocol):
@property
def id(self) -> AgentId: ...

async def send_message(
self,
message: Any,
recipient: AgentId,
*,
cancellation_token: CancellationToken | None = None,
) -> Any: ...

async def publish_message(
self,
message: Any,
topic_id: TopicId,
*,
cancellation_token: CancellationToken | None = None,
) -> None: ...


class ClosureAgent(BaseAgent, ClosureContext):
def __init__(
self, description: str, closure: Callable[[AgentRuntime, AgentId, T, MessageContext], Awaitable[Any]]
self, description: str, closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]]
) -> None:
try:
runtime = AgentInstantiationContext.current_runtime()
Expand All @@ -65,6 +92,7 @@ def __init__(
handled_types = get_handled_types_from_closure(closure)
self._expected_types = handled_types
self._closure = closure
super().__init__(description)

@property
def metadata(self) -> AgentMetadata:
Expand All @@ -88,7 +116,7 @@ async def on_message(self, message: Any, ctx: MessageContext) -> Any:
raise CantHandleException(
f"Message type {type(message)} not in target types {self._expected_types} of {self.id}"
)
return await self._closure(self._runtime, self._id, message, ctx)
return await self._closure(self, message, ctx)

async def save_state(self) -> Mapping[str, Any]:
raise ValueError("save_state not implemented for ClosureAgent")
Expand All @@ -97,16 +125,28 @@ async def load_state(self, state: Mapping[str, Any]) -> None:
raise ValueError("load_state not implemented for ClosureAgent")

@classmethod
async def register(
async def register_closure(
cls,
runtime: AgentRuntime,
type: str,
closure: Callable[[AgentRuntime, AgentId, T, MessageContext], Awaitable[Any]],
closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]],
*,
skip_class_subscriptions: bool = False,
skip_direct_message_subscription: bool = False,
description: str = "",
subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None = None,
) -> AgentType:
agent_type = AgentType(type)
def factory() -> ClosureAgent:
return ClosureAgent(description=description, closure=closure)

agent_type = await cls.register(
runtime=runtime,
type=type,
factory=factory, # type: ignore
skip_class_subscriptions=skip_class_subscriptions,
skip_direct_message_subscription=skip_direct_message_subscription,
)

subscriptions_list: List[Subscription] = []
if subscriptions is not None:
with SubscriptionInstantiationContext.populate_context(agent_type):
Expand All @@ -117,11 +157,6 @@ async def register(
# just ignore mypy here
subscriptions_list.extend(subscriptions_list_result) # type: ignore

agent_type = await runtime.register_factory(
type=agent_type,
agent_factory=lambda: ClosureAgent(description=description, closure=closure),
expected_class=cls,
)
for subscription in subscriptions_list:
await runtime.add_subscription(subscription)

Expand Down
Loading

0 comments on commit 4e53aca

Please sign in to comment.