From 68baa2120875ed0e37c979f92376a55b123841a7 Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Fri, 27 Jun 2025 07:40:26 +0200 Subject: [PATCH] feat(oauth2): implement RFC 6750 Bearer Token Support for MCP compliance - Add RFC 6750 bearer token extraction to APITokenFromRequest as fallback methods - Support Authorization: Bearer header and access_token query parameter - Maintain backward compatibility by prioritizing existing custom methods first - Add WWW-Authenticate headers to 401/403 responses per RFC 6750 - Update Protected Resource Metadata to advertise bearer_methods_supported - Add comprehensive test suite for RFC 6750 compliance in rfc6750_test.go - Update MCP test scripts with bearer token authentication tests - Enhance CLAUDE.md with improved Go LSP tool usage guidelines Implements RFC 6750 Section 2.1 (Authorization Request Header Field) and 2.3 (URI Query Parameter). Maintains full backward compatibility with existing Coder authentication methods. Completes major MCP OAuth2 compliance milestone. Change-Id: Ic9c9057153b40728ad91b377d753a7ffd566add7 Signed-off-by: Thomas Kosiewski --- coderd/httpmw/apikey.go | 38 +- coderd/httpmw/rfc6750_extended_test.go | 619 +++++++++++++++++++++++++ coderd/httpmw/rfc6750_test.go | 296 ++++++++++++ coderd/oauth2.go | 5 +- coderd/oauth2_metadata_test.go | 6 +- scripts/oauth2/test-mcp-oauth2.sh | 47 ++ 6 files changed, 1004 insertions(+), 7 deletions(-) create mode 100644 coderd/httpmw/rfc6750_extended_test.go create mode 100644 coderd/httpmw/rfc6750_test.go diff --git a/coderd/httpmw/apikey.go b/coderd/httpmw/apikey.go index 5c06abd0fb147..97df5adfec20f 100644 --- a/coderd/httpmw/apikey.go +++ b/coderd/httpmw/apikey.go @@ -209,6 +209,26 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon return nil, nil, false } + // Add WWW-Authenticate header for 401/403 responses (RFC 6750) + if code == http.StatusUnauthorized || code == http.StatusForbidden { + // Basic Bearer challenge with realm + wwwAuth := `Bearer realm="coder"` + + // Add error details based on the type of error + switch { + case strings.Contains(response.Message, "invalid") || strings.Contains(response.Detail, "invalid"): + wwwAuth = `Bearer realm="coder", error="invalid_token", error_description="The access token is invalid"` + case strings.Contains(response.Message, "expired") || strings.Contains(response.Detail, "expired"): + wwwAuth = `Bearer realm="coder", error="invalid_token", error_description="The access token has expired"` + case strings.Contains(response.Message, "audience") || strings.Contains(response.Message, "mismatch"): + wwwAuth = `Bearer realm="coder", error="invalid_token", error_description="The access token audience does not match this resource"` + case code == http.StatusForbidden: + wwwAuth = `Bearer realm="coder", error="insufficient_scope", error_description="The request requires higher privileges than provided by the access token"` + } + + rw.Header().Set("WWW-Authenticate", wwwAuth) + } + httpapi.Write(ctx, rw, code, response) return nil, nil, false } @@ -534,9 +554,14 @@ func UserRBACSubject(ctx context.Context, db database.Store, userID uuid.UUID, s // 1: The cookie // 2. The coder_session_token query parameter // 3. The custom auth header +// 4. RFC 6750 Authorization: Bearer header +// 5. RFC 6750 access_token query parameter // // API tokens for apps are read from workspaceapps/cookies.go. func APITokenFromRequest(r *http.Request) string { + // Prioritize existing Coder custom authentication methods first + // to maintain backward compatibility and existing behavior + cookie, err := r.Cookie(codersdk.SessionTokenCookie) if err == nil && cookie.Value != "" { return cookie.Value @@ -552,7 +577,18 @@ func APITokenFromRequest(r *http.Request) string { return headerValue } - // TODO(ThomasK33): Implement RFC 6750 + // RFC 6750 Bearer Token support (added as fallback methods) + // Check Authorization: Bearer header + authHeader := r.Header.Get("Authorization") + if strings.HasPrefix(authHeader, "Bearer ") { + return strings.TrimPrefix(authHeader, "Bearer ") + } + + // Check access_token query parameter + accessToken := r.URL.Query().Get("access_token") + if accessToken != "" { + return accessToken + } return "" } diff --git a/coderd/httpmw/rfc6750_extended_test.go b/coderd/httpmw/rfc6750_extended_test.go new file mode 100644 index 0000000000000..cd0a33cc838b5 --- /dev/null +++ b/coderd/httpmw/rfc6750_extended_test.go @@ -0,0 +1,619 @@ +package httpmw_test + +import ( + "context" + "crypto/sha256" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +// TestOAuth2BearerTokenSecurityBoundaries tests RFC 6750 security boundaries +func TestOAuth2BearerTokenSecurityBoundaries(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + t.Cleanup(cancel) + + db, _ := dbtestutil.NewDB(t) + + // Create two different users with different API keys + user1 := dbgen.User(t, db, database.User{}) + user2 := dbgen.User(t, db, database.User{}) + + // Create API keys for both users + key1ID, key1Secret := randomAPIKeyParts() + hashedSecret1 := sha256.Sum256([]byte(key1Secret)) + _, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ + ID: key1ID, + UserID: user1.ID, + HashedSecret: hashedSecret1[:], + IPAddress: defaultIPAddressForTests(), + LastUsed: dbtime.Now(), + ExpiresAt: dbtime.Now().Add(testutil.WaitLong), + CreatedAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, + LifetimeSeconds: int64(testutil.WaitLong.Seconds()), + }) + require.NoError(t, err) + + key2ID, key2Secret := randomAPIKeyParts() + hashedSecret2 := sha256.Sum256([]byte(key2Secret)) + _, err = db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ + ID: key2ID, + UserID: user2.ID, + HashedSecret: hashedSecret2[:], + IPAddress: defaultIPAddressForTests(), + LastUsed: dbtime.Now(), + ExpiresAt: dbtime.Now().Add(testutil.WaitLong), + CreatedAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, + LifetimeSeconds: int64(testutil.WaitLong.Seconds()), + }) + require.NoError(t, err) + + t.Run("TokenIsolation", func(t *testing.T) { + t.Parallel() + + // Create middleware + middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ + DB: db, + }) + + // Handler that returns the authenticated user ID + handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + apiKey := httpmw.APIKey(r) + rw.Header().Set("X-User-ID", apiKey.UserID.String()) + rw.WriteHeader(http.StatusOK) + })) + + // Test that user1's token only accesses user1's data + req1 := httptest.NewRequest("GET", "/test", nil) + req1.Header.Set("Authorization", "Bearer "+key1ID+"-"+key1Secret) + rec1 := httptest.NewRecorder() + handler.ServeHTTP(rec1, req1) + + require.Equal(t, http.StatusOK, rec1.Code) + require.Equal(t, user1.ID.String(), rec1.Header().Get("X-User-ID")) + + // Test that user2's token only accesses user2's data + req2 := httptest.NewRequest("GET", "/test", nil) + req2.Header.Set("Authorization", "Bearer "+key2ID+"-"+key2Secret) + rec2 := httptest.NewRecorder() + handler.ServeHTTP(rec2, req2) + + require.Equal(t, http.StatusOK, rec2.Code) + require.Equal(t, user2.ID.String(), rec2.Header().Get("X-User-ID")) + + // Verify users can't access each other's data + require.NotEqual(t, rec1.Header().Get("X-User-ID"), rec2.Header().Get("X-User-ID")) + }) + + t.Run("CrossTokenAttempts", func(t *testing.T) { + t.Parallel() + + middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ + DB: db, + }) + + handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusOK) + })) + + // Try to use user1's key ID with user2's secret (should fail) + invalidToken := key1ID + "-" + key2Secret + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+invalidToken) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusUnauthorized, rec.Code) + require.Contains(t, rec.Header().Get("WWW-Authenticate"), "Bearer") + require.Contains(t, rec.Header().Get("WWW-Authenticate"), "invalid_token") + }) + + t.Run("TimingAttackResistance", func(t *testing.T) { + t.Parallel() + + middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ + DB: db, + }) + + handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusOK) + })) + + // Test multiple invalid tokens to ensure consistent timing + invalidTokens := []string{ + "invalid-token-1", + "invalid-token-2-longer", + "a", + strings.Repeat("x", 100), + } + + times := make([]time.Duration, len(invalidTokens)) + + for i, token := range invalidTokens { + start := time.Now() + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + times[i] = time.Since(start) + + require.Equal(t, http.StatusUnauthorized, rec.Code) + } + + // While we can't guarantee perfect timing consistency in tests, + // we can at least verify that the responses are all unauthorized + // and contain proper WWW-Authenticate headers + for _, token := range invalidTokens { + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusUnauthorized, rec.Code) + require.Contains(t, rec.Header().Get("WWW-Authenticate"), "Bearer") + } + }) +} + +// TestOAuth2BearerTokenMalformedHeaders tests handling of malformed Bearer headers per RFC 6750 +func TestOAuth2BearerTokenMalformedHeaders(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + + middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ + DB: db, + }) + + handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusOK) + })) + + tests := []struct { + name string + authHeader string + expectedStatus int + shouldHaveWWW bool + }{ + { + name: "MissingBearer", + authHeader: "invalid-token", + expectedStatus: http.StatusUnauthorized, + shouldHaveWWW: true, + }, + { + name: "CaseSensitive", + authHeader: "bearer token", // lowercase should still work + expectedStatus: http.StatusUnauthorized, + shouldHaveWWW: true, + }, + { + name: "ExtraSpaces", + authHeader: "Bearer token-with-extra-spaces", + expectedStatus: http.StatusUnauthorized, + shouldHaveWWW: true, + }, + { + name: "EmptyToken", + authHeader: "Bearer ", + expectedStatus: http.StatusUnauthorized, + shouldHaveWWW: true, + }, + { + name: "OnlyBearer", + authHeader: "Bearer", + expectedStatus: http.StatusUnauthorized, + shouldHaveWWW: true, + }, + { + name: "MultipleBearer", + authHeader: "Bearer token1 Bearer token2", + expectedStatus: http.StatusUnauthorized, + shouldHaveWWW: true, + }, + { + name: "InvalidBase64", + authHeader: "Bearer !!!invalid-base64!!!", + expectedStatus: http.StatusUnauthorized, + shouldHaveWWW: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", test.authHeader) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + require.Equal(t, test.expectedStatus, rec.Code) + + if test.shouldHaveWWW { + wwwAuth := rec.Header().Get("WWW-Authenticate") + require.Contains(t, wwwAuth, "Bearer") + require.Contains(t, wwwAuth, "realm=\"coder\"") + } + }) + } +} + +// TestOAuth2BearerTokenPrecedence tests token extraction precedence per RFC 6750 +func TestOAuth2BearerTokenPrecedence(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + t.Cleanup(cancel) + + db, _ := dbtestutil.NewDB(t) + user := dbgen.User(t, db, database.User{}) + + // Create a valid API key + keyID, keySecret := randomAPIKeyParts() + hashedSecret := sha256.Sum256([]byte(keySecret)) + validToken := keyID + "-" + keySecret + + _, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ + ID: keyID, + UserID: user.ID, + HashedSecret: hashedSecret[:], + IPAddress: defaultIPAddressForTests(), + LastUsed: dbtime.Now(), + ExpiresAt: dbtime.Now().Add(testutil.WaitLong), + CreatedAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, + LifetimeSeconds: int64(testutil.WaitLong.Seconds()), + }) + require.NoError(t, err) + + middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ + DB: db, + }) + + handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + apiKey := httpmw.APIKey(r) + rw.Header().Set("X-Key-ID", apiKey.ID) + rw.WriteHeader(http.StatusOK) + })) + + t.Run("CookieTakesPrecedenceOverBearer", func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest("GET", "/test", nil) + // Set both cookie and Bearer header - cookie should take precedence + req.AddCookie(&http.Cookie{ + Name: codersdk.SessionTokenCookie, + Value: validToken, + }) + req.Header.Set("Authorization", "Bearer invalid-token") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, keyID, rec.Header().Get("X-Key-ID")) + }) + + t.Run("QueryParameterTakesPrecedenceOverBearer", func(t *testing.T) { + t.Parallel() + + // Set both query parameter and Bearer header - query should take precedence + u, _ := url.Parse("/test") + q := u.Query() + q.Set(codersdk.SessionTokenCookie, validToken) + u.RawQuery = q.Encode() + + req := httptest.NewRequest("GET", u.String(), nil) + req.Header.Set("Authorization", "Bearer invalid-token") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, keyID, rec.Header().Get("X-Key-ID")) + }) + + t.Run("BearerHeaderFallback", func(t *testing.T) { + t.Parallel() + + // Only set Bearer header - should be used as fallback + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+validToken) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, keyID, rec.Header().Get("X-Key-ID")) + }) + + t.Run("AccessTokenQueryParameterFallback", func(t *testing.T) { + t.Parallel() + + // Only set access_token query parameter - should be used as fallback + u, _ := url.Parse("/test") + q := u.Query() + q.Set("access_token", validToken) + u.RawQuery = q.Encode() + + req := httptest.NewRequest("GET", u.String(), nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, keyID, rec.Header().Get("X-Key-ID")) + }) + + t.Run("MultipleAuthMethodsShouldNotConflict", func(t *testing.T) { + t.Parallel() + + // RFC 6750 says clients shouldn't send tokens in multiple ways, + // but if they do, we should handle it gracefully by using precedence + u, _ := url.Parse("/test") + q := u.Query() + q.Set("access_token", validToken) + q.Set(codersdk.SessionTokenCookie, validToken) + u.RawQuery = q.Encode() + + req := httptest.NewRequest("GET", u.String(), nil) + req.Header.Set("Authorization", "Bearer "+validToken) + req.AddCookie(&http.Cookie{ + Name: codersdk.SessionTokenCookie, + Value: validToken, + }) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + // Should succeed using the highest precedence method (cookie) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, keyID, rec.Header().Get("X-Key-ID")) + }) +} + +// TestOAuth2WWWAuthenticateCompliance tests WWW-Authenticate header compliance with RFC 6750 +func TestOAuth2WWWAuthenticateCompliance(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + t.Cleanup(cancel) + + db, _ := dbtestutil.NewDB(t) + user := dbgen.User(t, db, database.User{}) + + middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ + DB: db, + }) + + handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusOK) + })) + + t.Run("UnauthorizedResponse", func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer invalid-token") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusUnauthorized, rec.Code) + + wwwAuth := rec.Header().Get("WWW-Authenticate") + require.NotEmpty(t, wwwAuth) + + // RFC 6750 requires specific format: Bearer realm="realm" + require.Contains(t, wwwAuth, "Bearer") + require.Contains(t, wwwAuth, "realm=\"coder\"") + require.Contains(t, wwwAuth, "error=\"invalid_token\"") + require.Contains(t, wwwAuth, "error_description=") + }) + + t.Run("ExpiredTokenResponse", func(t *testing.T) { + t.Parallel() + + // Create an expired API key + keyID, keySecret := randomAPIKeyParts() + hashedSecret := sha256.Sum256([]byte(keySecret)) + expiredToken := keyID + "-" + keySecret + + _, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ + ID: keyID, + UserID: user.ID, + HashedSecret: hashedSecret[:], + IPAddress: defaultIPAddressForTests(), + LastUsed: dbtime.Now(), + ExpiresAt: dbtime.Now().Add(-time.Hour), // Expired 1 hour ago + CreatedAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, + LifetimeSeconds: int64(testutil.WaitLong.Seconds()), + }) + require.NoError(t, err) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+expiredToken) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusUnauthorized, rec.Code) + + wwwAuth := rec.Header().Get("WWW-Authenticate") + require.Contains(t, wwwAuth, "Bearer") + require.Contains(t, wwwAuth, "realm=\"coder\"") + require.Contains(t, wwwAuth, "error=\"invalid_token\"") + require.Contains(t, wwwAuth, "error_description=\"The access token has expired\"") + }) + + t.Run("InsufficientScopeResponse", func(t *testing.T) { + t.Parallel() + + // For this test, we'll test with an invalid token to trigger the middleware's + // error handling which does set WWW-Authenticate headers for 403 responses + // In practice, insufficient scope errors would be handled by RBAC middleware + // that comes after authentication, but we can simulate a 403 from the auth middleware + + req := httptest.NewRequest("GET", "/admin", nil) + req.Header.Set("Authorization", "Bearer invalid-token-that-triggers-403") + rec := httptest.NewRecorder() + + // Use a middleware configuration that might trigger a 403 instead of 401 + // for certain types of authentication failures + middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ + DB: db, + }) + + handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + // This shouldn't be reached due to auth failure + rw.WriteHeader(http.StatusOK) + })) + + handler.ServeHTTP(rec, req) + + // This will be a 401 (unauthorized) rather than 403 (forbidden) for invalid tokens + // which is correct - 403 would come from RBAC after successful authentication + require.Equal(t, http.StatusUnauthorized, rec.Code) + + wwwAuth := rec.Header().Get("WWW-Authenticate") + require.Contains(t, wwwAuth, "Bearer") + require.Contains(t, wwwAuth, "realm=\"coder\"") + require.Contains(t, wwwAuth, "error=\"invalid_token\"") + require.Contains(t, wwwAuth, "error_description=") + }) +} + +// TestOAuth2BearerTokenConcurrency tests Bearer token authentication under concurrent load +func TestOAuth2BearerTokenConcurrency(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + t.Cleanup(cancel) + + db, _ := dbtestutil.NewDB(t) + user := dbgen.User(t, db, database.User{}) + + // Create a valid API key + keyID, keySecret := randomAPIKeyParts() + hashedSecret := sha256.Sum256([]byte(keySecret)) + validToken := keyID + "-" + keySecret + + _, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ + ID: keyID, + UserID: user.ID, + HashedSecret: hashedSecret[:], + IPAddress: defaultIPAddressForTests(), + LastUsed: dbtime.Now(), + ExpiresAt: dbtime.Now().Add(testutil.WaitLong), + CreatedAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, + LifetimeSeconds: int64(testutil.WaitLong.Seconds()), + }) + require.NoError(t, err) + + middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ + DB: db, + }) + + handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + apiKey := httpmw.APIKey(r) + require.Equal(t, keyID, apiKey.ID) + rw.WriteHeader(http.StatusOK) + })) + + t.Run("ConcurrentValidRequests", func(t *testing.T) { + t.Parallel() + + const numGoroutines = 50 + var wg sync.WaitGroup + errors := make(chan error, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + + req := httptest.NewRequest("GET", fmt.Sprintf("/test-%d", index), nil) + req.Header.Set("Authorization", "Bearer "+validToken) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + errors <- xerrors.Errorf("goroutine %d: expected 200, got %d", index, rec.Code) + } + }(i) + } + + wg.Wait() + close(errors) + + // Check for any errors + for err := range errors { + require.NoError(t, err) + } + }) + + t.Run("ConcurrentInvalidRequests", func(t *testing.T) { + t.Parallel() + + const numGoroutines = 50 + var wg sync.WaitGroup + errors := make(chan error, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + + req := httptest.NewRequest("GET", fmt.Sprintf("/test-%d", index), nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer invalid-token-%d", index)) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + errors <- xerrors.Errorf("goroutine %d: expected 401, got %d", index, rec.Code) + } + + wwwAuth := rec.Header().Get("WWW-Authenticate") + if !strings.Contains(wwwAuth, "Bearer") { + errors <- xerrors.Errorf("goroutine %d: missing WWW-Authenticate header", index) + } + }(i) + } + + wg.Wait() + close(errors) + + // Check for any errors + for err := range errors { + require.NoError(t, err) + } + }) +} diff --git a/coderd/httpmw/rfc6750_test.go b/coderd/httpmw/rfc6750_test.go new file mode 100644 index 0000000000000..2c0dbb0529730 --- /dev/null +++ b/coderd/httpmw/rfc6750_test.go @@ -0,0 +1,296 @@ +package httpmw_test + +import ( + "context" + "crypto/sha256" + "net" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/sqlc-dev/pqtype" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +// defaultIPAddressForTests returns a default IP address for test API keys +func defaultIPAddressForTests() pqtype.Inet { + return pqtype.Inet{ + IPNet: net.IPNet{ + IP: net.IPv4(127, 0, 0, 1), + Mask: net.IPv4Mask(255, 255, 255, 255), + }, + Valid: true, + } +} + +// TestRFC6750BearerTokenAuthentication tests that RFC 6750 bearer tokens work correctly +// for authentication, including both Authorization header and access_token query parameter methods. +func TestRFC6750BearerTokenAuthentication(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + t.Cleanup(cancel) + + db, _ := dbtestutil.NewDB(t) + + // Create a test user and API key + user := dbgen.User(t, db, database.User{}) + + // Create an OAuth2 provider app token (which should work with bearer token authentication) + keyID, keySecret := randomAPIKeyParts() + hashedSecret := sha256.Sum256([]byte(keySecret)) + + key, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ + ID: keyID, + UserID: user.ID, + HashedSecret: hashedSecret[:], + IPAddress: defaultIPAddressForTests(), + LastUsed: dbtime.Now(), + ExpiresAt: dbtime.Now().Add(testutil.WaitLong), + CreatedAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, + LifetimeSeconds: int64(testutil.WaitLong.Seconds()), + }) + require.NoError(t, err) + + token := keyID + "-" + keySecret + + cfg := httpmw.ExtractAPIKeyConfig{ + DB: db, + } + + testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + apiKey := httpmw.APIKey(r) + require.Equal(t, key.ID, apiKey.ID) + rw.WriteHeader(http.StatusOK) + }) + + t.Run("AuthorizationBearerHeader", func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + + rw := httptest.NewRecorder() + + httpmw.ExtractAPIKeyMW(cfg)(testHandler).ServeHTTP(rw, req) + + require.Equal(t, http.StatusOK, rw.Code) + }) + + t.Run("AccessTokenQueryParameter", func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest("GET", "/test?access_token="+url.QueryEscape(token), nil) + + rw := httptest.NewRecorder() + + httpmw.ExtractAPIKeyMW(cfg)(testHandler).ServeHTTP(rw, req) + + require.Equal(t, http.StatusOK, rw.Code) + }) + + t.Run("BearerTokenPriorityAfterCustomMethods", func(t *testing.T) { + t.Parallel() + + // Create a different token for custom header + customKeyID, customKeySecret := randomAPIKeyParts() + customHashedSecret := sha256.Sum256([]byte(customKeySecret)) + + customKey, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ + ID: customKeyID, + UserID: user.ID, + HashedSecret: customHashedSecret[:], + IPAddress: defaultIPAddressForTests(), + LastUsed: dbtime.Now(), + ExpiresAt: dbtime.Now().Add(testutil.WaitLong), + CreatedAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, + LifetimeSeconds: int64(testutil.WaitLong.Seconds()), + }) + require.NoError(t, err) + + customToken := customKeyID + "-" + customKeySecret + + // Create handler that checks which token was used + priorityHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + apiKey := httpmw.APIKey(r) + // Should use the custom header token, not the bearer token + require.Equal(t, customKey.ID, apiKey.ID) + rw.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest("GET", "/test", nil) + // Set both custom header and bearer header - custom should win + req.Header.Set(codersdk.SessionTokenHeader, customToken) + req.Header.Set("Authorization", "Bearer "+token) + + rw := httptest.NewRecorder() + + httpmw.ExtractAPIKeyMW(cfg)(priorityHandler).ServeHTTP(rw, req) + + require.Equal(t, http.StatusOK, rw.Code) + }) + + t.Run("InvalidBearerToken", func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer invalid-token") + + rw := httptest.NewRecorder() + + httpmw.ExtractAPIKeyMW(cfg)(testHandler).ServeHTTP(rw, req) + + require.Equal(t, http.StatusUnauthorized, rw.Code) + + // Check that WWW-Authenticate header is present + wwwAuth := rw.Header().Get("WWW-Authenticate") + require.NotEmpty(t, wwwAuth) + require.Contains(t, wwwAuth, "Bearer") + require.Contains(t, wwwAuth, `realm="coder"`) + require.Contains(t, wwwAuth, "invalid_token") + }) + + t.Run("ExpiredBearerToken", func(t *testing.T) { + t.Parallel() + + // Create an expired token + expiredKeyID, expiredKeySecret := randomAPIKeyParts() + expiredHashedSecret := sha256.Sum256([]byte(expiredKeySecret)) + + _, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ + ID: expiredKeyID, + UserID: user.ID, + HashedSecret: expiredHashedSecret[:], + IPAddress: defaultIPAddressForTests(), + LastUsed: dbtime.Now(), + ExpiresAt: dbtime.Now().Add(-testutil.WaitShort), // Expired + CreatedAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, + LifetimeSeconds: int64(testutil.WaitLong.Seconds()), + }) + require.NoError(t, err) + + expiredToken := expiredKeyID + "-" + expiredKeySecret + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+expiredToken) + + rw := httptest.NewRecorder() + + httpmw.ExtractAPIKeyMW(cfg)(testHandler).ServeHTTP(rw, req) + + require.Equal(t, http.StatusUnauthorized, rw.Code) + + // Check that WWW-Authenticate header contains expired error + wwwAuth := rw.Header().Get("WWW-Authenticate") + require.NotEmpty(t, wwwAuth) + require.Contains(t, wwwAuth, "Bearer") + require.Contains(t, wwwAuth, `realm="coder"`) + require.Contains(t, wwwAuth, "expired") + }) + + t.Run("MissingBearerToken", func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest("GET", "/test", nil) + // No authentication provided + + rw := httptest.NewRecorder() + + httpmw.ExtractAPIKeyMW(cfg)(testHandler).ServeHTTP(rw, req) + + require.Equal(t, http.StatusUnauthorized, rw.Code) + + // Check that WWW-Authenticate header is present + wwwAuth := rw.Header().Get("WWW-Authenticate") + require.NotEmpty(t, wwwAuth) + require.Contains(t, wwwAuth, "Bearer") + require.Contains(t, wwwAuth, `realm="coder"`) + }) +} + +// TestAPITokenFromRequest tests the RFC 6750 bearer token extraction directly +func TestAPITokenFromRequest(t *testing.T) { + t.Parallel() + + token := "test-token-value" + + t.Run("AuthorizationBearerHeader", func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + + extractedToken := httpmw.APITokenFromRequest(req) + require.Equal(t, token, extractedToken) + }) + + t.Run("AccessTokenQueryParameter", func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest("GET", "/test?access_token="+url.QueryEscape(token), nil) + + extractedToken := httpmw.APITokenFromRequest(req) + require.Equal(t, token, extractedToken) + }) + + t.Run("CustomMethodsPriorityOverBearer", func(t *testing.T) { + t.Parallel() + customToken := "custom-token" + + req := httptest.NewRequest("GET", "/test", nil) + // Set both custom header and bearer token - custom should win + req.Header.Set(codersdk.SessionTokenHeader, customToken) + req.Header.Set("Authorization", "Bearer "+token) + + extractedToken := httpmw.APITokenFromRequest(req) + require.Equal(t, customToken, extractedToken) + }) + + t.Run("CookiePriorityOverBearer", func(t *testing.T) { + t.Parallel() + cookieToken := "cookie-token" + + req := httptest.NewRequest("GET", "/test", nil) + req.AddCookie(&http.Cookie{ + Name: codersdk.SessionTokenCookie, + Value: cookieToken, + }) + req.Header.Set("Authorization", "Bearer "+token) + + extractedToken := httpmw.APITokenFromRequest(req) + require.Equal(t, cookieToken, extractedToken) + }) + + t.Run("NoTokenReturnsEmpty", func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest("GET", "/test", nil) + + extractedToken := httpmw.APITokenFromRequest(req) + require.Empty(t, extractedToken) + }) + + t.Run("InvalidAuthorizationHeaderIgnored", func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Basic dXNlcjpwYXNz") // Basic auth, not Bearer + + extractedToken := httpmw.APITokenFromRequest(req) + require.Empty(t, extractedToken) + }) +} diff --git a/coderd/oauth2.go b/coderd/oauth2.go index cc0b84501de21..a53513013a54b 100644 --- a/coderd/oauth2.go +++ b/coderd/oauth2.go @@ -431,9 +431,8 @@ func (api *API) oauth2ProtectedResourceMetadata(rw http.ResponseWriter, r *http. AuthorizationServers: []string{api.AccessURL.String()}, // TODO: Implement scope system based on RBAC permissions ScopesSupported: []string{}, - // Note: Coder uses custom authentication methods, not RFC 6750 bearer tokens - // TODO(ThomasK33): Implement RFC 6750 - // BearerMethodsSupported: []string{}, // Omitted - no standard bearer token support + // RFC 6750 Bearer Token methods supported as fallback methods in api key middleware + BearerMethodsSupported: []string{"header", "query"}, } httpapi.Write(ctx, rw, http.StatusOK, metadata) } diff --git a/coderd/oauth2_metadata_test.go b/coderd/oauth2_metadata_test.go index 163430bc30ef9..9359e48f46ee5 100644 --- a/coderd/oauth2_metadata_test.go +++ b/coderd/oauth2_metadata_test.go @@ -66,9 +66,9 @@ func TestOAuth2ProtectedResourceMetadata(t *testing.T) { require.NotEmpty(t, metadata.AuthorizationServers) require.Len(t, metadata.AuthorizationServers, 1) require.Equal(t, metadata.Resource, metadata.AuthorizationServers[0]) - // BearerMethodsSupported is omitted since Coder uses custom authentication methods - // Standard RFC 6750 bearer tokens are not supported - require.True(t, len(metadata.BearerMethodsSupported) == 0) + // RFC 6750 bearer tokens are now supported as fallback methods + require.Contains(t, metadata.BearerMethodsSupported, "header") + require.Contains(t, metadata.BearerMethodsSupported, "query") // ScopesSupported can be empty until scope system is implemented // Empty slice is marshaled as empty array, but can be nil when unmarshaled require.True(t, len(metadata.ScopesSupported) == 0) diff --git a/scripts/oauth2/test-mcp-oauth2.sh b/scripts/oauth2/test-mcp-oauth2.sh index f53724ae19349..4585cab499114 100755 --- a/scripts/oauth2/test-mcp-oauth2.sh +++ b/scripts/oauth2/test-mcp-oauth2.sh @@ -170,6 +170,53 @@ else echo -e "${RED}✗ Token refresh failed${NC}\n" fi +# Test 6: RFC 6750 Bearer Token Authentication +echo -e "${YELLOW}Test 6: RFC 6750 Bearer Token Authentication${NC}" +ACCESS_TOKEN=$(echo "$TOKEN_RESPONSE" | jq -r '.access_token') + +# Test Authorization: Bearer header +echo -e "${BLUE}Testing Authorization: Bearer header...${NC}" +BEARER_RESPONSE=$(curl -s -w "%{http_code}" "$BASE_URL/api/v2/users/me" \ + -H "Authorization: Bearer $ACCESS_TOKEN") + +HTTP_CODE="${BEARER_RESPONSE: -3}" +if [ "$HTTP_CODE" = "200" ]; then + echo -e "${GREEN}✓ Authorization: Bearer header working${NC}" +else + echo -e "${RED}✗ Authorization: Bearer header failed (HTTP $HTTP_CODE)${NC}" +fi + +# Test access_token query parameter +echo -e "${BLUE}Testing access_token query parameter...${NC}" +QUERY_RESPONSE=$(curl -s -w "%{http_code}" "$BASE_URL/api/v2/users/me?access_token=$ACCESS_TOKEN") + +HTTP_CODE="${QUERY_RESPONSE: -3}" +if [ "$HTTP_CODE" = "200" ]; then + echo -e "${GREEN}✓ access_token query parameter working${NC}" +else + echo -e "${RED}✗ access_token query parameter failed (HTTP $HTTP_CODE)${NC}" +fi + +# Test WWW-Authenticate header on unauthorized request +echo -e "${BLUE}Testing WWW-Authenticate header on 401...${NC}" +UNAUTH_RESPONSE=$(curl -s -I "$BASE_URL/api/v2/users/me") +if echo "$UNAUTH_RESPONSE" | grep -i "WWW-Authenticate.*Bearer" >/dev/null; then + echo -e "${GREEN}✓ WWW-Authenticate header present${NC}" +else + echo -e "${RED}✗ WWW-Authenticate header missing${NC}" +fi + +# Test 7: Protected Resource Metadata +echo -e "${YELLOW}Test 7: Protected Resource Metadata (RFC 9728)${NC}" +PROTECTED_METADATA=$(curl -s "$BASE_URL/.well-known/oauth-protected-resource") +echo "$PROTECTED_METADATA" | jq . + +if echo "$PROTECTED_METADATA" | jq -e '.bearer_methods_supported[]' | grep -q "header"; then + echo -e "${GREEN}✓ Protected Resource Metadata indicates bearer token support${NC}\n" +else + echo -e "${RED}✗ Protected Resource Metadata missing bearer token support${NC}\n" +fi + # Cleanup echo -e "${YELLOW}Cleaning up...${NC}" curl -s -X DELETE "$BASE_URL/api/v2/oauth2-provider/apps/$CLIENT_ID" \