Skip to content

Commit 922461e

Browse files
author
Hoa Lam
committed
Add unittest sse disconnect
1 parent 7b6a903 commit 922461e

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

tests/server/test_sse_disconnect.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import asyncio
2+
from uuid import UUID
3+
4+
import pytest
5+
from starlette.types import Message, Scope
6+
7+
from mcp.server.sse import SseServerTransport
8+
9+
10+
@pytest.mark.anyio
11+
async def test_sse_disconnect_handle():
12+
transport = SseServerTransport(endpoint="/sse")
13+
# Create a minimal ASGI scope for an HTTP GET request
14+
scope: Scope = {
15+
"type": "http",
16+
"method": "GET",
17+
"path": "/sse",
18+
"headers": [],
19+
}
20+
send_disconnect = False
21+
22+
# Dummy receive and send functions
23+
async def receive() -> dict:
24+
nonlocal send_disconnect
25+
if not send_disconnect:
26+
send_disconnect = True
27+
return {"type": "http.request"}
28+
else:
29+
return {"type": "http.disconnect"}
30+
31+
async def send(message: Message) -> None:
32+
await asyncio.sleep(0)
33+
34+
# Run the connect_sse context manager
35+
async with transport.connect_sse(scope, receive, send) as (
36+
read_stream,
37+
write_stream,
38+
):
39+
# Assert that streams are provided
40+
assert read_stream is not None
41+
assert write_stream is not None
42+
43+
# There should be exactly one session
44+
assert len(transport._read_stream_writers) == 1
45+
# Check that the session key is a UUID
46+
session_id = next(iter(transport._read_stream_writers.keys()))
47+
assert isinstance(session_id, UUID)
48+
49+
# Check that the writer is still open
50+
writer = transport._read_stream_writers[session_id]
51+
assert writer is not None
52+
53+
# After context exits, session should be cleaned up
54+
assert len(transport._read_stream_writers) == 0

0 commit comments

Comments
 (0)