Skip to content

Added support for passing tool_call_id via the RunContextWrapper #766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/agents/_run_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
MCPToolApprovalRequest,
Tool,
)
from .tool_context import ToolContext
from .tracing import (
SpanError,
Trace,
Expand Down Expand Up @@ -539,23 +540,24 @@ async def run_single_tool(
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
) -> Any:
with function_span(func_tool.name) as span_fn:
tool_context = ToolContext.from_agent_context(context_wrapper, tool_call.call_id)
if config.trace_include_sensitive_data:
span_fn.span_data.input = tool_call.arguments
try:
_, _, result = await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, func_tool),
hooks.on_tool_start(tool_context, agent, func_tool),
(
agent.hooks.on_tool_start(context_wrapper, agent, func_tool)
agent.hooks.on_tool_start(tool_context, agent, func_tool)
if agent.hooks
else _coro.noop_coroutine()
),
func_tool.on_invoke_tool(context_wrapper, tool_call.arguments),
func_tool.on_invoke_tool(tool_context, tool_call.arguments),
)

await asyncio.gather(
hooks.on_tool_end(context_wrapper, agent, func_tool, result),
hooks.on_tool_end(tool_context, agent, func_tool, result),
(
agent.hooks.on_tool_end(context_wrapper, agent, func_tool, result)
agent.hooks.on_tool_end(tool_context, agent, func_tool, result)
if agent.hooks
else _coro.noop_coroutine()
),
Expand Down
9 changes: 5 additions & 4 deletions src/agents/function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .exceptions import UserError
from .run_context import RunContextWrapper
from .strict_schema import ensure_strict_json_schema
from .tool_context import ToolContext


@dataclass
Expand Down Expand Up @@ -237,21 +238,21 @@ def function_schema(
ann = type_hints.get(first_name, first_param.annotation)
if ann != inspect._empty:
origin = get_origin(ann) or ann
if origin is RunContextWrapper:
if origin is RunContextWrapper or origin is ToolContext:
takes_context = True # Mark that the function takes context
else:
filtered_params.append((first_name, first_param))
else:
filtered_params.append((first_name, first_param))

# For parameters other than the first, raise error if any use RunContextWrapper.
# For parameters other than the first, raise error if any use RunContextWrapper or ToolContext.
for name, param in params[1:]:
ann = type_hints.get(name, param.annotation)
if ann != inspect._empty:
origin = get_origin(ann) or ann
if origin is RunContextWrapper:
if origin is RunContextWrapper or origin is ToolContext:
raise UserError(
f"RunContextWrapper param found at non-first position in function"
f"RunContextWrapper/ToolContext param found at non-first position in function"
f" {func.__name__}"
)
filtered_params.append((name, param))
Expand Down
14 changes: 10 additions & 4 deletions src/agents/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .items import RunItem
from .logger import logger
from .run_context import RunContextWrapper
from .tool_context import ToolContext
from .tracing import SpanError
from .util import _error_tracing
from .util._types import MaybeAwaitable
Expand All @@ -28,8 +29,13 @@

ToolFunctionWithoutContext = Callable[ToolParams, Any]
ToolFunctionWithContext = Callable[Concatenate[RunContextWrapper[Any], ToolParams], Any]
ToolFunctionWithToolContext = Callable[Concatenate[ToolContext, ToolParams], Any]

ToolFunction = Union[ToolFunctionWithoutContext[ToolParams], ToolFunctionWithContext[ToolParams]]
ToolFunction = Union[
ToolFunctionWithoutContext[ToolParams],
ToolFunctionWithContext[ToolParams],
ToolFunctionWithToolContext[ToolParams],
]


@dataclass
Expand Down Expand Up @@ -59,7 +65,7 @@ class FunctionTool:
params_json_schema: dict[str, Any]
"""The JSON schema for the tool's parameters."""

on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[Any]]
on_invoke_tool: Callable[[ToolContext[Any], str], Awaitable[Any]]
"""A function that invokes the tool with the given context and parameters. The params passed
are:
1. The tool run context.
Expand Down Expand Up @@ -330,7 +336,7 @@ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
strict_json_schema=strict_mode,
)

async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any:
async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any:
try:
json_data: dict[str, Any] = json.loads(input) if input else {}
except Exception as e:
Expand Down Expand Up @@ -379,7 +385,7 @@ async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any:

return result

async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> Any:
async def _on_invoke_tool(ctx: ToolContext[Any], input: str) -> Any:
try:
return await _on_invoke_tool_impl(ctx, input)
except Exception as e:
Expand Down
26 changes: 26 additions & 0 deletions src/agents/tool_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from dataclasses import KW_ONLY, dataclass, fields
from typing import Any

from .run_context import RunContextWrapper, TContext


@dataclass
class ToolContext(RunContextWrapper[TContext]):
"""The context of a tool call."""

_: KW_ONLY
tool_call_id: str
"""The ID of the tool call."""

@classmethod
def from_agent_context(
cls, context: RunContextWrapper[TContext], tool_call_id: str
) -> "ToolContext":
"""
Create a ToolContext from a RunContextWrapper.
"""
# Grab the names of the RunContextWrapper's init=True fields
base_values: dict[str, Any] = {
f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init
}
return cls(tool_call_id=tool_call_id, **base_values)
33 changes: 17 additions & 16 deletions tests/test_function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from agents import FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool
from agents.tool import default_tool_error_function
from agents.tool_context import ToolContext


def argless_function() -> str:
Expand All @@ -18,11 +19,11 @@ async def test_argless_function():
tool = function_tool(argless_function)
assert tool.name == "argless_function"

result = await tool.on_invoke_tool(RunContextWrapper(None), "")
result = await tool.on_invoke_tool(ToolContext(context=None, tool_call_id="1"), "")
assert result == "ok"


def argless_with_context(ctx: RunContextWrapper[str]) -> str:
def argless_with_context(ctx: ToolContext[str]) -> str:
return "ok"


Expand All @@ -31,11 +32,11 @@ async def test_argless_with_context():
tool = function_tool(argless_with_context)
assert tool.name == "argless_with_context"

result = await tool.on_invoke_tool(RunContextWrapper(None), "")
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "")
assert result == "ok"

# Extra JSON should not raise an error
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1}')
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}')
assert result == "ok"


Expand All @@ -48,15 +49,15 @@ async def test_simple_function():
tool = function_tool(simple_function, failure_error_function=None)
assert tool.name == "simple_function"

result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1}')
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}')
assert result == 6

result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1, "b": 2}')
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1, "b": 2}')
assert result == 3

# Missing required argument should raise an error
with pytest.raises(ModelBehaviorError):
await tool.on_invoke_tool(RunContextWrapper(None), "")
await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "")


class Foo(BaseModel):
Expand Down Expand Up @@ -84,7 +85,7 @@ async def test_complex_args_function():
"bar": Bar(x="hello", y=10),
}
)
result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json)
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
assert result == "6 hello10 hello"

valid_json = json.dumps(
Expand All @@ -93,7 +94,7 @@ async def test_complex_args_function():
"bar": Bar(x="hello", y=10),
}
)
result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json)
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
assert result == "3 hello10 hello"

valid_json = json.dumps(
Expand All @@ -103,12 +104,12 @@ async def test_complex_args_function():
"baz": "world",
}
)
result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json)
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
assert result == "3 hello10 world"

# Missing required argument should raise an error
with pytest.raises(ModelBehaviorError):
await tool.on_invoke_tool(RunContextWrapper(None), '{"foo": {"a": 1}}')
await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"foo": {"a": 1}}')


def test_function_config_overrides():
Expand Down Expand Up @@ -168,7 +169,7 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
assert tool.params_json_schema[key] == value
assert tool.strict_json_schema

result = await tool.on_invoke_tool(RunContextWrapper(None), '{"data": "hello"}')
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"data": "hello"}')
assert result == "hello_done"

tool_not_strict = FunctionTool(
Expand All @@ -183,7 +184,7 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
assert "additionalProperties" not in tool_not_strict.params_json_schema

result = await tool_not_strict.on_invoke_tool(
RunContextWrapper(None), '{"data": "hello", "bar": "baz"}'
ToolContext(None, tool_call_id="1"), '{"data": "hello", "bar": "baz"}'
)
assert result == "hello_done"

Expand All @@ -194,7 +195,7 @@ def my_func(a: int, b: int = 5):
raise ValueError("test")

tool = function_tool(my_func)
ctx = RunContextWrapper(None)
ctx = ToolContext(None, tool_call_id="1")

result = await tool.on_invoke_tool(ctx, "")
assert "Invalid JSON" in str(result)
Expand All @@ -218,7 +219,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
return f"error_{error.__class__.__name__}"

tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
ctx = RunContextWrapper(None)
ctx = ToolContext(None, tool_call_id="1")

result = await tool.on_invoke_tool(ctx, "")
assert result == "error_ModelBehaviorError"
Expand All @@ -242,7 +243,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
return f"error_{error.__class__.__name__}"

tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
ctx = RunContextWrapper(None)
ctx = ToolContext(None, tool_call_id="1")

result = await tool.on_invoke_tool(ctx, "")
assert result == "error_ModelBehaviorError"
Expand Down
9 changes: 5 additions & 4 deletions tests/test_function_tool_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@

from agents import function_tool
from agents.run_context import RunContextWrapper
from agents.tool_context import ToolContext


class DummyContext:
def __init__(self):
self.data = "something"


def ctx_wrapper() -> RunContextWrapper[DummyContext]:
return RunContextWrapper(DummyContext())
def ctx_wrapper() -> ToolContext[DummyContext]:
return ToolContext(context=DummyContext(), tool_call_id="1")


@function_tool
Expand Down Expand Up @@ -44,7 +45,7 @@ async def test_sync_no_context_with_args_invocation():


@function_tool
def sync_with_context(ctx: RunContextWrapper[DummyContext], name: str) -> str:
def sync_with_context(ctx: ToolContext[DummyContext], name: str) -> str:
return f"{name}_{ctx.context.data}"


Expand All @@ -71,7 +72,7 @@ async def test_async_no_context_invocation():


@function_tool
async def async_with_context(ctx: RunContextWrapper[DummyContext], prefix: str, num: int) -> str:
async def async_with_context(ctx: ToolContext[DummyContext], prefix: str, num: int) -> str:
await asyncio.sleep(0)
return f"{prefix}-{num}-{ctx.context.data}"

Expand Down
6 changes: 4 additions & 2 deletions tests/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ def _foo() -> str:
)


def get_function_tool_call(name: str, arguments: str | None = None) -> ResponseOutputItem:
def get_function_tool_call(
name: str, arguments: str | None = None, call_id: str | None = None
) -> ResponseOutputItem:
return ResponseFunctionToolCall(
id="1",
call_id="2",
call_id=call_id or "2",
type="function_call",
name=name,
arguments=arguments or "",
Expand Down
39 changes: 39 additions & 0 deletions tests/test_run_step_execution.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
from typing import Any

import pytest
Expand All @@ -26,6 +27,8 @@
RunImpl,
SingleStepResult,
)
from agents.tool import function_tool
from agents.tool_context import ToolContext

from .test_responses import (
get_final_output_message,
Expand Down Expand Up @@ -158,6 +161,42 @@ async def test_multiple_tool_calls():
assert isinstance(result.next_step, NextStepRunAgain)


@pytest.mark.asyncio
async def test_multiple_tool_calls_with_tool_context():
async def _fake_tool(context: ToolContext[str], value: str) -> str:
return f"{value}-{context.tool_call_id}"

tool = function_tool(_fake_tool, name_override="fake_tool", failure_error_function=None)

agent = Agent(
name="test",
tools=[tool],
)
response = ModelResponse(
output=[
get_function_tool_call("fake_tool", json.dumps({"value": "123"}), call_id="1"),
get_function_tool_call("fake_tool", json.dumps({"value": "456"}), call_id="2"),
],
usage=Usage(),
response_id=None,
)

result = await get_execute_result(agent, response)
assert result.original_input == "hello"

# 4 items: new message, 2 tool calls, 2 tool call outputs
assert len(result.generated_items) == 4
assert isinstance(result.next_step, NextStepRunAgain)

items = result.generated_items
assert_item_is_function_tool_call(items[0], "fake_tool", json.dumps({"value": "123"}))
assert_item_is_function_tool_call(items[1], "fake_tool", json.dumps({"value": "456"}))
assert_item_is_function_tool_call_output(items[2], "123-1")
assert_item_is_function_tool_call_output(items[3], "456-2")

assert isinstance(result.next_step, NextStepRunAgain)


@pytest.mark.asyncio
async def test_handoff_output_leads_to_handoff_next_step():
agent_1 = Agent(name="test_1")
Expand Down