From 46c2c62d1dd05fb5ea807fa79ce015a4b8ddeafd Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 23 Aug 2023 15:04:04 -0500 Subject: [PATCH 01/18] test: implement fake OIDC provider with full functionality --- coderd/coderdtest/oidctest/idp.go | 445 +++++++++++++++++++++++++ coderd/coderdtest/oidctest/idp_test.go | 65 ++++ 2 files changed, 510 insertions(+) create mode 100644 coderd/coderdtest/oidctest/idp.go create mode 100644 coderd/coderdtest/oidctest/idp_test.go diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go new file mode 100644 index 0000000000000..d18904db09f0d --- /dev/null +++ b/coderd/coderdtest/oidctest/idp.go @@ -0,0 +1,445 @@ +package oidctest + +import ( + "context" + "crypto" + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync" + "testing" + "time" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/go-chi/chi/v5" + "github.com/go-jose/go-jose/v3" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd" +) + +type FakeIDP struct { + issuer string + key *rsa.PrivateKey + provider providerJSON + handler http.Handler + + // clientID to be used by coderd + clientID string + clientSecret string + logger slog.Logger + + codeToStateMap sync.Map + accessTokens sync.Map + refreshTokens sync.Map + + // hooks + hookUserInfo func(token string) map[string]string +} + +func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) { + return func(f *FakeIDP) { + f.logger = slogtest.Make(t, options) + } +} + +func WithUserInfoHook(uf func(token string) map[string]string) func(*FakeIDP) { + return func(f *FakeIDP) { + f.hookUserInfo = uf + } +} + +func WithIssuer(issuer string) func(*FakeIDP) { + return func(f *FakeIDP) { + f.issuer = issuer + } +} + +const ( + authorizePath = "/oauth2/authorize" + tokenPath = "/oauth2/token" + keysPath = "/oauth2/keys" + userInfoPath = "/oauth2/userinfo" +) + +func NewFakeIDP(t testing.TB, opts ...func(idp *FakeIDP)) *FakeIDP { + t.Helper() + + block, _ := pem.Decode([]byte(testRSAPrivateKey)) + pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + require.NoError(t, err) + + idp := &FakeIDP{ + key: pkey, + clientID: uuid.NewString(), + clientSecret: uuid.NewString(), + logger: slog.Make(), + codeToStateMap: sync.Map{}, + accessTokens: sync.Map{}, + refreshTokens: sync.Map{}, + hookUserInfo: func(token string) map[string]string { + return map[string]string{} + }, + } + + for _, opt := range opts { + opt(idp) + } + + if idp.issuer == "" { + idp.issuer = "https://coder.com" + } + + u, err := url.Parse(idp.issuer) + require.NoError(t, err, "invalid issuer URL") + + // providerJSON is the JSON representation of the OpenID Connect provider + // These are all the urls that the IDP will respond to. + idp.provider = providerJSON{ + Issuer: idp.issuer, + AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(), + TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(), + JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(), + UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(), + Algorithms: []string{ + "RS256", + }, + } + idp.handler = idp.httpHandler(t) + + return idp +} + +func OIDCCallback(t testing.TB, cfg *coderd.OIDCConfig, cli *http.Client, state string) (*http.Response, error) { + t.Helper() + + url := cfg.AuthCodeURL(state) + req, err := http.NewRequest("GET", url, nil) + require.NoError(t, err) + + resp, err := cli.Do(req.WithContext(context.Background())) + require.NoError(t, err) + + return resp, nil +} + +type providerJSON struct { + Issuer string `json:"issuer"` + AuthURL string `json:"authorization_endpoint"` + TokenURL string `json:"token_endpoint"` + JWKSURL string `json:"jwks_uri"` + UserInfoURL string `json:"userinfo_endpoint"` + Algorithms []string `json:"id_token_signing_alg_values_supported"` +} + +// newCode enforces the code exchanged is actually a valid code +// created by the IDP. +func (f *FakeIDP) newCode(state string) string { + code := uuid.NewString() + f.codeToStateMap.Store(code, state) + return code +} + +// newToken enforces the access token exchanged is actually a valid access token +// created by the IDP. +func (f *FakeIDP) newToken(exp time.Time) string { + accessToken := uuid.NewString() + f.accessTokens.Store(accessToken, exp) + return accessToken +} + +func (f *FakeIDP) newRefreshTokens(exp time.Time) string { + refreshToken := uuid.NewString() + f.refreshTokens.Store(refreshToken, exp) + return refreshToken +} + +func (f *FakeIDP) authenticateBearerTokenRequest(t testing.TB, req *http.Request) (string, error) { + t.Helper() + + auth := req.Header.Get("Authorization") + token := strings.TrimPrefix(auth, "Bearer ") + _, ok := f.accessTokens.Load(token) + if !ok { + return "", xerrors.New("invalid access token") + } + return token, nil +} + +func (f *FakeIDP) authenticateOIDClientRequest(t testing.TB, req *http.Request) (url.Values, error) { + t.Helper() + + data, _ := io.ReadAll(req.Body) + values, err := url.ParseQuery(string(data)) + if !assert.NoError(t, err, "parse token request values") { + return nil, xerrors.New("invalid token request") + + } + + if !assert.Equal(t, f.clientID, values.Get("client_id"), "client_id mismatch") { + return nil, xerrors.New("client_id mismatch") + } + + if !assert.Equal(t, f.clientSecret, values.Get("client_secret"), "client_secret mismatch") { + return nil, xerrors.New("client_secret mismatch") + } + + return values, nil +} + +func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { + t.Helper() + + mux := chi.NewMux() + // This endpoint is required to initialize the OIDC provider. + // It is used to get the OIDC configuration. + mux.Get("/.well-known/openid-configuration", func(rw http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(rw).Encode(f.provider) + }) + + // Authorize is called when the user is redirected to the IDP to login. + // This is the browser hitting the IDP and the user logging into Google or + // w/e and clicking "Allow". They will be redirected back to the redirect + // when this is done. + mux.Handle(authorizePath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + f.logger.Info(r.Context(), "HTTP Call Authorize", slog.F("url", string(r.URL.String()))) + + clientID := r.URL.Query().Get("client_id") + if clientID != f.clientID { + t.Errorf("unexpected client_id %q", clientID) + http.Error(rw, "invalid client_id", http.StatusBadRequest) + } + + redirectURI := r.URL.Query().Get("redirect_uri") + state := r.URL.Query().Get("state") + + scope := r.URL.Query().Get("scope") + var _ = scope + + responseType := r.URL.Query().Get("response_type") + switch responseType { + case "code": + case "token": + t.Errorf("response_type %q not supported", responseType) + http.Error(rw, "invalid response_type", http.StatusBadRequest) + return + default: + t.Errorf("unexpected response_type %q", responseType) + http.Error(rw, "invalid response_type", http.StatusBadRequest) + return + } + + ru, err := url.Parse(redirectURI) + if err != nil { + t.Errorf("invalid redirect_uri %q", redirectURI) + http.Error(rw, "invalid redirect_uri", http.StatusBadRequest) + return + } + + q := ru.Query() + q.Set("state", state) + q.Set("code", f.newCode(state)) + ru.RawQuery = q.Encode() + + http.Redirect(rw, r, ru.String(), http.StatusTemporaryRedirect) + })) + + mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + values, err := f.authenticateOIDClientRequest(t, r) + f.logger.Info(r.Context(), "HTTP Call Token", + slog.Error(err), + slog.F("values", values.Encode()), + ) + if err != nil { + http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), http.StatusBadRequest) + return + } + + switch values.Get("grant_type") { + case "authorization_code": + code := values.Get("code") + if !assert.NotEmpty(t, code, "code is empty") { + http.Error(rw, "invalid code", http.StatusBadRequest) + return + } + _, ok := f.codeToStateMap.Load(code) + if !assert.True(t, ok, "invalid code") { + http.Error(rw, "invalid code", http.StatusBadRequest) + return + } + // Always invalidate the code after it is used. + f.codeToStateMap.Delete(code) + case "refresh_token": + refreshToken := values.Get("refresh_token") + if !assert.NotEmpty(t, refreshToken, "refresh_token is empty") { + http.Error(rw, "invalid refresh_token", http.StatusBadRequest) + return + } + + _, ok := f.refreshTokens.Load(refreshToken) + if !assert.True(t, ok, "invalid refresh_token") { + http.Error(rw, "invalid refresh_token", http.StatusBadRequest) + return + } + // Always invalidate the refresh token after it is used. + f.refreshTokens.Delete(refreshToken) + default: + t.Errorf("unexpected grant_type %q", values.Get("grant_type")) + http.Error(rw, "invalid grant_type", http.StatusBadRequest) + return + } + + exp := time.Now().Add(time.Minute * 5) + token := oauth2.Token{ + // Sometimes the access token is a jwt. Not going to do that here. + AccessToken: f.newToken(time.Now().Add(time.Minute * 5)), + RefreshToken: f.newRefreshTokens(time.Now().Add(time.Minute * 30)), + TokenType: "Bearer", + Expiry: exp, + } + + rw.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(rw).Encode(token) + })) + + mux.Handle(userInfoPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + token, err := f.authenticateBearerTokenRequest(t, r) + f.logger.Info(r.Context(), "HTTP Call UserInfo", + slog.Error(err), + ) + if err != nil { + http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusBadRequest) + return + } + + _ = json.NewEncoder(rw).Encode(f.hookUserInfo(token)) + })) + + mux.Handle(keysPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + f.logger.Info(r.Context(), "HTTP Call Keys") + set := jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + Key: f.key.Public(), + KeyID: "test-key", + Algorithm: "RSA", + }, + }, + } + _ = json.NewEncoder(rw).Encode(set) + })) + + mux.NotFound(func(rw http.ResponseWriter, r *http.Request) { + f.logger.Error(r.Context(), "HTTP Call NotFound", slog.F("path", r.URL.Path)) + t.Errorf("unexpected request to IDP at path %q. Not supported", r.URL.Path) + }) + + return mux +} + +// HTTPClient runs the IDP in memory and returns an http.Client that can be used +// to make requests to the IDP. All requests are handled in memory, and no network +// requests are made. +// +// If a request is not to the IDP, then the passed in client will be used. +// If no client is passed in, then any regular network requests will fail. +func (f *FakeIDP) HTTPClient(rest *http.Client) *http.Client { + return &http.Client{ + Transport: fakeRoundTripper{ + roundTrip: func(req *http.Request) (*http.Response, error) { + u, _ := url.Parse(f.issuer) + if req.URL.Host != u.Host { + if rest == nil { + return nil, fmt.Errorf("unexpected request to %q", req.URL.Host) + } + return rest.Do(req) + } + resp := httptest.NewRecorder() + f.handler.ServeHTTP(resp, req) + return resp.Result(), nil + }, + }, + } +} + +func (f *FakeIDP) OIDCConfig(t testing.TB, redirect string, scopes []string, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig { + t.Helper() + if len(scopes) == 0 { + scopes = []string{"openid", "email", "profile"} + } + + oauthCfg := &oauth2.Config{ + ClientID: f.clientID, + ClientSecret: f.clientSecret, + Endpoint: oauth2.Endpoint{ + AuthURL: f.provider.AuthURL, + TokenURL: f.provider.TokenURL, + AuthStyle: oauth2.AuthStyleInParams, + }, + RedirectURL: redirect, + Scopes: scopes, + } + + ctx := oidc.ClientContext(context.Background(), f.HTTPClient(nil)) + p, err := oidc.NewProvider(ctx, f.provider.Issuer) + require.NoError(t, err, "failed to create OIDC provider") + cfg := &coderd.OIDCConfig{ + OAuth2Config: oauthCfg, + Provider: p, + Verifier: oidc.NewVerifier(f.provider.Issuer, &oidc.StaticKeySet{ + PublicKeys: []crypto.PublicKey{f.key.Public()}, + }, &oidc.Config{ + ClientID: oauthCfg.ClientID, + SupportedSigningAlgs: []string{ + "RS256", + }, + // Todo: add support for Now() + }), + UsernameField: "preferred_username", + EmailField: "email", + AuthURLParams: map[string]string{"access_type": "offline"}, + } + + for _, opt := range opts { + opt(cfg) + } + + return cfg +} + +type fakeRoundTripper struct { + roundTrip func(req *http.Request) (*http.Response, error) +} + +func (f fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return f.roundTrip(req) +} + +const testRSAPrivateKey = `-----BEGIN RSA PRIVATE KEY----- +MIICXQIBAAKBgQDLets8+7M+iAQAqN/5BVyCIjhTQ4cmXulL+gm3v0oGMWzLupUS +v8KPA+Tp7dgC/DZPfMLaNH1obBBhJ9DhS6RdS3AS3kzeFrdu8zFHLWF53DUBhS92 +5dCAEuJpDnNizdEhxTfoHrhuCmz8l2nt1pe5eUK2XWgd08Uc93h5ij098wIDAQAB +AoGAHLaZeWGLSaen6O/rqxg2laZ+jEFbMO7zvOTruiIkL/uJfrY1kw+8RLIn+1q0 +wLcWcuEIHgKKL9IP/aXAtAoYh1FBvRPLkovF1NZB0Je/+CSGka6wvc3TGdvppZJe +rKNcUvuOYLxkmLy4g9zuY5qrxFyhtIn2qZzXEtLaVOHzPQECQQDvN0mSajpU7dTB +w4jwx7IRXGSSx65c+AsHSc1Rj++9qtPC6WsFgAfFN2CEmqhMbEUVGPv/aPjdyWk9 +pyLE9xR/AkEA2cGwyIunijE5v2rlZAD7C4vRgdcMyCf3uuPcgzFtsR6ZhyQSgLZ8 +YRPuvwm4cdPJMmO3YwBfxT6XGuSc2k8MjQJBAI0+b8prvpV2+DCQa8L/pjxp+VhR +Xrq2GozrHrgR7NRokTB88hwFRJFF6U9iogy9wOx8HA7qxEbwLZuhm/4AhbECQC2a +d8h4Ht09E+f3nhTEc87mODkl7WJZpHL6V2sORfeq/eIkds+H6CJ4hy5w/bSw8tjf +sz9Di8sGIaUbLZI2rd0CQQCzlVwEtRtoNCyMJTTrkgUuNufLP19RZ5FpyXxBO5/u +QastnN77KfUwdj3SJt44U/uh1jAIv4oSLBr8HYUkbnI8 +-----END RSA PRIVATE KEY-----` diff --git a/coderd/coderdtest/oidctest/idp_test.go b/coderd/coderdtest/oidctest/idp_test.go new file mode 100644 index 0000000000000..b7de8bd9dbe6d --- /dev/null +++ b/coderd/coderdtest/oidctest/idp_test.go @@ -0,0 +1,65 @@ +package oidctest_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/coder/coder/v2/coderd/coderdtest/oidctest" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +// TestFakeIDPBasicFlow tests the basic flow of the fake IDP. +func TestFakeIDPBasicFlow(t *testing.T) { + fake := oidctest.NewFakeIDP(t, oidctest.WithLogging(t, nil)) + + var handler http.Handler + srv := httptest.NewServer(http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler.ServeHTTP(w, r) + }))) + defer srv.Close() + + cfg := fake.OIDCConfig(t, srv.URL, nil) + cli := fake.HTTPClient(nil) + ctx := oidc.ClientContext(context.Background(), cli) + + const expectedState = "random-state" + var token *oauth2.Token + // This is the Coder callback using an actual network request. + handler = http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Emulate OIDC flow + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") + assert.Equal(t, expectedState, state, "state mismatch") + + oauthToken, err := cfg.Exchange(ctx, code) + if assert.NoError(t, err, "failed to exchange code") { + assert.NotEmpty(t, oauthToken.AccessToken, "access token is empty") + assert.NotEmpty(t, oauthToken.RefreshToken, "refresh token is empty") + } + token = oauthToken + })) + + resp, err := oidctest.OIDCCallback(t, cfg, fake.HTTPClient(srv.Client()), expectedState) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Test the user info + _, err = cfg.Provider.UserInfo(ctx, oauth2.StaticTokenSource(token)) + require.NoError(t, err) + + // Now test it can refresh + refreshed, err := cfg.TokenSource(ctx, &oauth2.Token{ + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + Expiry: time.Now().Add(time.Minute * -1), + }).Token() + require.NoError(t, err, "failed to refresh token") + require.NotEmpty(t, refreshed.AccessToken, "access token is empty on refresh") +} From 7d16bc1a4f0dc2c478d629c3e66960357e6983ab Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 24 Aug 2023 10:58:03 -0500 Subject: [PATCH 02/18] full idp fake for oidc Begin work on unit testing refactor --- coderd/coderdtest/coderdtest.go | 2 +- coderd/coderdtest/oidctest/idp.go | 307 +++++++++++++++++++++---- coderd/coderdtest/oidctest/idp_test.go | 14 +- coderd/coderdtest/oidctest/map.go | 71 ++++++ enterprise/coderd/userauth_test.go | 256 +++++++++++++++------ 5 files changed, 527 insertions(+), 123 deletions(-) create mode 100644 coderd/coderdtest/oidctest/map.go diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index bc2cb5e5925a0..b12cb3812ccff 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -33,7 +33,7 @@ import ( "cloud.google.com/go/compute/metadata" "github.com/coreos/go-oidc/v3/oidc" "github.com/fullsailor/pkcs7" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" "github.com/moby/moby/pkg/namesgenerator" "github.com/prometheus/client_golang/prometheus" diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index d18904db09f0d..b925811c5d6d6 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -9,25 +9,29 @@ import ( "encoding/pem" "fmt" "io" + "net" "net/http" + "net/http/cookiejar" "net/http/httptest" "net/url" "strings" - "sync" "testing" "time" - "cdr.dev/slog" - "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/codersdk" + "github.com/coreos/go-oidc/v3/oidc" "github.com/go-chi/chi/v5" "github.com/go-jose/go-jose/v3" + "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" "golang.org/x/xerrors" + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd" ) @@ -36,18 +40,30 @@ type FakeIDP struct { key *rsa.PrivateKey provider providerJSON handler http.Handler + cfg *oauth2.Config // clientID to be used by coderd clientID string clientSecret string logger slog.Logger - codeToStateMap sync.Map - accessTokens sync.Map - refreshTokens sync.Map + codeToStateMap *SyncMap[string, string] + // Token -> Email + accessTokens *SyncMap[string, string] + // Refresh Token -> Email + refreshTokensUsed *SyncMap[string, bool] + refreshTokens *SyncMap[string, string] + stateToIDTokenClaims *SyncMap[string, jwt.MapClaims] + refreshIDTokenClaims *SyncMap[string, jwt.MapClaims] // hooks - hookUserInfo func(token string) map[string]string + hookUserInfo func(email string) jwt.MapClaims + hookIDTokenClaims jwt.MapClaims + fakeCoderd func(req *http.Request) (*http.Response, error) + // Optional if you want to use a real http network request assuming + // it is not directed to the IDP. + defaultClient *http.Client + serve bool } func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) { @@ -56,15 +72,23 @@ func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) { } } -func WithUserInfoHook(uf func(token string) map[string]string) func(*FakeIDP) { +func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) { return func(f *FakeIDP) { - f.hookUserInfo = uf + f.hookUserInfo = func(_ string) jwt.MapClaims { + return info + } } } -func WithIssuer(issuer string) func(*FakeIDP) { +func WithDynamicUserInfo(userInfoFunc func(email string) jwt.MapClaims) func(*FakeIDP) { return func(f *FakeIDP) { - f.issuer = issuer + f.hookUserInfo = userInfoFunc + } +} + +func WithServing() func(*FakeIDP) { + return func(f *FakeIDP) { + f.serve = true } } @@ -83,17 +107,19 @@ func NewFakeIDP(t testing.TB, opts ...func(idp *FakeIDP)) *FakeIDP { require.NoError(t, err) idp := &FakeIDP{ - key: pkey, - clientID: uuid.NewString(), - clientSecret: uuid.NewString(), - logger: slog.Make(), - codeToStateMap: sync.Map{}, - accessTokens: sync.Map{}, - refreshTokens: sync.Map{}, - hookUserInfo: func(token string) map[string]string { - return map[string]string{} - }, + key: pkey, + clientID: uuid.NewString(), + clientSecret: uuid.NewString(), + logger: slog.Make(), + codeToStateMap: NewSyncMap[string, string](), + accessTokens: NewSyncMap[string, string](), + refreshTokens: NewSyncMap[string, string](), + refreshTokensUsed: NewSyncMap[string, bool](), + stateToIDTokenClaims: NewSyncMap[string, jwt.MapClaims](), + refreshIDTokenClaims: NewSyncMap[string, jwt.MapClaims](), + hookUserInfo: func(email string) jwt.MapClaims { return jwt.MapClaims{} }, } + idp.handler = idp.httpHandler(t) for _, opt := range opts { opt(idp) @@ -103,13 +129,25 @@ func NewFakeIDP(t testing.TB, opts ...func(idp *FakeIDP)) *FakeIDP { idp.issuer = "https://coder.com" } - u, err := url.Parse(idp.issuer) + idp.updateIssuerURL(t, idp.issuer) + if idp.serve { + idp.Serve(t) + } + + return idp +} + +func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { + t.Helper() + + u, err := url.Parse(issuer) require.NoError(t, err, "invalid issuer URL") + f.issuer = issuer // providerJSON is the JSON representation of the OpenID Connect provider // These are all the urls that the IDP will respond to. - idp.provider = providerJSON{ - Issuer: idp.issuer, + f.provider = providerJSON{ + Issuer: issuer, AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(), TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(), JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(), @@ -118,21 +156,103 @@ func NewFakeIDP(t testing.TB, opts ...func(idp *FakeIDP)) *FakeIDP { "RS256", }, } - idp.handler = idp.httpHandler(t) +} - return idp +// Serve is optional, but turns the FakeIDP into a real http server. +func (f *FakeIDP) Serve(t testing.TB) *httptest.Server { + t.Helper() + + ctx, cancel := context.WithCancel(context.Background()) + srv := httptest.NewUnstartedServer(f.handler) + srv.Config.BaseContext = func(_ net.Listener) context.Context { + return ctx + } + srv.Start() + t.Cleanup(srv.CloseClientConnections) + t.Cleanup(srv.Close) + t.Cleanup(cancel) + + f.updateIssuerURL(t, srv.URL) + return srv } -func OIDCCallback(t testing.TB, cfg *coderd.OIDCConfig, cli *http.Client, state string) (*http.Response, error) { +// LoginClient does the full OIDC flow starting at the "LoginButton". +// The client argument is just to get the URL of the Coder instance. +func (f *FakeIDP) LoginClient(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims) (*codersdk.Client, *http.Response) { t.Helper() - url := cfg.AuthCodeURL(state) - req, err := http.NewRequest("GET", url, nil) + coderOauthURL, err := client.URL.Parse("/api/v2/users/oidc/callback") + require.NoError(t, err) + f.SetRedirect(t, coderOauthURL.String()) + + cli := f.HTTPClient(client.HTTPClient) + shallowCpyCli := &(*cli) + shallowCpyCli.CheckRedirect = func(req *http.Request, via []*http.Request) error { + // Store the idTokenClaims to the specific state request. This ties + // the claims 1:1 with a given authentication flow. + state := req.URL.Query().Get("state") + f.stateToIDTokenClaims.Store(state, idTokenClaims) + return nil + } + + req, err := http.NewRequestWithContext(context.Background(), "GET", coderOauthURL.String(), nil) + require.NoError(t, err) + if shallowCpyCli.Jar == nil { + shallowCpyCli.Jar, err = cookiejar.New(nil) + require.NoError(t, err, "failed to create cookie jar") + } + + res, err := shallowCpyCli.Do(req) + require.NoError(t, err) + + // If the coder session token exists, return the new authed client! + var user *codersdk.Client + cookies := shallowCpyCli.Jar.Cookies(client.URL) + for _, cookie := range cookies { + if cookie.Name == codersdk.SessionTokenCookie { + user = codersdk.New(client.URL) + user.SetSessionToken(cookie.Value) + } + } + + t.Cleanup(func() { + if res.Body != nil { + res.Body.Close() + } + }) + return user, res +} + +// OIDCCallback will emulate the IDP redirecting back to the Coder callback. +// This is helpful if no Coderd exists. +func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.MapClaims) (*http.Response, error) { + t.Helper() + f.stateToIDTokenClaims.Store(state, idTokenClaims) + + baseCli := http.DefaultClient + if f.fakeCoderd != nil { + baseCli = &http.Client{ + Transport: fakeRoundTripper{ + roundTrip: func(req *http.Request) (*http.Response, error) { + return f.fakeCoderd(req) + }, + }, + } + } + + cli := f.HTTPClient(baseCli) + u := f.cfg.AuthCodeURL(state) + req, err := http.NewRequest("GET", u, nil) require.NoError(t, err) resp, err := cli.Do(req.WithContext(context.Background())) require.NoError(t, err) + t.Cleanup(func() { + if resp.Body != nil { + resp.Body.Close() + } + }) return resp, nil } @@ -155,15 +275,15 @@ func (f *FakeIDP) newCode(state string) string { // newToken enforces the access token exchanged is actually a valid access token // created by the IDP. -func (f *FakeIDP) newToken(exp time.Time) string { +func (f *FakeIDP) newToken(email string) string { accessToken := uuid.NewString() - f.accessTokens.Store(accessToken, exp) + f.accessTokens.Store(accessToken, email) return accessToken } -func (f *FakeIDP) newRefreshTokens(exp time.Time) string { +func (f *FakeIDP) newRefreshTokens(email string) string { refreshToken := uuid.NewString() - f.refreshTokens.Store(refreshToken, exp) + f.refreshTokens.Store(refreshToken, email) return refreshToken } @@ -200,6 +320,27 @@ func (f *FakeIDP) authenticateOIDClientRequest(t testing.TB, req *http.Request) return values, nil } +func (f *FakeIDP) encodeClaims(t testing.TB, claims jwt.MapClaims) string { + t.Helper() + + if _, ok := claims["exp"]; !ok { + claims["exp"] = time.Now().Add(time.Hour).UnixMilli() + } + + if _, ok := claims["aud"]; !ok { + claims["aud"] = f.clientID + } + + if _, ok := claims["iss"]; !ok { + claims["iss"] = f.issuer + } + + signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(f.key) + require.NoError(t, err) + + return signed +} + func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { t.Helper() @@ -268,6 +409,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { return } + var claims jwt.MapClaims switch values.Get("grant_type") { case "authorization_code": code := values.Get("code") @@ -275,13 +417,21 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { http.Error(rw, "invalid code", http.StatusBadRequest) return } - _, ok := f.codeToStateMap.Load(code) + stateStr, ok := f.codeToStateMap.Load(code) if !assert.True(t, ok, "invalid code") { http.Error(rw, "invalid code", http.StatusBadRequest) return } // Always invalidate the code after it is used. f.codeToStateMap.Delete(code) + + idTokenClaims, ok := f.stateToIDTokenClaims.Load(stateStr) + if !ok { + t.Errorf("missing id token claims") + http.Error(rw, "missing id token claims", http.StatusBadRequest) + return + } + claims = idTokenClaims case "refresh_token": refreshToken := values.Get("refresh_token") if !assert.NotEmpty(t, refreshToken, "refresh_token is empty") { @@ -296,6 +446,15 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { } // Always invalidate the refresh token after it is used. f.refreshTokens.Delete(refreshToken) + + idTokenClaims, ok := f.refreshIDTokenClaims.Load(refreshToken) + if !ok { + t.Errorf("missing id token claims in refresh") + http.Error(rw, "missing id token claims in refresh", http.StatusBadRequest) + return + } + claims = idTokenClaims + f.refreshTokensUsed.Store(refreshToken, true) default: t.Errorf("unexpected grant_type %q", values.Get("grant_type")) http.Error(rw, "invalid grant_type", http.StatusBadRequest) @@ -303,13 +462,22 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { } exp := time.Now().Add(time.Minute * 5) - token := oauth2.Token{ - // Sometimes the access token is a jwt. Not going to do that here. - AccessToken: f.newToken(time.Now().Add(time.Minute * 5)), - RefreshToken: f.newRefreshTokens(time.Now().Add(time.Minute * 30)), - TokenType: "Bearer", - Expiry: exp, + claims["exp"] = exp.UnixMilli() + email, ok := claims["email"] + if !ok || email.(string) == "" { + email = "unknown" + } + refreshToken := f.newRefreshTokens(email.(string)) + token := map[string]interface{}{ + "access_token": f.newToken(email.(string)), + "refresh_token": refreshToken, + "token_type": "Bearer", + "expires_in": int64(time.Minute * 5), + "expiry": exp.Unix(), + "id_token": f.encodeClaims(t, claims), } + // Store the claims for the next refresh + f.refreshIDTokenClaims.Store(refreshToken, claims) rw.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(rw).Encode(token) @@ -324,8 +492,15 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusBadRequest) return } + var _ = token - _ = json.NewEncoder(rw).Encode(f.hookUserInfo(token)) + email, ok := f.accessTokens.Load(token) + if !ok { + t.Errorf("access token user for user_info has no email to indicate which user") + http.Error(rw, "invalid access token, missing user info", http.StatusBadRequest) + return + } + _ = json.NewEncoder(rw).Encode(f.hookUserInfo(email)) })) mux.Handle(keysPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { @@ -357,15 +532,21 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { // If a request is not to the IDP, then the passed in client will be used. // If no client is passed in, then any regular network requests will fail. func (f *FakeIDP) HTTPClient(rest *http.Client) *http.Client { + if f.serve { + if rest == nil { + return http.DefaultClient + } + return rest + } return &http.Client{ Transport: fakeRoundTripper{ roundTrip: func(req *http.Request) (*http.Response, error) { u, _ := url.Parse(f.issuer) if req.URL.Host != u.Host { if rest == nil { - return nil, fmt.Errorf("unexpected request to %q", req.URL.Host) + return nil, fmt.Errorf("unexpected network request to %q", req.URL.Host) } - return rest.Do(req) + return rest.Transport.RoundTrip(req) } resp := httptest.NewRecorder() f.handler.ServeHTTP(resp, req) @@ -375,7 +556,39 @@ func (f *FakeIDP) HTTPClient(rest *http.Client) *http.Client { } } -func (f *FakeIDP) OIDCConfig(t testing.TB, redirect string, scopes []string, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig { +// RefreshUsed returns if the refresh token has been used. All refresh tokens +// can only be used once, then they are deleted. +func (f *FakeIDP) RefreshUsed(refreshToken string) bool { + used, _ := f.refreshTokensUsed.Load(refreshToken) + return used +} + +func (f *FakeIDP) UpdateRefreshClaims(refreshToken string, claims jwt.MapClaims) { + f.refreshIDTokenClaims.Store(refreshToken, claims) +} + +func (f *FakeIDP) SetRedirect(t testing.TB, url string) { + t.Helper() + + f.cfg.RedirectURL = url +} + +func (f *FakeIDP) SetCoderdCallback(callback func(req *http.Request) (*http.Response, error)) { + f.fakeCoderd = callback +} + +func (f *FakeIDP) SetCoderdCallbackHandler(handler http.HandlerFunc) { + if f.serve { + panic("cannot set callback handler when using 'WithServing'. Must implement an actual 'Coderd'") + } + f.fakeCoderd = func(req *http.Request) (*http.Response, error) { + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + return resp.Result(), nil + } +} + +func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig { t.Helper() if len(scopes) == 0 { scopes = []string{"openid", "email", "profile"} @@ -389,7 +602,9 @@ func (f *FakeIDP) OIDCConfig(t testing.TB, redirect string, scopes []string, opt TokenURL: f.provider.TokenURL, AuthStyle: oauth2.AuthStyleInParams, }, - RedirectURL: redirect, + // If the user is using a real network request, they will need to do + // 'fake.SetRedirect()' + RedirectURL: "https://redirect.com", Scopes: scopes, } @@ -417,6 +632,8 @@ func (f *FakeIDP) OIDCConfig(t testing.TB, redirect string, scopes []string, opt opt(cfg) } + f.cfg = oauthCfg + return cfg } diff --git a/coderd/coderdtest/oidctest/idp_test.go b/coderd/coderdtest/oidctest/idp_test.go index b7de8bd9dbe6d..a8098c874787f 100644 --- a/coderd/coderdtest/oidctest/idp_test.go +++ b/coderd/coderdtest/oidctest/idp_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt/v4" + "github.com/stretchr/testify/assert" "github.com/coder/coder/v2/coderd/coderdtest/oidctest" @@ -17,7 +19,9 @@ import ( // TestFakeIDPBasicFlow tests the basic flow of the fake IDP. func TestFakeIDPBasicFlow(t *testing.T) { - fake := oidctest.NewFakeIDP(t, oidctest.WithLogging(t, nil)) + fake := oidctest.NewFakeIDP(t, + oidctest.WithLogging(t, nil), + ) var handler http.Handler srv := httptest.NewServer(http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -25,14 +29,14 @@ func TestFakeIDPBasicFlow(t *testing.T) { }))) defer srv.Close() - cfg := fake.OIDCConfig(t, srv.URL, nil) + cfg := fake.OIDCConfig(t, nil) cli := fake.HTTPClient(nil) ctx := oidc.ClientContext(context.Background(), cli) const expectedState = "random-state" var token *oauth2.Token // This is the Coder callback using an actual network request. - handler = http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fake.SetCoderdCallbackHandler(func(w http.ResponseWriter, r *http.Request) { // Emulate OIDC flow code := r.URL.Query().Get("code") state := r.URL.Query().Get("state") @@ -44,9 +48,9 @@ func TestFakeIDPBasicFlow(t *testing.T) { assert.NotEmpty(t, oauthToken.RefreshToken, "refresh token is empty") } token = oauthToken - })) + }) - resp, err := oidctest.OIDCCallback(t, cfg, fake.HTTPClient(srv.Client()), expectedState) + resp, err := fake.OIDCCallback(t, expectedState, jwt.MapClaims{}) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) diff --git a/coderd/coderdtest/oidctest/map.go b/coderd/coderdtest/oidctest/map.go new file mode 100644 index 0000000000000..864e4ae926ee1 --- /dev/null +++ b/coderd/coderdtest/oidctest/map.go @@ -0,0 +1,71 @@ +package oidctest + +import "sync" + +type SyncMap[K, V any] struct { + m sync.Map +} + +func NewSyncMap[K, V any]() *SyncMap[K, V] { + return &SyncMap[K, V]{ + m: sync.Map{}, + } +} + +func (s *SyncMap[K, V]) Store(k K, v V) { + s.m.Store(k, v) +} + +func (s *SyncMap[K, V]) Load(key K) (value V, ok bool) { + v, ok := s.m.Load(key) + if !ok { + var empty V + return empty, false + } + return v.(V), ok +} + +func (m *SyncMap[K, V]) Delete(key K) { + m.m.Delete(key) +} + +func (m *SyncMap[K, V]) LoadAndDelete(key K) (actual V, loaded bool) { + act, loaded := m.m.LoadAndDelete(key) + if !loaded { + var empty V + return empty, loaded + } + return act.(V), loaded +} + +func (m *SyncMap[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { + act, loaded := m.m.LoadOrStore(key, value) + if !loaded { + var empty V + return empty, loaded + } + return act.(V), loaded +} + +func (m *SyncMap[K, V]) CompareAndSwap(key K, old V, new V) bool { + return m.m.CompareAndSwap(key, old, new) +} + +func (m *SyncMap[K, V]) CompareAndDelete(key K, old V) (deleted bool) { + return m.m.CompareAndDelete(key, old) +} + +func (m *SyncMap[K, V]) Swap(key K, value V) (previous any, loaded bool) { + previous, loaded = m.m.Swap(key, value) + if !loaded { + var empty V + return empty, loaded + } + return previous.(V), loaded +} + +func (m *SyncMap[K, V]) Range(f func(key K, value V) bool) { + m.m.Range(func(key, value interface{}) bool { + return f(key.(K), value.(V)) + }) +} diff --git a/enterprise/coderd/userauth_test.go b/enterprise/coderd/userauth_test.go index d6f6db3cbedbd..9ba1b7b4a10ab 100644 --- a/enterprise/coderd/userauth_test.go +++ b/enterprise/coderd/userauth_test.go @@ -7,16 +7,19 @@ import ( "net/http" "regexp" "testing" + "time" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/coderdtest/oidctest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/codersdk" @@ -31,97 +34,67 @@ func TestUserOIDC(t *testing.T) { t.Run("RoleSync", func(t *testing.T) { t.Parallel() + // NoRoles is the "control group". It has claims with 0 roles + // assigned, and asserts that the user has no roles. t.Run("NoRoles", func(t *testing.T) { t.Parallel() - ctx := testutil.Context(t, testutil.WaitMedium) - conf := coderdtest.NewOIDCConfig(t, "") - - oidcRoleName := "TemplateAuthor" - - config := conf.OIDCConfig(t, jwt.MapClaims{}, func(cfg *coderd.OIDCConfig) { - cfg.UserRoleMapping = map[string][]string{oidcRoleName: {rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()}} - }) - config.AllowSignups = true - config.UserRoleField = "roles" - - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - OIDCConfig: config, - }, - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{codersdk.FeatureUserRoleManagement: 1}, + const oidcRoleName = "TemplateAuthor" + runner := setupOIDCTest(t, oidcTestConfig{ + Config: func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.UserRoleField = "roles" }, }) - admin, err := client.User(ctx, "me") - require.NoError(t, err) - require.Len(t, admin.OrganizationIDs, 1) - - resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{ + claims := jwt.MapClaims{ "email": "alice@coder.com", - })) - require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - user, err := client.User(ctx, "alice") - require.NoError(t, err) - - require.Len(t, user.Roles, 0) - roleNames := []string{} - require.ElementsMatch(t, roleNames, []string{}) + } + // Login a new client that signs up + client, resp := runner.Login(claims) + require.Equal(t, http.StatusOK, resp.StatusCode) + // User should be in 0 groups. + runner.AssertRoles(t, "alice", []string{}) + // Force a refresh, and assert nothing has changes + runner.ForceRefresh(t, client, claims)(t) + runner.AssertRoles(t, "alice", []string{}) }) - t.Run("NewUserAndRemoveRoles", func(t *testing.T) { + // A user has some roles, then on an oauth refresh will lose said + // roles from an updated claim. + t.Run("NewUserAndRemoveRolesOnRefresh", func(t *testing.T) { + t.Skip("Refreshing tokens does not update roles :(") t.Parallel() - ctx := testutil.Context(t, testutil.WaitMedium) - conf := coderdtest.NewOIDCConfig(t, "") - - oidcRoleName := "TemplateAuthor" - - config := conf.OIDCConfig(t, jwt.MapClaims{}, func(cfg *coderd.OIDCConfig) { - cfg.UserRoleMapping = map[string][]string{oidcRoleName: {rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()}} - }) - config.AllowSignups = true - config.UserRoleField = "roles" - - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - OIDCConfig: config, - }, - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{codersdk.FeatureUserRoleManagement: 1}, + const oidcRoleName = "TemplateAuthor" + runner := setupOIDCTest(t, oidcTestConfig{ + Userinfo: jwt.MapClaims{oidcRoleName: []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()}}, + Config: func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.UserRoleField = "roles" + cfg.UserRoleMapping = map[string][]string{ + oidcRoleName: {rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()}, + } }, }) - admin, err := client.User(ctx, "me") - require.NoError(t, err) - require.Len(t, admin.OrganizationIDs, 1) - - resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{ + // User starts with the owner role + client, resp := runner.Login(jwt.MapClaims{ "email": "alice@coder.com", "roles": []string{"random", oidcRoleName, rbac.RoleOwner()}, - })) - require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - _ = resp.Body.Close() - user, err := client.User(ctx, "alice") - require.NoError(t, err) - - require.Len(t, user.Roles, 3) - roleNames := []string{user.Roles[0].Name, user.Roles[1].Name, user.Roles[2].Name} - require.ElementsMatch(t, roleNames, []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin(), rbac.RoleOwner()}) + }) + require.Equal(t, http.StatusOK, resp.StatusCode) + runner.AssertRoles(t, "alice", []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin(), rbac.RoleOwner()}) - // Now remove the roles with a new oidc login - resp = oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{ + // Now refresh the oauth, and check the roles are removed. + // Force a refresh, and assert nothing has changes + runner.ForceRefresh(t, client, jwt.MapClaims{ "email": "alice@coder.com", "roles": []string{"random"}, - })) - require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - _ = resp.Body.Close() - user, err = client.User(ctx, "alice") - require.NoError(t, err) - - require.Len(t, user.Roles, 0) + })(t) + runner.AssertRoles(t, "alice", []string{}) }) + t.Run("BlockAssignRoles", func(t *testing.T) { t.Parallel() @@ -588,3 +561,142 @@ func oidcCallback(t *testing.T, client *codersdk.Client, code string) *http.Resp t.Log(string(data)) return res } + +// oidcTestRunner is just a helper to setup and run oidc tests. +// An actual Coderd instance is used to run the tests. +type oidcTestRunner struct { + AdminClient *codersdk.Client + AdminUser codersdk.User + + // Login will call the OIDC flow with an unauthenticated client. + // The customer actions will all be taken care of, and the idToken claims + // will be returned. + Login func(idToken jwt.MapClaims) (*codersdk.Client, *http.Response) + // ForceRefresh will use an authenticated codersdk.Client, and force their + // OIDC token to be expired and require a refresh. The refresh will use the claims provided. + // + // The client MUST be used to actually trigger the refresh. This just + // expires the oauth token so the next authenticated API call will + // trigger a refresh. The returned function is an example of said call. + // It just calls the /users/me endpoint to trigger the refresh. + ForceRefresh func(t *testing.T, client *codersdk.Client, idToken jwt.MapClaims) func(t *testing.T) +} + +type oidcTestConfig struct { + Userinfo jwt.MapClaims + + // Config allows modifying the Coderd OIDC configuration. + Config func(cfg *coderd.OIDCConfig) +} + +func (r *oidcTestRunner) AssertRoles(t *testing.T, userIdent string, roles []string) { + t.Helper() + + ctx := testutil.Context(t, testutil.WaitMedium) + user, err := r.AdminClient.User(ctx, userIdent) + require.NoError(t, err) + + roleNames := []string{} + for _, role := range user.Roles { + roleNames = append(roleNames, role.Name) + } + require.ElementsMatch(t, roles, roleNames, "expected roles") +} + +func (r *oidcTestRunner) AssertGroups(t *testing.T, userIdent string, groups []string) { + ctx := testutil.Context(t, testutil.WaitMedium) + user, err := r.AdminClient.User(ctx, userIdent) + require.NoError(t, err) + + allGroups, err := r.AdminClient.GroupsByOrganization(ctx, user.OrganizationIDs[0]) + require.NoError(t, err) + + userInGroups := []string{} + for _, g := range allGroups { + for _, mem := range g.Members { + if mem.ID == user.ID { + userInGroups = append(userInGroups, g.Name) + } + } + } + + require.ElementsMatch(t, groups, userInGroups, "expected groups") +} + +func setupOIDCTest(t *testing.T, settings oidcTestConfig) *oidcTestRunner { + t.Helper() + + fake := oidctest.NewFakeIDP(t, + oidctest.WithStaticUserInfo(settings.Userinfo), + oidctest.WithLogging(t, nil), + // Run fake IDP on a real webserver + oidctest.WithServing(), + ) + + ctx := testutil.Context(t, testutil.WaitMedium) + cfg := fake.OIDCConfig(t, nil, settings.Config) + client, _, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + OIDCConfig: cfg, + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{codersdk.FeatureUserRoleManagement: 1}, + }, + }) + admin, err := client.User(ctx, "me") + require.NoError(t, err) + unauthenticatedClient := codersdk.New(client.URL) + + return &oidcTestRunner{ + AdminClient: client, + AdminUser: admin, + Login: func(idToken jwt.MapClaims) (*codersdk.Client, *http.Response) { + return fake.LoginClient(t, unauthenticatedClient, idToken) + }, + ForceRefresh: func(t *testing.T, client *codersdk.Client, idToken jwt.MapClaims) (authenticatedCall func(t *testing.T)) { + t.Helper() + + //nolint:gocritic // Testing + ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitMedium)) + + id, _, err := httpmw.SplitAPIToken(client.SessionToken()) + require.NoError(t, err) + + // We need to get the OIDC link and update it in the database to force + // it to be expired. + key, err := api.Database.GetAPIKeyByID(ctx, id) + require.NoError(t, err, "get api key") + + link, err := api.Database.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{ + UserID: key.UserID, + LoginType: database.LoginTypeOIDC, + }) + require.NoError(t, err, "get user link") + + // Updates the claims that the IDP will return. By default, it always + // uses the original claims for the original oauth token. + fake.UpdateRefreshClaims(link.OAuthRefreshToken, idToken) + + // Fetch the oauth link for the given user. + _, err = api.Database.UpdateUserLink(ctx, database.UpdateUserLinkParams{ + OAuthAccessToken: link.OAuthAccessToken, + OAuthRefreshToken: link.OAuthRefreshToken, + OAuthExpiry: time.Now().Add(time.Hour * -1), + UserID: key.UserID, + LoginType: database.LoginTypeOIDC, + }) + require.NoError(t, err, "expire user link") + t.Cleanup(func() { + require.True(t, fake.RefreshUsed(link.OAuthRefreshToken), "refresh token must be used, but has not. Did you forget to call the returned function from this call?") + }) + + return func(t *testing.T) { + t.Helper() + + // Do any authenticated call to force the refresh + _, err := client.User(testutil.Context(t, testutil.WaitShort), "me") + require.NoError(t, err, "user must be able to be fetched") + } + }, + } +} From 8bfd42810c9bf3014a6f13c98fbaf30393cc38d9 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 24 Aug 2023 11:33:03 -0500 Subject: [PATCH 03/18] Refactor existing tests --- enterprise/coderd/userauth_test.go | 426 +++++++++++++++-------------- 1 file changed, 222 insertions(+), 204 deletions(-) diff --git a/enterprise/coderd/userauth_test.go b/enterprise/coderd/userauth_test.go index 9ba1b7b4a10ab..664b6080ee708 100644 --- a/enterprise/coderd/userauth_test.go +++ b/enterprise/coderd/userauth_test.go @@ -1,17 +1,12 @@ package coderd_test import ( - "context" - "fmt" - "io" "net/http" "regexp" "testing" "time" "github.com/golang-jwt/jwt/v4" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd" @@ -23,6 +18,7 @@ import ( "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/codersdk" + coderden "github.com/coder/coder/v2/enterprise/coderd" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" "github.com/coder/coder/v2/enterprise/coderd/license" "github.com/coder/coder/v2/testutil" @@ -39,7 +35,6 @@ func TestUserOIDC(t *testing.T) { t.Run("NoRoles", func(t *testing.T) { t.Parallel() - const oidcRoleName = "TemplateAuthor" runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true @@ -63,6 +58,8 @@ func TestUserOIDC(t *testing.T) { // A user has some roles, then on an oauth refresh will lose said // roles from an updated claim. t.Run("NewUserAndRemoveRolesOnRefresh", func(t *testing.T) { + // TODO: Implement new feature to update roles/groups on OIDC + // refresh tokens. t.Skip("Refreshing tokens does not update roles :(") t.Parallel() @@ -95,37 +92,62 @@ func TestUserOIDC(t *testing.T) { runner.AssertRoles(t, "alice", []string{}) }) - t.Run("BlockAssignRoles", func(t *testing.T) { + // A user has some roles, then on another oauth login will lose said + // roles from an updated claim. + t.Run("NewUserAndRemoveRolesOnReAuth", func(t *testing.T) { t.Parallel() - ctx := testutil.Context(t, testutil.WaitMedium) - conf := coderdtest.NewOIDCConfig(t, "") + const oidcRoleName = "TemplateAuthor" + runner := setupOIDCTest(t, oidcTestConfig{ + Userinfo: jwt.MapClaims{oidcRoleName: []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()}}, + Config: func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.UserRoleField = "roles" + cfg.UserRoleMapping = map[string][]string{ + oidcRoleName: {rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()}, + } + }, + }) - config := conf.OIDCConfig(t, jwt.MapClaims{}) - config.AllowSignups = true - config.UserRoleField = "roles" + // User starts with the owner role + _, resp := runner.Login(jwt.MapClaims{ + "email": "alice@coder.com", + "roles": []string{"random", oidcRoleName, rbac.RoleOwner()}, + }) + require.Equal(t, http.StatusOK, resp.StatusCode) + runner.AssertRoles(t, "alice", []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin(), rbac.RoleOwner()}) - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - OIDCConfig: config, - }, - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{codersdk.FeatureUserRoleManagement: 1}, - }, + // Now login with oauth again, and check the roles are removed. + _, resp = runner.Login(jwt.MapClaims{ + "email": "alice@coder.com", + "roles": []string{"random"}, }) + require.Equal(t, http.StatusOK, resp.StatusCode) - admin, err := client.User(ctx, "me") - require.NoError(t, err) - require.Len(t, admin.OrganizationIDs, 1) + runner.AssertRoles(t, "alice", []string{}) + }) + + // All manual role updates should fail when role sync is enabled. + t.Run("BlockAssignRoles", func(t *testing.T) { + t.Parallel() + + const oidcRoleName = "TemplateAuthor" + runner := setupOIDCTest(t, oidcTestConfig{ + Config: func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.UserRoleField = "roles" + }, + }) - resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{ + _, resp := runner.Login(jwt.MapClaims{ "email": "alice@coder.com", "roles": []string{}, - })) - require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + }) + require.Equal(t, http.StatusOK, resp.StatusCode) // Try to manually update user roles, even though controlled by oidc // role sync. - _, err = client.UpdateUserRoles(ctx, "alice", codersdk.UpdateRoles{ + ctx := testutil.Context(t, testutil.WaitShort) + _, err := runner.AdminClient.UpdateUserRoles(ctx, "alice", codersdk.UpdateRoles{ Roles: []string{ rbac.RoleTemplateAdmin(), }, @@ -137,199 +159,211 @@ func TestUserOIDC(t *testing.T) { t.Run("Groups", func(t *testing.T) { t.Parallel() + + // Assigns does a simple test of assigning a user to a group based + // on the oidc claims. t.Run("Assigns", func(t *testing.T) { t.Parallel() - ctx := testutil.Context(t, testutil.WaitLong) - conf := coderdtest.NewOIDCConfig(t, "") - const groupClaim = "custom-groups" - config := conf.OIDCConfig(t, jwt.MapClaims{}, func(cfg *coderd.OIDCConfig) { - cfg.GroupField = groupClaim - }) - config.AllowSignups = true - - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - OIDCConfig: config, - }, - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{codersdk.FeatureTemplateRBAC: 1}, + const groupName = "bingbong" + runner := setupOIDCTest(t, oidcTestConfig{ + Config: func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.GroupField = groupClaim }, }) - admin, err := client.User(ctx, "me") - require.NoError(t, err) - require.Len(t, admin.OrganizationIDs, 1) - - groupName := "bingbong" - group, err := client.CreateGroup(ctx, admin.OrganizationIDs[0], codersdk.CreateGroupRequest{ + ctx := testutil.Context(t, testutil.WaitShort) + group, err := runner.AdminClient.CreateGroup(ctx, runner.AdminUser.OrganizationIDs[0], codersdk.CreateGroupRequest{ Name: groupName, }) require.NoError(t, err) require.Len(t, group.Members, 0) - resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{ - "email": "colin@coder.com", + _, resp := runner.Login(jwt.MapClaims{ + "email": "alice@coder.com", groupClaim: []string{groupName}, - })) - assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - - group, err = client.Group(ctx, group.ID) - require.NoError(t, err) - require.Len(t, group.Members, 1) + }) + require.Equal(t, http.StatusOK, resp.StatusCode) + runner.AssertGroups(t, "alice", []string{groupName}) }) + + // Tests the group mapping feature. t.Run("AssignsMapped", func(t *testing.T) { t.Parallel() - ctx := testutil.Context(t, testutil.WaitMedium) - conf := coderdtest.NewOIDCConfig(t, "") - - oidcGroupName := "pingpong" - coderGroupName := "bingbong" - - config := conf.OIDCConfig(t, jwt.MapClaims{}, func(cfg *coderd.OIDCConfig) { - cfg.GroupMapping = map[string]string{oidcGroupName: coderGroupName} - }) - config.AllowSignups = true + const groupClaim = "custom-groups" - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - OIDCConfig: config, - }, - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{codersdk.FeatureTemplateRBAC: 1}, + const oidcGroupName = "pingpong" + const coderGroupName = "bingbong" + runner := setupOIDCTest(t, oidcTestConfig{ + Config: func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.GroupField = groupClaim + cfg.GroupMapping = map[string]string{oidcGroupName: coderGroupName} }, }) - admin, err := client.User(ctx, "me") - require.NoError(t, err) - require.Len(t, admin.OrganizationIDs, 1) - - group, err := client.CreateGroup(ctx, admin.OrganizationIDs[0], codersdk.CreateGroupRequest{ + ctx := testutil.Context(t, testutil.WaitShort) + group, err := runner.AdminClient.CreateGroup(ctx, runner.AdminUser.OrganizationIDs[0], codersdk.CreateGroupRequest{ Name: coderGroupName, }) require.NoError(t, err) require.Len(t, group.Members, 0) - resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{ - "email": "colin@coder.com", - "groups": []string{oidcGroupName}, - })) - assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - - group, err = client.Group(ctx, group.ID) - require.NoError(t, err) - require.Len(t, group.Members, 1) + _, resp := runner.Login(jwt.MapClaims{ + "email": "alice@coder.com", + groupClaim: []string{oidcGroupName}, + }) + require.Equal(t, http.StatusOK, resp.StatusCode) + runner.AssertGroups(t, "alice", []string{coderGroupName}) }) - t.Run("AddThenRemove", func(t *testing.T) { + // User is in a group, then on an oauth refresh will lose said + // group. + t.Run("AddThenRemoveOnRefresh", func(t *testing.T) { t.Parallel() - ctx := testutil.Context(t, testutil.WaitLong) - conf := coderdtest.NewOIDCConfig(t, "") + // TODO: Implement new feature to update roles/groups on OIDC + // refresh tokens. + t.Skip("Refreshing tokens does not update groups :(") - config := conf.OIDCConfig(t, jwt.MapClaims{}) - config.AllowSignups = true - - client, firstUser := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - OIDCConfig: config, - }, - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{codersdk.FeatureTemplateRBAC: 1}, + const groupClaim = "custom-groups" + const groupName = "bingbong" + runner := setupOIDCTest(t, oidcTestConfig{ + Config: func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.GroupField = groupClaim }, }) - // Add some extra users/groups that should be asserted after. - // Adding this user as there was a bug that removing 1 user removed - // all users from the group. - _, extra := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID) - groupName := "bingbong" - group, err := client.CreateGroup(ctx, firstUser.OrganizationID, codersdk.CreateGroupRequest{ + ctx := testutil.Context(t, testutil.WaitShort) + group, err := runner.AdminClient.CreateGroup(ctx, runner.AdminUser.OrganizationIDs[0], codersdk.CreateGroupRequest{ Name: groupName, }) - require.NoError(t, err, "create group") + require.NoError(t, err) + require.Len(t, group.Members, 0) - group, err = client.PatchGroup(ctx, group.ID, codersdk.PatchGroupRequest{ - AddUsers: []string{ - firstUser.UserID.String(), - extra.ID.String(), - }, + client, resp := runner.Login(jwt.MapClaims{ + "email": "alice@coder.com", + groupClaim: []string{groupName}, }) - require.NoError(t, err, "patch group") - require.Len(t, group.Members, 2, "expect both members") + require.Equal(t, http.StatusOK, resp.StatusCode) + runner.AssertGroups(t, "alice", []string{groupName}) - // Now add OIDC user into the group - resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{ - "email": "colin@coder.com", - "groups": []string{groupName}, - })) - assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + // Refresh without the group claim + runner.ForceRefresh(t, client, jwt.MapClaims{ + "email": "alice@coder.com", + })(t) + runner.AssertGroups(t, "alice", []string{}) + }) - group, err = client.Group(ctx, group.ID) - require.NoError(t, err) - require.Len(t, group.Members, 3) + t.Run("AddThenRemoveOnReAuth", func(t *testing.T) { + t.Parallel() - // Login to remove the OIDC user from the group - resp = oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{ - "email": "colin@coder.com", - "groups": []string{}, - })) - assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + const groupClaim = "custom-groups" + const groupName = "bingbong" + runner := setupOIDCTest(t, oidcTestConfig{ + Config: func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.GroupField = groupClaim + }, + }) - group, err = client.Group(ctx, group.ID) + ctx := testutil.Context(t, testutil.WaitShort) + group, err := runner.AdminClient.CreateGroup(ctx, runner.AdminUser.OrganizationIDs[0], codersdk.CreateGroupRequest{ + Name: groupName, + }) require.NoError(t, err) - require.Len(t, group.Members, 2) - var expected []uuid.UUID - for _, mem := range group.Members { - expected = append(expected, mem.ID) - } - require.ElementsMatchf(t, expected, []uuid.UUID{firstUser.UserID, extra.ID}, "expected members") + require.Len(t, group.Members, 0) + + _, resp := runner.Login(jwt.MapClaims{ + "email": "alice@coder.com", + groupClaim: []string{groupName}, + }) + require.Equal(t, http.StatusOK, resp.StatusCode) + runner.AssertGroups(t, "alice", []string{groupName}) + + // Refresh without the group claim + _, resp = runner.Login(jwt.MapClaims{ + "email": "alice@coder.com", + }) + require.Equal(t, http.StatusOK, resp.StatusCode) + runner.AssertGroups(t, "alice", []string{}) }) + // Updating groups where the claimed group does not exist. t.Run("NoneMatch", func(t *testing.T) { t.Parallel() - ctx := testutil.Context(t, testutil.WaitLong) - conf := coderdtest.NewOIDCConfig(t, "") + const groupClaim = "custom-groups" + runner := setupOIDCTest(t, oidcTestConfig{ + Config: func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.GroupField = groupClaim + }, + }) - config := conf.OIDCConfig(t, jwt.MapClaims{}) - config.AllowSignups = true + _, resp := runner.Login(jwt.MapClaims{ + "email": "alice@coder.com", + groupClaim: []string{"not-exists"}, + }) + require.Equal(t, http.StatusOK, resp.StatusCode) + runner.AssertGroups(t, "alice", []string{}) + }) - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - OIDCConfig: config, - }, - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{codersdk.FeatureTemplateRBAC: 1}, + // Updating groups where the claimed group does not exist creates + // the group. + t.Run("AutoCreate", func(t *testing.T) { + t.Parallel() + + const groupClaim = "custom-groups" + const groupName = "make-me" + runner := setupOIDCTest(t, oidcTestConfig{ + Config: func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.GroupField = groupClaim + cfg.CreateMissingGroups = true }, }) - admin, err := client.User(ctx, "me") - require.NoError(t, err) - require.Len(t, admin.OrganizationIDs, 1) + _, resp := runner.Login(jwt.MapClaims{ + "email": "alice@coder.com", + groupClaim: []string{groupName}, + }) + require.Equal(t, http.StatusOK, resp.StatusCode) + runner.AssertGroups(t, "alice", []string{groupName}) + }) + }) - groupName := "bingbong" - group, err := client.CreateGroup(ctx, admin.OrganizationIDs[0], codersdk.CreateGroupRequest{ - Name: groupName, + t.Run("Refresh", func(t *testing.T) { + t.Run("RefreshTokensMultiple", func(t *testing.T) { + t.Parallel() + + runner := setupOIDCTest(t, oidcTestConfig{ + Config: func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.UserRoleField = "roles" + }, }) - require.NoError(t, err) - require.Len(t, group.Members, 0) - resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{ - "email": "colin@coder.com", - "groups": []string{"coolin"}, - })) - assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + claims := jwt.MapClaims{ + "email": "alice@coder.com", + } + // Login a new client that signs up + client, resp := runner.Login(claims) + require.Equal(t, http.StatusOK, resp.StatusCode) - group, err = client.Group(ctx, group.ID) - require.NoError(t, err) - require.Len(t, group.Members, 0) + // Refresh multiple times. + for i := 0; i < 3; i++ { + runner.ForceRefresh(t, client, claims)(t) + } }) }) } +// nolint:bodyclose func TestGroupSync(t *testing.T) { t.Parallel() @@ -443,28 +477,20 @@ func TestGroupSync(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx := testutil.Context(t, testutil.WaitLong) - conf := coderdtest.NewOIDCConfig(t, "") - - config := conf.OIDCConfig(t, jwt.MapClaims{}, tc.modCfg) - - client, _, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - OIDCConfig: config, - }, - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{codersdk.FeatureTemplateRBAC: 1}, + runner := setupOIDCTest(t, oidcTestConfig{ + Config: func(cfg *coderd.OIDCConfig) { + cfg.GroupField = "groups" + tc.modCfg(cfg) }, }) - admin, err := client.User(ctx, "me") - require.NoError(t, err) - require.Len(t, admin.OrganizationIDs, 1) - // Setup + ctx := testutil.Context(t, testutil.WaitLong) + org := runner.AdminUser.OrganizationIDs[0] + initialGroups := make(map[string]codersdk.Group) for _, group := range tc.initialOrgGroups { - newGroup, err := client.CreateGroup(ctx, admin.OrganizationIDs[0], codersdk.CreateGroupRequest{ + newGroup, err := runner.AdminClient.CreateGroup(ctx, org, codersdk.CreateGroupRequest{ Name: group, }) require.NoError(t, err) @@ -473,16 +499,16 @@ func TestGroupSync(t *testing.T) { } // Create the user and add them to their initial groups - _, user := coderdtest.CreateAnotherUser(t, client, admin.OrganizationIDs[0]) + _, user := coderdtest.CreateAnotherUser(t, runner.AdminClient, org) for _, group := range tc.initialUserGroups { - _, err := client.PatchGroup(ctx, initialGroups[group].ID, codersdk.PatchGroupRequest{ + _, err := runner.AdminClient.PatchGroup(ctx, initialGroups[group].ID, codersdk.PatchGroupRequest{ AddUsers: []string{user.ID.String()}, }) require.NoError(t, err) } // nolint:gocritic - _, err = api.Database.UpdateUserLoginType(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLoginTypeParams{ + _, err := runner.API.Database.UpdateUserLoginType(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLoginTypeParams{ NewLoginType: database.LoginTypeOIDC, UserID: user.ID, }) @@ -490,11 +516,11 @@ func TestGroupSync(t *testing.T) { // Log in the new user tc.claims["email"] = user.Email - resp := oidcCallback(t, client, conf.EncodeClaims(t, tc.claims)) - assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - _ = resp.Body.Close() + _, resp := runner.Login(tc.claims) + require.Equal(t, http.StatusOK, resp.StatusCode) - orgGroups, err := client.GroupsByOrganization(ctx, admin.OrganizationIDs[0]) + // Check group sources + orgGroups, err := runner.AdminClient.GroupsByOrganization(ctx, org) require.NoError(t, err) for _, group := range orgGroups { @@ -540,33 +566,12 @@ func TestGroupSync(t *testing.T) { } } -func oidcCallback(t *testing.T, client *codersdk.Client, code string) *http.Response { - t.Helper() - client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - } - oauthURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/users/oidc/callback?code=%s&state=somestate", code)) - require.NoError(t, err) - req, err := http.NewRequestWithContext(context.Background(), "GET", oauthURL.String(), nil) - require.NoError(t, err) - req.AddCookie(&http.Cookie{ - Name: codersdk.OAuth2StateCookie, - Value: "somestate", - }) - res, err := client.HTTPClient.Do(req) - require.NoError(t, err) - defer res.Body.Close() - data, err := io.ReadAll(res.Body) - require.NoError(t, err) - t.Log(string(data)) - return res -} - // oidcTestRunner is just a helper to setup and run oidc tests. // An actual Coderd instance is used to run the tests. type oidcTestRunner struct { AdminClient *codersdk.Client AdminUser codersdk.User + API *coderden.API // Login will call the OIDC flow with an unauthenticated client. // The customer actions will all be taken care of, and the idToken claims @@ -604,6 +609,15 @@ func (r *oidcTestRunner) AssertRoles(t *testing.T, userIdent string, roles []str } func (r *oidcTestRunner) AssertGroups(t *testing.T, userIdent string, groups []string) { + t.Helper() + + if !slice.Contains(groups, database.EveryoneGroup) { + var cpy []string + cpy = append(cpy, groups...) + // always include everyone group + cpy = append(cpy, database.EveryoneGroup) + groups = cpy + } ctx := testutil.Context(t, testutil.WaitMedium) user, err := r.AdminClient.User(ctx, userIdent) require.NoError(t, err) @@ -640,7 +654,10 @@ func setupOIDCTest(t *testing.T, settings oidcTestConfig) *oidcTestRunner { OIDCConfig: cfg, }, LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{codersdk.FeatureUserRoleManagement: 1}, + Features: license.Features{ + codersdk.FeatureUserRoleManagement: 1, + codersdk.FeatureTemplateRBAC: 1, + }, }, }) admin, err := client.User(ctx, "me") @@ -650,6 +667,7 @@ func setupOIDCTest(t *testing.T, settings oidcTestConfig) *oidcTestRunner { return &oidcTestRunner{ AdminClient: client, AdminUser: admin, + API: api, Login: func(idToken jwt.MapClaims) (*codersdk.Client, *http.Response) { return fake.LoginClient(t, unauthenticatedClient, idToken) }, From da69b1646fe52fb1cac370677db055cc35cb177a Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 24 Aug 2023 11:54:25 -0500 Subject: [PATCH 04/18] create refresh helper --- coderd/coderdtest/oidctest/runner.go | 94 ++++++++++++++++++++++++++++ coderd/userauth_test.go | 29 +++++++++ enterprise/coderd/userauth_test.go | 91 +++++++-------------------- 3 files changed, 146 insertions(+), 68 deletions(-) create mode 100644 coderd/coderdtest/oidctest/runner.go diff --git a/coderd/coderdtest/oidctest/runner.go b/coderd/coderdtest/oidctest/runner.go new file mode 100644 index 0000000000000..15c21bfcf925a --- /dev/null +++ b/coderd/coderdtest/oidctest/runner.go @@ -0,0 +1,94 @@ +package oidctest + +import ( + "net/http" + "testing" + "time" + + "github.com/golang-jwt/jwt/v4" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +// LoginHelper helps with logging in a user and refreshing their oauth tokens. +// It is mainly because refreshing oauth tokens is a bit tricky and requires +// some database manipulation. +type LoginHelper struct { + fake *FakeIDP + owner *codersdk.Client +} + +func NewLoginHelper(owner *codersdk.Client, fake *FakeIDP) *LoginHelper { + if owner == nil { + panic("owner must not be nil") + } + if fake == nil { + panic("fake must not be nil") + } + return &LoginHelper{ + fake: fake, + owner: owner, + } +} + +// Login just helps by making an unauthenticated client and logging in with +// the given claims. All Logins should be unauthenticated, so this is a +// convenience method. +func (h *LoginHelper) Login(t *testing.T, idTokenClaims jwt.MapClaims) (*codersdk.Client, *http.Response) { + t.Helper() + unauthenticatedClient := codersdk.New(h.owner.URL) + + return h.fake.LoginClient(t, unauthenticatedClient, idTokenClaims) +} + +// ForceRefresh forces the client to refresh its oauth token. +func (h *LoginHelper) ForceRefresh(t *testing.T, db database.Store, user *codersdk.Client, idToken jwt.MapClaims) (authenticatedCall func(t *testing.T)) { + t.Helper() + + //nolint:gocritic // Testing + ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitMedium)) + + id, _, err := httpmw.SplitAPIToken(user.SessionToken()) + require.NoError(t, err) + + // We need to get the OIDC link and update it in the database to force + // it to be expired. + key, err := db.GetAPIKeyByID(ctx, id) + require.NoError(t, err, "get api key") + + link, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{ + UserID: key.UserID, + LoginType: database.LoginTypeOIDC, + }) + require.NoError(t, err, "get user link") + + // Updates the claims that the IDP will return. By default, it always + // uses the original claims for the original oauth token. + h.fake.UpdateRefreshClaims(link.OAuthRefreshToken, idToken) + + // Fetch the oauth link for the given user. + _, err = db.UpdateUserLink(ctx, database.UpdateUserLinkParams{ + OAuthAccessToken: link.OAuthAccessToken, + OAuthRefreshToken: link.OAuthRefreshToken, + OAuthExpiry: time.Now().Add(time.Hour * -1), + UserID: key.UserID, + LoginType: database.LoginTypeOIDC, + }) + require.NoError(t, err, "expire user link") + t.Cleanup(func() { + require.True(t, h.fake.RefreshUsed(link.OAuthRefreshToken), "refresh token must be used, but has not. Did you forget to call the returned function from this call?") + }) + + return func(t *testing.T) { + t.Helper() + + // Do any authenticated call to force the refresh + _, err := user.User(testutil.Context(t, testutil.WaitShort), "me") + require.NoError(t, err, "user must be able to be fetched") + } +} diff --git a/coderd/userauth_test.go b/coderd/userauth_test.go index 10bf7ecf67234..17e0bf0f3b101 100644 --- a/coderd/userauth_test.go +++ b/coderd/userauth_test.go @@ -11,6 +11,10 @@ import ( "testing" "time" + "github.com/coder/coder/v2/coderd/coderdtest/oidctest" + "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" + "github.com/coder/coder/v2/enterprise/coderd/license" + "github.com/coreos/go-oidc/v3/oidc" "github.com/golang-jwt/jwt" "github.com/google/go-github/v43/github" @@ -38,6 +42,31 @@ import ( func TestOIDCOauthLoginWithExisting(t *testing.T) { t.Parallel() + fake := oidctest.NewFakeIDP(t) + ctx := testutil.Context(t, testutil.WaitMedium) + + cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) { + + }) + + client, _, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + OIDCConfig: cfg, + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureUserRoleManagement: 1, + codersdk.FeatureTemplateRBAC: 1, + }, + }, + }) + helper := oidctest.NewLoginHelper(client, fake) + + + + + // + conf := coderdtest.NewOIDCConfig(t, "", // Provide a refresh token so we use the refresh token flow coderdtest.WithRefreshToken("refresh_token"), diff --git a/enterprise/coderd/userauth_test.go b/enterprise/coderd/userauth_test.go index 664b6080ee708..8360a3ada3e97 100644 --- a/enterprise/coderd/userauth_test.go +++ b/enterprise/coderd/userauth_test.go @@ -4,7 +4,6 @@ import ( "net/http" "regexp" "testing" - "time" "github.com/golang-jwt/jwt/v4" "github.com/stretchr/testify/require" @@ -14,7 +13,6 @@ import ( "github.com/coder/coder/v2/coderd/coderdtest/oidctest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" - "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/codersdk" @@ -46,7 +44,7 @@ func TestUserOIDC(t *testing.T) { "email": "alice@coder.com", } // Login a new client that signs up - client, resp := runner.Login(claims) + client, resp := runner.Login(t, claims) require.Equal(t, http.StatusOK, resp.StatusCode) // User should be in 0 groups. runner.AssertRoles(t, "alice", []string{}) @@ -76,7 +74,7 @@ func TestUserOIDC(t *testing.T) { }) // User starts with the owner role - client, resp := runner.Login(jwt.MapClaims{ + client, resp := runner.Login(t, jwt.MapClaims{ "email": "alice@coder.com", "roles": []string{"random", oidcRoleName, rbac.RoleOwner()}, }) @@ -110,7 +108,7 @@ func TestUserOIDC(t *testing.T) { }) // User starts with the owner role - _, resp := runner.Login(jwt.MapClaims{ + _, resp := runner.Login(t, jwt.MapClaims{ "email": "alice@coder.com", "roles": []string{"random", oidcRoleName, rbac.RoleOwner()}, }) @@ -118,7 +116,7 @@ func TestUserOIDC(t *testing.T) { runner.AssertRoles(t, "alice", []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin(), rbac.RoleOwner()}) // Now login with oauth again, and check the roles are removed. - _, resp = runner.Login(jwt.MapClaims{ + _, resp = runner.Login(t, jwt.MapClaims{ "email": "alice@coder.com", "roles": []string{"random"}, }) @@ -139,7 +137,7 @@ func TestUserOIDC(t *testing.T) { }, }) - _, resp := runner.Login(jwt.MapClaims{ + _, resp := runner.Login(t, jwt.MapClaims{ "email": "alice@coder.com", "roles": []string{}, }) @@ -181,7 +179,7 @@ func TestUserOIDC(t *testing.T) { require.NoError(t, err) require.Len(t, group.Members, 0) - _, resp := runner.Login(jwt.MapClaims{ + _, resp := runner.Login(t, jwt.MapClaims{ "email": "alice@coder.com", groupClaim: []string{groupName}, }) @@ -212,7 +210,7 @@ func TestUserOIDC(t *testing.T) { require.NoError(t, err) require.Len(t, group.Members, 0) - _, resp := runner.Login(jwt.MapClaims{ + _, resp := runner.Login(t, jwt.MapClaims{ "email": "alice@coder.com", groupClaim: []string{oidcGroupName}, }) @@ -245,7 +243,7 @@ func TestUserOIDC(t *testing.T) { require.NoError(t, err) require.Len(t, group.Members, 0) - client, resp := runner.Login(jwt.MapClaims{ + client, resp := runner.Login(t, jwt.MapClaims{ "email": "alice@coder.com", groupClaim: []string{groupName}, }) @@ -278,7 +276,7 @@ func TestUserOIDC(t *testing.T) { require.NoError(t, err) require.Len(t, group.Members, 0) - _, resp := runner.Login(jwt.MapClaims{ + _, resp := runner.Login(t, jwt.MapClaims{ "email": "alice@coder.com", groupClaim: []string{groupName}, }) @@ -286,7 +284,7 @@ func TestUserOIDC(t *testing.T) { runner.AssertGroups(t, "alice", []string{groupName}) // Refresh without the group claim - _, resp = runner.Login(jwt.MapClaims{ + _, resp = runner.Login(t, jwt.MapClaims{ "email": "alice@coder.com", }) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -305,7 +303,7 @@ func TestUserOIDC(t *testing.T) { }, }) - _, resp := runner.Login(jwt.MapClaims{ + _, resp := runner.Login(t, jwt.MapClaims{ "email": "alice@coder.com", groupClaim: []string{"not-exists"}, }) @@ -328,7 +326,7 @@ func TestUserOIDC(t *testing.T) { }, }) - _, resp := runner.Login(jwt.MapClaims{ + _, resp := runner.Login(t, jwt.MapClaims{ "email": "alice@coder.com", groupClaim: []string{groupName}, }) @@ -352,7 +350,7 @@ func TestUserOIDC(t *testing.T) { "email": "alice@coder.com", } // Login a new client that signs up - client, resp := runner.Login(claims) + client, resp := runner.Login(t, claims) require.Equal(t, http.StatusOK, resp.StatusCode) // Refresh multiple times. @@ -516,7 +514,7 @@ func TestGroupSync(t *testing.T) { // Log in the new user tc.claims["email"] = user.Email - _, resp := runner.Login(tc.claims) + _, resp := runner.Login(t, tc.claims) require.Equal(t, http.StatusOK, resp.StatusCode) // Check group sources @@ -576,7 +574,7 @@ type oidcTestRunner struct { // Login will call the OIDC flow with an unauthenticated client. // The customer actions will all be taken care of, and the idToken claims // will be returned. - Login func(idToken jwt.MapClaims) (*codersdk.Client, *http.Response) + Login func(t *testing.T, idToken jwt.MapClaims) (*codersdk.Client, *http.Response) // ForceRefresh will use an authenticated codersdk.Client, and force their // OIDC token to be expired and require a refresh. The refresh will use the claims provided. // @@ -649,7 +647,7 @@ func setupOIDCTest(t *testing.T, settings oidcTestConfig) *oidcTestRunner { ctx := testutil.Context(t, testutil.WaitMedium) cfg := fake.OIDCConfig(t, nil, settings.Config) - client, _, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + owner, _, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ Options: &coderdtest.Options{ OIDCConfig: cfg, }, @@ -660,61 +658,18 @@ func setupOIDCTest(t *testing.T, settings oidcTestConfig) *oidcTestRunner { }, }, }) - admin, err := client.User(ctx, "me") + admin, err := owner.User(ctx, "me") require.NoError(t, err) - unauthenticatedClient := codersdk.New(client.URL) + + helper := oidctest.NewLoginHelper(owner, fake) return &oidcTestRunner{ - AdminClient: client, + AdminClient: owner, AdminUser: admin, API: api, - Login: func(idToken jwt.MapClaims) (*codersdk.Client, *http.Response) { - return fake.LoginClient(t, unauthenticatedClient, idToken) - }, - ForceRefresh: func(t *testing.T, client *codersdk.Client, idToken jwt.MapClaims) (authenticatedCall func(t *testing.T)) { - t.Helper() - - //nolint:gocritic // Testing - ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitMedium)) - - id, _, err := httpmw.SplitAPIToken(client.SessionToken()) - require.NoError(t, err) - - // We need to get the OIDC link and update it in the database to force - // it to be expired. - key, err := api.Database.GetAPIKeyByID(ctx, id) - require.NoError(t, err, "get api key") - - link, err := api.Database.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{ - UserID: key.UserID, - LoginType: database.LoginTypeOIDC, - }) - require.NoError(t, err, "get user link") - - // Updates the claims that the IDP will return. By default, it always - // uses the original claims for the original oauth token. - fake.UpdateRefreshClaims(link.OAuthRefreshToken, idToken) - - // Fetch the oauth link for the given user. - _, err = api.Database.UpdateUserLink(ctx, database.UpdateUserLinkParams{ - OAuthAccessToken: link.OAuthAccessToken, - OAuthRefreshToken: link.OAuthRefreshToken, - OAuthExpiry: time.Now().Add(time.Hour * -1), - UserID: key.UserID, - LoginType: database.LoginTypeOIDC, - }) - require.NoError(t, err, "expire user link") - t.Cleanup(func() { - require.True(t, fake.RefreshUsed(link.OAuthRefreshToken), "refresh token must be used, but has not. Did you forget to call the returned function from this call?") - }) - - return func(t *testing.T) { - t.Helper() - - // Do any authenticated call to force the refresh - _, err := client.User(testutil.Context(t, testutil.WaitShort), "me") - require.NoError(t, err, "user must be able to be fetched") - } + Login: helper.Login, + ForceRefresh: func(t *testing.T, client *codersdk.Client, idToken jwt.MapClaims) func(t *testing.T) { + return helper.ForceRefresh(t, api.Database, client, idToken) }, } } From a8a9633eeb345ea5a86768a4c6d6135dcc5f0aa9 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 24 Aug 2023 14:52:59 -0500 Subject: [PATCH 05/18] refactor all old tests, delete old fake --- coderd/coderdtest/coderdtest.go | 148 -------- .../oidctest/{runner.go => helper.go} | 34 +- coderd/coderdtest/oidctest/idp.go | 129 +++++-- coderd/userauth_test.go | 326 +++++++----------- coderd/users_test.go | 27 +- enterprise/coderd/userauth_test.go | 14 +- 6 files changed, 262 insertions(+), 416 deletions(-) rename coderd/coderdtest/oidctest/{runner.go => helper.go} (77%) diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index b12cb3812ccff..03b74b38e289d 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -31,7 +31,6 @@ import ( "time" "cloud.google.com/go/compute/metadata" - "github.com/coreos/go-oidc/v3/oidc" "github.com/fullsailor/pkcs7" "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" @@ -40,7 +39,6 @@ import ( "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/oauth2" "golang.org/x/xerrors" "google.golang.org/api/idtoken" "google.golang.org/api/option" @@ -1022,152 +1020,6 @@ func NewAWSInstanceIdentity(t *testing.T, instanceID string) (awsidentity.Certif } } -type OIDCConfig struct { - key *rsa.PrivateKey - issuer string - // These are optional - refreshToken string - oidcTokenExpires func() time.Time - tokenSource func() (*oauth2.Token, error) -} - -func WithRefreshToken(token string) func(cfg *OIDCConfig) { - return func(cfg *OIDCConfig) { - cfg.refreshToken = token - } -} - -func WithTokenExpires(expFunc func() time.Time) func(cfg *OIDCConfig) { - return func(cfg *OIDCConfig) { - cfg.oidcTokenExpires = expFunc - } -} - -func WithTokenSource(src func() (*oauth2.Token, error)) func(cfg *OIDCConfig) { - return func(cfg *OIDCConfig) { - cfg.tokenSource = src - } -} - -func NewOIDCConfig(t *testing.T, issuer string, opts ...func(cfg *OIDCConfig)) *OIDCConfig { - t.Helper() - - block, _ := pem.Decode([]byte(testRSAPrivateKey)) - pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes) - require.NoError(t, err) - - if issuer == "" { - issuer = "https://coder.com" - } - - cfg := &OIDCConfig{ - key: pkey, - issuer: issuer, - } - for _, opt := range opts { - opt(cfg) - } - return cfg -} - -func (*OIDCConfig) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string { - return "/?state=" + url.QueryEscape(state) -} - -type tokenSource struct { - src func() (*oauth2.Token, error) -} - -func (s tokenSource) Token() (*oauth2.Token, error) { - return s.src() -} - -func (cfg *OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource { - if cfg.tokenSource == nil { - return nil - } - return tokenSource{ - src: cfg.tokenSource, - } -} - -func (cfg *OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) { - token, err := base64.StdEncoding.DecodeString(code) - if err != nil { - return nil, xerrors.Errorf("decode code: %w", err) - } - - var exp time.Time - if cfg.oidcTokenExpires != nil { - exp = cfg.oidcTokenExpires() - } - - return (&oauth2.Token{ - AccessToken: "token", - RefreshToken: cfg.refreshToken, - Expiry: exp, - }).WithExtra(map[string]interface{}{ - "id_token": string(token), - }), nil -} - -func (cfg *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string { - t.Helper() - - if _, ok := claims["exp"]; !ok { - claims["exp"] = time.Now().Add(time.Hour).UnixMilli() - } - - if _, ok := claims["iss"]; !ok { - claims["iss"] = cfg.issuer - } - - if _, ok := claims["sub"]; !ok { - claims["sub"] = "testme" - } - - signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(cfg.key) - require.NoError(t, err) - - return base64.StdEncoding.EncodeToString([]byte(signed)) -} - -func (cfg *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig { - // By default, the provider can be empty. - // This means it won't support any endpoints! - provider := &oidc.Provider{} - if userInfoClaims != nil { - resp, err := json.Marshal(userInfoClaims) - require.NoError(t, err) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write(resp) - })) - t.Cleanup(srv.Close) - cfg := &oidc.ProviderConfig{ - UserInfoURL: srv.URL, - } - provider = cfg.NewProvider(context.Background()) - } - newCFG := &coderd.OIDCConfig{ - OAuth2Config: cfg, - Verifier: oidc.NewVerifier(cfg.issuer, &oidc.StaticKeySet{ - PublicKeys: []crypto.PublicKey{cfg.key.Public()}, - }, &oidc.Config{ - SkipClientIDCheck: true, - }), - Provider: provider, - UsernameField: "preferred_username", - EmailField: "email", - AuthURLParams: map[string]string{"access_type": "offline"}, - GroupField: "groups", - } - for _, opt := range opts { - opt(newCFG) - } - return newCFG -} - // NewAzureInstanceIdentity returns a metadata client and ID token validator for faking // instance authentication for Azure. func NewAzureInstanceIdentity(t *testing.T, instanceID string) (x509.VerifyOptions, *http.Client) { diff --git a/coderd/coderdtest/oidctest/runner.go b/coderd/coderdtest/oidctest/helper.go similarity index 77% rename from coderd/coderdtest/oidctest/runner.go rename to coderd/coderdtest/oidctest/helper.go index 15c21bfcf925a..dc75b0efd5f6e 100644 --- a/coderd/coderdtest/oidctest/runner.go +++ b/coderd/coderdtest/oidctest/helper.go @@ -43,11 +43,10 @@ func (h *LoginHelper) Login(t *testing.T, idTokenClaims jwt.MapClaims) (*codersd t.Helper() unauthenticatedClient := codersdk.New(h.owner.URL) - return h.fake.LoginClient(t, unauthenticatedClient, idTokenClaims) + return h.fake.Login(t, unauthenticatedClient, idTokenClaims) } -// ForceRefresh forces the client to refresh its oauth token. -func (h *LoginHelper) ForceRefresh(t *testing.T, db database.Store, user *codersdk.Client, idToken jwt.MapClaims) (authenticatedCall func(t *testing.T)) { +func (h *LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *codersdk.Client) (refreshToken string) { t.Helper() //nolint:gocritic // Testing @@ -67,10 +66,6 @@ func (h *LoginHelper) ForceRefresh(t *testing.T, db database.Store, user *coders }) require.NoError(t, err, "get user link") - // Updates the claims that the IDP will return. By default, it always - // uses the original claims for the original oauth token. - h.fake.UpdateRefreshClaims(link.OAuthRefreshToken, idToken) - // Fetch the oauth link for the given user. _, err = db.UpdateUserLink(ctx, database.UpdateUserLinkParams{ OAuthAccessToken: link.OAuthAccessToken, @@ -80,15 +75,24 @@ func (h *LoginHelper) ForceRefresh(t *testing.T, db database.Store, user *coders LoginType: database.LoginTypeOIDC, }) require.NoError(t, err, "expire user link") + + return link.OAuthRefreshToken +} + +// ForceRefresh forces the client to refresh its oauth token. +func (h *LoginHelper) ForceRefresh(t *testing.T, db database.Store, user *codersdk.Client, idToken jwt.MapClaims) { + t.Helper() + + refreshToken := h.ExpireOauthToken(t, db, user) + // Updates the claims that the IDP will return. By default, it always + // uses the original claims for the original oauth token. + h.fake.UpdateRefreshClaims(refreshToken, idToken) + t.Cleanup(func() { - require.True(t, h.fake.RefreshUsed(link.OAuthRefreshToken), "refresh token must be used, but has not. Did you forget to call the returned function from this call?") + require.True(t, h.fake.RefreshUsed(refreshToken), "refresh token must be used, but has not. Did you forget to call the returned function from this call?") }) - return func(t *testing.T) { - t.Helper() - - // Do any authenticated call to force the refresh - _, err := user.User(testutil.Context(t, testutil.WaitShort), "me") - require.NoError(t, err, "user must be able to be fetched") - } + // Do any authenticated call to force the refresh + _, err := user.User(testutil.Context(t, testutil.WaitShort), "me") + require.NoError(t, err, "user must be able to be fetched") } diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index b925811c5d6d6..19fe771704a9f 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -57,15 +57,23 @@ type FakeIDP struct { refreshIDTokenClaims *SyncMap[string, jwt.MapClaims] // hooks - hookUserInfo func(email string) jwt.MapClaims - hookIDTokenClaims jwt.MapClaims - fakeCoderd func(req *http.Request) (*http.Response, error) + hookUserInfo func(email string) jwt.MapClaims + fakeCoderd func(req *http.Request) (*http.Response, error) + hookOnRefresh func(email string) error // Optional if you want to use a real http network request assuming // it is not directed to the IDP. defaultClient *http.Client serve bool } +type FakeIDPOpt func(idp *FakeIDP) + +func WithRefreshHook(hook func(email string) error) func(*FakeIDP) { + return func(f *FakeIDP) { + f.hookOnRefresh = hook + } +} + func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) { return func(f *FakeIDP) { f.logger = slogtest.Make(t, options) @@ -99,7 +107,7 @@ const ( userInfoPath = "/oauth2/userinfo" ) -func NewFakeIDP(t testing.TB, opts ...func(idp *FakeIDP)) *FakeIDP { +func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { t.Helper() block, _ := pem.Decode([]byte(testRSAPrivateKey)) @@ -117,6 +125,7 @@ func NewFakeIDP(t testing.TB, opts ...func(idp *FakeIDP)) *FakeIDP { refreshTokensUsed: NewSyncMap[string, bool](), stateToIDTokenClaims: NewSyncMap[string, jwt.MapClaims](), refreshIDTokenClaims: NewSyncMap[string, jwt.MapClaims](), + hookOnRefresh: func(_ string) error { return nil }, hookUserInfo: func(email string) jwt.MapClaims { return jwt.MapClaims{} }, } idp.handler = idp.httpHandler(t) @@ -176,9 +185,40 @@ func (f *FakeIDP) Serve(t testing.TB) *httptest.Server { return srv } -// LoginClient does the full OIDC flow starting at the "LoginButton". +// Login does the full OIDC flow starting at the "LoginButton". // The client argument is just to get the URL of the Coder instance. -func (f *FakeIDP) LoginClient(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims) (*codersdk.Client, *http.Response) { +// +// The client passed in is just to get the url of the Coder instance. +// The actual client that is used is 100% unauthenticated and fresh. +func (f *FakeIDP) Login(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) { + t.Helper() + + client, resp := f.AttemptLogin(t, client, idTokenClaims, opts...) + require.Equal(t, http.StatusOK, resp.StatusCode, "client failed to login") + return client, resp +} + +func (f *FakeIDP) AttemptLogin(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) { + t.Helper() + var err error + + cli := f.HTTPClient(client.HTTPClient) + shallowCpyCli := &(*cli) + + if shallowCpyCli.Jar == nil { + shallowCpyCli.Jar, err = cookiejar.New(nil) + require.NoError(t, err, "failed to create cookie jar") + } + + unauthenticated := codersdk.New(client.URL) + unauthenticated.HTTPClient = shallowCpyCli + + return f.LoginClient(t, unauthenticated, idTokenClaims, opts...) +} + +// LoginClient reuses the context of the passed in client. This means the same +// cookies will be used. This should be an unauthenticated client in most cases. +func (f *FakeIDP) LoginClient(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) { t.Helper() coderOauthURL, err := client.URL.Parse("/api/v2/users/oidc/callback") @@ -186,8 +226,7 @@ func (f *FakeIDP) LoginClient(t testing.TB, client *codersdk.Client, idTokenClai f.SetRedirect(t, coderOauthURL.String()) cli := f.HTTPClient(client.HTTPClient) - shallowCpyCli := &(*cli) - shallowCpyCli.CheckRedirect = func(req *http.Request, via []*http.Request) error { + cli.CheckRedirect = func(req *http.Request, via []*http.Request) error { // Store the idTokenClaims to the specific state request. This ties // the claims 1:1 with a given authentication flow. state := req.URL.Query().Get("state") @@ -197,17 +236,21 @@ func (f *FakeIDP) LoginClient(t testing.TB, client *codersdk.Client, idTokenClai req, err := http.NewRequestWithContext(context.Background(), "GET", coderOauthURL.String(), nil) require.NoError(t, err) - if shallowCpyCli.Jar == nil { - shallowCpyCli.Jar, err = cookiejar.New(nil) + if cli.Jar == nil { + cli.Jar, err = cookiejar.New(nil) require.NoError(t, err, "failed to create cookie jar") } - res, err := shallowCpyCli.Do(req) + for _, opt := range opts { + opt(req) + } + + res, err := cli.Do(req) require.NoError(t, err) // If the coder session token exists, return the new authed client! var user *codersdk.Client - cookies := shallowCpyCli.Jar.Cookies(client.URL) + cookies := cli.Jar.Cookies(client.URL) for _, cookie := range cookies { if cookie.Name == codersdk.SessionTokenCookie { user = codersdk.New(client.URL) @@ -220,6 +263,7 @@ func (f *FakeIDP) LoginClient(t testing.TB, client *codersdk.Client, idTokenClai res.Body.Close() } }) + return user, res } @@ -229,18 +273,7 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map t.Helper() f.stateToIDTokenClaims.Store(state, idTokenClaims) - baseCli := http.DefaultClient - if f.fakeCoderd != nil { - baseCli = &http.Client{ - Transport: fakeRoundTripper{ - roundTrip: func(req *http.Request) (*http.Response, error) { - return f.fakeCoderd(req) - }, - }, - } - } - - cli := f.HTTPClient(baseCli) + cli := f.HTTPClient(nil) u := f.cfg.AuthCodeURL(state) req, err := http.NewRequest("GET", u, nil) require.NoError(t, err) @@ -408,6 +441,16 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), http.StatusBadRequest) return } + getEmail := func(claims jwt.MapClaims) string { + email, ok := claims["email"] + if !ok { + return "unknown" + } + if _, ok := email.(string); !ok { + return "wrong-type" + } + return email.(string) + } var claims jwt.MapClaims switch values.Get("grant_type") { @@ -444,8 +487,6 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { http.Error(rw, "invalid refresh_token", http.StatusBadRequest) return } - // Always invalidate the refresh token after it is used. - f.refreshTokens.Delete(refreshToken) idTokenClaims, ok := f.refreshIDTokenClaims.Load(refreshToken) if !ok { @@ -453,8 +494,17 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { http.Error(rw, "missing id token claims in refresh", http.StatusBadRequest) return } + claims = idTokenClaims + err := f.hookOnRefresh(getEmail(claims)) + if err != nil { + http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), http.StatusBadRequest) + return + } + f.refreshTokensUsed.Store(refreshToken, true) + // Always invalidate the refresh token after it is used. + f.refreshTokens.Delete(refreshToken) default: t.Errorf("unexpected grant_type %q", values.Get("grant_type")) http.Error(rw, "invalid grant_type", http.StatusBadRequest) @@ -463,13 +513,10 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { exp := time.Now().Add(time.Minute * 5) claims["exp"] = exp.UnixMilli() - email, ok := claims["email"] - if !ok || email.(string) == "" { - email = "unknown" - } - refreshToken := f.newRefreshTokens(email.(string)) + email := getEmail(claims) + refreshToken := f.newRefreshTokens(email) token := map[string]interface{}{ - "access_token": f.newToken(email.(string)), + "access_token": f.newToken(email), "refresh_token": refreshToken, "token_type": "Bearer", "expires_in": int64(time.Minute * 5), @@ -533,17 +580,26 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { // If no client is passed in, then any regular network requests will fail. func (f *FakeIDP) HTTPClient(rest *http.Client) *http.Client { if f.serve { - if rest == nil { - return http.DefaultClient + if rest == nil || rest.Transport == nil { + return &http.Client{} } return rest } + + var jar http.CookieJar + if rest != nil { + jar = rest.Jar + } return &http.Client{ + Jar: jar, Transport: fakeRoundTripper{ roundTrip: func(req *http.Request) (*http.Response, error) { u, _ := url.Parse(f.issuer) if req.URL.Host != u.Host { - if rest == nil { + if f.fakeCoderd != nil { + return f.fakeCoderd(req) + } + if rest == nil || rest.Transport == nil { return nil, fmt.Errorf("unexpected network request to %q", req.URL.Host) } return rest.Transport.RoundTrip(req) @@ -629,6 +685,9 @@ func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *co } for _, opt := range opts { + if opt == nil { + continue + } opt(cfg) } diff --git a/coderd/userauth_test.go b/coderd/userauth_test.go index 17e0bf0f3b101..4fd17abc35633 100644 --- a/coderd/userauth_test.go +++ b/coderd/userauth_test.go @@ -4,35 +4,30 @@ import ( "context" "crypto" "fmt" - "io" "net/http" "net/http/cookiejar" + "net/url" "strings" "testing" - "time" - - "github.com/coder/coder/v2/coderd/coderdtest/oidctest" - "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" - "github.com/coder/coder/v2/enterprise/coderd/license" "github.com/coreos/go-oidc/v3/oidc" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v4" "github.com/google/go-github/v43/github" "github.com/google/uuid" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/oauth2" "golang.org/x/xerrors" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/coderdtest/oidctest" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" + "github.com/coder/coder/v2/enterprise/coderd/license" "github.com/coder/coder/v2/testutil" ) @@ -42,11 +37,16 @@ import ( func TestOIDCOauthLoginWithExisting(t *testing.T) { t.Parallel() - fake := oidctest.NewFakeIDP(t) - ctx := testutil.Context(t, testutil.WaitMedium) + fake := oidctest.NewFakeIDP(t, + oidctest.WithRefreshHook(func(_ string) error { + return xerrors.New("refreshing token should never occur") + }), + oidctest.WithServing(), + ) cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) { - + cfg.AllowSignups = true + cfg.IgnoreUserInfo = true }) client, _, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ @@ -60,89 +60,23 @@ func TestOIDCOauthLoginWithExisting(t *testing.T) { }, }, }) - helper := oidctest.NewLoginHelper(client, fake) - - - - // - - conf := coderdtest.NewOIDCConfig(t, "", - // Provide a refresh token so we use the refresh token flow - coderdtest.WithRefreshToken("refresh_token"), - // We need to set the expire in the future for the first api calls. - coderdtest.WithTokenExpires(func() time.Time { - return time.Now().Add(time.Hour).UTC() - }), - // No refresh should actually happen in this test. - coderdtest.WithTokenSource(func() (*oauth2.Token, error) { - return nil, xerrors.New("token should not require refresh") - }), - ) - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - auditor := audit.NewMock() const username = "alice" claims := jwt.MapClaims{ "email": "alice@coder.com", "email_verified": true, "preferred_username": username, } - config := conf.OIDCConfig(t, claims) - - config.AllowSignups = true - config.IgnoreUserInfo = true - client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ - Auditor: auditor, - OIDCConfig: config, - Logger: &logger, - }) + helper := oidctest.NewLoginHelper(client, fake) // Signup alice - resp := oidcCallback(t, client, conf.EncodeClaims(t, claims)) - // Set the client to use this OIDC context - authCookie := authCookieValue(resp.Cookies()) - client.SetSessionToken(authCookie) - _ = resp.Body.Close() - - ctx := testutil.Context(t, testutil.WaitLong) - // Verify the user and oauth link - user, err := client.User(ctx, "me") - require.NoError(t, err) - require.Equal(t, username, user.Username) + userClient, _ := helper.Login(t, claims) - // nolint:gocritic - link, err := api.Database.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{ - UserID: user.ID, - LoginType: database.LoginType(user.LoginType), - }) - require.NoError(t, err, "failed to get user link") - - // Expire the link - // nolint:gocritic - _, err = api.Database.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{ - OAuthAccessToken: link.OAuthAccessToken, - OAuthRefreshToken: link.OAuthRefreshToken, - OAuthExpiry: time.Now().Add(time.Hour * -1).UTC(), - UserID: link.UserID, - LoginType: link.LoginType, - }) - require.NoError(t, err, "failed to update user link") - - // Log in again with OIDC - loginAgain := oidcCallbackWithState(t, client, conf.EncodeClaims(t, claims), "seconds_login", func(req *http.Request) { - req.AddCookie(&http.Cookie{ - Name: codersdk.SessionTokenCookie, - Value: authCookie, - Path: "/", - }) - }) - require.Equal(t, http.StatusTemporaryRedirect, loginAgain.StatusCode) - _ = loginAgain.Body.Close() + // Expire the link. This will force the client to refresh the token. + helper.ExpireOauthToken(t, api.Database, userClient) - // Try to use new login - client.SetSessionToken(authCookieValue(resp.Cookies())) - _, err = client.User(ctx, "me") - require.NoError(t, err, "use new session") + // Instead of refreshing, just log in again. + userClient, _ = helper.Login(t, claims) } func TestUserLogin(t *testing.T) { @@ -689,7 +623,7 @@ func TestUserOIDC(t *testing.T) { "email": "kyle@kwc.io", }, AllowSignups: true, - StatusCode: http.StatusTemporaryRedirect, + StatusCode: http.StatusOK, Username: "kyle", }, { Name: "EmailNotVerified", @@ -714,7 +648,7 @@ func TestUserOIDC(t *testing.T) { "email_verified": false, }, AllowSignups: true, - StatusCode: http.StatusTemporaryRedirect, + StatusCode: http.StatusOK, Username: "kyle", IgnoreEmailVerified: true, }, { @@ -738,7 +672,7 @@ func TestUserOIDC(t *testing.T) { EmailDomain: []string{ "kwc.io", }, - StatusCode: http.StatusTemporaryRedirect, + StatusCode: http.StatusOK, }, { Name: "EmptyClaims", IDTokenClaims: jwt.MapClaims{}, @@ -759,7 +693,7 @@ func TestUserOIDC(t *testing.T) { }, Username: "kyle", AllowSignups: true, - StatusCode: http.StatusTemporaryRedirect, + StatusCode: http.StatusOK, }, { Name: "UsernameFromClaims", IDTokenClaims: jwt.MapClaims{ @@ -769,7 +703,7 @@ func TestUserOIDC(t *testing.T) { }, Username: "hotdog", AllowSignups: true, - StatusCode: http.StatusTemporaryRedirect, + StatusCode: http.StatusOK, }, { // Services like Okta return the email as the username: // https://developer.okta.com/docs/reference/api/oidc/#base-claims-always-present @@ -781,7 +715,7 @@ func TestUserOIDC(t *testing.T) { }, Username: "kyle", AllowSignups: true, - StatusCode: http.StatusTemporaryRedirect, + StatusCode: http.StatusOK, }, { // See: https://github.com/coder/coder/issues/4472 Name: "UsernameIsEmail", @@ -790,7 +724,7 @@ func TestUserOIDC(t *testing.T) { }, Username: "kyle", AllowSignups: true, - StatusCode: http.StatusTemporaryRedirect, + StatusCode: http.StatusOK, }, { Name: "WithPicture", IDTokenClaims: jwt.MapClaims{ @@ -802,7 +736,7 @@ func TestUserOIDC(t *testing.T) { Username: "kyle", AllowSignups: true, AvatarURL: "/example.png", - StatusCode: http.StatusTemporaryRedirect, + StatusCode: http.StatusOK, }, { Name: "WithUserInfoClaims", IDTokenClaims: jwt.MapClaims{ @@ -816,7 +750,7 @@ func TestUserOIDC(t *testing.T) { Username: "potato", AllowSignups: true, AvatarURL: "/example.png", - StatusCode: http.StatusTemporaryRedirect, + StatusCode: http.StatusOK, }, { Name: "GroupsDoesNothing", IDTokenClaims: jwt.MapClaims{ @@ -824,7 +758,7 @@ func TestUserOIDC(t *testing.T) { "groups": []string{"pingpong"}, }, AllowSignups: true, - StatusCode: http.StatusTemporaryRedirect, + StatusCode: http.StatusOK, }, { Name: "UserInfoOverridesIDTokenClaims", IDTokenClaims: jwt.MapClaims{ @@ -839,7 +773,7 @@ func TestUserOIDC(t *testing.T) { Username: "user", AllowSignups: true, IgnoreEmailVerified: false, - StatusCode: http.StatusTemporaryRedirect, + StatusCode: http.StatusOK, }, { Name: "InvalidUserInfo", IDTokenClaims: jwt.MapClaims{ @@ -866,36 +800,41 @@ func TestUserOIDC(t *testing.T) { Username: "user", IgnoreUserInfo: true, AllowSignups: true, - StatusCode: http.StatusTemporaryRedirect, + StatusCode: http.StatusOK, }} { tc := tc t.Run(tc.Name, func(t *testing.T) { t.Parallel() - auditor := audit.NewMock() - conf := coderdtest.NewOIDCConfig(t, "") - - config := conf.OIDCConfig(t, tc.UserInfoClaims) - config.AllowSignups = tc.AllowSignups - config.EmailDomain = tc.EmailDomain - config.IgnoreEmailVerified = tc.IgnoreEmailVerified - config.IgnoreUserInfo = tc.IgnoreUserInfo + fake := oidctest.NewFakeIDP(t, + oidctest.WithRefreshHook(func(_ string) error { + return xerrors.New("refreshing token should never occur") + }), + oidctest.WithServing(), + oidctest.WithStaticUserInfo(tc.UserInfoClaims), + ) + cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = tc.AllowSignups + cfg.EmailDomain = tc.EmailDomain + cfg.IgnoreEmailVerified = tc.IgnoreEmailVerified + cfg.IgnoreUserInfo = tc.IgnoreUserInfo + }) + auditor := audit.NewMock() logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - client := coderdtest.New(t, &coderdtest.Options{ + owner := coderdtest.New(t, &coderdtest.Options{ Auditor: auditor, - OIDCConfig: config, + OIDCConfig: cfg, Logger: &logger, }) numLogs := len(auditor.AuditLogs()) - resp := oidcCallback(t, client, conf.EncodeClaims(t, tc.IDTokenClaims)) + client, resp := fake.AttemptLogin(t, owner, tc.IDTokenClaims) numLogs++ // add an audit log for login - assert.Equal(t, tc.StatusCode, resp.StatusCode) + require.Equal(t, tc.StatusCode, resp.StatusCode) ctx := testutil.Context(t, testutil.WaitLong) if tc.Username != "" { - client.SetSessionToken(authCookieValue(resp.Cookies())) user, err := client.User(ctx, "me") require.NoError(t, err) require.Equal(t, tc.Username, user.Username) @@ -906,7 +845,6 @@ func TestUserOIDC(t *testing.T) { } if tc.AvatarURL != "" { - client.SetSessionToken(authCookieValue(resp.Cookies())) user, err := client.User(ctx, "me") require.NoError(t, err) require.Equal(t, tc.AvatarURL, user.AvatarURL) @@ -919,26 +857,29 @@ func TestUserOIDC(t *testing.T) { t.Run("OIDCConvert", func(t *testing.T) { t.Parallel() - auditor := audit.NewMock() - conf := coderdtest.NewOIDCConfig(t, "") - config := conf.OIDCConfig(t, nil) - config.AllowSignups = true + auditor := audit.NewMock() + fake := oidctest.NewFakeIDP(t, + oidctest.WithRefreshHook(func(_ string) error { + return xerrors.New("refreshing token should never occur") + }), + oidctest.WithServing(), + ) + cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + }) - cfg := coderdtest.DeploymentValues(t) client := coderdtest.New(t, &coderdtest.Options{ - Auditor: auditor, - OIDCConfig: config, - DeploymentValues: cfg, + Auditor: auditor, + OIDCConfig: cfg, }) - owner := coderdtest.CreateFirstUser(t, client) + owner := coderdtest.CreateFirstUser(t, client) user, userData := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) - code := conf.EncodeClaims(t, jwt.MapClaims{ + claims := jwt.MapClaims{ "email": userData.Email, - }) - + } var err error user.HTTPClient.Jar, err = cookiejar.New(nil) require.NoError(t, err) @@ -950,52 +891,58 @@ func TestUserOIDC(t *testing.T) { }) require.NoError(t, err) - resp := oidcCallbackWithState(t, user, code, convertResponse.StateString, nil) - require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + fake.LoginClient(t, user, claims, func(r *http.Request) { + r.URL.RawQuery = url.Values{ + "oidc_merge_state": {convertResponse.StateString}, + }.Encode() + r.Header.Set(codersdk.SessionTokenHeader, user.SessionToken()) + cookies := user.HTTPClient.Jar.Cookies(r.URL) + for _, cookie := range cookies { + r.AddCookie(cookie) + } + }) }) t.Run("AlternateUsername", func(t *testing.T) { t.Parallel() auditor := audit.NewMock() - conf := coderdtest.NewOIDCConfig(t, "") - - config := conf.OIDCConfig(t, nil) - config.AllowSignups = true + fake := oidctest.NewFakeIDP(t, + oidctest.WithRefreshHook(func(_ string) error { + return xerrors.New("refreshing token should never occur") + }), + oidctest.WithServing(), + ) + cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + }) client := coderdtest.New(t, &coderdtest.Options{ Auditor: auditor, - OIDCConfig: config, + OIDCConfig: cfg, }) - numLogs := len(auditor.AuditLogs()) - code := conf.EncodeClaims(t, jwt.MapClaims{ + numLogs := len(auditor.AuditLogs()) + claims := jwt.MapClaims{ "email": "jon@coder.com", - }) - resp := oidcCallback(t, client, code) - numLogs++ // add an audit log for login + } - assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + userClient, _ := fake.Login(t, client, claims) + numLogs++ // add an audit log for login ctx := testutil.Context(t, testutil.WaitLong) - - client.SetSessionToken(authCookieValue(resp.Cookies())) - user, err := client.User(ctx, "me") + user, err := userClient.User(ctx, "me") require.NoError(t, err) require.Equal(t, "jon", user.Username) // Pass a different subject field so that we prompt creating a - // new user. - code = conf.EncodeClaims(t, jwt.MapClaims{ + // new user + userClient, _ = fake.Login(t, client, jwt.MapClaims{ "email": "jon@example2.com", "sub": "diff", }) - resp = oidcCallback(t, client, code) numLogs++ // add an audit log for login - assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - - client.SetSessionToken(authCookieValue(resp.Cookies())) - user, err = client.User(ctx, "me") + user, err = userClient.User(ctx, "me") require.NoError(t, err) require.True(t, strings.HasPrefix(user.Username, "jon-"), "username %q should have prefix %q", user.Username, "jon-") @@ -1006,45 +953,62 @@ func TestUserOIDC(t *testing.T) { t.Run("Disabled", func(t *testing.T) { t.Parallel() client := coderdtest.New(t, nil) - resp := oidcCallback(t, client, "asdf") + oauthURL, err := client.URL.Parse("/api/v2/users/oidc/callback") + require.NoError(t, err) + + req, err := http.NewRequest("GET", oauthURL.String(), nil) + require.NoError(t, err) + resp, err := client.HTTPClient.Do(req) + require.NoError(t, err) + resp.Body.Close() + require.Equal(t, http.StatusBadRequest, resp.StatusCode) }) t.Run("NoIDToken", func(t *testing.T) { t.Parallel() + fake := oidctest.NewFakeIDP(t, + oidctest.WithRefreshHook(func(_ string) error { + return xerrors.New("refreshing token should never occur") + }), + oidctest.WithServing(), + ) + cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + }) + client := coderdtest.New(t, &coderdtest.Options{ - OIDCConfig: &coderd.OIDCConfig{ - OAuth2Config: &testutil.OAuth2Config{}, - }, + OIDCConfig: cfg, }) - resp := oidcCallback(t, client, "asdf") + _, resp := fake.AttemptLogin(t, client, jwt.MapClaims{}) require.Equal(t, http.StatusBadRequest, resp.StatusCode) }) t.Run("BadVerify", func(t *testing.T) { t.Parallel() - verifier := oidc.NewVerifier("", &oidc.StaticKeySet{ + badVerifier := oidc.NewVerifier("", &oidc.StaticKeySet{ PublicKeys: []crypto.PublicKey{}, }, &oidc.Config{}) - provider := &oidc.Provider{} + badProvider := &oidc.Provider{} + + fake := oidctest.NewFakeIDP(t, + oidctest.WithRefreshHook(func(_ string) error { + return xerrors.New("refreshing token should never occur") + }), + oidctest.WithServing(), + ) + cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.Provider = badProvider + cfg.Verifier = badVerifier + }) client := coderdtest.New(t, &coderdtest.Options{ - OIDCConfig: &coderd.OIDCConfig{ - OAuth2Config: &testutil.OAuth2Config{ - Token: (&oauth2.Token{ - AccessToken: "token", - }).WithExtra(map[string]interface{}{ - "id_token": "invalid", - }), - }, - Provider: provider, - Verifier: verifier, - }, + OIDCConfig: cfg, }) - resp := oidcCallback(t, client, "asdf") - + _, resp := fake.AttemptLogin(t, client, jwt.MapClaims{}) require.Equal(t, http.StatusBadRequest, resp.StatusCode) }) } @@ -1175,36 +1139,6 @@ func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response { return res } -func oidcCallback(t *testing.T, client *codersdk.Client, code string) *http.Response { - return oidcCallbackWithState(t, client, code, "somestate", nil) -} - -func oidcCallbackWithState(t *testing.T, client *codersdk.Client, code, state string, modify func(r *http.Request)) *http.Response { - t.Helper() - - client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - } - oauthURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/users/oidc/callback?code=%s&state=%s", code, state)) - require.NoError(t, err) - req, err := http.NewRequestWithContext(context.Background(), "GET", oauthURL.String(), nil) - require.NoError(t, err) - req.AddCookie(&http.Cookie{ - Name: codersdk.OAuth2StateCookie, - Value: state, - }) - if modify != nil { - modify(req) - } - res, err := client.HTTPClient.Do(req) - require.NoError(t, err) - defer res.Body.Close() - data, err := io.ReadAll(res.Body) - require.NoError(t, err) - t.Log(string(data)) - return res -} - func i64ptr(i int64) *int64 { return &i } diff --git a/coderd/users_test.go b/coderd/users_test.go index c36b4fad98afd..d4150365a9ba3 100644 --- a/coderd/users_test.go +++ b/coderd/users_test.go @@ -8,7 +8,10 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt" + "github.com/coder/coder/v2/coderd" + "github.com/coder/coder/v2/coderd/coderdtest/oidctest" + + "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -593,15 +596,15 @@ func TestPostUsers(t *testing.T) { t.Run("CreateOIDCLoginType", func(t *testing.T) { t.Parallel() email := "another@user.org" - conf := coderdtest.NewOIDCConfig(t, "") - config := conf.OIDCConfig(t, jwt.MapClaims{ - "email": email, + fake := oidctest.NewFakeIDP(t, + oidctest.WithServing(), + ) + cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true }) - config.AllowSignups = false - config.IgnoreUserInfo = true client := coderdtest.New(t, &coderdtest.Options{ - OIDCConfig: config, + OIDCConfig: cfg, }) first := coderdtest.CreateFirstUser(t, client) @@ -618,15 +621,9 @@ func TestPostUsers(t *testing.T) { require.NoError(t, err) // Try to log in with OIDC. - userClient := codersdk.New(client.URL) - resp := oidcCallback(t, userClient, conf.EncodeClaims(t, jwt.MapClaims{ + userClient, _ := fake.Login(t, client, jwt.MapClaims{ "email": email, - })) - require.Equal(t, resp.StatusCode, http.StatusTemporaryRedirect) - // Set the client to use this OIDC context - authCookie := authCookieValue(resp.Cookies()) - userClient.SetSessionToken(authCookie) - _ = resp.Body.Close() + }) found, err := userClient.User(ctx, "me") require.NoError(t, err) diff --git a/enterprise/coderd/userauth_test.go b/enterprise/coderd/userauth_test.go index 8360a3ada3e97..07328b2f2c617 100644 --- a/enterprise/coderd/userauth_test.go +++ b/enterprise/coderd/userauth_test.go @@ -49,7 +49,7 @@ func TestUserOIDC(t *testing.T) { // User should be in 0 groups. runner.AssertRoles(t, "alice", []string{}) // Force a refresh, and assert nothing has changes - runner.ForceRefresh(t, client, claims)(t) + runner.ForceRefresh(t, client, claims) runner.AssertRoles(t, "alice", []string{}) }) @@ -86,7 +86,7 @@ func TestUserOIDC(t *testing.T) { runner.ForceRefresh(t, client, jwt.MapClaims{ "email": "alice@coder.com", "roles": []string{"random"}, - })(t) + }) runner.AssertRoles(t, "alice", []string{}) }) @@ -253,7 +253,7 @@ func TestUserOIDC(t *testing.T) { // Refresh without the group claim runner.ForceRefresh(t, client, jwt.MapClaims{ "email": "alice@coder.com", - })(t) + }) runner.AssertGroups(t, "alice", []string{}) }) @@ -355,7 +355,7 @@ func TestUserOIDC(t *testing.T) { // Refresh multiple times. for i := 0; i < 3; i++ { - runner.ForceRefresh(t, client, claims)(t) + runner.ForceRefresh(t, client, claims) } }) }) @@ -582,7 +582,7 @@ type oidcTestRunner struct { // expires the oauth token so the next authenticated API call will // trigger a refresh. The returned function is an example of said call. // It just calls the /users/me endpoint to trigger the refresh. - ForceRefresh func(t *testing.T, client *codersdk.Client, idToken jwt.MapClaims) func(t *testing.T) + ForceRefresh func(t *testing.T, client *codersdk.Client, idToken jwt.MapClaims) } type oidcTestConfig struct { @@ -668,8 +668,8 @@ func setupOIDCTest(t *testing.T, settings oidcTestConfig) *oidcTestRunner { AdminUser: admin, API: api, Login: helper.Login, - ForceRefresh: func(t *testing.T, client *codersdk.Client, idToken jwt.MapClaims) func(t *testing.T) { - return helper.ForceRefresh(t, api.Database, client, idToken) + ForceRefresh: func(t *testing.T, client *codersdk.Client, idToken jwt.MapClaims) { + helper.ForceRefresh(t, api.Database, client, idToken) }, } } From a1b716b6aecbbefb757f18d93fcaf4afac3a94f8 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 24 Aug 2023 15:15:13 -0500 Subject: [PATCH 06/18] comments --- coderd/coderdtest/oidctest/helper.go | 7 +++- coderd/coderdtest/oidctest/idp.go | 53 +++++++++++++++++++------- coderd/coderdtest/oidctest/idp_test.go | 1 - coderd/coderdtest/oidctest/map.go | 1 + enterprise/coderd/userauth_test.go | 7 +--- 5 files changed, 48 insertions(+), 21 deletions(-) diff --git a/coderd/coderdtest/oidctest/helper.go b/coderd/coderdtest/oidctest/helper.go index dc75b0efd5f6e..40f643181f9fc 100644 --- a/coderd/coderdtest/oidctest/helper.go +++ b/coderd/coderdtest/oidctest/helper.go @@ -46,6 +46,7 @@ func (h *LoginHelper) Login(t *testing.T, idTokenClaims jwt.MapClaims) (*codersd return h.fake.Login(t, unauthenticatedClient, idTokenClaims) } +// ExpireOauthToken expires the oauth token for the given user. func (h *LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *codersdk.Client) (refreshToken string) { t.Helper() @@ -79,7 +80,11 @@ func (h *LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *co return link.OAuthRefreshToken } -// ForceRefresh forces the client to refresh its oauth token. +// ForceRefresh forces the client to refresh its oauth token. It does this by +// expiring the oauth token, then doing an authenticated call. This will force +// the API Key middleware to refresh the oauth token. +// +// A unit test assertion makes sure the refresh token is used. func (h *LoginHelper) ForceRefresh(t *testing.T, db database.Store, user *codersdk.Client, idToken jwt.MapClaims) { t.Helper() diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index 19fe771704a9f..048228c94a4b1 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -18,8 +18,6 @@ import ( "testing" "time" - "github.com/coder/coder/v2/codersdk" - "github.com/coreos/go-oidc/v3/oidc" "github.com/go-chi/chi/v5" "github.com/go-jose/go-jose/v3" @@ -33,8 +31,11 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd" + "github.com/coder/coder/v2/codersdk" ) +// FakeIDP is a functional OIDC provider. +// It only supports 1 OIDC client. type FakeIDP struct { issuer string key *rsa.PrivateKey @@ -47,6 +48,8 @@ type FakeIDP struct { clientSecret string logger slog.Logger + // These maps are used to control the state of the IDP. + // That is the various access tokens, refresh tokens, states, etc. codeToStateMap *SyncMap[string, string] // Token -> Email accessTokens *SyncMap[string, string] @@ -68,18 +71,23 @@ type FakeIDP struct { type FakeIDPOpt func(idp *FakeIDP) +// WithRefreshHook is called when a refresh token is used. The email is +// the email of the user that is being refreshed assuming the claims are correct. func WithRefreshHook(hook func(email string) error) func(*FakeIDP) { return func(f *FakeIDP) { f.hookOnRefresh = hook } } +// WithLogging is optional, but will log some HTTP calls made to the IDP. func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) { return func(f *FakeIDP) { f.logger = slogtest.Make(t, options) } } +// WithStaticUserInfo is optional, but will return the same user info for +// every user on the /userinfo endpoint. func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) { return func(f *FakeIDP) { f.hookUserInfo = func(_ string) jwt.MapClaims { @@ -94,6 +102,7 @@ func WithDynamicUserInfo(userInfoFunc func(email string) jwt.MapClaims) func(*Fa } } +// WithServing makes the IDP run an actual http server. func WithServing() func(*FakeIDP) { return func(f *FakeIDP) { f.serve = true @@ -218,6 +227,8 @@ func (f *FakeIDP) AttemptLogin(t testing.TB, client *codersdk.Client, idTokenCla // LoginClient reuses the context of the passed in client. This means the same // cookies will be used. This should be an unauthenticated client in most cases. +// +// This is a niche case, but it is needed for testing ConvertLoginType. func (f *FakeIDP) LoginClient(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) { t.Helper() @@ -268,7 +279,10 @@ func (f *FakeIDP) LoginClient(t testing.TB, client *codersdk.Client, idTokenClai } // OIDCCallback will emulate the IDP redirecting back to the Coder callback. -// This is helpful if no Coderd exists. +// This is helpful if no Coderd exists because the IDP needs to redirect to +// something. +// Essentially this is used to fake the Coderd side of the exchange. +// The flow starts at the user hitting the OIDC login page. func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.MapClaims) (*http.Response, error) { t.Helper() f.stateToIDTokenClaims.Store(state, idTokenClaims) @@ -320,6 +334,7 @@ func (f *FakeIDP) newRefreshTokens(email string) string { return refreshToken } +// authenticateBearerTokenRequest enforces the access token is valid. func (f *FakeIDP) authenticateBearerTokenRequest(t testing.TB, req *http.Request) (string, error) { t.Helper() @@ -332,6 +347,7 @@ func (f *FakeIDP) authenticateBearerTokenRequest(t testing.TB, req *http.Request return token, nil } +// authenticateOIDClientRequest enforces the client_id and client_secret are valid. func (f *FakeIDP) authenticateOIDClientRequest(t testing.TB, req *http.Request) (url.Values, error) { t.Helper() @@ -353,6 +369,7 @@ func (f *FakeIDP) authenticateOIDClientRequest(t testing.TB, req *http.Request) return values, nil } +// encodeClaims is a helper func to convert claims to a valid JWT. func (f *FakeIDP) encodeClaims(t testing.TB, claims jwt.MapClaims) string { t.Helper() @@ -374,6 +391,7 @@ func (f *FakeIDP) encodeClaims(t testing.TB, claims jwt.MapClaims) string { return signed } +// httpHandler is the IDP http server. func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { t.Helper() @@ -572,12 +590,12 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { return mux } -// HTTPClient runs the IDP in memory and returns an http.Client that can be used -// to make requests to the IDP. All requests are handled in memory, and no network -// requests are made. +// HTTPClient does nothing if IsServing is used. // -// If a request is not to the IDP, then the passed in client will be used. -// If no client is passed in, then any regular network requests will fail. +// If IsServing is not used, then it will return a client that will make requests +// to the IDP all in memory. If a request is not to the IDP, then the passed in +// client will be used. If no client is passed in, then any regular network +// requests will fail. func (f *FakeIDP) HTTPClient(rest *http.Client) *http.Client { if f.serve { if rest == nil || rest.Transport == nil { @@ -619,31 +637,40 @@ func (f *FakeIDP) RefreshUsed(refreshToken string) bool { return used } +// UpdateRefreshClaims allows the caller to change what claims are returned +// for a given refresh token. By default, all refreshes use the same claims as +// the original IDToken issuance. func (f *FakeIDP) UpdateRefreshClaims(refreshToken string, claims jwt.MapClaims) { f.refreshIDTokenClaims.Store(refreshToken, claims) } +// SetRedirect is required for the IDP to know where to redirect and call +// Coderd. func (f *FakeIDP) SetRedirect(t testing.TB, url string) { t.Helper() f.cfg.RedirectURL = url } +// SetCoderdCallback is optional and only works if not using the IsServing. +// It will setup a fake "Coderd" for the IDP to call when the IDP redirects +// back after authenticating. func (f *FakeIDP) SetCoderdCallback(callback func(req *http.Request) (*http.Response, error)) { + if f.serve { + panic("cannot set callback handler when using 'WithServing'. Must implement an actual 'Coderd'") + } f.fakeCoderd = callback } func (f *FakeIDP) SetCoderdCallbackHandler(handler http.HandlerFunc) { - if f.serve { - panic("cannot set callback handler when using 'WithServing'. Must implement an actual 'Coderd'") - } - f.fakeCoderd = func(req *http.Request) (*http.Response, error) { + f.SetCoderdCallback(func(req *http.Request) (*http.Response, error) { resp := httptest.NewRecorder() handler.ServeHTTP(resp, req) return resp.Result(), nil - } + }) } +// OIDCConfig returns the OIDC config to use for Coderd. func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig { t.Helper() if len(scopes) == 0 { diff --git a/coderd/coderdtest/oidctest/idp_test.go b/coderd/coderdtest/oidctest/idp_test.go index a8098c874787f..e0663f9f9aa0b 100644 --- a/coderd/coderdtest/oidctest/idp_test.go +++ b/coderd/coderdtest/oidctest/idp_test.go @@ -8,7 +8,6 @@ import ( "time" "github.com/golang-jwt/jwt/v4" - "github.com/stretchr/testify/assert" "github.com/coder/coder/v2/coderd/coderdtest/oidctest" diff --git a/coderd/coderdtest/oidctest/map.go b/coderd/coderdtest/oidctest/map.go index 864e4ae926ee1..c978c24f99773 100644 --- a/coderd/coderdtest/oidctest/map.go +++ b/coderd/coderdtest/oidctest/map.go @@ -2,6 +2,7 @@ package oidctest import "sync" +// SyncMap is a type safe sync.Map type SyncMap[K, V any] struct { m sync.Map } diff --git a/enterprise/coderd/userauth_test.go b/enterprise/coderd/userauth_test.go index 07328b2f2c617..45d91c9f45701 100644 --- a/enterprise/coderd/userauth_test.go +++ b/enterprise/coderd/userauth_test.go @@ -572,15 +572,10 @@ type oidcTestRunner struct { API *coderden.API // Login will call the OIDC flow with an unauthenticated client. - // The customer actions will all be taken care of, and the idToken claims - // will be returned. + // The IDP will return the idToken claims. Login func(t *testing.T, idToken jwt.MapClaims) (*codersdk.Client, *http.Response) // ForceRefresh will use an authenticated codersdk.Client, and force their // OIDC token to be expired and require a refresh. The refresh will use the claims provided. - // - // The client MUST be used to actually trigger the refresh. This just - // expires the oauth token so the next authenticated API call will - // trigger a refresh. The returned function is an example of said call. // It just calls the /users/me endpoint to trigger the refresh. ForceRefresh func(t *testing.T, client *codersdk.Client, idToken jwt.MapClaims) } From dffeb16bb4ef240181d9ad3ddcfd14b37df30b71 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 24 Aug 2023 15:17:22 -0500 Subject: [PATCH 07/18] Fix import --- coderd/coderdtest/oidctest/idp.go | 4 +--- coderd/userauth_test.go | 14 ++------------ 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index 048228c94a4b1..bc5947452064c 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -355,7 +355,6 @@ func (f *FakeIDP) authenticateOIDClientRequest(t testing.TB, req *http.Request) values, err := url.ParseQuery(string(data)) if !assert.NoError(t, err, "parse token request values") { return nil, xerrors.New("invalid token request") - } if !assert.Equal(t, f.clientID, values.Get("client_id"), "client_id mismatch") { @@ -419,7 +418,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { state := r.URL.Query().Get("state") scope := r.URL.Query().Get("scope") - var _ = scope + _ = scope responseType := r.URL.Query().Get("response_type") switch responseType { @@ -557,7 +556,6 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusBadRequest) return } - var _ = token email, ok := f.accessTokens.Load(token) if !ok { diff --git a/coderd/userauth_test.go b/coderd/userauth_test.go index 4fd17abc35633..fef8aa3af233a 100644 --- a/coderd/userauth_test.go +++ b/coderd/userauth_test.go @@ -26,8 +26,6 @@ import ( "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" - "github.com/coder/coder/v2/enterprise/coderd/license" "github.com/coder/coder/v2/testutil" ) @@ -49,16 +47,8 @@ func TestOIDCOauthLoginWithExisting(t *testing.T) { cfg.IgnoreUserInfo = true }) - client, _, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - OIDCConfig: cfg, - }, - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{ - codersdk.FeatureUserRoleManagement: 1, - codersdk.FeatureTemplateRBAC: 1, - }, - }, + client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ + OIDCConfig: cfg, }) const username = "alice" From 3be7d312b4f0cf006d9bb1e8f05ecb4dfa673a01 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 24 Aug 2023 15:24:50 -0500 Subject: [PATCH 08/18] Fix comments --- coderd/coderdtest/oidctest/helper.go | 6 +++--- coderd/coderdtest/oidctest/idp.go | 16 ++++++++++------ coderd/userauth_test.go | 2 +- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/coderd/coderdtest/oidctest/helper.go b/coderd/coderdtest/oidctest/helper.go index 40f643181f9fc..f236ae8117ea4 100644 --- a/coderd/coderdtest/oidctest/helper.go +++ b/coderd/coderdtest/oidctest/helper.go @@ -67,13 +67,13 @@ func (h *LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *co }) require.NoError(t, err, "get user link") - // Fetch the oauth link for the given user. + // Expire the oauth link for the given user. _, err = db.UpdateUserLink(ctx, database.UpdateUserLinkParams{ OAuthAccessToken: link.OAuthAccessToken, OAuthRefreshToken: link.OAuthRefreshToken, OAuthExpiry: time.Now().Add(time.Hour * -1), - UserID: key.UserID, - LoginType: database.LoginTypeOIDC, + UserID: link.UserID, + LoginType: link.LoginType, }) require.NoError(t, err, "expire user link") diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index bc5947452064c..ef32096057605 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -149,7 +149,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { idp.updateIssuerURL(t, idp.issuer) if idp.serve { - idp.Serve(t) + idp.realServer(t) } return idp @@ -176,8 +176,8 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { } } -// Serve is optional, but turns the FakeIDP into a real http server. -func (f *FakeIDP) Serve(t testing.TB) *httptest.Server { +// realServer turns the FakeIDP into a real http server. +func (f *FakeIDP) realServer(t testing.TB) *httptest.Server { t.Helper() ctx, cancel := context.WithCancel(context.Background()) @@ -222,14 +222,14 @@ func (f *FakeIDP) AttemptLogin(t testing.TB, client *codersdk.Client, idTokenCla unauthenticated := codersdk.New(client.URL) unauthenticated.HTTPClient = shallowCpyCli - return f.LoginClient(t, unauthenticated, idTokenClaims, opts...) + return f.LoginWithClient(t, unauthenticated, idTokenClaims, opts...) } -// LoginClient reuses the context of the passed in client. This means the same +// LoginWithClient reuses the context of the passed in client. This means the same // cookies will be used. This should be an unauthenticated client in most cases. // // This is a niche case, but it is needed for testing ConvertLoginType. -func (f *FakeIDP) LoginClient(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) { +func (f *FakeIDP) LoginWithClient(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) { t.Helper() coderOauthURL, err := client.URL.Parse("/api/v2/users/oidc/callback") @@ -285,6 +285,10 @@ func (f *FakeIDP) LoginClient(t testing.TB, client *codersdk.Client, idTokenClai // The flow starts at the user hitting the OIDC login page. func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.MapClaims) (*http.Response, error) { t.Helper() + if f.serve { + panic("cannot use OIDCCallback with WithServing. This is only for the in memory usage") + } + f.stateToIDTokenClaims.Store(state, idTokenClaims) cli := f.HTTPClient(nil) diff --git a/coderd/userauth_test.go b/coderd/userauth_test.go index fef8aa3af233a..96b2de95c935f 100644 --- a/coderd/userauth_test.go +++ b/coderd/userauth_test.go @@ -881,7 +881,7 @@ func TestUserOIDC(t *testing.T) { }) require.NoError(t, err) - fake.LoginClient(t, user, claims, func(r *http.Request) { + fake.LoginWithClient(t, user, claims, func(r *http.Request) { r.URL.RawQuery = url.Values{ "oidc_merge_state": {convertResponse.StateString}, }.Encode() From 4b31e28eff230ebf75e26a5b77c2c1acbb30af37 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 24 Aug 2023 16:06:12 -0500 Subject: [PATCH 09/18] Fix issue with expirary json --- coderd/coderdtest/oidctest/idp.go | 22 ++++++++++-- coderd/oauthpki/oidcpki.go | 5 ++- coderd/oauthpki/okidcpki_test.go | 58 +++++++++++++++++++++++++++++-- 3 files changed, 80 insertions(+), 5 deletions(-) diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index ef32096057605..f0c3f09c75384 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -63,6 +63,9 @@ type FakeIDP struct { hookUserInfo func(email string) jwt.MapClaims fakeCoderd func(req *http.Request) (*http.Response, error) hookOnRefresh func(email string) error + // Custom authentication for the client. This is useful if you want + // to test something like PKI auth vs a client_secret. + hookAuthenticateClient func(t testing.TB, req *http.Request) (url.Values, error) // Optional if you want to use a real http network request assuming // it is not directed to the IDP. defaultClient *http.Client @@ -79,6 +82,12 @@ func WithRefreshHook(hook func(email string) error) func(*FakeIDP) { } } +func WithCustomClientAuth(hook func(t testing.TB, req *http.Request) (url.Values, error)) func(*FakeIDP) { + return func(f *FakeIDP) { + f.hookAuthenticateClient = hook + } +} + // WithLogging is optional, but will log some HTTP calls made to the IDP. func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) { return func(f *FakeIDP) { @@ -109,6 +118,12 @@ func WithServing() func(*FakeIDP) { } } +func WithIssuer(issuer string) func(*FakeIDP) { + return func(f *FakeIDP) { + f.issuer = issuer + } +} + const ( authorizePath = "/oauth2/authorize" tokenPath = "/oauth2/token" @@ -137,7 +152,6 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { hookOnRefresh: func(_ string) error { return nil }, hookUserInfo: func(email string) jwt.MapClaims { return jwt.MapClaims{} }, } - idp.handler = idp.httpHandler(t) for _, opt := range opts { opt(idp) @@ -147,6 +161,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { idp.issuer = "https://coder.com" } + idp.handler = idp.httpHandler(t) idp.updateIssuerURL(t, idp.issuer) if idp.serve { idp.realServer(t) @@ -355,6 +370,10 @@ func (f *FakeIDP) authenticateBearerTokenRequest(t testing.TB, req *http.Request func (f *FakeIDP) authenticateOIDClientRequest(t testing.TB, req *http.Request) (url.Values, error) { t.Helper() + if f.hookAuthenticateClient != nil { + return f.hookAuthenticateClient(t, req) + } + data, _ := io.ReadAll(req.Body) values, err := url.ParseQuery(string(data)) if !assert.NoError(t, err, "parse token request values") { @@ -541,7 +560,6 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { "refresh_token": refreshToken, "token_type": "Bearer", "expires_in": int64(time.Minute * 5), - "expiry": exp.Unix(), "id_token": f.encodeClaims(t, claims), } // Store the claims for the next refresh diff --git a/coderd/oauthpki/oidcpki.go b/coderd/oauthpki/oidcpki.go index d5bc625336ab7..c44d130e5be9f 100644 --- a/coderd/oauthpki/oidcpki.go +++ b/coderd/oauthpki/oidcpki.go @@ -215,7 +215,10 @@ func (src *jwtTokenSource) Token() (*oauth2.Token, error) { } var tokenRes struct { - oauth2.Token + AccessToken string `json:"access_token"` + TokenType string `json:"token_type,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + // Extra fields returned by the refresh that are needed IDToken string `json:"id_token"` ExpiresIn int64 `json:"expires_in"` // relative seconds from now diff --git a/coderd/oauthpki/okidcpki_test.go b/coderd/oauthpki/okidcpki_test.go index 27593607f2a16..682d8e06ac7fd 100644 --- a/coderd/oauthpki/okidcpki_test.go +++ b/coderd/oauthpki/okidcpki_test.go @@ -12,12 +12,15 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" "golang.org/x/xerrors" + "github.com/coder/coder/v2/coderd" + "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/coderdtest/oidctest" "github.com/coder/coder/v2/coderd/oauthpki" "github.com/coder/coder/v2/testutil" ) @@ -123,6 +126,57 @@ func TestAzureADPKIOIDC(t *testing.T) { require.Error(t, err, "error expected") } +// TestAzureAKPKIWithCoderd uses a fake IDP and a real Coderd to test PKI auth. +func TestAzureAKPKIWithCoderd(t *testing.T) { + t.Parallel() + + scopes := []string{"openid", "email", "profile", "offline_access"} + fake := oidctest.NewFakeIDP(t, + oidctest.WithIssuer("https://login.microsoftonline.com/fake_app"), + oidctest.WithCustomClientAuth(func(t testing.TB, req *http.Request) (url.Values, error) { + values := assertJWTAuth(t, req) + if values == nil { + return nil, xerrors.New("authorizatin failed in request") + } + return values, nil + }), + oidctest.WithServing(), + ) + cfg := fake.OIDCConfig(t, scopes, func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + }) + + oauthCfg := cfg.OAuth2Config.(*oauth2.Config) + // Create the oauthpki config + pki, err := oauthpki.NewOauth2PKIConfig(oauthpki.ConfigParams{ + ClientID: oauthCfg.ClientID, + TokenURL: oauthCfg.Endpoint.TokenURL, + Scopes: scopes, + PemEncodedKey: []byte(testClientKey), + PemEncodedCert: []byte(testClientCert), + Config: oauthCfg, + }) + require.NoError(t, err) + cfg.OAuth2Config = pki + + owner, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ + OIDCConfig: cfg, + }) + + // Create a user and login + const email = "alice@coder.com" + claims := jwt.MapClaims{ + "email": email, + } + helper := oidctest.NewLoginHelper(owner, fake) + user, _ := helper.Login(t, claims) + + // Try refreshing the token more than once. + for i := 0; i < 2; i++ { + helper.ForceRefresh(t, api.Database, user, claims) + } +} + // TestSavedAzureADPKIOIDC was created by capturing actual responses from an Azure // AD instance and saving them to replay, removing some details. // The reason this is done is that this is the only way to assert values @@ -269,7 +323,7 @@ func (f fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { // assertJWTAuth will assert the basic JWT auth assertions. It will return the // url.Values from the request body for any additional assertions to be made. -func assertJWTAuth(t *testing.T, r *http.Request) url.Values { +func assertJWTAuth(t testing.TB, r *http.Request) url.Values { body, err := io.ReadAll(r.Body) if !assert.NoError(t, err) { return nil From 722e36ff8a03af06e8d7342cde23ef56b6c03716 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 24 Aug 2023 16:07:06 -0500 Subject: [PATCH 10/18] fixup! Fix issue with expirary json --- coderd/coderdtest/oidctest/idp.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index f0c3f09c75384..f5364fb5f4832 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -559,7 +559,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { "access_token": f.newToken(email), "refresh_token": refreshToken, "token_type": "Bearer", - "expires_in": int64(time.Minute * 5), + "expires_in": int64((time.Minute * 5).Seconds()), "id_token": f.encodeClaims(t, claims), } // Store the claims for the next refresh From 00b47606e18c6117555e9217b13312639aa562ed Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 24 Aug 2023 16:53:03 -0500 Subject: [PATCH 11/18] Linting --- coderd/coderdtest/oidctest/helper.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/coderdtest/oidctest/helper.go b/coderd/coderdtest/oidctest/helper.go index f236ae8117ea4..e0d0910428d7e 100644 --- a/coderd/coderdtest/oidctest/helper.go +++ b/coderd/coderdtest/oidctest/helper.go @@ -47,7 +47,7 @@ func (h *LoginHelper) Login(t *testing.T, idTokenClaims jwt.MapClaims) (*codersd } // ExpireOauthToken expires the oauth token for the given user. -func (h *LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *codersdk.Client) (refreshToken string) { +func (*LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *codersdk.Client) (refreshToken string) { t.Helper() //nolint:gocritic // Testing From 1c7e8b47d2f4f990e19a1baf30ffb5637400fd59 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 25 Aug 2023 10:02:43 -0500 Subject: [PATCH 12/18] Pr feedback, add authorized redirect urls --- coderd/coderdtest/oidctest/helper.go | 28 ++++++++++---------- coderd/coderdtest/oidctest/idp.go | 38 +++++++++++++++++++++------- 2 files changed, 43 insertions(+), 23 deletions(-) diff --git a/coderd/coderdtest/oidctest/helper.go b/coderd/coderdtest/oidctest/helper.go index e0d0910428d7e..11d9114be2ce8 100644 --- a/coderd/coderdtest/oidctest/helper.go +++ b/coderd/coderdtest/oidctest/helper.go @@ -19,20 +19,20 @@ import ( // It is mainly because refreshing oauth tokens is a bit tricky and requires // some database manipulation. type LoginHelper struct { - fake *FakeIDP - owner *codersdk.Client + fake *FakeIDP + client *codersdk.Client } -func NewLoginHelper(owner *codersdk.Client, fake *FakeIDP) *LoginHelper { - if owner == nil { - panic("owner must not be nil") +func NewLoginHelper(client *codersdk.Client, fake *FakeIDP) *LoginHelper { + if client == nil { + panic("client must not be nil") } if fake == nil { panic("fake must not be nil") } return &LoginHelper{ - fake: fake, - owner: owner, + fake: fake, + client: client, } } @@ -41,13 +41,13 @@ func NewLoginHelper(owner *codersdk.Client, fake *FakeIDP) *LoginHelper { // convenience method. func (h *LoginHelper) Login(t *testing.T, idTokenClaims jwt.MapClaims) (*codersdk.Client, *http.Response) { t.Helper() - unauthenticatedClient := codersdk.New(h.owner.URL) + unauthenticatedClient := codersdk.New(h.client.URL) return h.fake.Login(t, unauthenticatedClient, idTokenClaims) } // ExpireOauthToken expires the oauth token for the given user. -func (*LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *codersdk.Client) (refreshToken string) { +func (*LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *codersdk.Client) database.UserLink { t.Helper() //nolint:gocritic // Testing @@ -68,7 +68,7 @@ func (*LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *code require.NoError(t, err, "get user link") // Expire the oauth link for the given user. - _, err = db.UpdateUserLink(ctx, database.UpdateUserLinkParams{ + updated, err := db.UpdateUserLink(ctx, database.UpdateUserLinkParams{ OAuthAccessToken: link.OAuthAccessToken, OAuthRefreshToken: link.OAuthRefreshToken, OAuthExpiry: time.Now().Add(time.Hour * -1), @@ -77,7 +77,7 @@ func (*LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *code }) require.NoError(t, err, "expire user link") - return link.OAuthRefreshToken + return updated } // ForceRefresh forces the client to refresh its oauth token. It does this by @@ -88,13 +88,13 @@ func (*LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *code func (h *LoginHelper) ForceRefresh(t *testing.T, db database.Store, user *codersdk.Client, idToken jwt.MapClaims) { t.Helper() - refreshToken := h.ExpireOauthToken(t, db, user) + link := h.ExpireOauthToken(t, db, user) // Updates the claims that the IDP will return. By default, it always // uses the original claims for the original oauth token. - h.fake.UpdateRefreshClaims(refreshToken, idToken) + h.fake.UpdateRefreshClaims(link.OAuthRefreshToken, idToken) t.Cleanup(func() { - require.True(t, h.fake.RefreshUsed(refreshToken), "refresh token must be used, but has not. Did you forget to call the returned function from this call?") + require.True(t, h.fake.RefreshUsed(link.OAuthRefreshToken), "refresh token must be used, but has not. Did you forget to call the returned function from this call?") }) // Do any authenticated call to force the refresh diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index f5364fb5f4832..bc5661555c3ee 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -60,9 +60,13 @@ type FakeIDP struct { refreshIDTokenClaims *SyncMap[string, jwt.MapClaims] // hooks - hookUserInfo func(email string) jwt.MapClaims - fakeCoderd func(req *http.Request) (*http.Response, error) - hookOnRefresh func(email string) error + // hookValidRedirectURL can be used to reject a redirect url from the + // IDP -> Application. Almost all IDPs have the concept of + // "Authorized Redirect URLs". This can be used to emulate that. + hookValidRedirectURL func(redirectURL string) error + hookUserInfo func(email string) jwt.MapClaims + fakeCoderd func(req *http.Request) (*http.Response, error) + hookOnRefresh func(email string) error // Custom authentication for the client. This is useful if you want // to test something like PKI auth vs a client_secret. hookAuthenticateClient func(t testing.TB, req *http.Request) (url.Values, error) @@ -74,6 +78,12 @@ type FakeIDP struct { type FakeIDPOpt func(idp *FakeIDP) +func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeIDP) { + return func(f *FakeIDP) { + f.hookValidRedirectURL = hook + } +} + // WithRefreshHook is called when a refresh token is used. The email is // the email of the user that is being refreshed assuming the claims are correct. func WithRefreshHook(hook func(email string) error) func(*FakeIDP) { @@ -421,6 +431,8 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { // This endpoint is required to initialize the OIDC provider. // It is used to get the OIDC configuration. mux.Get("/.well-known/openid-configuration", func(rw http.ResponseWriter, r *http.Request) { + f.logger.Info(r.Context(), "HTTP OIDC Config", slog.F("url", r.URL.String())) + _ = json.NewEncoder(rw).Encode(f.provider) }) @@ -429,19 +441,19 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { // w/e and clicking "Allow". They will be redirected back to the redirect // when this is done. mux.Handle(authorizePath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - f.logger.Info(r.Context(), "HTTP Call Authorize", slog.F("url", string(r.URL.String()))) + f.logger.Info(r.Context(), "HTTP Call Authorize", slog.F("url", r.URL.String())) clientID := r.URL.Query().Get("client_id") - if clientID != f.clientID { - t.Errorf("unexpected client_id %q", clientID) + if !assert.Equal(t, f.clientID, clientID, "unexpected client_id") { http.Error(rw, "invalid client_id", http.StatusBadRequest) + return } redirectURI := r.URL.Query().Get("redirect_uri") state := r.URL.Query().Get("state") scope := r.URL.Query().Get("scope") - _ = scope + assert.NotEmpty(t, scope, "scope is empty") responseType := r.URL.Query().Get("response_type") switch responseType { @@ -456,10 +468,17 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { return } + err := f.hookValidRedirectURL(redirectURI) + if err != nil { + t.Errorf("not authorized redirect_uri by custom hook %q: %s", redirectURI, err.Error()) + http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest) + return + } + ru, err := url.Parse(redirectURI) if err != nil { - t.Errorf("invalid redirect_uri %q", redirectURI) - http.Error(rw, "invalid redirect_uri", http.StatusBadRequest) + t.Errorf("invalid redirect_uri %q: %s", redirectURI, err.Error()) + http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest) return } @@ -573,6 +592,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { token, err := f.authenticateBearerTokenRequest(t, r) f.logger.Info(r.Context(), "HTTP Call UserInfo", slog.Error(err), + slog.F("url", r.URL.String()), ) if err != nil { http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusBadRequest) From 930e7dd3797120b733d5bd985eaad6834cedb23b Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 25 Aug 2023 10:17:29 -0500 Subject: [PATCH 13/18] PR Feedback --- coderd/coderdtest/oidctest/idp.go | 11 +++++++---- enterprise/coderd/userauth_test.go | 4 ++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index bc5661555c3ee..6d0b6716f4571 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -376,15 +376,18 @@ func (f *FakeIDP) authenticateBearerTokenRequest(t testing.TB, req *http.Request return token, nil } -// authenticateOIDClientRequest enforces the client_id and client_secret are valid. -func (f *FakeIDP) authenticateOIDClientRequest(t testing.TB, req *http.Request) (url.Values, error) { +// authenticateOIDCClientRequest enforces the client_id and client_secret are valid. +func (f *FakeIDP) authenticateOIDCClientRequest(t testing.TB, req *http.Request) (url.Values, error) { t.Helper() if f.hookAuthenticateClient != nil { return f.hookAuthenticateClient(t, req) } - data, _ := io.ReadAll(req.Body) + data, err := io.ReadAll(req.Body) + if !assert.NoError(t, err, "read token request body") { + return nil, xerrors.Errorf("authenticate request, read body: %w", err) + } values, err := url.ParseQuery(string(data)) if !assert.NoError(t, err, "parse token request values") { return nil, xerrors.New("invalid token request") @@ -491,7 +494,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { })) mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - values, err := f.authenticateOIDClientRequest(t, r) + values, err := f.authenticateOIDCClientRequest(t, r) f.logger.Info(r.Context(), "HTTP Call Token", slog.Error(err), slog.F("values", values.Encode()), diff --git a/enterprise/coderd/userauth_test.go b/enterprise/coderd/userauth_test.go index 45d91c9f45701..976a9a459cc4d 100644 --- a/enterprise/coderd/userauth_test.go +++ b/enterprise/coderd/userauth_test.go @@ -57,7 +57,7 @@ func TestUserOIDC(t *testing.T) { // roles from an updated claim. t.Run("NewUserAndRemoveRolesOnRefresh", func(t *testing.T) { // TODO: Implement new feature to update roles/groups on OIDC - // refresh tokens. + // refresh tokens. https://github.com/coder/coder/issues/9312 t.Skip("Refreshing tokens does not update roles :(") t.Parallel() @@ -224,7 +224,7 @@ func TestUserOIDC(t *testing.T) { t.Parallel() // TODO: Implement new feature to update roles/groups on OIDC - // refresh tokens. + // refresh tokens. https://github.com/coder/coder/issues/9312 t.Skip("Refreshing tokens does not update groups :(") const groupClaim = "custom-groups" From c61ceb9e29d3446afe679f4a006a159d3c7d94f1 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 25 Aug 2023 10:27:20 -0500 Subject: [PATCH 14/18] Linting --- coderd/coderdtest/coderdtest.go | 16 ------------- coderd/coderdtest/oidctest/idp.go | 31 +++++++++++++------------- coderd/coderdtest/oidctest/idp_test.go | 3 +++ coderd/coderdtest/oidctest/map.go | 13 +++++++---- coderd/userauth_test.go | 3 ++- coderd/users_test.go | 1 + enterprise/coderd/userauth_test.go | 1 - 7 files changed, 30 insertions(+), 38 deletions(-) diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 03b74b38e289d..e0a681e613d9e 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -1108,22 +1108,6 @@ func SDKError(t *testing.T, err error) *codersdk.Error { return cerr } -const testRSAPrivateKey = `-----BEGIN RSA PRIVATE KEY----- -MIICXQIBAAKBgQDLets8+7M+iAQAqN/5BVyCIjhTQ4cmXulL+gm3v0oGMWzLupUS -v8KPA+Tp7dgC/DZPfMLaNH1obBBhJ9DhS6RdS3AS3kzeFrdu8zFHLWF53DUBhS92 -5dCAEuJpDnNizdEhxTfoHrhuCmz8l2nt1pe5eUK2XWgd08Uc93h5ij098wIDAQAB -AoGAHLaZeWGLSaen6O/rqxg2laZ+jEFbMO7zvOTruiIkL/uJfrY1kw+8RLIn+1q0 -wLcWcuEIHgKKL9IP/aXAtAoYh1FBvRPLkovF1NZB0Je/+CSGka6wvc3TGdvppZJe -rKNcUvuOYLxkmLy4g9zuY5qrxFyhtIn2qZzXEtLaVOHzPQECQQDvN0mSajpU7dTB -w4jwx7IRXGSSx65c+AsHSc1Rj++9qtPC6WsFgAfFN2CEmqhMbEUVGPv/aPjdyWk9 -pyLE9xR/AkEA2cGwyIunijE5v2rlZAD7C4vRgdcMyCf3uuPcgzFtsR6ZhyQSgLZ8 -YRPuvwm4cdPJMmO3YwBfxT6XGuSc2k8MjQJBAI0+b8prvpV2+DCQa8L/pjxp+VhR -Xrq2GozrHrgR7NRokTB88hwFRJFF6U9iogy9wOx8HA7qxEbwLZuhm/4AhbECQC2a -d8h4Ht09E+f3nhTEc87mODkl7WJZpHL6V2sORfeq/eIkds+H6CJ4hy5w/bSw8tjf -sz9Di8sGIaUbLZI2rd0CQQCzlVwEtRtoNCyMJTTrkgUuNufLP19RZ5FpyXxBO5/u -QastnN77KfUwdj3SJt44U/uh1jAIv4oSLBr8HYUkbnI8 ------END RSA PRIVATE KEY-----` - func DeploymentValues(t testing.TB) *codersdk.DeploymentValues { var cfg codersdk.DeploymentValues opts := cfg.Options() diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index 6d0b6716f4571..3d65c38b771be 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -70,10 +70,7 @@ type FakeIDP struct { // Custom authentication for the client. This is useful if you want // to test something like PKI auth vs a client_secret. hookAuthenticateClient func(t testing.TB, req *http.Request) (url.Values, error) - // Optional if you want to use a real http network request assuming - // it is not directed to the IDP. - defaultClient *http.Client - serve bool + serve bool } type FakeIDPOpt func(idp *FakeIDP) @@ -135,8 +132,9 @@ func WithIssuer(issuer string) func(*FakeIDP) { } const ( - authorizePath = "/oauth2/authorize" + // nolint:gosec // It thinks this is a secret lol tokenPath = "/oauth2/token" + authorizePath = "/oauth2/authorize" keysPath = "/oauth2/keys" userInfoPath = "/oauth2/userinfo" ) @@ -237,7 +235,7 @@ func (f *FakeIDP) AttemptLogin(t testing.TB, client *codersdk.Client, idTokenCla var err error cli := f.HTTPClient(client.HTTPClient) - shallowCpyCli := &(*cli) + shallowCpyCli := *cli if shallowCpyCli.Jar == nil { shallowCpyCli.Jar, err = cookiejar.New(nil) @@ -245,7 +243,7 @@ func (f *FakeIDP) AttemptLogin(t testing.TB, client *codersdk.Client, idTokenCla } unauthenticated := codersdk.New(client.URL) - unauthenticated.HTTPClient = shallowCpyCli + unauthenticated.HTTPClient = &shallowCpyCli return f.LoginWithClient(t, unauthenticated, idTokenClaims, opts...) } @@ -296,7 +294,7 @@ func (f *FakeIDP) LoginWithClient(t testing.TB, client *codersdk.Client, idToken t.Cleanup(func() { if res.Body != nil { - res.Body.Close() + _ = res.Body.Close() } }) @@ -326,7 +324,7 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map t.Cleanup(func() { if resp.Body != nil { - resp.Body.Close() + _ = resp.Body.Close() } }) return resp, nil @@ -444,7 +442,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { // w/e and clicking "Allow". They will be redirected back to the redirect // when this is done. mux.Handle(authorizePath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - f.logger.Info(r.Context(), "HTTP Call Authorize", slog.F("url", r.URL.String())) + f.logger.Info(r.Context(), "http call authorize", slog.F("url", r.URL.String())) clientID := r.URL.Query().Get("client_id") if !assert.Equal(t, f.clientID, clientID, "unexpected client_id") { @@ -495,7 +493,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { values, err := f.authenticateOIDCClientRequest(t, r) - f.logger.Info(r.Context(), "HTTP Call Token", + f.logger.Info(r.Context(), "http call token", slog.Error(err), slog.F("values", values.Encode()), ) @@ -508,10 +506,11 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { if !ok { return "unknown" } - if _, ok := email.(string); !ok { + emailStr, ok := email.(string) + if !ok { return "wrong-type" } - return email.(string) + return emailStr } var claims jwt.MapClaims @@ -593,7 +592,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { mux.Handle(userInfoPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { token, err := f.authenticateBearerTokenRequest(t, r) - f.logger.Info(r.Context(), "HTTP Call UserInfo", + f.logger.Info(r.Context(), "http call user info", slog.Error(err), slog.F("url", r.URL.String()), ) @@ -612,7 +611,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { })) mux.Handle(keysPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - f.logger.Info(r.Context(), "HTTP Call Keys") + f.logger.Info(r.Context(), "http call keys") set := jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ { @@ -626,7 +625,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { })) mux.NotFound(func(rw http.ResponseWriter, r *http.Request) { - f.logger.Error(r.Context(), "HTTP Call NotFound", slog.F("path", r.URL.Path)) + f.logger.Error(r.Context(), "http call not found", slog.F("path", r.URL.Path)) t.Errorf("unexpected request to IDP at path %q. Not supported", r.URL.Path) }) diff --git a/coderd/coderdtest/oidctest/idp_test.go b/coderd/coderdtest/oidctest/idp_test.go index e0663f9f9aa0b..30529be91b37c 100644 --- a/coderd/coderdtest/oidctest/idp_test.go +++ b/coderd/coderdtest/oidctest/idp_test.go @@ -17,7 +17,10 @@ import ( ) // TestFakeIDPBasicFlow tests the basic flow of the fake IDP. +// It is done all in memory with no actual network requests. func TestFakeIDPBasicFlow(t *testing.T) { + t.Parallel() + fake := oidctest.NewFakeIDP(t, oidctest.WithLogging(t, nil), ) diff --git a/coderd/coderdtest/oidctest/map.go b/coderd/coderdtest/oidctest/map.go index c978c24f99773..36e4e506d6509 100644 --- a/coderd/coderdtest/oidctest/map.go +++ b/coderd/coderdtest/oidctest/map.go @@ -13,12 +13,13 @@ func NewSyncMap[K, V any]() *SyncMap[K, V] { } } -func (s *SyncMap[K, V]) Store(k K, v V) { - s.m.Store(k, v) +func (m *SyncMap[K, V]) Store(k K, v V) { + m.m.Store(k, v) } -func (s *SyncMap[K, V]) Load(key K) (value V, ok bool) { - v, ok := s.m.Load(key) +//nolint:forcetypeassert +func (m *SyncMap[K, V]) Load(key K) (value V, ok bool) { + v, ok := m.m.Load(key) if !ok { var empty V return empty, false @@ -30,6 +31,7 @@ func (m *SyncMap[K, V]) Delete(key K) { m.m.Delete(key) } +//nolint:forcetypeassert func (m *SyncMap[K, V]) LoadAndDelete(key K) (actual V, loaded bool) { act, loaded := m.m.LoadAndDelete(key) if !loaded { @@ -39,6 +41,7 @@ func (m *SyncMap[K, V]) LoadAndDelete(key K) (actual V, loaded bool) { return act.(V), loaded } +//nolint:forcetypeassert func (m *SyncMap[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { act, loaded := m.m.LoadOrStore(key, value) if !loaded { @@ -56,6 +59,7 @@ func (m *SyncMap[K, V]) CompareAndDelete(key K, old V) (deleted bool) { return m.m.CompareAndDelete(key, old) } +//nolint:forcetypeassert func (m *SyncMap[K, V]) Swap(key K, value V) (previous any, loaded bool) { previous, loaded = m.m.Swap(key, value) if !loaded { @@ -65,6 +69,7 @@ func (m *SyncMap[K, V]) Swap(key K, value V) (previous any, loaded bool) { return previous.(V), loaded } +//nolint:forcetypeassert func (m *SyncMap[K, V]) Range(f func(key K, value V) bool) { m.m.Range(func(key, value interface{}) bool { return f(key.(K), value.(V)) diff --git a/coderd/userauth_test.go b/coderd/userauth_test.go index 96b2de95c935f..df8186a195175 100644 --- a/coderd/userauth_test.go +++ b/coderd/userauth_test.go @@ -32,6 +32,7 @@ import ( // This test specifically tests logging in with OIDC when an expired // OIDC session token exists. // The token refreshing should not happen since we are reauthenticating. +// nolint:bodyclose func TestOIDCOauthLoginWithExisting(t *testing.T) { t.Parallel() @@ -946,7 +947,7 @@ func TestUserOIDC(t *testing.T) { oauthURL, err := client.URL.Parse("/api/v2/users/oidc/callback") require.NoError(t, err) - req, err := http.NewRequest("GET", oauthURL.String(), nil) + req, err := http.NewRequestWithContext(context.Background(), "GET", oauthURL.String(), nil) require.NoError(t, err) resp, err := client.HTTPClient.Do(req) require.NoError(t, err) diff --git a/coderd/users_test.go b/coderd/users_test.go index d4150365a9ba3..60e6ddb82aecf 100644 --- a/coderd/users_test.go +++ b/coderd/users_test.go @@ -406,6 +406,7 @@ func TestPostLogout(t *testing.T) { }) } +// nolint:bodyclose func TestPostUsers(t *testing.T) { t.Parallel() t.Run("NoAuth", func(t *testing.T) { diff --git a/enterprise/coderd/userauth_test.go b/enterprise/coderd/userauth_test.go index 976a9a459cc4d..8e76a36b1df14 100644 --- a/enterprise/coderd/userauth_test.go +++ b/enterprise/coderd/userauth_test.go @@ -129,7 +129,6 @@ func TestUserOIDC(t *testing.T) { t.Run("BlockAssignRoles", func(t *testing.T) { t.Parallel() - const oidcRoleName = "TemplateAuthor" runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true From 6e6084bce56b817b60e3e0e93a0446781bc1db51 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 25 Aug 2023 10:40:23 -0500 Subject: [PATCH 15/18] Linting --- coderd/coderdtest/oidctest/idp_test.go | 1 + coderd/oauthpki/okidcpki_test.go | 1 + 2 files changed, 2 insertions(+) diff --git a/coderd/coderdtest/oidctest/idp_test.go b/coderd/coderdtest/oidctest/idp_test.go index 30529be91b37c..0dc1149d93fa9 100644 --- a/coderd/coderdtest/oidctest/idp_test.go +++ b/coderd/coderdtest/oidctest/idp_test.go @@ -18,6 +18,7 @@ import ( // TestFakeIDPBasicFlow tests the basic flow of the fake IDP. // It is done all in memory with no actual network requests. +// nolint:bodyclose func TestFakeIDPBasicFlow(t *testing.T) { t.Parallel() diff --git a/coderd/oauthpki/okidcpki_test.go b/coderd/oauthpki/okidcpki_test.go index 682d8e06ac7fd..ab6e3e3a08179 100644 --- a/coderd/oauthpki/okidcpki_test.go +++ b/coderd/oauthpki/okidcpki_test.go @@ -127,6 +127,7 @@ func TestAzureADPKIOIDC(t *testing.T) { } // TestAzureAKPKIWithCoderd uses a fake IDP and a real Coderd to test PKI auth. +// nolint:bodyclose func TestAzureAKPKIWithCoderd(t *testing.T) { t.Parallel() From 9a600c099f217daa768681200175780d5de2a4ec Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 25 Aug 2023 10:42:46 -0500 Subject: [PATCH 16/18] Move sync map to utils --- coderd/coderdtest/oidctest/idp.go | 26 +++++++++-------- .../oidctest => util/syncmap}/map.go | 28 +++++++++---------- 2 files changed, 28 insertions(+), 26 deletions(-) rename coderd/{coderdtest/oidctest => util/syncmap}/map.go (55%) diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index 3d65c38b771be..60db8dca00230 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -18,6 +18,8 @@ import ( "testing" "time" + "github.com/coder/coder/v2/coderd/util/syncmap" + "github.com/coreos/go-oidc/v3/oidc" "github.com/go-chi/chi/v5" "github.com/go-jose/go-jose/v3" @@ -50,14 +52,14 @@ type FakeIDP struct { // These maps are used to control the state of the IDP. // That is the various access tokens, refresh tokens, states, etc. - codeToStateMap *SyncMap[string, string] + codeToStateMap *syncmap.Map[string, string] // Token -> Email - accessTokens *SyncMap[string, string] + accessTokens *syncmap.Map[string, string] // Refresh Token -> Email - refreshTokensUsed *SyncMap[string, bool] - refreshTokens *SyncMap[string, string] - stateToIDTokenClaims *SyncMap[string, jwt.MapClaims] - refreshIDTokenClaims *SyncMap[string, jwt.MapClaims] + refreshTokensUsed *syncmap.Map[string, bool] + refreshTokens *syncmap.Map[string, string] + stateToIDTokenClaims *syncmap.Map[string, jwt.MapClaims] + refreshIDTokenClaims *syncmap.Map[string, jwt.MapClaims] // hooks // hookValidRedirectURL can be used to reject a redirect url from the @@ -151,12 +153,12 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { clientID: uuid.NewString(), clientSecret: uuid.NewString(), logger: slog.Make(), - codeToStateMap: NewSyncMap[string, string](), - accessTokens: NewSyncMap[string, string](), - refreshTokens: NewSyncMap[string, string](), - refreshTokensUsed: NewSyncMap[string, bool](), - stateToIDTokenClaims: NewSyncMap[string, jwt.MapClaims](), - refreshIDTokenClaims: NewSyncMap[string, jwt.MapClaims](), + codeToStateMap: syncmap.New[string, string](), + accessTokens: syncmap.New[string, string](), + refreshTokens: syncmap.New[string, string](), + refreshTokensUsed: syncmap.New[string, bool](), + stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](), + refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](), hookOnRefresh: func(_ string) error { return nil }, hookUserInfo: func(email string) jwt.MapClaims { return jwt.MapClaims{} }, } diff --git a/coderd/coderdtest/oidctest/map.go b/coderd/util/syncmap/map.go similarity index 55% rename from coderd/coderdtest/oidctest/map.go rename to coderd/util/syncmap/map.go index 36e4e506d6509..d245973efa844 100644 --- a/coderd/coderdtest/oidctest/map.go +++ b/coderd/util/syncmap/map.go @@ -1,24 +1,24 @@ -package oidctest +package syncmap import "sync" -// SyncMap is a type safe sync.Map -type SyncMap[K, V any] struct { +// Map is a type safe sync.Map +type Map[K, V any] struct { m sync.Map } -func NewSyncMap[K, V any]() *SyncMap[K, V] { - return &SyncMap[K, V]{ +func New[K, V any]() *Map[K, V] { + return &Map[K, V]{ m: sync.Map{}, } } -func (m *SyncMap[K, V]) Store(k K, v V) { +func (m *Map[K, V]) Store(k K, v V) { m.m.Store(k, v) } //nolint:forcetypeassert -func (m *SyncMap[K, V]) Load(key K) (value V, ok bool) { +func (m *Map[K, V]) Load(key K) (value V, ok bool) { v, ok := m.m.Load(key) if !ok { var empty V @@ -27,12 +27,12 @@ func (m *SyncMap[K, V]) Load(key K) (value V, ok bool) { return v.(V), ok } -func (m *SyncMap[K, V]) Delete(key K) { +func (m *Map[K, V]) Delete(key K) { m.m.Delete(key) } //nolint:forcetypeassert -func (m *SyncMap[K, V]) LoadAndDelete(key K) (actual V, loaded bool) { +func (m *Map[K, V]) LoadAndDelete(key K) (actual V, loaded bool) { act, loaded := m.m.LoadAndDelete(key) if !loaded { var empty V @@ -42,7 +42,7 @@ func (m *SyncMap[K, V]) LoadAndDelete(key K) (actual V, loaded bool) { } //nolint:forcetypeassert -func (m *SyncMap[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { +func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { act, loaded := m.m.LoadOrStore(key, value) if !loaded { var empty V @@ -51,16 +51,16 @@ func (m *SyncMap[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { return act.(V), loaded } -func (m *SyncMap[K, V]) CompareAndSwap(key K, old V, new V) bool { +func (m *Map[K, V]) CompareAndSwap(key K, old V, new V) bool { return m.m.CompareAndSwap(key, old, new) } -func (m *SyncMap[K, V]) CompareAndDelete(key K, old V) (deleted bool) { +func (m *Map[K, V]) CompareAndDelete(key K, old V) (deleted bool) { return m.m.CompareAndDelete(key, old) } //nolint:forcetypeassert -func (m *SyncMap[K, V]) Swap(key K, value V) (previous any, loaded bool) { +func (m *Map[K, V]) Swap(key K, value V) (previous any, loaded bool) { previous, loaded = m.m.Swap(key, value) if !loaded { var empty V @@ -70,7 +70,7 @@ func (m *SyncMap[K, V]) Swap(key K, value V) (previous any, loaded bool) { } //nolint:forcetypeassert -func (m *SyncMap[K, V]) Range(f func(key K, value V) bool) { +func (m *Map[K, V]) Range(f func(key K, value V) bool) { m.m.Range(func(key, value interface{}) bool { return f(key.(K), value.(V)) }) From 6030a5faa8cfa949f990ea8b8891de4012d05291 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 25 Aug 2023 10:46:32 -0500 Subject: [PATCH 17/18] Fix panic --- coderd/coderdtest/oidctest/idp.go | 1 + 1 file changed, 1 insertion(+) diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index 60db8dca00230..1778b2e696f70 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -161,6 +161,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](), hookOnRefresh: func(_ string) error { return nil }, hookUserInfo: func(email string) jwt.MapClaims { return jwt.MapClaims{} }, + hookValidRedirectURL: func(redirectURL string) error { return nil }, } for _, opt := range opts { From 2b0ff9cea445c0030ee05066589b231430d44c15 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 25 Aug 2023 10:54:58 -0500 Subject: [PATCH 18/18] Linting --- coderd/coderdtest/oidctest/idp.go | 12 ++++++------ coderd/userauth_test.go | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index 1778b2e696f70..912d9acd7c221 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -435,7 +435,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { // This endpoint is required to initialize the OIDC provider. // It is used to get the OIDC configuration. mux.Get("/.well-known/openid-configuration", func(rw http.ResponseWriter, r *http.Request) { - f.logger.Info(r.Context(), "HTTP OIDC Config", slog.F("url", r.URL.String())) + f.logger.Info(r.Context(), "http OIDC config", slog.F("url", r.URL.String())) _ = json.NewEncoder(rw).Encode(f.provider) }) @@ -496,7 +496,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { values, err := f.authenticateOIDCClientRequest(t, r) - f.logger.Info(r.Context(), "http call token", + f.logger.Info(r.Context(), "http idp call token", slog.Error(err), slog.F("values", values.Encode()), ) @@ -595,7 +595,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { mux.Handle(userInfoPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { token, err := f.authenticateBearerTokenRequest(t, r) - f.logger.Info(r.Context(), "http call user info", + f.logger.Info(r.Context(), "http call idp user info", slog.Error(err), slog.F("url", r.URL.String()), ) @@ -614,7 +614,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { })) mux.Handle(keysPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - f.logger.Info(r.Context(), "http call keys") + f.logger.Info(r.Context(), "http call idp /keys") set := jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ { @@ -691,10 +691,10 @@ func (f *FakeIDP) UpdateRefreshClaims(refreshToken string, claims jwt.MapClaims) // SetRedirect is required for the IDP to know where to redirect and call // Coderd. -func (f *FakeIDP) SetRedirect(t testing.TB, url string) { +func (f *FakeIDP) SetRedirect(t testing.TB, u string) { t.Helper() - f.cfg.RedirectURL = url + f.cfg.RedirectURL = u } // SetCoderdCallback is optional and only works if not using the IsServing. diff --git a/coderd/userauth_test.go b/coderd/userauth_test.go index df8186a195175..1f37a0721a1e7 100644 --- a/coderd/userauth_test.go +++ b/coderd/userauth_test.go @@ -67,7 +67,7 @@ func TestOIDCOauthLoginWithExisting(t *testing.T) { helper.ExpireOauthToken(t, api.Database, userClient) // Instead of refreshing, just log in again. - userClient, _ = helper.Login(t, claims) + helper.Login(t, claims) } func TestUserLogin(t *testing.T) {