From 46603d13ff35102d96e96703cd045f287fb93184 Mon Sep 17 00:00:00 2001 From: Basil Hosmer Date: Mon, 14 Apr 2025 23:09:46 -0400 Subject: [PATCH 1/3] Properly clean up response streams in BaseSession Wraps the request handling in a try/finally block to ensure that response streams are properly closed and removed from the tracking dictionary, even if an exception occurs during request processing. This change also prevents response_stream and response_stream_reader instances from piling up on _exit_stack over the course of the session. Github-Issue:#169 --- pyproject.toml | 3 +- src/mcp/shared/session.py | 70 ++++++++++++++------------- tests/client/test_resource_cleanup.py | 68 ++++++++++++++++++++++++++ uv.lock | 14 +++--- 4 files changed, 113 insertions(+), 42 deletions(-) create mode 100644 tests/client/test_resource_cleanup.py diff --git a/pyproject.toml b/pyproject.toml index 25514cd6b..fd7814226 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "httpx>=0.27", "httpx-sse>=0.4", "pydantic>=2.7.2,<3.0.0", - "starlette>=0.27", + "starlette>=0.46.2", "sse-starlette>=1.6.1", "pydantic-settings>=2.5.2", "uvicorn>=0.23.1", @@ -53,6 +53,7 @@ dev = [ "pytest-flakefinder>=1.1.0", "pytest-xdist>=3.6.1", "pytest-examples>=0.0.14", + "starlette>=0.46.2", ] docs = [ "mkdocs>=1.6.1", diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 05fd3ce37..67256045e 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -187,7 +187,6 @@ def __init__( self._receive_notification_type = receive_notification_type self._read_timeout_seconds = read_timeout_seconds self._in_flight = {} - self._exit_stack = AsyncExitStack() async def __aenter__(self) -> Self: @@ -230,42 +229,45 @@ async def send_request( ](1) self._response_streams[request_id] = response_stream - self._exit_stack.push_async_callback(lambda: response_stream.aclose()) - self._exit_stack.push_async_callback(lambda: response_stream_reader.aclose()) - - jsonrpc_request = JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - **request.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - - # TODO: Support progress callbacks - - await self._write_stream.send(JSONRPCMessage(jsonrpc_request)) - try: - with anyio.fail_after( - None - if self._read_timeout_seconds is None - else self._read_timeout_seconds.total_seconds() - ): - response_or_error = await response_stream_reader.receive() - except TimeoutError: - raise McpError( - ErrorData( - code=httpx.codes.REQUEST_TIMEOUT, - message=( - f"Timed out while waiting for response to " - f"{request.__class__.__name__}. Waited " - f"{self._read_timeout_seconds} seconds." - ), - ) + jsonrpc_request = JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + **request.model_dump(by_alias=True, mode="json", exclude_none=True), ) - if isinstance(response_or_error, JSONRPCError): - raise McpError(response_or_error.error) - else: - return result_type.model_validate(response_or_error.result) + # TODO: Support progress callbacks + + await self._write_stream.send(JSONRPCMessage(jsonrpc_request)) + + try: + with anyio.fail_after( + None + if self._read_timeout_seconds is None + else self._read_timeout_seconds.total_seconds() + ): + response_or_error = await response_stream_reader.receive() + except TimeoutError: + raise McpError( + ErrorData( + code=httpx.codes.REQUEST_TIMEOUT, + message=( + f"Timed out while waiting for response to " + f"{request.__class__.__name__}. Waited " + f"{self._read_timeout_seconds} seconds." + ), + ) + ) + + if isinstance(response_or_error, JSONRPCError): + raise McpError(response_or_error.error) + else: + return result_type.model_validate(response_or_error.result) + + finally: + self._response_streams.pop(request_id, None) + await response_stream.aclose() + await response_stream_reader.aclose() async def send_notification(self, notification: SendNotificationT) -> None: """ diff --git a/tests/client/test_resource_cleanup.py b/tests/client/test_resource_cleanup.py new file mode 100644 index 000000000..990b3a89a --- /dev/null +++ b/tests/client/test_resource_cleanup.py @@ -0,0 +1,68 @@ +from unittest.mock import patch + +import anyio +import pytest + +from mcp.shared.session import BaseSession +from mcp.types import ( + ClientRequest, + EmptyResult, + PingRequest, +) + + +@pytest.mark.anyio +async def test_send_request_stream_cleanup(): + """ + Test that send_request properly cleans up streams when an exception occurs. + + This test mocks out most of the session functionality to focus on stream cleanup. + """ + + # Create a mock session with the minimal required functionality + class TestSession(BaseSession): + async def _send_response(self, request_id, response): + pass + + # Create streams + write_stream_send, write_stream_receive = anyio.create_memory_object_stream(1) + read_stream_send, read_stream_receive = anyio.create_memory_object_stream(1) + + # Create the session + session = TestSession( + read_stream_receive, + write_stream_send, + object, # Request type doesn't matter for this test + object, # Notification type doesn't matter for this test + ) + + # Create a test request + request = ClientRequest( + PingRequest( + method="ping", + ) + ) + + # Patch the _write_stream.send method to raise an exception + async def mock_send(*args, **kwargs): + raise RuntimeError("Simulated network error") + + # Record the response streams before the test + initial_stream_count = len(session._response_streams) + + # Run the test with the patched method + with patch.object(session._write_stream, "send", mock_send): + with pytest.raises(RuntimeError): + await session.send_request(request, EmptyResult) + + # Verify that no response streams were leaked + assert len(session._response_streams) == initial_stream_count, ( + f"Expected {initial_stream_count} response streams after request, " + f"but found {len(session._response_streams)}" + ) + + # Clean up + await write_stream_send.aclose() + await write_stream_receive.aclose() + await read_stream_send.aclose() + await read_stream_receive.aclose() diff --git a/uv.lock b/uv.lock index 424e2d482..ae059e2bd 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,4 @@ version = 1 -revision = 1 requires-python = ">=3.10" [options] @@ -519,6 +518,7 @@ dev = [ { name = "pytest-flakefinder" }, { name = "pytest-xdist" }, { name = "ruff" }, + { name = "starlette" }, { name = "trio" }, ] docs = [ @@ -538,12 +538,11 @@ requires-dist = [ { name = "python-dotenv", marker = "extra == 'cli'", specifier = ">=1.0.0" }, { name = "rich", marker = "extra == 'rich'", specifier = ">=13.9.4" }, { name = "sse-starlette", specifier = ">=1.6.1" }, - { name = "starlette", specifier = ">=0.27" }, + { name = "starlette", specifier = ">=0.46.2" }, { name = "typer", marker = "extra == 'cli'", specifier = ">=0.12.4" }, { name = "uvicorn", specifier = ">=0.23.1" }, { name = "websockets", marker = "extra == 'ws'", specifier = ">=15.0.1" }, ] -provides-extras = ["cli", "rich", "ws"] [package.metadata.requires-dev] dev = [ @@ -553,6 +552,7 @@ dev = [ { name = "pytest-flakefinder", specifier = ">=1.1.0" }, { name = "pytest-xdist", specifier = ">=3.6.1" }, { name = "ruff", specifier = ">=0.8.5" }, + { name = "starlette", specifier = ">=0.46.2" }, { name = "trio", specifier = ">=0.26.2" }, ] docs = [ @@ -1394,14 +1394,14 @@ wheels = [ [[package]] name = "starlette" -version = "0.27.0" +version = "0.46.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/06/68/559bed5484e746f1ab2ebbe22312f2c25ec62e4b534916d41a8c21147bf8/starlette-0.27.0.tar.gz", hash = "sha256:6a6b0d042acb8d469a01eba54e9cda6cbd24ac602c4cd016723117d6a7e73b75", size = 51394 } +sdist = { url = "https://files.pythonhosted.org/packages/ce/20/08dfcd9c983f6a6f4a1000d934b9e6d626cff8d2eeb77a89a68eef20a2b7/starlette-0.46.2.tar.gz", hash = "sha256:7f7361f34eed179294600af672f565727419830b54b7b084efe44bb82d2fccd5", size = 2580846 } wheels = [ - { url = "https://files.pythonhosted.org/packages/58/f8/e2cca22387965584a409795913b774235752be4176d276714e15e1a58884/starlette-0.27.0-py3-none-any.whl", hash = "sha256:918416370e846586541235ccd38a474c08b80443ed31c578a418e2209b3eef91", size = 66978 }, + { url = "https://files.pythonhosted.org/packages/8b/0c/9d30a4ebeb6db2b25a841afbb80f6ef9a854fc3b41be131d249a977b4959/starlette-0.46.2-py3-none-any.whl", hash = "sha256:595633ce89f8ffa71a015caed34a5b2dc1c0cdb3f0f1fbd1e69339cf2abeec35", size = 72037 }, ] [[package]] @@ -1618,4 +1618,4 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2b/fb/c492d6daa5ec067c2988ac80c61359ace5c4c674c532985ac5a123436cec/websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b359ed09954d7c18bbc1680f380c7301f92c60bf924171629c5db97febb12f04", size = 174155 }, { url = "https://files.pythonhosted.org/packages/68/a1/dcb68430b1d00b698ae7a7e0194433bce4f07ded185f0ee5fb21e2a2e91e/websockets-15.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:cad21560da69f4ce7658ca2cb83138fb4cf695a2ba3e475e0559e05991aa8122", size = 176884 }, { url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743 }, -] \ No newline at end of file +] From 7a2e64627d20dd7c21add6644dcdb642c23c7b95 Mon Sep 17 00:00:00 2001 From: Basil Hosmer Date: Tue, 29 Apr 2025 15:35:21 -0400 Subject: [PATCH 2/3] back out starlette upgrade, add warning suppression --- pyproject.toml | 7 ++++--- uv.lock | 14 +++++++------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fd7814226..214d353ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "httpx>=0.27", "httpx-sse>=0.4", "pydantic>=2.7.2,<3.0.0", - "starlette>=0.46.2", + "starlette>=0.27", "sse-starlette>=1.6.1", "pydantic-settings>=2.5.2", "uvicorn>=0.23.1", @@ -53,7 +53,6 @@ dev = [ "pytest-flakefinder>=1.1.0", "pytest-xdist>=3.6.1", "pytest-examples>=0.0.14", - "starlette>=0.46.2", ] docs = [ "mkdocs>=1.6.1", @@ -114,5 +113,7 @@ filterwarnings = [ # This should be fixed on Uvicorn's side. "ignore::DeprecationWarning:websockets", "ignore:websockets.server.WebSocketServerProtocol is deprecated:DeprecationWarning", - "ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel" + "ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel", + # this is a problem in starlette 0.27, which we're currently pinned to + "ignore:Please use `import python_multipart` instead.:PendingDeprecationWarning", ] diff --git a/uv.lock b/uv.lock index ae059e2bd..424e2d482 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.10" [options] @@ -518,7 +519,6 @@ dev = [ { name = "pytest-flakefinder" }, { name = "pytest-xdist" }, { name = "ruff" }, - { name = "starlette" }, { name = "trio" }, ] docs = [ @@ -538,11 +538,12 @@ requires-dist = [ { name = "python-dotenv", marker = "extra == 'cli'", specifier = ">=1.0.0" }, { name = "rich", marker = "extra == 'rich'", specifier = ">=13.9.4" }, { name = "sse-starlette", specifier = ">=1.6.1" }, - { name = "starlette", specifier = ">=0.46.2" }, + { name = "starlette", specifier = ">=0.27" }, { name = "typer", marker = "extra == 'cli'", specifier = ">=0.12.4" }, { name = "uvicorn", specifier = ">=0.23.1" }, { name = "websockets", marker = "extra == 'ws'", specifier = ">=15.0.1" }, ] +provides-extras = ["cli", "rich", "ws"] [package.metadata.requires-dev] dev = [ @@ -552,7 +553,6 @@ dev = [ { name = "pytest-flakefinder", specifier = ">=1.1.0" }, { name = "pytest-xdist", specifier = ">=3.6.1" }, { name = "ruff", specifier = ">=0.8.5" }, - { name = "starlette", specifier = ">=0.46.2" }, { name = "trio", specifier = ">=0.26.2" }, ] docs = [ @@ -1394,14 +1394,14 @@ wheels = [ [[package]] name = "starlette" -version = "0.46.2" +version = "0.27.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ce/20/08dfcd9c983f6a6f4a1000d934b9e6d626cff8d2eeb77a89a68eef20a2b7/starlette-0.46.2.tar.gz", hash = "sha256:7f7361f34eed179294600af672f565727419830b54b7b084efe44bb82d2fccd5", size = 2580846 } +sdist = { url = "https://files.pythonhosted.org/packages/06/68/559bed5484e746f1ab2ebbe22312f2c25ec62e4b534916d41a8c21147bf8/starlette-0.27.0.tar.gz", hash = "sha256:6a6b0d042acb8d469a01eba54e9cda6cbd24ac602c4cd016723117d6a7e73b75", size = 51394 } wheels = [ - { url = "https://files.pythonhosted.org/packages/8b/0c/9d30a4ebeb6db2b25a841afbb80f6ef9a854fc3b41be131d249a977b4959/starlette-0.46.2-py3-none-any.whl", hash = "sha256:595633ce89f8ffa71a015caed34a5b2dc1c0cdb3f0f1fbd1e69339cf2abeec35", size = 72037 }, + { url = "https://files.pythonhosted.org/packages/58/f8/e2cca22387965584a409795913b774235752be4176d276714e15e1a58884/starlette-0.27.0-py3-none-any.whl", hash = "sha256:918416370e846586541235ccd38a474c08b80443ed31c578a418e2209b3eef91", size = 66978 }, ] [[package]] @@ -1618,4 +1618,4 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2b/fb/c492d6daa5ec067c2988ac80c61359ace5c4c674c532985ac5a123436cec/websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b359ed09954d7c18bbc1680f380c7301f92c60bf924171629c5db97febb12f04", size = 174155 }, { url = "https://files.pythonhosted.org/packages/68/a1/dcb68430b1d00b698ae7a7e0194433bce4f07ded185f0ee5fb21e2a2e91e/websockets-15.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:cad21560da69f4ce7658ca2cb83138fb4cf695a2ba3e475e0559e05991aa8122", size = 176884 }, { url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743 }, -] +] \ No newline at end of file From db6a17fc36b12c9d99ce52cb4f2464a5a723d7ab Mon Sep 17 00:00:00 2001 From: Basil Hosmer Date: Wed, 30 Apr 2025 09:42:56 -0400 Subject: [PATCH 3/3] remove deprecation warning --- pyproject.toml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 095212757..1aaf15593 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,7 +113,5 @@ filterwarnings = [ # This should be fixed on Uvicorn's side. "ignore::DeprecationWarning:websockets", "ignore:websockets.server.WebSocketServerProtocol is deprecated:DeprecationWarning", - "ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel", - # this is a problem in starlette 0.27, which we're currently pinned to - "ignore:Please use `import python_multipart` instead.:PendingDeprecationWarning", + "ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel" ]