Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ dependencies = [
"pyjwt[crypto]>=2.10.1",
"typing-extensions>=4.9.0",
"typing-inspection>=0.4.1",
"exceptiongroup>=1.0.0; python_version < '3.11'",
]

[project.optional-dependencies]
Expand Down
17 changes: 17 additions & 0 deletions src/mcp/client/sse.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import logging
import sys
from collections.abc import Callable
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import parse_qs, urljoin, urlparse

import anyio

if sys.version_info >= (3, 11):
from builtins import BaseExceptionGroup # pragma: no cover
else:
from exceptiongroup import BaseExceptionGroup # pragma: no cover
import httpx
from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
Expand Down Expand Up @@ -157,8 +163,19 @@ async def post_writer(endpoint_url: str):

try:
yield read_stream, write_stream
# Suppress GeneratorExit to prevent "generator didn't stop after athrow()"
# when client code exits the context manager during cancellation.
# See https://github.com/python/cpython/issues/95571
except GeneratorExit:
pass
# anyio wraps GeneratorExit in BaseExceptionGroup; extract and re-raise other exceptions
except BaseExceptionGroup as eg:
_, rest = eg.split(GeneratorExit)
if rest:
raise rest from None
finally:
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await read_stream.aclose()
await write_stream.aclose()
17 changes: 17 additions & 0 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import contextlib
import logging
import sys
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass
Expand All @@ -16,6 +17,11 @@
from warnings import warn

import anyio

if sys.version_info >= (3, 11):
from builtins import BaseExceptionGroup # pragma: no cover
else:
from exceptiongroup import BaseExceptionGroup # pragma: no cover
import httpx
from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
Expand Down Expand Up @@ -672,12 +678,23 @@ def start_get_stream() -> None:
write_stream,
transport.get_session_id,
)
# Suppress GeneratorExit to prevent "generator didn't stop after athrow()"
# when client code exits the context manager during cancellation.
# See https://github.com/python/cpython/issues/95571
except GeneratorExit:
pass
# anyio wraps GeneratorExit in BaseExceptionGroup; extract and re-raise other exceptions
except BaseExceptionGroup as eg:
_, rest = eg.split(GeneratorExit)
if rest:
raise rest from None
finally:
if transport.session_id and terminate_on_close:
await transport.terminate_session(client)
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await read_stream.aclose()
await write_stream.aclose()


Expand Down
74 changes: 73 additions & 1 deletion tests/client/conftest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,86 @@
from collections.abc import Callable, Generator
import multiprocessing
import socket
from collections.abc import AsyncGenerator, Callable, Generator
from contextlib import asynccontextmanager
from typing import Any
from unittest.mock import patch

import pytest
import uvicorn
from anyio.streams.memory import MemoryObjectSendStream
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Mount, Route

import mcp.shared.memory
from mcp.server import Server
from mcp.server.sse import SseServerTransport
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.shared.message import SessionMessage
from mcp.types import JSONRPCNotification, JSONRPCRequest
from tests.test_helpers import wait_for_server


def run_server(port: int) -> None: # pragma: no cover
"""Run server with SSE and Streamable HTTP endpoints."""
server = Server(name="cleanup_test_server")
session_manager = StreamableHTTPSessionManager(app=server, json_response=False)
sse_transport = SseServerTransport("/messages/")

async def handle_sse(request: Request) -> Response:
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams:
if streams:
await server.run(streams[0], streams[1], server.create_initialization_options())
return Response()

@asynccontextmanager
async def lifespan(app: Starlette) -> AsyncGenerator[None, None]:
async with session_manager.run():
yield

app = Starlette(
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse_transport.handle_post_message),
Mount("/mcp", app=session_manager.handle_request),
],
lifespan=lifespan,
)
uvicorn.Server(uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error")).run()


@pytest.fixture
def server_port() -> int:
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]


@pytest.fixture
def test_server(server_port: int) -> Generator[str, None, None]:
"""Start server with SSE and Streamable HTTP endpoints."""
proc = multiprocessing.Process(target=run_server, kwargs={"port": server_port}, daemon=True)
proc.start()
wait_for_server(server_port)
try:
yield f"http://127.0.0.1:{server_port}"
finally:
proc.terminate()
proc.join(timeout=2)
if proc.is_alive(): # pragma: no cover
proc.kill()
proc.join(timeout=1)


@pytest.fixture
def sse_server_url(test_server: str) -> str:
return f"{test_server}/sse"


@pytest.fixture
def streamable_server_url(test_server: str) -> str:
return f"{test_server}/mcp"


class SpyMemoryObjectSendStream:
Expand Down
92 changes: 73 additions & 19 deletions tests/client/test_resource_cleanup.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,117 @@
import sys
from collections.abc import Callable
from typing import Any

if sys.version_info >= (3, 11):
from builtins import BaseExceptionGroup # pragma: no cover
else:
from exceptiongroup import BaseExceptionGroup # pragma: no cover

from unittest.mock import patch

import anyio
import pytest
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamable_http_client
from mcp.shared.message import SessionMessage
from mcp.shared.session import BaseSession, RequestId, SendResultT
from mcp.types import ClientNotification, ClientRequest, ClientResult, EmptyResult, ErrorData, PingRequest

ClientTransport = tuple[
str,
Callable[..., Any],
Callable[[Any], tuple[MemoryObjectReceiveStream[Any], MemoryObjectSendStream[Any]]],
]


@pytest.mark.anyio
async def test_send_request_stream_cleanup():
"""
Test that send_request properly cleans up streams when an exception occurs.
"""Test that send_request properly cleans up streams when an exception occurs."""

This test mocks out most of the session functionality to focus on stream cleanup.
"""

# Create a mock session with the minimal required functionality
class TestSession(BaseSession[ClientRequest, ClientNotification, ClientResult, Any, Any]):
async def _send_response(
self, request_id: RequestId, response: SendResultT | ErrorData
) -> None: # pragma: no cover
pass

# Create streams
write_stream_send, write_stream_receive = anyio.create_memory_object_stream[SessionMessage](1)
read_stream_send, read_stream_receive = anyio.create_memory_object_stream[SessionMessage](1)

# Create the session
session = TestSession(
read_stream_receive,
write_stream_send,
object, # Request type doesn't matter for this test
object, # Notification type doesn't matter for this test
object,
object,
)

# Create a test request
request = ClientRequest(PingRequest())

# Patch the _write_stream.send method to raise an exception
async def mock_send(*args: Any, **kwargs: Any):
raise RuntimeError("Simulated network error")

# Record the response streams before the test
initial_stream_count = len(session._response_streams)

# Run the test with the patched method
with patch.object(session._write_stream, "send", mock_send):
with pytest.raises(RuntimeError):
await session.send_request(request, EmptyResult)

# Verify that no response streams were leaked
assert len(session._response_streams) == initial_stream_count, (
f"Expected {initial_stream_count} response streams after request, but found {len(session._response_streams)}"
)
assert len(session._response_streams) == initial_stream_count

# Clean up
await write_stream_send.aclose()
await write_stream_receive.aclose()
await read_stream_send.aclose()
await read_stream_receive.aclose()


@pytest.fixture(params=["sse", "streamable"])
def client_transport(
request: pytest.FixtureRequest, sse_server_url: str, streamable_server_url: str
) -> ClientTransport:
if request.param == "sse":
return (sse_server_url, sse_client, lambda x: (x[0], x[1]))
else:
return (streamable_server_url, streamable_http_client, lambda x: (x[0], x[1]))


@pytest.mark.anyio
async def test_generator_exit_on_gc_cleanup(client_transport: ClientTransport) -> None:
"""Suppress GeneratorExit from aclose() during GC cleanup (python/cpython#95571)."""
url, client_func, unpack = client_transport
cm = client_func(url)
result = await cm.__aenter__()
read_stream, write_stream = unpack(result)
await cm.gen.aclose()
await read_stream.aclose()
await write_stream.aclose()


@pytest.mark.anyio
async def test_generator_exit_in_exception_group(client_transport: ClientTransport) -> None:
"""Extract GeneratorExit from BaseExceptionGroup (python/cpython#135736)."""
url, client_func, unpack = client_transport
async with client_func(url) as result:
unpack(result)
raise BaseExceptionGroup("unhandled errors in a TaskGroup", [GeneratorExit()])


@pytest.mark.anyio
async def test_generator_exit_mixed_group(client_transport: ClientTransport) -> None:
"""Extract GeneratorExit from BaseExceptionGroup, re-raise other exceptions (python/cpython#135736)."""
url, client_func, unpack = client_transport
with pytest.raises(BaseExceptionGroup) as exc_info:
async with client_func(url) as result:
unpack(result)
raise BaseExceptionGroup("errors", [GeneratorExit(), ValueError("real error")])

def has_generator_exit(eg: BaseExceptionGroup[Any]) -> bool:
for e in eg.exceptions:
if isinstance(e, GeneratorExit):
return True # pragma: no cover
if isinstance(e, BaseExceptionGroup):
if has_generator_exit(eg=e): # type: ignore[arg-type]
return True # pragma: no cover
return False

assert not has_generator_exit(exc_info.value)
2 changes: 2 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.