From eef11bc81cedf94daac3f9352a9afb55aa1a3535 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 28 Mar 2025 13:59:45 +0000 Subject: [PATCH 1/4] chore(coderd/coderdtest/oidctest): protect mutable fields with rwmutex --- coderd/coderdtest/oidctest/idp.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index 67186a4fd7ddf..8792c0dace8bd 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -20,6 +20,7 @@ import ( "net/url" "strconv" "strings" + "sync" "testing" "time" @@ -61,6 +62,8 @@ type deviceFlow struct { // FakeIDP is a functional OIDC provider. // It only supports 1 OIDC client. type FakeIDP struct { + mu sync.RWMutex + issuer string issuerURL *url.URL key *rsa.PrivateKey @@ -369,10 +372,14 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { } func (f *FakeIDP) WellknownConfig() ProviderJSON { + f.mu.RLock() + defer f.mu.RUnlock() return f.provider } func (f *FakeIDP) IssuerURL() *url.URL { + f.mu.RLock() + defer f.mu.RUnlock() return f.issuerURL } @@ -382,6 +389,7 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { u, err := url.Parse(issuer) require.NoError(t, err, "invalid issuer URL") + f.mu.Lock() f.issuer = issuer f.issuerURL = u // ProviderJSON is the JSON representation of the OpenID Connect provider @@ -398,6 +406,7 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { }, ExternalAuthURL: u.ResolveReference(&url.URL{Path: "/external-auth-validate/user"}).String(), } + f.mu.Unlock() } // realServer turns the FakeIDP into a real http server. @@ -437,7 +446,9 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server { // valid token for some given claims. func (f *FakeIDP) GenerateAuthenticatedToken(claims jwt.MapClaims) (*oauth2.Token, error) { state := uuid.NewString() + f.mu.Lock() f.stateToIDTokenClaims.Store(state, claims) + f.mu.Unlock() code := f.newCode(state) return f.cfg.Exchange(oidc.ClientContext(context.Background(), f.HTTPClient(nil)), code) } @@ -1269,6 +1280,7 @@ func (f *FakeIDP) RefreshUsed(refreshToken string) bool { // for a given refresh token. By default, all refreshes use the same claims as // the original IDToken issuance. func (f *FakeIDP) UpdateRefreshClaims(refreshToken string, claims jwt.MapClaims) { + // no mutex because it's a sync.Map f.refreshIDTokenClaims.Store(refreshToken, claims) } @@ -1277,7 +1289,9 @@ func (f *FakeIDP) UpdateRefreshClaims(refreshToken string, claims jwt.MapClaims) func (f *FakeIDP) SetRedirect(t testing.TB, u string) { t.Helper() + f.mu.Lock() f.cfg.RedirectURL = u + f.mu.Unlock() } // SetCoderdCallback is optional and only works if not using the IsServing. @@ -1287,15 +1301,19 @@ func (f *FakeIDP) SetCoderdCallback(callback func(req *http.Request) (*http.Resp if f.serve { panic("cannot set callback handler when using 'WithServing'. Must implement an actual 'Coderd'") } + f.mu.Lock() f.fakeCoderd = callback + f.mu.Unlock() } func (f *FakeIDP) SetCoderdCallbackHandler(handler http.HandlerFunc) { + f.mu.Lock() f.SetCoderdCallback(func(req *http.Request) (*http.Response, error) { resp := httptest.NewRecorder() handler.ServeHTTP(resp, req) return resp.Result(), nil }) + f.mu.Unlock() } // ExternalAuthConfigOptions exists to provide additional functionality ontop From e3b553cb0784afc33cc61738ff0a60251526095a Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 28 Mar 2025 16:33:11 +0000 Subject: [PATCH 2/4] go absolutely ham --- coderd/coderdtest/oidctest/idp.go | 234 ++++++++++++++++++++---------- 1 file changed, 156 insertions(+), 78 deletions(-) diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index 8792c0dace8bd..a28b612cf8745 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -59,17 +59,107 @@ type deviceFlow struct { granted bool } +type fakeIDPProtected struct { + mu sync.RWMutex + + issuer string + issuerURL *url.URL + key *rsa.PrivateKey + provider ProviderJSON + handler http.Handler + cfg *oauth2.Config + fakeCoderd func(req *http.Request) (*http.Response, error) +} + +func (f *fakeIDPProtected) Issuer() string { + f.mu.RLock() + defer f.mu.RUnlock() + return f.issuer +} + +func (f *fakeIDPProtected) IssuerURL() *url.URL { + f.mu.RLock() + defer f.mu.RUnlock() + return f.issuerURL +} + +func (f *fakeIDPProtected) PrivateKey() *rsa.PrivateKey { + f.mu.RLock() + defer f.mu.RUnlock() + return f.key +} + +func (f *fakeIDPProtected) Provider() ProviderJSON { + f.mu.RLock() + defer f.mu.RUnlock() + return f.provider +} + +func (f *fakeIDPProtected) Config() *oauth2.Config { + f.mu.RLock() + defer f.mu.RUnlock() + return f.cfg +} + +func (f *fakeIDPProtected) Handler() http.Handler { + f.mu.RLock() + defer f.mu.RUnlock() + return f.handler +} + +func (f *fakeIDPProtected) SetIssuer(issuer string) { + f.mu.Lock() + defer f.mu.Unlock() + f.issuer = issuer +} + +func (f *fakeIDPProtected) SetIssuerURL(issuerURL *url.URL) { + f.mu.Lock() + defer f.mu.Unlock() + f.issuerURL = issuerURL +} + +func (f *fakeIDPProtected) SetProvider(provider ProviderJSON) { + f.mu.Lock() + defer f.mu.Unlock() + f.provider = provider +} + +func (f *fakeIDPProtected) MutateConfig(fn func(cfg *oauth2.Config)) { + f.mu.RLock() + cfg := f.cfg + if cfg == nil { + cfg = &oauth2.Config{} + } + fn(cfg) + f.mu.RUnlock() + f.mu.Lock() + f.cfg = cfg + f.mu.Unlock() +} + +func (f *fakeIDPProtected) SetHandler(handler http.Handler) { + f.mu.Lock() + defer f.mu.Unlock() + f.handler = handler +} + +func (f *fakeIDPProtected) SetFakeCoderd(fakeCoderd func(req *http.Request) (*http.Response, error)) { + f.mu.Lock() + defer f.mu.Unlock() + f.fakeCoderd = fakeCoderd +} + +func (f *fakeIDPProtected) FakeCoderd() func(req *http.Request) (*http.Response, error) { + f.mu.RLock() + defer f.mu.RUnlock() + return f.fakeCoderd +} + // FakeIDP is a functional OIDC provider. // It only supports 1 OIDC client. type FakeIDP struct { - mu sync.RWMutex - - issuer string - issuerURL *url.URL - key *rsa.PrivateKey - provider ProviderJSON - handler http.Handler - cfg *oauth2.Config + prot fakeIDPProtected // callbackPath allows changing where the callback path to coderd is expected. // This only affects using the Login helper functions. @@ -113,8 +203,8 @@ type FakeIDP struct { // some claims. defaultIDClaims jwt.MapClaims hookMutateToken func(token map[string]interface{}) - fakeCoderd func(req *http.Request) (*http.Response, error) - hookOnRefresh func(email string) error + // fakeCoderd func(req *http.Request) (*http.Response, error) + hookOnRefresh func(email string) error // Custom authentication for the client. This is useful if you want // to test something like PKI auth vs a client_secret. hookAuthenticateClient func(t testing.TB, req *http.Request) (url.Values, error) @@ -259,7 +349,7 @@ func WithServing() func(*FakeIDP) { func WithIssuer(issuer string) func(*FakeIDP) { return func(f *FakeIDP) { - f.issuer = issuer + f.prot.SetIssuer(issuer) } } @@ -330,7 +420,9 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { require.NoError(t, err) idp := &FakeIDP{ - key: pkey, + prot: fakeIDPProtected{ + key: pkey, + }, clientID: uuid.NewString(), clientSecret: uuid.NewString(), logger: slog.Make(), @@ -351,12 +443,12 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { opt(idp) } - if idp.issuer == "" { - idp.issuer = "https://coder.com" + if idp.prot.Issuer() == "" { + idp.prot.SetIssuer("https://coder.com") } - idp.handler = idp.httpHandler(t) - idp.updateIssuerURL(t, idp.issuer) + idp.prot.SetHandler(idp.httpHandler(t)) + idp.updateIssuerURL(t, idp.prot.Issuer()) if idp.serve { idp.realServer(t) } @@ -372,15 +464,11 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { } func (f *FakeIDP) WellknownConfig() ProviderJSON { - f.mu.RLock() - defer f.mu.RUnlock() - return f.provider + return f.prot.Provider() } func (f *FakeIDP) IssuerURL() *url.URL { - f.mu.RLock() - defer f.mu.RUnlock() - return f.issuerURL + return f.prot.IssuerURL() } func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { @@ -389,12 +477,11 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { u, err := url.Parse(issuer) require.NoError(t, err, "invalid issuer URL") - f.mu.Lock() - f.issuer = issuer - f.issuerURL = u + f.prot.SetIssuer(issuer) + f.prot.SetIssuerURL(u) // ProviderJSON is the JSON representation of the OpenID Connect provider // These are all the urls that the IDP will respond to. - f.provider = ProviderJSON{ + f.prot.SetProvider(ProviderJSON{ Issuer: issuer, AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(), TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(), @@ -405,8 +492,7 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { "RS256", }, ExternalAuthURL: u.ResolveReference(&url.URL{Path: "/external-auth-validate/user"}).String(), - } - f.mu.Unlock() + }) } // realServer turns the FakeIDP into a real http server. @@ -414,7 +500,7 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server { t.Helper() srvURL := "localhost:0" - issURL, err := url.Parse(f.issuer) + issURL, err := url.Parse(f.prot.Issuer()) if err == nil { if issURL.Hostname() == "localhost" || issURL.Hostname() == "127.0.0.1" { srvURL = issURL.Host @@ -427,7 +513,7 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server { ctx, cancel := context.WithCancel(context.Background()) srv := &httptest.Server{ Listener: l, - Config: &http.Server{Handler: f.handler, ReadHeaderTimeout: time.Second * 5}, + Config: &http.Server{Handler: f.prot.Handler(), ReadHeaderTimeout: time.Second * 5}, } srv.Config.BaseContext = func(_ net.Listener) context.Context { @@ -446,11 +532,9 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server { // valid token for some given claims. func (f *FakeIDP) GenerateAuthenticatedToken(claims jwt.MapClaims) (*oauth2.Token, error) { state := uuid.NewString() - f.mu.Lock() f.stateToIDTokenClaims.Store(state, claims) - f.mu.Unlock() code := f.newCode(state) - return f.cfg.Exchange(oidc.ClientContext(context.Background(), f.HTTPClient(nil)), code) + return f.prot.Config().Exchange(oidc.ClientContext(context.Background(), f.HTTPClient(nil)), code) } // Login does the full OIDC flow starting at the "LoginButton". @@ -631,9 +715,9 @@ func (f *FakeIDP) CreateAuthCode(t testing.TB, state string) string { // it expects some claims to be present. f.stateToIDTokenClaims.Store(state, jwt.MapClaims{}) - code, err := OAuth2GetCode(f.cfg.AuthCodeURL(state), func(req *http.Request) (*http.Response, error) { + code, err := OAuth2GetCode(f.prot.Config().AuthCodeURL(state), func(req *http.Request) (*http.Response, error) { rw := httptest.NewRecorder() - f.handler.ServeHTTP(rw, req) + f.prot.Handler().ServeHTTP(rw, req) resp := rw.Result() return resp, nil }) @@ -655,7 +739,7 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map f.stateToIDTokenClaims.Store(state, idTokenClaims) cli := f.HTTPClient(nil) - u := f.cfg.AuthCodeURL(state) + u := f.prot.Config().AuthCodeURL(state) req, err := http.NewRequest("GET", u, nil) require.NoError(t, err) @@ -773,10 +857,10 @@ func (f *FakeIDP) encodeClaims(t testing.TB, claims jwt.MapClaims) string { } if _, ok := claims["iss"]; !ok { - claims["iss"] = f.issuer + claims["iss"] = f.prot.Issuer() } - signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(f.key) + signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(f.prot.PrivateKey()) require.NoError(t, err) return signed @@ -793,7 +877,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { mux.Get("/.well-known/openid-configuration", func(rw http.ResponseWriter, r *http.Request) { f.logger.Info(r.Context(), "http OIDC config", slogRequestFields(r)...) - cpy := f.provider + cpy := f.prot.Provider() if f.hookWellKnown != nil { err := f.hookWellKnown(r, &cpy) if err != nil { @@ -1093,7 +1177,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { set := jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ { - Key: f.key.Public(), + Key: f.prot.PrivateKey().Public(), KeyID: "test-key", Algorithm: "RSA", }, @@ -1192,7 +1276,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { exp: time.Now().Add(lifetime), }) - verifyURL := f.issuerURL.ResolveReference(&url.URL{ + verifyURL := f.prot.IssuerURL().ResolveReference(&url.URL{ Path: deviceVerify, RawQuery: url.Values{ "device_code": {deviceCode}, @@ -1251,10 +1335,10 @@ func (f *FakeIDP) HTTPClient(rest *http.Client) *http.Client { Jar: jar, Transport: fakeRoundTripper{ roundTrip: func(req *http.Request) (*http.Response, error) { - u, _ := url.Parse(f.issuer) + u, _ := url.Parse(f.prot.Issuer()) if req.URL.Host != u.Host { - if f.fakeCoderd != nil { - return f.fakeCoderd(req) + if fakeCoderd := f.prot.FakeCoderd(); fakeCoderd != nil { + return fakeCoderd(req) } if rest == nil || rest.Transport == nil { return nil, xerrors.Errorf("unexpected network request to %q", req.URL.Host) @@ -1262,7 +1346,7 @@ func (f *FakeIDP) HTTPClient(rest *http.Client) *http.Client { return rest.Transport.RoundTrip(req) } resp := httptest.NewRecorder() - f.handler.ServeHTTP(resp, req) + f.prot.Handler().ServeHTTP(resp, req) return resp.Result(), nil }, }, @@ -1288,10 +1372,9 @@ func (f *FakeIDP) UpdateRefreshClaims(refreshToken string, claims jwt.MapClaims) // Coderd. func (f *FakeIDP) SetRedirect(t testing.TB, u string) { t.Helper() - - f.mu.Lock() - f.cfg.RedirectURL = u - f.mu.Unlock() + f.prot.MutateConfig(func(cfg *oauth2.Config) { + cfg.RedirectURL = u + }) } // SetCoderdCallback is optional and only works if not using the IsServing. @@ -1301,19 +1384,15 @@ func (f *FakeIDP) SetCoderdCallback(callback func(req *http.Request) (*http.Resp if f.serve { panic("cannot set callback handler when using 'WithServing'. Must implement an actual 'Coderd'") } - f.mu.Lock() - f.fakeCoderd = callback - f.mu.Unlock() + f.prot.SetFakeCoderd(callback) } func (f *FakeIDP) SetCoderdCallbackHandler(handler http.HandlerFunc) { - f.mu.Lock() f.SetCoderdCallback(func(req *http.Request) (*http.Response, error) { resp := httptest.NewRecorder() handler.ServeHTTP(resp, req) return resp.Result(), nil }) - f.mu.Unlock() } // ExternalAuthConfigOptions exists to provide additional functionality ontop @@ -1402,13 +1481,13 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu DisplayIcon: f.WellknownConfig().UserInfoURL, // Omit the /user for the validate so we can easily append to it when modifying // the cfg for advanced tests. - ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(), + ValidateURL: f.prot.IssuerURL().ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(), DeviceAuth: &externalauth.DeviceAuth{ Config: oauthCfg, ClientID: f.clientID, - TokenURL: f.provider.TokenURL, + TokenURL: f.prot.Provider().TokenURL, Scopes: []string{}, - CodeURL: f.provider.DeviceCodeURL, + CodeURL: f.prot.Provider().DeviceCodeURL, }, } @@ -1419,7 +1498,7 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu for _, opt := range opts { opt(cfg) } - f.updateIssuerURL(t, f.issuer) + f.updateIssuerURL(t, f.prot.Issuer()) return cfg } @@ -1428,35 +1507,34 @@ func (f *FakeIDP) AppCredentials() (clientID string, clientSecret string) { } func (f *FakeIDP) PublicKey() crypto.PublicKey { - return f.key.Public() + return f.prot.PrivateKey().Public() } func (f *FakeIDP) OauthConfig(t testing.TB, scopes []string) *oauth2.Config { t.Helper() - if len(scopes) == 0 { - scopes = []string{"openid", "email", "profile"} - } - oauthCfg := &oauth2.Config{ - ClientID: f.clientID, - ClientSecret: f.clientSecret, - Endpoint: oauth2.Endpoint{ - AuthURL: f.provider.AuthURL, - TokenURL: f.provider.TokenURL, + f.prot.MutateConfig(func(cfg *oauth2.Config) { + if len(scopes) == 0 { + scopes = []string{"openid", "email", "profile"} + } + cfg.ClientID = f.clientID + cfg.ClientSecret = f.clientSecret + cfg.Endpoint = oauth2.Endpoint{ + AuthURL: f.prot.Provider().AuthURL, + TokenURL: f.prot.Provider().TokenURL, AuthStyle: oauth2.AuthStyleInParams, - }, + } // If the user is using a real network request, they will need to do // 'fake.SetRedirect()' - RedirectURL: "https://redirect.com", - Scopes: scopes, - } - f.cfg = oauthCfg + cfg.RedirectURL = "https://redirect.com" + cfg.Scopes = scopes + }) - return oauthCfg + return f.prot.Config() } func (f *FakeIDP) OIDCConfigSkipIssuerChecks(t testing.TB, scopes []string, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig { - ctx := oidc.InsecureIssuerURLContext(context.Background(), f.issuer) + ctx := oidc.InsecureIssuerURLContext(context.Background(), f.prot.Issuer()) return f.internalOIDCConfig(ctx, t, scopes, func(config *oidc.Config) { config.SkipIssuerCheck = true @@ -1474,7 +1552,7 @@ func (f *FakeIDP) internalOIDCConfig(ctx context.Context, t testing.TB, scopes [ oauthCfg := f.OauthConfig(t, scopes) ctx = oidc.ClientContext(ctx, f.HTTPClient(nil)) - p, err := oidc.NewProvider(ctx, f.provider.Issuer) + p, err := oidc.NewProvider(ctx, f.prot.Issuer()) require.NoError(t, err, "failed to create OIDC provider") verifierConfig := &oidc.Config{ @@ -1491,8 +1569,8 @@ func (f *FakeIDP) internalOIDCConfig(ctx context.Context, t testing.TB, scopes [ cfg := &coderd.OIDCConfig{ OAuth2Config: oauthCfg, Provider: p, - Verifier: oidc.NewVerifier(f.provider.Issuer, &oidc.StaticKeySet{ - PublicKeys: []crypto.PublicKey{f.key.Public()}, + Verifier: oidc.NewVerifier(f.prot.Issuer(), &oidc.StaticKeySet{ + PublicKeys: []crypto.PublicKey{f.prot.PrivateKey().Public()}, }, verifierConfig), UsernameField: "preferred_username", EmailField: "email", From 104eaf00ab15c33cacc7745f3528beda6aa695bb Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Wed, 2 Apr 2025 11:02:14 +0100 Subject: [PATCH 3/4] address PR comments --- coderd/coderdtest/oidctest/idp.go | 169 +++++++++++++++++------------- 1 file changed, 94 insertions(+), 75 deletions(-) diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index a28b612cf8745..329a03d9acba0 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -59,7 +59,9 @@ type deviceFlow struct { granted bool } -type fakeIDPProtected struct { +// fakeIDPLocked is a set of fields of FakeIDP that are protected +// behind a mutex. +type fakeIDPLocked struct { mu sync.RWMutex issuer string @@ -71,87 +73,103 @@ type fakeIDPProtected struct { fakeCoderd func(req *http.Request) (*http.Response, error) } -func (f *fakeIDPProtected) Issuer() string { +func (f *fakeIDPLocked) Issuer() string { f.mu.RLock() defer f.mu.RUnlock() return f.issuer } -func (f *fakeIDPProtected) IssuerURL() *url.URL { +func (f *fakeIDPLocked) IssuerURL() *url.URL { f.mu.RLock() defer f.mu.RUnlock() return f.issuerURL } -func (f *fakeIDPProtected) PrivateKey() *rsa.PrivateKey { +func (f *fakeIDPLocked) PrivateKey() *rsa.PrivateKey { f.mu.RLock() defer f.mu.RUnlock() return f.key } -func (f *fakeIDPProtected) Provider() ProviderJSON { - f.mu.RLock() +func (f *fakeIDPLocked) Provider() ProviderJSON { + if !f.mu.TryRLock() { + panic("developer error: fakeIDPLocked is already locked") + } defer f.mu.RUnlock() return f.provider } -func (f *fakeIDPProtected) Config() *oauth2.Config { - f.mu.RLock() +func (f *fakeIDPLocked) Config() *oauth2.Config { + if !f.mu.TryRLock() { + panic("developer error: fakeIDPLocked is already locked") + } defer f.mu.RUnlock() return f.cfg } -func (f *fakeIDPProtected) Handler() http.Handler { - f.mu.RLock() +func (f *fakeIDPLocked) Handler() http.Handler { + if !f.mu.TryRLock() { + panic("developer error: fakeIDPLocked is already locked") + } defer f.mu.RUnlock() return f.handler } -func (f *fakeIDPProtected) SetIssuer(issuer string) { - f.mu.Lock() +func (f *fakeIDPLocked) SetIssuer(issuer string) { + if !f.mu.TryLock() { + panic("developer error: fakeIDPLocked is already locked") + } defer f.mu.Unlock() f.issuer = issuer } -func (f *fakeIDPProtected) SetIssuerURL(issuerURL *url.URL) { - f.mu.Lock() +func (f *fakeIDPLocked) SetIssuerURL(issuerURL *url.URL) { + if !f.mu.TryLock() { + panic("developer error: fakeIDPLocked is already locked") + } defer f.mu.Unlock() f.issuerURL = issuerURL } -func (f *fakeIDPProtected) SetProvider(provider ProviderJSON) { - f.mu.Lock() +func (f *fakeIDPLocked) SetProvider(provider ProviderJSON) { + if !f.mu.TryLock() { + panic("developer error: fakeIDPLocked is already locked") + } defer f.mu.Unlock() f.provider = provider } -func (f *fakeIDPProtected) MutateConfig(fn func(cfg *oauth2.Config)) { - f.mu.RLock() - cfg := f.cfg - if cfg == nil { - cfg = &oauth2.Config{} - } - fn(cfg) - f.mu.RUnlock() - f.mu.Lock() - f.cfg = cfg +func (f *fakeIDPLocked) MutateConfig(fn func(cfg *oauth2.Config)) { + if !f.mu.TryLock() { + panic("developer error: fakeIDPLocked is already locked") + } + if f.cfg == nil { + f.cfg = &oauth2.Config{} + } + fn(f.cfg) f.mu.Unlock() } -func (f *fakeIDPProtected) SetHandler(handler http.Handler) { - f.mu.Lock() +func (f *fakeIDPLocked) SetHandler(handler http.Handler) { + if !f.mu.TryLock() { + panic("developer error: fakeIDPLocked is already locked") + } defer f.mu.Unlock() f.handler = handler } -func (f *fakeIDPProtected) SetFakeCoderd(fakeCoderd func(req *http.Request) (*http.Response, error)) { - f.mu.Lock() +func (f *fakeIDPLocked) SetFakeCoderd(fakeCoderd func(req *http.Request) (*http.Response, error)) { + if !f.mu.TryLock() { + panic("developer error: fakeIDPLocked is already locked") + } defer f.mu.Unlock() f.fakeCoderd = fakeCoderd } -func (f *fakeIDPProtected) FakeCoderd() func(req *http.Request) (*http.Response, error) { - f.mu.RLock() +func (f *fakeIDPLocked) FakeCoderd() func(req *http.Request) (*http.Response, error) { + if !f.mu.TryRLock() { + panic("developer error: fakeIDPLocked is already locked") + } defer f.mu.RUnlock() return f.fakeCoderd } @@ -159,7 +177,7 @@ func (f *fakeIDPProtected) FakeCoderd() func(req *http.Request) (*http.Response, // FakeIDP is a functional OIDC provider. // It only supports 1 OIDC client. type FakeIDP struct { - prot fakeIDPProtected + locked fakeIDPLocked // callbackPath allows changing where the callback path to coderd is expected. // This only affects using the Login helper functions. @@ -203,8 +221,7 @@ type FakeIDP struct { // some claims. defaultIDClaims jwt.MapClaims hookMutateToken func(token map[string]interface{}) - // fakeCoderd func(req *http.Request) (*http.Response, error) - hookOnRefresh func(email string) error + hookOnRefresh func(email string) error // Custom authentication for the client. This is useful if you want // to test something like PKI auth vs a client_secret. hookAuthenticateClient func(t testing.TB, req *http.Request) (url.Values, error) @@ -349,7 +366,7 @@ func WithServing() func(*FakeIDP) { func WithIssuer(issuer string) func(*FakeIDP) { return func(f *FakeIDP) { - f.prot.SetIssuer(issuer) + f.locked.SetIssuer(issuer) } } @@ -420,7 +437,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { require.NoError(t, err) idp := &FakeIDP{ - prot: fakeIDPProtected{ + locked: fakeIDPLocked{ key: pkey, }, clientID: uuid.NewString(), @@ -443,12 +460,12 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { opt(idp) } - if idp.prot.Issuer() == "" { - idp.prot.SetIssuer("https://coder.com") + if idp.locked.Issuer() == "" { + idp.locked.SetIssuer("https://coder.com") } - idp.prot.SetHandler(idp.httpHandler(t)) - idp.updateIssuerURL(t, idp.prot.Issuer()) + idp.locked.SetHandler(idp.httpHandler(t)) + idp.updateIssuerURL(t, idp.locked.Issuer()) if idp.serve { idp.realServer(t) } @@ -464,11 +481,11 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { } func (f *FakeIDP) WellknownConfig() ProviderJSON { - return f.prot.Provider() + return f.locked.Provider() } func (f *FakeIDP) IssuerURL() *url.URL { - return f.prot.IssuerURL() + return f.locked.IssuerURL() } func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { @@ -477,11 +494,11 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { u, err := url.Parse(issuer) require.NoError(t, err, "invalid issuer URL") - f.prot.SetIssuer(issuer) - f.prot.SetIssuerURL(u) + f.locked.SetIssuer(issuer) + f.locked.SetIssuerURL(u) // ProviderJSON is the JSON representation of the OpenID Connect provider // These are all the urls that the IDP will respond to. - f.prot.SetProvider(ProviderJSON{ + f.locked.SetProvider(ProviderJSON{ Issuer: issuer, AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(), TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(), @@ -500,7 +517,7 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server { t.Helper() srvURL := "localhost:0" - issURL, err := url.Parse(f.prot.Issuer()) + issURL, err := url.Parse(f.locked.Issuer()) if err == nil { if issURL.Hostname() == "localhost" || issURL.Hostname() == "127.0.0.1" { srvURL = issURL.Host @@ -513,7 +530,7 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server { ctx, cancel := context.WithCancel(context.Background()) srv := &httptest.Server{ Listener: l, - Config: &http.Server{Handler: f.prot.Handler(), ReadHeaderTimeout: time.Second * 5}, + Config: &http.Server{Handler: f.locked.Handler(), ReadHeaderTimeout: time.Second * 5}, } srv.Config.BaseContext = func(_ net.Listener) context.Context { @@ -534,7 +551,7 @@ func (f *FakeIDP) GenerateAuthenticatedToken(claims jwt.MapClaims) (*oauth2.Toke state := uuid.NewString() f.stateToIDTokenClaims.Store(state, claims) code := f.newCode(state) - return f.prot.Config().Exchange(oidc.ClientContext(context.Background(), f.HTTPClient(nil)), code) + return f.locked.Config().Exchange(oidc.ClientContext(context.Background(), f.HTTPClient(nil)), code) } // Login does the full OIDC flow starting at the "LoginButton". @@ -715,9 +732,9 @@ func (f *FakeIDP) CreateAuthCode(t testing.TB, state string) string { // it expects some claims to be present. f.stateToIDTokenClaims.Store(state, jwt.MapClaims{}) - code, err := OAuth2GetCode(f.prot.Config().AuthCodeURL(state), func(req *http.Request) (*http.Response, error) { + code, err := OAuth2GetCode(f.locked.Config().AuthCodeURL(state), func(req *http.Request) (*http.Response, error) { rw := httptest.NewRecorder() - f.prot.Handler().ServeHTTP(rw, req) + f.locked.Handler().ServeHTTP(rw, req) resp := rw.Result() return resp, nil }) @@ -739,7 +756,7 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map f.stateToIDTokenClaims.Store(state, idTokenClaims) cli := f.HTTPClient(nil) - u := f.prot.Config().AuthCodeURL(state) + u := f.locked.Config().AuthCodeURL(state) req, err := http.NewRequest("GET", u, nil) require.NoError(t, err) @@ -857,10 +874,10 @@ func (f *FakeIDP) encodeClaims(t testing.TB, claims jwt.MapClaims) string { } if _, ok := claims["iss"]; !ok { - claims["iss"] = f.prot.Issuer() + claims["iss"] = f.locked.Issuer() } - signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(f.prot.PrivateKey()) + signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(f.locked.PrivateKey()) require.NoError(t, err) return signed @@ -877,7 +894,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { mux.Get("/.well-known/openid-configuration", func(rw http.ResponseWriter, r *http.Request) { f.logger.Info(r.Context(), "http OIDC config", slogRequestFields(r)...) - cpy := f.prot.Provider() + cpy := f.locked.Provider() if f.hookWellKnown != nil { err := f.hookWellKnown(r, &cpy) if err != nil { @@ -1177,7 +1194,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { set := jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ { - Key: f.prot.PrivateKey().Public(), + Key: f.locked.PrivateKey().Public(), KeyID: "test-key", Algorithm: "RSA", }, @@ -1276,7 +1293,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { exp: time.Now().Add(lifetime), }) - verifyURL := f.prot.IssuerURL().ResolveReference(&url.URL{ + verifyURL := f.locked.IssuerURL().ResolveReference(&url.URL{ Path: deviceVerify, RawQuery: url.Values{ "device_code": {deviceCode}, @@ -1335,9 +1352,9 @@ func (f *FakeIDP) HTTPClient(rest *http.Client) *http.Client { Jar: jar, Transport: fakeRoundTripper{ roundTrip: func(req *http.Request) (*http.Response, error) { - u, _ := url.Parse(f.prot.Issuer()) + u, _ := url.Parse(f.locked.Issuer()) if req.URL.Host != u.Host { - if fakeCoderd := f.prot.FakeCoderd(); fakeCoderd != nil { + if fakeCoderd := f.locked.FakeCoderd(); fakeCoderd != nil { return fakeCoderd(req) } if rest == nil || rest.Transport == nil { @@ -1346,7 +1363,7 @@ func (f *FakeIDP) HTTPClient(rest *http.Client) *http.Client { return rest.Transport.RoundTrip(req) } resp := httptest.NewRecorder() - f.prot.Handler().ServeHTTP(resp, req) + f.locked.Handler().ServeHTTP(resp, req) return resp.Result(), nil }, }, @@ -1372,7 +1389,7 @@ func (f *FakeIDP) UpdateRefreshClaims(refreshToken string, claims jwt.MapClaims) // Coderd. func (f *FakeIDP) SetRedirect(t testing.TB, u string) { t.Helper() - f.prot.MutateConfig(func(cfg *oauth2.Config) { + f.locked.MutateConfig(func(cfg *oauth2.Config) { cfg.RedirectURL = u }) } @@ -1384,7 +1401,7 @@ func (f *FakeIDP) SetCoderdCallback(callback func(req *http.Request) (*http.Resp if f.serve { panic("cannot set callback handler when using 'WithServing'. Must implement an actual 'Coderd'") } - f.prot.SetFakeCoderd(callback) + f.locked.SetFakeCoderd(callback) } func (f *FakeIDP) SetCoderdCallbackHandler(handler http.HandlerFunc) { @@ -1481,13 +1498,13 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu DisplayIcon: f.WellknownConfig().UserInfoURL, // Omit the /user for the validate so we can easily append to it when modifying // the cfg for advanced tests. - ValidateURL: f.prot.IssuerURL().ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(), + ValidateURL: f.locked.IssuerURL().ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(), DeviceAuth: &externalauth.DeviceAuth{ Config: oauthCfg, ClientID: f.clientID, - TokenURL: f.prot.Provider().TokenURL, + TokenURL: f.locked.Provider().TokenURL, Scopes: []string{}, - CodeURL: f.prot.Provider().DeviceCodeURL, + CodeURL: f.locked.Provider().DeviceCodeURL, }, } @@ -1498,7 +1515,7 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu for _, opt := range opts { opt(cfg) } - f.updateIssuerURL(t, f.prot.Issuer()) + f.updateIssuerURL(t, f.locked.Issuer()) return cfg } @@ -1507,21 +1524,23 @@ func (f *FakeIDP) AppCredentials() (clientID string, clientSecret string) { } func (f *FakeIDP) PublicKey() crypto.PublicKey { - return f.prot.PrivateKey().Public() + return f.locked.PrivateKey().Public() } func (f *FakeIDP) OauthConfig(t testing.TB, scopes []string) *oauth2.Config { t.Helper() - f.prot.MutateConfig(func(cfg *oauth2.Config) { + authURL := f.locked.Provider().AuthURL + tokenURL := f.locked.Provider().TokenURL + f.locked.MutateConfig(func(cfg *oauth2.Config) { if len(scopes) == 0 { scopes = []string{"openid", "email", "profile"} } cfg.ClientID = f.clientID cfg.ClientSecret = f.clientSecret cfg.Endpoint = oauth2.Endpoint{ - AuthURL: f.prot.Provider().AuthURL, - TokenURL: f.prot.Provider().TokenURL, + AuthURL: authURL, + TokenURL: tokenURL, AuthStyle: oauth2.AuthStyleInParams, } // If the user is using a real network request, they will need to do @@ -1530,11 +1549,11 @@ func (f *FakeIDP) OauthConfig(t testing.TB, scopes []string) *oauth2.Config { cfg.Scopes = scopes }) - return f.prot.Config() + return f.locked.Config() } func (f *FakeIDP) OIDCConfigSkipIssuerChecks(t testing.TB, scopes []string, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig { - ctx := oidc.InsecureIssuerURLContext(context.Background(), f.prot.Issuer()) + ctx := oidc.InsecureIssuerURLContext(context.Background(), f.locked.Issuer()) return f.internalOIDCConfig(ctx, t, scopes, func(config *oidc.Config) { config.SkipIssuerCheck = true @@ -1552,7 +1571,7 @@ func (f *FakeIDP) internalOIDCConfig(ctx context.Context, t testing.TB, scopes [ oauthCfg := f.OauthConfig(t, scopes) ctx = oidc.ClientContext(ctx, f.HTTPClient(nil)) - p, err := oidc.NewProvider(ctx, f.prot.Issuer()) + p, err := oidc.NewProvider(ctx, f.locked.Issuer()) require.NoError(t, err, "failed to create OIDC provider") verifierConfig := &oidc.Config{ @@ -1569,8 +1588,8 @@ func (f *FakeIDP) internalOIDCConfig(ctx context.Context, t testing.TB, scopes [ cfg := &coderd.OIDCConfig{ OAuth2Config: oauthCfg, Provider: p, - Verifier: oidc.NewVerifier(f.prot.Issuer(), &oidc.StaticKeySet{ - PublicKeys: []crypto.PublicKey{f.prot.PrivateKey().Public()}, + Verifier: oidc.NewVerifier(f.locked.Issuer(), &oidc.StaticKeySet{ + PublicKeys: []crypto.PublicKey{f.locked.PrivateKey().Public()}, }, verifierConfig), UsernameField: "preferred_username", EmailField: "email", From afab7856b62917b5716ce9a63060dfcd6b9172bf Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Wed, 2 Apr 2025 12:23:01 +0100 Subject: [PATCH 4/4] no more TryLock --- coderd/coderdtest/oidctest/idp.go | 49 ++++++++++--------------------- 1 file changed, 15 insertions(+), 34 deletions(-) diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index 329a03d9acba0..d4f24140b6726 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -92,57 +92,45 @@ func (f *fakeIDPLocked) PrivateKey() *rsa.PrivateKey { } func (f *fakeIDPLocked) Provider() ProviderJSON { - if !f.mu.TryRLock() { - panic("developer error: fakeIDPLocked is already locked") - } + f.mu.RLock() defer f.mu.RUnlock() return f.provider } func (f *fakeIDPLocked) Config() *oauth2.Config { - if !f.mu.TryRLock() { - panic("developer error: fakeIDPLocked is already locked") - } + f.mu.RLock() defer f.mu.RUnlock() return f.cfg } func (f *fakeIDPLocked) Handler() http.Handler { - if !f.mu.TryRLock() { - panic("developer error: fakeIDPLocked is already locked") - } + f.mu.RLock() defer f.mu.RUnlock() return f.handler } func (f *fakeIDPLocked) SetIssuer(issuer string) { - if !f.mu.TryLock() { - panic("developer error: fakeIDPLocked is already locked") - } + f.mu.Lock() defer f.mu.Unlock() f.issuer = issuer } func (f *fakeIDPLocked) SetIssuerURL(issuerURL *url.URL) { - if !f.mu.TryLock() { - panic("developer error: fakeIDPLocked is already locked") - } + f.mu.Lock() defer f.mu.Unlock() f.issuerURL = issuerURL } func (f *fakeIDPLocked) SetProvider(provider ProviderJSON) { - if !f.mu.TryLock() { - panic("developer error: fakeIDPLocked is already locked") - } + f.mu.Lock() defer f.mu.Unlock() f.provider = provider } +// MutateConfig is a helper function to mutate the oauth2.Config. +// Beware of re-entrant locks! func (f *fakeIDPLocked) MutateConfig(fn func(cfg *oauth2.Config)) { - if !f.mu.TryLock() { - panic("developer error: fakeIDPLocked is already locked") - } + f.mu.Lock() if f.cfg == nil { f.cfg = &oauth2.Config{} } @@ -151,25 +139,19 @@ func (f *fakeIDPLocked) MutateConfig(fn func(cfg *oauth2.Config)) { } func (f *fakeIDPLocked) SetHandler(handler http.Handler) { - if !f.mu.TryLock() { - panic("developer error: fakeIDPLocked is already locked") - } + f.mu.Lock() defer f.mu.Unlock() f.handler = handler } func (f *fakeIDPLocked) SetFakeCoderd(fakeCoderd func(req *http.Request) (*http.Response, error)) { - if !f.mu.TryLock() { - panic("developer error: fakeIDPLocked is already locked") - } + f.mu.Lock() defer f.mu.Unlock() f.fakeCoderd = fakeCoderd } func (f *fakeIDPLocked) FakeCoderd() func(req *http.Request) (*http.Response, error) { - if !f.mu.TryRLock() { - panic("developer error: fakeIDPLocked is already locked") - } + f.mu.RLock() defer f.mu.RUnlock() return f.fakeCoderd } @@ -1530,8 +1512,7 @@ func (f *FakeIDP) PublicKey() crypto.PublicKey { func (f *FakeIDP) OauthConfig(t testing.TB, scopes []string) *oauth2.Config { t.Helper() - authURL := f.locked.Provider().AuthURL - tokenURL := f.locked.Provider().TokenURL + provider := f.locked.Provider() f.locked.MutateConfig(func(cfg *oauth2.Config) { if len(scopes) == 0 { scopes = []string{"openid", "email", "profile"} @@ -1539,8 +1520,8 @@ func (f *FakeIDP) OauthConfig(t testing.TB, scopes []string) *oauth2.Config { cfg.ClientID = f.clientID cfg.ClientSecret = f.clientSecret cfg.Endpoint = oauth2.Endpoint{ - AuthURL: authURL, - TokenURL: tokenURL, + AuthURL: provider.AuthURL, + TokenURL: provider.TokenURL, AuthStyle: oauth2.AuthStyleInParams, } // If the user is using a real network request, they will need to do