diff --git a/coderd/coderd.go b/coderd/coderd.go index 58b6c902c7dbc..3407be5ac8de2 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -693,7 +693,6 @@ func New(options *Options) *API { r.Route("/github", func(r chi.Router) { r.Use( httpmw.ExtractOAuth2(options.GithubOAuth2Config, options.HTTPClient, nil), - apiKeyMiddlewareOptional, ) r.Get("/callback", api.userOAuth2Github) }) @@ -701,7 +700,6 @@ func New(options *Options) *API { r.Route("/oidc/callback", func(r chi.Router) { r.Use( httpmw.ExtractOAuth2(options.OIDCConfig, options.HTTPClient, oidcAuthURLParams), - apiKeyMiddlewareOptional, ) r.Get("/", api.userOIDC) }) diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 71e3336ab2e87..5e2e55d5c032f 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -1022,9 +1022,31 @@ func NewAWSInstanceIdentity(t *testing.T, instanceID string) (awsidentity.Certif type OIDCConfig struct { key *rsa.PrivateKey issuer string + // These are optional + refreshToken string + oidcTokenExpires func() time.Time + tokenSource func() (*oauth2.Token, error) } -func NewOIDCConfig(t *testing.T, issuer string) *OIDCConfig { +func WithRefreshToken(token string) func(cfg *OIDCConfig) { + return func(cfg *OIDCConfig) { + cfg.refreshToken = token + } +} + +func WithTokenExpires(expFunc func() time.Time) func(cfg *OIDCConfig) { + return func(cfg *OIDCConfig) { + cfg.oidcTokenExpires = expFunc + } +} + +func WithTokenSource(src func() (*oauth2.Token, error)) func(cfg *OIDCConfig) { + return func(cfg *OIDCConfig) { + cfg.tokenSource = src + } +} + +func NewOIDCConfig(t *testing.T, issuer string, opts ...func(cfg *OIDCConfig)) *OIDCConfig { t.Helper() block, _ := pem.Decode([]byte(testRSAPrivateKey)) @@ -1035,33 +1057,58 @@ func NewOIDCConfig(t *testing.T, issuer string) *OIDCConfig { issuer = "https://coder.com" } - return &OIDCConfig{ + cfg := &OIDCConfig{ key: pkey, issuer: issuer, } + for _, opt := range opts { + opt(cfg) + } + return cfg } func (*OIDCConfig) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string { return "/?state=" + url.QueryEscape(state) } -func (*OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource { - return nil +type tokenSource struct { + src func() (*oauth2.Token, error) +} + +func (s tokenSource) Token() (*oauth2.Token, error) { + return s.src() +} + +func (cfg *OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource { + if cfg.tokenSource == nil { + return nil + } + return tokenSource{ + src: cfg.tokenSource, + } } -func (*OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) { +func (cfg *OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) { token, err := base64.StdEncoding.DecodeString(code) if err != nil { return nil, xerrors.Errorf("decode code: %w", err) } + + var exp time.Time + if cfg.oidcTokenExpires != nil { + exp = cfg.oidcTokenExpires() + } + return (&oauth2.Token{ - AccessToken: "token", + AccessToken: "token", + RefreshToken: cfg.refreshToken, + Expiry: exp, }).WithExtra(map[string]interface{}{ "id_token": string(token), }), nil } -func (o *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string { +func (cfg *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string { t.Helper() if _, ok := claims["exp"]; !ok { @@ -1069,20 +1116,20 @@ func (o *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string { } if _, ok := claims["iss"]; !ok { - claims["iss"] = o.issuer + claims["iss"] = cfg.issuer } if _, ok := claims["sub"]; !ok { claims["sub"] = "testme" } - signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(o.key) + signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(cfg.key) require.NoError(t, err) return base64.StdEncoding.EncodeToString([]byte(signed)) } -func (o *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig { +func (cfg *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig { // By default, the provider can be empty. // This means it won't support any endpoints! provider := &oidc.Provider{} @@ -1099,10 +1146,10 @@ func (o *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts } provider = cfg.NewProvider(context.Background()) } - cfg := &coderd.OIDCConfig{ - OAuth2Config: o, - Verifier: oidc.NewVerifier(o.issuer, &oidc.StaticKeySet{ - PublicKeys: []crypto.PublicKey{o.key.Public()}, + newCFG := &coderd.OIDCConfig{ + OAuth2Config: cfg, + Verifier: oidc.NewVerifier(cfg.issuer, &oidc.StaticKeySet{ + PublicKeys: []crypto.PublicKey{cfg.key.Public()}, }, &oidc.Config{ SkipClientIDCheck: true, }), @@ -1113,9 +1160,9 @@ func (o *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts GroupField: "groups", } for _, opt := range opts { - opt(cfg) + opt(newCFG) } - return cfg + return newCFG } // NewAzureInstanceIdentity returns a metadata client and ID token validator for faking diff --git a/coderd/httpmw/apikey.go b/coderd/httpmw/apikey.go index 5f0ec0dc263c7..f8f809761787c 100644 --- a/coderd/httpmw/apikey.go +++ b/coderd/httpmw/apikey.go @@ -142,6 +142,56 @@ func ExtractAPIKeyMW(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler { } } +func APIKeyFromRequest(ctx context.Context, db database.Store, sessionTokenFunc func(r *http.Request) string, r *http.Request) (*database.APIKey, codersdk.Response, bool) { + tokenFunc := APITokenFromRequest + if sessionTokenFunc != nil { + tokenFunc = sessionTokenFunc + } + + token := tokenFunc(r) + if token == "" { + return nil, codersdk.Response{ + Message: SignedOutErrorMessage, + Detail: fmt.Sprintf("Cookie %q or query parameter must be provided.", codersdk.SessionTokenCookie), + }, false + } + + keyID, keySecret, err := SplitAPIToken(token) + if err != nil { + return nil, codersdk.Response{ + Message: SignedOutErrorMessage, + Detail: "Invalid API key format: " + err.Error(), + }, false + } + + //nolint:gocritic // System needs to fetch API key to check if it's valid. + key, err := db.GetAPIKeyByID(dbauthz.AsSystemRestricted(ctx), keyID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, codersdk.Response{ + Message: SignedOutErrorMessage, + Detail: "API key is invalid.", + }, false + } + + return nil, codersdk.Response{ + Message: internalErrorMessage, + Detail: fmt.Sprintf("Internal error fetching API key by id. %s", err.Error()), + }, false + } + + // Checking to see if the secret is valid. + hashedSecret := sha256.Sum256([]byte(keySecret)) + if subtle.ConstantTimeCompare(key.HashedSecret, hashedSecret[:]) != 1 { + return nil, codersdk.Response{ + Message: SignedOutErrorMessage, + Detail: "API key secret is invalid.", + }, false + } + + return &key, codersdk.Response{}, true +} + // ExtractAPIKey requires authentication using a valid API key. It handles // extending an API key if it comes close to expiry, updating the last used time // in the database. @@ -179,49 +229,9 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon return nil, nil, false } - tokenFunc := APITokenFromRequest - if cfg.SessionTokenFunc != nil { - tokenFunc = cfg.SessionTokenFunc - } - token := tokenFunc(r) - if token == "" { - return optionalWrite(http.StatusUnauthorized, codersdk.Response{ - Message: SignedOutErrorMessage, - Detail: fmt.Sprintf("Cookie %q or query parameter must be provided.", codersdk.SessionTokenCookie), - }) - } - - keyID, keySecret, err := SplitAPIToken(token) - if err != nil { - return optionalWrite(http.StatusUnauthorized, codersdk.Response{ - Message: SignedOutErrorMessage, - Detail: "Invalid API key format: " + err.Error(), - }) - } - - //nolint:gocritic // System needs to fetch API key to check if it's valid. - key, err := cfg.DB.GetAPIKeyByID(dbauthz.AsSystemRestricted(ctx), keyID) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return optionalWrite(http.StatusUnauthorized, codersdk.Response{ - Message: SignedOutErrorMessage, - Detail: "API key is invalid.", - }) - } - - return write(http.StatusInternalServerError, codersdk.Response{ - Message: internalErrorMessage, - Detail: fmt.Sprintf("Internal error fetching API key by id. %s", err.Error()), - }) - } - - // Checking to see if the secret is valid. - hashedSecret := sha256.Sum256([]byte(keySecret)) - if subtle.ConstantTimeCompare(key.HashedSecret, hashedSecret[:]) != 1 { - return optionalWrite(http.StatusUnauthorized, codersdk.Response{ - Message: SignedOutErrorMessage, - Detail: "API key secret is invalid.", - }) + key, resp, ok := APIKeyFromRequest(ctx, cfg.DB, cfg.SessionTokenFunc, r) + if !ok { + return optionalWrite(http.StatusUnauthorized, resp) } var ( @@ -232,7 +242,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon ) if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC { //nolint:gocritic // System needs to fetch UserLink to check if it's valid. - link, err = cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{ + link, err := cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{ UserID: key.UserID, LoginType: key.LoginType, }) @@ -427,7 +437,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon }.WithCachedASTValue(), } - return &key, &authz, true + return key, &authz, true } // APITokenFromRequest returns the api token from the request. diff --git a/coderd/userauth.go b/coderd/userauth.go index 9b6ba7992bad5..f1e110c08bfdc 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -1427,7 +1427,8 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C } var key database.APIKey - if oldKey, ok := httpmw.APIKeyOptional(r); ok && isConvertLoginType { + oldKey, _, ok := httpmw.APIKeyFromRequest(ctx, api.Database, nil, r) + if ok && oldKey != nil && isConvertLoginType { // If this is a convert login type, and it succeeds, then delete the old // session. Force the user to log back in. err := api.Database.DeleteAPIKeyByID(r.Context(), oldKey.ID) @@ -1447,7 +1448,9 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C Secure: api.SecureAuthCookie, HttpOnly: true, }) - key = oldKey + // This is intentional setting the key to the deleted old key, + // as the user needs to be forced to log back in. + key = *oldKey } else { //nolint:gocritic cookie, newKey, err := api.createAPIKey(dbauthz.AsSystemRestricted(ctx), apikey.CreateParams{ diff --git a/coderd/userauth_test.go b/coderd/userauth_test.go index 6f49222ff8764..efa7673890863 100644 --- a/coderd/userauth_test.go +++ b/coderd/userauth_test.go @@ -9,6 +9,7 @@ import ( "net/http/cookiejar" "strings" "testing" + "time" "github.com/coreos/go-oidc/v3/oidc" "github.com/golang-jwt/jwt" @@ -24,12 +25,97 @@ import ( "github.com/coder/coder/coderd/audit" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbauthz" "github.com/coder/coder/coderd/database/dbgen" "github.com/coder/coder/coderd/database/dbtestutil" "github.com/coder/coder/codersdk" "github.com/coder/coder/testutil" ) +// This test specifically tests logging in with OIDC when an expired +// OIDC session token exists. +// The token refreshing should not happen since we are reauthenticating. +func TestOIDCOauthLoginWithExisting(t *testing.T) { + t.Parallel() + + conf := coderdtest.NewOIDCConfig(t, "", + // Provide a refresh token so we use the refresh token flow + coderdtest.WithRefreshToken("refresh_token"), + // We need to set the expire in the future for the first api calls. + coderdtest.WithTokenExpires(func() time.Time { + return time.Now().Add(time.Hour).UTC() + }), + // No refresh should actually happen in this test. + coderdtest.WithTokenSource(func() (*oauth2.Token, error) { + return nil, xerrors.New("token should not require refresh") + }), + ) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + auditor := audit.NewMock() + const username = "alice" + claims := jwt.MapClaims{ + "email": "alice@coder.com", + "email_verified": true, + "preferred_username": username, + } + config := conf.OIDCConfig(t, claims) + + config.AllowSignups = true + config.IgnoreUserInfo = true + client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ + Auditor: auditor, + OIDCConfig: config, + Logger: &logger, + }) + + // Signup alice + resp := oidcCallback(t, client, conf.EncodeClaims(t, claims)) + // Set the client to use this OIDC context + authCookie := authCookieValue(resp.Cookies()) + client.SetSessionToken(authCookie) + _ = resp.Body.Close() + + ctx := testutil.Context(t, testutil.WaitLong) + // Verify the user and oauth link + user, err := client.User(ctx, "me") + require.NoError(t, err) + require.Equal(t, username, user.Username) + + // nolint:gocritic + link, err := api.Database.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{ + UserID: user.ID, + LoginType: database.LoginType(user.LoginType), + }) + require.NoError(t, err, "failed to get user link") + + // Expire the link + // nolint:gocritic + _, err = api.Database.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{ + OAuthAccessToken: link.OAuthAccessToken, + OAuthRefreshToken: link.OAuthRefreshToken, + OAuthExpiry: time.Now().Add(time.Hour * -1).UTC(), + UserID: link.UserID, + LoginType: link.LoginType, + }) + require.NoError(t, err, "failed to update user link") + + // Log in again with OIDC + loginAgain := oidcCallbackWithState(t, client, conf.EncodeClaims(t, claims), "seconds_login", func(req *http.Request) { + req.AddCookie(&http.Cookie{ + Name: codersdk.SessionTokenCookie, + Value: authCookie, + Path: "/", + }) + }) + require.Equal(t, http.StatusTemporaryRedirect, loginAgain.StatusCode) + _ = loginAgain.Body.Close() + + // Try to use new login + client.SetSessionToken(authCookieValue(resp.Cookies())) + _, err = client.User(ctx, "me") + require.NoError(t, err, "use new session") +} + func TestUserLogin(t *testing.T) { t.Parallel() t.Run("OK", func(t *testing.T) { @@ -819,7 +905,7 @@ func TestUserOIDC(t *testing.T) { }) require.NoError(t, err) - resp := oidcCallbackWithState(t, user, code, convertResponse.StateString) + resp := oidcCallbackWithState(t, user, code, convertResponse.StateString, nil) require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) }) @@ -1045,10 +1131,10 @@ func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response { } func oidcCallback(t *testing.T, client *codersdk.Client, code string) *http.Response { - return oidcCallbackWithState(t, client, code, "somestate") + return oidcCallbackWithState(t, client, code, "somestate", nil) } -func oidcCallbackWithState(t *testing.T, client *codersdk.Client, code, state string) *http.Response { +func oidcCallbackWithState(t *testing.T, client *codersdk.Client, code, state string, modify func(r *http.Request)) *http.Response { t.Helper() client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { @@ -1062,6 +1148,9 @@ func oidcCallbackWithState(t *testing.T, client *codersdk.Client, code, state st Name: codersdk.OAuth2StateCookie, Value: state, }) + if modify != nil { + modify(req) + } res, err := client.HTTPClient.Do(req) require.NoError(t, err) defer res.Body.Close() diff --git a/enterprise/coderd/userauth_test.go b/enterprise/coderd/userauth_test.go index 428cf91a6fef2..2cb110abe987b 100644 --- a/enterprise/coderd/userauth_test.go +++ b/enterprise/coderd/userauth_test.go @@ -99,6 +99,7 @@ func TestUserOIDC(t *testing.T) { "roles": []string{"random", oidcRoleName, rbac.RoleOwner()}, })) require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + _ = resp.Body.Close() user, err := client.User(ctx, "alice") require.NoError(t, err) @@ -112,6 +113,7 @@ func TestUserOIDC(t *testing.T) { "roles": []string{"random"}, })) require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + _ = resp.Body.Close() user, err = client.User(ctx, "alice") require.NoError(t, err)