3
3
from contextlib import AsyncExitStack
4
4
from datetime import timedelta
5
5
from types import TracebackType
6
- from typing import Any , Generic , TypeVar
6
+ from typing import Any , Generic , Protocol , TypeVar
7
7
8
8
import anyio
9
9
import httpx
24
24
JSONRPCNotification ,
25
25
JSONRPCRequest ,
26
26
JSONRPCResponse ,
27
+ ProgressNotification ,
27
28
RequestParams ,
28
29
ServerNotification ,
29
30
ServerRequest ,
42
43
RequestId = str | int
43
44
44
45
46
+ class ProgressFnT (Protocol ):
47
+ """Protocol for progress notification callbacks."""
48
+
49
+ async def __call__ (
50
+ self , progress : float , total : float | None , message : str | None
51
+ ) -> None : ...
52
+
53
+
45
54
class RequestResponder (Generic [ReceiveRequestT , SendResultT ]):
46
55
"""Handles responding to MCP requests and manages request lifecycle.
47
56
@@ -169,6 +178,7 @@ class BaseSession(
169
178
]
170
179
_request_id : int
171
180
_in_flight : dict [RequestId , RequestResponder [ReceiveRequestT , SendResultT ]]
181
+ _progress_callbacks : dict [RequestId , ProgressFnT ]
172
182
173
183
def __init__ (
174
184
self ,
@@ -187,6 +197,7 @@ def __init__(
187
197
self ._receive_notification_type = receive_notification_type
188
198
self ._session_read_timeout_seconds = read_timeout_seconds
189
199
self ._in_flight = {}
200
+ self ._progress_callbacks = {}
190
201
self ._exit_stack = AsyncExitStack ()
191
202
192
203
async def __aenter__ (self ) -> Self :
@@ -214,6 +225,7 @@ async def send_request(
214
225
result_type : type [ReceiveResultT ],
215
226
request_read_timeout_seconds : timedelta | None = None ,
216
227
metadata : MessageMetadata = None ,
228
+ progress_callback : ProgressFnT | None = None ,
217
229
) -> ReceiveResultT :
218
230
"""
219
231
Sends a request and wait for a response. Raises an McpError if the
@@ -231,15 +243,25 @@ async def send_request(
231
243
](1 )
232
244
self ._response_streams [request_id ] = response_stream
233
245
246
+ # Set up progress token if progress callback is provided
247
+ request_data = request .model_dump (by_alias = True , mode = "json" , exclude_none = True )
248
+ if progress_callback is not None :
249
+ # Use request_id as progress token
250
+ if "params" not in request_data :
251
+ request_data ["params" ] = {}
252
+ if "_meta" not in request_data ["params" ]:
253
+ request_data ["params" ]["_meta" ] = {}
254
+ request_data ["params" ]["_meta" ]["progressToken" ] = request_id
255
+ # Store the callback for this request
256
+ self ._progress_callbacks [request_id ] = progress_callback
257
+
234
258
try :
235
259
jsonrpc_request = JSONRPCRequest (
236
260
jsonrpc = "2.0" ,
237
261
id = request_id ,
238
- ** request . model_dump ( by_alias = True , mode = "json" , exclude_none = True ) ,
262
+ ** request_data ,
239
263
)
240
264
241
- # TODO: Support progress callbacks
242
-
243
265
await self ._write_stream .send (
244
266
SessionMessage (
245
267
message = JSONRPCMessage (jsonrpc_request ), metadata = metadata
@@ -275,6 +297,7 @@ async def send_request(
275
297
276
298
finally :
277
299
self ._response_streams .pop (request_id , None )
300
+ self ._progress_callbacks .pop (request_id , None )
278
301
await response_stream .aclose ()
279
302
await response_stream_reader .aclose ()
280
303
@@ -333,7 +356,6 @@ async def _receive_loop(self) -> None:
333
356
by_alias = True , mode = "json" , exclude_none = True
334
357
)
335
358
)
336
-
337
359
responder = RequestResponder (
338
360
request_id = message .message .root .id ,
339
361
request_meta = validated_request .root .params .meta
@@ -363,6 +385,18 @@ async def _receive_loop(self) -> None:
363
385
if cancelled_id in self ._in_flight :
364
386
await self ._in_flight [cancelled_id ].cancel ()
365
387
else :
388
+ # Handle progress notifications callback
389
+ if isinstance (notification .root , ProgressNotification ):
390
+ progress_token = notification .root .params .progressToken
391
+ # If there is a progress callback for this token,
392
+ # call it with the progress information
393
+ if progress_token in self ._progress_callbacks :
394
+ callback = self ._progress_callbacks [progress_token ]
395
+ await callback (
396
+ notification .root .params .progress ,
397
+ notification .root .params .total ,
398
+ notification .root .params .message ,
399
+ )
366
400
await self ._received_notification (notification )
367
401
await self ._handle_incoming (notification )
368
402
except Exception as e :
0 commit comments