From 8c3c593edf57c8af3e01f756487f9882a6dab6ec Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 8 May 2025 12:54:30 +0000 Subject: [PATCH] Fix streaming partial tool call responses from Anthropic --- pydantic_ai_slim/pydantic_ai/models/anthropic.py | 16 ++-------------- tests/models/test_anthropic.py | 11 ++++++++--- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index e176b01cd..9265959d8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -5,7 +5,6 @@ from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime, timezone -from json import JSONDecodeError, loads as json_loads from typing import Any, Literal, Union, cast, overload from typing_extensions import assert_never @@ -441,7 +440,6 @@ class AnthropicStreamedResponse(StreamedResponse): async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: current_block: ContentBlock | None = None - current_json: str = '' async for event in self._response: self._usage += _map_usage(event) @@ -454,7 +452,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=current_block.id, tool_name=current_block.name, - args=cast(dict[str, Any], current_block.input), + args=cast(dict[str, Any], current_block.input) or None, tool_call_id=current_block.id, ) if maybe_event is not None: @@ -466,20 +464,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: elif ( current_block and event.delta.type == 'input_json_delta' and isinstance(current_block, ToolUseBlock) ): - # Try to parse the JSON immediately, otherwise cache the value for later. This handles - # cases where the JSON is not currently valid but will be valid once we stream more tokens. - try: - parsed_args = json_loads(current_json + event.delta.partial_json) - current_json = '' - except JSONDecodeError: - current_json += event.delta.partial_json - continue - - # For tool calls, we need to handle partial JSON updates maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=current_block.id, tool_name='', - args=parsed_args, + args=event.delta.partial_json, tool_call_id=current_block.id, ) if maybe_event is not None: diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 1a82e83b3..13566d06c 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -529,18 +529,23 @@ async def test_stream_structured(allow_model_requests: None): RawContentBlockStartEvent( type='content_block_start', index=0, - content_block=ToolUseBlock(type='tool_use', id='tool_1', name='my_tool', input={'first': 'One'}), + content_block=ToolUseBlock(type='tool_use', id='tool_1', name='my_tool', input={}), ), # Add more data through an incomplete JSON delta RawContentBlockDeltaEvent( type='content_block_delta', index=0, - delta=InputJSONDelta(type='input_json_delta', partial_json='{"second":'), + delta=InputJSONDelta(type='input_json_delta', partial_json='{"first": "One'), ), RawContentBlockDeltaEvent( type='content_block_delta', index=0, - delta=InputJSONDelta(type='input_json_delta', partial_json='"Two"}'), + delta=InputJSONDelta(type='input_json_delta', partial_json='", "second": "Two"'), + ), + RawContentBlockDeltaEvent( + type='content_block_delta', + index=0, + delta=InputJSONDelta(type='input_json_delta', partial_json='}'), ), # Mark tool block as complete RawContentBlockStopEvent(type='content_block_stop', index=0),