Skip to content

Commit cfab842

Browse files
committed
Move around validation logic
1 parent c2c3c54 commit cfab842

File tree

7 files changed

+76
-70
lines changed

7 files changed

+76
-70
lines changed

src/mcp/server/auth/errors.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,6 @@ def error_response(self) -> ErrorResponse:
2828
)
2929

3030

31-
class InvalidRequestError(OAuthError):
32-
"""
33-
Invalid request error.
34-
"""
35-
36-
error_code = "invalid_request"
37-
38-
39-
class InvalidClientError(OAuthError):
40-
"""
41-
Invalid client error.
42-
"""
43-
44-
error_code = "invalid_client"
45-
46-
4731
def stringify_pydantic_error(validation_error: ValidationError) -> str:
4832
return "\n".join(
4933
f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}"

src/mcp/server/auth/handlers/authorize.py

Lines changed: 12 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from starlette.responses import RedirectResponse, Response
1010

1111
from mcp.server.auth.errors import (
12-
InvalidRequestError,
1312
OAuthError,
1413
stringify_pydantic_error,
1514
)
@@ -19,7 +18,10 @@
1918
OAuthServerProvider,
2019
construct_redirect_uri,
2120
)
22-
from mcp.shared.auth import OAuthClientInformationFull
21+
from mcp.shared.auth import (
22+
InvalidRedirectUriError,
23+
InvalidScopeError,
24+
)
2325

2426
logger = logging.getLogger(__name__)
2527

@@ -66,37 +68,6 @@ class AuthorizationErrorResponse(BaseModel):
6668
state: str | None = None
6769

6870

69-
def validate_scope(
70-
requested_scope: str | None, client: OAuthClientInformationFull
71-
) -> list[str] | None:
72-
if requested_scope is None:
73-
return None
74-
requested_scopes = requested_scope.split(" ")
75-
allowed_scopes = [] if client.scope is None else client.scope.split(" ")
76-
for scope in requested_scopes:
77-
if scope not in allowed_scopes:
78-
raise InvalidRequestError(f"Client was not registered with scope {scope}")
79-
return requested_scopes
80-
81-
82-
def validate_redirect_uri(
83-
redirect_uri: AnyHttpUrl | None, client: OAuthClientInformationFull
84-
) -> AnyHttpUrl:
85-
if redirect_uri is not None:
86-
# Validate redirect_uri against client's registered redirect URIs
87-
if redirect_uri not in client.redirect_uris:
88-
raise InvalidRequestError(
89-
f"Redirect URI '{redirect_uri}' not registered for client"
90-
)
91-
return redirect_uri
92-
elif len(client.redirect_uris) == 1:
93-
return client.redirect_uris[0]
94-
else:
95-
raise InvalidRequestError(
96-
"redirect_uri must be specified when client has multiple registered URIs"
97-
)
98-
99-
10071
def best_effort_extract_string(
10172
key: str, params: None | FormData | QueryParams
10273
) -> str | None:
@@ -146,8 +117,8 @@ async def error_response(
146117
best_effort_extract_string("redirect_uri", params)
147118
).root
148119
try:
149-
redirect_uri = validate_redirect_uri(raw_redirect_uri, client)
150-
except (ValidationError, InvalidRequestError):
120+
redirect_uri = client.validate_redirect_uri(raw_redirect_uri)
121+
except (ValidationError, InvalidRedirectUriError):
151122
pass
152123
if state is None:
153124
# make last-ditch effort to load state
@@ -213,22 +184,22 @@ async def error_response(
213184

214185
# Validate redirect_uri against client's registered URIs
215186
try:
216-
redirect_uri = validate_redirect_uri(auth_request.redirect_uri, client)
217-
except InvalidRequestError as validation_error:
187+
redirect_uri = client.validate_redirect_uri(auth_request.redirect_uri)
188+
except InvalidRedirectUriError as validation_error:
218189
# For redirect_uri validation errors, return direct error (no redirect)
219190
return await error_response(
220191
error="invalid_request",
221-
error_description=validation_error.error_description,
192+
error_description=validation_error.message,
222193
)
223194

224195
# Validate scope - for scope errors, we can redirect
225196
try:
226-
scopes = validate_scope(auth_request.scope, client)
227-
except InvalidRequestError as validation_error:
197+
scopes = client.validate_scope(auth_request.scope)
198+
except InvalidScopeError as validation_error:
228199
# For scope errors, redirect with error parameters
229200
return await error_response(
230201
error="invalid_scope",
231-
error_description=validation_error.error_description,
202+
error_description=validation_error.message,
232203
)
233204

234205
# Setup authorization parameters

src/mcp/server/auth/handlers/revoke.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
from starlette.responses import Response
88

99
from mcp.server.auth.errors import (
10-
InvalidClientError,
1110
stringify_pydantic_error,
1211
)
1312
from mcp.server.auth.json_response import PydanticJSONResponse
1413
from mcp.server.auth.middleware.client_auth import (
14+
AuthenticationError,
1515
ClientAuthenticator,
1616
)
1717
from mcp.server.auth.provider import AuthInfo, OAuthServerProvider, RefreshToken
@@ -29,7 +29,7 @@ class RevocationRequest(BaseModel):
2929

3030

3131
class RevocationErrorResponse(BaseModel):
32-
error: Literal["invalid_request",]
32+
error: Literal["invalid_request", "unauthorized_client"]
3333
error_description: str | None = None
3434

3535

@@ -59,8 +59,14 @@ async def handle(self, request: Request) -> Response:
5959
client = await self.client_authenticator.authenticate(
6060
revocation_request.client_id, revocation_request.client_secret
6161
)
62-
except InvalidClientError as e:
63-
return PydanticJSONResponse(status_code=401, content=e.error_response())
62+
except AuthenticationError as e:
63+
return PydanticJSONResponse(
64+
status_code=401,
65+
content=RevocationErrorResponse(
66+
error="unauthorized_client",
67+
error_description=e.message,
68+
),
69+
)
6470

6571
loaders = [
6672
self.provider.load_access_token,

src/mcp/server/auth/handlers/token.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99

1010
from mcp.server.auth.errors import (
1111
ErrorResponse,
12-
InvalidClientError,
1312
stringify_pydantic_error,
1413
)
1514
from mcp.server.auth.json_response import PydanticJSONResponse
1615
from mcp.server.auth.middleware.client_auth import (
16+
AuthenticationError,
1717
ClientAuthenticator,
1818
)
1919
from mcp.server.auth.provider import OAuthServerProvider
@@ -111,8 +111,13 @@ async def handle(self, request: Request):
111111
client_id=token_request.client_id,
112112
client_secret=token_request.client_secret,
113113
)
114-
except InvalidClientError as e:
115-
return self.response(e.error_response())
114+
except AuthenticationError as e:
115+
return self.response(
116+
TokenErrorResponse(
117+
error="unauthorized_client",
118+
error_description=e.message,
119+
)
120+
)
116121

117122
if token_request.grant_type not in client_info.grant_types:
118123
return self.response(

src/mcp/server/auth/middleware/client_auth.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import time
22

3-
from mcp.server.auth.errors import InvalidClientError
43
from mcp.server.auth.provider import OAuthRegisteredClientsStore
54
from mcp.shared.auth import OAuthClientInformationFull
65

76

7+
class AuthenticationError(Exception):
8+
def __init__(self, message: str):
9+
self.message = message
10+
11+
812
class ClientAuthenticator:
913
"""
1014
ClientAuthenticator is a callable which validates requests from a client
@@ -31,21 +35,21 @@ async def authenticate(
3135
# Look up client information
3236
client = await self.clients_store.get_client(client_id)
3337
if not client:
34-
raise InvalidClientError("Invalid client_id")
38+
raise AuthenticationError("Invalid client_id")
3539

3640
# If client from the store expects a secret, validate that the request provides
3741
# that secret
3842
if client.client_secret:
3943
if not client_secret:
40-
raise InvalidClientError("Client secret is required")
44+
raise AuthenticationError("Client secret is required")
4145

4246
if client.client_secret != client_secret:
43-
raise InvalidClientError("Invalid client_secret")
47+
raise AuthenticationError("Invalid client_secret")
4448

4549
if (
4650
client.client_secret_expires_at
4751
and client.client_secret_expires_at < int(time.time())
4852
):
49-
raise InvalidClientError("Client secret has expired")
53+
raise AuthenticationError("Client secret has expired")
5054

5155
return client

src/mcp/shared/auth.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,16 @@ class OAuthToken(BaseModel):
1515
refresh_token: str | None = None
1616

1717

18+
class InvalidScopeError(Exception):
19+
def __init__(self, message: str):
20+
self.message = message
21+
22+
23+
class InvalidRedirectUriError(Exception):
24+
def __init__(self, message: str):
25+
self.message = message
26+
27+
1828
class OAuthClientMetadata(BaseModel):
1929
"""
2030
RFC 7591 OAuth 2.0 Dynamic Client Registration metadata.
@@ -50,6 +60,32 @@ class OAuthClientMetadata(BaseModel):
5060
software_id: str | None = None
5161
software_version: str | None = None
5262

63+
def validate_scope(self, requested_scope: str | None) -> list[str] | None:
64+
if requested_scope is None:
65+
return None
66+
requested_scopes = requested_scope.split(" ")
67+
allowed_scopes = [] if self.scope is None else self.scope.split(" ")
68+
for scope in requested_scopes:
69+
if scope not in allowed_scopes:
70+
raise InvalidScopeError(f"Client was not registered with scope {scope}")
71+
return requested_scopes
72+
73+
def validate_redirect_uri(self, redirect_uri: AnyHttpUrl | None) -> AnyHttpUrl:
74+
if redirect_uri is not None:
75+
# Validate redirect_uri against client's registered redirect URIs
76+
if redirect_uri not in self.redirect_uris:
77+
raise InvalidRedirectUriError(
78+
f"Redirect URI '{redirect_uri}' not registered for client"
79+
)
80+
return redirect_uri
81+
elif len(self.redirect_uris) == 1:
82+
return self.redirect_uris[0]
83+
else:
84+
raise InvalidRedirectUriError(
85+
"redirect_uri must be specified when client "
86+
"has multiple registered URIs"
87+
)
88+
5389

5490
class OAuthClientInformationFull(OAuthClientMetadata):
5591
"""

tests/server/fastmcp/auth/test_auth_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ async def load_access_token(self, token: str) -> AuthInfo | None:
198198
expires_at=token_info.expires_at,
199199
)
200200

201-
async def revoke_token(self, token: OAuthToken | RefreshToken) -> None:
201+
async def revoke_token(self, token: AuthInfo | RefreshToken) -> None:
202202
match token:
203203
case RefreshToken():
204204
# Remove the refresh token

0 commit comments

Comments
 (0)