From 60f8a5410961b124bbb8acdff8dfbd870e10af9b Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Thu, 20 Mar 2025 17:42:32 +0000 Subject: [PATCH] refactor: Make types.py strictly typechecked. We are now making the types module typecheck in strict mode. This is mostly achieved by passing the correct types to the generic type variables. We ran into one specific issue with `JSONRPCRequest` and `JSONRPCNotification`. Both are generic classes that take a `dict[str, Any]` as params and just a plain string as method. However, since the TypeVar for `RequestT` and `NotificationT` are bound to `RequestParams` and `NotificationParams` respectively we get into a type issue. There are two ways of solving this: 1. Widen the bound by allowing explicitly for `dict[str, Any]` 2. Make JSONRPCRequest and JSONRPCNotificaiton not part of the type hierarchy with Request and Notification roots. It felt most naturally to keep JSONRPCRequest/JSONRPCNotification part of the type hierarchy and allow for general passing of dict[str, Any]. This now typechecks. --- pyproject.toml | 1 - src/mcp/types.py | 91 ++++++++++++++++++++++++++++------------ tests/shared/test_sse.py | 4 +- tests/shared/test_ws.py | 4 +- 4 files changed, 66 insertions(+), 34 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e400ad7d8..d014bf0c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,7 +77,6 @@ include = ["src/mcp", "tests"] venvPath = "." venv = ".venv" strict = ["src/mcp/**/*.py"] -exclude = ["src/mcp/types.py"] [tool.ruff.lint] select = ["E", "F", "I", "UP"] diff --git a/src/mcp/types.py b/src/mcp/types.py index f043fb10a..4ef111069 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -64,8 +64,10 @@ class Meta(BaseModel): """ -RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams) -NotificationParamsT = TypeVar("NotificationParamsT", bound=NotificationParams) +RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None) +NotificationParamsT = TypeVar( + "NotificationParamsT", bound=NotificationParams | dict[str, Any] | None +) MethodT = TypeVar("MethodT", bound=str) @@ -113,15 +115,16 @@ class PaginatedResult(Result): """ -class JSONRPCRequest(Request): +class JSONRPCRequest(Request[dict[str, Any] | None, str]): """A request that expects a response.""" jsonrpc: Literal["2.0"] id: RequestId + method: str params: dict[str, Any] | None = None -class JSONRPCNotification(Notification): +class JSONRPCNotification(Notification[dict[str, Any] | None, str]): """A notification which does not expect a response.""" jsonrpc: Literal["2.0"] @@ -277,7 +280,7 @@ class InitializeRequestParams(RequestParams): model_config = ConfigDict(extra="allow") -class InitializeRequest(Request): +class InitializeRequest(Request[InitializeRequestParams, Literal["initialize"]]): """ This request is sent from the client to the server when it first connects, asking it to begin initialization. @@ -298,7 +301,9 @@ class InitializeResult(Result): """Instructions describing how to use the server and its features.""" -class InitializedNotification(Notification): +class InitializedNotification( + Notification[NotificationParams | None, Literal["notifications/initialized"]] +): """ This notification is sent from the client to the server after initialization has finished. @@ -308,7 +313,7 @@ class InitializedNotification(Notification): params: NotificationParams | None = None -class PingRequest(Request): +class PingRequest(Request[RequestParams | None, Literal["ping"]]): """ A ping, issued by either the server or the client, to check that the other party is still alive. @@ -336,7 +341,9 @@ class ProgressNotificationParams(NotificationParams): model_config = ConfigDict(extra="allow") -class ProgressNotification(Notification): +class ProgressNotification( + Notification[ProgressNotificationParams, Literal["notifications/progress"]] +): """ An out-of-band notification used to inform the receiver of a progress update for a long-running request. @@ -346,7 +353,9 @@ class ProgressNotification(Notification): params: ProgressNotificationParams -class ListResourcesRequest(PaginatedRequest): +class ListResourcesRequest( + PaginatedRequest[RequestParams | None, Literal["resources/list"]] +): """Sent from the client to request a list of resources the server has.""" method: Literal["resources/list"] @@ -408,7 +417,9 @@ class ListResourcesResult(PaginatedResult): resources: list[Resource] -class ListResourceTemplatesRequest(PaginatedRequest): +class ListResourceTemplatesRequest( + PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]] +): """Sent from the client to request a list of resource templates the server has.""" method: Literal["resources/templates/list"] @@ -432,7 +443,9 @@ class ReadResourceRequestParams(RequestParams): model_config = ConfigDict(extra="allow") -class ReadResourceRequest(Request): +class ReadResourceRequest( + Request[ReadResourceRequestParams, Literal["resources/read"]] +): """Sent from the client to the server, to read a specific resource URI.""" method: Literal["resources/read"] @@ -472,7 +485,11 @@ class ReadResourceResult(Result): contents: list[TextResourceContents | BlobResourceContents] -class ResourceListChangedNotification(Notification): +class ResourceListChangedNotification( + Notification[ + NotificationParams | None, Literal["notifications/resources/list_changed"] + ] +): """ An optional notification from the server to the client, informing it that the list of resources it can read from has changed. @@ -493,7 +510,7 @@ class SubscribeRequestParams(RequestParams): model_config = ConfigDict(extra="allow") -class SubscribeRequest(Request): +class SubscribeRequest(Request[SubscribeRequestParams, Literal["resources/subscribe"]]): """ Sent from the client to request resources/updated notifications from the server whenever a particular resource changes. @@ -511,7 +528,9 @@ class UnsubscribeRequestParams(RequestParams): model_config = ConfigDict(extra="allow") -class UnsubscribeRequest(Request): +class UnsubscribeRequest( + Request[UnsubscribeRequestParams, Literal["resources/unsubscribe"]] +): """ Sent from the client to request cancellation of resources/updated notifications from the server. @@ -532,7 +551,11 @@ class ResourceUpdatedNotificationParams(NotificationParams): model_config = ConfigDict(extra="allow") -class ResourceUpdatedNotification(Notification): +class ResourceUpdatedNotification( + Notification[ + ResourceUpdatedNotificationParams, Literal["notifications/resources/updated"] + ] +): """ A notification from the server to the client, informing it that a resource has changed and may need to be read again. @@ -542,7 +565,9 @@ class ResourceUpdatedNotification(Notification): params: ResourceUpdatedNotificationParams -class ListPromptsRequest(PaginatedRequest): +class ListPromptsRequest( + PaginatedRequest[RequestParams | None, Literal["prompts/list"]] +): """Sent from the client to request a list of prompts and prompt templates.""" method: Literal["prompts/list"] @@ -589,7 +614,7 @@ class GetPromptRequestParams(RequestParams): model_config = ConfigDict(extra="allow") -class GetPromptRequest(Request): +class GetPromptRequest(Request[GetPromptRequestParams, Literal["prompts/get"]]): """Used by the client to get a prompt provided by the server.""" method: Literal["prompts/get"] @@ -659,7 +684,11 @@ class GetPromptResult(Result): messages: list[PromptMessage] -class PromptListChangedNotification(Notification): +class PromptListChangedNotification( + Notification[ + NotificationParams | None, Literal["notifications/prompts/list_changed"] + ] +): """ An optional notification from the server to the client, informing it that the list of prompts it offers has changed. @@ -669,7 +698,7 @@ class PromptListChangedNotification(Notification): params: NotificationParams | None = None -class ListToolsRequest(PaginatedRequest): +class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]): """Sent from the client to request a list of tools the server has.""" method: Literal["tools/list"] @@ -702,7 +731,7 @@ class CallToolRequestParams(RequestParams): model_config = ConfigDict(extra="allow") -class CallToolRequest(Request): +class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]): """Used by the client to invoke a tool provided by the server.""" method: Literal["tools/call"] @@ -716,7 +745,9 @@ class CallToolResult(Result): isError: bool = False -class ToolListChangedNotification(Notification): +class ToolListChangedNotification( + Notification[NotificationParams | None, Literal["notifications/tools/list_changed"]] +): """ An optional notification from the server to the client, informing it that the list of tools it offers has changed. @@ -739,7 +770,7 @@ class SetLevelRequestParams(RequestParams): model_config = ConfigDict(extra="allow") -class SetLevelRequest(Request): +class SetLevelRequest(Request[SetLevelRequestParams, Literal["logging/setLevel"]]): """A request from the client to the server, to enable or adjust logging.""" method: Literal["logging/setLevel"] @@ -761,7 +792,9 @@ class LoggingMessageNotificationParams(NotificationParams): model_config = ConfigDict(extra="allow") -class LoggingMessageNotification(Notification): +class LoggingMessageNotification( + Notification[LoggingMessageNotificationParams, Literal["notifications/message"]] +): """Notification of a log message passed from server to client.""" method: Literal["notifications/message"] @@ -856,7 +889,9 @@ class CreateMessageRequestParams(RequestParams): model_config = ConfigDict(extra="allow") -class CreateMessageRequest(Request): +class CreateMessageRequest( + Request[CreateMessageRequestParams, Literal["sampling/createMessage"]] +): """A request from the server to sample an LLM via the client.""" method: Literal["sampling/createMessage"] @@ -913,7 +948,7 @@ class CompleteRequestParams(RequestParams): model_config = ConfigDict(extra="allow") -class CompleteRequest(Request): +class CompleteRequest(Request[CompleteRequestParams, Literal["completion/complete"]]): """A request from the client to the server, to ask for completion options.""" method: Literal["completion/complete"] @@ -944,7 +979,7 @@ class CompleteResult(Result): completion: Completion -class ListRootsRequest(Request): +class ListRootsRequest(Request[RequestParams | None, Literal["roots/list"]]): """ Sent from the server to request a list of root URIs from the client. Roots allow servers to ask for specific directories or files to operate on. A common example @@ -987,7 +1022,9 @@ class ListRootsResult(Result): roots: list[Root] -class RootsListChangedNotification(Notification): +class RootsListChangedNotification( + Notification[NotificationParams | None, Literal["notifications/roots/list_changed"]] +): """ A notification from the client to the server, informing it that the list of roots has changed. diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 43107b597..f5158c3c3 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -138,9 +138,7 @@ def server(server_port: int) -> Generator[None, None, None]: time.sleep(0.1) attempt += 1 else: - raise RuntimeError( - f"Server failed to start after {max_attempts} attempts" - ) + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") yield diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 2aca97e15..1381c8153 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -134,9 +134,7 @@ def server(server_port: int) -> Generator[None, None, None]: time.sleep(0.1) attempt += 1 else: - raise RuntimeError( - f"Server failed to start after {max_attempts} attempts" - ) + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") yield