Skip to content

Added support for "return" handoffs (#1) #869

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions examples/basic/prompt_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import argparse
import asyncio
import random

from agents import Agent, GenerateDynamicPromptData, Runner

"""
NOTE: This example will not work out of the box, because the default prompt ID will not be available
in your project.

To use it, please:
1. Go to https://platform.openai.com/playground/prompts
2. Create a new prompt variable, `poem_style`.
3. Create a system prompt with the content:
```
Write a poem in {{poem_style}}
```
4. Run the example with the `--prompt-id` flag.
"""

DEFAULT_PROMPT_ID = "pmpt_6850729e8ba481939fd439e058c69ee004afaa19c520b78b"


class DynamicContext:
def __init__(self, prompt_id: str):
self.prompt_id = prompt_id
self.poem_style = random.choice(["limerick", "haiku", "ballad"])
print(f"[debug] DynamicContext initialized with poem_style: {self.poem_style}")


async def _get_dynamic_prompt(data: GenerateDynamicPromptData):
ctx: DynamicContext = data.context.context
return {
"id": ctx.prompt_id,
"version": "1",
"variables": {
"poem_style": ctx.poem_style,
},
}


async def dynamic_prompt(prompt_id: str):
context = DynamicContext(prompt_id)

agent = Agent(
name="Assistant",
prompt=_get_dynamic_prompt,
)

result = await Runner.run(agent, "Tell me about recursion in programming.", context=context)
print(result.final_output)


async def static_prompt(prompt_id: str):
agent = Agent(
name="Assistant",
prompt={
"id": prompt_id,
"version": "1",
"variables": {
"poem_style": "limerick",
},
},
)

result = await Runner.run(agent, "Tell me about recursion in programming.")
print(result.final_output)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dynamic", action="store_true")
parser.add_argument("--prompt-id", type=str, default=DEFAULT_PROMPT_ID)
args = parser.parse_args()

if args.dynamic:
asyncio.run(dynamic_prompt(args.prompt_id))
else:
asyncio.run(static_prompt(args.prompt_id))
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
[project]
name = "openai-agents"
version = "0.0.17"
version = "0.0.19"
description = "OpenAI Agents SDK"
readme = "README.md"
requires-python = ">=3.9"
license = "MIT"
authors = [{ name = "OpenAI", email = "[email protected]" }]
dependencies = [
"openai>=1.81.0",
"openai>=1.87.0",
"pydantic>=2.10, <3",
"griffe>=1.5.6, <2",
"typing-extensions>=4.12.2, <5",
"requests>=2.0, <3",
"types-requests>=2.0, <3",
"mcp>=1.8.0, <2; python_version >= '3.10'",
"mcp>=1.9.4, <2; python_version >= '3.10'",
]
classifiers = [
"Typing :: Typed",
Expand Down
6 changes: 6 additions & 0 deletions src/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from .models.openai_chatcompletions import OpenAIChatCompletionsModel
from .models.openai_provider import OpenAIProvider
from .models.openai_responses import OpenAIResponsesModel
from .prompts import DynamicPromptFunction, GenerateDynamicPromptData, Prompt
from .repl import run_demo_loop
from .result import RunResult, RunResultStreaming
from .run import RunConfig, Runner
Expand Down Expand Up @@ -103,6 +104,7 @@
handoff_span,
mcp_tools_span,
set_trace_processors,
set_trace_provider,
set_tracing_disabled,
set_tracing_export_api_key,
speech_group_span,
Expand Down Expand Up @@ -178,6 +180,9 @@ def enable_verbose_stdout_logging():
"AgentsException",
"InputGuardrailTripwireTriggered",
"OutputGuardrailTripwireTriggered",
"DynamicPromptFunction",
"GenerateDynamicPromptData",
"Prompt",
"MaxTurnsExceeded",
"ModelBehaviorError",
"UserError",
Expand Down Expand Up @@ -242,6 +247,7 @@ def enable_verbose_stdout_logging():
"guardrail_span",
"handoff_span",
"set_trace_processors",
"set_trace_provider",
"set_tracing_disabled",
"speech_group_span",
"transcription_span",
Expand Down
46 changes: 36 additions & 10 deletions src/agents/_run_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ class NextStepHandoff:
new_agent: Agent[Any]


@dataclass
class NextStepHandoffReturnControl:
previous_agent: Agent[Any]


@dataclass
class NextStepFinalOutput:
output: Any
Expand All @@ -201,7 +206,9 @@ class SingleStepResult:
new_step_items: list[RunItem]
"""Items generated during this current step."""

next_step: NextStepHandoff | NextStepFinalOutput | NextStepRunAgain
next_step: (
NextStepHandoff | NextStepFinalOutput | NextStepRunAgain | NextStepHandoffReturnControl
)
"""The next step to take."""

@property
Expand Down Expand Up @@ -238,6 +245,7 @@ async def execute_tools_and_side_effects(
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
previous_agents: list[Agent],
) -> SingleStepResult:
# Make a copy of the generated items
pre_step_items = list(pre_step_items)
Expand Down Expand Up @@ -286,6 +294,7 @@ async def execute_tools_and_side_effects(
hooks=hooks,
context_wrapper=context_wrapper,
run_config=run_config,
previous_agents=previous_agents,
)

# Next, we'll check if the tool use should result in a final output
Expand Down Expand Up @@ -316,6 +325,7 @@ async def execute_tools_and_side_effects(
final_output=check_tool_use.final_output,
hooks=hooks,
context_wrapper=context_wrapper,
previous_agents=previous_agents,
)

# Now we can check if the model also produced a final output
Expand All @@ -340,6 +350,7 @@ async def execute_tools_and_side_effects(
final_output=final_output,
hooks=hooks,
context_wrapper=context_wrapper,
previous_agents=previous_agents,
)
elif (
not output_schema or output_schema.is_plain_text()
Expand All @@ -353,6 +364,7 @@ async def execute_tools_and_side_effects(
final_output=potential_final_output_text or "",
hooks=hooks,
context_wrapper=context_wrapper,
previous_agents=previous_agents,
)
else:
# If there's no final output, we can just run again
Expand Down Expand Up @@ -663,6 +675,7 @@ async def execute_handoffs(
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
previous_agents: list[Agent[TContext]],
) -> SingleStepResult:
# If there is more than one handoff, add tool responses that reject those handoffs
multiple_handoffs = len(run_handoffs) > 1
Expand All @@ -684,6 +697,8 @@ async def execute_handoffs(
actual_handoff = run_handoffs[0]
with handoff_span(from_agent=agent.name) as span_handoff:
handoff = actual_handoff.handoff
if handoff.should_return_control:
previous_agents.append(agent)
new_agent: Agent[Any] = await handoff.on_invoke_handoff(
context_wrapper, actual_handoff.tool_call.arguments
)
Expand Down Expand Up @@ -825,16 +840,21 @@ async def execute_final_output(
final_output: Any,
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
previous_agents: list[Agent[TContext]],
) -> SingleStepResult:
is_returning_control = len(previous_agents) > 0
# Run the on_end hooks
await cls.run_final_output_hooks(agent, hooks, context_wrapper, final_output)

await cls.run_final_output_hooks(
agent, hooks, context_wrapper, final_output, is_returning_control
)
return SingleStepResult(
original_input=original_input,
model_response=new_response,
pre_step_items=pre_step_items,
new_step_items=new_step_items,
next_step=NextStepFinalOutput(final_output),
next_step=NextStepHandoffReturnControl(previous_agents.pop())
if is_returning_control
else NextStepFinalOutput(final_output),
)

@classmethod
Expand All @@ -844,13 +864,19 @@ async def run_final_output_hooks(
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
final_output: Any,
is_returning_control: bool,
):
await asyncio.gather(
hooks.on_agent_end(context_wrapper, agent, final_output),
agent.hooks.on_end(context_wrapper, agent, final_output)
if agent.hooks
else _coro.noop_coroutine(),
)
# If the agent is not returning control, run the hooks
if not is_returning_control:
await asyncio.gather(
hooks.on_agent_end(context_wrapper, agent, final_output),
agent.hooks.on_end(context_wrapper, agent, final_output)
if agent.hooks
else _coro.noop_coroutine(),
)
# If the agent is returning control, only run the current agent's hooks
elif agent.hooks:
await agent.hooks.on_end(context_wrapper, agent, final_output)

@classmethod
async def run_single_input_guardrail(
Expand Down
14 changes: 14 additions & 0 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast

from openai.types.responses.response_prompt_param import ResponsePromptParam
from typing_extensions import NotRequired, TypeAlias, TypedDict

from .agent_output import AgentOutputSchemaBase
Expand All @@ -17,6 +18,7 @@
from .mcp import MCPUtil
from .model_settings import ModelSettings
from .models.interface import Model
from .prompts import DynamicPromptFunction, Prompt, PromptUtil
from .run_context import RunContextWrapper, TContext
from .tool import FunctionTool, FunctionToolResult, Tool, function_tool
from .util import _transforms
Expand Down Expand Up @@ -95,6 +97,12 @@ class Agent(Generic[TContext]):
return a string.
"""

prompt: Prompt | DynamicPromptFunction | None = None
"""A prompt object (or a function that returns a Prompt). Prompts allow you to dynamically
configure the instructions, tools and other config for an agent outside of your code. Only
usable with OpenAI models, using the Responses API.
"""

handoff_description: str | None = None
"""A description of the agent. This is used when the agent is used as a handoff, so that an
LLM knows what it does and when to invoke it.
Expand Down Expand Up @@ -242,6 +250,12 @@ async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> s

return None

async def get_prompt(
self, run_context: RunContextWrapper[TContext]
) -> ResponsePromptParam | None:
"""Get the prompt for the agent."""
return await PromptUtil.to_model_input(self.prompt, run_context, self)

async def get_mcp_tools(self) -> list[Tool]:
"""Fetches the available tools from the MCP servers."""
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
Expand Down
8 changes: 7 additions & 1 deletion src/agents/extensions/models/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ async def get_response(
handoffs: list[Handoff],
tracing: ModelTracing,
previous_response_id: str | None,
prompt: Any | None = None,
) -> ModelResponse:
with generation_span(
model=str(self.model),
Expand All @@ -88,6 +89,7 @@ async def get_response(
span_generation,
tracing,
stream=False,
prompt=prompt,
)

assert isinstance(response.choices[0], litellm.types.utils.Choices)
Expand Down Expand Up @@ -153,8 +155,8 @@ async def stream_response(
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
tracing: ModelTracing,
*,
previous_response_id: str | None,
prompt: Any | None = None,
) -> AsyncIterator[TResponseStreamEvent]:
with generation_span(
model=str(self.model),
Expand All @@ -172,6 +174,7 @@ async def stream_response(
span_generation,
tracing,
stream=True,
prompt=prompt,
)

final_response: Response | None = None
Expand Down Expand Up @@ -202,6 +205,7 @@ async def _fetch_response(
span: Span[GenerationSpanData],
tracing: ModelTracing,
stream: Literal[True],
prompt: Any | None = None,
) -> tuple[Response, AsyncStream[ChatCompletionChunk]]: ...

@overload
Expand All @@ -216,6 +220,7 @@ async def _fetch_response(
span: Span[GenerationSpanData],
tracing: ModelTracing,
stream: Literal[False],
prompt: Any | None = None,
) -> litellm.types.utils.ModelResponse: ...

async def _fetch_response(
Expand All @@ -229,6 +234,7 @@ async def _fetch_response(
span: Span[GenerationSpanData],
tracing: ModelTracing,
stream: bool = False,
prompt: Any | None = None,
) -> litellm.types.utils.ModelResponse | tuple[Response, AsyncStream[ChatCompletionChunk]]:
converted_messages = Converter.items_to_messages(input)

Expand Down
3 changes: 2 additions & 1 deletion src/agents/function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ def function_schema(
doc_info = None
param_descs = {}

func_name = name_override or doc_info.name if doc_info else func.__name__
# Ensure name_override takes precedence even if docstring info is disabled.
func_name = name_override or (doc_info.name if doc_info else func.__name__)

# 2. Inspect function signature and get type hints
sig = inspect.signature(func)
Expand Down
Loading