diff --git a/pyproject.toml b/pyproject.toml index e1b175c76..cd4edc2e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "pyjwt[crypto]>=2.10.1", "typing-extensions>=4.9.0", "typing-inspection>=0.4.1", + "exceptiongroup>=1.0.0; python_version < '3.11'", ] [project.optional-dependencies] diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 4b0bbbc1e..8f50d7bd6 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -1,10 +1,16 @@ import logging +import sys from collections.abc import Callable from contextlib import asynccontextmanager from typing import Any from urllib.parse import parse_qs, urljoin, urlparse import anyio + +if sys.version_info >= (3, 11): + from builtins import BaseExceptionGroup # pragma: no cover +else: + from exceptiongroup import BaseExceptionGroup # pragma: no cover import httpx from anyio.abc import TaskStatus from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -157,8 +163,19 @@ async def post_writer(endpoint_url: str): try: yield read_stream, write_stream + # Suppress GeneratorExit to prevent "generator didn't stop after athrow()" + # when client code exits the context manager during cancellation. + # See https://github.com/python/cpython/issues/95571 + except GeneratorExit: + pass + # anyio wraps GeneratorExit in BaseExceptionGroup; extract and re-raise other exceptions + except BaseExceptionGroup as eg: + _, rest = eg.split(GeneratorExit) + if rest: + raise rest from None finally: tg.cancel_scope.cancel() finally: await read_stream_writer.aclose() + await read_stream.aclose() await write_stream.aclose() diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 22645d3ba..75e9ea603 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -8,6 +8,7 @@ import contextlib import logging +import sys from collections.abc import AsyncGenerator, Awaitable, Callable from contextlib import asynccontextmanager from dataclasses import dataclass @@ -16,6 +17,11 @@ from warnings import warn import anyio + +if sys.version_info >= (3, 11): + from builtins import BaseExceptionGroup # pragma: no cover +else: + from exceptiongroup import BaseExceptionGroup # pragma: no cover import httpx from anyio.abc import TaskGroup from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -672,12 +678,23 @@ def start_get_stream() -> None: write_stream, transport.get_session_id, ) + # Suppress GeneratorExit to prevent "generator didn't stop after athrow()" + # when client code exits the context manager during cancellation. + # See https://github.com/python/cpython/issues/95571 + except GeneratorExit: + pass + # anyio wraps GeneratorExit in BaseExceptionGroup; extract and re-raise other exceptions + except BaseExceptionGroup as eg: + _, rest = eg.split(GeneratorExit) + if rest: + raise rest from None finally: if transport.session_id and terminate_on_close: await transport.terminate_session(client) tg.cancel_scope.cancel() finally: await read_stream_writer.aclose() + await read_stream.aclose() await write_stream.aclose() diff --git a/tests/client/conftest.py b/tests/client/conftest.py index 1e5c4d524..dcf744359 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -1,14 +1,86 @@ -from collections.abc import Callable, Generator +import multiprocessing +import socket +from collections.abc import AsyncGenerator, Callable, Generator from contextlib import asynccontextmanager from typing import Any from unittest.mock import patch import pytest +import uvicorn from anyio.streams.memory import MemoryObjectSendStream +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Mount, Route import mcp.shared.memory +from mcp.server import Server +from mcp.server.sse import SseServerTransport +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.shared.message import SessionMessage from mcp.types import JSONRPCNotification, JSONRPCRequest +from tests.test_helpers import wait_for_server + + +def run_server(port: int) -> None: # pragma: no cover + """Run server with SSE and Streamable HTTP endpoints.""" + server = Server(name="cleanup_test_server") + session_manager = StreamableHTTPSessionManager(app=server, json_response=False) + sse_transport = SseServerTransport("/messages/") + + async def handle_sse(request: Request) -> Response: + async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams: + if streams: + await server.run(streams[0], streams[1], server.create_initialization_options()) + return Response() + + @asynccontextmanager + async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: + async with session_manager.run(): + yield + + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse_transport.handle_post_message), + Mount("/mcp", app=session_manager.handle_request), + ], + lifespan=lifespan, + ) + uvicorn.Server(uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error")).run() + + +@pytest.fixture +def server_port() -> int: + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def test_server(server_port: int) -> Generator[str, None, None]: + """Start server with SSE and Streamable HTTP endpoints.""" + proc = multiprocessing.Process(target=run_server, kwargs={"port": server_port}, daemon=True) + proc.start() + wait_for_server(server_port) + try: + yield f"http://127.0.0.1:{server_port}" + finally: + proc.terminate() + proc.join(timeout=2) + if proc.is_alive(): # pragma: no cover + proc.kill() + proc.join(timeout=1) + + +@pytest.fixture +def sse_server_url(test_server: str) -> str: + return f"{test_server}/sse" + + +@pytest.fixture +def streamable_server_url(test_server: str) -> str: + return f"{test_server}/mcp" class SpyMemoryObjectSendStream: diff --git a/tests/client/test_resource_cleanup.py b/tests/client/test_resource_cleanup.py index cc6c5059f..78f45353e 100644 --- a/tests/client/test_resource_cleanup.py +++ b/tests/client/test_resource_cleanup.py @@ -1,63 +1,117 @@ +import sys +from collections.abc import Callable from typing import Any + +if sys.version_info >= (3, 11): + from builtins import BaseExceptionGroup # pragma: no cover +else: + from exceptiongroup import BaseExceptionGroup # pragma: no cover + from unittest.mock import patch import anyio import pytest +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamable_http_client from mcp.shared.message import SessionMessage from mcp.shared.session import BaseSession, RequestId, SendResultT from mcp.types import ClientNotification, ClientRequest, ClientResult, EmptyResult, ErrorData, PingRequest +ClientTransport = tuple[ + str, + Callable[..., Any], + Callable[[Any], tuple[MemoryObjectReceiveStream[Any], MemoryObjectSendStream[Any]]], +] + @pytest.mark.anyio async def test_send_request_stream_cleanup(): - """ - Test that send_request properly cleans up streams when an exception occurs. + """Test that send_request properly cleans up streams when an exception occurs.""" - This test mocks out most of the session functionality to focus on stream cleanup. - """ - - # Create a mock session with the minimal required functionality class TestSession(BaseSession[ClientRequest, ClientNotification, ClientResult, Any, Any]): async def _send_response( self, request_id: RequestId, response: SendResultT | ErrorData ) -> None: # pragma: no cover pass - # Create streams write_stream_send, write_stream_receive = anyio.create_memory_object_stream[SessionMessage](1) read_stream_send, read_stream_receive = anyio.create_memory_object_stream[SessionMessage](1) - # Create the session session = TestSession( read_stream_receive, write_stream_send, - object, # Request type doesn't matter for this test - object, # Notification type doesn't matter for this test + object, + object, ) - # Create a test request request = ClientRequest(PingRequest()) - # Patch the _write_stream.send method to raise an exception async def mock_send(*args: Any, **kwargs: Any): raise RuntimeError("Simulated network error") - # Record the response streams before the test initial_stream_count = len(session._response_streams) - # Run the test with the patched method with patch.object(session._write_stream, "send", mock_send): with pytest.raises(RuntimeError): await session.send_request(request, EmptyResult) - # Verify that no response streams were leaked - assert len(session._response_streams) == initial_stream_count, ( - f"Expected {initial_stream_count} response streams after request, but found {len(session._response_streams)}" - ) + assert len(session._response_streams) == initial_stream_count - # Clean up await write_stream_send.aclose() await write_stream_receive.aclose() await read_stream_send.aclose() await read_stream_receive.aclose() + + +@pytest.fixture(params=["sse", "streamable"]) +def client_transport( + request: pytest.FixtureRequest, sse_server_url: str, streamable_server_url: str +) -> ClientTransport: + if request.param == "sse": + return (sse_server_url, sse_client, lambda x: (x[0], x[1])) + else: + return (streamable_server_url, streamable_http_client, lambda x: (x[0], x[1])) + + +@pytest.mark.anyio +async def test_generator_exit_on_gc_cleanup(client_transport: ClientTransport) -> None: + """Suppress GeneratorExit from aclose() during GC cleanup (python/cpython#95571).""" + url, client_func, unpack = client_transport + cm = client_func(url) + result = await cm.__aenter__() + read_stream, write_stream = unpack(result) + await cm.gen.aclose() + await read_stream.aclose() + await write_stream.aclose() + + +@pytest.mark.anyio +async def test_generator_exit_in_exception_group(client_transport: ClientTransport) -> None: + """Extract GeneratorExit from BaseExceptionGroup (python/cpython#135736).""" + url, client_func, unpack = client_transport + async with client_func(url) as result: + unpack(result) + raise BaseExceptionGroup("unhandled errors in a TaskGroup", [GeneratorExit()]) + + +@pytest.mark.anyio +async def test_generator_exit_mixed_group(client_transport: ClientTransport) -> None: + """Extract GeneratorExit from BaseExceptionGroup, re-raise other exceptions (python/cpython#135736).""" + url, client_func, unpack = client_transport + with pytest.raises(BaseExceptionGroup) as exc_info: + async with client_func(url) as result: + unpack(result) + raise BaseExceptionGroup("errors", [GeneratorExit(), ValueError("real error")]) + + def has_generator_exit(eg: BaseExceptionGroup[Any]) -> bool: + for e in eg.exceptions: + if isinstance(e, GeneratorExit): + return True # pragma: no cover + if isinstance(e, BaseExceptionGroup): + if has_generator_exit(eg=e): # type: ignore[arg-type] + return True # pragma: no cover + return False + + assert not has_generator_exit(exc_info.value) diff --git a/uv.lock b/uv.lock index d2a515863..29e1b9d63 100644 --- a/uv.lock +++ b/uv.lock @@ -722,6 +722,7 @@ name = "mcp" source = { editable = "." } dependencies = [ { name = "anyio" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "httpx" }, { name = "httpx-sse" }, { name = "jsonschema" }, @@ -776,6 +777,7 @@ docs = [ [package.metadata] requires-dist = [ { name = "anyio", specifier = ">=4.5" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'", specifier = ">=1.0.0" }, { name = "httpx", specifier = ">=0.27.1" }, { name = "httpx-sse", specifier = ">=0.4" }, { name = "jsonschema", specifier = ">=4.20.0" },