-
Notifications
You must be signed in to change notification settings - Fork 129
Enforce mypy and fix typing issues #145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a9667ae
00aab9e
2155204
5823026
f2a3b1e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
import re | ||
from typing import Literal | ||
from typing import Literal, TypeGuard, cast | ||
|
||
from langchain_core.language_models import LanguageModelLike | ||
from langchain_core.messages import AIMessage, BaseMessage | ||
|
@@ -11,7 +11,7 @@ | |
AgentNameMode = Literal["inline"] | ||
|
||
|
||
def _is_content_blocks_content(content: list[dict] | str) -> bool: | ||
def _is_content_blocks_content(content: list[dict | str] | str) -> TypeGuard[list[dict]]: | ||
return ( | ||
isinstance(content, list) | ||
and len(content) > 0 | ||
|
@@ -35,12 +35,13 @@ def add_inline_agent_name(message: BaseMessage) -> BaseMessage: | |
return message | ||
|
||
formatted_message = message.model_copy() | ||
if _is_content_blocks_content(formatted_message.content): | ||
if _is_content_blocks_content(message.content): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why change this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This allows the new |
||
text_blocks = [block for block in message.content if block["type"] == "text"] | ||
non_text_blocks = [block for block in message.content if block["type"] != "text"] | ||
content = text_blocks[0]["text"] if text_blocks else "" | ||
formatted_content = f"<name>{message.name}</name><content>{content}</content>" | ||
formatted_message.content = [{"type": "text", "text": formatted_content}] + non_text_blocks | ||
formatted_message_content = [{"type": "text", "text": formatted_content}] + non_text_blocks | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why change this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a weird one. Storing it in an intermediate variable avoids this
I think the only alternative would be to explicitly |
||
formatted_message.content = formatted_message_content | ||
else: | ||
formatted_message.content = ( | ||
f"<name>{message.name}</name><content>{formatted_message.content}</content>" | ||
|
@@ -62,8 +63,7 @@ def remove_inline_agent_name(message: BaseMessage) -> BaseMessage: | |
if not isinstance(message, AIMessage) or not message.content: | ||
return message | ||
|
||
is_content_blocks_content = _is_content_blocks_content(message.content) | ||
if is_content_blocks_content: | ||
if is_content_blocks_content := _is_content_blocks_content(message.content): | ||
text_blocks = [block for block in message.content if block["type"] == "text"] | ||
if not text_blocks: | ||
return message | ||
|
@@ -85,7 +85,7 @@ def remove_inline_agent_name(message: BaseMessage) -> BaseMessage: | |
if parsed_content: | ||
content_blocks = [{"type": "text", "text": parsed_content}] + content_blocks | ||
|
||
parsed_message.content = content_blocks | ||
parsed_message.content = cast(list[str | dict], content_blocks) | ||
else: | ||
parsed_message.content = parsed_content | ||
return parsed_message | ||
|
@@ -120,9 +120,10 @@ def with_agent_name( | |
def process_input_messages(messages: list[BaseMessage]) -> list[BaseMessage]: | ||
return [process_input_message(message) for message in messages] | ||
|
||
model = ( | ||
chain = ( | ||
process_input_messages | ||
| model | ||
| RunnableLambda(process_output_message, name="process_output_message") | ||
) | ||
return model | ||
|
||
return cast(LanguageModelLike, chain) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
import re | ||
import uuid | ||
from typing import cast | ||
from typing import TypeGuard, cast | ||
|
||
from langchain_core.messages import AIMessage, ToolCall, ToolMessage | ||
from langchain_core.tools import BaseTool, InjectedToolCallId, tool | ||
|
@@ -18,6 +18,11 @@ def _normalize_agent_name(agent_name: str) -> str: | |
return WHITESPACE_RE.sub("_", agent_name.strip()).lower() | ||
|
||
|
||
def _has_multiple_content_blocks(content: str | list[str | dict]) -> TypeGuard[list[dict]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't think it should have list of str There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above – it's due to the type of |
||
"""Check if content contains multiple content blocks.""" | ||
return isinstance(content, list) and len(content) > 1 and isinstance(content[0], dict) | ||
|
||
|
||
def _remove_non_handoff_tool_calls( | ||
last_ai_message: AIMessage, handoff_tool_call_id: str | ||
) -> AIMessage: | ||
|
@@ -26,7 +31,7 @@ def _remove_non_handoff_tool_calls( | |
# we need to remove tool calls that are not meant for this agent | ||
# to ensure that the resulting message history is valid | ||
content = last_ai_message.content | ||
if isinstance(content, list) and len(content) > 1 and isinstance(content[0], dict): | ||
if _has_multiple_content_blocks(content): | ||
content = [ | ||
content_block | ||
for content_block in content | ||
|
@@ -80,7 +85,7 @@ def create_handoff_tool( | |
def handoff_to_agent( | ||
state: Annotated[dict, InjectedState], | ||
tool_call_id: Annotated[str, InjectedToolCallId], | ||
): | ||
) -> Command: | ||
tool_message = ToolMessage( | ||
content=f"Successfully transferred to {agent_name}", | ||
name=name, | ||
|
@@ -166,7 +171,7 @@ def create_forward_message_tool(supervisor_name: str = "supervisor") -> BaseTool | |
def forward_message( | ||
from_agent: str, | ||
state: Annotated[dict, InjectedState], | ||
): | ||
) -> str | Command: | ||
target_message = next( | ||
( | ||
m | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
) | ||
|
||
|
||
def test_add_inline_agent_name(): | ||
def test_add_inline_agent_name() -> None: | ||
# Test that non-AI messages are returned unchanged. | ||
human_message = HumanMessage(content="Hello") | ||
result = add_inline_agent_name(human_message) | ||
|
@@ -24,8 +24,8 @@ def test_add_inline_agent_name(): | |
assert result.name == "assistant" | ||
|
||
|
||
def test_add_inline_agent_name_content_blocks(): | ||
content_blocks = [ | ||
def test_add_inline_agent_name_content_blocks() -> None: | ||
content_blocks: list[str | dict] = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be just list of dict There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Perhaps it would be better for |
||
{"type": "text", "text": "Hello world"}, | ||
{"type": "image", "image_url": "http://example.com/image.jpg"}, | ||
] | ||
|
@@ -51,7 +51,7 @@ def test_add_inline_agent_name_content_blocks(): | |
assert result.content == expected_content_blocks | ||
|
||
|
||
def test_remove_inline_agent_name(): | ||
def test_remove_inline_agent_name() -> None: | ||
# Test that non-AI messages are returned unchanged. | ||
human_message = HumanMessage(content="Hello") | ||
result = remove_inline_agent_name(human_message) | ||
|
@@ -76,8 +76,8 @@ def test_remove_inline_agent_name(): | |
assert result.name == "assistant" | ||
|
||
|
||
def test_remove_inline_agent_name_content_blocks(): | ||
content_blocks = [ | ||
def test_remove_inline_agent_name_content_blocks() -> None: | ||
content_blocks: list[str | dict] = [ | ||
{"type": "text", "text": "<name>assistant</name><content>Hello world</content>"}, | ||
{"type": "image", "image_url": "http://example.com/image.jpg"}, | ||
] | ||
|
@@ -103,7 +103,7 @@ def test_remove_inline_agent_name_content_blocks(): | |
assert result.content == expected_content_blocks | ||
|
||
|
||
def test_remove_inline_agent_name_multiline_content(): | ||
def test_remove_inline_agent_name_multiline_content() -> None: | ||
multiline_content = """<name>assistant</name><content>This is | ||
a multiline | ||
message</content>""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't think there should be a list of str here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was needed because of how
BaseMessage.content
is typed.