Skip to content

Commit

Permalink
Merge branch 'main' into u/#4354
Browse files Browse the repository at this point in the history
  • Loading branch information
LittleLittleCloud authored Nov 27, 2024
2 parents 4e53aca + a4067f6 commit 2dc5246
Show file tree
Hide file tree
Showing 25 changed files with 849 additions and 183 deletions.
3 changes: 1 addition & 2 deletions protos/agent_worker.proto
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,11 @@ message Message {
oneof message {
RpcRequest request = 1;
RpcResponse response = 2;
Event event = 3;
cloudevent.CloudEvent cloudEvent = 3;
RegisterAgentTypeRequest registerAgentTypeRequest = 4;
RegisterAgentTypeResponse registerAgentTypeResponse = 5;
AddSubscriptionRequest addSubscriptionRequest = 6;
AddSubscriptionResponse addSubscriptionResponse = 7;
cloudevent.CloudEvent cloudEvent = 8;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def _create_group_chat_manager_factory(
return lambda: MagenticOneOrchestrator(
group_topic_type,
output_topic_type,
self._team_id,
participant_topic_types,
participant_descriptions,
max_turns,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from typing import Any, List

from autogen_core.base import MessageContext
from autogen_core.base import MessageContext, AgentId
from autogen_core.components import DefaultTopicId, Image, event, rpc
from autogen_core.components.models import (
AssistantMessage,
Expand Down Expand Up @@ -42,6 +42,7 @@ def __init__(
self,
group_topic_type: str,
output_topic_type: str,
team_id: str,
participant_topic_types: List[str],
participant_descriptions: List[str],
max_turns: int | None,
Expand All @@ -51,6 +52,7 @@ def __init__(
super().__init__(description="Group chat manager")
self._group_topic_type = group_topic_type
self._output_topic_type = output_topic_type
self._team_id = team_id
if len(participant_topic_types) != len(participant_descriptions):
raise ValueError("The number of participant topic types, agent types, and descriptions must be the same.")
if len(set(participant_topic_types)) != len(participant_topic_types):
Expand Down Expand Up @@ -164,10 +166,13 @@ async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:

async def _reenter_inner_loop(self) -> None:
# Reset the agents
await self.publish_message(
GroupChatReset(),
topic_id=DefaultTopicId(type=self._group_topic_type),
)
for participant_topic_type in self._participant_topic_types:
await self._runtime.send_message(
GroupChatReset(),
recipient=AgentId(type=participant_topic_type, key=self._team_id),
)
# Reset the group chat manager
await self.reset()
self._message_thread.clear()

# Prepare the ledger
Expand Down
1 change: 1 addition & 0 deletions python/packages/autogen-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"opentelemetry-api~=1.27.0",
"asyncio_atexit",
"jsonref~=1.1.0",
"grpcio~=1.62.0", # TODO: update this once we have a stable version.
]

[project.optional-dependencies]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
GRPC_IMPORT_ERROR_STR = (
"Distributed runtime features require additional dependencies. Install them with: pip install autogen-core[grpc]"
)

DATA_CONTENT_TYPE_ATTR = "datacontenttype"
DATA_SCHEMA_ATTR = "dataschema"
AGENT_SENDER_TYPE_ATTR = "agagentsendertype"
AGENT_SENDER_KEY_ATTR = "agagentsenderkey"
MESSAGE_KIND_ATTR = "agmsgkind"
MESSAGE_KIND_VALUE_PUBLISH = "publish"
MESSAGE_KIND_VALUE_RPC_REQUEST = "rpc_request"
MESSAGE_KIND_VALUE_RPC_RESPONSE = "rpc_response"
MESSAGE_KIND_VALUE_RPC_ERROR = "error"

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,15 @@
cast,
)

from google.protobuf import any_pb2
from opentelemetry.trace import TracerProvider
from typing_extensions import Self, deprecated

from autogen_core.application.protos import cloudevent_pb2

from ..base import (
JSON_DATA_CONTENT_TYPE,
PROTOBUF_DATA_CONTENT_TYPE,
Agent,
AgentId,
AgentInstantiationContext,
Expand All @@ -49,8 +53,9 @@
from ..base._serialization import MessageSerializer, SerializationRegistry
from ..base._type_helpers import ChannelArgumentType
from ..components import TypePrefixSubscription, TypeSubscription
from . import _constants
from ._constants import GRPC_IMPORT_ERROR_STR
from ._helpers import SubscriptionManager, get_impl
from ._utils import GRPC_IMPORT_ERROR_STR
from .protos import agent_worker_pb2, agent_worker_pb2_grpc
from .telemetry import MessageRuntimeTracingConfig, TraceHelper, get_telemetry_grpc_metadata

Expand Down Expand Up @@ -178,6 +183,7 @@ def __init__(
host_address: str,
tracer_provider: TracerProvider | None = None,
extra_grpc_config: ChannelArgumentType | None = None,
payload_serialization_format: str = JSON_DATA_CONTENT_TYPE,
) -> None:
self._host_address = host_address
self._trace_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("Worker Runtime"))
Expand All @@ -198,6 +204,11 @@ def __init__(
self._serialization_registry = SerializationRegistry()
self._extra_grpc_config = extra_grpc_config or []

if payload_serialization_format not in {JSON_DATA_CONTENT_TYPE, PROTOBUF_DATA_CONTENT_TYPE}:
raise ValueError(f"Unsupported payload serialization format: {payload_serialization_format}")

self._payload_serialization_format = payload_serialization_format

def start(self) -> None:
"""Start the runtime in a background task."""
if self._running:
Expand Down Expand Up @@ -236,8 +247,10 @@ async def _run_read_loop(self) -> None:
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "event":
task = asyncio.create_task(self._process_event(message.event))
case "cloudEvent":
# The proto typing doesnt resolve this one
cloud_event = cast(cloudevent_pb2.CloudEvent, message.cloudEvent) # type: ignore
task = asyncio.create_task(self._process_event(cloud_event))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
Expand All @@ -257,8 +270,6 @@ async def _run_read_loop(self) -> None:
task.add_done_callback(self._background_tasks.discard)
case None:
logger.warning("No message")
case other:
logger.error(f"Unknown message type: {other}")
except Exception as e:
logger.error("Error in read loop", exc_info=e)

Expand Down Expand Up @@ -381,30 +392,64 @@ async def publish_message(
if message_id is None:
message_id = str(uuid.uuid4())

# TODO: consume message_id

message_type = self._serialization_registry.type_name(message)
with self._trace_helper.trace_block(
"create", topic_id, parent=None, extraAttributes={"message_type": message_type}
):
serialized_message = self._serialization_registry.serialize(
message, type_name=message_type, data_content_type=JSON_DATA_CONTENT_TYPE
message, type_name=message_type, data_content_type=self._payload_serialization_format
)
telemetry_metadata = get_telemetry_grpc_metadata()
runtime_message = agent_worker_pb2.Message(
event=agent_worker_pb2.Event(
topic_type=topic_id.type,
topic_source=topic_id.source,
source=agent_worker_pb2.AgentId(type=sender.type, key=sender.key) if sender is not None else None,
metadata=telemetry_metadata,
payload=agent_worker_pb2.Payload(
data_type=message_type,
data=serialized_message,
data_content_type=JSON_DATA_CONTENT_TYPE,
),

sender_id = sender or AgentId("unknown", "unknown")
attributes = {
_constants.DATA_CONTENT_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
ce_string=self._payload_serialization_format
),
_constants.DATA_SCHEMA_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(ce_string=message_type),
_constants.AGENT_SENDER_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
ce_string=sender_id.type
),
_constants.AGENT_SENDER_KEY_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
ce_string=sender_id.key
),
_constants.MESSAGE_KIND_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
ce_string=_constants.MESSAGE_KIND_VALUE_PUBLISH
),
}

# If sending JSON we fill text_data with the serialized message
# If sending Protobuf we fill proto_data with the serialized message
# TODO: add an encoding field for serializer

if self._payload_serialization_format == JSON_DATA_CONTENT_TYPE:
runtime_message = agent_worker_pb2.Message(
cloudEvent=cloudevent_pb2.CloudEvent(
id=message_id,
spec_version="1.0",
type=topic_id.type,
source=topic_id.source,
attributes=attributes,
# TODO: use text, or proto fields appropriately
binary_data=serialized_message,
)
)
else:
# We need to unpack the serialized proto back into an Any
# TODO: find a way to prevent the roundtrip serialization
any_proto = any_pb2.Any()
any_proto.ParseFromString(serialized_message)
runtime_message = agent_worker_pb2.Message(
cloudEvent=cloudevent_pb2.CloudEvent(
id=message_id,
spec_version="1.0",
type=topic_id.type,
source=topic_id.source,
attributes=attributes,
proto_data=any_proto,
)
)
)

telemetry_metadata = get_telemetry_grpc_metadata()
task = asyncio.create_task(self._send_message(runtime_message, "publish", topic_id, telemetry_metadata))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
Expand Down Expand Up @@ -523,28 +568,58 @@ async def _process_response(self, response: agent_worker_pb2.RpcResponse) -> Non
else:
future.set_result(result)

async def _process_event(self, event: agent_worker_pb2.Event) -> None:
message = self._serialization_registry.deserialize(
event.payload.data, type_name=event.payload.data_type, data_content_type=event.payload.data_content_type
)
async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None:
event_attributes = event.attributes
sender: AgentId | None = None
if event.HasField("source"):
sender = AgentId(event.source.type, event.source.key)
topic_id = TopicId(event.topic_type, event.topic_source)
if (
_constants.AGENT_SENDER_TYPE_ATTR in event_attributes
and _constants.AGENT_SENDER_KEY_ATTR in event_attributes
):
sender = AgentId(
event_attributes[_constants.AGENT_SENDER_TYPE_ATTR].ce_string,
event_attributes[_constants.AGENT_SENDER_KEY_ATTR].ce_string,
)
topic_id = TopicId(event.type, event.source)
# Get the recipients for the topic.
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)

message_content_type = event_attributes[_constants.DATA_CONTENT_TYPE_ATTR].ce_string
message_type = event_attributes[_constants.DATA_SCHEMA_ATTR].ce_string

if message_content_type == JSON_DATA_CONTENT_TYPE:
message = self._serialization_registry.deserialize(
event.binary_data, type_name=message_type, data_content_type=message_content_type
)
elif message_content_type == PROTOBUF_DATA_CONTENT_TYPE:
# TODO: find a way to prevent the roundtrip serialization
proto_binary_data = event.proto_data.SerializeToString()
message = self._serialization_registry.deserialize(
proto_binary_data, type_name=message_type, data_content_type=message_content_type
)
else:
raise ValueError(f"Unsupported message content type: {message_content_type}")

# TODO: dont read these values in the runtime
topic_type_suffix = topic_id.type.split(":", maxsplit=1)[1] if ":" in topic_id.type else ""
is_rpc = topic_type_suffix == _constants.MESSAGE_KIND_VALUE_RPC_REQUEST
is_marked_rpc_type = (
_constants.MESSAGE_KIND_ATTR in event_attributes
and event_attributes[_constants.MESSAGE_KIND_ATTR].ce_string == _constants.MESSAGE_KIND_VALUE_RPC_REQUEST
)
if is_rpc and not is_marked_rpc_type:
warnings.warn("Received RPC request with topic type suffix but not marked as RPC request.", stacklevel=2)

# Send the message to each recipient.
responses: List[Awaitable[Any]] = []
for agent_id in recipients:
if agent_id == sender:
continue
# TODO: consume message_id
message_context = MessageContext(
sender=sender,
topic_id=topic_id,
is_rpc=False,
is_rpc=is_rpc,
cancellation_token=CancellationToken(),
message_id="NOT_DEFINED_TODO_FIX",
message_id=event.id,
)
agent = await self._get_agent(agent_id)
with MessageHandlerContext.populate_context(agent.id):
Expand All @@ -554,7 +629,7 @@ async def send_message(agent: Agent, message_context: MessageContext) -> Any:
"process",
agent.id,
parent=event.metadata,
extraAttributes={"message_type": event.payload.data_type},
extraAttributes={"message_type": message_type},
):
await agent.on_message(message, ctx=message_context)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Optional, Sequence

from ..base._type_helpers import ChannelArgumentType
from ._utils import GRPC_IMPORT_ERROR_STR
from ._constants import GRPC_IMPORT_ERROR_STR
from ._worker_runtime_host_servicer import WorkerAgentRuntimeHostServicer

try:
Expand Down
Loading

0 comments on commit 2dc5246

Please sign in to comment.