Skip to content

Commit

Permalink
Merge pull request #138 from OFFIS-DAI/enh-json-ser-agentaddress
Browse files Browse the repository at this point in the history
Making AgentAddress JSON serializable
  • Loading branch information
rcschrg authored Nov 13, 2024
2 parents ed859ce + 332e037 commit 40c2dd9
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 16 deletions.
8 changes: 1 addition & 7 deletions mango/agent/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,16 @@
import asyncio
import logging
from abc import ABC
from dataclasses import dataclass
from enum import Enum
from typing import Any

from ..messages.message import AgentAddress
from ..util.clock import Clock
from ..util.scheduling import ScheduledProcessTask, ScheduledTask, Scheduler

logger = logging.getLogger(__name__)


@dataclass(frozen=True, order=True)
class AgentAddress:
protocol_addr: Any
aid: str


class State(Enum):
NORMAL = 0 # normal neighbor
INACTIVE = (
Expand Down
2 changes: 2 additions & 0 deletions mango/messages/codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from mango.messages.message import (
ACLMessage,
AgentAddress,
MangoMessage,
Performatives,
enum_serializer,
Expand Down Expand Up @@ -170,6 +171,7 @@ def __init__(self):
super().__init__()
self.add_serializer(*ACLMessage.__json_serializer__())
self.add_serializer(*MangoMessage.__json_serializer__())
self.add_serializer(*AgentAddress.__serializer__())
self.add_serializer(*enum_serializer(Performatives))

def encode(self, data):
Expand Down
18 changes: 17 additions & 1 deletion mango/messages/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,27 @@
from enum import Enum
from typing import Any

from ..agent.core import AgentAddress
from .acl_message_pb2 import ACLMessage as ACLProto
from .mango_message_pb2 import MangoMessage as MangoMsg


@dataclass(frozen=True, order=True)
class AgentAddress:
protocol_addr: Any
aid: str

def __asdict__(self):
return vars(self)

@classmethod
def __fromdict__(cls, attrs):
return cls(**attrs)

@classmethod
def __serializer__(cls):
return cls, cls.__asdict__, cls.__fromdict__


class Message(ABC):
@abstractmethod
def split_content_and_meta(self):
Expand Down
39 changes: 31 additions & 8 deletions tests/integration_tests/test_message_roundtrip.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
from dataclasses import dataclass

import pytest

from mango import activate, addr
from mango import AgentAddress, activate, addr, json_serializable
from mango.agent.core import Agent
from mango.messages.codecs import JSON, PROTOBUF

Expand Down Expand Up @@ -35,7 +36,7 @@ def string_serializer():
PROTO_CODEC.add_serializer(*string_serializer())


async def setup_and_run_test_case(connection_type, codec):
async def setup_and_run_test_case(connection_type, codec, message=None):
comm_topic = "test_topic"
init_addr = ("127.0.0.1", 1555) if connection_type == "tcp" else "c1"
repl_addr = ("127.0.0.1", 1556) if connection_type == "tcp" else "c2"
Expand All @@ -50,8 +51,12 @@ async def setup_and_run_test_case(connection_type, codec):
init_target = repl_addr
repl_target = init_addr

init_agent = container_1.register(InitiatorAgent(container_1))
repl_agent = container_2.register(ReplierAgent(container_2))
init_agent = container_1.register(
InitiatorAgent(container_1, custom_message=message or M3)
)
repl_agent = container_2.register(
ReplierAgent(container_2, expect_custom_message=message is not None)
)
repl_agent.target = addr(repl_target, init_agent.aid)
init_agent.target = addr(init_target, repl_agent.aid)

Expand All @@ -65,11 +70,12 @@ async def setup_and_run_test_case(connection_type, codec):
# - answers to reply
# - shuts down
class InitiatorAgent(Agent):
def __init__(self, container):
def __init__(self, container, custom_message):
super().__init__()
self.target = None
self.got_reply = asyncio.Future()
self.container = container
self.custom_message = custom_message

def handle_message(self, content, meta):
if content == M2:
Expand All @@ -90,7 +96,7 @@ async def start(self):
await self.got_reply

# answer to reply
await self.send_message(M3, self.target)
await self.send_message(self.custom_message, self.target)


# ReplierAgent:
Expand All @@ -99,10 +105,11 @@ async def start(self):
# - awaits reply
# - shuts down
class ReplierAgent(Agent):
def __init__(self, container):
def __init__(self, container, expect_custom_message):
super().__init__()
self.target = None
self.other_aid = None
self.expect_custom_mesage = expect_custom_message

self.got_first = asyncio.Future()
self.got_second = asyncio.Future()
Expand All @@ -112,7 +119,7 @@ def __init__(self, container):
def handle_message(self, content, meta):
if content == M1:
self.got_first.set_result(True)
elif content == M3:
elif content == M3 or self.expect_custom_mesage:
self.got_second.set_result(True)

async def start(self):
Expand All @@ -136,6 +143,22 @@ async def test_tcp_json():
await setup_and_run_test_case("tcp", JSON_CODEC)


@json_serializable
@dataclass
class ABC:
AA: list[AgentAddress]


@pytest.mark.asyncio
async def test_tcp_json_with_complex_agentaddress():
JSON_CODEC.add_serializer(*ABC.__serializer__())
await setup_and_run_test_case(
"tcp",
JSON_CODEC,
ABC([AgentAddress(("123", 1), "Test"), AgentAddress(("123", 1), "Test2")]),
)


@pytest.mark.asyncio
async def test_tcp_proto():
await setup_and_run_test_case("tcp", PROTO_CODEC)
Expand Down

0 comments on commit 40c2dd9

Please sign in to comment.