From 212429b33a5a7cd18e1d7ce4f68e0ffe8570b06d Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 31 Aug 2023 15:03:16 -0500 Subject: [PATCH 1/5] chore: fix NoRefresh to honor unlimited tokens - improve testing coverage of gitauth --- coderd/coderdtest/oidctest/idp.go | 22 +++++-- coderd/gitauth/config.go | 24 +++++++- coderd/gitauth/config_test.go | 96 ++++++++++++++++++++++++++++--- 3 files changed, 128 insertions(+), 14 deletions(-) diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index 3ca8cadbc9ff9..2050efdce3ffa 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -41,7 +41,7 @@ import ( type FakeIDP struct { issuer string key *rsa.PrivateKey - provider providerJSON + provider ProviderJSON handler http.Handler cfg *oauth2.Config @@ -181,6 +181,10 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { return idp } +func (f *FakeIDP) WellknownConfig() ProviderJSON { + return f.provider +} + func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { t.Helper() @@ -188,9 +192,9 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { require.NoError(t, err, "invalid issuer URL") f.issuer = issuer - // providerJSON is the JSON representation of the OpenID Connect provider + // ProviderJSON is the JSON representation of the OpenID Connect provider // These are all the urls that the IDP will respond to. - f.provider = providerJSON{ + f.provider = ProviderJSON{ Issuer: issuer, AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(), TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(), @@ -220,6 +224,15 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server { return srv } +// GenerateAuthenticatedToken skips all oauth2 flows, and just generates a +// valid token for some given claims. +func (f *FakeIDP) GenerateAuthenticatedToken(claims jwt.MapClaims) (*oauth2.Token, error) { + state := uuid.NewString() + f.stateToIDTokenClaims.Store(state, claims) + code := f.newCode(state) + return f.cfg.Exchange(oidc.ClientContext(context.Background(), f.HTTPClient(nil)), code) +} + // Login does the full OIDC flow starting at the "LoginButton". // The client argument is just to get the URL of the Coder instance. // @@ -333,7 +346,8 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map return resp, nil } -type providerJSON struct { +// ProviderJSON is the .well-known/configuration JSON +type ProviderJSON struct { Issuer string `json:"issuer"` AuthURL string `json:"authorization_endpoint"` TokenURL string `json:"token_endpoint"` diff --git a/coderd/gitauth/config.go b/coderd/gitauth/config.go index 31b0f052fcd9e..f0c51945f0766 100644 --- a/coderd/gitauth/config.go +++ b/coderd/gitauth/config.go @@ -63,13 +63,26 @@ type Config struct { func (c *Config) RefreshToken(ctx context.Context, db database.Store, gitAuthLink database.GitAuthLink) (database.GitAuthLink, bool, error) { // If the token is expired and refresh is disabled, we prompt // the user to authenticate again. - if c.NoRefresh && gitAuthLink.OAuthExpiry.Before(database.Now()) { + if c.NoRefresh && + // If the time is set to 0, then it should never expire. + // This is true for github, which has no expiry. + !gitAuthLink.OAuthExpiry.IsZero() && + gitAuthLink.OAuthExpiry.Before(database.Now()) { return gitAuthLink, false, nil } + // This is additional defensive programming. Because TokenSource is an interface, + // we cannot be sure that the implementation will treat an 'IsZero' time + // as "not-expired". The default implementation does, but a custom implementation + // might not. Removing the refreshToken will guarantee a refresh will fail. + refreshToken := gitAuthLink.OAuthRefreshToken + if c.NoRefresh { + refreshToken = "" + } + token, err := c.TokenSource(ctx, &oauth2.Token{ AccessToken: gitAuthLink.OAuthAccessToken, - RefreshToken: gitAuthLink.OAuthRefreshToken, + RefreshToken: refreshToken, Expiry: gitAuthLink.OAuthExpiry, }).Token() if err != nil { @@ -129,8 +142,13 @@ func (c *Config) ValidateToken(ctx context.Context, token string) (bool, *coders if err != nil { return false, nil, err } + + cli := http.DefaultClient + if v, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok { + cli = v + } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - res, err := http.DefaultClient.Do(req) + res, err := cli.Do(req) if err != nil { return false, nil, err } diff --git a/coderd/gitauth/config_test.go b/coderd/gitauth/config_test.go index bcd650e82ad3a..1b556d6644920 100644 --- a/coderd/gitauth/config_test.go +++ b/coderd/gitauth/config_test.go @@ -8,6 +8,14 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt/v4" + + "github.com/google/uuid" + + "github.com/coreos/go-oidc/v3/oidc" + + "github.com/coder/coder/v2/coderd/coderdtest/oidctest" + "github.com/stretchr/testify/require" "golang.org/x/oauth2" "golang.org/x/xerrors" @@ -22,17 +30,82 @@ import ( func TestRefreshToken(t *testing.T) { t.Parallel() - t.Run("FalseIfNoRefresh", func(t *testing.T) { + const providerID = "test-idp" + expired := time.Now().Add(time.Hour * -1) + t.Run("NoRefreshExpired", func(t *testing.T) { t.Parallel() + + fake := oidctest.NewFakeIDP(t, + // The IDP should not be contacted since the token is expired. An expired + // token with 'NoRefresh' should early abort. + oidctest.WithRefreshHook(func(_ string) error { + t.Error("refresh on the IDP was called, but NoRefresh was set") + return xerrors.New("should not be called") + }), + oidctest.WithDynamicUserInfo(func(_ string) jwt.MapClaims { + t.Error("token was validated, but it was expired and this should never have happened.") + return nil + }), + ) + + ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil)) config := &gitauth.Config{ - NoRefresh: true, + ID: providerID, + OAuth2Config: fake.OIDCConfig(t, nil), + NoRefresh: true, + ValidateURL: fake.WellknownConfig().UserInfoURL, } - _, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{ - OAuthExpiry: time.Time{}, + _, refreshed, err := config.RefreshToken(ctx, nil, database.GitAuthLink{ + ProviderID: providerID, + UserID: uuid.New(), + OAuthAccessToken: uuid.NewString(), + OAuthRefreshToken: uuid.NewString(), + OAuthExpiry: expired, }) require.NoError(t, err) require.False(t, refreshed) }) + t.Run("NoRefreshNoExpiry", func(t *testing.T) { + t.Parallel() + + validated := false + fake := oidctest.NewFakeIDP(t, + // The IDP should not be contacted since the token is expired. An expired + // token with 'NoRefresh' should early abort. + oidctest.WithRefreshHook(func(_ string) error { + t.Error("refresh on the IDP was called, but NoRefresh was set") + return xerrors.New("should not be called") + }), + oidctest.WithDynamicUserInfo(func(_ string) jwt.MapClaims { + validated = true + return jwt.MapClaims{} + }), + ) + + ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil)) + config := &gitauth.Config{ + ID: providerID, + OAuth2Config: fake.OIDCConfig(t, nil), + NoRefresh: true, + ValidateURL: fake.WellknownConfig().UserInfoURL, + } + + token, err := fake.GenerateAuthenticatedToken(jwt.MapClaims{}) + require.NoError(t, err) + + _, refreshed, err := config.RefreshToken(ctx, nil, database.GitAuthLink{ + ProviderID: providerID, + UserID: uuid.New(), + OAuthAccessToken: token.AccessToken, + // Pass a refresh token, but this should be ignored in this test! + OAuthRefreshToken: token.RefreshToken, + // Zero time used + OAuthExpiry: time.Time{}, + }) + require.NoError(t, err) + require.True(t, refreshed, "token without expiry is always valid") + require.True(t, validated, "token should have been validated") + }) t.Run("FalseIfTokenSourceFails", func(t *testing.T) { t.Parallel() config := &gitauth.Config{ @@ -42,7 +115,9 @@ func TestRefreshToken(t *testing.T) { }, }, } - _, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{}) + _, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{ + OAuthExpiry: expired, + }) require.NoError(t, err) require.False(t, refreshed) }) @@ -56,7 +131,9 @@ func TestRefreshToken(t *testing.T) { OAuth2Config: &testutil.OAuth2Config{}, ValidateURL: srv.URL, } - _, _, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{}) + _, _, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{ + OAuthExpiry: expired, + }) require.ErrorContains(t, err, "Failure") }) t.Run("ValidateFailure", func(t *testing.T) { @@ -69,7 +146,9 @@ func TestRefreshToken(t *testing.T) { OAuth2Config: &testutil.OAuth2Config{}, ValidateURL: srv.URL, } - _, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{}) + _, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{ + OAuthExpiry: expired, + }) require.NoError(t, err) require.False(t, refreshed) }) @@ -100,6 +179,7 @@ func TestRefreshToken(t *testing.T) { link := dbgen.GitAuthLink(t, db, database.GitAuthLink{ ProviderID: config.ID, OAuthAccessToken: "initial", + OAuthExpiry: expired, }) _, refreshed, err := config.RefreshToken(context.Background(), db, link) require.NoError(t, err) @@ -124,6 +204,7 @@ func TestRefreshToken(t *testing.T) { } _, valid, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{ OAuthAccessToken: accessToken, + OAuthExpiry: expired, }) require.NoError(t, err) require.True(t, valid) @@ -143,6 +224,7 @@ func TestRefreshToken(t *testing.T) { link := dbgen.GitAuthLink(t, db, database.GitAuthLink{ ProviderID: config.ID, OAuthAccessToken: "initial", + OAuthExpiry: expired, }) _, valid, err := config.RefreshToken(context.Background(), db, link) require.NoError(t, err) From 4637ddb714f2f5495ad594154e9f801d007995e8 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 31 Aug 2023 15:47:27 -0500 Subject: [PATCH 2/5] refactor rest of gitauth tests --- coderd/coderdtest/oidctest/idp.go | 53 +++- coderd/gitauth/config.go | 2 +- coderd/gitauth/config_test.go | 385 +++++++++++++++++++----------- 3 files changed, 284 insertions(+), 156 deletions(-) diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index 2050efdce3ffa..0970b4a10b486 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -7,6 +7,7 @@ import ( "crypto/x509" "encoding/json" "encoding/pem" + "errors" "fmt" "io" "net" @@ -66,7 +67,7 @@ type FakeIDP struct { // IDP -> Application. Almost all IDPs have the concept of // "Authorized Redirect URLs". This can be used to emulate that. hookValidRedirectURL func(redirectURL string) error - hookUserInfo func(email string) jwt.MapClaims + hookUserInfo func(email string) (jwt.MapClaims, error) fakeCoderd func(req *http.Request) (*http.Response, error) hookOnRefresh func(email string) error // Custom authentication for the client. This is useful if you want @@ -75,6 +76,26 @@ type FakeIDP struct { serve bool } +func StatusError(code int, err error) error { + return statusHookError{ + Err: err, + HTTPStatusCode: code, + } +} + +// statusHookError allows a hook to change the returned http status code. +type statusHookError struct { + Err error + HTTPStatusCode int +} + +func (s statusHookError) Error() string { + if s.Err == nil { + return "" + } + return s.Err.Error() +} + type FakeIDPOpt func(idp *FakeIDP) func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeIDP) { @@ -108,13 +129,13 @@ func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) { // every user on the /userinfo endpoint. func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) { return func(f *FakeIDP) { - f.hookUserInfo = func(_ string) jwt.MapClaims { - return info + f.hookUserInfo = func(_ string) (jwt.MapClaims, error) { + return info, nil } } } -func WithDynamicUserInfo(userInfoFunc func(email string) jwt.MapClaims) func(*FakeIDP) { +func WithDynamicUserInfo(userInfoFunc func(email string) (jwt.MapClaims, error)) func(*FakeIDP) { return func(f *FakeIDP) { f.hookUserInfo = userInfoFunc } @@ -160,7 +181,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](), refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](), hookOnRefresh: func(_ string) error { return nil }, - hookUserInfo: func(email string) jwt.MapClaims { return jwt.MapClaims{} }, + hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil }, hookValidRedirectURL: func(redirectURL string) error { return nil }, } @@ -489,7 +510,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { err := f.hookValidRedirectURL(redirectURI) if err != nil { t.Errorf("not authorized redirect_uri by custom hook %q: %s", redirectURI, err.Error()) - http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest) + http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err)) return } @@ -515,7 +536,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { slog.F("values", values.Encode()), ) if err != nil { - http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), http.StatusBadRequest) + http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err)) return } getEmail := func(claims jwt.MapClaims) string { @@ -576,7 +597,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { claims = idTokenClaims err := f.hookOnRefresh(getEmail(claims)) if err != nil { - http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), http.StatusBadRequest) + http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err)) return } @@ -624,7 +645,12 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { http.Error(rw, "invalid access token, missing user info", http.StatusBadRequest) return } - _ = json.NewEncoder(rw).Encode(f.hookUserInfo(email)) + claims, err := f.hookUserInfo(email) + if err != nil { + http.Error(rw, fmt.Sprintf("user info hook returned error: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err)) + return + } + _ = json.NewEncoder(rw).Encode(claims) })) mux.Handle(keysPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { @@ -782,6 +808,15 @@ func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *co return cfg } +func httpErrorCode(defaultCode int, err error) int { + var stautsErr statusHookError + var status = defaultCode + if errors.As(err, &stautsErr) { + status = stautsErr.HTTPStatusCode + } + return status +} + type fakeRoundTripper struct { roundTrip func(req *http.Request) (*http.Response, error) } diff --git a/coderd/gitauth/config.go b/coderd/gitauth/config.go index f0c51945f0766..5bb726f531941 100644 --- a/coderd/gitauth/config.go +++ b/coderd/gitauth/config.go @@ -59,7 +59,7 @@ type Config struct { } // RefreshToken automatically refreshes the token if expired and permitted. -// It returns the token and a bool indicating if the token was refreshed. +// It returns the token and a bool indicating if the token is valid. func (c *Config) RefreshToken(ctx context.Context, db database.Store, gitAuthLink database.GitAuthLink) (database.GitAuthLink, bool, error) { // If the token is expired and refresh is disabled, we prompt // the user to authenticate again. diff --git a/coderd/gitauth/config_test.go b/coderd/gitauth/config_test.go index 1b556d6644920..6915dfc6b4933 100644 --- a/coderd/gitauth/config_test.go +++ b/coderd/gitauth/config_test.go @@ -3,26 +3,22 @@ package gitauth_test import ( "context" "net/http" - "net/http/httptest" "net/url" "testing" "time" + "github.com/coreos/go-oidc/v3/oidc" "github.com/golang-jwt/jwt/v4" - "github.com/google/uuid" - - "github.com/coreos/go-oidc/v3/oidc" - - "github.com/coder/coder/v2/coderd/coderdtest/oidctest" - "github.com/stretchr/testify/require" "golang.org/x/oauth2" "golang.org/x/xerrors" + "github.com/coder/coder/v2/coderd" + "github.com/coder/coder/v2/coderd/coderdtest/oidctest" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbfake" - "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/gitauth" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" @@ -32,80 +28,68 @@ func TestRefreshToken(t *testing.T) { t.Parallel() const providerID = "test-idp" expired := time.Now().Add(time.Hour * -1) + t.Run("NoRefreshExpired", func(t *testing.T) { t.Parallel() - - fake := oidctest.NewFakeIDP(t, - // The IDP should not be contacted since the token is expired. An expired - // token with 'NoRefresh' should early abort. - oidctest.WithRefreshHook(func(_ string) error { - t.Error("refresh on the IDP was called, but NoRefresh was set") - return xerrors.New("should not be called") - }), - oidctest.WithDynamicUserInfo(func(_ string) jwt.MapClaims { - t.Error("token was validated, but it was expired and this should never have happened.") - return nil - }), - ) + fake, config, link := setupOauth2Test(t, testConfig{ + FakeIDPOpts: []oidctest.FakeIDPOpt{ + oidctest.WithRefreshHook(func(_ string) error { + t.Error("refresh on the IDP was called, but NoRefresh was set") + return xerrors.New("should not be called") + }), + // The IDP should not be contacted since the token is expired. An expired + // token with 'NoRefresh' should early abort. + oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) { + t.Error("token was validated, but it was expired and this should never have happened.") + return nil, xerrors.New("should not be called") + }), + }, + GitConfigOpt: func(cfg *gitauth.Config) { + cfg.NoRefresh = true + }, + }) ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil)) - config := &gitauth.Config{ - ID: providerID, - OAuth2Config: fake.OIDCConfig(t, nil), - NoRefresh: true, - ValidateURL: fake.WellknownConfig().UserInfoURL, - } - _, refreshed, err := config.RefreshToken(ctx, nil, database.GitAuthLink{ - ProviderID: providerID, - UserID: uuid.New(), - OAuthAccessToken: uuid.NewString(), - OAuthRefreshToken: uuid.NewString(), - OAuthExpiry: expired, - }) + // Expire the link + link.OAuthExpiry = expired + + _, refreshed, err := config.RefreshToken(ctx, nil, link) require.NoError(t, err) require.False(t, refreshed) }) + + // NoRefreshNoExpiry tests that an oauth token without an expiry is always valid. + // The "validate url" should be hit, but the refresh endpoint should not. t.Run("NoRefreshNoExpiry", func(t *testing.T) { t.Parallel() validated := false - fake := oidctest.NewFakeIDP(t, - // The IDP should not be contacted since the token is expired. An expired - // token with 'NoRefresh' should early abort. - oidctest.WithRefreshHook(func(_ string) error { - t.Error("refresh on the IDP was called, but NoRefresh was set") - return xerrors.New("should not be called") - }), - oidctest.WithDynamicUserInfo(func(_ string) jwt.MapClaims { - validated = true - return jwt.MapClaims{} - }), - ) + fake, config, link := setupOauth2Test(t, testConfig{ + FakeIDPOpts: []oidctest.FakeIDPOpt{ + oidctest.WithRefreshHook(func(_ string) error { + t.Error("refresh on the IDP was called, but NoRefresh was set") + return xerrors.New("should not be called") + }), + oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) { + validated = true + return jwt.MapClaims{}, nil + }), + }, + GitConfigOpt: func(cfg *gitauth.Config) { + cfg.NoRefresh = true + }, + }) ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil)) - config := &gitauth.Config{ - ID: providerID, - OAuth2Config: fake.OIDCConfig(t, nil), - NoRefresh: true, - ValidateURL: fake.WellknownConfig().UserInfoURL, - } - token, err := fake.GenerateAuthenticatedToken(jwt.MapClaims{}) - require.NoError(t, err) - - _, refreshed, err := config.RefreshToken(ctx, nil, database.GitAuthLink{ - ProviderID: providerID, - UserID: uuid.New(), - OAuthAccessToken: token.AccessToken, - // Pass a refresh token, but this should be ignored in this test! - OAuthRefreshToken: token.RefreshToken, - // Zero time used - OAuthExpiry: time.Time{}, - }) + // Zero time used + link.OAuthExpiry = time.Time{} + _, refreshed, err := config.RefreshToken(ctx, nil, link) require.NoError(t, err) require.True(t, refreshed, "token without expiry is always valid") require.True(t, validated, "token should have been validated") }) + t.Run("FalseIfTokenSourceFails", func(t *testing.T) { t.Parallel() config := &gitauth.Config{ @@ -121,114 +105,161 @@ func TestRefreshToken(t *testing.T) { require.NoError(t, err) require.False(t, refreshed) }) + t.Run("ValidateServerError", func(t *testing.T) { t.Parallel() - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("Failure")) - })) - config := &gitauth.Config{ - OAuth2Config: &testutil.OAuth2Config{}, - ValidateURL: srv.URL, - } - _, _, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{ - OAuthExpiry: expired, + + const staticError = "static error" + validated := false + fake, config, link := setupOauth2Test(t, testConfig{ + FakeIDPOpts: []oidctest.FakeIDPOpt{ + oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) { + validated = true + return jwt.MapClaims{}, xerrors.New(staticError) + }), + }, + GitConfigOpt: func(cfg *gitauth.Config) { + }, }) - require.ErrorContains(t, err, "Failure") + + ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil)) + link.OAuthExpiry = expired + + _, _, err := config.RefreshToken(ctx, nil, link) + require.ErrorContains(t, err, staticError) + require.True(t, validated, "token should have been attempted to be validated") }) + + // ValidateFailure tests if the token is no longer valid with a 401 response. t.Run("ValidateFailure", func(t *testing.T) { t.Parallel() - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusUnauthorized) - w.Write([]byte("Not permitted")) - })) - config := &gitauth.Config{ - OAuth2Config: &testutil.OAuth2Config{}, - ValidateURL: srv.URL, - } - _, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{ - OAuthExpiry: expired, + + const staticError = "static error" + validated := false + fake, config, link := setupOauth2Test(t, testConfig{ + FakeIDPOpts: []oidctest.FakeIDPOpt{ + oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) { + validated = true + return jwt.MapClaims{}, oidctest.StatusError(http.StatusUnauthorized, xerrors.New(staticError)) + }), + }, + GitConfigOpt: func(cfg *gitauth.Config) { + }, }) - require.NoError(t, err) + + ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil)) + link.OAuthExpiry = expired + + _, refreshed, err := config.RefreshToken(ctx, nil, link) + require.NoError(t, err, staticError) require.False(t, refreshed) + require.True(t, validated, "token should have been attempted to be validated") }) + t.Run("ValidateRetryGitHub", func(t *testing.T) { t.Parallel() - hit := false - // We need to ensure that the exponential backoff kicks in properly. - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !hit { - hit = true - w.WriteHeader(http.StatusUnauthorized) - w.Write([]byte("Not permitted")) - return - } - w.WriteHeader(http.StatusOK) - })) - config := &gitauth.Config{ - ID: "test", - OAuth2Config: &testutil.OAuth2Config{ - Token: &oauth2.Token{ - AccessToken: "updated", - }, + + const staticError = "static error" + validateCalls := 0 + fake, config, link := setupOauth2Test(t, testConfig{ + FakeIDPOpts: []oidctest.FakeIDPOpt{ + oidctest.WithRefreshHook(func(_ string) error { + t.Error("refresh on the IDP was called, but the token is not expired") + return xerrors.New("should not be called") + }), + oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) { + validateCalls++ + // Make the first call return a 401, subsequent calls should return a 200. + if validateCalls > 1 { + return jwt.MapClaims{}, nil + } + return jwt.MapClaims{}, oidctest.StatusError(http.StatusUnauthorized, xerrors.New(staticError)) + }), + }, + GitConfigOpt: func(cfg *gitauth.Config) { + cfg.Type = codersdk.GitProviderGitHub }, - ValidateURL: srv.URL, - Type: codersdk.GitProviderGitHub, - } - db := dbfake.New() - link := dbgen.GitAuthLink(t, db, database.GitAuthLink{ - ProviderID: config.ID, - OAuthAccessToken: "initial", - OAuthExpiry: expired, }) - _, refreshed, err := config.RefreshToken(context.Background(), db, link) + + ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil)) + // Unlimited lifetime, this is what GitHub returns tokens as + link.OAuthExpiry = time.Time{} + + _, ok, err := config.RefreshToken(ctx, nil, link) require.NoError(t, err) - require.True(t, refreshed) - require.True(t, hit) + require.True(t, ok) + require.Equal(t, 2, validateCalls, "token should have been attempted to be validated more than once") }) + t.Run("ValidateNoUpdate", func(t *testing.T) { t.Parallel() - validated := make(chan struct{}) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - close(validated) - })) - accessToken := "testing" - config := &gitauth.Config{ - OAuth2Config: &testutil.OAuth2Config{ - Token: &oauth2.Token{ - AccessToken: accessToken, - }, + + validateCalls := 0 + fake, config, link := setupOauth2Test(t, testConfig{ + FakeIDPOpts: []oidctest.FakeIDPOpt{ + oidctest.WithRefreshHook(func(_ string) error { + t.Error("refresh on the IDP was called, but the token is not expired") + return xerrors.New("should not be called") + }), + oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) { + validateCalls++ + return jwt.MapClaims{}, nil + }), + }, + GitConfigOpt: func(cfg *gitauth.Config) { + cfg.Type = codersdk.GitProviderGitHub }, - ValidateURL: srv.URL, - } - _, valid, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{ - OAuthAccessToken: accessToken, - OAuthExpiry: expired, }) + + ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil)) + + _, ok, err := config.RefreshToken(ctx, nil, link) require.NoError(t, err) - require.True(t, valid) - <-validated + require.True(t, ok) + require.Equal(t, 1, validateCalls, "token is validated") }) + + // A token update comes from a refresh. t.Run("Updates", func(t *testing.T) { t.Parallel() - config := &gitauth.Config{ - ID: "test", - OAuth2Config: &testutil.OAuth2Config{ - Token: &oauth2.Token{ - AccessToken: "updated", - }, - }, - } + db := dbfake.New() - link := dbgen.GitAuthLink(t, db, database.GitAuthLink{ - ProviderID: config.ID, - OAuthAccessToken: "initial", - OAuthExpiry: expired, + validateCalls := 0 + refreshCalls := 0 + fake, config, link := setupOauth2Test(t, testConfig{ + FakeIDPOpts: []oidctest.FakeIDPOpt{ + oidctest.WithRefreshHook(func(_ string) error { + refreshCalls++ + return nil + }), + oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) { + validateCalls++ + return jwt.MapClaims{}, nil + }), + }, + GitConfigOpt: func(cfg *gitauth.Config) { + cfg.Type = codersdk.GitProviderGitHub + }, + DB: db, + }) + + ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil)) + // Force a refresh + link.OAuthExpiry = expired + + updated, ok, err := config.RefreshToken(ctx, db, link) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, 1, validateCalls, "token is validated") + require.Equal(t, 1, refreshCalls, "token is refreshed") + require.NotEqualf(t, link.OAuthAccessToken, updated.OAuthAccessToken, "token is updated") + //nolint:gocritic // testing + dbLink, err := db.GetGitAuthLink(dbauthz.AsSystemRestricted(context.Background()), database.GetGitAuthLinkParams{ + ProviderID: link.ProviderID, + UserID: link.UserID, }) - _, valid, err := config.RefreshToken(context.Background(), db, link) require.NoError(t, err) - require.True(t, valid) + require.Equal(t, updated.OAuthAccessToken, dbLink.OAuthAccessToken, "token is updated in the DB") }) } @@ -314,3 +345,65 @@ func TestConvertYAML(t *testing.T) { require.Equal(t, "https://auth.com?client_id=id&redirect_uri=%2Fgitauth%2Fgitlab%2Fcallback&response_type=code&scope=read", config[0].AuthCodeURL("")) }) } + +type testConfig struct { + FakeIDPOpts []oidctest.FakeIDPOpt + CoderOIDCConfigOpts []func(cfg *coderd.OIDCConfig) + GitConfigOpt func(cfg *gitauth.Config) + // If DB is passed in, the link will be inserted into the DB. + DB database.Store +} + +// setupTest will configure a fake IDP and a gitauth.Config for testing. +// The Fake's userinfo endpoint is used for validating tokens. +// No http servers are started so use the fake IDP's HTTPClient to make requests. +// The returned token is a fully valid token for the IDP. Feel free to manipulate it +// to test different scenarios. +func setupOauth2Test(t *testing.T, settings testConfig) (*oidctest.FakeIDP, *gitauth.Config, database.GitAuthLink) { + t.Helper() + + const providerID = "test-idp" + fake := oidctest.NewFakeIDP(t, + append([]oidctest.FakeIDPOpt{}, settings.FakeIDPOpts...)..., + ) + + config := &gitauth.Config{ + OAuth2Config: fake.OIDCConfig(t, nil, settings.CoderOIDCConfigOpts...), + ID: providerID, + ValidateURL: fake.WellknownConfig().UserInfoURL, + } + settings.GitConfigOpt(config) + + oauthToken, err := fake.GenerateAuthenticatedToken(jwt.MapClaims{ + "email": "test@coder.com", + }) + require.NoError(t, err) + + now := time.Now() + link := database.GitAuthLink{ + ProviderID: providerID, + UserID: uuid.New(), + CreatedAt: now, + UpdatedAt: now, + OAuthAccessToken: oauthToken.AccessToken, + OAuthRefreshToken: oauthToken.RefreshToken, + // The caller can manually expire this if they want. + OAuthExpiry: now.Add(time.Hour), + } + + if settings.DB != nil { + // Feel free to insert additional things like the user, etc if required. + link, err = settings.DB.InsertGitAuthLink(context.Background(), database.InsertGitAuthLinkParams{ + ProviderID: link.ProviderID, + UserID: link.UserID, + CreatedAt: link.CreatedAt, + UpdatedAt: link.UpdatedAt, + OAuthAccessToken: link.OAuthAccessToken, + OAuthRefreshToken: link.OAuthRefreshToken, + OAuthExpiry: link.OAuthExpiry, + }) + require.NoError(t, err, "failed to insert link into DB") + } + + return fake, config, link +} From 74889002c709b6101c9cd691f66a71f05f889b74 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 31 Aug 2023 15:50:21 -0500 Subject: [PATCH 3/5] Linting --- coderd/coderdtest/oidctest/idp.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index 0970b4a10b486..b2b9c372da308 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -810,7 +810,7 @@ func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *co func httpErrorCode(defaultCode int, err error) int { var stautsErr statusHookError - var status = defaultCode + status := defaultCode if errors.As(err, &stautsErr) { status = stautsErr.HTTPStatusCode } From 399f91f5e02e9f0c8bec9704414c2bf333bd6733 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 5 Sep 2023 08:52:49 -0500 Subject: [PATCH 4/5] Merge issue, fix import --- coderd/gitauth/config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/gitauth/config.go b/coderd/gitauth/config.go index 490e37f0b46ce..f5810f733181d 100644 --- a/coderd/gitauth/config.go +++ b/coderd/gitauth/config.go @@ -68,7 +68,7 @@ func (c *Config) RefreshToken(ctx context.Context, db database.Store, gitAuthLin // If the time is set to 0, then it should never expire. // This is true for github, which has no expiry. !gitAuthLink.OAuthExpiry.IsZero() && - gitAuthLink.OAuthExpiry.Before(database.Now()) { + gitAuthLink.OAuthExpiry.Before(dbtime.Now()) { return gitAuthLink, false, nil } From 72cf5091adf92883ae694d03c980b04d038c7d1a Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 5 Sep 2023 08:55:00 -0500 Subject: [PATCH 5/5] WithRefreshHook -> WithRefresh --- coderd/coderdtest/oidctest/idp.go | 4 ++-- coderd/gitauth/config_test.go | 10 +++++----- coderd/userauth_test.go | 12 ++++++------ enterprise/coderd/userauth_test.go | 2 +- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index b2b9c372da308..6f060aea2c6b6 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -104,9 +104,9 @@ func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeID } } -// WithRefreshHook is called when a refresh token is used. The email is +// WithRefresh is called when a refresh token is used. The email is // the email of the user that is being refreshed assuming the claims are correct. -func WithRefreshHook(hook func(email string) error) func(*FakeIDP) { +func WithRefresh(hook func(email string) error) func(*FakeIDP) { return func(f *FakeIDP) { f.hookOnRefresh = hook } diff --git a/coderd/gitauth/config_test.go b/coderd/gitauth/config_test.go index 6915dfc6b4933..f6c97a440e3cb 100644 --- a/coderd/gitauth/config_test.go +++ b/coderd/gitauth/config_test.go @@ -33,7 +33,7 @@ func TestRefreshToken(t *testing.T) { t.Parallel() fake, config, link := setupOauth2Test(t, testConfig{ FakeIDPOpts: []oidctest.FakeIDPOpt{ - oidctest.WithRefreshHook(func(_ string) error { + oidctest.WithRefresh(func(_ string) error { t.Error("refresh on the IDP was called, but NoRefresh was set") return xerrors.New("should not be called") }), @@ -66,7 +66,7 @@ func TestRefreshToken(t *testing.T) { validated := false fake, config, link := setupOauth2Test(t, testConfig{ FakeIDPOpts: []oidctest.FakeIDPOpt{ - oidctest.WithRefreshHook(func(_ string) error { + oidctest.WithRefresh(func(_ string) error { t.Error("refresh on the IDP was called, but NoRefresh was set") return xerrors.New("should not be called") }), @@ -163,7 +163,7 @@ func TestRefreshToken(t *testing.T) { validateCalls := 0 fake, config, link := setupOauth2Test(t, testConfig{ FakeIDPOpts: []oidctest.FakeIDPOpt{ - oidctest.WithRefreshHook(func(_ string) error { + oidctest.WithRefresh(func(_ string) error { t.Error("refresh on the IDP was called, but the token is not expired") return xerrors.New("should not be called") }), @@ -197,7 +197,7 @@ func TestRefreshToken(t *testing.T) { validateCalls := 0 fake, config, link := setupOauth2Test(t, testConfig{ FakeIDPOpts: []oidctest.FakeIDPOpt{ - oidctest.WithRefreshHook(func(_ string) error { + oidctest.WithRefresh(func(_ string) error { t.Error("refresh on the IDP was called, but the token is not expired") return xerrors.New("should not be called") }), @@ -228,7 +228,7 @@ func TestRefreshToken(t *testing.T) { refreshCalls := 0 fake, config, link := setupOauth2Test(t, testConfig{ FakeIDPOpts: []oidctest.FakeIDPOpt{ - oidctest.WithRefreshHook(func(_ string) error { + oidctest.WithRefresh(func(_ string) error { refreshCalls++ return nil }), diff --git a/coderd/userauth_test.go b/coderd/userauth_test.go index 1f37a0721a1e7..fe6ded1e901b1 100644 --- a/coderd/userauth_test.go +++ b/coderd/userauth_test.go @@ -37,7 +37,7 @@ func TestOIDCOauthLoginWithExisting(t *testing.T) { t.Parallel() fake := oidctest.NewFakeIDP(t, - oidctest.WithRefreshHook(func(_ string) error { + oidctest.WithRefresh(func(_ string) error { return xerrors.New("refreshing token should never occur") }), oidctest.WithServing(), @@ -797,7 +797,7 @@ func TestUserOIDC(t *testing.T) { t.Run(tc.Name, func(t *testing.T) { t.Parallel() fake := oidctest.NewFakeIDP(t, - oidctest.WithRefreshHook(func(_ string) error { + oidctest.WithRefresh(func(_ string) error { return xerrors.New("refreshing token should never occur") }), oidctest.WithServing(), @@ -851,7 +851,7 @@ func TestUserOIDC(t *testing.T) { auditor := audit.NewMock() fake := oidctest.NewFakeIDP(t, - oidctest.WithRefreshHook(func(_ string) error { + oidctest.WithRefresh(func(_ string) error { return xerrors.New("refreshing token should never occur") }), oidctest.WithServing(), @@ -898,7 +898,7 @@ func TestUserOIDC(t *testing.T) { t.Parallel() auditor := audit.NewMock() fake := oidctest.NewFakeIDP(t, - oidctest.WithRefreshHook(func(_ string) error { + oidctest.WithRefresh(func(_ string) error { return xerrors.New("refreshing token should never occur") }), oidctest.WithServing(), @@ -959,7 +959,7 @@ func TestUserOIDC(t *testing.T) { t.Run("NoIDToken", func(t *testing.T) { t.Parallel() fake := oidctest.NewFakeIDP(t, - oidctest.WithRefreshHook(func(_ string) error { + oidctest.WithRefresh(func(_ string) error { return xerrors.New("refreshing token should never occur") }), oidctest.WithServing(), @@ -984,7 +984,7 @@ func TestUserOIDC(t *testing.T) { badProvider := &oidc.Provider{} fake := oidctest.NewFakeIDP(t, - oidctest.WithRefreshHook(func(_ string) error { + oidctest.WithRefresh(func(_ string) error { return xerrors.New("refreshing token should never occur") }), oidctest.WithServing(), diff --git a/enterprise/coderd/userauth_test.go b/enterprise/coderd/userauth_test.go index 2927ea88b9d1e..9d7e2762f005e 100644 --- a/enterprise/coderd/userauth_test.go +++ b/enterprise/coderd/userauth_test.go @@ -365,7 +365,7 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ FakeOpts: []oidctest.FakeIDPOpt{ - oidctest.WithRefreshHook(func(_ string) error { + oidctest.WithRefresh(func(_ string) error { // Always "expired" refresh token. return xerrors.New("refresh token is expired") }),