Skip to content

feat: oauth2 - add RFC 8707 resource indicators and audience validation #18575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: thomask33/06-24-feat_oauth2_add_authorization_server_metadata_endpoint_and_pkce_support
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ Read [cursor rules](.cursorrules).
- Format: `{number}_{description}.{up|down}.sql`
- Number must be unique and sequential
- Always include both up and down migrations
- **Use helper scripts**:
- `./coderd/database/migrations/create_migration.sh "migration name"` - Creates new migration files
- `./coderd/database/migrations/fix_migration_numbers.sh` - Renumbers migrations to avoid conflicts
- `./coderd/database/migrations/create_fixture.sh "fixture name"` - Creates test fixtures for migrations

2. **Update database queries**:
- MUST DO! Any changes to database - adding queries, modifying queries should be done in the `coderd/database/queries/*.sql` files
Expand Down Expand Up @@ -125,6 +129,29 @@ Read [cursor rules](.cursorrules).
4. Run `make gen` again
5. Run `make lint` to catch any remaining issues

### In-Memory Database Testing

When adding new database fields:

- **CRITICAL**: Update `coderd/database/dbmem/dbmem.go` in-memory implementations
- The `Insert*` functions must include ALL new fields, not just basic ones
- Common issue: Tests pass with real database but fail with in-memory database due to missing field mappings
- Always verify in-memory database functions match the real database schema after migrations

Example pattern:

```go
// In dbmem.go - ensure ALL fields are included
code := database.OAuth2ProviderAppCode{
ID: arg.ID,
CreatedAt: arg.CreatedAt,
// ... existing fields ...
ResourceUri: arg.ResourceUri, // New field
CodeChallenge: arg.CodeChallenge, // New field
CodeChallengeMethod: arg.CodeChallengeMethod, // New field
}
```

## Architecture

### Core Components
Expand Down Expand Up @@ -209,6 +236,12 @@ When working on OAuth2 provider features:
- Avoid dependency on referer headers for security decisions
- Support proper state parameter validation

6. **RFC 8707 Resource Indicators**:
- Store resource parameters in database for server-side validation (opaque tokens)
- Validate resource consistency between authorization and token requests
- Support audience validation in refresh token flows
- Resource parameter is optional but must be consistent when provided

### OAuth2 Error Handling Pattern

```go
Expand Down Expand Up @@ -265,3 +298,6 @@ Always run the full test suite after OAuth2 changes:
4. **Missing newlines** - Ensure files end with newline character
5. **Tests passing locally but failing in CI** - Check if `dbmem` implementation needs updating
6. **OAuth2 endpoints returning wrong error format** - Ensure OAuth2 endpoints return RFC 6749 compliant errors
7. **OAuth2 tests failing but scripts working** - Check in-memory database implementations in `dbmem.go`
8. **Resource indicator validation failing** - Ensure database stores and retrieves resource parameters correctly
9. **PKCE tests failing** - Verify both authorization code storage and token exchange handle PKCE fields
3 changes: 3 additions & 0 deletions coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,7 @@ func New(options *Options) *API {
Optional: false,
SessionTokenFunc: nil, // Default behavior
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
Logger: options.Logger,
})
// Same as above but it redirects to the login page.
apiKeyMiddlewareRedirect := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
Expand All @@ -791,6 +792,7 @@ func New(options *Options) *API {
Optional: false,
SessionTokenFunc: nil, // Default behavior
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
Logger: options.Logger,
})
// Same as the first but it's optional.
apiKeyMiddlewareOptional := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
Expand All @@ -801,6 +803,7 @@ func New(options *Options) *API {
Optional: true,
SessionTokenFunc: nil, // Default behavior
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
Logger: options.Logger,
})

workspaceAgentInfo := httpmw.ExtractWorkspaceAgentAndLatestBuild(httpmw.ExtractWorkspaceAgentAndLatestBuildConfig{
Expand Down
19 changes: 19 additions & 0 deletions coderd/database/dbauthz/dbauthz.go
Original file line number Diff line number Diff line change
Expand Up @@ -2165,6 +2165,25 @@ func (q *querier) GetOAuth2ProviderAppSecretsByAppID(ctx context.Context, appID
return q.db.GetOAuth2ProviderAppSecretsByAppID(ctx, appID)
}

func (q *querier) GetOAuth2ProviderAppTokenByAPIKeyID(ctx context.Context, apiKeyID string) (database.OAuth2ProviderAppToken, error) {
token, err := q.db.GetOAuth2ProviderAppTokenByAPIKeyID(ctx, apiKeyID)
if err != nil {
return database.OAuth2ProviderAppToken{}, err
}

// Get the associated API key to check ownership
apiKey, err := q.db.GetAPIKeyByID(ctx, token.APIKeyID)
if err != nil {
return database.OAuth2ProviderAppToken{}, err
}

if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceOauth2AppCodeToken.WithOwner(apiKey.UserID.String())); err != nil {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are oauth2 app tokens organization-scoped?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OAuth2 apps are not organization-scoped. They are site-wide resources.

The RBAC resource ResourceOauth2AppCodeToken uses user ownership (.WithOwner(apiKey.UserID.String())), not organization scoping.

return database.OAuth2ProviderAppToken{}, err
}

return token, nil
}

func (q *querier) GetOAuth2ProviderAppTokenByPrefix(ctx context.Context, hashPrefix []byte) (database.OAuth2ProviderAppToken, error) {
token, err := q.db.GetOAuth2ProviderAppTokenByPrefix(ctx, hashPrefix)
if err != nil {
Expand Down
15 changes: 15 additions & 0 deletions coderd/database/dbauthz/dbauthz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5370,6 +5370,21 @@ func (s *MethodTestSuite) TestOAuth2ProviderAppTokens() {
})
check.Args(token.HashPrefix).Asserts(rbac.ResourceOauth2AppCodeToken.WithOwner(user.ID.String()), policy.ActionRead)
}))
s.Run("GetOAuth2ProviderAppTokenByAPIKeyID", s.Subtest(func(db database.Store, check *expects) {
user := dbgen.User(s.T(), db, database.User{})
key, _ := dbgen.APIKey(s.T(), db, database.APIKey{
UserID: user.ID,
})
app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{})
secret := dbgen.OAuth2ProviderAppSecret(s.T(), db, database.OAuth2ProviderAppSecret{
AppID: app.ID,
})
token := dbgen.OAuth2ProviderAppToken(s.T(), db, database.OAuth2ProviderAppToken{
AppSecretID: secret.ID,
APIKeyID: key.ID,
})
check.Args(token.APIKeyID).Asserts(rbac.ResourceOauth2AppCodeToken.WithOwner(user.ID.String()), policy.ActionRead).Returns(token)
}))
s.Run("DeleteOAuth2ProviderAppTokensByAppAndUserID", s.Subtest(func(db database.Store, check *expects) {
dbtestutil.DisableForeignKeysAndTriggers(s.T(), db)
user := dbgen.User(s.T(), db, database.User{})
Expand Down
13 changes: 13 additions & 0 deletions coderd/database/dbmem/dbmem.go
Original file line number Diff line number Diff line change
Expand Up @@ -4050,6 +4050,19 @@ func (q *FakeQuerier) GetOAuth2ProviderAppSecretsByAppID(_ context.Context, appI
return []database.OAuth2ProviderAppSecret{}, sql.ErrNoRows
}

func (q *FakeQuerier) GetOAuth2ProviderAppTokenByAPIKeyID(_ context.Context, apiKeyID string) (database.OAuth2ProviderAppToken, error) {
q.mutex.Lock()
defer q.mutex.Unlock()

for _, token := range q.oauth2ProviderAppTokens {
if token.APIKeyID == apiKeyID {
return token, nil
}
}

return database.OAuth2ProviderAppToken{}, sql.ErrNoRows
}

func (q *FakeQuerier) GetOAuth2ProviderAppTokenByPrefix(_ context.Context, hashPrefix []byte) (database.OAuth2ProviderAppToken, error) {
q.mutex.Lock()
defer q.mutex.Unlock()
Expand Down
7 changes: 7 additions & 0 deletions coderd/database/dbmetrics/querymetrics.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions coderd/database/dbmock/dbmock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions coderd/database/querier.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions coderd/database/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions coderd/database/queries/oauth2.sql
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ INSERT INTO oauth2_provider_app_tokens (
-- name: GetOAuth2ProviderAppTokenByPrefix :one
SELECT * FROM oauth2_provider_app_tokens WHERE hash_prefix = $1;

-- name: GetOAuth2ProviderAppTokenByAPIKeyID :one
SELECT * FROM oauth2_provider_app_tokens WHERE api_key_id = $1;

-- name: GetOAuth2ProviderAppsByUserID :many
SELECT
COUNT(DISTINCT oauth2_provider_app_tokens.id) as token_count,
Expand Down
56 changes: 56 additions & 0 deletions coderd/httpmw/apikey.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"golang.org/x/oauth2"
"golang.org/x/xerrors"

"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
Expand Down Expand Up @@ -110,6 +111,9 @@ type ExtractAPIKeyConfig struct {
// This is originally implemented to send entitlement warning headers after
// a user is authenticated to prevent additional CLI invocations.
PostAuthAdditionalHeadersFunc func(a rbac.Subject, header http.Header)

// Logger is used for logging middleware operations.
Logger slog.Logger
}

// ExtractAPIKeyMW calls ExtractAPIKey with the given config on each request,
Expand Down Expand Up @@ -240,6 +244,17 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
})
}

// Validate OAuth2 provider app token audience (RFC 8707) if applicable
if key.LoginType == database.LoginTypeOAuth2ProviderApp {
if err := validateOAuth2ProviderAppTokenAudience(ctx, cfg.DB, *key, r); err != nil {
// Log the detailed error for debugging but don't expose it to the client
cfg.Logger.Info(ctx, "oauth2 token audience validation failed", slog.Error(err))
return optionalWrite(http.StatusForbidden, codersdk.Response{
Message: "Token audience validation failed",
})
}
}

// We only check OIDC stuff if we have a valid APIKey. An expired key means we don't trust the requestor
// really is the user whose key they have, and so we shouldn't be doing anything on their behalf including possibly
// refreshing the OIDC token.
Expand Down Expand Up @@ -446,6 +461,47 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
return key, &actor, true
}

// validateOAuth2ProviderAppTokenAudience validates that an OAuth2 provider app token
// is being used with the correct audience/resource server (RFC 8707).
func validateOAuth2ProviderAppTokenAudience(ctx context.Context, db database.Store, key database.APIKey, r *http.Request) error {
// Get the OAuth2 provider app token to check its audience
//nolint:gocritic // System needs to access token for audience validation
token, err := db.GetOAuth2ProviderAppTokenByAPIKeyID(dbauthz.AsSystemRestricted(ctx), key.ID)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

review: This is a legitimate use of dbauthz.SystemRestricted.

if err != nil {
return xerrors.Errorf("failed to get OAuth2 token: %w", err)
}

// If no audience is set, allow the request (for backward compatibility)
if !token.Audience.Valid || token.Audience.String == "" {
return nil
}

// Extract the expected audience from the request
expectedAudience := extractExpectedAudience(r)

// Validate that the token's audience matches the expected audience
if token.Audience.String != expectedAudience {
return xerrors.Errorf("token audience %q does not match expected audience %q",
token.Audience.String, expectedAudience)
}

return nil
}

// extractExpectedAudience determines the expected audience for the current request.
// This should match the resource parameter used during authorization.
func extractExpectedAudience(r *http.Request) string {
// For MCP compliance, the audience should be the canonical URI of the resource server
// This typically matches the access URL of the Coder deployment
scheme := "https"
if r.TLS == nil {
scheme = "http"
}

// Use the Host header to construct the canonical audience URI
return fmt.Sprintf("%s://%s", scheme, r.Host)
}

// UserRBACSubject fetches a user's rbac.Subject from the database. It pulls all roles from both
// site and organization scopes. It also pulls the groups, and the user's status.
func UserRBACSubject(ctx context.Context, db database.Store, userID uuid.UUID, scope rbac.ExpandableScope) (rbac.Subject, database.UserStatus, error) {
Expand Down
9 changes: 9 additions & 0 deletions coderd/identityprovider/authorize.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ func extractAuthorizeParams(r *http.Request, callbackURL *url.URL) (authorizePar
codeChallenge: p.String(vals, "", "code_challenge"),
codeChallengeMethod: p.String(vals, "", "code_challenge_method"),
}
// Validate resource indicator syntax (RFC 8707): must be absolute URI without fragment
if params.resource != "" {
if u, err := url.Parse(params.resource); err != nil || u.Scheme == "" || u.Fragment != "" {
p.Errors = append(p.Errors, codersdk.ValidationError{
Field: "resource",
Detail: "must be an absolute URI without fragment",
})
}
}

p.ErrorExcessParams(vals)
if len(p.Errors) > 0 {
Expand Down
Loading
Loading