From 673ae93a8698beb285df31bef063646a922213a9 Mon Sep 17 00:00:00 2001 From: rkondra-eightfold Date: Tue, 15 Apr 2025 23:07:40 +0530 Subject: [PATCH 1/2] fix: incorrect resolution of the /messages endpoint URL in the SSE client when the FastAPI app is mounted under a base path (e.g., /mcp). --- src/mcp/client/sse.py | 44 +++++++++++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 4f6241a7..c6663ea4 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -14,8 +14,31 @@ logger = logging.getLogger(__name__) +# TODO: move these to utils/url_utils.py +def get_origin(url: str) -> str: + parsed_url = urlparse(url) + return f"{parsed_url.scheme}://{parsed_url.netloc}" + + +def get_path(url: str) -> str: + parsed_url = urlparse(url) + return parsed_url.path + + +def get_endpoint_url( + base_url: str, sse_relative_url: str, server_mount_path: str = "" +) -> str: + endpoint_url = urljoin(base_url, sse_relative_url) + if server_mount_path: + origin, path = get_origin(endpoint_url), get_path(endpoint_url) + endpoint_url = urljoin( + f"{origin}/{server_mount_path.strip('/')}/", path.lstrip("/") + ) + return endpoint_url + + def remove_request_params(url: str) -> str: - return urljoin(url, urlparse(url).path) + return urljoin(url, get_path(url)) @asynccontextmanager @@ -24,12 +47,16 @@ async def sse_client( headers: dict[str, Any] | None = None, timeout: float = 5, sse_read_timeout: float = 60 * 5, + server_mount_path: str = "", ): """ Client transport for SSE. `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. + + `server_mount_path` provides the relative mount path of the MCP server + (used if it is mounted relatively on another ASGI server). """ read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] @@ -61,18 +88,15 @@ async def sse_reader( logger.debug(f"Received SSE event: {sse.event}") match sse.event: case "endpoint": - endpoint_url = urljoin(url, sse.data) + endpoint_url = get_endpoint_url( + base_url=url, + sse_relative_url=sse.data, + server_mount_path=server_mount_path, + ) logger.info( f"Received endpoint URL: {endpoint_url}" ) - - url_parsed = urlparse(url) - endpoint_parsed = urlparse(endpoint_url) - if ( - url_parsed.netloc != endpoint_parsed.netloc - or url_parsed.scheme - != endpoint_parsed.scheme - ): + if get_origin(url) != get_origin(endpoint_url): error_msg = ( "Endpoint origin does not match " f"connection origin: {endpoint_url}" From 0729e923b3c27dffdbe9baf04ff76c748b757acf Mon Sep 17 00:00:00 2001 From: rkondra-eightfold Date: Tue, 15 Apr 2025 23:25:11 +0530 Subject: [PATCH 2/2] fix: preserve session_id --- src/mcp/client/sse.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index c6663ea4..0aef8c15 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -20,9 +20,16 @@ def get_origin(url: str) -> str: return f"{parsed_url.scheme}://{parsed_url.netloc}" -def get_path(url: str) -> str: +def get_relative_path(url: str, remove_params: bool = False) -> str: parsed_url = urlparse(url) - return parsed_url.path + if remove_params: + return parsed_url.path + relative_path = parsed_url.path + if parsed_url.query: + relative_path += f"?{parsed_url.query}" + if parsed_url.fragment: + relative_path += f"#{parsed_url.fragment}" + return relative_path def get_endpoint_url( @@ -30,7 +37,7 @@ def get_endpoint_url( ) -> str: endpoint_url = urljoin(base_url, sse_relative_url) if server_mount_path: - origin, path = get_origin(endpoint_url), get_path(endpoint_url) + origin, path = get_origin(endpoint_url), get_relative_path(endpoint_url) endpoint_url = urljoin( f"{origin}/{server_mount_path.strip('/')}/", path.lstrip("/") ) @@ -38,7 +45,7 @@ def get_endpoint_url( def remove_request_params(url: str) -> str: - return urljoin(url, get_path(url)) + return urljoin(url, get_relative_path(url, remove_params=True)) @asynccontextmanager