diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index 67186a4fd7ddf..bc7f11318a8d5 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -215,7 +215,7 @@ func WithCustomClientAuth(hook func(t testing.TB, req *http.Request) (url.Values // WithLogging is optional, but will log some HTTP calls made to the IDP. func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) { return func(f *FakeIDP) { - f.logger = slogtest.Make(t, options) + f.logger = slogtest.Make(t, options).Named("fakeidp") } } @@ -700,6 +700,7 @@ func (f *FakeIDP) newToken(t testing.TB, email string, expires time.Time) string func (f *FakeIDP) newRefreshTokens(email string) string { refreshToken := uuid.NewString() f.refreshTokens.Store(refreshToken, email) + f.logger.Info(context.Background(), "new refresh token", slog.F("email", email), slog.F("token", refreshToken)) return refreshToken } @@ -909,6 +910,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { return } + f.logger.Info(r.Context(), "http idp call refresh_token", slog.F("token", refreshToken)) _, ok := f.refreshTokens.Load(refreshToken) if !assert.True(t, ok, "invalid refresh_token") { http.Error(rw, "invalid refresh_token", http.StatusBadRequest) @@ -932,6 +934,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { f.refreshTokensUsed.Store(refreshToken, true) // Always invalidate the refresh token after it is used. f.refreshTokens.Delete(refreshToken) + f.logger.Info(r.Context(), "refresh token invalidated", slog.F("token", refreshToken)) case "urn:ietf:params:oauth:grant-type:device_code": // Device flow var resp externalauth.ExchangeDeviceCodeResponse diff --git a/coderd/httpmw/apikey.go b/coderd/httpmw/apikey.go index d614b37a3d897..4b92848b773e2 100644 --- a/coderd/httpmw/apikey.go +++ b/coderd/httpmw/apikey.go @@ -232,16 +232,21 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon return optionalWrite(http.StatusUnauthorized, resp) } - var ( - link database.UserLink - now = dbtime.Now() - // Tracks if the API key has properties updated - changed = false - ) + now := dbtime.Now() + if key.ExpiresAt.Before(now) { + return optionalWrite(http.StatusUnauthorized, codersdk.Response{ + Message: SignedOutErrorMessage, + Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()), + }) + } + + // We only check OIDC stuff if we have a valid APIKey. An expired key means we don't trust the requestor + // really is the user whose key they have, and so we shouldn't be doing anything on their behalf including possibly + // refreshing the OIDC token. if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC { var err error //nolint:gocritic // System needs to fetch UserLink to check if it's valid. - link, err = cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{ + link, err := cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{ UserID: key.UserID, LoginType: key.LoginType, }) @@ -258,7 +263,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon }) } // Check if the OAuth token is expired - if link.OAuthExpiry.Before(now) && !link.OAuthExpiry.IsZero() && link.OAuthRefreshToken != "" { + if !link.OAuthExpiry.IsZero() && link.OAuthExpiry.Before(now) { if cfg.OAuth2Configs.IsZero() { return write(http.StatusInternalServerError, codersdk.Response{ Message: internalErrorMessage, @@ -267,12 +272,15 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon }) } + var friendlyName string var oauthConfig promoauth.OAuth2Config switch key.LoginType { case database.LoginTypeGithub: oauthConfig = cfg.OAuth2Configs.Github + friendlyName = "GitHub" case database.LoginTypeOIDC: oauthConfig = cfg.OAuth2Configs.OIDC + friendlyName = "OpenID Connect" default: return write(http.StatusInternalServerError, codersdk.Response{ Message: internalErrorMessage, @@ -292,7 +300,13 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon }) } - // If it is, let's refresh it from the provided config + if link.OAuthRefreshToken == "" { + return optionalWrite(http.StatusUnauthorized, codersdk.Response{ + Message: SignedOutErrorMessage, + Detail: fmt.Sprintf("%s session expired at %q. Try signing in again.", friendlyName, link.OAuthExpiry.String()), + }) + } + // We have a refresh token, so let's try it token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{ AccessToken: link.OAuthAccessToken, RefreshToken: link.OAuthRefreshToken, @@ -300,28 +314,39 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon }).Token() if err != nil { return write(http.StatusUnauthorized, codersdk.Response{ - Message: "Could not refresh expired Oauth token. Try re-authenticating to resolve this issue.", - Detail: err.Error(), + Message: fmt.Sprintf( + "Could not refresh expired %s token. Try re-authenticating to resolve this issue.", + friendlyName), + Detail: err.Error(), }) } link.OAuthAccessToken = token.AccessToken link.OAuthRefreshToken = token.RefreshToken link.OAuthExpiry = token.Expiry - key.ExpiresAt = token.Expiry - changed = true + //nolint:gocritic // system needs to update user link + link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{ + UserID: link.UserID, + LoginType: link.LoginType, + OAuthAccessToken: link.OAuthAccessToken, + OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required + OAuthRefreshToken: link.OAuthRefreshToken, + OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required + OAuthExpiry: link.OAuthExpiry, + // Refresh should keep the same debug context because we use + // the original claims for the group/role sync. + Claims: link.Claims, + }) + if err != nil { + return write(http.StatusInternalServerError, codersdk.Response{ + Message: internalErrorMessage, + Detail: fmt.Sprintf("update user_link: %s.", err.Error()), + }) + } } } - // Checking if the key is expired. - // NOTE: The `RequireAuth` React component depends on this `Detail` to detect when - // the users token has expired. If you change the text here, make sure to update it - // in site/src/components/RequireAuth/RequireAuth.tsx as well. - if key.ExpiresAt.Before(now) { - return optionalWrite(http.StatusUnauthorized, codersdk.Response{ - Message: SignedOutErrorMessage, - Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()), - }) - } + // Tracks if the API key has properties updated + changed := false // Only update LastUsed once an hour to prevent database spam. if now.Sub(key.LastUsed) > time.Hour { @@ -363,29 +388,6 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon Detail: fmt.Sprintf("API key couldn't update: %s.", err.Error()), }) } - // If the API Key is associated with a user_link (e.g. Github/OIDC) - // then we want to update the relevant oauth fields. - if link.UserID != uuid.Nil { - //nolint:gocritic // system needs to update user link - link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{ - UserID: link.UserID, - LoginType: link.LoginType, - OAuthAccessToken: link.OAuthAccessToken, - OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required - OAuthRefreshToken: link.OAuthRefreshToken, - OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required - OAuthExpiry: link.OAuthExpiry, - // Refresh should keep the same debug context because we use - // the original claims for the group/role sync. - Claims: link.Claims, - }) - if err != nil { - return write(http.StatusInternalServerError, codersdk.Response{ - Message: internalErrorMessage, - Detail: fmt.Sprintf("update user_link: %s.", err.Error()), - }) - } - } // We only want to update this occasionally to reduce DB write // load. We update alongside the UserLink and APIKey since it's diff --git a/coderd/httpmw/apikey_test.go b/coderd/httpmw/apikey_test.go index bd979e88235ad..6e2e75ace9825 100644 --- a/coderd/httpmw/apikey_test.go +++ b/coderd/httpmw/apikey_test.go @@ -508,6 +508,102 @@ func TestAPIKey(t *testing.T) { require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt) }) + t.Run("APIKeyExpiredOAuthExpired", func(t *testing.T) { + t.Parallel() + var ( + db = dbmem.New() + user = dbgen.User(t, db, database.User{}) + sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ + UserID: user.ID, + LastUsed: dbtime.Now().AddDate(0, 0, -1), + ExpiresAt: dbtime.Now().AddDate(0, 0, -1), + LoginType: database.LoginTypeOIDC, + }) + _ = dbgen.UserLink(t, db, database.UserLink{ + UserID: user.ID, + LoginType: database.LoginTypeOIDC, + OAuthExpiry: dbtime.Now().AddDate(0, 0, -1), + }) + + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.Header.Set(codersdk.SessionTokenHeader, token) + + // Include a valid oauth token for refreshing. If this token is invalid, + // it is difficult to tell an auth failure from an expired api key, or + // an expired oauth key. + oauthToken := &oauth2.Token{ + AccessToken: "wow", + RefreshToken: "moo", + Expiry: dbtime.Now().AddDate(0, 0, 1), + } + httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ + DB: db, + OAuth2Configs: &httpmw.OAuth2Configs{ + OIDC: &testutil.OAuth2Config{ + Token: oauthToken, + }, + }, + RedirectToLogin: false, + })(successHandler).ServeHTTP(rw, r) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + + gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) + require.NoError(t, err) + + require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed) + require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt) + }) + + t.Run("APIKeyExpiredOAuthNotExpired", func(t *testing.T) { + t.Parallel() + var ( + db = dbmem.New() + user = dbgen.User(t, db, database.User{}) + sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ + UserID: user.ID, + LastUsed: dbtime.Now().AddDate(0, 0, -1), + ExpiresAt: dbtime.Now().AddDate(0, 0, -1), + LoginType: database.LoginTypeOIDC, + }) + _ = dbgen.UserLink(t, db, database.UserLink{ + UserID: user.ID, + LoginType: database.LoginTypeOIDC, + }) + + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.Header.Set(codersdk.SessionTokenHeader, token) + + oauthToken := &oauth2.Token{ + AccessToken: "wow", + RefreshToken: "moo", + Expiry: dbtime.Now().AddDate(0, 0, 1), + } + httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ + DB: db, + OAuth2Configs: &httpmw.OAuth2Configs{ + OIDC: &testutil.OAuth2Config{ + Token: oauthToken, + }, + }, + RedirectToLogin: false, + })(successHandler).ServeHTTP(rw, r) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + + gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) + require.NoError(t, err) + + require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed) + require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt) + }) + t.Run("OAuthRefresh", func(t *testing.T) { t.Parallel() var ( @@ -553,7 +649,67 @@ func TestAPIKey(t *testing.T) { require.NoError(t, err) require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed) - require.Equal(t, oauthToken.Expiry, gotAPIKey.ExpiresAt) + // Note that OAuth expiry is independent of APIKey expiry, so an OIDC refresh DOES NOT affect the expiry of the + // APIKey + require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt) + + gotLink, err := db.GetUserLinkByUserIDLoginType(r.Context(), database.GetUserLinkByUserIDLoginTypeParams{ + UserID: user.ID, + LoginType: database.LoginTypeGithub, + }) + require.NoError(t, err) + require.Equal(t, gotLink.OAuthRefreshToken, "moo") + }) + + t.Run("OAuthExpiredNoRefresh", func(t *testing.T) { + t.Parallel() + var ( + ctx = testutil.Context(t, testutil.WaitShort) + db = dbmem.New() + user = dbgen.User(t, db, database.User{}) + sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ + UserID: user.ID, + LastUsed: dbtime.Now(), + ExpiresAt: dbtime.Now().AddDate(0, 0, 1), + LoginType: database.LoginTypeGithub, + }) + + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + _, err := db.InsertUserLink(ctx, database.InsertUserLinkParams{ + UserID: user.ID, + LoginType: database.LoginTypeGithub, + OAuthExpiry: dbtime.Now().AddDate(0, 0, -1), + OAuthAccessToken: "letmein", + }) + require.NoError(t, err) + + r.Header.Set(codersdk.SessionTokenHeader, token) + + oauthToken := &oauth2.Token{ + AccessToken: "wow", + RefreshToken: "moo", + Expiry: dbtime.Now().AddDate(0, 0, 1), + } + httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ + DB: db, + OAuth2Configs: &httpmw.OAuth2Configs{ + Github: &testutil.OAuth2Config{ + Token: oauthToken, + }, + }, + RedirectToLogin: false, + })(successHandler).ServeHTTP(rw, r) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + + gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) + require.NoError(t, err) + + require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed) + require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt) }) t.Run("RemoteIPUpdates", func(t *testing.T) { diff --git a/coderd/oauthpki/okidcpki_test.go b/coderd/oauthpki/okidcpki_test.go index 509da563a9145..7f7dda17bcba8 100644 --- a/coderd/oauthpki/okidcpki_test.go +++ b/coderd/oauthpki/okidcpki_test.go @@ -144,6 +144,7 @@ func TestAzureAKPKIWithCoderd(t *testing.T) { return values, nil }), oidctest.WithServing(), + oidctest.WithLogging(t, nil), ) cfg := fake.OIDCConfig(t, scopes, func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true