diff --git a/coderd/coderd.go b/coderd/coderd.go index c3c1fb09cc6cc..fa10846a7d0a6 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -790,6 +790,7 @@ func New(options *Options) *API { SessionTokenFunc: nil, // Default behavior PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc, Logger: options.Logger, + AccessURL: options.AccessURL, }) // Same as above but it redirects to the login page. apiKeyMiddlewareRedirect := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ @@ -801,6 +802,7 @@ func New(options *Options) *API { SessionTokenFunc: nil, // Default behavior PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc, Logger: options.Logger, + AccessURL: options.AccessURL, }) // Same as the first but it's optional. apiKeyMiddlewareOptional := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ @@ -812,6 +814,7 @@ func New(options *Options) *API { SessionTokenFunc: nil, // Default behavior PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc, Logger: options.Logger, + AccessURL: options.AccessURL, }) workspaceAgentInfo := httpmw.ExtractWorkspaceAgentAndLatestBuild(httpmw.ExtractWorkspaceAgentAndLatestBuildConfig{ diff --git a/coderd/httpmw/apikey.go b/coderd/httpmw/apikey.go index 67d19a925a685..8fb68579a91e5 100644 --- a/coderd/httpmw/apikey.go +++ b/coderd/httpmw/apikey.go @@ -113,6 +113,10 @@ type ExtractAPIKeyConfig struct { // a user is authenticated to prevent additional CLI invocations. PostAuthAdditionalHeadersFunc func(a rbac.Subject, header http.Header) + // AccessURL is the configured access URL for this Coder deployment. + // Used for generating OAuth2 resource metadata URLs in WWW-Authenticate headers. + AccessURL *url.URL + // Logger is used for logging middleware operations. Logger slog.Logger } @@ -214,29 +218,9 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon return nil, nil, false } - // Add WWW-Authenticate header for 401/403 responses (RFC 6750) + // Add WWW-Authenticate header for 401/403 responses (RFC 6750 + RFC 9728) if code == http.StatusUnauthorized || code == http.StatusForbidden { - var wwwAuth string - - switch code { - case http.StatusUnauthorized: - // Map 401 to invalid_token with specific error descriptions - switch { - case strings.Contains(response.Message, "expired") || strings.Contains(response.Detail, "expired"): - wwwAuth = `Bearer realm="coder", error="invalid_token", error_description="The access token has expired"` - case strings.Contains(response.Message, "audience") || strings.Contains(response.Message, "mismatch"): - wwwAuth = `Bearer realm="coder", error="invalid_token", error_description="The access token audience does not match this resource"` - default: - wwwAuth = `Bearer realm="coder", error="invalid_token", error_description="The access token is invalid"` - } - case http.StatusForbidden: - // Map 403 to insufficient_scope per RFC 6750 - wwwAuth = `Bearer realm="coder", error="insufficient_scope", error_description="The request requires higher privileges than provided by the access token"` - default: - wwwAuth = `Bearer realm="coder"` - } - - rw.Header().Set("WWW-Authenticate", wwwAuth) + rw.Header().Set("WWW-Authenticate", buildWWWAuthenticateHeader(cfg.AccessURL, r, code, response)) } httpapi.Write(ctx, rw, code, response) @@ -272,7 +256,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon // Validate OAuth2 provider app token audience (RFC 8707) if applicable if key.LoginType == database.LoginTypeOAuth2ProviderApp { - if err := validateOAuth2ProviderAppTokenAudience(ctx, cfg.DB, *key, r); err != nil { + if err := validateOAuth2ProviderAppTokenAudience(ctx, cfg.DB, *key, cfg.AccessURL, r); err != nil { // Log the detailed error for debugging but don't expose it to the client cfg.Logger.Debug(ctx, "oauth2 token audience validation failed", slog.Error(err)) return optionalWrite(http.StatusForbidden, codersdk.Response{ @@ -489,7 +473,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon // validateOAuth2ProviderAppTokenAudience validates that an OAuth2 provider app token // is being used with the correct audience/resource server (RFC 8707). -func validateOAuth2ProviderAppTokenAudience(ctx context.Context, db database.Store, key database.APIKey, r *http.Request) error { +func validateOAuth2ProviderAppTokenAudience(ctx context.Context, db database.Store, key database.APIKey, accessURL *url.URL, r *http.Request) error { // Get the OAuth2 provider app token to check its audience //nolint:gocritic // System needs to access token for audience validation token, err := db.GetOAuth2ProviderAppTokenByAPIKeyID(dbauthz.AsSystemRestricted(ctx), key.ID) @@ -502,8 +486,8 @@ func validateOAuth2ProviderAppTokenAudience(ctx context.Context, db database.Sto return nil } - // Extract the expected audience from the request - expectedAudience := extractExpectedAudience(r) + // Extract the expected audience from the access URL + expectedAudience := extractExpectedAudience(accessURL, r) // Normalize both audience values for RFC 3986 compliant comparison normalizedTokenAudience := normalizeAudienceURI(token.Audience.String) @@ -624,18 +608,59 @@ func normalizePathSegments(path string) string { // Test export functions for testing package access +// buildWWWAuthenticateHeader constructs RFC 6750 + RFC 9728 compliant WWW-Authenticate header +func buildWWWAuthenticateHeader(accessURL *url.URL, r *http.Request, code int, response codersdk.Response) string { + // Use the configured access URL for resource metadata + if accessURL == nil { + scheme := "https" + if r.TLS == nil { + scheme = "http" + } + + // Use the Host header to construct the canonical audience URI + accessURL = &url.URL{ + Scheme: scheme, + Host: r.Host, + } + } + + resourceMetadata := accessURL.JoinPath("/.well-known/oauth-protected-resource").String() + + switch code { + case http.StatusUnauthorized: + switch { + case strings.Contains(response.Message, "expired") || strings.Contains(response.Detail, "expired"): + return fmt.Sprintf(`Bearer realm="coder", error="invalid_token", error_description="The access token has expired", resource_metadata=%q`, resourceMetadata) + case strings.Contains(response.Message, "audience") || strings.Contains(response.Message, "mismatch"): + return fmt.Sprintf(`Bearer realm="coder", error="invalid_token", error_description="The access token audience does not match this resource", resource_metadata=%q`, resourceMetadata) + default: + return fmt.Sprintf(`Bearer realm="coder", error="invalid_token", error_description="The access token is invalid", resource_metadata=%q`, resourceMetadata) + } + case http.StatusForbidden: + return fmt.Sprintf(`Bearer realm="coder", error="insufficient_scope", error_description="The request requires higher privileges than provided by the access token", resource_metadata=%q`, resourceMetadata) + default: + return fmt.Sprintf(`Bearer realm="coder", resource_metadata=%q`, resourceMetadata) + } +} + // extractExpectedAudience determines the expected audience for the current request. // This should match the resource parameter used during authorization. -func extractExpectedAudience(r *http.Request) string { +func extractExpectedAudience(accessURL *url.URL, r *http.Request) string { // For MCP compliance, the audience should be the canonical URI of the resource server // This typically matches the access URL of the Coder deployment - scheme := "https" - if r.TLS == nil { - scheme = "http" - } + var audience string + + if accessURL != nil { + audience = accessURL.String() + } else { + scheme := "https" + if r.TLS == nil { + scheme = "http" + } - // Use the Host header to construct the canonical audience URI - audience := fmt.Sprintf("%s://%s", scheme, r.Host) + // Use the Host header to construct the canonical audience URI + audience = fmt.Sprintf("%s://%s", scheme, r.Host) + } // Normalize the URI according to RFC 3986 for consistent comparison return normalizeAudienceURI(audience) diff --git a/coderd/httpmw/cors.go b/coderd/httpmw/cors.go index 2350a7dd3b8a6..218aab6609f60 100644 --- a/coderd/httpmw/cors.go +++ b/coderd/httpmw/cors.go @@ -4,6 +4,7 @@ import ( "net/http" "net/url" "regexp" + "strings" "github.com/go-chi/cors" @@ -28,13 +29,15 @@ const ( func Cors(allowAll bool, origins ...string) func(next http.Handler) http.Handler { if len(origins) == 0 { // The default behavior is '*', so putting the empty string defaults to - // the secure behavior of blocking CORs requests. + // the secure behavior of blocking CORS requests. origins = []string{""} } if allowAll { origins = []string{"*"} } - return cors.Handler(cors.Options{ + + // Standard CORS for most endpoints + standardCors := cors.Handler(cors.Options{ AllowedOrigins: origins, // We only need GET for latency requests AllowedMethods: []string{http.MethodOptions, http.MethodGet}, @@ -42,6 +45,50 @@ func Cors(allowAll bool, origins ...string) func(next http.Handler) http.Handler // Do not send any cookies AllowCredentials: false, }) + + // Permissive CORS for OAuth2 and MCP endpoints + permissiveCors := cors.Handler(cors.Options{ + AllowedOrigins: []string{"*"}, + AllowedMethods: []string{ + http.MethodGet, + http.MethodPost, + http.MethodDelete, + http.MethodOptions, + }, + AllowedHeaders: []string{ + "Content-Type", + "Accept", + "Authorization", + "x-api-key", + "Mcp-Session-Id", + "MCP-Protocol-Version", + "Last-Event-ID", + }, + ExposedHeaders: []string{ + "Content-Type", + "Authorization", + "x-api-key", + "Mcp-Session-Id", + "MCP-Protocol-Version", + }, + MaxAge: 86400, // 24 hours in seconds + AllowCredentials: false, + }) + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Use permissive CORS for OAuth2, MCP, and well-known endpoints + if strings.HasPrefix(r.URL.Path, "/oauth2/") || + strings.HasPrefix(r.URL.Path, "/api/experimental/mcp/") || + strings.HasPrefix(r.URL.Path, "/.well-known/oauth-") { + permissiveCors(next).ServeHTTP(w, r) + return + } + + // Use standard CORS for all other endpoints + standardCors(next).ServeHTTP(w, r) + }) + } } func WorkspaceAppCors(regex *regexp.Regexp, app appurl.ApplicationURL) func(next http.Handler) http.Handler { diff --git a/coderd/httpmw/csp_test.go b/coderd/httpmw/csp_test.go index 7bf8b879ef26f..ba88320e6fac9 100644 --- a/coderd/httpmw/csp_test.go +++ b/coderd/httpmw/csp_test.go @@ -34,7 +34,7 @@ func TestCSP(t *testing.T) { expected := []string{ "frame-src 'self' *.test.com *.coder.com *.coder2.com", - "media-src 'self' media.com media2.com", + "media-src 'self' " + strings.Join(expectedMedia, " "), strings.Join([]string{ "connect-src", "'self'", // Added from host header. diff --git a/coderd/httpmw/httpmw_internal_test.go b/coderd/httpmw/httpmw_internal_test.go index ee2d2ab663c52..7519fe770d922 100644 --- a/coderd/httpmw/httpmw_internal_test.go +++ b/coderd/httpmw/httpmw_internal_test.go @@ -258,7 +258,7 @@ func TestExtractExpectedAudience(t *testing.T) { } req.Host = tc.host - result := extractExpectedAudience(req) + result := extractExpectedAudience(nil, req) assert.Equal(t, tc.expected, result) }) } diff --git a/coderd/oauth2provider/authorize.go b/coderd/oauth2provider/authorize.go index 77be5fc397a8a..29d0c99abc707 100644 --- a/coderd/oauth2provider/authorize.go +++ b/coderd/oauth2provider/authorize.go @@ -33,7 +33,7 @@ func extractAuthorizeParams(r *http.Request, callbackURL *url.URL) (authorizePar p := httpapi.NewQueryParamParser() vals := r.URL.Query() - p.RequiredNotEmpty("state", "response_type", "client_id") + p.RequiredNotEmpty("response_type", "client_id") params := authorizeParams{ clientID: p.String(vals, "", "client_id"), @@ -154,7 +154,9 @@ func ProcessAuthorize(db database.Store, accessURL *url.URL) http.HandlerFunc { newQuery := params.redirectURL.Query() newQuery.Add("code", code.Formatted) - newQuery.Add("state", params.state) + if params.state != "" { + newQuery.Add("state", params.state) + } params.redirectURL.RawQuery = newQuery.Encode() http.Redirect(rw, r, params.redirectURL.String(), http.StatusTemporaryRedirect)