diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py index bbf3dc64c..2b3c2aa78 100644 --- a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py @@ -43,7 +43,12 @@ def main( @app.call_tool() async def call_tool( name: str, arguments: dict - ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + ) -> list[ + types.TextContent + | types.ImageContent + | types.AudioContent + | types.EmbeddedResource + ]: ctx = app.request_context interval = arguments.get("interval", 1.0) count = arguments.get("count", 5) diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index bf6f51e5c..5115c1251 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -47,7 +47,12 @@ def main( @app.call_tool() async def call_tool( name: str, arguments: dict - ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + ) -> list[ + types.TextContent + | types.ImageContent + | types.AudioContent + | types.EmbeddedResource + ]: ctx = app.request_context interval = arguments.get("interval", 1.0) count = arguments.get("count", 5) diff --git a/examples/servers/simple-tool/mcp_simple_tool/server.py b/examples/servers/simple-tool/mcp_simple_tool/server.py index cd574ad5e..46f9bbf7a 100644 --- a/examples/servers/simple-tool/mcp_simple_tool/server.py +++ b/examples/servers/simple-tool/mcp_simple_tool/server.py @@ -7,7 +7,9 @@ async def fetch_website( url: str, -) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: +) -> list[ + types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource +]: headers = { "User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)" } @@ -31,7 +33,12 @@ def main(port: int, transport: str) -> int: @app.call_tool() async def fetch_tool( name: str, arguments: dict - ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + ) -> list[ + types.TextContent + | types.ImageContent + | types.AudioContent + | types.EmbeddedResource + ]: if name != "fetch": raise ValueError(f"Unknown tool: {name}") if "url" not in arguments: diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 2855f606d..7cf4278cc 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -195,7 +195,8 @@ async def handle_get_stream( self.url, headers=headers, timeout=httpx.Timeout( - self.timeout.seconds, read=self.sse_read_timeout.seconds + self.timeout.total_seconds(), + read=self.sse_read_timeout.total_seconds(), ), ) as event_source: event_source.response.raise_for_status() @@ -226,7 +227,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: self.url, headers=headers, timeout=httpx.Timeout( - self.timeout.seconds, read=ctx.sse_read_timeout.seconds + self.timeout.total_seconds(), read=ctx.sse_read_timeout.total_seconds() ), ) as event_source: event_source.response.raise_for_status() @@ -468,7 +469,8 @@ async def streamablehttp_client( async with httpx_client_factory( headers=transport.request_headers, timeout=httpx.Timeout( - transport.timeout.seconds, read=transport.sse_read_timeout.seconds + transport.timeout.total_seconds(), + read=transport.sse_read_timeout.total_seconds(), ), auth=transport.auth, ) as client: diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index 8f3768908..f03afd566 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import Any, Literal -from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError +from pydantic import AnyUrl, BaseModel, Field, RootModel, ValidationError from starlette.datastructures import FormData, QueryParams from starlette.requests import Request from starlette.responses import RedirectResponse, Response @@ -29,7 +29,7 @@ class AuthorizationRequest(BaseModel): # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1 client_id: str = Field(..., description="The client ID") - redirect_uri: AnyHttpUrl | None = Field( + redirect_uri: AnyUrl | None = Field( None, description="URL to redirect to after authorization" ) @@ -68,8 +68,8 @@ def best_effort_extract_string( return None -class AnyHttpUrlModel(RootModel[AnyHttpUrl]): - root: AnyHttpUrl +class AnyUrlModel(RootModel[AnyUrl]): + root: AnyUrl @dataclass @@ -116,7 +116,7 @@ async def error_response( if params is not None and "redirect_uri" not in params: raw_redirect_uri = None else: - raw_redirect_uri = AnyHttpUrlModel.model_validate( + raw_redirect_uri = AnyUrlModel.model_validate( best_effort_extract_string("redirect_uri", params) ).root redirect_uri = client.validate_redirect_uri(raw_redirect_uri) diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 94a5c4de3..abea2bd41 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import Annotated, Any, Literal -from pydantic import AnyHttpUrl, BaseModel, Field, RootModel, ValidationError +from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError from starlette.requests import Request from mcp.server.auth.errors import ( @@ -27,7 +27,7 @@ class AuthorizationCodeRequest(BaseModel): # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 grant_type: Literal["authorization_code"] code: str = Field(..., description="The authorization code") - redirect_uri: AnyHttpUrl | None = Field( + redirect_uri: AnyUrl | None = Field( None, description="Must be the same as redirect URI provided in /authorize" ) client_id: str diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index be1ac1dbc..9f107f71b 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -2,7 +2,7 @@ from typing import Generic, Literal, Protocol, TypeVar from urllib.parse import parse_qs, urlencode, urlparse, urlunparse -from pydantic import AnyHttpUrl, BaseModel +from pydantic import AnyUrl, BaseModel from mcp.shared.auth import ( OAuthClientInformationFull, @@ -14,7 +14,7 @@ class AuthorizationParams(BaseModel): state: str | None scopes: list[str] | None code_challenge: str - redirect_uri: AnyHttpUrl + redirect_uri: AnyUrl redirect_uri_provided_explicitly: bool @@ -24,7 +24,7 @@ class AuthorizationCode(BaseModel): expires_at: float client_id: str code_challenge: str - redirect_uri: AnyHttpUrl + redirect_uri: AnyUrl redirect_uri_provided_explicitly: bool diff --git a/src/mcp/server/fastmcp/prompts/base.py b/src/mcp/server/fastmcp/prompts/base.py index aa3d1eac9..33bf68025 100644 --- a/src/mcp/server/fastmcp/prompts/base.py +++ b/src/mcp/server/fastmcp/prompts/base.py @@ -7,9 +7,9 @@ import pydantic_core from pydantic import BaseModel, Field, TypeAdapter, validate_call -from mcp.types import EmbeddedResource, ImageContent, TextContent +from mcp.types import AudioContent, EmbeddedResource, ImageContent, TextContent -CONTENT_TYPES = TextContent | ImageContent | EmbeddedResource +CONTENT_TYPES = TextContent | ImageContent | AudioContent | EmbeddedResource class Message(BaseModel): diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index e5b6c3acc..577ed1b9a 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -52,6 +52,7 @@ from mcp.shared.context import LifespanContextT, RequestContext, RequestT from mcp.types import ( AnyFunction, + AudioContent, EmbeddedResource, GetPromptResult, ImageContent, @@ -275,7 +276,7 @@ def get_context(self) -> Context[ServerSession, object, Request]: async def call_tool( self, name: str, arguments: dict[str, Any] - ) -> Sequence[TextContent | ImageContent | EmbeddedResource]: + ) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]: """Call a tool by name with arguments.""" context = self.get_context() result = await self._tool_manager.call_tool(name, arguments, context=context) @@ -875,12 +876,12 @@ async def get_prompt( def _convert_to_content( result: Any, -) -> Sequence[TextContent | ImageContent | EmbeddedResource]: +) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]: """Convert a result to a sequence of content objects.""" if result is None: return [] - if isinstance(result, TextContent | ImageContent | EmbeddedResource): + if isinstance(result, TextContent | ImageContent | AudioContent | EmbeddedResource): return [result] if isinstance(result, Image): diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index b98e3dd1a..882f5acc3 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -147,7 +147,7 @@ def __init__( } self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} self.notification_options = NotificationOptions() - logger.debug(f"Initializing server '{name}'") + logger.debug("Initializing server %r", name) def create_initialization_options( self, @@ -405,7 +405,10 @@ def decorator( ..., Awaitable[ Iterable[ - types.TextContent | types.ImageContent | types.EmbeddedResource + types.TextContent + | types.ImageContent + | types.AudioContent + | types.EmbeddedResource ] ], ], @@ -510,7 +513,7 @@ async def run( async with anyio.create_task_group() as tg: async for message in session.incoming_messages: - logger.debug(f"Received message: {message}") + logger.debug("Received message: %s", message) tg.start_soon( self._handle_message, @@ -543,7 +546,9 @@ async def _handle_message( await self._handle_notification(notify) for warning in w: - logger.info(f"Warning: {warning.category.__name__}: {warning.message}") + logger.info( + "Warning: %s: %s", warning.category.__name__, warning.message + ) async def _handle_request( self, @@ -553,10 +558,9 @@ async def _handle_request( lifespan_context: LifespanResultT, raise_exceptions: bool, ): - logger.info(f"Processing request of type {type(req).__name__}") - if type(req) in self.request_handlers: - handler = self.request_handlers[type(req)] - logger.debug(f"Dispatching request of type {type(req).__name__}") + logger.info("Processing request of type %s", type(req).__name__) + if handler := self.request_handlers.get(type(req)): # type: ignore + logger.debug("Dispatching request of type %s", type(req).__name__) token = None try: @@ -602,16 +606,13 @@ async def _handle_request( logger.debug("Response sent") async def _handle_notification(self, notify: Any): - if type(notify) in self.notification_handlers: - assert type(notify) in self.notification_handlers - - handler = self.notification_handlers[type(notify)] - logger.debug(f"Dispatching notification of type {type(notify).__name__}") + if handler := self.notification_handlers.get(type(notify)): # type: ignore + logger.debug("Dispatching notification of type %s", type(notify).__name__) try: await handler(notify) - except Exception as err: - logger.error(f"Uncaught exception in notification handler: {err}") + except Exception: + logger.exception("Uncaught exception in notification handler") async def _ping_handler(request: types.PingRequest) -> types.ServerResult: diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 22f8a971d..1c988a5e2 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -1,6 +1,6 @@ from typing import Any, Literal -from pydantic import AnyHttpUrl, BaseModel, Field +from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field class OAuthToken(BaseModel): @@ -32,7 +32,7 @@ class OAuthClientMetadata(BaseModel): for the full specification. """ - redirect_uris: list[AnyHttpUrl] = Field(..., min_length=1) + redirect_uris: list[AnyUrl] = Field(..., min_length=1) # token_endpoint_auth_method: this implementation only supports none & # client_secret_post; # ie: we do not support client_secret_basic @@ -71,7 +71,7 @@ def validate_scope(self, requested_scope: str | None) -> list[str] | None: raise InvalidScopeError(f"Client was not registered with scope {scope}") return requested_scopes - def validate_redirect_uri(self, redirect_uri: AnyHttpUrl | None) -> AnyHttpUrl: + def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl: if redirect_uri is not None: # Validate redirect_uri against client's registered redirect URIs if redirect_uri not in self.redirect_uris: diff --git a/src/mcp/types.py b/src/mcp/types.py index 4f5af27b9..c5076ee5b 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -350,12 +350,12 @@ class ProgressNotificationParams(NotificationParams): total is unknown. """ total: float | None = None + """Total number of items to process (or total progress required), if known.""" + message: str | None = None """ Message related to progress. This should provide relevant human readable progress information. """ - message: str | None = None - """Total number of items to process (or total progress required), if known.""" model_config = ConfigDict(extra="allow") @@ -657,11 +657,26 @@ class ImageContent(BaseModel): model_config = ConfigDict(extra="allow") +class AudioContent(BaseModel): + """Audio content for a message.""" + + type: Literal["audio"] + data: str + """The base64-encoded audio data.""" + mimeType: str + """ + The MIME type of the audio. Different providers may support different + audio types. + """ + annotations: Annotations | None = None + model_config = ConfigDict(extra="allow") + + class SamplingMessage(BaseModel): """Describes a message issued to or received from an LLM API.""" role: Role - content: TextContent | ImageContent + content: TextContent | ImageContent | AudioContent model_config = ConfigDict(extra="allow") @@ -683,7 +698,7 @@ class PromptMessage(BaseModel): """Describes a message returned as part of a prompt.""" role: Role - content: TextContent | ImageContent | EmbeddedResource + content: TextContent | ImageContent | AudioContent | EmbeddedResource model_config = ConfigDict(extra="allow") @@ -801,7 +816,7 @@ class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]): class CallToolResult(Result): """The server's response to a tool call.""" - content: list[TextContent | ImageContent | EmbeddedResource] + content: list[TextContent | ImageContent | AudioContent | EmbeddedResource] isError: bool = False @@ -965,7 +980,7 @@ class CreateMessageResult(Result): """The client's response to a sampling/create_message request from the server.""" role: Role - content: TextContent | ImageContent + content: TextContent | ImageContent | AudioContent model: str """The name of the model that generated the message.""" stopReason: StopReason | None = None diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 2edaff946..c663bddcc 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -11,7 +11,7 @@ import httpx import pytest from inline_snapshot import snapshot -from pydantic import AnyHttpUrl +from pydantic import AnyHttpUrl, AnyUrl from mcp.client.auth import OAuthClientProvider from mcp.server.auth.routes import build_metadata @@ -52,7 +52,7 @@ def mock_storage(): @pytest.fixture def client_metadata(): return OAuthClientMetadata( - redirect_uris=[AnyHttpUrl("http://localhost:3000/callback")], + redirect_uris=[AnyUrl("http://localhost:3000/callback")], client_name="Test Client", grant_types=["authorization_code", "refresh_token"], response_types=["code"], @@ -79,7 +79,7 @@ def oauth_client_info(): return OAuthClientInformationFull( client_id="test_client_id", client_secret="test_client_secret", - redirect_uris=[AnyHttpUrl("http://localhost:3000/callback")], + redirect_uris=[AnyUrl("http://localhost:3000/callback")], client_name="Test Client", grant_types=["authorization_code", "refresh_token"], response_types=["code"], diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index 88e41d66d..9b21e4ba1 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -12,6 +12,7 @@ from mcp.server.lowlevel import Server from mcp.shared.exceptions import McpError from mcp.types import ( + AudioContent, EmbeddedResource, ImageContent, TextContent, @@ -37,7 +38,7 @@ async def test_notification_validation_error(tmp_path: Path): @server.call_tool() async def slow_tool( name: str, arg - ) -> Sequence[TextContent | ImageContent | EmbeddedResource]: + ) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]: nonlocal request_count request_count += 1 diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index b817761ea..71cad7e68 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -16,6 +16,7 @@ create_connected_server_and_client_session as client_session, ) from mcp.types import ( + AudioContent, BlobResourceContents, ImageContent, TextContent, @@ -207,10 +208,11 @@ def image_tool_fn(path: str) -> Image: return Image(path) -def mixed_content_tool_fn() -> list[TextContent | ImageContent]: +def mixed_content_tool_fn() -> list[TextContent | ImageContent | AudioContent]: return [ TextContent(type="text", text="Hello"), ImageContent(type="image", data="abc", mimeType="image/png"), + AudioContent(type="audio", data="def", mimeType="audio/wav"), ] @@ -312,14 +314,16 @@ async def test_tool_mixed_content(self): mcp.add_tool(mixed_content_tool_fn) async with client_session(mcp._mcp_server) as client: result = await client.call_tool("mixed_content_tool_fn", {}) - assert len(result.content) == 2 - content1 = result.content[0] - content2 = result.content[1] + assert len(result.content) == 3 + content1, content2, content3 = result.content assert isinstance(content1, TextContent) assert content1.text == "Hello" assert isinstance(content2, ImageContent) assert content2.mimeType == "image/png" assert content2.data == "abc" + assert isinstance(content3, AudioContent) + assert content3.mimeType == "audio/wav" + assert content3.data == "def" @pytest.mark.anyio async def test_tool_mixed_list_with_image(self, tmp_path: Path):