Skip to content

Fix: strip reasoningContent from messages before sending to Bedrock to avoid ValidationException #652

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 131 additions & 29 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,17 @@
import json
import logging
import os
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
from typing import (
Any,
AsyncGenerator,
Callable,
Iterable,
Literal,
Optional,
Type,
TypeVar,
Union,
)

import boto3
from botocore.config import Config as BotocoreConfig
Expand Down Expand Up @@ -131,19 +141,28 @@ def __init__(
else:
new_user_agent = "strands-agents"

client_config = boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent))
client_config = boto_client_config.merge(
BotocoreConfig(user_agent_extra=new_user_agent)
)
else:
client_config = BotocoreConfig(user_agent_extra="strands-agents")

resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION
resolved_region = (
region_name
or session.region_name
or os.environ.get("AWS_REGION")
or DEFAULT_BEDROCK_REGION
)

self.client = session.client(
service_name="bedrock-runtime",
config=client_config,
region_name=resolved_region,
)

logger.debug("region=<%s> | bedrock client created", self.client.meta.region_name)
logger.debug(
"region=<%s> | bedrock client created", self.client.meta.region_name
)

@override
def update_config(self, **model_config: Unpack[BedrockConfig]) -> None: # type: ignore
Expand Down Expand Up @@ -184,7 +203,11 @@ def format_request(
"messages": self._format_bedrock_messages(messages),
"system": [
*([{"text": system_prompt}] if system_prompt else []),
*([{"cachePoint": {"type": self.config["cache_prompt"]}}] if self.config.get("cache_prompt") else []),
*(
[{"cachePoint": {"type": self.config["cache_prompt"]}}]
if self.config.get("cache_prompt")
else []
),
],
**(
{
Expand All @@ -204,12 +227,20 @@ def format_request(
else {}
),
**(
{"additionalModelRequestFields": self.config["additional_request_fields"]}
{
"additionalModelRequestFields": self.config[
"additional_request_fields"
]
}
if self.config.get("additional_request_fields")
else {}
),
**(
{"additionalModelResponseFieldPaths": self.config["additional_response_field_paths"]}
{
"additionalModelResponseFieldPaths": self.config[
"additional_response_field_paths"
]
}
if self.config.get("additional_response_field_paths")
else {}
),
Expand All @@ -220,13 +251,18 @@ def format_request(
"guardrailVersion": self.config["guardrail_version"],
"trace": self.config.get("guardrail_trace", "enabled"),
**(
{"streamProcessingMode": self.config.get("guardrail_stream_processing_mode")}
{
"streamProcessingMode": self.config.get(
"guardrail_stream_processing_mode"
)
}
if self.config.get("guardrail_stream_processing_mode")
else {}
),
}
}
if self.config.get("guardrail_id") and self.config.get("guardrail_version")
if self.config.get("guardrail_id")
and self.config.get("guardrail_version")
else {}
),
"inferenceConfig": {
Expand All @@ -241,7 +277,8 @@ def format_request(
},
**(
self.config["additional_args"]
if "additional_args" in self.config and self.config["additional_args"] is not None
if "additional_args" in self.config
and self.config["additional_args"] is not None
else {}
),
}
Expand Down Expand Up @@ -278,7 +315,9 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages:

# Keep only the required fields for Bedrock
cleaned_tool_result = ToolResult(
content=tool_result["content"], toolUseId=tool_result["toolUseId"], status=tool_result["status"]
content=tool_result["content"],
toolUseId=tool_result["toolUseId"],
status=tool_result["status"],
)

cleaned_block: ContentBlock = {"toolResult": cleaned_tool_result}
Expand All @@ -288,7 +327,9 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages:
cleaned_content.append(content_block)

# Create new message with cleaned content
cleaned_message: Message = Message(content=cleaned_content, role=message["role"])
cleaned_message: Message = Message(
content=cleaned_content, role=message["role"]
)
cleaned_messages.append(cleaned_message)

return cleaned_messages
Expand All @@ -306,11 +347,17 @@ def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool:
output_assessments = guardrail_data.get("outputAssessments", {})

# Check input assessments
if any(self._find_detected_and_blocked_policy(assessment) for assessment in input_assessment.values()):
if any(
self._find_detected_and_blocked_policy(assessment)
for assessment in input_assessment.values()
):
return True

# Check output assessments
if any(self._find_detected_and_blocked_policy(assessment) for assessment in output_assessments.values()):
if any(
self._find_detected_and_blocked_policy(assessment)
for assessment in output_assessments.values()
):
return True

return False
Expand Down Expand Up @@ -341,7 +388,8 @@ def _generate_redaction_events(self) -> list[StreamEvent]:
{
"redactContent": {
"redactAssistantContentMessage": self.config.get(
"guardrail_redact_output_message", "[Assistant output redacted.]"
"guardrail_redact_output_message",
"[Assistant output redacted.]",
)
}
}
Expand Down Expand Up @@ -384,7 +432,9 @@ def callback(event: Optional[StreamEvent] = None) -> None:
loop = asyncio.get_event_loop()
queue: asyncio.Queue[Optional[StreamEvent]] = asyncio.Queue()

thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt)
thread = asyncio.to_thread(
self._stream, callback, messages, tool_specs, system_prompt
)
task = asyncio.create_task(thread)

while True:
Expand All @@ -396,6 +446,18 @@ def callback(event: Optional[StreamEvent] = None) -> None:

await task

def _strip_reasoning_content_from_message(self, message: Message) -> Message:
# Deep copy the message to avoid mutating original
import copy

msg_copy = copy.deepcopy(message)

content = msg_copy.get("content", [])
# Filter out any content blocks with reasoningContent
filtered_content = [c for c in content if "reasoningContent" not in c]
msg_copy["content"] = filtered_content
return msg_copy

def _stream(
self,
callback: Callable[..., None],
Expand All @@ -418,8 +480,14 @@ def _stream(
ContextWindowOverflowException: If the input exceeds the model's context window.
ModelThrottledException: If the model service is throttling requests.
"""
logger.debug("formatting request")
request = self.format_request(messages, tool_specs, system_prompt)
logger.debug("stripping reasoning content from messages")
cleaned_messages = [
self._strip_reasoning_content_from_message(m) for m in messages
]

logger.debug("formatting request with cleaned messages")
request = self.format_request(cleaned_messages, tool_specs, system_prompt)

logger.debug("request=<%s>", request)

logger.debug("invoking model")
Expand Down Expand Up @@ -461,7 +529,10 @@ def _stream(
if e.response["Error"]["Code"] == "ThrottlingException":
raise ModelThrottledException(error_message) from e

if any(overflow_message in error_message for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES):
if any(
overflow_message in error_message
for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES
):
logger.warning("bedrock threw context window overflow error")
raise ContextWindowOverflowException(e) from e

Expand Down Expand Up @@ -497,7 +568,9 @@ def _stream(
callback()
logger.debug("finished streaming response from model")

def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]:
def _convert_non_streaming_to_streaming(
self, response: dict[str, Any]
) -> Iterable[StreamEvent]:
"""Convert a non-streaming response to the streaming format.

Args:
Expand Down Expand Up @@ -527,7 +600,9 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera
# For tool use, we need to yield the input as a delta
input_value = json.dumps(content["toolUse"]["input"])

yield {"contentBlockDelta": {"delta": {"toolUse": {"input": input_value}}}}
yield {
"contentBlockDelta": {"delta": {"toolUse": {"input": input_value}}}
}
elif "text" in content:
# Then yield the text as a delta
yield {
Expand All @@ -539,7 +614,13 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera
# Then yield the reasoning content as a delta
yield {
"contentBlockDelta": {
"delta": {"reasoningContent": {"text": content["reasoningContent"]["reasoningText"]["text"]}}
"delta": {
"reasoningContent": {
"text": content["reasoningContent"]["reasoningText"][
"text"
]
}
}
}
}

Expand All @@ -548,7 +629,9 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera
"contentBlockDelta": {
"delta": {
"reasoningContent": {
"signature": content["reasoningContent"]["reasoningText"]["signature"]
"signature": content["reasoningContent"][
"reasoningText"
]["signature"]
}
}
}
Expand All @@ -561,7 +644,9 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera
yield {
"messageStop": {
"stopReason": response["stopReason"],
"additionalModelResponseFields": response.get("additionalModelResponseFields"),
"additionalModelResponseFields": response.get(
"additionalModelResponseFields"
),
}
}

Expand Down Expand Up @@ -589,7 +674,11 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool:
# Check if input is a dictionary
if isinstance(input, dict):
# Check if current dictionary has action: BLOCKED and detected: true
if input.get("action") == "BLOCKED" and input.get("detected") and isinstance(input.get("detected"), bool):
if (
input.get("action") == "BLOCKED"
and input.get("detected")
and isinstance(input.get("detected"), bool)
):
return True

# Recursively check all values in the dictionary
Expand All @@ -609,7 +698,11 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool:

@override
async def structured_output(
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
self,
output_model: Type[T],
prompt: Messages,
system_prompt: Optional[str] = None,
**kwargs: Any,
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
"""Get structured output from the model.

Expand All @@ -624,14 +717,21 @@ async def structured_output(
"""
tool_spec = convert_pydantic_to_tool_spec(output_model)

response = self.stream(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, **kwargs)
response = self.stream(
messages=prompt,
tool_specs=[tool_spec],
system_prompt=system_prompt,
**kwargs,
)
async for event in streaming.process_stream(response):
yield event

stop_reason, messages, _, _ = event["stop"]

if stop_reason != "tool_use":
raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".')
raise ValueError(
f'Model returned stop_reason: {stop_reason} instead of "tool_use".'
)

content = messages["content"]
output_response: dict[str, Any] | None = None
Expand All @@ -644,6 +744,8 @@ async def structured_output(
continue

if output_response is None:
raise ValueError("No valid tool use or tool use input was found in the Bedrock response.")
raise ValueError(
"No valid tool use or tool use input was found in the Bedrock response."
)

yield {"output": output_model(**output_response)}