diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index c0740f7c..3ad73dc3 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -17,13 +17,13 @@ from pydantic import BaseModel, Field from pydantic.networks import AnyUrl from pydantic_settings import BaseSettings, SettingsConfigDict -from sse_starlette import EventSourceResponse from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware from starlette.requests import Request from starlette.responses import Response -from starlette.routing import Mount, Route, request_response # type: ignore +from starlette.routing import Mount, Route +from starlette.types import Receive, Scope, Send from mcp.server.auth.middleware.auth_context import AuthContextMiddleware from mcp.server.auth.middleware.bearer_auth import ( @@ -576,20 +576,19 @@ def sse_app(self) -> Starlette: sse = SseServerTransport(self.settings.message_path) - async def handle_sse(request: Request) -> EventSourceResponse: + async def handle_sse(scope: Scope, receive: Receive, send: Send): # Add client ID from auth context into request context if available async with sse.connect_sse( - request.scope, - request.receive, - request._send, # type: ignore[reportPrivateUsage] + scope, + receive, + send, ) as streams: await self._mcp_server.run( streams[0], streams[1], self._mcp_server.create_initialization_options(), ) - return streams[2] # Create routes routes: list[Route | Mount] = [] @@ -629,7 +628,7 @@ async def handle_sse(request: Request) -> EventSourceResponse: Route( self.settings.sse_path, endpoint=RequireAuthMiddleware( - request_response(handle_sse), required_scopes + handle_sse, required_scopes ), methods=["GET"], ) diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index f6054c79..9390a7e2 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -120,17 +120,15 @@ async def sse_writer(): } ) - # Ensure all streams are properly closed - async with read_stream, write_stream, read_stream_writer, sse_stream_reader: - async with anyio.create_task_group() as tg: - response = EventSourceResponse( - content=sse_stream_reader, data_sender_callable=sse_writer - ) - logger.debug("Starting SSE response task") - tg.start_soon(response, scope, receive, send) - - logger.debug("Yielding read and write streams") - yield (read_stream, write_stream, response) + async with anyio.create_task_group() as tg: + response = EventSourceResponse( + content=sse_stream_reader, data_sender_callable=sse_writer + ) + logger.debug("Starting SSE response task") + tg.start_soon(response, scope, receive, send) + + logger.debug("Yielding read and write streams") + yield (read_stream, write_stream) async def handle_post_message( self, scope: Scope, receive: Receive, send: Send @@ -175,3 +173,4 @@ async def handle_post_message( response = Response("Accepted", status_code=202) await response(scope, receive, send) await writer.send(message) + \ No newline at end of file diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index e6d82524..b0088c64 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -4,16 +4,13 @@ import base64 import hashlib -import json import secrets import time import unittest.mock from urllib.parse import parse_qs, urlparse -import anyio import httpx import pytest -from httpx_sse import aconnect_sse from pydantic import AnyHttpUrl from starlette.applications import Starlette @@ -30,14 +27,10 @@ RevocationOptions, create_auth_routes, ) -from mcp.server.auth.settings import AuthSettings -from mcp.server.fastmcp import FastMCP -from mcp.server.streaming_asgi_transport import StreamingASGITransport from mcp.shared.auth import ( OAuthClientInformationFull, OAuthToken, ) -from mcp.types import JSONRPCRequest # Mock OAuth provider for testing @@ -230,10 +223,11 @@ def auth_app(mock_oauth_provider): @pytest.fixture -def test_client(auth_app) -> httpx.AsyncClient: - return httpx.AsyncClient( +async def test_client(auth_app): + async with httpx.AsyncClient( transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com" - ) + ) as client: + yield client @pytest.fixture @@ -993,163 +987,7 @@ async def test_client_registration_invalid_grant_type( ) -class TestFastMCPWithAuth: - """Test FastMCP server with authentication.""" - - @pytest.mark.anyio - async def test_fastmcp_with_auth( - self, mock_oauth_provider: MockOAuthProvider, pkce_challenge - ): - """Test creating a FastMCP server with authentication.""" - # Create FastMCP server with auth provider - mcp = FastMCP( - auth_server_provider=mock_oauth_provider, - require_auth=True, - auth=AuthSettings( - issuer_url=AnyHttpUrl("https://auth.example.com"), - client_registration_options=ClientRegistrationOptions(enabled=True), - revocation_options=RevocationOptions(enabled=True), - required_scopes=["read", "write"], - ), - ) - - # Add a test tool - @mcp.tool() - def test_tool(x: int) -> str: - return f"Result: {x}" - - async with anyio.create_task_group() as task_group: - transport = StreamingASGITransport( - app=mcp.sse_app(), - task_group=task_group, - ) - test_client = httpx.AsyncClient( - transport=transport, base_url="http://mcptest.com" - ) - - # Test metadata endpoint - response = await test_client.get("/.well-known/oauth-authorization-server") - assert response.status_code == 200 - # Test that auth is required for protected endpoints - response = await test_client.get("/sse") - assert response.status_code == 401 - - response = await test_client.post("/messages/") - assert response.status_code == 401, response.content - - response = await test_client.post( - "/messages/", - headers={"Authorization": "invalid"}, - ) - assert response.status_code == 401 - - response = await test_client.post( - "/messages/", - headers={"Authorization": "Bearer invalid"}, - ) - assert response.status_code == 401 - - # now, become authenticated and try to go through the flow again - client_metadata = { - "redirect_uris": ["https://client.example.com/callback"], - "client_name": "Test Client", - } - - response = await test_client.post( - "/register", - json=client_metadata, - ) - assert response.status_code == 201 - client_info = response.json() - - # Request authorization using POST with form-encoded data - response = await test_client.post( - "/authorize", - data={ - "response_type": "code", - "client_id": client_info["client_id"], - "redirect_uri": "https://client.example.com/callback", - "code_challenge": pkce_challenge["code_challenge"], - "code_challenge_method": "S256", - "state": "test_state", - }, - ) - assert response.status_code == 302 - - # Extract the authorization code from the redirect URL - redirect_url = response.headers["location"] - parsed_url = urlparse(redirect_url) - query_params = parse_qs(parsed_url.query) - - assert "code" in query_params - auth_code = query_params["code"][0] - - # Exchange the authorization code for tokens - response = await test_client.post( - "/token", - data={ - "grant_type": "authorization_code", - "client_id": client_info["client_id"], - "client_secret": client_info["client_secret"], - "code": auth_code, - "code_verifier": pkce_challenge["code_verifier"], - "redirect_uri": "https://client.example.com/callback", - }, - ) - assert response.status_code == 200 - - token_response = response.json() - assert "access_token" in token_response - authorization = f"Bearer {token_response['access_token']}" - - # Test the authenticated endpoint with valid token - async with aconnect_sse( - test_client, "GET", "/sse", headers={"Authorization": authorization} - ) as event_source: - assert event_source.response.status_code == 200 - events = event_source.aiter_sse() - sse = await events.__anext__() - assert sse.event == "endpoint" - assert sse.data.startswith("/messages/?session_id=") - messages_uri = sse.data - - # verify that we can now post to the /messages endpoint, - # and get a response on the /sse endpoint - response = await test_client.post( - messages_uri, - headers={"Authorization": authorization}, - content=JSONRPCRequest( - jsonrpc="2.0", - id="123", - method="initialize", - params={ - "protocolVersion": "2024-11-05", - "capabilities": { - "roots": {"listChanged": True}, - "sampling": {}, - }, - "clientInfo": {"name": "ExampleClient", "version": "1.0.0"}, - }, - ).model_dump_json(), - ) - assert response.status_code == 202 - assert response.content == b"Accepted" - - sse = await events.__anext__() - assert sse.event == "message" - sse_data = json.loads(sse.data) - assert sse_data["id"] == "123" - assert set(sse_data["result"]["capabilities"].keys()) == { - "experimental", - "prompts", - "resources", - "tools", - } - # the /sse endpoint will never finish; normally, the client could just - # disconnect, but in tests the easiest way to do this is to cancel the - # task group - task_group.cancel_scope.cancel() class TestAuthorizeEndpointErrors: diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index d7a10d09..1d5e12f9 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -150,15 +150,20 @@ def server(server_port: int) -> Generator[None, None, None]: print("server process failed to terminate") +@pytest.fixture() +async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, None]: + """Create test client""" + async with httpx.AsyncClient(base_url=server_url) as client: + yield client + +# Tests @pytest.mark.anyio -@pytest.mark.skip( - "fails in CI, but works locally. Need to investigate why." -) -async def test_raw_sse_connection(server, server_url) -> None: +async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None: """Test the SSE connection establishment simply with an HTTP client.""" - try: - async with httpx.AsyncClient(base_url=server_url) as http_client: + async with anyio.create_task_group(): + + async def connection_test() -> None: async with http_client.stream("GET", "/sse") as response: assert response.status_code == 200 assert ( @@ -176,8 +181,9 @@ async def test_raw_sse_connection(server, server_url) -> None: return line_number += 1 - except Exception as e: - pytest.fail(f"{e}") + # Add timeout to prevent test from hanging if it fails + with anyio.fail_after(3): + await connection_test() @pytest.mark.anyio @@ -243,4 +249,4 @@ async def test_sse_client_timeout( # we should receive an error here return - pytest.fail("the client should have timed out and returned an error already") + pytest.fail("the client should have timed out and returned an error already") \ No newline at end of file