Skip to content

Enforce mypy and fix typing issues #145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=. --name-only --
lint lint_diff:
[ "$(PYTHON_FILES)" = "" ] || uv run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || uv run ruff check $(PYTHON_FILES) --diff
# [ "$(PYTHON_FILES)" = "" ] || uv run mypy $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || uv run mypy $(PYTHON_FILES)

format format_diff:
[ "$(PYTHON_FILES)" = "" ] || uv run ruff check --fix $(PYTHON_FILES)
Expand Down
19 changes: 10 additions & 9 deletions langgraph_supervisor/agent_name.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Literal
from typing import Literal, TypeGuard, cast

from langchain_core.language_models import LanguageModelLike
from langchain_core.messages import AIMessage, BaseMessage
Expand All @@ -11,7 +11,7 @@
AgentNameMode = Literal["inline"]


def _is_content_blocks_content(content: list[dict] | str) -> bool:
def _is_content_blocks_content(content: list[dict | str] | str) -> TypeGuard[list[dict]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think there should be a list of str here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was needed because of how BaseMessage.content is typed.

return (
isinstance(content, list)
and len(content) > 0
Expand All @@ -35,12 +35,13 @@ def add_inline_agent_name(message: BaseMessage) -> BaseMessage:
return message

formatted_message = message.model_copy()
if _is_content_blocks_content(formatted_message.content):
if _is_content_blocks_content(message.content):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows the new TypeGuard returned by _is_content_blocks_content to treat message.content as list[dict]. message.content is used/iterated within this conditional, so it requires the type guard to be treated as a list[dict].

text_blocks = [block for block in message.content if block["type"] == "text"]
non_text_blocks = [block for block in message.content if block["type"] != "text"]
content = text_blocks[0]["text"] if text_blocks else ""
formatted_content = f"<name>{message.name}</name><content>{content}</content>"
formatted_message.content = [{"type": "text", "text": formatted_content}] + non_text_blocks
formatted_message_content = [{"type": "text", "text": formatted_content}] + non_text_blocks
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a weird one. Storing it in an intermediate variable avoids this mypy error from assigning the expression directly to formatted_message.content:

Incompatible types in assignment (expression has type "list[dict[str, str]]", variable has type "str | list[str | dict[Any, Any]]")

I think the only alternative would be to explicitly cast to the expected type (which I prefer to avoid if possible).

formatted_message.content = formatted_message_content
else:
formatted_message.content = (
f"<name>{message.name}</name><content>{formatted_message.content}</content>"
Expand All @@ -62,8 +63,7 @@ def remove_inline_agent_name(message: BaseMessage) -> BaseMessage:
if not isinstance(message, AIMessage) or not message.content:
return message

is_content_blocks_content = _is_content_blocks_content(message.content)
if is_content_blocks_content:
if is_content_blocks_content := _is_content_blocks_content(message.content):
text_blocks = [block for block in message.content if block["type"] == "text"]
if not text_blocks:
return message
Expand All @@ -85,7 +85,7 @@ def remove_inline_agent_name(message: BaseMessage) -> BaseMessage:
if parsed_content:
content_blocks = [{"type": "text", "text": parsed_content}] + content_blocks

parsed_message.content = content_blocks
parsed_message.content = cast(list[str | dict], content_blocks)
else:
parsed_message.content = parsed_content
return parsed_message
Expand Down Expand Up @@ -120,9 +120,10 @@ def with_agent_name(
def process_input_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
return [process_input_message(message) for message in messages]

model = (
chain = (
process_input_messages
| model
| RunnableLambda(process_output_message, name="process_output_message")
)
return model

return cast(LanguageModelLike, chain)
13 changes: 9 additions & 4 deletions langgraph_supervisor/handoff.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
import uuid
from typing import cast
from typing import TypeGuard, cast

from langchain_core.messages import AIMessage, ToolCall, ToolMessage
from langchain_core.tools import BaseTool, InjectedToolCallId, tool
Expand All @@ -18,6 +18,11 @@ def _normalize_agent_name(agent_name: str) -> str:
return WHITESPACE_RE.sub("_", agent_name.strip()).lower()


def _has_multiple_content_blocks(content: str | list[str | dict]) -> TypeGuard[list[dict]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't think it should have list of str

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above – it's due to the type of BaseMessage.content.

"""Check if content contains multiple content blocks."""
return isinstance(content, list) and len(content) > 1 and isinstance(content[0], dict)


def _remove_non_handoff_tool_calls(
last_ai_message: AIMessage, handoff_tool_call_id: str
) -> AIMessage:
Expand All @@ -26,7 +31,7 @@ def _remove_non_handoff_tool_calls(
# we need to remove tool calls that are not meant for this agent
# to ensure that the resulting message history is valid
content = last_ai_message.content
if isinstance(content, list) and len(content) > 1 and isinstance(content[0], dict):
if _has_multiple_content_blocks(content):
content = [
content_block
for content_block in content
Expand Down Expand Up @@ -80,7 +85,7 @@ def create_handoff_tool(
def handoff_to_agent(
state: Annotated[dict, InjectedState],
tool_call_id: Annotated[str, InjectedToolCallId],
):
) -> Command:
tool_message = ToolMessage(
content=f"Successfully transferred to {agent_name}",
name=name,
Expand Down Expand Up @@ -166,7 +171,7 @@ def create_forward_message_tool(supervisor_name: str = "supervisor") -> BaseTool
def forward_message(
from_agent: str,
state: Annotated[dict, InjectedState],
):
) -> str | Command:
target_message = next(
(
m
Expand Down
20 changes: 11 additions & 9 deletions langgraph_supervisor/supervisor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from typing import Any, Callable, Literal, Optional, Type, Union
from typing import Any, Callable, Literal, Optional, Type, Union, cast, get_args

from langchain_core.language_models import BaseChatModel, LanguageModelLike
from langchain_core.tools import BaseTool
Expand Down Expand Up @@ -55,9 +55,9 @@ def _make_call_agent(
add_handoff_back_messages: bool,
supervisor_name: str,
) -> Callable[[dict], dict] | RunnableCallable:
if output_mode not in OutputMode.__args__:
if output_mode not in get_args(OutputMode):
raise ValueError(
f"Invalid agent output mode: {output_mode}. Needs to be one of {OutputMode.__args__}"
f"Invalid agent output mode: {output_mode}. Needs to be one of {get_args(OutputMode)}"
)

def _process_output(output: dict) -> dict:
Expand Down Expand Up @@ -255,9 +255,9 @@ def web_search(query: str) -> str:

agent_names.add(agent.name)

handoff_destinations = _get_handoff_destinations(tools or [])
if handoff_destinations:
if missing_handoff_destinations := set(agent_names) - set(handoff_destinations):
extracted_handoff_destinations = _get_handoff_destinations(tools or [])
if extracted_handoff_destinations:
if missing_handoff_destinations := set(agent_names) - set(extracted_handoff_destinations):
raise ValueError(
"When providing custom handoff tools, you must provide them for all subagents. "
f"Missing handoff tools for agents '{missing_handoff_destinations}'."
Expand All @@ -278,12 +278,14 @@ def web_search(query: str) -> str:
)
for agent in agents
]
all_tools = (tools or []) + handoff_destinations
all_tools = (tools or []) + list(handoff_destinations)

if _supports_disable_parallel_tool_calls(model):
model = model.bind_tools(all_tools, parallel_tool_calls=parallel_tool_calls)
model = cast(BaseChatModel, model).bind_tools(
all_tools, parallel_tool_calls=parallel_tool_calls
)
else:
model = model.bind_tools(all_tools)
model = cast(BaseChatModel, model).bind_tools(all_tools)

if include_agent_name:
model = with_agent_name(model, include_agent_name)
Expand Down
14 changes: 7 additions & 7 deletions tests/test_agent_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
)


def test_add_inline_agent_name():
def test_add_inline_agent_name() -> None:
# Test that non-AI messages are returned unchanged.
human_message = HumanMessage(content="Hello")
result = add_inline_agent_name(human_message)
Expand All @@ -24,8 +24,8 @@ def test_add_inline_agent_name():
assert result.name == "assistant"


def test_add_inline_agent_name_content_blocks():
content_blocks = [
def test_add_inline_agent_name_content_blocks() -> None:
content_blocks: list[str | dict] = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be just list of dict

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

content_blocks is passed as the content parameter to AIMessage below, and that expects list[str | dict]. Because list is invariant, mypy won't accept a list[str] passed to a list[str | dict].

Perhaps it would be better for AIMessage to expect list[str] | list[dict] (or even Sequence[str | dict], since Sequence is covariant) to avoid needing to do this, but that would require a change in that library. Happy to do that if you think that's preferable.

{"type": "text", "text": "Hello world"},
{"type": "image", "image_url": "http://example.com/image.jpg"},
]
Expand All @@ -51,7 +51,7 @@ def test_add_inline_agent_name_content_blocks():
assert result.content == expected_content_blocks


def test_remove_inline_agent_name():
def test_remove_inline_agent_name() -> None:
# Test that non-AI messages are returned unchanged.
human_message = HumanMessage(content="Hello")
result = remove_inline_agent_name(human_message)
Expand All @@ -76,8 +76,8 @@ def test_remove_inline_agent_name():
assert result.name == "assistant"


def test_remove_inline_agent_name_content_blocks():
content_blocks = [
def test_remove_inline_agent_name_content_blocks() -> None:
content_blocks: list[str | dict] = [
{"type": "text", "text": "<name>assistant</name><content>Hello world</content>"},
{"type": "image", "image_url": "http://example.com/image.jpg"},
]
Expand All @@ -103,7 +103,7 @@ def test_remove_inline_agent_name_content_blocks():
assert result.content == expected_content_blocks


def test_remove_inline_agent_name_multiline_content():
def test_remove_inline_agent_name_multiline_content() -> None:
multiline_content = """<name>assistant</name><content>This is
a multiline
message</content>"""
Expand Down
46 changes: 27 additions & 19 deletions tests/test_supervisor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Tests for the supervisor module."""

from typing import Callable, Optional
from collections.abc import Callable, Sequence
from typing import Any, Optional, cast

import pytest
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.chat_models import BaseChatModel, LanguageModelInput
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool, tool
from langgraph.prebuilt import create_react_agent

Expand All @@ -17,7 +19,7 @@

class FakeChatModel(BaseChatModel):
idx: int = 0
responses: list[BaseMessage]
responses: Sequence[BaseMessage]

@property
def _llm_type(self) -> str:
Expand All @@ -28,16 +30,18 @@ def _generate(
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs,
**kwargs: dict[str, Any],
) -> ChatResult:
generation = ChatGeneration(message=self.responses[self.idx])
self.idx += 1
return ChatResult(generations=[generation])

def bind_tools(self, tools: list[BaseTool]) -> "FakeChatModel":
def bind_tools(
self, tools: Sequence[dict[str, Any] | type | Callable | BaseTool], **kwargs: Any
) -> Runnable[LanguageModelInput, BaseMessage]:
tool_dicts = [
{
"name": tool.name,
"name": tool.name if isinstance(tool, BaseTool) else str(tool),
}
for tool in tools
]
Expand Down Expand Up @@ -180,9 +184,12 @@ def web_search(query: str) -> str:
"5. **Google (Alphabet)**: 181,269 employees."
)

math_model = FakeChatModel(responses=math_agent_messages)
math_model: FakeChatModel = FakeChatModel(responses=math_agent_messages)
if include_individual_agent_name:
math_model = with_agent_name(math_model.bind_tools([add]), include_individual_agent_name)
math_model = cast(
FakeChatModel,
with_agent_name(math_model.bind_tools([add]), include_individual_agent_name),
)

math_agent = create_react_agent(
model=math_model,
Expand All @@ -192,8 +199,9 @@ def web_search(query: str) -> str:

research_model = FakeChatModel(responses=research_agent_messages)
if include_individual_agent_name:
research_model = with_agent_name(
research_model.bind_tools([web_search]), include_individual_agent_name
research_model = cast(
FakeChatModel,
with_agent_name(research_model.bind_tools([web_search]), include_individual_agent_name),
)

research_agent = create_react_agent(
Expand Down Expand Up @@ -286,13 +294,13 @@ def _generate(
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs,
**kwargs: dict[str, Any],
) -> ChatResult:
self.assertion(messages)
return super()._generate(messages, stop, run_manager, **kwargs)


def get_tool_calls(msg):
def get_tool_calls(msg: BaseMessage) -> list[dict[str, Any]] | None:
tool_calls = getattr(msg, "tool_calls", None)
if tool_calls is None:
return None
Expand All @@ -301,7 +309,7 @@ def get_tool_calls(msg):
]


def as_dict(msg):
def as_dict(msg: BaseMessage) -> dict[str, Any]:
return {
"name": msg.name,
"content": msg.content,
Expand All @@ -311,24 +319,24 @@ def as_dict(msg):


class Expectations:
def __init__(self, expected: list[dict]):
def __init__(self, expected: list[list[dict[str, Any]]]) -> None:
self.expected = expected.copy()

def __call__(self, messages: list[BaseMessage]):
def __call__(self, messages: list[BaseMessage]) -> None:
expected = self.expected.pop(0)
received = [as_dict(m) for m in messages]
assert expected == received


def test_worker_hide_handoffs():
def test_worker_hide_handoffs() -> None:
"""Test that the supervisor forwards a message to a specific agent and receives the correct response."""

@tool
def echo_tool(text: str) -> str:
"""Echo the input text."""
return text

expectations = [
expectations: list[list[dict[str, Any]]] = [
[
{
"name": None,
Expand Down Expand Up @@ -415,7 +423,7 @@ def echo_tool(text: str) -> str:
app.invoke({"messages": result["messages"] + [HumanMessage(content="Huh take two?")]})


def test_supervisor_message_forwarding():
def test_supervisor_message_forwarding() -> None:
"""Test that the supervisor forwards a message to a specific agent and receives the correct response."""

@tool
Expand Down Expand Up @@ -470,7 +478,7 @@ def echo_tool(text: str) -> str:

result = app.invoke({"messages": [HumanMessage(content="Scooby-dooby-doo")]})

def get_tool_calls(msg):
def get_tool_calls(msg: BaseMessage) -> list[dict[str, Any]] | None:
tool_calls = getattr(msg, "tool_calls", None)
if tool_calls is None:
return None
Expand Down
8 changes: 4 additions & 4 deletions tests/test_supervisor_functional_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class FakeModel(GenericFakeChatModel):
def bind_tools(self, *args, **kwargs) -> "FakeModel":
def bind_tools(self, *args: tuple, **kwargs: Any) -> "FakeModel":
"""Do nothing for now."""
return self

Expand All @@ -24,15 +24,15 @@ def test_supervisor_functional_workflow() -> None:

# Create a joke agent using functional API
@task
def generate_joke(messages: List[BaseMessage]) -> AIMessage:
def generate_joke(messages: List[BaseMessage]) -> BaseMessage:
"""Generate a joke using the model."""
return model.invoke([SystemMessage(content="Write a short joke")] + messages)
return model.invoke([SystemMessage(content="Write a short joke")] + list(messages))

@entrypoint()
def joke_agent(state: Dict[str, Any]) -> Dict[str, Any]:
"""Joke agent entrypoint."""
joke = generate_joke(state["messages"]).result()
messages = add_messages(state["messages"], [joke])
messages = add_messages(state["messages"], joke)
return {"messages": messages}

# Set agent name
Expand Down
Loading