From 922461eb668760dc51ce9f5deaafc23f25550e9f Mon Sep 17 00:00:00 2001 From: Hoa Lam Date: Sat, 26 Apr 2025 19:58:42 +0700 Subject: [PATCH 1/2] Add unittest sse disconnect --- tests/server/test_sse_disconnect.py | 54 +++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 tests/server/test_sse_disconnect.py diff --git a/tests/server/test_sse_disconnect.py b/tests/server/test_sse_disconnect.py new file mode 100644 index 000000000..3dc6e0055 --- /dev/null +++ b/tests/server/test_sse_disconnect.py @@ -0,0 +1,54 @@ +import asyncio +from uuid import UUID + +import pytest +from starlette.types import Message, Scope + +from mcp.server.sse import SseServerTransport + + +@pytest.mark.anyio +async def test_sse_disconnect_handle(): + transport = SseServerTransport(endpoint="/sse") + # Create a minimal ASGI scope for an HTTP GET request + scope: Scope = { + "type": "http", + "method": "GET", + "path": "/sse", + "headers": [], + } + send_disconnect = False + + # Dummy receive and send functions + async def receive() -> dict: + nonlocal send_disconnect + if not send_disconnect: + send_disconnect = True + return {"type": "http.request"} + else: + return {"type": "http.disconnect"} + + async def send(message: Message) -> None: + await asyncio.sleep(0) + + # Run the connect_sse context manager + async with transport.connect_sse(scope, receive, send) as ( + read_stream, + write_stream, + ): + # Assert that streams are provided + assert read_stream is not None + assert write_stream is not None + + # There should be exactly one session + assert len(transport._read_stream_writers) == 1 + # Check that the session key is a UUID + session_id = next(iter(transport._read_stream_writers.keys())) + assert isinstance(session_id, UUID) + + # Check that the writer is still open + writer = transport._read_stream_writers[session_id] + assert writer is not None + + # After context exits, session should be cleaned up + assert len(transport._read_stream_writers) == 0 From 7de20f550d58b0c3d2b2345718a7326ebff5f3e5 Mon Sep 17 00:00:00 2001 From: Hoa Lam Date: Tue, 13 May 2025 17:27:47 +0700 Subject: [PATCH 2/2] Fix clean up memory when sse disconnect --- src/mcp/server/sse.py | 12 ++++++++++-- tests/server/test_sse_disconnect.py | 5 ++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index a6350a39b..f03f2a7ce 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -145,7 +145,12 @@ async def sse_writer(): async with anyio.create_task_group() as tg: - async def response_wrapper(scope: Scope, receive: Receive, send: Send): + async def response_wrapper( + scope: Scope, + receive: Receive, + send: Send, + transport: SseServerTransport, + ): """ The EventSourceResponse returning signals a client close / disconnect. In this case we close our side of the streams to signal the client that @@ -156,10 +161,13 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): )(scope, receive, send) await read_stream_writer.aclose() await write_stream_reader.aclose() + await read_stream.aclose() + await write_stream.aclose() + transport._read_stream_writers.pop(session_id) logging.debug(f"Client session disconnected {session_id}") logger.debug("Starting SSE response task") - tg.start_soon(response_wrapper, scope, receive, send) + tg.start_soon(response_wrapper, scope, receive, send, self) logger.debug("Yielding read and write streams") yield (read_stream, write_stream) diff --git a/tests/server/test_sse_disconnect.py b/tests/server/test_sse_disconnect.py index 3dc6e0055..81a7592b5 100644 --- a/tests/server/test_sse_disconnect.py +++ b/tests/server/test_sse_disconnect.py @@ -46,9 +46,8 @@ async def send(message: Message) -> None: session_id = next(iter(transport._read_stream_writers.keys())) assert isinstance(session_id, UUID) - # Check that the writer is still open - writer = transport._read_stream_writers[session_id] - assert writer is not None + # Check that the session_id should be clean up + assert session_id not in transport._read_stream_writers # After context exits, session should be cleaned up assert len(transport._read_stream_writers) == 0