Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MagenticOne Orchestrator Fixes #4430

Merged
merged 8 commits into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
participants: List[ChatAgent],
model_client: ChatCompletionClient,
*,
termination_condition: TerminationCondition | None = None,
max_turns: int | None = 20,
max_stalls: int = 3,
):
Expand Down Expand Up @@ -52,4 +53,5 @@ def _create_group_chat_manager_factory(
max_turns,
self._model_client,
self._max_stalls,
termination_condition
)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, List
from typing import Any, List, Dict

from autogen_core.base import AgentId, CancellationToken, MessageContext
from autogen_core.components import DefaultTopicId, Image, event, rpc
Expand All @@ -10,12 +10,13 @@
UserMessage,
)

from ....base import Response
from ....base import Response, TerminationCondition
from ....messages import (
AgentMessage,
MultiModalMessage,
StopMessage,
TextMessage,
ChatMessage
)
from .._events import (
GroupChatAgentResponse,
Expand All @@ -25,7 +26,7 @@
GroupChatStart,
GroupChatTermination,
)
from .._sequential_routed_agent import SequentialRoutedAgent
from .._base_group_chat_manager import BaseGroupChatManager
from ._prompts import (
ORCHESTRATOR_FINAL_ANSWER_PROMPT,
ORCHESTRATOR_PROGRESS_LEDGER_PROMPT,
Expand All @@ -37,7 +38,7 @@
)


class MagenticOneOrchestrator(SequentialRoutedAgent):
class MagenticOneOrchestrator(BaseGroupChatManager):
def __init__(
self,
group_topic_type: str,
Expand All @@ -47,32 +48,26 @@ def __init__(
max_turns: int | None,
model_client: ChatCompletionClient,
max_stalls: int,
termination_condition: TerminationCondition | None,
):
super().__init__(description="Group chat manager")
self._group_topic_type = group_topic_type
self._output_topic_type = output_topic_type
if len(participant_topic_types) != len(participant_descriptions):
raise ValueError("The number of participant topic types, agent types, and descriptions must be the same.")
if len(set(participant_topic_types)) != len(participant_topic_types):
raise ValueError("The participant topic types must be unique.")
if group_topic_type in participant_topic_types:
raise ValueError("The group topic type must not be in the participant topic types.")
self._participant_topic_types = participant_topic_types
self._participant_descriptions = participant_descriptions
self._message_thread: List[AgentMessage] = []

self._name: str = "orchestrator"
self._model_client: ChatCompletionClient = model_client
self._max_turns: int | None = max_turns
self._max_stalls: int = max_stalls

self._task: str = ""
self._facts: str = ""
self._plan: str = ""
self._n_rounds: int = 0
self._n_stalls: int = 0

self._team_description: str = "\n".join(
super().__init__(
group_topic_type,
output_topic_type,
participant_topic_types,
participant_descriptions,
termination_condition,
max_turns,
)
self._model_client = model_client
self._max_stalls = max_stalls
self._name = "MagenticOneOrchestrator"
self._max_json_retries = 10
self._task = ""
self._facts = ""
self._plan = ""
self._n_rounds = 0
self._n_stalls = 0
self._team_description = "\n".join(
[
f"{topic_type}: {description}".strip()
for topic_type, description in zip(
Expand Down Expand Up @@ -135,6 +130,7 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No

assert isinstance(response.content, str)
self._plan = response.content
# TODO: add inner message for facts or plan

# Kick things off
self._n_stalls = 0
Expand All @@ -150,14 +146,24 @@ async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> No
# Reset the group chat manager.
await self.reset()

async def validate_group_state(self, message: ChatMessage | None) -> None:
pass

async def select_speaker(self, thread: List[AgentMessage]) -> str:
"""Select a speaker from the participants and return the
topic type of the selected speaker."""
return ""

async def reset(self) -> None:
"""Reset the group chat manager."""
pass
self._message_thread.clear()
if self._termination_condition is not None:
await self._termination_condition.reset()
self._n_rounds = 0
self._n_stalls = 0
self._task = ""
self._facts = ""
self._plan = ""

async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
raise ValueError(f"Unhandled message in group chat manager: {type(message)}")
Expand All @@ -170,8 +176,7 @@ async def _reenter_inner_loop(self, cancellation_token: CancellationToken) -> No
recipient=AgentId(type=participant_topic_type, key=self.id.key),
cancellation_token=cancellation_token,
)
# Reset the group chat manager
await self.reset()
# Reset partially the group chat manager
self._message_thread.clear()

# Prepare the ledger
Expand Down Expand Up @@ -212,12 +217,35 @@ async def _orchestrate_step(self, cancellation_token: CancellationToken) -> None
self._task, self._team_description, self._participant_topic_types
)
context.append(UserMessage(content=progress_ledger_prompt, source=self._name))

response = await self._model_client.create(context, json_output=True)

assert isinstance(response.content, str)
progress_ledger = json.loads(response.content)

progress_ledger: Dict[str, Any] = {}
assert self._max_json_retries > 0
key_error: bool = False
for _ in range(self._max_json_retries):
response = await self._model_client.create(context, json_output=True)
ledger_str = response.content
try:
assert isinstance(ledger_str, str)
progress_ledger = json.loads(ledger_str)
required_keys = [
"is_request_satisfied",
"is_progress_being_made",
"is_in_loop",
"instruction_or_question",
"next_speaker",
]
key_error = False
for key in required_keys:
if key not in progress_ledger or "answer" not in progress_ledger[key]:
key_error = True
break
if not key_error:
break
# TODO: add logging THAT WE ARE RETRYING
except json.JSONDecodeError:
continue
if key_error:
raise ValueError("Failed to parse ledger information after multiple retries.")
# TODO: add logging of the ledger
# Check for task completion
if progress_ledger["is_request_satisfied"]["answer"]:
await self._prepare_final_answer(progress_ledger["is_request_satisfied"]["reason"], cancellation_token)
Expand All @@ -233,11 +261,12 @@ async def _orchestrate_step(self, cancellation_token: CancellationToken) -> None

# Too much stalling
if self._n_stalls >= self._max_stalls:
# TODO: add logging
await self._update_task_ledger(cancellation_token)
await self._reenter_inner_loop(cancellation_token)
return

# Broadcst the next step
# Broadcast the next step
message = TextMessage(content=progress_ledger["instruction_or_question"]["answer"], source=self._name)
self._message_thread.append(message) # My copy

Expand All @@ -255,10 +284,21 @@ async def _orchestrate_step(self, cancellation_token: CancellationToken) -> None
)

# Request that the step be completed
valid_next_speaker: bool = False
next_speaker = progress_ledger["next_speaker"]["answer"]
await self.publish_message(
GroupChatRequestPublish(), topic_id=DefaultTopicId(type=next_speaker), cancellation_token=cancellation_token
)
for participant_topic_type in self._participant_topic_types:
if participant_topic_type == next_speaker:
await self.publish_message(
GroupChatRequestPublish(),
topic_id=DefaultTopicId(type=next_speaker),
cancellation_token=cancellation_token,
)
valid_next_speaker = True
break
if not valid_next_speaker:
raise ValueError(
f"Invalid next speaker: {next_speaker} from the ledger, participants are: {self._participant_topic_types}"
)

async def _update_task_ledger(self, cancellation_token: CancellationToken) -> None:
context = self._thread_to_context()
Expand Down
Loading