From 2c7bd8343eb92bf6d3d3dbf3e66687b8d47cbd5f Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Tue, 11 Feb 2025 12:14:58 +0000 Subject: [PATCH 1/6] feat: add lifespan support to low-level MCP server Adds a context manager based lifespan API in mcp.server.lowlevel.server to manage server lifecycles in a type-safe way. This enables servers to: - Initialize resources on startup and clean them up on shutdown - Pass context data from startup to request handlers - Support async startup/shutdown operations --- src/mcp/server/lowlevel/server.py | 85 +++++++++++++++++++++---------- src/mcp/shared/context.py | 3 +- 2 files changed, 61 insertions(+), 27 deletions(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 3d9172260..28942cf8e 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -68,7 +68,8 @@ async def main(): import logging import warnings from collections.abc import Awaitable, Callable -from typing import Any, Sequence +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Any, AsyncIterator, Generic, Sequence, TypeVar from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl @@ -101,13 +102,36 @@ def __init__( self.tools_changed = tools_changed -class Server: +LifespanResultT = TypeVar("LifespanResultT") + + +@asynccontextmanager +async def lifespan(server: "Server") -> AsyncIterator[object]: + """Default lifespan context manager that does nothing. + + Args: + server: The server instance this lifespan is managing + + Returns: + An empty context object + """ + yield {} + + +class Server(Generic[LifespanResultT]): def __init__( - self, name: str, version: str | None = None, instructions: str | None = None + self, + name: str, + version: str | None = None, + instructions: str | None = None, + lifespan: Callable[ + ["Server"], AbstractAsyncContextManager[LifespanResultT] + ] = lifespan, ): self.name = name self.version = version self.instructions = instructions + self.lifespan = lifespan self.request_handlers: dict[ type, Callable[..., Awaitable[types.ServerResult]] ] = { @@ -446,35 +470,43 @@ async def run( raise_exceptions: bool = False, ): with warnings.catch_warnings(record=True) as w: - async with ServerSession( - read_stream, write_stream, initialization_options - ) as session: - async for message in session.incoming_messages: - logger.debug(f"Received message: {message}") - - match message: - case ( - RequestResponder( - request=types.ClientRequest(root=req) - ) as responder - ): - with responder: - await self._handle_request( - message, req, session, raise_exceptions - ) - case types.ClientNotification(root=notify): - await self._handle_notification(notify) - - for warning in w: - logger.info( - f"Warning: {warning.category.__name__}: {warning.message}" - ) + async with self.lifespan(self) as lifespan_context: + async with ServerSession( + read_stream, write_stream, initialization_options + ) as session: + async for message in session.incoming_messages: + logger.debug(f"Received message: {message}") + + match message: + case ( + RequestResponder( + request=types.ClientRequest(root=req) + ) as responder + ): + with responder: + await self._handle_request( + message, + req, + session, + lifespan_context, + raise_exceptions, + ) + case types.ClientNotification(root=notify): + await self._handle_notification(notify) + + for warning in w: + logger.info( + "Warning: %s: %s", + warning.category.__name__, + warning.message, + ) async def _handle_request( self, message: RequestResponder, req: Any, session: ServerSession, + lifespan_context: object, raise_exceptions: bool, ): logger.info(f"Processing request of type {type(req).__name__}") @@ -491,6 +523,7 @@ async def _handle_request( message.request_id, message.request_meta, session, + lifespan_context, ) ) response = await handler(req) diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 760d55877..50e5d5194 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Generic, TypeVar +from typing import Any, Generic, TypeVar from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParams @@ -12,3 +12,4 @@ class RequestContext(Generic[SessionT]): request_id: RequestId meta: RequestParams.Meta | None session: SessionT + lifespan_context: Any From d3ea9009b0414c713084dc8688459120af6a67cc Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Tue, 11 Feb 2025 12:15:08 +0000 Subject: [PATCH 2/6] feat: add lifespan support to FastMCP server Adds support for the lifespan API to FastMCP server, enabling: - Simple setup with FastMCP constructor - Type-safe context passing to tools and handlers - Configuration via Settings class --- src/mcp/server/fastmcp/server.py | 44 +++++++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index aa7c79bcb..bc341b404 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -3,8 +3,13 @@ import inspect import json import re +from collections.abc import AsyncIterator +from contextlib import ( + AbstractAsyncContextManager, + asynccontextmanager, +) from itertools import chain -from typing import Any, Callable, Literal, Sequence +from typing import Any, Callable, Generic, Literal, Sequence import anyio import pydantic_core @@ -19,8 +24,16 @@ from mcp.server.fastmcp.tools import ToolManager from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger from mcp.server.fastmcp.utilities.types import Image -from mcp.server.lowlevel import Server as MCPServer from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp.server.lowlevel.server import ( + LifespanResultT, +) +from mcp.server.lowlevel.server import ( + Server as MCPServer, +) +from mcp.server.lowlevel.server import ( + lifespan as default_lifespan, +) from mcp.server.sse import SseServerTransport from mcp.server.stdio import stdio_server from mcp.shared.context import RequestContext @@ -50,7 +63,7 @@ logger = get_logger(__name__) -class Settings(BaseSettings): +class Settings(BaseSettings, Generic[LifespanResultT]): """FastMCP server settings. All settings can be configured via environment variables with the prefix FASTMCP_. @@ -85,13 +98,36 @@ class Settings(BaseSettings): description="List of dependencies to install in the server environment", ) + lifespan: ( + Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]] | None + ) = Field(None, description="Lifespan contexte manager") + + +def lifespan_wrapper( + app: "FastMCP", + lifespan: Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]], +) -> Callable[[MCPServer], AbstractAsyncContextManager[object]]: + @asynccontextmanager + async def wrap(s: MCPServer) -> AsyncIterator[object]: + async with lifespan(app) as context: + yield context + + return wrap + class FastMCP: def __init__( self, name: str | None = None, instructions: str | None = None, **settings: Any ): self.settings = Settings(**settings) - self._mcp_server = MCPServer(name=name or "FastMCP", instructions=instructions) + + self._mcp_server = MCPServer( + name=name or "FastMCP", + instructions=instructions, + lifespan=lifespan_wrapper(self, self.settings.lifespan) + if self.settings.lifespan + else default_lifespan, + ) self._tool_manager = ToolManager( warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools ) From e598750cbaec0dcc48be1f6561c0062553ac9b2e Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Tue, 11 Feb 2025 12:15:18 +0000 Subject: [PATCH 3/6] test: add tests for server lifespan support Adds comprehensive tests for lifespan functionality: - Tests for both low-level Server and FastMCP classes - Coverage for startup, shutdown, and context access - Verifies context passing to request handlers --- tests/issues/test_176_progress_token.py | 5 +- tests/server/test_lifespan.py | 207 ++++++++++++++++++++++++ 2 files changed, 211 insertions(+), 1 deletion(-) create mode 100644 tests/server/test_lifespan.py diff --git a/tests/issues/test_176_progress_token.py b/tests/issues/test_176_progress_token.py index ed8ab128a..7f9131a1e 100644 --- a/tests/issues/test_176_progress_token.py +++ b/tests/issues/test_176_progress_token.py @@ -20,7 +20,10 @@ async def test_progress_token_zero_first_call(): mock_meta.progressToken = 0 # This is the key test case - token is 0 request_context = RequestContext( - request_id="test-request", session=mock_session, meta=mock_meta + request_id="test-request", + session=mock_session, + meta=mock_meta, + lifespan_context=None, ) # Create context with our mocks diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py new file mode 100644 index 000000000..14afb6b06 --- /dev/null +++ b/tests/server/test_lifespan.py @@ -0,0 +1,207 @@ +"""Tests for lifespan functionality in both low-level and FastMCP servers.""" + +from contextlib import asynccontextmanager +from typing import AsyncIterator + +import anyio +import pytest +from pydantic import TypeAdapter + +from mcp.server.fastmcp import Context, FastMCP +from mcp.server.lowlevel.server import NotificationOptions, Server +from mcp.server.models import InitializationOptions +from mcp.types import ( + ClientCapabilities, + Implementation, + InitializeRequestParams, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, +) + + +@pytest.mark.anyio +async def test_lowlevel_server_lifespan(): + """Test that lifespan works in low-level server.""" + + @asynccontextmanager + async def test_lifespan(server: Server) -> AsyncIterator[dict]: + """Test lifespan context that tracks startup/shutdown.""" + context = {"started": False, "shutdown": False} + try: + context["started"] = True + yield context + finally: + context["shutdown"] = True + + server = Server("test", lifespan=test_lifespan) + + # Create memory streams for testing + send_stream1, receive_stream1 = anyio.create_memory_object_stream(100) + send_stream2, receive_stream2 = anyio.create_memory_object_stream(100) + + # Create a tool that accesses lifespan context + @server.call_tool() + async def check_lifespan(name: str, arguments: dict) -> list: + ctx = server.request_context + assert isinstance(ctx.lifespan_context, dict) + assert ctx.lifespan_context["started"] + assert not ctx.lifespan_context["shutdown"] + return [{"type": "text", "text": "true"}] + + # Run server in background task + async with anyio.create_task_group() as tg: + + async def run_server(): + await server.run( + receive_stream1, + send_stream2, + InitializationOptions( + server_name="test", + server_version="0.1.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + raise_exceptions=True, + ) + + tg.start_soon(run_server) + + # Initialize the server + params = InitializeRequestParams( + protocolVersion="2024-11-05", + capabilities=ClientCapabilities(), + clientInfo=Implementation(name="test-client", version="0.1.0"), + ) + await send_stream1.send( + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=TypeAdapter(InitializeRequestParams).dump_python(params), + ) + ) + ) + response = await receive_stream2.receive() + + # Send initialized notification + await send_stream1.send( + JSONRPCMessage( + root=JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) + ) + ) + + # Call the tool to verify lifespan context + await send_stream1.send( + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params={"name": "check_lifespan", "arguments": {}}, + ) + ) + ) + + # Get response and verify + response = await receive_stream2.receive() + assert response.root.result["content"][0]["text"] == "true" + + # Cancel server task + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_fastmcp_server_lifespan(): + """Test that lifespan works in FastMCP server.""" + + @asynccontextmanager + async def test_lifespan(server: FastMCP) -> AsyncIterator[dict]: + """Test lifespan context that tracks startup/shutdown.""" + context = {"started": False, "shutdown": False} + try: + context["started"] = True + yield context + finally: + context["shutdown"] = True + + server = FastMCP("test", lifespan=test_lifespan) + + # Create memory streams for testing + send_stream1, receive_stream1 = anyio.create_memory_object_stream(100) + send_stream2, receive_stream2 = anyio.create_memory_object_stream(100) + + # Add a tool that checks lifespan context + @server.tool() + def check_lifespan(ctx: Context) -> bool: + """Tool that checks lifespan context.""" + assert isinstance(ctx.request_context.lifespan_context, dict) + assert ctx.request_context.lifespan_context["started"] + assert not ctx.request_context.lifespan_context["shutdown"] + return True + + # Run server in background task + async with anyio.create_task_group() as tg: + + async def run_server(): + await server._mcp_server.run( + receive_stream1, + send_stream2, + server._mcp_server.create_initialization_options(), + raise_exceptions=True, + ) + + tg.start_soon(run_server) + + # Initialize the server + params = InitializeRequestParams( + protocolVersion="2024-11-05", + capabilities=ClientCapabilities(), + clientInfo=Implementation(name="test-client", version="0.1.0"), + ) + await send_stream1.send( + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=TypeAdapter(InitializeRequestParams).dump_python(params), + ) + ) + ) + response = await receive_stream2.receive() + + # Send initialized notification + await send_stream1.send( + JSONRPCMessage( + root=JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) + ) + ) + + # Call the tool to verify lifespan context + await send_stream1.send( + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params={"name": "check_lifespan", "arguments": {}}, + ) + ) + ) + + # Get response and verify + response = await receive_stream2.receive() + assert response.root.result["content"][0]["text"] == "true" + + # Cancel server task + tg.cancel_scope.cancel() From e5815bd162c490741a3c4769bcbadc77f8961c21 Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Tue, 11 Feb 2025 12:15:29 +0000 Subject: [PATCH 4/6] docs: update README with lifespan examples and usage Add comprehensive documentation for lifespan support: - Add usage examples for both Server and FastMPC classes - Document startup/shutdown patterns - Show context access in tools and handlers - Clean up spacing in test files --- README.md | 56 +++++++++++++++++++++- tests/server/fastmcp/test_func_metadata.py | 22 ++++----- 2 files changed, 66 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 310bb35b5..8de0d7988 100644 --- a/README.md +++ b/README.md @@ -135,6 +135,28 @@ mcp = FastMCP("My App") # Specify dependencies for deployment and development mcp = FastMCP("My App", dependencies=["pandas", "numpy"]) + +# Add lifespan support for startup/shutdown +@asynccontextmanager +async def app_lifespan(server: FastMCP) -> AsyncIterator[dict]: + """Manage application lifecycle""" + try: + # Initialize on startup + await db.connect() + yield {"db": db} + finally: + # Cleanup on shutdown + await db.disconnect() + +# Pass lifespan to server +mcp = FastMCP("My App", lifespan=app_lifespan) + +# Access lifespan context in tools +@mcp.tool() +def query_db(ctx: Context) -> str: + """Tool that uses initialized resources""" + db = ctx.request_context.lifespan_context["db"] + return db.query() ``` ### Resources @@ -334,7 +356,39 @@ def query_data(sql: str) -> str: ### Low-Level Server -For more control, you can use the low-level server implementation directly. This gives you full access to the protocol and allows you to customize every aspect of your server: +For more control, you can use the low-level server implementation directly. This gives you full access to the protocol and allows you to customize every aspect of your server, including lifecycle management through the lifespan API: + +```python +from contextlib import asynccontextmanager +from typing import AsyncIterator + +@asynccontextmanager +async def server_lifespan(server: Server) -> AsyncIterator[dict]: + """Manage server startup and shutdown lifecycle.""" + try: + # Initialize resources on startup + await db.connect() + yield {"db": db} + finally: + # Clean up on shutdown + await db.disconnect() + +# Pass lifespan to server +server = Server("example-server", lifespan=server_lifespan) + +# Access lifespan context in handlers +@server.call_tool() +async def query_db(name: str, arguments: dict) -> list: + ctx = server.request_context + db = ctx.lifespan_context["db"] + return await db.query(arguments["query"]) +``` + +The lifespan API provides: +- A way to initialize resources when the server starts and clean them up when it stops +- Access to initialized resources through the request context in handlers +- Support for both low-level Server and FastMCP classes +- Type-safe context passing between lifespan and request handlers ```python from mcp.server.lowlevel import Server, NotificationOptions diff --git a/tests/server/fastmcp/test_func_metadata.py b/tests/server/fastmcp/test_func_metadata.py index b68fb9025..6461648eb 100644 --- a/tests/server/fastmcp/test_func_metadata.py +++ b/tests/server/fastmcp/test_func_metadata.py @@ -236,7 +236,7 @@ async def check_call(args): def test_complex_function_json_schema(): """Test JSON schema generation for complex function arguments. - + Note: Different versions of pydantic output slightly different JSON Schema formats for model fields with defaults. The format changed in 2.9.0: @@ -245,16 +245,16 @@ def test_complex_function_json_schema(): "allOf": [{"$ref": "#/$defs/Model"}], "default": {} } - + 2. Since 2.9.0: { "$ref": "#/$defs/Model", "default": {} } - + Both formats are valid and functionally equivalent. This test accepts either format to ensure compatibility across our supported pydantic versions. - + This change in format does not affect runtime behavior since: 1. Both schemas validate the same way 2. The actual model classes and validation logic are unchanged @@ -262,17 +262,17 @@ def test_complex_function_json_schema(): """ meta = func_metadata(complex_arguments_fn) actual_schema = meta.arg_model.model_json_schema() - + # Create a copy of the actual schema to normalize normalized_schema = actual_schema.copy() - + # Normalize the my_model_a_with_default field to handle both pydantic formats - if 'allOf' in actual_schema['properties']['my_model_a_with_default']: - normalized_schema['properties']['my_model_a_with_default'] = { - '$ref': '#/$defs/SomeInputModelA', - 'default': {} + if "allOf" in actual_schema["properties"]["my_model_a_with_default"]: + normalized_schema["properties"]["my_model_a_with_default"] = { + "$ref": "#/$defs/SomeInputModelA", + "default": {}, } - + assert normalized_schema == { "$defs": { "InnerModel": { From fddba007230bf9fb94a6e2595f0c67ef2f857815 Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Tue, 11 Feb 2025 12:26:32 +0000 Subject: [PATCH 5/6] refactor: improve server context management with AsyncExitStack Replace nested context managers with AsyncExitStack to ensure proper cleanup order during server shutdown and make the code more maintainable. --- src/mcp/server/lowlevel/server.py | 64 ++++++++++++++++--------------- 1 file changed, 34 insertions(+), 30 deletions(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 28942cf8e..a4a8510c4 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -470,36 +470,40 @@ async def run( raise_exceptions: bool = False, ): with warnings.catch_warnings(record=True) as w: - async with self.lifespan(self) as lifespan_context: - async with ServerSession( - read_stream, write_stream, initialization_options - ) as session: - async for message in session.incoming_messages: - logger.debug(f"Received message: {message}") - - match message: - case ( - RequestResponder( - request=types.ClientRequest(root=req) - ) as responder - ): - with responder: - await self._handle_request( - message, - req, - session, - lifespan_context, - raise_exceptions, - ) - case types.ClientNotification(root=notify): - await self._handle_notification(notify) - - for warning in w: - logger.info( - "Warning: %s: %s", - warning.category.__name__, - warning.message, - ) + from contextlib import AsyncExitStack + + async with AsyncExitStack() as stack: + lifespan_context = await stack.enter_async_context(self.lifespan(self)) + session = await stack.enter_async_context( + ServerSession(read_stream, write_stream, initialization_options) + ) + + async for message in session.incoming_messages: + logger.debug(f"Received message: {message}") + + match message: + case ( + RequestResponder( + request=types.ClientRequest(root=req) + ) as responder + ): + with responder: + await self._handle_request( + message, + req, + session, + lifespan_context, + raise_exceptions, + ) + case types.ClientNotification(root=notify): + await self._handle_notification(notify) + + for warning in w: + logger.info( + "Warning: %s: %s", + warning.category.__name__, + warning.message, + ) async def _handle_request( self, From 4d3e05f6f6104e7e189f29ea02068a3f2025324f Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Wed, 12 Feb 2025 22:12:09 +0000 Subject: [PATCH 6/6] refactor: improve lifespan context typing and documentation - Add proper generic parameter for lifespan context type - Update README with TypedDict example for strong typing - Fix context variable initialization in server - Improve property return type safety - Remove redundant documentation - Ensure compatibility with existing tests --- README.md | 17 +++++++++++------ src/mcp/server/fastmcp/server.py | 2 +- src/mcp/server/lowlevel/server.py | 12 ++++++------ src/mcp/shared/context.py | 7 ++++--- 4 files changed, 22 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 8de0d7988..370b4f334 100644 --- a/README.md +++ b/README.md @@ -128,6 +128,9 @@ The [Model Context Protocol (MCP)](https://modelcontextprotocol.io) lets you bui The FastMCP server is your core interface to the MCP protocol. It handles connection management, protocol compliance, and message routing: ```python +# Add lifespan support for startup/shutdown with strong typing +from dataclasses import dataclass +from typing import AsyncIterator from mcp.server.fastmcp import FastMCP # Create a named server @@ -136,14 +139,17 @@ mcp = FastMCP("My App") # Specify dependencies for deployment and development mcp = FastMCP("My App", dependencies=["pandas", "numpy"]) -# Add lifespan support for startup/shutdown +@dataclass +class AppContext: + db: Database # Replace with your actual DB type + @asynccontextmanager -async def app_lifespan(server: FastMCP) -> AsyncIterator[dict]: - """Manage application lifecycle""" +async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]: + """Manage application lifecycle with type-safe context""" try: # Initialize on startup await db.connect() - yield {"db": db} + yield AppContext(db=db) finally: # Cleanup on shutdown await db.disconnect() @@ -151,7 +157,7 @@ async def app_lifespan(server: FastMCP) -> AsyncIterator[dict]: # Pass lifespan to server mcp = FastMCP("My App", lifespan=app_lifespan) -# Access lifespan context in tools +# Access type-safe lifespan context in tools @mcp.tool() def query_db(ctx: Context) -> str: """Tool that uses initialized resources""" @@ -387,7 +393,6 @@ async def query_db(name: str, arguments: dict) -> list: The lifespan API provides: - A way to initialize resources when the server starts and clean them up when it stops - Access to initialized resources through the request context in handlers -- Support for both low-level Server and FastMCP classes - Type-safe context passing between lifespan and request handlers ```python diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index bc341b404..5ae30a5ca 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -100,7 +100,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]): lifespan: ( Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]] | None - ) = Field(None, description="Lifespan contexte manager") + ) = Field(None, description="Lifespan context manager") def lifespan_wrapper( diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index a4a8510c4..643e1a272 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -85,7 +85,10 @@ async def main(): logger = logging.getLogger(__name__) -request_ctx: contextvars.ContextVar[RequestContext[ServerSession]] = ( +LifespanResultT = TypeVar("LifespanResultT") + +# This will be properly typed in each Server instance's context +request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = ( contextvars.ContextVar("request_ctx") ) @@ -102,9 +105,6 @@ def __init__( self.tools_changed = tools_changed -LifespanResultT = TypeVar("LifespanResultT") - - @asynccontextmanager async def lifespan(server: "Server") -> AsyncIterator[object]: """Default lifespan context manager that does nothing. @@ -212,7 +212,7 @@ def get_capabilities( ) @property - def request_context(self) -> RequestContext[ServerSession]: + def request_context(self) -> RequestContext[ServerSession, LifespanResultT]: """If called outside of a request context, this will raise a LookupError.""" return request_ctx.get() @@ -510,7 +510,7 @@ async def _handle_request( message: RequestResponder, req: Any, session: ServerSession, - lifespan_context: object, + lifespan_context: LifespanResultT, raise_exceptions: bool, ): logger.info(f"Processing request of type {type(req).__name__}") diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 50e5d5194..a45fdacd4 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -1,15 +1,16 @@ from dataclasses import dataclass -from typing import Any, Generic, TypeVar +from typing import Generic, TypeVar from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParams SessionT = TypeVar("SessionT", bound=BaseSession) +LifespanContextT = TypeVar("LifespanContextT") @dataclass -class RequestContext(Generic[SessionT]): +class RequestContext(Generic[SessionT, LifespanContextT]): request_id: RequestId meta: RequestParams.Meta | None session: SessionT - lifespan_context: Any + lifespan_context: LifespanContextT