diff --git a/coderd/httpmw/apikey.go b/coderd/httpmw/apikey.go index 5c06abd0fb147..97df5adfec20f 100644 --- a/coderd/httpmw/apikey.go +++ b/coderd/httpmw/apikey.go @@ -209,6 +209,26 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon return nil, nil, false } + // Add WWW-Authenticate header for 401/403 responses (RFC 6750) + if code == http.StatusUnauthorized || code == http.StatusForbidden { + // Basic Bearer challenge with realm + wwwAuth := `Bearer realm="coder"` + + // Add error details based on the type of error + switch { + case strings.Contains(response.Message, "invalid") || strings.Contains(response.Detail, "invalid"): + wwwAuth = `Bearer realm="coder", error="invalid_token", error_description="The access token is invalid"` + case strings.Contains(response.Message, "expired") || strings.Contains(response.Detail, "expired"): + wwwAuth = `Bearer realm="coder", error="invalid_token", error_description="The access token has expired"` + case strings.Contains(response.Message, "audience") || strings.Contains(response.Message, "mismatch"): + wwwAuth = `Bearer realm="coder", error="invalid_token", error_description="The access token audience does not match this resource"` + case code == http.StatusForbidden: + wwwAuth = `Bearer realm="coder", error="insufficient_scope", error_description="The request requires higher privileges than provided by the access token"` + } + + rw.Header().Set("WWW-Authenticate", wwwAuth) + } + httpapi.Write(ctx, rw, code, response) return nil, nil, false } @@ -534,9 +554,14 @@ func UserRBACSubject(ctx context.Context, db database.Store, userID uuid.UUID, s // 1: The cookie // 2. The coder_session_token query parameter // 3. The custom auth header +// 4. RFC 6750 Authorization: Bearer header +// 5. RFC 6750 access_token query parameter // // API tokens for apps are read from workspaceapps/cookies.go. func APITokenFromRequest(r *http.Request) string { + // Prioritize existing Coder custom authentication methods first + // to maintain backward compatibility and existing behavior + cookie, err := r.Cookie(codersdk.SessionTokenCookie) if err == nil && cookie.Value != "" { return cookie.Value @@ -552,7 +577,18 @@ func APITokenFromRequest(r *http.Request) string { return headerValue } - // TODO(ThomasK33): Implement RFC 6750 + // RFC 6750 Bearer Token support (added as fallback methods) + // Check Authorization: Bearer header + authHeader := r.Header.Get("Authorization") + if strings.HasPrefix(authHeader, "Bearer ") { + return strings.TrimPrefix(authHeader, "Bearer ") + } + + // Check access_token query parameter + accessToken := r.URL.Query().Get("access_token") + if accessToken != "" { + return accessToken + } return "" } diff --git a/coderd/httpmw/rfc6750_extended_test.go b/coderd/httpmw/rfc6750_extended_test.go new file mode 100644 index 0000000000000..cd0a33cc838b5 --- /dev/null +++ b/coderd/httpmw/rfc6750_extended_test.go @@ -0,0 +1,619 @@ +package httpmw_test + +import ( + "context" + "crypto/sha256" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +// TestOAuth2BearerTokenSecurityBoundaries tests RFC 6750 security boundaries +func TestOAuth2BearerTokenSecurityBoundaries(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + t.Cleanup(cancel) + + db, _ := dbtestutil.NewDB(t) + + // Create two different users with different API keys + user1 := dbgen.User(t, db, database.User{}) + user2 := dbgen.User(t, db, database.User{}) + + // Create API keys for both users + key1ID, key1Secret := randomAPIKeyParts() + hashedSecret1 := sha256.Sum256([]byte(key1Secret)) + _, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ + ID: key1ID, + UserID: user1.ID, + HashedSecret: hashedSecret1[:], + IPAddress: defaultIPAddressForTests(), + LastUsed: dbtime.Now(), + ExpiresAt: dbtime.Now().Add(testutil.WaitLong), + CreatedAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, + LifetimeSeconds: int64(testutil.WaitLong.Seconds()), + }) + require.NoError(t, err) + + key2ID, key2Secret := randomAPIKeyParts() + hashedSecret2 := sha256.Sum256([]byte(key2Secret)) + _, err = db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ + ID: key2ID, + UserID: user2.ID, + HashedSecret: hashedSecret2[:], + IPAddress: defaultIPAddressForTests(), + LastUsed: dbtime.Now(), + ExpiresAt: dbtime.Now().Add(testutil.WaitLong), + CreatedAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, + LifetimeSeconds: int64(testutil.WaitLong.Seconds()), + }) + require.NoError(t, err) + + t.Run("TokenIsolation", func(t *testing.T) { + t.Parallel() + + // Create middleware + middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ + DB: db, + }) + + // Handler that returns the authenticated user ID + handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + apiKey := httpmw.APIKey(r) + rw.Header().Set("X-User-ID", apiKey.UserID.String()) + rw.WriteHeader(http.StatusOK) + })) + + // Test that user1's token only accesses user1's data + req1 := httptest.NewRequest("GET", "/test", nil) + req1.Header.Set("Authorization", "Bearer "+key1ID+"-"+key1Secret) + rec1 := httptest.NewRecorder() + handler.ServeHTTP(rec1, req1) + + require.Equal(t, http.StatusOK, rec1.Code) + require.Equal(t, user1.ID.String(), rec1.Header().Get("X-User-ID")) + + // Test that user2's token only accesses user2's data + req2 := httptest.NewRequest("GET", "/test", nil) + req2.Header.Set("Authorization", "Bearer "+key2ID+"-"+key2Secret) + rec2 := httptest.NewRecorder() + handler.ServeHTTP(rec2, req2) + + require.Equal(t, http.StatusOK, rec2.Code) + require.Equal(t, user2.ID.String(), rec2.Header().Get("X-User-ID")) + + // Verify users can't access each other's data + require.NotEqual(t, rec1.Header().Get("X-User-ID"), rec2.Header().Get("X-User-ID")) + }) + + t.Run("CrossTokenAttempts", func(t *testing.T) { + t.Parallel() + + middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ + DB: db, + }) + + handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusOK) + })) + + // Try to use user1's key ID with user2's secret (should fail) + invalidToken := key1ID + "-" + key2Secret + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+invalidToken) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusUnauthorized, rec.Code) + require.Contains(t, rec.Header().Get("WWW-Authenticate"), "Bearer") + require.Contains(t, rec.Header().Get("WWW-Authenticate"), "invalid_token") + }) + + t.Run("TimingAttackResistance", func(t *testing.T) { + t.Parallel() + + middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ + DB: db, + }) + + handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusOK) + })) + + // Test multiple invalid tokens to ensure consistent timing + invalidTokens := []string{ + "invalid-token-1", + "invalid-token-2-longer", + "a", + strings.Repeat("x", 100), + } + + times := make([]time.Duration, len(invalidTokens)) + + for i, token := range invalidTokens { + start := time.Now() + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + times[i] = time.Since(start) + + require.Equal(t, http.StatusUnauthorized, rec.Code) + } + + // While we can't guarantee perfect timing consistency in tests, + // we can at least verify that the responses are all unauthorized + // and contain proper WWW-Authenticate headers + for _, token := range invalidTokens { + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusUnauthorized, rec.Code) + require.Contains(t, rec.Header().Get("WWW-Authenticate"), "Bearer") + } + }) +} + +// TestOAuth2BearerTokenMalformedHeaders tests handling of malformed Bearer headers per RFC 6750 +func TestOAuth2BearerTokenMalformedHeaders(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + + middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ + DB: db, + }) + + handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusOK) + })) + + tests := []struct { + name string + authHeader string + expectedStatus int + shouldHaveWWW bool + }{ + { + name: "MissingBearer", + authHeader: "invalid-token", + expectedStatus: http.StatusUnauthorized, + shouldHaveWWW: true, + }, + { + name: "CaseSensitive", + authHeader: "bearer token", // lowercase should still work + expectedStatus: http.StatusUnauthorized, + shouldHaveWWW: true, + }, + { + name: "ExtraSpaces", + authHeader: "Bearer token-with-extra-spaces", + expectedStatus: http.StatusUnauthorized, + shouldHaveWWW: true, + }, + { + name: "EmptyToken", + authHeader: "Bearer ", + expectedStatus: http.StatusUnauthorized, + shouldHaveWWW: true, + }, + { + name: "OnlyBearer", + authHeader: "Bearer", + expectedStatus: http.StatusUnauthorized, + shouldHaveWWW: true, + }, + { + name: "MultipleBearer", + authHeader: "Bearer token1 Bearer token2", + expectedStatus: http.StatusUnauthorized, + shouldHaveWWW: true, + }, + { + name: "InvalidBase64", + authHeader: "Bearer !!!invalid-base64!!!", + expectedStatus: http.StatusUnauthorized, + shouldHaveWWW: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", test.authHeader) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + require.Equal(t, test.expectedStatus, rec.Code) + + if test.shouldHaveWWW { + wwwAuth := rec.Header().Get("WWW-Authenticate") + require.Contains(t, wwwAuth, "Bearer") + require.Contains(t, wwwAuth, "realm=\"coder\"") + } + }) + } +} + +// TestOAuth2BearerTokenPrecedence tests token extraction precedence per RFC 6750 +func TestOAuth2BearerTokenPrecedence(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + t.Cleanup(cancel) + + db, _ := dbtestutil.NewDB(t) + user := dbgen.User(t, db, database.User{}) + + // Create a valid API key + keyID, keySecret := randomAPIKeyParts() + hashedSecret := sha256.Sum256([]byte(keySecret)) + validToken := keyID + "-" + keySecret + + _, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ + ID: keyID, + UserID: user.ID, + HashedSecret: hashedSecret[:], + IPAddress: defaultIPAddressForTests(), + LastUsed: dbtime.Now(), + ExpiresAt: dbtime.Now().Add(testutil.WaitLong), + CreatedAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, + LifetimeSeconds: int64(testutil.WaitLong.Seconds()), + }) + require.NoError(t, err) + + middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ + DB: db, + }) + + handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + apiKey := httpmw.APIKey(r) + rw.Header().Set("X-Key-ID", apiKey.ID) + rw.WriteHeader(http.StatusOK) + })) + + t.Run("CookieTakesPrecedenceOverBearer", func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest("GET", "/test", nil) + // Set both cookie and Bearer header - cookie should take precedence + req.AddCookie(&http.Cookie{ + Name: codersdk.SessionTokenCookie, + Value: validToken, + }) + req.Header.Set("Authorization", "Bearer invalid-token") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, keyID, rec.Header().Get("X-Key-ID")) + }) + + t.Run("QueryParameterTakesPrecedenceOverBearer", func(t *testing.T) { + t.Parallel() + + // Set both query parameter and Bearer header - query should take precedence + u, _ := url.Parse("/test") + q := u.Query() + q.Set(codersdk.SessionTokenCookie, validToken) + u.RawQuery = q.Encode() + + req := httptest.NewRequest("GET", u.String(), nil) + req.Header.Set("Authorization", "Bearer invalid-token") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, keyID, rec.Header().Get("X-Key-ID")) + }) + + t.Run("BearerHeaderFallback", func(t *testing.T) { + t.Parallel() + + // Only set Bearer header - should be used as fallback + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+validToken) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, keyID, rec.Header().Get("X-Key-ID")) + }) + + t.Run("AccessTokenQueryParameterFallback", func(t *testing.T) { + t.Parallel() + + // Only set access_token query parameter - should be used as fallback + u, _ := url.Parse("/test") + q := u.Query() + q.Set("access_token", validToken) + u.RawQuery = q.Encode() + + req := httptest.NewRequest("GET", u.String(), nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, keyID, rec.Header().Get("X-Key-ID")) + }) + + t.Run("MultipleAuthMethodsShouldNotConflict", func(t *testing.T) { + t.Parallel() + + // RFC 6750 says clients shouldn't send tokens in multiple ways, + // but if they do, we should handle it gracefully by using precedence + u, _ := url.Parse("/test") + q := u.Query() + q.Set("access_token", validToken) + q.Set(codersdk.SessionTokenCookie, validToken) + u.RawQuery = q.Encode() + + req := httptest.NewRequest("GET", u.String(), nil) + req.Header.Set("Authorization", "Bearer "+validToken) + req.AddCookie(&http.Cookie{ + Name: codersdk.SessionTokenCookie, + Value: validToken, + }) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + // Should succeed using the highest precedence method (cookie) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, keyID, rec.Header().Get("X-Key-ID")) + }) +} + +// TestOAuth2WWWAuthenticateCompliance tests WWW-Authenticate header compliance with RFC 6750 +func TestOAuth2WWWAuthenticateCompliance(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + t.Cleanup(cancel) + + db, _ := dbtestutil.NewDB(t) + user := dbgen.User(t, db, database.User{}) + + middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ + DB: db, + }) + + handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusOK) + })) + + t.Run("UnauthorizedResponse", func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer invalid-token") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusUnauthorized, rec.Code) + + wwwAuth := rec.Header().Get("WWW-Authenticate") + require.NotEmpty(t, wwwAuth) + + // RFC 6750 requires specific format: Bearer realm="realm" + require.Contains(t, wwwAuth, "Bearer") + require.Contains(t, wwwAuth, "realm=\"coder\"") + require.Contains(t, wwwAuth, "error=\"invalid_token\"") + require.Contains(t, wwwAuth, "error_description=") + }) + + t.Run("ExpiredTokenResponse", func(t *testing.T) { + t.Parallel() + + // Create an expired API key + keyID, keySecret := randomAPIKeyParts() + hashedSecret := sha256.Sum256([]byte(keySecret)) + expiredToken := keyID + "-" + keySecret + + _, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ + ID: keyID, + UserID: user.ID, + HashedSecret: hashedSecret[:], + IPAddress: defaultIPAddressForTests(), + LastUsed: dbtime.Now(), + ExpiresAt: dbtime.Now().Add(-time.Hour), // Expired 1 hour ago + CreatedAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, + LifetimeSeconds: int64(testutil.WaitLong.Seconds()), + }) + require.NoError(t, err) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+expiredToken) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusUnauthorized, rec.Code) + + wwwAuth := rec.Header().Get("WWW-Authenticate") + require.Contains(t, wwwAuth, "Bearer") + require.Contains(t, wwwAuth, "realm=\"coder\"") + require.Contains(t, wwwAuth, "error=\"invalid_token\"") + require.Contains(t, wwwAuth, "error_description=\"The access token has expired\"") + }) + + t.Run("InsufficientScopeResponse", func(t *testing.T) { + t.Parallel() + + // For this test, we'll test with an invalid token to trigger the middleware's + // error handling which does set WWW-Authenticate headers for 403 responses + // In practice, insufficient scope errors would be handled by RBAC middleware + // that comes after authentication, but we can simulate a 403 from the auth middleware + + req := httptest.NewRequest("GET", "/admin", nil) + req.Header.Set("Authorization", "Bearer invalid-token-that-triggers-403") + rec := httptest.NewRecorder() + + // Use a middleware configuration that might trigger a 403 instead of 401 + // for certain types of authentication failures + middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ + DB: db, + }) + + handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + // This shouldn't be reached due to auth failure + rw.WriteHeader(http.StatusOK) + })) + + handler.ServeHTTP(rec, req) + + // This will be a 401 (unauthorized) rather than 403 (forbidden) for invalid tokens + // which is correct - 403 would come from RBAC after successful authentication + require.Equal(t, http.StatusUnauthorized, rec.Code) + + wwwAuth := rec.Header().Get("WWW-Authenticate") + require.Contains(t, wwwAuth, "Bearer") + require.Contains(t, wwwAuth, "realm=\"coder\"") + require.Contains(t, wwwAuth, "error=\"invalid_token\"") + require.Contains(t, wwwAuth, "error_description=") + }) +} + +// TestOAuth2BearerTokenConcurrency tests Bearer token authentication under concurrent load +func TestOAuth2BearerTokenConcurrency(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + t.Cleanup(cancel) + + db, _ := dbtestutil.NewDB(t) + user := dbgen.User(t, db, database.User{}) + + // Create a valid API key + keyID, keySecret := randomAPIKeyParts() + hashedSecret := sha256.Sum256([]byte(keySecret)) + validToken := keyID + "-" + keySecret + + _, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ + ID: keyID, + UserID: user.ID, + HashedSecret: hashedSecret[:], + IPAddress: defaultIPAddressForTests(), + LastUsed: dbtime.Now(), + ExpiresAt: dbtime.Now().Add(testutil.WaitLong), + CreatedAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, + LifetimeSeconds: int64(testutil.WaitLong.Seconds()), + }) + require.NoError(t, err) + + middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ + DB: db, + }) + + handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + apiKey := httpmw.APIKey(r) + require.Equal(t, keyID, apiKey.ID) + rw.WriteHeader(http.StatusOK) + })) + + t.Run("ConcurrentValidRequests", func(t *testing.T) { + t.Parallel() + + const numGoroutines = 50 + var wg sync.WaitGroup + errors := make(chan error, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + + req := httptest.NewRequest("GET", fmt.Sprintf("/test-%d", index), nil) + req.Header.Set("Authorization", "Bearer "+validToken) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + errors <- xerrors.Errorf("goroutine %d: expected 200, got %d", index, rec.Code) + } + }(i) + } + + wg.Wait() + close(errors) + + // Check for any errors + for err := range errors { + require.NoError(t, err) + } + }) + + t.Run("ConcurrentInvalidRequests", func(t *testing.T) { + t.Parallel() + + const numGoroutines = 50 + var wg sync.WaitGroup + errors := make(chan error, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + + req := httptest.NewRequest("GET", fmt.Sprintf("/test-%d", index), nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer invalid-token-%d", index)) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + errors <- xerrors.Errorf("goroutine %d: expected 401, got %d", index, rec.Code) + } + + wwwAuth := rec.Header().Get("WWW-Authenticate") + if !strings.Contains(wwwAuth, "Bearer") { + errors <- xerrors.Errorf("goroutine %d: missing WWW-Authenticate header", index) + } + }(i) + } + + wg.Wait() + close(errors) + + // Check for any errors + for err := range errors { + require.NoError(t, err) + } + }) +} diff --git a/coderd/httpmw/rfc6750_test.go b/coderd/httpmw/rfc6750_test.go new file mode 100644 index 0000000000000..2c0dbb0529730 --- /dev/null +++ b/coderd/httpmw/rfc6750_test.go @@ -0,0 +1,296 @@ +package httpmw_test + +import ( + "context" + "crypto/sha256" + "net" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/sqlc-dev/pqtype" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +// defaultIPAddressForTests returns a default IP address for test API keys +func defaultIPAddressForTests() pqtype.Inet { + return pqtype.Inet{ + IPNet: net.IPNet{ + IP: net.IPv4(127, 0, 0, 1), + Mask: net.IPv4Mask(255, 255, 255, 255), + }, + Valid: true, + } +} + +// TestRFC6750BearerTokenAuthentication tests that RFC 6750 bearer tokens work correctly +// for authentication, including both Authorization header and access_token query parameter methods. +func TestRFC6750BearerTokenAuthentication(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + t.Cleanup(cancel) + + db, _ := dbtestutil.NewDB(t) + + // Create a test user and API key + user := dbgen.User(t, db, database.User{}) + + // Create an OAuth2 provider app token (which should work with bearer token authentication) + keyID, keySecret := randomAPIKeyParts() + hashedSecret := sha256.Sum256([]byte(keySecret)) + + key, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ + ID: keyID, + UserID: user.ID, + HashedSecret: hashedSecret[:], + IPAddress: defaultIPAddressForTests(), + LastUsed: dbtime.Now(), + ExpiresAt: dbtime.Now().Add(testutil.WaitLong), + CreatedAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, + LifetimeSeconds: int64(testutil.WaitLong.Seconds()), + }) + require.NoError(t, err) + + token := keyID + "-" + keySecret + + cfg := httpmw.ExtractAPIKeyConfig{ + DB: db, + } + + testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + apiKey := httpmw.APIKey(r) + require.Equal(t, key.ID, apiKey.ID) + rw.WriteHeader(http.StatusOK) + }) + + t.Run("AuthorizationBearerHeader", func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + + rw := httptest.NewRecorder() + + httpmw.ExtractAPIKeyMW(cfg)(testHandler).ServeHTTP(rw, req) + + require.Equal(t, http.StatusOK, rw.Code) + }) + + t.Run("AccessTokenQueryParameter", func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest("GET", "/test?access_token="+url.QueryEscape(token), nil) + + rw := httptest.NewRecorder() + + httpmw.ExtractAPIKeyMW(cfg)(testHandler).ServeHTTP(rw, req) + + require.Equal(t, http.StatusOK, rw.Code) + }) + + t.Run("BearerTokenPriorityAfterCustomMethods", func(t *testing.T) { + t.Parallel() + + // Create a different token for custom header + customKeyID, customKeySecret := randomAPIKeyParts() + customHashedSecret := sha256.Sum256([]byte(customKeySecret)) + + customKey, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ + ID: customKeyID, + UserID: user.ID, + HashedSecret: customHashedSecret[:], + IPAddress: defaultIPAddressForTests(), + LastUsed: dbtime.Now(), + ExpiresAt: dbtime.Now().Add(testutil.WaitLong), + CreatedAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, + LifetimeSeconds: int64(testutil.WaitLong.Seconds()), + }) + require.NoError(t, err) + + customToken := customKeyID + "-" + customKeySecret + + // Create handler that checks which token was used + priorityHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + apiKey := httpmw.APIKey(r) + // Should use the custom header token, not the bearer token + require.Equal(t, customKey.ID, apiKey.ID) + rw.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest("GET", "/test", nil) + // Set both custom header and bearer header - custom should win + req.Header.Set(codersdk.SessionTokenHeader, customToken) + req.Header.Set("Authorization", "Bearer "+token) + + rw := httptest.NewRecorder() + + httpmw.ExtractAPIKeyMW(cfg)(priorityHandler).ServeHTTP(rw, req) + + require.Equal(t, http.StatusOK, rw.Code) + }) + + t.Run("InvalidBearerToken", func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer invalid-token") + + rw := httptest.NewRecorder() + + httpmw.ExtractAPIKeyMW(cfg)(testHandler).ServeHTTP(rw, req) + + require.Equal(t, http.StatusUnauthorized, rw.Code) + + // Check that WWW-Authenticate header is present + wwwAuth := rw.Header().Get("WWW-Authenticate") + require.NotEmpty(t, wwwAuth) + require.Contains(t, wwwAuth, "Bearer") + require.Contains(t, wwwAuth, `realm="coder"`) + require.Contains(t, wwwAuth, "invalid_token") + }) + + t.Run("ExpiredBearerToken", func(t *testing.T) { + t.Parallel() + + // Create an expired token + expiredKeyID, expiredKeySecret := randomAPIKeyParts() + expiredHashedSecret := sha256.Sum256([]byte(expiredKeySecret)) + + _, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ + ID: expiredKeyID, + UserID: user.ID, + HashedSecret: expiredHashedSecret[:], + IPAddress: defaultIPAddressForTests(), + LastUsed: dbtime.Now(), + ExpiresAt: dbtime.Now().Add(-testutil.WaitShort), // Expired + CreatedAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, + LifetimeSeconds: int64(testutil.WaitLong.Seconds()), + }) + require.NoError(t, err) + + expiredToken := expiredKeyID + "-" + expiredKeySecret + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+expiredToken) + + rw := httptest.NewRecorder() + + httpmw.ExtractAPIKeyMW(cfg)(testHandler).ServeHTTP(rw, req) + + require.Equal(t, http.StatusUnauthorized, rw.Code) + + // Check that WWW-Authenticate header contains expired error + wwwAuth := rw.Header().Get("WWW-Authenticate") + require.NotEmpty(t, wwwAuth) + require.Contains(t, wwwAuth, "Bearer") + require.Contains(t, wwwAuth, `realm="coder"`) + require.Contains(t, wwwAuth, "expired") + }) + + t.Run("MissingBearerToken", func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest("GET", "/test", nil) + // No authentication provided + + rw := httptest.NewRecorder() + + httpmw.ExtractAPIKeyMW(cfg)(testHandler).ServeHTTP(rw, req) + + require.Equal(t, http.StatusUnauthorized, rw.Code) + + // Check that WWW-Authenticate header is present + wwwAuth := rw.Header().Get("WWW-Authenticate") + require.NotEmpty(t, wwwAuth) + require.Contains(t, wwwAuth, "Bearer") + require.Contains(t, wwwAuth, `realm="coder"`) + }) +} + +// TestAPITokenFromRequest tests the RFC 6750 bearer token extraction directly +func TestAPITokenFromRequest(t *testing.T) { + t.Parallel() + + token := "test-token-value" + + t.Run("AuthorizationBearerHeader", func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + + extractedToken := httpmw.APITokenFromRequest(req) + require.Equal(t, token, extractedToken) + }) + + t.Run("AccessTokenQueryParameter", func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest("GET", "/test?access_token="+url.QueryEscape(token), nil) + + extractedToken := httpmw.APITokenFromRequest(req) + require.Equal(t, token, extractedToken) + }) + + t.Run("CustomMethodsPriorityOverBearer", func(t *testing.T) { + t.Parallel() + customToken := "custom-token" + + req := httptest.NewRequest("GET", "/test", nil) + // Set both custom header and bearer token - custom should win + req.Header.Set(codersdk.SessionTokenHeader, customToken) + req.Header.Set("Authorization", "Bearer "+token) + + extractedToken := httpmw.APITokenFromRequest(req) + require.Equal(t, customToken, extractedToken) + }) + + t.Run("CookiePriorityOverBearer", func(t *testing.T) { + t.Parallel() + cookieToken := "cookie-token" + + req := httptest.NewRequest("GET", "/test", nil) + req.AddCookie(&http.Cookie{ + Name: codersdk.SessionTokenCookie, + Value: cookieToken, + }) + req.Header.Set("Authorization", "Bearer "+token) + + extractedToken := httpmw.APITokenFromRequest(req) + require.Equal(t, cookieToken, extractedToken) + }) + + t.Run("NoTokenReturnsEmpty", func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest("GET", "/test", nil) + + extractedToken := httpmw.APITokenFromRequest(req) + require.Empty(t, extractedToken) + }) + + t.Run("InvalidAuthorizationHeaderIgnored", func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Basic dXNlcjpwYXNz") // Basic auth, not Bearer + + extractedToken := httpmw.APITokenFromRequest(req) + require.Empty(t, extractedToken) + }) +} diff --git a/coderd/oauth2.go b/coderd/oauth2.go index cc0b84501de21..a53513013a54b 100644 --- a/coderd/oauth2.go +++ b/coderd/oauth2.go @@ -431,9 +431,8 @@ func (api *API) oauth2ProtectedResourceMetadata(rw http.ResponseWriter, r *http. AuthorizationServers: []string{api.AccessURL.String()}, // TODO: Implement scope system based on RBAC permissions ScopesSupported: []string{}, - // Note: Coder uses custom authentication methods, not RFC 6750 bearer tokens - // TODO(ThomasK33): Implement RFC 6750 - // BearerMethodsSupported: []string{}, // Omitted - no standard bearer token support + // RFC 6750 Bearer Token methods supported as fallback methods in api key middleware + BearerMethodsSupported: []string{"header", "query"}, } httpapi.Write(ctx, rw, http.StatusOK, metadata) } diff --git a/coderd/oauth2_metadata_test.go b/coderd/oauth2_metadata_test.go index 163430bc30ef9..9359e48f46ee5 100644 --- a/coderd/oauth2_metadata_test.go +++ b/coderd/oauth2_metadata_test.go @@ -66,9 +66,9 @@ func TestOAuth2ProtectedResourceMetadata(t *testing.T) { require.NotEmpty(t, metadata.AuthorizationServers) require.Len(t, metadata.AuthorizationServers, 1) require.Equal(t, metadata.Resource, metadata.AuthorizationServers[0]) - // BearerMethodsSupported is omitted since Coder uses custom authentication methods - // Standard RFC 6750 bearer tokens are not supported - require.True(t, len(metadata.BearerMethodsSupported) == 0) + // RFC 6750 bearer tokens are now supported as fallback methods + require.Contains(t, metadata.BearerMethodsSupported, "header") + require.Contains(t, metadata.BearerMethodsSupported, "query") // ScopesSupported can be empty until scope system is implemented // Empty slice is marshaled as empty array, but can be nil when unmarshaled require.True(t, len(metadata.ScopesSupported) == 0) diff --git a/scripts/oauth2/test-mcp-oauth2.sh b/scripts/oauth2/test-mcp-oauth2.sh index f53724ae19349..4585cab499114 100755 --- a/scripts/oauth2/test-mcp-oauth2.sh +++ b/scripts/oauth2/test-mcp-oauth2.sh @@ -170,6 +170,53 @@ else echo -e "${RED}✗ Token refresh failed${NC}\n" fi +# Test 6: RFC 6750 Bearer Token Authentication +echo -e "${YELLOW}Test 6: RFC 6750 Bearer Token Authentication${NC}" +ACCESS_TOKEN=$(echo "$TOKEN_RESPONSE" | jq -r '.access_token') + +# Test Authorization: Bearer header +echo -e "${BLUE}Testing Authorization: Bearer header...${NC}" +BEARER_RESPONSE=$(curl -s -w "%{http_code}" "$BASE_URL/api/v2/users/me" \ + -H "Authorization: Bearer $ACCESS_TOKEN") + +HTTP_CODE="${BEARER_RESPONSE: -3}" +if [ "$HTTP_CODE" = "200" ]; then + echo -e "${GREEN}✓ Authorization: Bearer header working${NC}" +else + echo -e "${RED}✗ Authorization: Bearer header failed (HTTP $HTTP_CODE)${NC}" +fi + +# Test access_token query parameter +echo -e "${BLUE}Testing access_token query parameter...${NC}" +QUERY_RESPONSE=$(curl -s -w "%{http_code}" "$BASE_URL/api/v2/users/me?access_token=$ACCESS_TOKEN") + +HTTP_CODE="${QUERY_RESPONSE: -3}" +if [ "$HTTP_CODE" = "200" ]; then + echo -e "${GREEN}✓ access_token query parameter working${NC}" +else + echo -e "${RED}✗ access_token query parameter failed (HTTP $HTTP_CODE)${NC}" +fi + +# Test WWW-Authenticate header on unauthorized request +echo -e "${BLUE}Testing WWW-Authenticate header on 401...${NC}" +UNAUTH_RESPONSE=$(curl -s -I "$BASE_URL/api/v2/users/me") +if echo "$UNAUTH_RESPONSE" | grep -i "WWW-Authenticate.*Bearer" >/dev/null; then + echo -e "${GREEN}✓ WWW-Authenticate header present${NC}" +else + echo -e "${RED}✗ WWW-Authenticate header missing${NC}" +fi + +# Test 7: Protected Resource Metadata +echo -e "${YELLOW}Test 7: Protected Resource Metadata (RFC 9728)${NC}" +PROTECTED_METADATA=$(curl -s "$BASE_URL/.well-known/oauth-protected-resource") +echo "$PROTECTED_METADATA" | jq . + +if echo "$PROTECTED_METADATA" | jq -e '.bearer_methods_supported[]' | grep -q "header"; then + echo -e "${GREEN}✓ Protected Resource Metadata indicates bearer token support${NC}\n" +else + echo -e "${RED}✗ Protected Resource Metadata missing bearer token support${NC}\n" +fi + # Cleanup echo -e "${YELLOW}Cleaning up...${NC}" curl -s -X DELETE "$BASE_URL/api/v2/oauth2-provider/apps/$CLIENT_ID" \