diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 891eff0a6..4ff8d1391 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -344,7 +344,7 @@ async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResu types.ListPromptsResult, ) - async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult: + async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> types.GetPromptResult: """Send a prompts/get request.""" return await self.send_request( types.ClientRequest( diff --git a/src/mcp/types.py b/src/mcp/types.py index 4a9c2bf1a..83e811e38 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -639,7 +639,7 @@ class GetPromptRequestParams(RequestParams): name: str """The name of the prompt or prompt template.""" - arguments: dict[str, str] | None = None + arguments: dict[str, Any] | None = None """Arguments to use for templating the prompt.""" model_config = ConfigDict(extra="allow") diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 327d1a9e4..e577d435d 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -495,3 +495,72 @@ async def mock_server(): assert received_capabilities.roots is not None # Custom list_roots callback provided assert isinstance(received_capabilities.roots, types.RootsCapability) assert received_capabilities.roots.listChanged is True # Should be True for custom callback + + +@pytest.mark.anyio +async def test_get_prompt_with_non_string_argument(): + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + async def mock_server(): + await client_to_server_receive.receive() + + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=0, + result=InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + await client_to_server_receive.receive() + + # Receive get_prompt request + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, types.GetPromptRequest) + assert request.root.params.arguments == {"employee_id": 77} + + # Send get_prompt result + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=types.GetPromptResult( + messages=[ + types.PromptMessage(role="user", content=types.TextContent(type="text", text="...")) + ] + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + await session.get_prompt("get_employee_profile", {"employee_id": 77})