From 332e037e7632d6d60bee4f628d966429c86f5f31 Mon Sep 17 00:00:00 2001 From: Rico Schrage Date: Wed, 13 Nov 2024 11:08:23 +0100 Subject: [PATCH] Making AgentAddress JSON serializable --- mango/agent/core.py | 8 +--- mango/messages/codecs.py | 2 + mango/messages/message.py | 18 ++++++++- .../test_message_roundtrip.py | 39 +++++++++++++++---- 4 files changed, 51 insertions(+), 16 deletions(-) diff --git a/mango/agent/core.py b/mango/agent/core.py index 27a5151..d8ca846 100644 --- a/mango/agent/core.py +++ b/mango/agent/core.py @@ -8,22 +8,16 @@ import asyncio import logging from abc import ABC -from dataclasses import dataclass from enum import Enum from typing import Any +from ..messages.message import AgentAddress from ..util.clock import Clock from ..util.scheduling import ScheduledProcessTask, ScheduledTask, Scheduler logger = logging.getLogger(__name__) -@dataclass(frozen=True, order=True) -class AgentAddress: - protocol_addr: Any - aid: str - - class State(Enum): NORMAL = 0 # normal neighbor INACTIVE = ( diff --git a/mango/messages/codecs.py b/mango/messages/codecs.py index f1af688..02df16f 100644 --- a/mango/messages/codecs.py +++ b/mango/messages/codecs.py @@ -16,6 +16,7 @@ from mango.messages.message import ( ACLMessage, + AgentAddress, MangoMessage, Performatives, enum_serializer, @@ -170,6 +171,7 @@ def __init__(self): super().__init__() self.add_serializer(*ACLMessage.__json_serializer__()) self.add_serializer(*MangoMessage.__json_serializer__()) + self.add_serializer(*AgentAddress.__serializer__()) self.add_serializer(*enum_serializer(Performatives)) def encode(self, data): diff --git a/mango/messages/message.py b/mango/messages/message.py index 3400cf3..d386dfa 100644 --- a/mango/messages/message.py +++ b/mango/messages/message.py @@ -14,11 +14,27 @@ from enum import Enum from typing import Any -from ..agent.core import AgentAddress from .acl_message_pb2 import ACLMessage as ACLProto from .mango_message_pb2 import MangoMessage as MangoMsg +@dataclass(frozen=True, order=True) +class AgentAddress: + protocol_addr: Any + aid: str + + def __asdict__(self): + return vars(self) + + @classmethod + def __fromdict__(cls, attrs): + return cls(**attrs) + + @classmethod + def __serializer__(cls): + return cls, cls.__asdict__, cls.__fromdict__ + + class Message(ABC): @abstractmethod def split_content_and_meta(self): diff --git a/tests/integration_tests/test_message_roundtrip.py b/tests/integration_tests/test_message_roundtrip.py index e91c7e2..e26887b 100644 --- a/tests/integration_tests/test_message_roundtrip.py +++ b/tests/integration_tests/test_message_roundtrip.py @@ -1,8 +1,9 @@ import asyncio +from dataclasses import dataclass import pytest -from mango import activate, addr +from mango import AgentAddress, activate, addr, json_serializable from mango.agent.core import Agent from mango.messages.codecs import JSON, PROTOBUF @@ -35,7 +36,7 @@ def string_serializer(): PROTO_CODEC.add_serializer(*string_serializer()) -async def setup_and_run_test_case(connection_type, codec): +async def setup_and_run_test_case(connection_type, codec, message=None): comm_topic = "test_topic" init_addr = ("127.0.0.1", 1555) if connection_type == "tcp" else "c1" repl_addr = ("127.0.0.1", 1556) if connection_type == "tcp" else "c2" @@ -50,8 +51,12 @@ async def setup_and_run_test_case(connection_type, codec): init_target = repl_addr repl_target = init_addr - init_agent = container_1.register(InitiatorAgent(container_1)) - repl_agent = container_2.register(ReplierAgent(container_2)) + init_agent = container_1.register( + InitiatorAgent(container_1, custom_message=message or M3) + ) + repl_agent = container_2.register( + ReplierAgent(container_2, expect_custom_message=message is not None) + ) repl_agent.target = addr(repl_target, init_agent.aid) init_agent.target = addr(init_target, repl_agent.aid) @@ -65,11 +70,12 @@ async def setup_and_run_test_case(connection_type, codec): # - answers to reply # - shuts down class InitiatorAgent(Agent): - def __init__(self, container): + def __init__(self, container, custom_message): super().__init__() self.target = None self.got_reply = asyncio.Future() self.container = container + self.custom_message = custom_message def handle_message(self, content, meta): if content == M2: @@ -90,7 +96,7 @@ async def start(self): await self.got_reply # answer to reply - await self.send_message(M3, self.target) + await self.send_message(self.custom_message, self.target) # ReplierAgent: @@ -99,10 +105,11 @@ async def start(self): # - awaits reply # - shuts down class ReplierAgent(Agent): - def __init__(self, container): + def __init__(self, container, expect_custom_message): super().__init__() self.target = None self.other_aid = None + self.expect_custom_mesage = expect_custom_message self.got_first = asyncio.Future() self.got_second = asyncio.Future() @@ -112,7 +119,7 @@ def __init__(self, container): def handle_message(self, content, meta): if content == M1: self.got_first.set_result(True) - elif content == M3: + elif content == M3 or self.expect_custom_mesage: self.got_second.set_result(True) async def start(self): @@ -136,6 +143,22 @@ async def test_tcp_json(): await setup_and_run_test_case("tcp", JSON_CODEC) +@json_serializable +@dataclass +class ABC: + AA: list[AgentAddress] + + +@pytest.mark.asyncio +async def test_tcp_json_with_complex_agentaddress(): + JSON_CODEC.add_serializer(*ABC.__serializer__()) + await setup_and_run_test_case( + "tcp", + JSON_CODEC, + ABC([AgentAddress(("123", 1), "Test"), AgentAddress(("123", 1), "Test2")]), + ) + + @pytest.mark.asyncio async def test_tcp_proto(): await setup_and_run_test_case("tcp", PROTO_CODEC)