diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 90b4eb27c..017af2848 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -351,11 +351,21 @@ async def _receive_loop(self) -> None: if isinstance(message, Exception): await self._handle_incoming(message) elif isinstance(message.message.root, JSONRPCRequest): - validated_request = self._receive_request_type.model_validate( - message.message.root.model_dump( - by_alias=True, mode="json", exclude_none=True + try: + validated_request = self._receive_request_type.model_validate( + message.message.root.model_dump( + by_alias=True, mode="json", exclude_none=True + ) ) - ) + except Exception as e: + # For other validation errors, log and continue + logging.warning( + "Failed to validate request: %s. Message was: %s", + e, + message.message.root, + ) + continue + responder = RequestResponder( request_id=message.message.root.id, request_meta=validated_request.root.params.meta @@ -379,32 +389,40 @@ async def _receive_loop(self) -> None: by_alias=True, mode="json", exclude_none=True ) ) - # Handle cancellation notifications - if isinstance(notification.root, CancelledNotification): - cancelled_id = notification.root.params.requestId - if cancelled_id in self._in_flight: - await self._in_flight[cancelled_id].cancel() - else: - # Handle progress notifications callback - if isinstance(notification.root, ProgressNotification): - progress_token = notification.root.params.progressToken - # If there is a progress callback for this token, - # call it with the progress information - if progress_token in self._progress_callbacks: - callback = self._progress_callbacks[progress_token] + except Exception as e: + # For other validation errors, log and continue + logging.warning( + "Failed to validate notification: %s. Message was: %s", + e, + message.message.root, + ) + continue + # Handle cancellation notifications + if isinstance(notification.root, CancelledNotification): + cancelled_id = notification.root.params.requestId + if cancelled_id in self._in_flight: + await self._in_flight[cancelled_id].cancel() + else: + # Handle progress notifications callback + if isinstance(notification.root, ProgressNotification): + progress_token = notification.root.params.progressToken + # If there is a progress callback for this token, + # call it with the progress information + if progress_token in self._progress_callbacks: + callback = self._progress_callbacks[progress_token] + try: await callback( notification.root.params.progress, notification.root.params.total, notification.root.params.message, ) - await self._received_notification(notification) - await self._handle_incoming(notification) - except Exception as e: - # For other validation errors, log and continue - logging.warning( - f"Failed to validate notification: {e}. " - f"Message was: {message.message.root}" - ) + except Exception as e: + logging.warning( + "Progress callback raised an exception: %s", + e, + ) + await self._received_notification(notification) + await self._handle_incoming(notification) else: # Response or error stream = self._response_streams.pop(message.message.root.id, None) if stream: diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index 1e0409e14..6b9be5656 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -1,4 +1,5 @@ from typing import Any, cast +from unittest.mock import patch import anyio import pytest @@ -10,12 +11,16 @@ from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.shared.context import RequestContext +from mcp.shared.memory import create_connected_server_and_client_session from mcp.shared.progress import progress from mcp.shared.session import ( BaseSession, RequestResponder, SessionMessage, ) +from mcp.types import ( + TextContent, +) @pytest.mark.anyio @@ -347,3 +352,78 @@ async def handle_client_message( assert server_progress_updates[3]["progress"] == 100 assert server_progress_updates[3]["total"] == 100 assert server_progress_updates[3]["message"] == "Processing results..." + + +@pytest.mark.anyio +async def test_progress_callback_exception_logging(): + """Test that exceptions in progress callbacks are logged and \ + don't crash the session.""" + # Track logged warnings + logged_warnings = [] + + def mock_warning(msg, *args): + logged_warnings.append(msg % args if args else msg) + + # Create a progress callback that raises an exception + async def failing_progress_callback( + progress: float, total: float | None, message: str | None + ) -> None: + raise ValueError("Progress callback failed!") + + # Create a server with a tool that sends progress notifications + server = Server(name="TestProgressServer") + + @server.call_tool() + async def handle_call_tool( + name: str, arguments: dict | None + ) -> list[types.TextContent]: + if name == "progress_tool": + # Send a progress notification + await server.request_context.session.send_progress_notification( + progress_token=server.request_context.request_id, + progress=50.0, + total=100.0, + message="Halfway done", + ) + return [types.TextContent(type="text", text="progress_result")] + raise ValueError(f"Unknown tool: {name}") + + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="progress_tool", + description="A tool that sends progress notifications", + inputSchema={}, + ) + ] + + # Test with mocked logging + with patch("mcp.shared.session.logging.warning", side_effect=mock_warning): + async with create_connected_server_and_client_session(server) as client_session: + # Send a request with a failing progress callback + result = await client_session.send_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name="progress_tool", arguments={} + ), + ) + ), + types.CallToolResult, + progress_callback=failing_progress_callback, + ) + + # Verify the request completed successfully despite the callback failure + assert len(result.content) == 1 + content = result.content[0] + assert isinstance(content, TextContent) + assert content.text == "progress_result" + + # Check that a warning was logged for the progress callback exception + assert len(logged_warnings) > 0 + assert any( + "Progress callback raised an exception" in warning + for warning in logged_warnings + )