diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index a782b58a7..476c5c001 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -10,7 +10,7 @@ from httpx_sse import aconnect_sse import mcp.types as types -from mcp.shared._httpx_utils import create_mcp_http_client +from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) @@ -26,6 +26,7 @@ async def sse_client( headers: dict[str, Any] | None = None, timeout: float = 5, sse_read_timeout: float = 60 * 5, + httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, auth: httpx.Auth | None = None, ): """ @@ -53,7 +54,7 @@ async def sse_client( async with anyio.create_task_group() as tg: try: logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") - async with create_mcp_http_client(headers=headers, auth=auth) as client: + async with httpx_client_factory(headers=headers, auth=auth) as client: async with aconnect_sse( client, "GET", diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 79b2995e1..61aca4282 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -19,7 +19,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import EventSource, ServerSentEvent, aconnect_sse -from mcp.shared._httpx_utils import create_mcp_http_client +from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import ( ErrorData, @@ -430,6 +430,7 @@ async def streamablehttp_client( timeout: timedelta = timedelta(seconds=30), sse_read_timeout: timedelta = timedelta(seconds=60 * 5), terminate_on_close: bool = True, + httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, auth: httpx.Auth | None = None, ) -> AsyncGenerator[ tuple[ @@ -464,7 +465,7 @@ async def streamablehttp_client( try: logger.info(f"Connecting to StreamableHTTP endpoint: {url}") - async with create_mcp_http_client( + async with httpx_client_factory( headers=transport.request_headers, timeout=httpx.Timeout( transport.timeout.seconds, read=transport.sse_read_timeout.seconds diff --git a/src/mcp/shared/_httpx_utils.py b/src/mcp/shared/_httpx_utils.py index 5240c970c..e0611ce73 100644 --- a/src/mcp/shared/_httpx_utils.py +++ b/src/mcp/shared/_httpx_utils.py @@ -1,12 +1,21 @@ """Utilities for creating standardized httpx AsyncClient instances.""" -from typing import Any +from typing import Any, Protocol import httpx __all__ = ["create_mcp_http_client"] +class McpHttpClientFactory(Protocol): + def __call__( + self, + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: ... + + def create_mcp_http_client( headers: dict[str, str] | None = None, timeout: httpx.Timeout | None = None,