Skip to content

Commit 5d33861

Browse files
authored
Add progress notification callback for client (#721)
1 parent 1bdeed3 commit 5d33861

File tree

6 files changed

+609
-12
lines changed

6 files changed

+609
-12
lines changed

examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
if __name__ == "__main__":
44
# Click will handle CLI arguments
55
import sys
6-
6+
77
sys.exit(main()) # type: ignore[call-arg]

src/mcp/client/session.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import mcp.types as types
99
from mcp.shared.context import RequestContext
1010
from mcp.shared.message import SessionMessage
11-
from mcp.shared.session import BaseSession, RequestResponder
11+
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
1212
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1313

1414
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
@@ -270,18 +270,23 @@ async def call_tool(
270270
name: str,
271271
arguments: dict[str, Any] | None = None,
272272
read_timeout_seconds: timedelta | None = None,
273+
progress_callback: ProgressFnT | None = None,
273274
) -> types.CallToolResult:
274-
"""Send a tools/call request."""
275+
"""Send a tools/call request with optional progress callback support."""
275276

276277
return await self.send_request(
277278
types.ClientRequest(
278279
types.CallToolRequest(
279280
method="tools/call",
280-
params=types.CallToolRequestParams(name=name, arguments=arguments),
281+
params=types.CallToolRequestParams(
282+
name=name,
283+
arguments=arguments,
284+
),
281285
)
282286
),
283287
types.CallToolResult,
284288
request_read_timeout_seconds=read_timeout_seconds,
289+
progress_callback=progress_callback,
285290
)
286291

287292
async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResult:

src/mcp/server/fastmcp/server.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -963,7 +963,6 @@ async def report_progress(
963963
total: Optional total value e.g. 100
964964
message: Optional message e.g. Starting render...
965965
"""
966-
967966
progress_token = (
968967
self.request_context.meta.progressToken
969968
if self.request_context.meta

src/mcp/shared/session.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from contextlib import AsyncExitStack
44
from datetime import timedelta
55
from types import TracebackType
6-
from typing import Any, Generic, TypeVar
6+
from typing import Any, Generic, Protocol, TypeVar
77

88
import anyio
99
import httpx
@@ -24,6 +24,7 @@
2424
JSONRPCNotification,
2525
JSONRPCRequest,
2626
JSONRPCResponse,
27+
ProgressNotification,
2728
RequestParams,
2829
ServerNotification,
2930
ServerRequest,
@@ -42,6 +43,14 @@
4243
RequestId = str | int
4344

4445

46+
class ProgressFnT(Protocol):
47+
"""Protocol for progress notification callbacks."""
48+
49+
async def __call__(
50+
self, progress: float, total: float | None, message: str | None
51+
) -> None: ...
52+
53+
4554
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
4655
"""Handles responding to MCP requests and manages request lifecycle.
4756
@@ -169,6 +178,7 @@ class BaseSession(
169178
]
170179
_request_id: int
171180
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
181+
_progress_callbacks: dict[RequestId, ProgressFnT]
172182

173183
def __init__(
174184
self,
@@ -187,6 +197,7 @@ def __init__(
187197
self._receive_notification_type = receive_notification_type
188198
self._session_read_timeout_seconds = read_timeout_seconds
189199
self._in_flight = {}
200+
self._progress_callbacks = {}
190201
self._exit_stack = AsyncExitStack()
191202

192203
async def __aenter__(self) -> Self:
@@ -214,6 +225,7 @@ async def send_request(
214225
result_type: type[ReceiveResultT],
215226
request_read_timeout_seconds: timedelta | None = None,
216227
metadata: MessageMetadata = None,
228+
progress_callback: ProgressFnT | None = None,
217229
) -> ReceiveResultT:
218230
"""
219231
Sends a request and wait for a response. Raises an McpError if the
@@ -231,15 +243,25 @@ async def send_request(
231243
](1)
232244
self._response_streams[request_id] = response_stream
233245

246+
# Set up progress token if progress callback is provided
247+
request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True)
248+
if progress_callback is not None:
249+
# Use request_id as progress token
250+
if "params" not in request_data:
251+
request_data["params"] = {}
252+
if "_meta" not in request_data["params"]:
253+
request_data["params"]["_meta"] = {}
254+
request_data["params"]["_meta"]["progressToken"] = request_id
255+
# Store the callback for this request
256+
self._progress_callbacks[request_id] = progress_callback
257+
234258
try:
235259
jsonrpc_request = JSONRPCRequest(
236260
jsonrpc="2.0",
237261
id=request_id,
238-
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
262+
**request_data,
239263
)
240264

241-
# TODO: Support progress callbacks
242-
243265
await self._write_stream.send(
244266
SessionMessage(
245267
message=JSONRPCMessage(jsonrpc_request), metadata=metadata
@@ -275,6 +297,7 @@ async def send_request(
275297

276298
finally:
277299
self._response_streams.pop(request_id, None)
300+
self._progress_callbacks.pop(request_id, None)
278301
await response_stream.aclose()
279302
await response_stream_reader.aclose()
280303

@@ -333,7 +356,6 @@ async def _receive_loop(self) -> None:
333356
by_alias=True, mode="json", exclude_none=True
334357
)
335358
)
336-
337359
responder = RequestResponder(
338360
request_id=message.message.root.id,
339361
request_meta=validated_request.root.params.meta
@@ -363,6 +385,18 @@ async def _receive_loop(self) -> None:
363385
if cancelled_id in self._in_flight:
364386
await self._in_flight[cancelled_id].cancel()
365387
else:
388+
# Handle progress notifications callback
389+
if isinstance(notification.root, ProgressNotification):
390+
progress_token = notification.root.params.progressToken
391+
# If there is a progress callback for this token,
392+
# call it with the progress information
393+
if progress_token in self._progress_callbacks:
394+
callback = self._progress_callbacks[progress_token]
395+
await callback(
396+
notification.root.params.progress,
397+
notification.root.params.total,
398+
notification.root.params.message,
399+
)
366400
await self._received_notification(notification)
367401
await self._handle_incoming(notification)
368402
except Exception as e:

tests/client/test_list_methods_cursor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
async def test_list_tools_cursor_parameter():
1313
"""Test that the cursor parameter is accepted for list_tools.
14-
14+
1515
Note: FastMCP doesn't currently implement pagination, so this test
1616
only verifies that the cursor parameter is accepted by the client.
1717
"""

0 commit comments

Comments
 (0)