diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index d051c25bf..6883925f3 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -41,6 +41,7 @@ async def handle_sse(request): from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import ValidationError from sse_starlette import EventSourceResponse +from starlette.background import BackgroundTask from starlette.requests import Request from starlette.responses import Response from starlette.types import Receive, Scope, Send @@ -78,6 +79,17 @@ def __init__(self, endpoint: str) -> None: self._read_stream_writers = {} logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") + async def _remove_stream_writer(self, session_id: UUID) -> None: + """ + Remove the SSE session with the given session ID. + """ + logger.debug(f"Remove SSE session with ID: {session_id}") + if writer := self._read_stream_writers.pop(session_id, None): + await writer.aclose() + logger.debug(f"Closed SSE session with ID: {session_id}") + else: + logger.warning(f"Session ID {session_id} not found for removal") + @asynccontextmanager async def connect_sse(self, scope: Scope, receive: Receive, send: Send): if scope["type"] != "http": @@ -120,9 +132,12 @@ async def sse_writer(): } ) + background_task = BackgroundTask(self._remove_stream_writer, session_id) async with anyio.create_task_group() as tg: response = EventSourceResponse( - content=sse_stream_reader, data_sender_callable=sse_writer + content=sse_stream_reader, + data_sender_callable=sse_writer, + background=background_task, ) logger.debug("Starting SSE response task") tg.start_soon(response, scope, receive, send)