Skip to content

[OpenAI Codex PR] Fix cleanup order issue with task groups #773

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/mcp/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import mcp.types as types
from mcp.shared._httpx_utils import create_mcp_http_client
from mcp.shared.message import SessionMessage
from mcp.shared.taskgroup import CompatTaskGroup

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,7 +51,7 @@ async def sse_client(
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

async with anyio.create_task_group() as tg:
async with CompatTaskGroup() as tg:
try:
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
async with create_mcp_http_client(headers=headers, auth=auth) as client:
Expand Down
3 changes: 2 additions & 1 deletion src/mcp/client/stdio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import mcp.types as types
from mcp.shared.message import SessionMessage
from mcp.shared.taskgroup import CompatTaskGroup

from .win32 import (
create_windows_process,
Expand Down Expand Up @@ -168,7 +169,7 @@ async def stdin_writer():
await anyio.lowlevel.checkpoint()

async with (
anyio.create_task_group() as tg,
CompatTaskGroup() as tg,
process,
):
tg.start_soon(stdout_reader)
Expand Down
6 changes: 3 additions & 3 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

import anyio
import httpx
from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse

from mcp.shared._httpx_utils import create_mcp_http_client
from mcp.shared.message import ClientMessageMetadata, SessionMessage
from mcp.shared.taskgroup import CompatTaskGroup
from mcp.types import (
ErrorData,
JSONRPCError,
Expand Down Expand Up @@ -352,7 +352,7 @@ async def post_writer(
read_stream_writer: StreamWriter,
write_stream: MemoryObjectSendStream[SessionMessage],
start_get_stream: Callable[[], None],
tg: TaskGroup,
tg: CompatTaskGroup,
) -> None:
"""Handle writing requests to the server."""
try:
Expand Down Expand Up @@ -460,7 +460,7 @@ async def streamablehttp_client(
SessionMessage
](0)

async with anyio.create_task_group() as tg:
async with CompatTaskGroup() as tg:
try:
logger.info(f"Connecting to StreamableHTTP endpoint: {url}")

Expand Down
3 changes: 2 additions & 1 deletion src/mcp/client/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import mcp.types as types
from mcp.shared.message import SessionMessage
from mcp.shared.taskgroup import CompatTaskGroup

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -79,7 +80,7 @@ async def ws_writer():
)
await ws.send(json.dumps(msg_dict))

async with anyio.create_task_group() as tg:
async with CompatTaskGroup() as tg:
# Start reader and writer tasks
tg.start_soon(ws_reader)
tg.start_soon(ws_writer)
Expand Down
4 changes: 2 additions & 2 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ async def main():
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from typing import Any, Generic, TypeVar

import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl

Expand All @@ -87,6 +86,7 @@ async def main():
from mcp.shared.exceptions import McpError
from mcp.shared.message import SessionMessage
from mcp.shared.session import RequestResponder
from mcp.shared.taskgroup import CompatTaskGroup

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -503,7 +503,7 @@ async def run(
)
)

async with anyio.create_task_group() as tg:
async with CompatTaskGroup() as tg:
async for message in session.incoming_messages:
logger.debug(f"Received message: {message}")

Expand Down
3 changes: 2 additions & 1 deletion src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ async def handle_sse(request):

import mcp.types as types
from mcp.shared.message import SessionMessage
from mcp.shared.taskgroup import CompatTaskGroup

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -143,7 +144,7 @@ async def sse_writer():
}
)

async with anyio.create_task_group() as tg:
async with CompatTaskGroup() as tg:

async def response_wrapper(scope: Scope, receive: Receive, send: Send):
"""
Expand Down
3 changes: 2 additions & 1 deletion src/mcp/server/stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ async def run_server():

import mcp.types as types
from mcp.shared.message import SessionMessage
from mcp.shared.taskgroup import CompatTaskGroup


@asynccontextmanager
Expand Down Expand Up @@ -84,7 +85,7 @@ async def stdout_writer():
except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint()

async with anyio.create_task_group() as tg:
async with CompatTaskGroup() as tg:
tg.start_soon(stdin_reader)
tg.start_soon(stdout_writer)
yield read_stream, write_stream
5 changes: 3 additions & 2 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from starlette.types import Receive, Scope, Send

from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.taskgroup import CompatTaskGroup
from mcp.types import (
INTERNAL_ERROR,
INVALID_PARAMS,
Expand Down Expand Up @@ -508,7 +509,7 @@ async def sse_writer():
# Start the SSE response (this will send headers immediately)
try:
# First send the response to establish the SSE connection
async with anyio.create_task_group() as tg:
async with CompatTaskGroup() as tg:
tg.start_soon(response, scope, receive, send)
# Then send the message to be processed by the server
session_message = SessionMessage(message)
Expand Down Expand Up @@ -840,7 +841,7 @@ async def connect(
self._write_stream = write_stream

# Start a task group for message routing
async with anyio.create_task_group() as tg:
async with CompatTaskGroup() as tg:
# Create a message router that distributes messages to request streams
async def message_router():
try:
Expand Down
3 changes: 2 additions & 1 deletion src/mcp/server/streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
EventStore,
StreamableHTTPServerTransport,
)
from mcp.shared.taskgroup import CompatTaskGroup

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -103,7 +104,7 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
)
self._has_started = True

async with anyio.create_task_group() as tg:
async with CompatTaskGroup() as tg:
# Store the task group for later use
self._task_group = tg
logger.info("StreamableHTTP session manager started")
Expand Down
3 changes: 2 additions & 1 deletion src/mcp/server/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import mcp.types as types
from mcp.shared.message import SessionMessage
from mcp.shared.taskgroup import CompatTaskGroup

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,7 +59,7 @@ async def ws_writer():
except anyio.ClosedResourceError:
await websocket.close()

async with anyio.create_task_group() as tg:
async with CompatTaskGroup() as tg:
tg.start_soon(ws_reader)
tg.start_soon(ws_writer)
yield (read_stream, write_stream)
3 changes: 2 additions & 1 deletion src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from mcp.shared.exceptions import McpError
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
from mcp.shared.taskgroup import CompatTaskGroup
from mcp.types import (
CancelledNotification,
ClientNotification,
Expand Down Expand Up @@ -201,7 +202,7 @@ def __init__(
self._exit_stack = AsyncExitStack()

async def __aenter__(self) -> Self:
self._task_group = anyio.create_task_group()
self._task_group = CompatTaskGroup()
await self._task_group.__aenter__()
self._task_group.start_soon(self._receive_loop)
return self
Expand Down
76 changes: 76 additions & 0 deletions src/mcp/shared/taskgroup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from __future__ import annotations

import asyncio
import sys
from collections.abc import Awaitable, Callable
from contextlib import AbstractAsyncContextManager
from typing import Any, TypeVar

import anyio

_T = TypeVar("_T")

class _AsyncioCancelScope:
def __init__(self, tasks: set[asyncio.Task[Any]]):
self._tasks = tasks

def cancel(self) -> None:
for task in list(self._tasks):
task.cancel()

class CompatTaskGroup(AbstractAsyncContextManager):
"""Minimal compatibility layer mimicking ``anyio.TaskGroup``."""

def __init__(self) -> None:
self._use_asyncio = sys.version_info >= (3, 11)
if self._use_asyncio:
self._tg = asyncio.TaskGroup()
self._tasks: set[asyncio.Task[Any]] = set()
self.cancel_scope = _AsyncioCancelScope(self._tasks)
else:
self._tg = anyio.create_task_group()
self.cancel_scope = self._tg.cancel_scope # type: ignore[attr-defined]

async def __aenter__(self) -> CompatTaskGroup:
await self._tg.__aenter__()
return self

async def __aexit__(self, exc_type, exc, tb) -> bool | None:
return await self._tg.__aexit__(exc_type, exc, tb)

def start_soon(
self,
func: Callable[..., Awaitable[Any]],
*args: Any,
name: Any | None = None,
) -> None:
if self._use_asyncio:
task = self._tg.create_task(func(*args))
self._tasks.add(task)
else:
self._tg.start_soon(func, *args, name=name)

async def start(
self,
func: Callable[..., Awaitable[Any]],
*args: Any,
name: Any | None = None,
) -> Any:
if self._use_asyncio:
fut: asyncio.Future[Any] = asyncio.get_running_loop().create_future()

async def runner() -> None:
try:
result = await func(*args, task_status=fut)
if not fut.done():
fut.set_result(result)
except BaseException as exc:
if not fut.done():
fut.set_exception(exc)
raise

task = self._tg.create_task(runner())
self._tasks.add(task)
return await fut
else:
return await self._tg.start(func, *args, name=name)