diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index 67186a4fd7ddf..d4f24140b6726 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -20,6 +20,7 @@ import ( "net/url" "strconv" "strings" + "sync" "testing" "time" @@ -58,15 +59,107 @@ type deviceFlow struct { granted bool } +// fakeIDPLocked is a set of fields of FakeIDP that are protected +// behind a mutex. +type fakeIDPLocked struct { + mu sync.RWMutex + + issuer string + issuerURL *url.URL + key *rsa.PrivateKey + provider ProviderJSON + handler http.Handler + cfg *oauth2.Config + fakeCoderd func(req *http.Request) (*http.Response, error) +} + +func (f *fakeIDPLocked) Issuer() string { + f.mu.RLock() + defer f.mu.RUnlock() + return f.issuer +} + +func (f *fakeIDPLocked) IssuerURL() *url.URL { + f.mu.RLock() + defer f.mu.RUnlock() + return f.issuerURL +} + +func (f *fakeIDPLocked) PrivateKey() *rsa.PrivateKey { + f.mu.RLock() + defer f.mu.RUnlock() + return f.key +} + +func (f *fakeIDPLocked) Provider() ProviderJSON { + f.mu.RLock() + defer f.mu.RUnlock() + return f.provider +} + +func (f *fakeIDPLocked) Config() *oauth2.Config { + f.mu.RLock() + defer f.mu.RUnlock() + return f.cfg +} + +func (f *fakeIDPLocked) Handler() http.Handler { + f.mu.RLock() + defer f.mu.RUnlock() + return f.handler +} + +func (f *fakeIDPLocked) SetIssuer(issuer string) { + f.mu.Lock() + defer f.mu.Unlock() + f.issuer = issuer +} + +func (f *fakeIDPLocked) SetIssuerURL(issuerURL *url.URL) { + f.mu.Lock() + defer f.mu.Unlock() + f.issuerURL = issuerURL +} + +func (f *fakeIDPLocked) SetProvider(provider ProviderJSON) { + f.mu.Lock() + defer f.mu.Unlock() + f.provider = provider +} + +// MutateConfig is a helper function to mutate the oauth2.Config. +// Beware of re-entrant locks! +func (f *fakeIDPLocked) MutateConfig(fn func(cfg *oauth2.Config)) { + f.mu.Lock() + if f.cfg == nil { + f.cfg = &oauth2.Config{} + } + fn(f.cfg) + f.mu.Unlock() +} + +func (f *fakeIDPLocked) SetHandler(handler http.Handler) { + f.mu.Lock() + defer f.mu.Unlock() + f.handler = handler +} + +func (f *fakeIDPLocked) SetFakeCoderd(fakeCoderd func(req *http.Request) (*http.Response, error)) { + f.mu.Lock() + defer f.mu.Unlock() + f.fakeCoderd = fakeCoderd +} + +func (f *fakeIDPLocked) FakeCoderd() func(req *http.Request) (*http.Response, error) { + f.mu.RLock() + defer f.mu.RUnlock() + return f.fakeCoderd +} + // FakeIDP is a functional OIDC provider. // It only supports 1 OIDC client. type FakeIDP struct { - issuer string - issuerURL *url.URL - key *rsa.PrivateKey - provider ProviderJSON - handler http.Handler - cfg *oauth2.Config + locked fakeIDPLocked // callbackPath allows changing where the callback path to coderd is expected. // This only affects using the Login helper functions. @@ -110,7 +203,6 @@ type FakeIDP struct { // some claims. defaultIDClaims jwt.MapClaims hookMutateToken func(token map[string]interface{}) - fakeCoderd func(req *http.Request) (*http.Response, error) hookOnRefresh func(email string) error // Custom authentication for the client. This is useful if you want // to test something like PKI auth vs a client_secret. @@ -256,7 +348,7 @@ func WithServing() func(*FakeIDP) { func WithIssuer(issuer string) func(*FakeIDP) { return func(f *FakeIDP) { - f.issuer = issuer + f.locked.SetIssuer(issuer) } } @@ -327,7 +419,9 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { require.NoError(t, err) idp := &FakeIDP{ - key: pkey, + locked: fakeIDPLocked{ + key: pkey, + }, clientID: uuid.NewString(), clientSecret: uuid.NewString(), logger: slog.Make(), @@ -348,12 +442,12 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { opt(idp) } - if idp.issuer == "" { - idp.issuer = "https://coder.com" + if idp.locked.Issuer() == "" { + idp.locked.SetIssuer("https://coder.com") } - idp.handler = idp.httpHandler(t) - idp.updateIssuerURL(t, idp.issuer) + idp.locked.SetHandler(idp.httpHandler(t)) + idp.updateIssuerURL(t, idp.locked.Issuer()) if idp.serve { idp.realServer(t) } @@ -369,11 +463,11 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { } func (f *FakeIDP) WellknownConfig() ProviderJSON { - return f.provider + return f.locked.Provider() } func (f *FakeIDP) IssuerURL() *url.URL { - return f.issuerURL + return f.locked.IssuerURL() } func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { @@ -382,11 +476,11 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { u, err := url.Parse(issuer) require.NoError(t, err, "invalid issuer URL") - f.issuer = issuer - f.issuerURL = u + f.locked.SetIssuer(issuer) + f.locked.SetIssuerURL(u) // ProviderJSON is the JSON representation of the OpenID Connect provider // These are all the urls that the IDP will respond to. - f.provider = ProviderJSON{ + f.locked.SetProvider(ProviderJSON{ Issuer: issuer, AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(), TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(), @@ -397,7 +491,7 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { "RS256", }, ExternalAuthURL: u.ResolveReference(&url.URL{Path: "/external-auth-validate/user"}).String(), - } + }) } // realServer turns the FakeIDP into a real http server. @@ -405,7 +499,7 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server { t.Helper() srvURL := "localhost:0" - issURL, err := url.Parse(f.issuer) + issURL, err := url.Parse(f.locked.Issuer()) if err == nil { if issURL.Hostname() == "localhost" || issURL.Hostname() == "127.0.0.1" { srvURL = issURL.Host @@ -418,7 +512,7 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server { ctx, cancel := context.WithCancel(context.Background()) srv := &httptest.Server{ Listener: l, - Config: &http.Server{Handler: f.handler, ReadHeaderTimeout: time.Second * 5}, + Config: &http.Server{Handler: f.locked.Handler(), ReadHeaderTimeout: time.Second * 5}, } srv.Config.BaseContext = func(_ net.Listener) context.Context { @@ -439,7 +533,7 @@ func (f *FakeIDP) GenerateAuthenticatedToken(claims jwt.MapClaims) (*oauth2.Toke state := uuid.NewString() f.stateToIDTokenClaims.Store(state, claims) code := f.newCode(state) - return f.cfg.Exchange(oidc.ClientContext(context.Background(), f.HTTPClient(nil)), code) + return f.locked.Config().Exchange(oidc.ClientContext(context.Background(), f.HTTPClient(nil)), code) } // Login does the full OIDC flow starting at the "LoginButton". @@ -620,9 +714,9 @@ func (f *FakeIDP) CreateAuthCode(t testing.TB, state string) string { // it expects some claims to be present. f.stateToIDTokenClaims.Store(state, jwt.MapClaims{}) - code, err := OAuth2GetCode(f.cfg.AuthCodeURL(state), func(req *http.Request) (*http.Response, error) { + code, err := OAuth2GetCode(f.locked.Config().AuthCodeURL(state), func(req *http.Request) (*http.Response, error) { rw := httptest.NewRecorder() - f.handler.ServeHTTP(rw, req) + f.locked.Handler().ServeHTTP(rw, req) resp := rw.Result() return resp, nil }) @@ -644,7 +738,7 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map f.stateToIDTokenClaims.Store(state, idTokenClaims) cli := f.HTTPClient(nil) - u := f.cfg.AuthCodeURL(state) + u := f.locked.Config().AuthCodeURL(state) req, err := http.NewRequest("GET", u, nil) require.NoError(t, err) @@ -762,10 +856,10 @@ func (f *FakeIDP) encodeClaims(t testing.TB, claims jwt.MapClaims) string { } if _, ok := claims["iss"]; !ok { - claims["iss"] = f.issuer + claims["iss"] = f.locked.Issuer() } - signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(f.key) + signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(f.locked.PrivateKey()) require.NoError(t, err) return signed @@ -782,7 +876,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { mux.Get("/.well-known/openid-configuration", func(rw http.ResponseWriter, r *http.Request) { f.logger.Info(r.Context(), "http OIDC config", slogRequestFields(r)...) - cpy := f.provider + cpy := f.locked.Provider() if f.hookWellKnown != nil { err := f.hookWellKnown(r, &cpy) if err != nil { @@ -1082,7 +1176,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { set := jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ { - Key: f.key.Public(), + Key: f.locked.PrivateKey().Public(), KeyID: "test-key", Algorithm: "RSA", }, @@ -1181,7 +1275,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { exp: time.Now().Add(lifetime), }) - verifyURL := f.issuerURL.ResolveReference(&url.URL{ + verifyURL := f.locked.IssuerURL().ResolveReference(&url.URL{ Path: deviceVerify, RawQuery: url.Values{ "device_code": {deviceCode}, @@ -1240,10 +1334,10 @@ func (f *FakeIDP) HTTPClient(rest *http.Client) *http.Client { Jar: jar, Transport: fakeRoundTripper{ roundTrip: func(req *http.Request) (*http.Response, error) { - u, _ := url.Parse(f.issuer) + u, _ := url.Parse(f.locked.Issuer()) if req.URL.Host != u.Host { - if f.fakeCoderd != nil { - return f.fakeCoderd(req) + if fakeCoderd := f.locked.FakeCoderd(); fakeCoderd != nil { + return fakeCoderd(req) } if rest == nil || rest.Transport == nil { return nil, xerrors.Errorf("unexpected network request to %q", req.URL.Host) @@ -1251,7 +1345,7 @@ func (f *FakeIDP) HTTPClient(rest *http.Client) *http.Client { return rest.Transport.RoundTrip(req) } resp := httptest.NewRecorder() - f.handler.ServeHTTP(resp, req) + f.locked.Handler().ServeHTTP(resp, req) return resp.Result(), nil }, }, @@ -1269,6 +1363,7 @@ func (f *FakeIDP) RefreshUsed(refreshToken string) bool { // for a given refresh token. By default, all refreshes use the same claims as // the original IDToken issuance. func (f *FakeIDP) UpdateRefreshClaims(refreshToken string, claims jwt.MapClaims) { + // no mutex because it's a sync.Map f.refreshIDTokenClaims.Store(refreshToken, claims) } @@ -1276,8 +1371,9 @@ func (f *FakeIDP) UpdateRefreshClaims(refreshToken string, claims jwt.MapClaims) // Coderd. func (f *FakeIDP) SetRedirect(t testing.TB, u string) { t.Helper() - - f.cfg.RedirectURL = u + f.locked.MutateConfig(func(cfg *oauth2.Config) { + cfg.RedirectURL = u + }) } // SetCoderdCallback is optional and only works if not using the IsServing. @@ -1287,7 +1383,7 @@ func (f *FakeIDP) SetCoderdCallback(callback func(req *http.Request) (*http.Resp if f.serve { panic("cannot set callback handler when using 'WithServing'. Must implement an actual 'Coderd'") } - f.fakeCoderd = callback + f.locked.SetFakeCoderd(callback) } func (f *FakeIDP) SetCoderdCallbackHandler(handler http.HandlerFunc) { @@ -1384,13 +1480,13 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu DisplayIcon: f.WellknownConfig().UserInfoURL, // Omit the /user for the validate so we can easily append to it when modifying // the cfg for advanced tests. - ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(), + ValidateURL: f.locked.IssuerURL().ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(), DeviceAuth: &externalauth.DeviceAuth{ Config: oauthCfg, ClientID: f.clientID, - TokenURL: f.provider.TokenURL, + TokenURL: f.locked.Provider().TokenURL, Scopes: []string{}, - CodeURL: f.provider.DeviceCodeURL, + CodeURL: f.locked.Provider().DeviceCodeURL, }, } @@ -1401,7 +1497,7 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu for _, opt := range opts { opt(cfg) } - f.updateIssuerURL(t, f.issuer) + f.updateIssuerURL(t, f.locked.Issuer()) return cfg } @@ -1410,35 +1506,35 @@ func (f *FakeIDP) AppCredentials() (clientID string, clientSecret string) { } func (f *FakeIDP) PublicKey() crypto.PublicKey { - return f.key.Public() + return f.locked.PrivateKey().Public() } func (f *FakeIDP) OauthConfig(t testing.TB, scopes []string) *oauth2.Config { t.Helper() - if len(scopes) == 0 { - scopes = []string{"openid", "email", "profile"} - } - oauthCfg := &oauth2.Config{ - ClientID: f.clientID, - ClientSecret: f.clientSecret, - Endpoint: oauth2.Endpoint{ - AuthURL: f.provider.AuthURL, - TokenURL: f.provider.TokenURL, + provider := f.locked.Provider() + f.locked.MutateConfig(func(cfg *oauth2.Config) { + if len(scopes) == 0 { + scopes = []string{"openid", "email", "profile"} + } + cfg.ClientID = f.clientID + cfg.ClientSecret = f.clientSecret + cfg.Endpoint = oauth2.Endpoint{ + AuthURL: provider.AuthURL, + TokenURL: provider.TokenURL, AuthStyle: oauth2.AuthStyleInParams, - }, + } // If the user is using a real network request, they will need to do // 'fake.SetRedirect()' - RedirectURL: "https://redirect.com", - Scopes: scopes, - } - f.cfg = oauthCfg + cfg.RedirectURL = "https://redirect.com" + cfg.Scopes = scopes + }) - return oauthCfg + return f.locked.Config() } func (f *FakeIDP) OIDCConfigSkipIssuerChecks(t testing.TB, scopes []string, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig { - ctx := oidc.InsecureIssuerURLContext(context.Background(), f.issuer) + ctx := oidc.InsecureIssuerURLContext(context.Background(), f.locked.Issuer()) return f.internalOIDCConfig(ctx, t, scopes, func(config *oidc.Config) { config.SkipIssuerCheck = true @@ -1456,7 +1552,7 @@ func (f *FakeIDP) internalOIDCConfig(ctx context.Context, t testing.TB, scopes [ oauthCfg := f.OauthConfig(t, scopes) ctx = oidc.ClientContext(ctx, f.HTTPClient(nil)) - p, err := oidc.NewProvider(ctx, f.provider.Issuer) + p, err := oidc.NewProvider(ctx, f.locked.Issuer()) require.NoError(t, err, "failed to create OIDC provider") verifierConfig := &oidc.Config{ @@ -1473,8 +1569,8 @@ func (f *FakeIDP) internalOIDCConfig(ctx context.Context, t testing.TB, scopes [ cfg := &coderd.OIDCConfig{ OAuth2Config: oauthCfg, Provider: p, - Verifier: oidc.NewVerifier(f.provider.Issuer, &oidc.StaticKeySet{ - PublicKeys: []crypto.PublicKey{f.key.Public()}, + Verifier: oidc.NewVerifier(f.locked.Issuer(), &oidc.StaticKeySet{ + PublicKeys: []crypto.PublicKey{f.locked.PrivateKey().Public()}, }, verifierConfig), UsernameField: "preferred_username", EmailField: "email",