Skip to content

Commit 473d5a4

Browse files
committed
fix: stop extending API key access if OIDC refresh is available
1 parent 90e93a2 commit 473d5a4

File tree

4 files changed

+204
-42
lines changed

4 files changed

+204
-42
lines changed

coderd/coderdtest/oidctest/idp.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ func WithCustomClientAuth(hook func(t testing.TB, req *http.Request) (url.Values
307307
// WithLogging is optional, but will log some HTTP calls made to the IDP.
308308
func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
309309
return func(f *FakeIDP) {
310-
f.logger = slogtest.Make(t, options)
310+
f.logger = slogtest.Make(t, options).Named("fakeidp")
311311
}
312312
}
313313

@@ -794,6 +794,7 @@ func (f *FakeIDP) newToken(t testing.TB, email string, expires time.Time) string
794794
func (f *FakeIDP) newRefreshTokens(email string) string {
795795
refreshToken := uuid.NewString()
796796
f.refreshTokens.Store(refreshToken, email)
797+
f.logger.Info(context.Background(), "new refresh token", slog.F("email", email), slog.F("token", refreshToken))
797798
return refreshToken
798799
}
799800

@@ -1003,6 +1004,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
10031004
return
10041005
}
10051006

1007+
f.logger.Info(r.Context(), "http idp call refresh_token", slog.F("token", refreshToken))
10061008
_, ok := f.refreshTokens.Load(refreshToken)
10071009
if !assert.True(t, ok, "invalid refresh_token") {
10081010
http.Error(rw, "invalid refresh_token", http.StatusBadRequest)
@@ -1026,6 +1028,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
10261028
f.refreshTokensUsed.Store(refreshToken, true)
10271029
// Always invalidate the refresh token after it is used.
10281030
f.refreshTokens.Delete(refreshToken)
1031+
f.logger.Info(r.Context(), "refresh token invalidated", slog.F("token", refreshToken))
10291032
case "urn:ietf:params:oauth:grant-type:device_code":
10301033
// Device flow
10311034
var resp externalauth.ExchangeDeviceCodeResponse

coderd/httpmw/apikey.go

Lines changed: 45 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,17 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
238238
// Tracks if the API key has properties updated
239239
changed = false
240240
)
241+
242+
if key.ExpiresAt.Before(now) {
243+
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
244+
Message: SignedOutErrorMessage,
245+
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
246+
})
247+
}
248+
249+
// We only check OIDC stuff if we have a valid APIKey. An expired key means we don't trust the requestor
250+
// really is the user whose key they have, and so we shouldn't be doing anything on their behalf including possibly
251+
// refreshing the OIDC token.
241252
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
242253
var err error
243254
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
@@ -258,7 +269,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
258269
})
259270
}
260271
// Check if the OAuth token is expired
261-
if link.OAuthExpiry.Before(now) && !link.OAuthExpiry.IsZero() && link.OAuthRefreshToken != "" {
272+
if !link.OAuthExpiry.IsZero() && link.OAuthExpiry.Before(now) {
262273
if cfg.OAuth2Configs.IsZero() {
263274
return write(http.StatusInternalServerError, codersdk.Response{
264275
Message: internalErrorMessage,
@@ -267,12 +278,15 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
267278
})
268279
}
269280

281+
var friendlyName string
270282
var oauthConfig promoauth.OAuth2Config
271283
switch key.LoginType {
272284
case database.LoginTypeGithub:
273285
oauthConfig = cfg.OAuth2Configs.Github
286+
friendlyName = "GitHub"
274287
case database.LoginTypeOIDC:
275288
oauthConfig = cfg.OAuth2Configs.OIDC
289+
friendlyName = "OpenID Connect"
276290
default:
277291
return write(http.StatusInternalServerError, codersdk.Response{
278292
Message: internalErrorMessage,
@@ -292,37 +306,51 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
292306
})
293307
}
294308

295-
// If it is, let's refresh it from the provided config
309+
if link.OAuthRefreshToken == "" {
310+
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
311+
Message: SignedOutErrorMessage,
312+
Detail: fmt.Sprintf("%s session expired at %q.", friendlyName, link.OAuthExpiry.String()),
313+
})
314+
}
315+
// We have a refresh token, so let's try it
296316
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
297317
AccessToken: link.OAuthAccessToken,
298318
RefreshToken: link.OAuthRefreshToken,
299319
Expiry: link.OAuthExpiry,
300320
}).Token()
301321
if err != nil {
302322
return write(http.StatusUnauthorized, codersdk.Response{
303-
Message: "Could not refresh expired Oauth token. Try re-authenticating to resolve this issue.",
304-
Detail: err.Error(),
323+
Message: fmt.Sprintf(
324+
"Could not refresh expired %s token. Try re-authenticating to resolve this issue.",
325+
friendlyName),
326+
Detail: err.Error(),
305327
})
306328
}
307329
link.OAuthAccessToken = token.AccessToken
308330
link.OAuthRefreshToken = token.RefreshToken
309331
link.OAuthExpiry = token.Expiry
310-
key.ExpiresAt = token.Expiry
311-
changed = true
332+
//nolint:gocritic // system needs to update user link
333+
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
334+
UserID: link.UserID,
335+
LoginType: link.LoginType,
336+
OAuthAccessToken: link.OAuthAccessToken,
337+
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
338+
OAuthRefreshToken: link.OAuthRefreshToken,
339+
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
340+
OAuthExpiry: link.OAuthExpiry,
341+
// Refresh should keep the same debug context because we use
342+
// the original claims for the group/role sync.
343+
Claims: link.Claims,
344+
})
345+
if err != nil {
346+
return write(http.StatusInternalServerError, codersdk.Response{
347+
Message: internalErrorMessage,
348+
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
349+
})
350+
}
312351
}
313352
}
314353

315-
// Checking if the key is expired.
316-
// NOTE: The `RequireAuth` React component depends on this `Detail` to detect when
317-
// the users token has expired. If you change the text here, make sure to update it
318-
// in site/src/components/RequireAuth/RequireAuth.tsx as well.
319-
if key.ExpiresAt.Before(now) {
320-
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
321-
Message: SignedOutErrorMessage,
322-
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
323-
})
324-
}
325-
326354
// Only update LastUsed once an hour to prevent database spam.
327355
if now.Sub(key.LastUsed) > time.Hour {
328356
key.LastUsed = now
@@ -363,29 +391,6 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
363391
Detail: fmt.Sprintf("API key couldn't update: %s.", err.Error()),
364392
})
365393
}
366-
// If the API Key is associated with a user_link (e.g. Github/OIDC)
367-
// then we want to update the relevant oauth fields.
368-
if link.UserID != uuid.Nil {
369-
//nolint:gocritic // system needs to update user link
370-
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
371-
UserID: link.UserID,
372-
LoginType: link.LoginType,
373-
OAuthAccessToken: link.OAuthAccessToken,
374-
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
375-
OAuthRefreshToken: link.OAuthRefreshToken,
376-
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
377-
OAuthExpiry: link.OAuthExpiry,
378-
// Refresh should keep the same debug context because we use
379-
// the original claims for the group/role sync.
380-
Claims: link.Claims,
381-
})
382-
if err != nil {
383-
return write(http.StatusInternalServerError, codersdk.Response{
384-
Message: internalErrorMessage,
385-
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
386-
})
387-
}
388-
}
389394

390395
// We only want to update this occasionally to reduce DB write
391396
// load. We update alongside the UserLink and APIKey since it's

coderd/httpmw/apikey_test.go

Lines changed: 154 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,99 @@ func TestAPIKey(t *testing.T) {
508508
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
509509
})
510510

511+
t.Run("APIKeyExpiredOAuthExpired", func(t *testing.T) {
512+
t.Parallel()
513+
var (
514+
db = dbmem.New()
515+
user = dbgen.User(t, db, database.User{})
516+
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
517+
UserID: user.ID,
518+
LastUsed: dbtime.Now().AddDate(0, 0, -1),
519+
ExpiresAt: dbtime.Now().AddDate(0, 0, -1),
520+
LoginType: database.LoginTypeOIDC,
521+
})
522+
_ = dbgen.UserLink(t, db, database.UserLink{
523+
UserID: user.ID,
524+
LoginType: database.LoginTypeOIDC,
525+
OAuthExpiry: dbtime.Now().AddDate(0, 0, -1),
526+
})
527+
528+
r = httptest.NewRequest("GET", "/", nil)
529+
rw = httptest.NewRecorder()
530+
)
531+
r.Header.Set(codersdk.SessionTokenHeader, token)
532+
533+
oauthToken := &oauth2.Token{
534+
AccessToken: "wow",
535+
RefreshToken: "moo",
536+
Expiry: dbtime.Now().AddDate(0, 0, 1),
537+
}
538+
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
539+
DB: db,
540+
OAuth2Configs: &httpmw.OAuth2Configs{
541+
OIDC: &testutil.OAuth2Config{
542+
Token: oauthToken,
543+
},
544+
},
545+
RedirectToLogin: false,
546+
})(successHandler).ServeHTTP(rw, r)
547+
res := rw.Result()
548+
defer res.Body.Close()
549+
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
550+
551+
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID)
552+
require.NoError(t, err)
553+
554+
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
555+
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
556+
})
557+
558+
t.Run("APIKeyExpiredOAuthNotExpired", func(t *testing.T) {
559+
t.Parallel()
560+
var (
561+
db = dbmem.New()
562+
user = dbgen.User(t, db, database.User{})
563+
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
564+
UserID: user.ID,
565+
LastUsed: dbtime.Now().AddDate(0, 0, -1),
566+
ExpiresAt: dbtime.Now().AddDate(0, 0, -1),
567+
LoginType: database.LoginTypeOIDC,
568+
})
569+
_ = dbgen.UserLink(t, db, database.UserLink{
570+
UserID: user.ID,
571+
LoginType: database.LoginTypeOIDC,
572+
})
573+
574+
r = httptest.NewRequest("GET", "/", nil)
575+
rw = httptest.NewRecorder()
576+
)
577+
r.Header.Set(codersdk.SessionTokenHeader, token)
578+
579+
oauthToken := &oauth2.Token{
580+
AccessToken: "wow",
581+
RefreshToken: "moo",
582+
Expiry: dbtime.Now().AddDate(0, 0, 1),
583+
}
584+
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
585+
DB: db,
586+
OAuth2Configs: &httpmw.OAuth2Configs{
587+
OIDC: &testutil.OAuth2Config{
588+
Token: oauthToken,
589+
},
590+
},
591+
RedirectToLogin: false,
592+
})(successHandler).ServeHTTP(rw, r)
593+
res := rw.Result()
594+
defer res.Body.Close()
595+
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
596+
597+
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID)
598+
require.NoError(t, err)
599+
600+
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
601+
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
602+
})
603+
511604
t.Run("OAuthRefresh", func(t *testing.T) {
512605
t.Parallel()
513606
var (
@@ -553,7 +646,67 @@ func TestAPIKey(t *testing.T) {
553646
require.NoError(t, err)
554647

555648
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
556-
require.Equal(t, oauthToken.Expiry, gotAPIKey.ExpiresAt)
649+
// Note that OAuth expiry is independent of APIKey expiry, so an OIDC refresh DOES NOT affect the expiry of the
650+
// APIKey
651+
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
652+
653+
gotLink, err := db.GetUserLinkByUserIDLoginType(r.Context(), database.GetUserLinkByUserIDLoginTypeParams{
654+
UserID: user.ID,
655+
LoginType: database.LoginTypeGithub,
656+
})
657+
require.NoError(t, err)
658+
require.Equal(t, gotLink.OAuthRefreshToken, "moo")
659+
})
660+
661+
t.Run("OAuthExpiredNoRefresh", func(t *testing.T) {
662+
t.Parallel()
663+
var (
664+
ctx = testutil.Context(t, testutil.WaitShort)
665+
db = dbmem.New()
666+
user = dbgen.User(t, db, database.User{})
667+
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
668+
UserID: user.ID,
669+
LastUsed: dbtime.Now(),
670+
ExpiresAt: dbtime.Now().AddDate(0, 0, 1),
671+
LoginType: database.LoginTypeGithub,
672+
})
673+
674+
r = httptest.NewRequest("GET", "/", nil)
675+
rw = httptest.NewRecorder()
676+
)
677+
_, err := db.InsertUserLink(ctx, database.InsertUserLinkParams{
678+
UserID: user.ID,
679+
LoginType: database.LoginTypeGithub,
680+
OAuthExpiry: dbtime.Now().AddDate(0, 0, -1),
681+
OAuthAccessToken: "letmein",
682+
})
683+
require.NoError(t, err)
684+
685+
r.Header.Set(codersdk.SessionTokenHeader, token)
686+
687+
oauthToken := &oauth2.Token{
688+
AccessToken: "wow",
689+
RefreshToken: "moo",
690+
Expiry: dbtime.Now().AddDate(0, 0, 1),
691+
}
692+
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
693+
DB: db,
694+
OAuth2Configs: &httpmw.OAuth2Configs{
695+
Github: &testutil.OAuth2Config{
696+
Token: oauthToken,
697+
},
698+
},
699+
RedirectToLogin: false,
700+
})(successHandler).ServeHTTP(rw, r)
701+
res := rw.Result()
702+
defer res.Body.Close()
703+
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
704+
705+
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID)
706+
require.NoError(t, err)
707+
708+
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
709+
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
557710
})
558711

559712
t.Run("RemoteIPUpdates", func(t *testing.T) {

coderd/oauthpki/okidcpki_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ func TestAzureAKPKIWithCoderd(t *testing.T) {
144144
return values, nil
145145
}),
146146
oidctest.WithServing(),
147+
oidctest.WithLogging(t, nil),
147148
)
148149
cfg := fake.OIDCConfig(t, scopes, func(cfg *coderd.OIDCConfig) {
149150
cfg.AllowSignups = true

0 commit comments

Comments
 (0)