|
| 1 | +import asyncio |
| 2 | +from uuid import UUID |
| 3 | + |
| 4 | +import pytest |
| 5 | +from starlette.types import Message, Scope |
| 6 | + |
| 7 | +from mcp.server.sse import SseServerTransport |
| 8 | + |
| 9 | + |
| 10 | +@pytest.mark.anyio |
| 11 | +async def test_sse_disconnect_handle(): |
| 12 | + transport = SseServerTransport(endpoint="/sse") |
| 13 | + # Create a minimal ASGI scope for an HTTP GET request |
| 14 | + scope: Scope = { |
| 15 | + "type": "http", |
| 16 | + "method": "GET", |
| 17 | + "path": "/sse", |
| 18 | + "headers": [], |
| 19 | + } |
| 20 | + send_disconnect = False |
| 21 | + |
| 22 | + # Dummy receive and send functions |
| 23 | + async def receive() -> dict: |
| 24 | + nonlocal send_disconnect |
| 25 | + if not send_disconnect: |
| 26 | + send_disconnect = True |
| 27 | + return {"type": "http.request"} |
| 28 | + else: |
| 29 | + return {"type": "http.disconnect"} |
| 30 | + |
| 31 | + async def send(message: Message) -> None: |
| 32 | + await asyncio.sleep(0) |
| 33 | + |
| 34 | + # Run the connect_sse context manager |
| 35 | + async with transport.connect_sse(scope, receive, send) as ( |
| 36 | + read_stream, |
| 37 | + write_stream, |
| 38 | + ): |
| 39 | + # Assert that streams are provided |
| 40 | + assert read_stream is not None |
| 41 | + assert write_stream is not None |
| 42 | + |
| 43 | + # There should be exactly one session |
| 44 | + assert len(transport._read_stream_writers) == 1 |
| 45 | + # Check that the session key is a UUID |
| 46 | + session_id = next(iter(transport._read_stream_writers.keys())) |
| 47 | + assert isinstance(session_id, UUID) |
| 48 | + |
| 49 | + # Check that the writer is still open |
| 50 | + writer = transport._read_stream_writers[session_id] |
| 51 | + assert writer is not None |
| 52 | + |
| 53 | + # After context exits, session should be cleaned up |
| 54 | + assert len(transport._read_stream_writers) == 0 |
0 commit comments