Skip to content

Commit 1a41608

Browse files
authored
fix: stop extending API key access if OIDC refresh is available (#17878)
fixes #17070 Cleans up our handling of APIKey expiration and OIDC to keep them separate concepts. For an OIDC-login APIKey, both the APIKey and OIDC link must be valid to login. If the OIDC link is expired and we have a refresh token, we will attempt to refresh. OIDC refreshes do not have any effect on APIKey expiry. #17070 (comment) explains why this is the correct behavior.
1 parent ca5a78a commit 1a41608

File tree

4 files changed

+210
-48
lines changed

4 files changed

+210
-48
lines changed

coderd/coderdtest/oidctest/idp.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ func WithCustomClientAuth(hook func(t testing.TB, req *http.Request) (url.Values
307307
// WithLogging is optional, but will log some HTTP calls made to the IDP.
308308
func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
309309
return func(f *FakeIDP) {
310-
f.logger = slogtest.Make(t, options)
310+
f.logger = slogtest.Make(t, options).Named("fakeidp")
311311
}
312312
}
313313

@@ -794,6 +794,7 @@ func (f *FakeIDP) newToken(t testing.TB, email string, expires time.Time) string
794794
func (f *FakeIDP) newRefreshTokens(email string) string {
795795
refreshToken := uuid.NewString()
796796
f.refreshTokens.Store(refreshToken, email)
797+
f.logger.Info(context.Background(), "new refresh token", slog.F("email", email), slog.F("token", refreshToken))
797798
return refreshToken
798799
}
799800

@@ -1003,6 +1004,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
10031004
return
10041005
}
10051006

1007+
f.logger.Info(r.Context(), "http idp call refresh_token", slog.F("token", refreshToken))
10061008
_, ok := f.refreshTokens.Load(refreshToken)
10071009
if !assert.True(t, ok, "invalid refresh_token") {
10081010
http.Error(rw, "invalid refresh_token", http.StatusBadRequest)
@@ -1026,6 +1028,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
10261028
f.refreshTokensUsed.Store(refreshToken, true)
10271029
// Always invalidate the refresh token after it is used.
10281030
f.refreshTokens.Delete(refreshToken)
1031+
f.logger.Info(r.Context(), "refresh token invalidated", slog.F("token", refreshToken))
10291032
case "urn:ietf:params:oauth:grant-type:device_code":
10301033
// Device flow
10311034
var resp externalauth.ExchangeDeviceCodeResponse

coderd/httpmw/apikey.go

Lines changed: 48 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -232,16 +232,21 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
232232
return optionalWrite(http.StatusUnauthorized, resp)
233233
}
234234

235-
var (
236-
link database.UserLink
237-
now = dbtime.Now()
238-
// Tracks if the API key has properties updated
239-
changed = false
240-
)
235+
now := dbtime.Now()
236+
if key.ExpiresAt.Before(now) {
237+
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
238+
Message: SignedOutErrorMessage,
239+
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
240+
})
241+
}
242+
243+
// We only check OIDC stuff if we have a valid APIKey. An expired key means we don't trust the requestor
244+
// really is the user whose key they have, and so we shouldn't be doing anything on their behalf including possibly
245+
// refreshing the OIDC token.
241246
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
242247
var err error
243248
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
244-
link, err = cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
249+
link, err := cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
245250
UserID: key.UserID,
246251
LoginType: key.LoginType,
247252
})
@@ -258,7 +263,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
258263
})
259264
}
260265
// Check if the OAuth token is expired
261-
if link.OAuthExpiry.Before(now) && !link.OAuthExpiry.IsZero() && link.OAuthRefreshToken != "" {
266+
if !link.OAuthExpiry.IsZero() && link.OAuthExpiry.Before(now) {
262267
if cfg.OAuth2Configs.IsZero() {
263268
return write(http.StatusInternalServerError, codersdk.Response{
264269
Message: internalErrorMessage,
@@ -267,12 +272,15 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
267272
})
268273
}
269274

275+
var friendlyName string
270276
var oauthConfig promoauth.OAuth2Config
271277
switch key.LoginType {
272278
case database.LoginTypeGithub:
273279
oauthConfig = cfg.OAuth2Configs.Github
280+
friendlyName = "GitHub"
274281
case database.LoginTypeOIDC:
275282
oauthConfig = cfg.OAuth2Configs.OIDC
283+
friendlyName = "OpenID Connect"
276284
default:
277285
return write(http.StatusInternalServerError, codersdk.Response{
278286
Message: internalErrorMessage,
@@ -292,36 +300,53 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
292300
})
293301
}
294302

295-
// If it is, let's refresh it from the provided config
303+
if link.OAuthRefreshToken == "" {
304+
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
305+
Message: SignedOutErrorMessage,
306+
Detail: fmt.Sprintf("%s session expired at %q. Try signing in again.", friendlyName, link.OAuthExpiry.String()),
307+
})
308+
}
309+
// We have a refresh token, so let's try it
296310
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
297311
AccessToken: link.OAuthAccessToken,
298312
RefreshToken: link.OAuthRefreshToken,
299313
Expiry: link.OAuthExpiry,
300314
}).Token()
301315
if err != nil {
302316
return write(http.StatusUnauthorized, codersdk.Response{
303-
Message: "Could not refresh expired Oauth token. Try re-authenticating to resolve this issue.",
304-
Detail: err.Error(),
317+
Message: fmt.Sprintf(
318+
"Could not refresh expired %s token. Try re-authenticating to resolve this issue.",
319+
friendlyName),
320+
Detail: err.Error(),
305321
})
306322
}
307323
link.OAuthAccessToken = token.AccessToken
308324
link.OAuthRefreshToken = token.RefreshToken
309325
link.OAuthExpiry = token.Expiry
310-
key.ExpiresAt = token.Expiry
311-
changed = true
326+
//nolint:gocritic // system needs to update user link
327+
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
328+
UserID: link.UserID,
329+
LoginType: link.LoginType,
330+
OAuthAccessToken: link.OAuthAccessToken,
331+
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
332+
OAuthRefreshToken: link.OAuthRefreshToken,
333+
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
334+
OAuthExpiry: link.OAuthExpiry,
335+
// Refresh should keep the same debug context because we use
336+
// the original claims for the group/role sync.
337+
Claims: link.Claims,
338+
})
339+
if err != nil {
340+
return write(http.StatusInternalServerError, codersdk.Response{
341+
Message: internalErrorMessage,
342+
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
343+
})
344+
}
312345
}
313346
}
314347

315-
// Checking if the key is expired.
316-
// NOTE: The `RequireAuth` React component depends on this `Detail` to detect when
317-
// the users token has expired. If you change the text here, make sure to update it
318-
// in site/src/components/RequireAuth/RequireAuth.tsx as well.
319-
if key.ExpiresAt.Before(now) {
320-
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
321-
Message: SignedOutErrorMessage,
322-
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
323-
})
324-
}
348+
// Tracks if the API key has properties updated
349+
changed := false
325350

326351
// Only update LastUsed once an hour to prevent database spam.
327352
if now.Sub(key.LastUsed) > time.Hour {
@@ -363,29 +388,6 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
363388
Detail: fmt.Sprintf("API key couldn't update: %s.", err.Error()),
364389
})
365390
}
366-
// If the API Key is associated with a user_link (e.g. Github/OIDC)
367-
// then we want to update the relevant oauth fields.
368-
if link.UserID != uuid.Nil {
369-
//nolint:gocritic // system needs to update user link
370-
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
371-
UserID: link.UserID,
372-
LoginType: link.LoginType,
373-
OAuthAccessToken: link.OAuthAccessToken,
374-
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
375-
OAuthRefreshToken: link.OAuthRefreshToken,
376-
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
377-
OAuthExpiry: link.OAuthExpiry,
378-
// Refresh should keep the same debug context because we use
379-
// the original claims for the group/role sync.
380-
Claims: link.Claims,
381-
})
382-
if err != nil {
383-
return write(http.StatusInternalServerError, codersdk.Response{
384-
Message: internalErrorMessage,
385-
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
386-
})
387-
}
388-
}
389391

390392
// We only want to update this occasionally to reduce DB write
391393
// load. We update alongside the UserLink and APIKey since it's

coderd/httpmw/apikey_test.go

Lines changed: 157 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,102 @@ func TestAPIKey(t *testing.T) {
508508
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
509509
})
510510

511+
t.Run("APIKeyExpiredOAuthExpired", func(t *testing.T) {
512+
t.Parallel()
513+
var (
514+
db = dbmem.New()
515+
user = dbgen.User(t, db, database.User{})
516+
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
517+
UserID: user.ID,
518+
LastUsed: dbtime.Now().AddDate(0, 0, -1),
519+
ExpiresAt: dbtime.Now().AddDate(0, 0, -1),
520+
LoginType: database.LoginTypeOIDC,
521+
})
522+
_ = dbgen.UserLink(t, db, database.UserLink{
523+
UserID: user.ID,
524+
LoginType: database.LoginTypeOIDC,
525+
OAuthExpiry: dbtime.Now().AddDate(0, 0, -1),
526+
})
527+
528+
r = httptest.NewRequest("GET", "/", nil)
529+
rw = httptest.NewRecorder()
530+
)
531+
r.Header.Set(codersdk.SessionTokenHeader, token)
532+
533+
// Include a valid oauth token for refreshing. If this token is invalid,
534+
// it is difficult to tell an auth failure from an expired api key, or
535+
// an expired oauth key.
536+
oauthToken := &oauth2.Token{
537+
AccessToken: "wow",
538+
RefreshToken: "moo",
539+
Expiry: dbtime.Now().AddDate(0, 0, 1),
540+
}
541+
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
542+
DB: db,
543+
OAuth2Configs: &httpmw.OAuth2Configs{
544+
OIDC: &testutil.OAuth2Config{
545+
Token: oauthToken,
546+
},
547+
},
548+
RedirectToLogin: false,
549+
})(successHandler).ServeHTTP(rw, r)
550+
res := rw.Result()
551+
defer res.Body.Close()
552+
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
553+
554+
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID)
555+
require.NoError(t, err)
556+
557+
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
558+
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
559+
})
560+
561+
t.Run("APIKeyExpiredOAuthNotExpired", func(t *testing.T) {
562+
t.Parallel()
563+
var (
564+
db = dbmem.New()
565+
user = dbgen.User(t, db, database.User{})
566+
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
567+
UserID: user.ID,
568+
LastUsed: dbtime.Now().AddDate(0, 0, -1),
569+
ExpiresAt: dbtime.Now().AddDate(0, 0, -1),
570+
LoginType: database.LoginTypeOIDC,
571+
})
572+
_ = dbgen.UserLink(t, db, database.UserLink{
573+
UserID: user.ID,
574+
LoginType: database.LoginTypeOIDC,
575+
})
576+
577+
r = httptest.NewRequest("GET", "/", nil)
578+
rw = httptest.NewRecorder()
579+
)
580+
r.Header.Set(codersdk.SessionTokenHeader, token)
581+
582+
oauthToken := &oauth2.Token{
583+
AccessToken: "wow",
584+
RefreshToken: "moo",
585+
Expiry: dbtime.Now().AddDate(0, 0, 1),
586+
}
587+
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
588+
DB: db,
589+
OAuth2Configs: &httpmw.OAuth2Configs{
590+
OIDC: &testutil.OAuth2Config{
591+
Token: oauthToken,
592+
},
593+
},
594+
RedirectToLogin: false,
595+
})(successHandler).ServeHTTP(rw, r)
596+
res := rw.Result()
597+
defer res.Body.Close()
598+
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
599+
600+
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID)
601+
require.NoError(t, err)
602+
603+
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
604+
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
605+
})
606+
511607
t.Run("OAuthRefresh", func(t *testing.T) {
512608
t.Parallel()
513609
var (
@@ -553,7 +649,67 @@ func TestAPIKey(t *testing.T) {
553649
require.NoError(t, err)
554650

555651
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
556-
require.Equal(t, oauthToken.Expiry, gotAPIKey.ExpiresAt)
652+
// Note that OAuth expiry is independent of APIKey expiry, so an OIDC refresh DOES NOT affect the expiry of the
653+
// APIKey
654+
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
655+
656+
gotLink, err := db.GetUserLinkByUserIDLoginType(r.Context(), database.GetUserLinkByUserIDLoginTypeParams{
657+
UserID: user.ID,
658+
LoginType: database.LoginTypeGithub,
659+
})
660+
require.NoError(t, err)
661+
require.Equal(t, gotLink.OAuthRefreshToken, "moo")
662+
})
663+
664+
t.Run("OAuthExpiredNoRefresh", func(t *testing.T) {
665+
t.Parallel()
666+
var (
667+
ctx = testutil.Context(t, testutil.WaitShort)
668+
db = dbmem.New()
669+
user = dbgen.User(t, db, database.User{})
670+
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
671+
UserID: user.ID,
672+
LastUsed: dbtime.Now(),
673+
ExpiresAt: dbtime.Now().AddDate(0, 0, 1),
674+
LoginType: database.LoginTypeGithub,
675+
})
676+
677+
r = httptest.NewRequest("GET", "/", nil)
678+
rw = httptest.NewRecorder()
679+
)
680+
_, err := db.InsertUserLink(ctx, database.InsertUserLinkParams{
681+
UserID: user.ID,
682+
LoginType: database.LoginTypeGithub,
683+
OAuthExpiry: dbtime.Now().AddDate(0, 0, -1),
684+
OAuthAccessToken: "letmein",
685+
})
686+
require.NoError(t, err)
687+
688+
r.Header.Set(codersdk.SessionTokenHeader, token)
689+
690+
oauthToken := &oauth2.Token{
691+
AccessToken: "wow",
692+
RefreshToken: "moo",
693+
Expiry: dbtime.Now().AddDate(0, 0, 1),
694+
}
695+
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
696+
DB: db,
697+
OAuth2Configs: &httpmw.OAuth2Configs{
698+
Github: &testutil.OAuth2Config{
699+
Token: oauthToken,
700+
},
701+
},
702+
RedirectToLogin: false,
703+
})(successHandler).ServeHTTP(rw, r)
704+
res := rw.Result()
705+
defer res.Body.Close()
706+
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
707+
708+
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID)
709+
require.NoError(t, err)
710+
711+
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
712+
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
557713
})
558714

559715
t.Run("RemoteIPUpdates", func(t *testing.T) {

coderd/oauthpki/okidcpki_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ func TestAzureAKPKIWithCoderd(t *testing.T) {
144144
return values, nil
145145
}),
146146
oidctest.WithServing(),
147+
oidctest.WithLogging(t, nil),
147148
)
148149
cfg := fake.OIDCConfig(t, scopes, func(cfg *coderd.OIDCConfig) {
149150
cfg.AllowSignups = true

0 commit comments

Comments
 (0)