Skip to content

fix: stop extending API key access if OIDC refresh is available #17878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion coderd/coderdtest/oidctest/idp.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ func WithCustomClientAuth(hook func(t testing.TB, req *http.Request) (url.Values
// WithLogging is optional, but will log some HTTP calls made to the IDP.
func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
return func(f *FakeIDP) {
f.logger = slogtest.Make(t, options)
f.logger = slogtest.Make(t, options).Named("fakeidp")
}
}

Expand Down Expand Up @@ -794,6 +794,7 @@ func (f *FakeIDP) newToken(t testing.TB, email string, expires time.Time) string
func (f *FakeIDP) newRefreshTokens(email string) string {
refreshToken := uuid.NewString()
f.refreshTokens.Store(refreshToken, email)
f.logger.Info(context.Background(), "new refresh token", slog.F("email", email), slog.F("token", refreshToken))
return refreshToken
}

Expand Down Expand Up @@ -1003,6 +1004,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
return
}

f.logger.Info(r.Context(), "http idp call refresh_token", slog.F("token", refreshToken))
_, ok := f.refreshTokens.Load(refreshToken)
if !assert.True(t, ok, "invalid refresh_token") {
http.Error(rw, "invalid refresh_token", http.StatusBadRequest)
Expand All @@ -1026,6 +1028,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
f.refreshTokensUsed.Store(refreshToken, true)
// Always invalidate the refresh token after it is used.
f.refreshTokens.Delete(refreshToken)
f.logger.Info(r.Context(), "refresh token invalidated", slog.F("token", refreshToken))
case "urn:ietf:params:oauth:grant-type:device_code":
// Device flow
var resp externalauth.ExchangeDeviceCodeResponse
Expand Down
94 changes: 48 additions & 46 deletions coderd/httpmw/apikey.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,16 +232,21 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
return optionalWrite(http.StatusUnauthorized, resp)
}

var (
link database.UserLink
now = dbtime.Now()
// Tracks if the API key has properties updated
changed = false
)
now := dbtime.Now()
if key.ExpiresAt.Before(now) {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
})
}

// We only check OIDC stuff if we have a valid APIKey. An expired key means we don't trust the requestor
// really is the user whose key they have, and so we shouldn't be doing anything on their behalf including possibly
// refreshing the OIDC token.
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
var err error
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
link, err = cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
link, err := cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
UserID: key.UserID,
LoginType: key.LoginType,
})
Expand All @@ -258,7 +263,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
})
}
// Check if the OAuth token is expired
if link.OAuthExpiry.Before(now) && !link.OAuthExpiry.IsZero() && link.OAuthRefreshToken != "" {
if !link.OAuthExpiry.IsZero() && link.OAuthExpiry.Before(now) {
if cfg.OAuth2Configs.IsZero() {
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Expand All @@ -267,12 +272,15 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
})
}

var friendlyName string
var oauthConfig promoauth.OAuth2Config
switch key.LoginType {
case database.LoginTypeGithub:
oauthConfig = cfg.OAuth2Configs.Github
friendlyName = "GitHub"
case database.LoginTypeOIDC:
oauthConfig = cfg.OAuth2Configs.OIDC
friendlyName = "OpenID Connect"
default:
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Expand All @@ -292,36 +300,53 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
})
}

// If it is, let's refresh it from the provided config
if link.OAuthRefreshToken == "" {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("%s session expired at %q. Try signing in again.", friendlyName, link.OAuthExpiry.String()),
})
}
// We have a refresh token, so let's try it
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
AccessToken: link.OAuthAccessToken,
RefreshToken: link.OAuthRefreshToken,
Expiry: link.OAuthExpiry,
}).Token()
if err != nil {
return write(http.StatusUnauthorized, codersdk.Response{
Message: "Could not refresh expired Oauth token. Try re-authenticating to resolve this issue.",
Detail: err.Error(),
Message: fmt.Sprintf(
"Could not refresh expired %s token. Try re-authenticating to resolve this issue.",
friendlyName),
Detail: err.Error(),
})
}
link.OAuthAccessToken = token.AccessToken
link.OAuthRefreshToken = token.RefreshToken
link.OAuthExpiry = token.Expiry
key.ExpiresAt = token.Expiry
changed = true
//nolint:gocritic // system needs to update user link
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
UserID: link.UserID,
LoginType: link.LoginType,
OAuthAccessToken: link.OAuthAccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthRefreshToken: link.OAuthRefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: link.OAuthExpiry,
// Refresh should keep the same debug context because we use
// the original claims for the group/role sync.
Claims: link.Claims,
})
if err != nil {
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
})
}
}
}

// Checking if the key is expired.
// NOTE: The `RequireAuth` React component depends on this `Detail` to detect when
// the users token has expired. If you change the text here, make sure to update it
// in site/src/components/RequireAuth/RequireAuth.tsx as well.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that RequireAuth.tsx was modified to not have this string match dependency in #9442

if key.ExpiresAt.Before(now) {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
})
}
// Tracks if the API key has properties updated
changed := false

// Only update LastUsed once an hour to prevent database spam.
if now.Sub(key.LastUsed) > time.Hour {
Expand Down Expand Up @@ -363,29 +388,6 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
Detail: fmt.Sprintf("API key couldn't update: %s.", err.Error()),
})
}
// If the API Key is associated with a user_link (e.g. Github/OIDC)
// then we want to update the relevant oauth fields.
if link.UserID != uuid.Nil {
//nolint:gocritic // system needs to update user link
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
UserID: link.UserID,
LoginType: link.LoginType,
OAuthAccessToken: link.OAuthAccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthRefreshToken: link.OAuthRefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: link.OAuthExpiry,
// Refresh should keep the same debug context because we use
// the original claims for the group/role sync.
Claims: link.Claims,
})
if err != nil {
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
})
}
}

// We only want to update this occasionally to reduce DB write
// load. We update alongside the UserLink and APIKey since it's
Expand Down
158 changes: 157 additions & 1 deletion coderd/httpmw/apikey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,102 @@ func TestAPIKey(t *testing.T) {
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})

t.Run("APIKeyExpiredOAuthExpired", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
user = dbgen.User(t, db, database.User{})
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
LastUsed: dbtime.Now().AddDate(0, 0, -1),
ExpiresAt: dbtime.Now().AddDate(0, 0, -1),
LoginType: database.LoginTypeOIDC,
})
_ = dbgen.UserLink(t, db, database.UserLink{
UserID: user.ID,
LoginType: database.LoginTypeOIDC,
OAuthExpiry: dbtime.Now().AddDate(0, 0, -1),
})

r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.Header.Set(codersdk.SessionTokenHeader, token)

// Include a valid oauth token for refreshing. If this token is invalid,
// it is difficult to tell an auth failure from an expired api key, or
// an expired oauth key.
oauthToken := &oauth2.Token{
AccessToken: "wow",
RefreshToken: "moo",
Expiry: dbtime.Now().AddDate(0, 0, 1),
}
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: db,
OAuth2Configs: &httpmw.OAuth2Configs{
OIDC: &testutil.OAuth2Config{
Token: oauthToken,
},
},
RedirectToLogin: false,
})(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)

gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID)
require.NoError(t, err)

require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})

t.Run("APIKeyExpiredOAuthNotExpired", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
user = dbgen.User(t, db, database.User{})
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
LastUsed: dbtime.Now().AddDate(0, 0, -1),
ExpiresAt: dbtime.Now().AddDate(0, 0, -1),
LoginType: database.LoginTypeOIDC,
})
_ = dbgen.UserLink(t, db, database.UserLink{
UserID: user.ID,
LoginType: database.LoginTypeOIDC,
})

r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.Header.Set(codersdk.SessionTokenHeader, token)

oauthToken := &oauth2.Token{
AccessToken: "wow",
RefreshToken: "moo",
Expiry: dbtime.Now().AddDate(0, 0, 1),
}
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: db,
OAuth2Configs: &httpmw.OAuth2Configs{
OIDC: &testutil.OAuth2Config{
Token: oauthToken,
},
},
RedirectToLogin: false,
})(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)

gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID)
require.NoError(t, err)

require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})

t.Run("OAuthRefresh", func(t *testing.T) {
t.Parallel()
var (
Expand Down Expand Up @@ -553,7 +649,67 @@ func TestAPIKey(t *testing.T) {
require.NoError(t, err)

require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, oauthToken.Expiry, gotAPIKey.ExpiresAt)
// Note that OAuth expiry is independent of APIKey expiry, so an OIDC refresh DOES NOT affect the expiry of the
// APIKey
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)

gotLink, err := db.GetUserLinkByUserIDLoginType(r.Context(), database.GetUserLinkByUserIDLoginTypeParams{
UserID: user.ID,
LoginType: database.LoginTypeGithub,
})
require.NoError(t, err)
require.Equal(t, gotLink.OAuthRefreshToken, "moo")
})

t.Run("OAuthExpiredNoRefresh", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
db = dbmem.New()
user = dbgen.User(t, db, database.User{})
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
LastUsed: dbtime.Now(),
ExpiresAt: dbtime.Now().AddDate(0, 0, 1),
LoginType: database.LoginTypeGithub,
})

r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
_, err := db.InsertUserLink(ctx, database.InsertUserLinkParams{
UserID: user.ID,
LoginType: database.LoginTypeGithub,
OAuthExpiry: dbtime.Now().AddDate(0, 0, -1),
OAuthAccessToken: "letmein",
})
require.NoError(t, err)

r.Header.Set(codersdk.SessionTokenHeader, token)

oauthToken := &oauth2.Token{
AccessToken: "wow",
RefreshToken: "moo",
Expiry: dbtime.Now().AddDate(0, 0, 1),
}
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: db,
OAuth2Configs: &httpmw.OAuth2Configs{
Github: &testutil.OAuth2Config{
Token: oauthToken,
},
},
RedirectToLogin: false,
})(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)

gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID)
require.NoError(t, err)

require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})

t.Run("RemoteIPUpdates", func(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions coderd/oauthpki/okidcpki_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ func TestAzureAKPKIWithCoderd(t *testing.T) {
return values, nil
}),
oidctest.WithServing(),
oidctest.WithLogging(t, nil),
)
cfg := fake.OIDCConfig(t, scopes, func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
Expand Down
Loading