From 19a92418e46aa14d89579ddc199a9ef762518d9e Mon Sep 17 00:00:00 2001 From: Marc Nuri Date: Mon, 11 Aug 2025 14:39:27 +0300 Subject: [PATCH] feat(auth): support for VSCode auth flow (#258) Adds DisableDynamicClientRegistration and OAuthScopes to be able to override the values proxied from the configured authorization server. DisableDynamicClientRegistration removes the registration_endpoint field from the well-known authorization resource metadata. This forces VSCode to show a for to input the Client ID and Client Secret since these can't be discovered. The OAuthScopes allows to override the scopes_supported field. VSCode automatically makes an auth request for all of the supported scopes. In many cases, this is not supported by the auth server. By providing this configuration, the user (MCP Server administrator) is able to set which scopes are effectively supported and force VSCode to only request these. Signed-off-by: Marc Nuri --- pkg/config/config.go | 5 ++ pkg/http/authorization.go | 1 + pkg/http/http_test.go | 108 +++++++++++++++++++++++++++++++++++++- pkg/http/wellknown.go | 29 ++++++++-- 4 files changed, 137 insertions(+), 6 deletions(-) diff --git a/pkg/config/config.go b/pkg/config/config.go index 6e797971..26e007d1 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -33,6 +33,11 @@ type StaticConfig struct { // AuthorizationURL is the URL of the OIDC authorization server. // It is used for token validation and for STS token exchange. AuthorizationURL string `toml:"authorization_url,omitempty"` + // DisableDynamicClientRegistration indicates whether dynamic client registration is disabled. + // If true, the .well-known endpoints will not expose the registration endpoint. + DisableDynamicClientRegistration bool `toml:"disable_dynamic_client_registration,omitempty"` + // OAuthScopes are the supported **client** scopes requested during the **client/frontend** OAuth flow. + OAuthScopes []string `toml:"oauth_scopes,omitempty"` // StsClientId is the OAuth client ID used for backend token exchange StsClientId string `toml:"sts_client_id,omitempty"` // StsClientSecret is the OAuth client secret used for backend token exchange diff --git a/pkg/http/authorization.go b/pkg/http/authorization.go index 6fa81f08..39259d51 100644 --- a/pkg/http/authorization.go +++ b/pkg/http/authorization.go @@ -111,6 +111,7 @@ func AuthorizationMiddleware(staticConfig *config.StaticConfig, oidcProvider *oi } // Token exchange with OIDC provider sts := NewFromConfig(staticConfig, oidcProvider) + // TODO: Maybe the token had already been exchanged, if it has the right audience and scopes, we can skip this step. if err == nil && sts.IsEnabled() { var exchangedToken *oauth2.Token // If the token is valid, we can exchange it for a new token with the specified audience and scopes. diff --git a/pkg/http/http_test.go b/pkg/http/http_test.go index 04cccf47..0ceaf2e8 100644 --- a/pkg/http/http_test.go +++ b/pkg/http/http_test.go @@ -8,6 +8,7 @@ import ( "crypto/rsa" "flag" "fmt" + "io" "net" "net/http" "net/http/httptest" @@ -334,7 +335,28 @@ func TestWellKnownReverseProxy(t *testing.T) { }) } }) - // With Authorization URL configured + // With Authorization URL configured but invalid payload + invalidPayloadServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`NOT A JSON PAYLOAD`)) + })) + t.Cleanup(invalidPayloadServer.Close) + invalidPayloadConfig := &config.StaticConfig{AuthorizationURL: invalidPayloadServer.URL, RequireOAuth: true, ValidateToken: true} + testCaseWithContext(t, &httpContext{StaticConfig: invalidPayloadConfig}, func(ctx *httpContext) { + for _, path := range cases { + resp, err := http.Get(fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path)) + t.Cleanup(func() { _ = resp.Body.Close() }) + t.Run("Protected resource '"+path+"' with invalid Authorization URL payload returns 500 - Internal Server Error", func(t *testing.T) { + if err != nil { + t.Fatalf("Failed to get %s endpoint: %v", path, err) + } + if resp.StatusCode != http.StatusInternalServerError { + t.Errorf("Expected HTTP 500 Internal Server Error, got %d", resp.StatusCode) + } + }) + } + }) + // With Authorization URL configured and valid payload testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !strings.HasPrefix(r.URL.EscapedPath(), "/.well-known/") { http.NotFound(w, r) @@ -344,7 +366,8 @@ func TestWellKnownReverseProxy(t *testing.T) { _, _ = w.Write([]byte(`{"issuer": "https://example.com","scopes_supported":["mcp-server"]}`)) })) t.Cleanup(testServer.Close) - testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{AuthorizationURL: testServer.URL, RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) { + staticConfig := &config.StaticConfig{AuthorizationURL: testServer.URL, RequireOAuth: true, ValidateToken: true} + testCaseWithContext(t, &httpContext{StaticConfig: staticConfig}, func(ctx *httpContext) { for _, path := range cases { resp, err := http.Get(fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path)) t.Cleanup(func() { _ = resp.Body.Close() }) @@ -365,6 +388,87 @@ func TestWellKnownReverseProxy(t *testing.T) { }) } +func TestWellKnownOverrides(t *testing.T) { + cases := []string{ + ".well-known/oauth-authorization-server", + ".well-known/oauth-protected-resource", + ".well-known/openid-configuration", + } + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.EscapedPath(), "/.well-known/") { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(` + { + "issuer": "https://localhost", + "registration_endpoint": "https://localhost/clients-registrations/openid-connect", + "require_request_uri_registration": true, + "scopes_supported":["scope-1", "scope-2"] + }`)) + })) + t.Cleanup(testServer.Close) + baseConfig := config.StaticConfig{AuthorizationURL: testServer.URL, RequireOAuth: true, ValidateToken: true} + // With Dynamic Client Registration disabled + disableDynamicRegistrationConfig := baseConfig + disableDynamicRegistrationConfig.DisableDynamicClientRegistration = true + testCaseWithContext(t, &httpContext{StaticConfig: &disableDynamicRegistrationConfig}, func(ctx *httpContext) { + for _, path := range cases { + resp, _ := http.Get(fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path)) + t.Cleanup(func() { _ = resp.Body.Close() }) + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + t.Run("DisableDynamicClientRegistration removes registration_endpoint field", func(t *testing.T) { + if strings.Contains(string(body), "registration_endpoint") { + t.Error("Expected registration_endpoint to be removed, but it was found in the response") + } + }) + t.Run("DisableDynamicClientRegistration sets require_request_uri_registration = false", func(t *testing.T) { + if !strings.Contains(string(body), `"require_request_uri_registration":false`) { + t.Error("Expected require_request_uri_registration to be false, but it was not found in the response") + } + }) + t.Run("DisableDynamicClientRegistration includes/preserves scopes_supported", func(t *testing.T) { + if !strings.Contains(string(body), `"scopes_supported":["scope-1","scope-2"]`) { + t.Error("Expected scopes_supported to be present, but it was not found in the response") + } + }) + } + }) + // With overrides for OAuth scopes (client/frontend) + oAuthScopesConfig := baseConfig + oAuthScopesConfig.OAuthScopes = []string{"openid", "mcp-server"} + testCaseWithContext(t, &httpContext{StaticConfig: &oAuthScopesConfig}, func(ctx *httpContext) { + for _, path := range cases { + resp, _ := http.Get(fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path)) + t.Cleanup(func() { _ = resp.Body.Close() }) + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + t.Run("OAuthScopes overrides scopes_supported", func(t *testing.T) { + if !strings.Contains(string(body), `"scopes_supported":["openid","mcp-server"]`) { + t.Errorf("Expected scopes_supported to be overridden, but original was preserved, response: %s", string(body)) + } + }) + t.Run("OAuthScopes preserves other fields", func(t *testing.T) { + if !strings.Contains(string(body), `"issuer":"https://localhost"`) { + t.Errorf("Expected issuer to be preserved, but got: %s", string(body)) + } + if !strings.Contains(string(body), `"registration_endpoint":"https://localhost`) { + t.Errorf("Expected registration_endpoint to be preserved, but got: %s", string(body)) + } + if !strings.Contains(string(body), `"require_request_uri_registration":true`) { + t.Error("Expected require_request_uri_registration to be true, but it was not found in the response") + } + }) + } + }) +} + func TestMiddlewareLogging(t *testing.T) { testCase(t, func(ctx *httpContext) { _, _ = http.Get(fmt.Sprintf("http://%s/.well-known/oauth-protected-resource", ctx.HttpAddress)) diff --git a/pkg/http/wellknown.go b/pkg/http/wellknown.go index c1e375ea..0d80221e 100644 --- a/pkg/http/wellknown.go +++ b/pkg/http/wellknown.go @@ -1,7 +1,8 @@ package http import ( - "io" + "encoding/json" + "fmt" "net/http" "strings" @@ -21,7 +22,9 @@ var WellKnownEndpoints = []string{ } type WellKnown struct { - authorizationUrl string + authorizationUrl string + scopesSupported []string + disableDynamicClientRegistration bool } var _ http.Handler = &WellKnown{} @@ -31,7 +34,11 @@ func WellKnownHandler(staticConfig *config.StaticConfig) http.Handler { if authorizationUrl != "" && strings.HasSuffix("authorizationUrl", "/") { authorizationUrl = strings.TrimSuffix(authorizationUrl, "/") } - return &WellKnown{authorizationUrl} + return &WellKnown{ + authorizationUrl: authorizationUrl, + disableDynamicClientRegistration: staticConfig.DisableDynamicClientRegistration, + scopesSupported: staticConfig.OAuthScopes, + } } func (w WellKnown) ServeHTTP(writer http.ResponseWriter, request *http.Request) { @@ -50,16 +57,30 @@ func (w WellKnown) ServeHTTP(writer http.ResponseWriter, request *http.Request) return } defer func() { _ = resp.Body.Close() }() - body, err := io.ReadAll(resp.Body) + var resourceMetadata map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&resourceMetadata) if err != nil { http.Error(writer, "Failed to read response body: "+err.Error(), http.StatusInternalServerError) return } + if w.disableDynamicClientRegistration { + delete(resourceMetadata, "registration_endpoint") + resourceMetadata["require_request_uri_registration"] = false + } + if len(w.scopesSupported) > 0 { + resourceMetadata["scopes_supported"] = w.scopesSupported + } + body, err := json.Marshal(resourceMetadata) + if err != nil { + http.Error(writer, "Failed to marshal response body: "+err.Error(), http.StatusInternalServerError) + return + } for key, values := range resp.Header { for _, value := range values { writer.Header().Add(key, value) } } + writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(body))) writer.WriteHeader(resp.StatusCode) _, _ = writer.Write(body) }