diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index 3ca8cadbc9ff9..6f060aea2c6b6 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -7,6 +7,7 @@ import ( "crypto/x509" "encoding/json" "encoding/pem" + "errors" "fmt" "io" "net" @@ -41,7 +42,7 @@ import ( type FakeIDP struct { issuer string key *rsa.PrivateKey - provider providerJSON + provider ProviderJSON handler http.Handler cfg *oauth2.Config @@ -66,7 +67,7 @@ type FakeIDP struct { // IDP -> Application. Almost all IDPs have the concept of // "Authorized Redirect URLs". This can be used to emulate that. hookValidRedirectURL func(redirectURL string) error - hookUserInfo func(email string) jwt.MapClaims + hookUserInfo func(email string) (jwt.MapClaims, error) fakeCoderd func(req *http.Request) (*http.Response, error) hookOnRefresh func(email string) error // Custom authentication for the client. This is useful if you want @@ -75,6 +76,26 @@ type FakeIDP struct { serve bool } +func StatusError(code int, err error) error { + return statusHookError{ + Err: err, + HTTPStatusCode: code, + } +} + +// statusHookError allows a hook to change the returned http status code. +type statusHookError struct { + Err error + HTTPStatusCode int +} + +func (s statusHookError) Error() string { + if s.Err == nil { + return "" + } + return s.Err.Error() +} + type FakeIDPOpt func(idp *FakeIDP) func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeIDP) { @@ -83,9 +104,9 @@ func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeID } } -// WithRefreshHook is called when a refresh token is used. The email is +// WithRefresh is called when a refresh token is used. The email is // the email of the user that is being refreshed assuming the claims are correct. -func WithRefreshHook(hook func(email string) error) func(*FakeIDP) { +func WithRefresh(hook func(email string) error) func(*FakeIDP) { return func(f *FakeIDP) { f.hookOnRefresh = hook } @@ -108,13 +129,13 @@ func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) { // every user on the /userinfo endpoint. func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) { return func(f *FakeIDP) { - f.hookUserInfo = func(_ string) jwt.MapClaims { - return info + f.hookUserInfo = func(_ string) (jwt.MapClaims, error) { + return info, nil } } } -func WithDynamicUserInfo(userInfoFunc func(email string) jwt.MapClaims) func(*FakeIDP) { +func WithDynamicUserInfo(userInfoFunc func(email string) (jwt.MapClaims, error)) func(*FakeIDP) { return func(f *FakeIDP) { f.hookUserInfo = userInfoFunc } @@ -160,7 +181,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](), refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](), hookOnRefresh: func(_ string) error { return nil }, - hookUserInfo: func(email string) jwt.MapClaims { return jwt.MapClaims{} }, + hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil }, hookValidRedirectURL: func(redirectURL string) error { return nil }, } @@ -181,6 +202,10 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { return idp } +func (f *FakeIDP) WellknownConfig() ProviderJSON { + return f.provider +} + func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { t.Helper() @@ -188,9 +213,9 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { require.NoError(t, err, "invalid issuer URL") f.issuer = issuer - // providerJSON is the JSON representation of the OpenID Connect provider + // ProviderJSON is the JSON representation of the OpenID Connect provider // These are all the urls that the IDP will respond to. - f.provider = providerJSON{ + f.provider = ProviderJSON{ Issuer: issuer, AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(), TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(), @@ -220,6 +245,15 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server { return srv } +// GenerateAuthenticatedToken skips all oauth2 flows, and just generates a +// valid token for some given claims. +func (f *FakeIDP) GenerateAuthenticatedToken(claims jwt.MapClaims) (*oauth2.Token, error) { + state := uuid.NewString() + f.stateToIDTokenClaims.Store(state, claims) + code := f.newCode(state) + return f.cfg.Exchange(oidc.ClientContext(context.Background(), f.HTTPClient(nil)), code) +} + // Login does the full OIDC flow starting at the "LoginButton". // The client argument is just to get the URL of the Coder instance. // @@ -333,7 +367,8 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map return resp, nil } -type providerJSON struct { +// ProviderJSON is the .well-known/configuration JSON +type ProviderJSON struct { Issuer string `json:"issuer"` AuthURL string `json:"authorization_endpoint"` TokenURL string `json:"token_endpoint"` @@ -475,7 +510,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { err := f.hookValidRedirectURL(redirectURI) if err != nil { t.Errorf("not authorized redirect_uri by custom hook %q: %s", redirectURI, err.Error()) - http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest) + http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err)) return } @@ -501,7 +536,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { slog.F("values", values.Encode()), ) if err != nil { - http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), http.StatusBadRequest) + http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err)) return } getEmail := func(claims jwt.MapClaims) string { @@ -562,7 +597,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { claims = idTokenClaims err := f.hookOnRefresh(getEmail(claims)) if err != nil { - http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), http.StatusBadRequest) + http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err)) return } @@ -610,7 +645,12 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { http.Error(rw, "invalid access token, missing user info", http.StatusBadRequest) return } - _ = json.NewEncoder(rw).Encode(f.hookUserInfo(email)) + claims, err := f.hookUserInfo(email) + if err != nil { + http.Error(rw, fmt.Sprintf("user info hook returned error: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err)) + return + } + _ = json.NewEncoder(rw).Encode(claims) })) mux.Handle(keysPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { @@ -768,6 +808,15 @@ func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *co return cfg } +func httpErrorCode(defaultCode int, err error) int { + var stautsErr statusHookError + status := defaultCode + if errors.As(err, &stautsErr) { + status = stautsErr.HTTPStatusCode + } + return status +} + type fakeRoundTripper struct { roundTrip func(req *http.Request) (*http.Response, error) } diff --git a/coderd/gitauth/config.go b/coderd/gitauth/config.go index 00aa5a2d23a58..f5810f733181d 100644 --- a/coderd/gitauth/config.go +++ b/coderd/gitauth/config.go @@ -60,17 +60,30 @@ type Config struct { } // RefreshToken automatically refreshes the token if expired and permitted. -// It returns the token and a bool indicating if the token was refreshed. +// It returns the token and a bool indicating if the token is valid. func (c *Config) RefreshToken(ctx context.Context, db database.Store, gitAuthLink database.GitAuthLink) (database.GitAuthLink, bool, error) { // If the token is expired and refresh is disabled, we prompt // the user to authenticate again. - if c.NoRefresh && gitAuthLink.OAuthExpiry.Before(dbtime.Now()) { + if c.NoRefresh && + // If the time is set to 0, then it should never expire. + // This is true for github, which has no expiry. + !gitAuthLink.OAuthExpiry.IsZero() && + gitAuthLink.OAuthExpiry.Before(dbtime.Now()) { return gitAuthLink, false, nil } + // This is additional defensive programming. Because TokenSource is an interface, + // we cannot be sure that the implementation will treat an 'IsZero' time + // as "not-expired". The default implementation does, but a custom implementation + // might not. Removing the refreshToken will guarantee a refresh will fail. + refreshToken := gitAuthLink.OAuthRefreshToken + if c.NoRefresh { + refreshToken = "" + } + token, err := c.TokenSource(ctx, &oauth2.Token{ AccessToken: gitAuthLink.OAuthAccessToken, - RefreshToken: gitAuthLink.OAuthRefreshToken, + RefreshToken: refreshToken, Expiry: gitAuthLink.OAuthExpiry, }).Token() if err != nil { @@ -130,8 +143,13 @@ func (c *Config) ValidateToken(ctx context.Context, token string) (bool, *coders if err != nil { return false, nil, err } + + cli := http.DefaultClient + if v, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok { + cli = v + } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - res, err := http.DefaultClient.Do(req) + res, err := cli.Do(req) if err != nil { return false, nil, err } diff --git a/coderd/gitauth/config_test.go b/coderd/gitauth/config_test.go index bcd650e82ad3a..f6c97a440e3cb 100644 --- a/coderd/gitauth/config_test.go +++ b/coderd/gitauth/config_test.go @@ -3,18 +3,22 @@ package gitauth_test import ( "context" "net/http" - "net/http/httptest" "net/url" "testing" "time" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" "github.com/stretchr/testify/require" "golang.org/x/oauth2" "golang.org/x/xerrors" + "github.com/coder/coder/v2/coderd" + "github.com/coder/coder/v2/coderd/coderdtest/oidctest" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbfake" - "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/gitauth" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" @@ -22,17 +26,70 @@ import ( func TestRefreshToken(t *testing.T) { t.Parallel() - t.Run("FalseIfNoRefresh", func(t *testing.T) { + const providerID = "test-idp" + expired := time.Now().Add(time.Hour * -1) + + t.Run("NoRefreshExpired", func(t *testing.T) { t.Parallel() - config := &gitauth.Config{ - NoRefresh: true, - } - _, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{ - OAuthExpiry: time.Time{}, + fake, config, link := setupOauth2Test(t, testConfig{ + FakeIDPOpts: []oidctest.FakeIDPOpt{ + oidctest.WithRefresh(func(_ string) error { + t.Error("refresh on the IDP was called, but NoRefresh was set") + return xerrors.New("should not be called") + }), + // The IDP should not be contacted since the token is expired. An expired + // token with 'NoRefresh' should early abort. + oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) { + t.Error("token was validated, but it was expired and this should never have happened.") + return nil, xerrors.New("should not be called") + }), + }, + GitConfigOpt: func(cfg *gitauth.Config) { + cfg.NoRefresh = true + }, }) + + ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil)) + // Expire the link + link.OAuthExpiry = expired + + _, refreshed, err := config.RefreshToken(ctx, nil, link) require.NoError(t, err) require.False(t, refreshed) }) + + // NoRefreshNoExpiry tests that an oauth token without an expiry is always valid. + // The "validate url" should be hit, but the refresh endpoint should not. + t.Run("NoRefreshNoExpiry", func(t *testing.T) { + t.Parallel() + + validated := false + fake, config, link := setupOauth2Test(t, testConfig{ + FakeIDPOpts: []oidctest.FakeIDPOpt{ + oidctest.WithRefresh(func(_ string) error { + t.Error("refresh on the IDP was called, but NoRefresh was set") + return xerrors.New("should not be called") + }), + oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) { + validated = true + return jwt.MapClaims{}, nil + }), + }, + GitConfigOpt: func(cfg *gitauth.Config) { + cfg.NoRefresh = true + }, + }) + + ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil)) + + // Zero time used + link.OAuthExpiry = time.Time{} + _, refreshed, err := config.RefreshToken(ctx, nil, link) + require.NoError(t, err) + require.True(t, refreshed, "token without expiry is always valid") + require.True(t, validated, "token should have been validated") + }) + t.Run("FalseIfTokenSourceFails", func(t *testing.T) { t.Parallel() config := &gitauth.Config{ @@ -42,111 +99,167 @@ func TestRefreshToken(t *testing.T) { }, }, } - _, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{}) + _, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{ + OAuthExpiry: expired, + }) require.NoError(t, err) require.False(t, refreshed) }) + t.Run("ValidateServerError", func(t *testing.T) { t.Parallel() - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("Failure")) - })) - config := &gitauth.Config{ - OAuth2Config: &testutil.OAuth2Config{}, - ValidateURL: srv.URL, - } - _, _, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{}) - require.ErrorContains(t, err, "Failure") + + const staticError = "static error" + validated := false + fake, config, link := setupOauth2Test(t, testConfig{ + FakeIDPOpts: []oidctest.FakeIDPOpt{ + oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) { + validated = true + return jwt.MapClaims{}, xerrors.New(staticError) + }), + }, + GitConfigOpt: func(cfg *gitauth.Config) { + }, + }) + + ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil)) + link.OAuthExpiry = expired + + _, _, err := config.RefreshToken(ctx, nil, link) + require.ErrorContains(t, err, staticError) + require.True(t, validated, "token should have been attempted to be validated") }) + + // ValidateFailure tests if the token is no longer valid with a 401 response. t.Run("ValidateFailure", func(t *testing.T) { t.Parallel() - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusUnauthorized) - w.Write([]byte("Not permitted")) - })) - config := &gitauth.Config{ - OAuth2Config: &testutil.OAuth2Config{}, - ValidateURL: srv.URL, - } - _, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{}) - require.NoError(t, err) + + const staticError = "static error" + validated := false + fake, config, link := setupOauth2Test(t, testConfig{ + FakeIDPOpts: []oidctest.FakeIDPOpt{ + oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) { + validated = true + return jwt.MapClaims{}, oidctest.StatusError(http.StatusUnauthorized, xerrors.New(staticError)) + }), + }, + GitConfigOpt: func(cfg *gitauth.Config) { + }, + }) + + ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil)) + link.OAuthExpiry = expired + + _, refreshed, err := config.RefreshToken(ctx, nil, link) + require.NoError(t, err, staticError) require.False(t, refreshed) + require.True(t, validated, "token should have been attempted to be validated") }) + t.Run("ValidateRetryGitHub", func(t *testing.T) { t.Parallel() - hit := false - // We need to ensure that the exponential backoff kicks in properly. - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !hit { - hit = true - w.WriteHeader(http.StatusUnauthorized) - w.Write([]byte("Not permitted")) - return - } - w.WriteHeader(http.StatusOK) - })) - config := &gitauth.Config{ - ID: "test", - OAuth2Config: &testutil.OAuth2Config{ - Token: &oauth2.Token{ - AccessToken: "updated", - }, + + const staticError = "static error" + validateCalls := 0 + fake, config, link := setupOauth2Test(t, testConfig{ + FakeIDPOpts: []oidctest.FakeIDPOpt{ + oidctest.WithRefresh(func(_ string) error { + t.Error("refresh on the IDP was called, but the token is not expired") + return xerrors.New("should not be called") + }), + oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) { + validateCalls++ + // Make the first call return a 401, subsequent calls should return a 200. + if validateCalls > 1 { + return jwt.MapClaims{}, nil + } + return jwt.MapClaims{}, oidctest.StatusError(http.StatusUnauthorized, xerrors.New(staticError)) + }), + }, + GitConfigOpt: func(cfg *gitauth.Config) { + cfg.Type = codersdk.GitProviderGitHub }, - ValidateURL: srv.URL, - Type: codersdk.GitProviderGitHub, - } - db := dbfake.New() - link := dbgen.GitAuthLink(t, db, database.GitAuthLink{ - ProviderID: config.ID, - OAuthAccessToken: "initial", }) - _, refreshed, err := config.RefreshToken(context.Background(), db, link) + + ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil)) + // Unlimited lifetime, this is what GitHub returns tokens as + link.OAuthExpiry = time.Time{} + + _, ok, err := config.RefreshToken(ctx, nil, link) require.NoError(t, err) - require.True(t, refreshed) - require.True(t, hit) + require.True(t, ok) + require.Equal(t, 2, validateCalls, "token should have been attempted to be validated more than once") }) + t.Run("ValidateNoUpdate", func(t *testing.T) { t.Parallel() - validated := make(chan struct{}) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - close(validated) - })) - accessToken := "testing" - config := &gitauth.Config{ - OAuth2Config: &testutil.OAuth2Config{ - Token: &oauth2.Token{ - AccessToken: accessToken, - }, + + validateCalls := 0 + fake, config, link := setupOauth2Test(t, testConfig{ + FakeIDPOpts: []oidctest.FakeIDPOpt{ + oidctest.WithRefresh(func(_ string) error { + t.Error("refresh on the IDP was called, but the token is not expired") + return xerrors.New("should not be called") + }), + oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) { + validateCalls++ + return jwt.MapClaims{}, nil + }), + }, + GitConfigOpt: func(cfg *gitauth.Config) { + cfg.Type = codersdk.GitProviderGitHub }, - ValidateURL: srv.URL, - } - _, valid, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{ - OAuthAccessToken: accessToken, }) + + ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil)) + + _, ok, err := config.RefreshToken(ctx, nil, link) require.NoError(t, err) - require.True(t, valid) - <-validated + require.True(t, ok) + require.Equal(t, 1, validateCalls, "token is validated") }) + + // A token update comes from a refresh. t.Run("Updates", func(t *testing.T) { t.Parallel() - config := &gitauth.Config{ - ID: "test", - OAuth2Config: &testutil.OAuth2Config{ - Token: &oauth2.Token{ - AccessToken: "updated", - }, - }, - } + db := dbfake.New() - link := dbgen.GitAuthLink(t, db, database.GitAuthLink{ - ProviderID: config.ID, - OAuthAccessToken: "initial", + validateCalls := 0 + refreshCalls := 0 + fake, config, link := setupOauth2Test(t, testConfig{ + FakeIDPOpts: []oidctest.FakeIDPOpt{ + oidctest.WithRefresh(func(_ string) error { + refreshCalls++ + return nil + }), + oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) { + validateCalls++ + return jwt.MapClaims{}, nil + }), + }, + GitConfigOpt: func(cfg *gitauth.Config) { + cfg.Type = codersdk.GitProviderGitHub + }, + DB: db, }) - _, valid, err := config.RefreshToken(context.Background(), db, link) + + ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil)) + // Force a refresh + link.OAuthExpiry = expired + + updated, ok, err := config.RefreshToken(ctx, db, link) require.NoError(t, err) - require.True(t, valid) + require.True(t, ok) + require.Equal(t, 1, validateCalls, "token is validated") + require.Equal(t, 1, refreshCalls, "token is refreshed") + require.NotEqualf(t, link.OAuthAccessToken, updated.OAuthAccessToken, "token is updated") + //nolint:gocritic // testing + dbLink, err := db.GetGitAuthLink(dbauthz.AsSystemRestricted(context.Background()), database.GetGitAuthLinkParams{ + ProviderID: link.ProviderID, + UserID: link.UserID, + }) + require.NoError(t, err) + require.Equal(t, updated.OAuthAccessToken, dbLink.OAuthAccessToken, "token is updated in the DB") }) } @@ -232,3 +345,65 @@ func TestConvertYAML(t *testing.T) { require.Equal(t, "https://auth.com?client_id=id&redirect_uri=%2Fgitauth%2Fgitlab%2Fcallback&response_type=code&scope=read", config[0].AuthCodeURL("")) }) } + +type testConfig struct { + FakeIDPOpts []oidctest.FakeIDPOpt + CoderOIDCConfigOpts []func(cfg *coderd.OIDCConfig) + GitConfigOpt func(cfg *gitauth.Config) + // If DB is passed in, the link will be inserted into the DB. + DB database.Store +} + +// setupTest will configure a fake IDP and a gitauth.Config for testing. +// The Fake's userinfo endpoint is used for validating tokens. +// No http servers are started so use the fake IDP's HTTPClient to make requests. +// The returned token is a fully valid token for the IDP. Feel free to manipulate it +// to test different scenarios. +func setupOauth2Test(t *testing.T, settings testConfig) (*oidctest.FakeIDP, *gitauth.Config, database.GitAuthLink) { + t.Helper() + + const providerID = "test-idp" + fake := oidctest.NewFakeIDP(t, + append([]oidctest.FakeIDPOpt{}, settings.FakeIDPOpts...)..., + ) + + config := &gitauth.Config{ + OAuth2Config: fake.OIDCConfig(t, nil, settings.CoderOIDCConfigOpts...), + ID: providerID, + ValidateURL: fake.WellknownConfig().UserInfoURL, + } + settings.GitConfigOpt(config) + + oauthToken, err := fake.GenerateAuthenticatedToken(jwt.MapClaims{ + "email": "test@coder.com", + }) + require.NoError(t, err) + + now := time.Now() + link := database.GitAuthLink{ + ProviderID: providerID, + UserID: uuid.New(), + CreatedAt: now, + UpdatedAt: now, + OAuthAccessToken: oauthToken.AccessToken, + OAuthRefreshToken: oauthToken.RefreshToken, + // The caller can manually expire this if they want. + OAuthExpiry: now.Add(time.Hour), + } + + if settings.DB != nil { + // Feel free to insert additional things like the user, etc if required. + link, err = settings.DB.InsertGitAuthLink(context.Background(), database.InsertGitAuthLinkParams{ + ProviderID: link.ProviderID, + UserID: link.UserID, + CreatedAt: link.CreatedAt, + UpdatedAt: link.UpdatedAt, + OAuthAccessToken: link.OAuthAccessToken, + OAuthRefreshToken: link.OAuthRefreshToken, + OAuthExpiry: link.OAuthExpiry, + }) + require.NoError(t, err, "failed to insert link into DB") + } + + return fake, config, link +} diff --git a/coderd/userauth_test.go b/coderd/userauth_test.go index 1f37a0721a1e7..fe6ded1e901b1 100644 --- a/coderd/userauth_test.go +++ b/coderd/userauth_test.go @@ -37,7 +37,7 @@ func TestOIDCOauthLoginWithExisting(t *testing.T) { t.Parallel() fake := oidctest.NewFakeIDP(t, - oidctest.WithRefreshHook(func(_ string) error { + oidctest.WithRefresh(func(_ string) error { return xerrors.New("refreshing token should never occur") }), oidctest.WithServing(), @@ -797,7 +797,7 @@ func TestUserOIDC(t *testing.T) { t.Run(tc.Name, func(t *testing.T) { t.Parallel() fake := oidctest.NewFakeIDP(t, - oidctest.WithRefreshHook(func(_ string) error { + oidctest.WithRefresh(func(_ string) error { return xerrors.New("refreshing token should never occur") }), oidctest.WithServing(), @@ -851,7 +851,7 @@ func TestUserOIDC(t *testing.T) { auditor := audit.NewMock() fake := oidctest.NewFakeIDP(t, - oidctest.WithRefreshHook(func(_ string) error { + oidctest.WithRefresh(func(_ string) error { return xerrors.New("refreshing token should never occur") }), oidctest.WithServing(), @@ -898,7 +898,7 @@ func TestUserOIDC(t *testing.T) { t.Parallel() auditor := audit.NewMock() fake := oidctest.NewFakeIDP(t, - oidctest.WithRefreshHook(func(_ string) error { + oidctest.WithRefresh(func(_ string) error { return xerrors.New("refreshing token should never occur") }), oidctest.WithServing(), @@ -959,7 +959,7 @@ func TestUserOIDC(t *testing.T) { t.Run("NoIDToken", func(t *testing.T) { t.Parallel() fake := oidctest.NewFakeIDP(t, - oidctest.WithRefreshHook(func(_ string) error { + oidctest.WithRefresh(func(_ string) error { return xerrors.New("refreshing token should never occur") }), oidctest.WithServing(), @@ -984,7 +984,7 @@ func TestUserOIDC(t *testing.T) { badProvider := &oidc.Provider{} fake := oidctest.NewFakeIDP(t, - oidctest.WithRefreshHook(func(_ string) error { + oidctest.WithRefresh(func(_ string) error { return xerrors.New("refreshing token should never occur") }), oidctest.WithServing(), diff --git a/enterprise/coderd/userauth_test.go b/enterprise/coderd/userauth_test.go index 2927ea88b9d1e..9d7e2762f005e 100644 --- a/enterprise/coderd/userauth_test.go +++ b/enterprise/coderd/userauth_test.go @@ -365,7 +365,7 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ FakeOpts: []oidctest.FakeIDPOpt{ - oidctest.WithRefreshHook(func(_ string) error { + oidctest.WithRefresh(func(_ string) error { // Always "expired" refresh token. return xerrors.New("refresh token is expired") }),