diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index a75c5e82..089191da 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -176,6 +176,11 @@ class NextStepHandoff: new_agent: Agent[Any] +@dataclass +class NextStepHandoffReturnControl: + previous_agent: Agent[Any] + + @dataclass class NextStepFinalOutput: output: Any @@ -201,7 +206,9 @@ class SingleStepResult: new_step_items: list[RunItem] """Items generated during this current step.""" - next_step: NextStepHandoff | NextStepFinalOutput | NextStepRunAgain + next_step: ( + NextStepHandoff | NextStepFinalOutput | NextStepRunAgain | NextStepHandoffReturnControl + ) """The next step to take.""" @property @@ -238,6 +245,7 @@ async def execute_tools_and_side_effects( hooks: RunHooks[TContext], context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, + previous_agents: list[Agent], ) -> SingleStepResult: # Make a copy of the generated items pre_step_items = list(pre_step_items) @@ -286,6 +294,7 @@ async def execute_tools_and_side_effects( hooks=hooks, context_wrapper=context_wrapper, run_config=run_config, + previous_agents=previous_agents, ) # Next, we'll check if the tool use should result in a final output @@ -316,6 +325,7 @@ async def execute_tools_and_side_effects( final_output=check_tool_use.final_output, hooks=hooks, context_wrapper=context_wrapper, + previous_agents=previous_agents, ) # Now we can check if the model also produced a final output @@ -340,6 +350,7 @@ async def execute_tools_and_side_effects( final_output=final_output, hooks=hooks, context_wrapper=context_wrapper, + previous_agents=previous_agents, ) elif ( not output_schema or output_schema.is_plain_text() @@ -353,6 +364,7 @@ async def execute_tools_and_side_effects( final_output=potential_final_output_text or "", hooks=hooks, context_wrapper=context_wrapper, + previous_agents=previous_agents, ) else: # If there's no final output, we can just run again @@ -663,6 +675,7 @@ async def execute_handoffs( hooks: RunHooks[TContext], context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, + previous_agents: list[Agent[TContext]], ) -> SingleStepResult: # If there is more than one handoff, add tool responses that reject those handoffs multiple_handoffs = len(run_handoffs) > 1 @@ -684,6 +697,8 @@ async def execute_handoffs( actual_handoff = run_handoffs[0] with handoff_span(from_agent=agent.name) as span_handoff: handoff = actual_handoff.handoff + if handoff.should_return_control: + previous_agents.append(agent) new_agent: Agent[Any] = await handoff.on_invoke_handoff( context_wrapper, actual_handoff.tool_call.arguments ) @@ -825,16 +840,21 @@ async def execute_final_output( final_output: Any, hooks: RunHooks[TContext], context_wrapper: RunContextWrapper[TContext], + previous_agents: list[Agent[TContext]], ) -> SingleStepResult: + is_returning_control = len(previous_agents) > 0 # Run the on_end hooks - await cls.run_final_output_hooks(agent, hooks, context_wrapper, final_output) - + await cls.run_final_output_hooks( + agent, hooks, context_wrapper, final_output, is_returning_control + ) return SingleStepResult( original_input=original_input, model_response=new_response, pre_step_items=pre_step_items, new_step_items=new_step_items, - next_step=NextStepFinalOutput(final_output), + next_step=NextStepHandoffReturnControl(previous_agents.pop()) + if is_returning_control + else NextStepFinalOutput(final_output), ) @classmethod @@ -844,13 +864,19 @@ async def run_final_output_hooks( hooks: RunHooks[TContext], context_wrapper: RunContextWrapper[TContext], final_output: Any, + is_returning_control: bool, ): - await asyncio.gather( - hooks.on_agent_end(context_wrapper, agent, final_output), - agent.hooks.on_end(context_wrapper, agent, final_output) - if agent.hooks - else _coro.noop_coroutine(), - ) + # If the agent is not returning control, run the hooks + if not is_returning_control: + await asyncio.gather( + hooks.on_agent_end(context_wrapper, agent, final_output), + agent.hooks.on_end(context_wrapper, agent, final_output) + if agent.hooks + else _coro.noop_coroutine(), + ) + # If the agent is returning control, only run the current agent's hooks + elif agent.hooks: + await agent.hooks.on_end(context_wrapper, agent, final_output) @classmethod async def run_single_input_guardrail( diff --git a/src/agents/handoffs.py b/src/agents/handoffs.py index cb2752e4..71f615e4 100644 --- a/src/agents/handoffs.py +++ b/src/agents/handoffs.py @@ -105,6 +105,12 @@ class Handoff(Generic[TContext]): agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable a handoff based on your context/state.""" + should_return_control: bool = False + """Whether the Agent that receives control during a handoff should return control to the + original (previous) Agent upon completion of its work. If False, after the Agent that received + the handoff completes its work, the interaction will end. + """ + def get_transfer_message(self, agent: Agent[Any]) -> str: return json.dumps({"assistant": agent.name}) @@ -128,6 +134,7 @@ def handoff( tool_description_override: str | None = None, input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None, is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, + should_return_control: bool = False, ) -> Handoff[TContext]: ... @@ -141,6 +148,7 @@ def handoff( tool_name_override: str | None = None, input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None, is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, + should_return_control: bool = False, ) -> Handoff[TContext]: ... @@ -153,6 +161,7 @@ def handoff( tool_name_override: str | None = None, input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None, is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, + should_return_control: bool = False, ) -> Handoff[TContext]: ... @@ -164,6 +173,7 @@ def handoff( input_type: type[THandoffInput] | None = None, input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None, is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, + should_return_control: bool = False, ) -> Handoff[TContext]: """Create a handoff from an agent. @@ -181,7 +191,7 @@ def handoff( hidden from the LLM at runtime. """ assert (on_handoff and input_type) or not (on_handoff and input_type), ( - "You must provide either both on_handoff and input_type, or neither" + "You must provide either both on_input and input_type, or neither" ) type_adapter: TypeAdapter[Any] | None if input_type is not None: @@ -247,4 +257,5 @@ async def _invoke_handoff( input_filter=input_filter, agent_name=agent.name, is_enabled=is_enabled, + should_return_control=should_return_control, ) diff --git a/src/agents/run.py b/src/agents/run.py index e5f9378e..8a20ab64 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -16,6 +16,7 @@ AgentToolUseTracker, NextStepFinalOutput, NextStepHandoff, + NextStepHandoffReturnControl, NextStepRunAgain, QueueCompleteSentinel, RunImpl, @@ -156,6 +157,9 @@ class RunOptions(TypedDict, Generic[TContext]): previous_response_id: NotRequired[str | None] """The ID of the previous response, if any.""" + previous_agents: NotRequired[list[Agent[TContext]] | None] + """The agents that have been run before, and wish to regain control of the run.""" + class Runner: @classmethod @@ -169,6 +173,7 @@ async def run( hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, + previous_agents: list[Agent[TContext]] | None = None, ) -> RunResult: """Run a workflow starting at the given agent. The agent will run in a loop until a final output is generated. The loop runs like so: @@ -205,6 +210,7 @@ async def run( hooks=hooks, run_config=run_config, previous_response_id=previous_response_id, + previous_agents=previous_agents, ) @classmethod @@ -218,6 +224,7 @@ def run_sync( hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, + previous_agents: list[Agent[TContext]] | None = None, ) -> RunResult: """Run a workflow synchronously, starting at the given agent. Note that this just wraps the `run` method, so it will not work if there's already an event loop (e.g. inside an async @@ -257,6 +264,7 @@ def run_sync( hooks=hooks, run_config=run_config, previous_response_id=previous_response_id, + previous_agents=previous_agents, ) @classmethod @@ -269,6 +277,7 @@ def run_streamed( hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, + previous_agents: list[Agent[TContext]] | None = None, ) -> RunResultStreaming: """Run a workflow starting at the given agent in streaming mode. The returned result object contains a method you can use to stream semantic events as they are generated. @@ -305,6 +314,7 @@ def run_streamed( hooks=hooks, run_config=run_config, previous_response_id=previous_response_id, + previous_agents=previous_agents, ) @@ -325,10 +335,13 @@ async def run( hooks = kwargs.get("hooks") run_config = kwargs.get("run_config") previous_response_id = kwargs.get("previous_response_id") + previous_agents = kwargs.get("previous_agents") if hooks is None: hooks = RunHooks[Any]() if run_config is None: run_config = RunConfig() + if previous_agents is None: + previous_agents = [] tool_use_tracker = AgentToolUseTracker() @@ -413,6 +426,7 @@ async def run( should_run_agent_start_hooks=should_run_agent_start_hooks, tool_use_tracker=tool_use_tracker, previous_response_id=previous_response_id, + previous_agents=previous_agents, ), ) else: @@ -427,6 +441,7 @@ async def run( should_run_agent_start_hooks=should_run_agent_start_hooks, tool_use_tracker=tool_use_tracker, previous_response_id=previous_response_id, + previous_agents=previous_agents, ) should_run_agent_start_hooks = False @@ -451,8 +466,13 @@ async def run( output_guardrail_results=output_guardrail_results, context_wrapper=context_wrapper, ) - elif isinstance(turn_result.next_step, NextStepHandoff): - current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) + elif isinstance(turn_result.next_step, NextStepHandoff) or isinstance( + turn_result.next_step, NextStepHandoffReturnControl + ): + if isinstance(turn_result.next_step, NextStepHandoffReturnControl): + current_agent = turn_result.next_step.previous_agent + else: + current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) current_span.finish(reset_current=True) current_span = None should_run_agent_start_hooks = True @@ -488,6 +508,7 @@ def run_sync( hooks = kwargs.get("hooks") run_config = kwargs.get("run_config") previous_response_id = kwargs.get("previous_response_id") + previous_agents = kwargs.get("previous_agents") return asyncio.get_event_loop().run_until_complete( self.run( starting_agent, @@ -497,6 +518,7 @@ def run_sync( hooks=hooks, run_config=run_config, previous_response_id=previous_response_id, + previous_agents=previous_agents, ) ) @@ -511,10 +533,13 @@ def run_streamed( hooks = kwargs.get("hooks") run_config = kwargs.get("run_config") previous_response_id = kwargs.get("previous_response_id") + previous_agents = kwargs.get("previous_agents") if hooks is None: hooks = RunHooks[Any]() if run_config is None: run_config = RunConfig() + if previous_agents is None: + previous_agents = [] # If there's already a trace, we don't create a new one. In addition, we can't end the # trace here, because the actual work is done in `stream_events` and this method ends @@ -563,6 +588,7 @@ def run_streamed( context_wrapper=context_wrapper, run_config=run_config, previous_response_id=previous_response_id, + previous_agents=previous_agents, ) ) return streamed_result @@ -621,6 +647,7 @@ async def _start_streaming( context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, previous_response_id: str | None, + previous_agents: list[Agent[TContext]], ): if streamed_result.trace: streamed_result.trace.start(mark_as_current=True) @@ -697,6 +724,7 @@ async def _start_streaming( tool_use_tracker, all_tools, previous_response_id, + previous_agents, ) should_run_agent_start_hooks = False @@ -706,8 +734,14 @@ async def _start_streaming( streamed_result.input = turn_result.original_input streamed_result.new_items = turn_result.generated_items - if isinstance(turn_result.next_step, NextStepHandoff): - current_agent = turn_result.next_step.new_agent + if isinstance(turn_result.next_step, NextStepHandoff) or isinstance( + turn_result.next_step, NextStepHandoffReturnControl + ): + if isinstance(turn_result.next_step, NextStepHandoff): + current_agent = turn_result.next_step.new_agent + else: + current_agent = turn_result.next_step.previous_agent + current_span.finish(reset_current=True) current_span = None should_run_agent_start_hooks = True @@ -782,6 +816,7 @@ async def _run_single_turn_streamed( tool_use_tracker: AgentToolUseTracker, all_tools: list[Tool], previous_response_id: str | None, + previous_agents: list[Agent[TContext]], ) -> SingleStepResult: if should_run_agent_start_hooks: await asyncio.gather( @@ -866,6 +901,7 @@ async def _run_single_turn_streamed( context_wrapper=context_wrapper, run_config=run_config, tool_use_tracker=tool_use_tracker, + previous_agents=previous_agents, ) RunImpl.stream_step_result_to_queue(single_step_result, streamed_result._event_queue) @@ -885,6 +921,7 @@ async def _run_single_turn( should_run_agent_start_hooks: bool, tool_use_tracker: AgentToolUseTracker, previous_response_id: str | None, + previous_agents: list[Agent[TContext]], ) -> SingleStepResult: # Ensure we run the hooks before anything else if should_run_agent_start_hooks: @@ -933,6 +970,7 @@ async def _run_single_turn( context_wrapper=context_wrapper, run_config=run_config, tool_use_tracker=tool_use_tracker, + previous_agents=previous_agents, ) @classmethod @@ -950,6 +988,7 @@ async def _get_single_step_result_from_response( context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, tool_use_tracker: AgentToolUseTracker, + previous_agents: list[Agent[TContext]], ) -> SingleStepResult: processed_response = RunImpl.process_model_response( agent=agent, @@ -971,6 +1010,7 @@ async def _get_single_step_result_from_response( hooks=hooks, context_wrapper=context_wrapper, run_config=run_config, + previous_agents=previous_agents, ) @classmethod diff --git a/tests/test_run_step_execution.py b/tests/test_run_step_execution.py index 4cf9ae83..7e82a0ce 100644 --- a/tests/test_run_step_execution.py +++ b/tests/test_run_step_execution.py @@ -22,10 +22,12 @@ from agents._run_impl import ( NextStepFinalOutput, NextStepHandoff, + NextStepHandoffReturnControl, NextStepRunAgain, RunImpl, SingleStepResult, ) +from agents.handoffs import handoff from agents.run import AgentRunner from agents.tool import function_tool from agents.tool_context import ToolContext @@ -215,6 +217,46 @@ async def test_handoff_output_leads_to_handoff_next_step(): assert len(result.generated_items) == 3 +@pytest.mark.asyncio +async def test_handoff_output_leads_to_handoff_return_control_next_step(): + agent_1 = Agent(name="test_1") + agent_2 = Agent(name="test_2") + agent_3 = Agent(name="test_3", handoffs=[handoff(agent_1, should_return_control=True), agent_2]) + response = ModelResponse( + output=[get_text_message("Hello, world!"), get_handoff_tool_call(agent_1)], + usage=Usage(), + response_id=None, + ) + previous_agents: list[Agent[Any]] = [] + result = await get_execute_result(agent_3, response, previous_agents=previous_agents) + + assert isinstance(result.next_step, NextStepHandoff) + assert result.next_step.new_agent == agent_1 + assert len(previous_agents) == 1 + assert previous_agents[0] == agent_3 + assert len(result.generated_items) == 3 + + +@pytest.mark.asyncio +async def test_completion_of_handoff_returns_control_to_previous_agent(): + last_agent = Agent(name="last_agent") + sub_agent = Agent(name="sub_agent", handoffs=[last_agent]) + main_agent = Agent(name="main_agent", handoffs=[sub_agent]) + response = ModelResponse( + output=[get_text_message("Completed everything")], + usage=Usage(), + response_id=None, + ) + previous_agents = [main_agent] + result = await get_execute_result(last_agent, response, previous_agents=previous_agents) + + assert isinstance(result.next_step, NextStepHandoffReturnControl) + assert result.next_step.previous_agent == main_agent + assert len(previous_agents) == 0 + assert len(result.generated_items) == 1 + assert_item_is_message(result.generated_items[0], "Completed everything") + + class Foo(BaseModel): bar: str @@ -323,9 +365,11 @@ async def get_execute_result( hooks: RunHooks[Any] | None = None, context_wrapper: RunContextWrapper[Any] | None = None, run_config: RunConfig | None = None, + previous_agents: list[Agent[Any]] | None = None, ) -> SingleStepResult: output_schema = AgentRunner._get_output_schema(agent) handoffs = await AgentRunner._get_handoffs(agent, context_wrapper or RunContextWrapper(None)) + previous_agents = previous_agents if previous_agents is not None else [] processed_response = RunImpl.process_model_response( agent=agent, @@ -344,4 +388,5 @@ async def get_execute_result( hooks=hooks or RunHooks(), context_wrapper=context_wrapper or RunContextWrapper(None), run_config=run_config or RunConfig(), + previous_agents=previous_agents, )