From b7f13fada0bbcf908eb0fd40df0d322bafc749f7 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 9 Jan 2024 11:00:41 -0600 Subject: [PATCH 01/20] chore: instrument external oauth2 requests --- cli/server.go | 24 ++-- coderd/externalauth/externalauth.go | 15 +- coderd/httpmw/apikey.go | 7 +- coderd/httpmw/oauth2.go | 11 +- coderd/oauthpki/oidcpki.go | 6 +- coderd/promoauth/doc.go | 4 + coderd/promoauth/oauth2.go | 128 ++++++++++++++++++ .../provisionerdserver/provisionerdserver.go | 8 +- coderd/userauth.go | 5 +- 9 files changed, 168 insertions(+), 40 deletions(-) create mode 100644 coderd/promoauth/doc.go create mode 100644 coderd/promoauth/oauth2.go diff --git a/cli/server.go b/cli/server.go index 72c72679fd2b9..4f62aefdf98b3 100644 --- a/cli/server.go +++ b/cli/server.go @@ -80,6 +80,7 @@ import ( "github.com/coder/coder/v2/coderd/oauthpki" "github.com/coder/coder/v2/coderd/prometheusmetrics" "github.com/coder/coder/v2/coderd/prometheusmetrics/insights" + "github.com/coder/coder/v2/coderd/promoauth" "github.com/coder/coder/v2/coderd/schedule" "github.com/coder/coder/v2/coderd/telemetry" "github.com/coder/coder/v2/coderd/tracing" @@ -102,7 +103,7 @@ import ( "github.com/coder/wgtunnel/tunnelsdk" ) -func createOIDCConfig(ctx context.Context, vals *codersdk.DeploymentValues) (*coderd.OIDCConfig, error) { +func createOIDCConfig(ctx context.Context, instrument *promoauth.Factory, vals *codersdk.DeploymentValues) (*coderd.OIDCConfig, error) { if vals.OIDC.ClientID == "" { return nil, xerrors.Errorf("OIDC client ID must be set!") } @@ -133,7 +134,7 @@ func createOIDCConfig(ctx context.Context, vals *codersdk.DeploymentValues) (*co Scopes: vals.OIDC.Scopes, } - var useCfg httpmw.OAuth2Config = oauthCfg + var useCfg promoauth.OAuth2Config = oauthCfg if vals.OIDC.ClientKeyFile != "" { // PKI authentication is done in the params. If a // counter example is found, we can add a config option to @@ -159,7 +160,7 @@ func createOIDCConfig(ctx context.Context, vals *codersdk.DeploymentValues) (*co } return &coderd.OIDCConfig{ - OAuth2Config: useCfg, + OAuth2Config: instrument.New("oidc-login", useCfg), Provider: oidcProvider, Verifier: oidcProvider.Verifier(&oidc.Config{ ClientID: vals.OIDC.ClientID.String(), @@ -523,8 +524,11 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. return xerrors.Errorf("read external auth providers from env: %w", err) } + promRegistry := prometheus.NewRegistry() + oauthIntrument := promoauth.NewFactory(promRegistry) vals.ExternalAuthConfigs.Value = append(vals.ExternalAuthConfigs.Value, extAuthEnv...) externalAuthConfigs, err := externalauth.ConvertConfig( + oauthIntrument, vals.ExternalAuthConfigs.Value, vals.AccessURL.Value(), ) @@ -571,7 +575,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. // the DeploymentValues instead, this just serves to indicate the source of each // option. This is just defensive to prevent accidentally leaking. DeploymentOptions: codersdk.DeploymentOptionsWithoutSecrets(opts), - PrometheusRegistry: prometheus.NewRegistry(), + PrometheusRegistry: promRegistry, APIRateLimit: int(vals.RateLimit.API.Value()), LoginRateLimit: loginRateLimit, FilesRateLimit: filesRateLimit, @@ -617,7 +621,9 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. } if vals.OAuth2.Github.ClientSecret != "" { - options.GithubOAuth2Config, err = configureGithubOAuth2(vals.AccessURL.Value(), + options.GithubOAuth2Config, err = configureGithubOAuth2( + oauthIntrument, + vals.AccessURL.Value(), vals.OAuth2.Github.ClientID.String(), vals.OAuth2.Github.ClientSecret.String(), vals.OAuth2.Github.AllowSignups.Value(), @@ -636,7 +642,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. logger.Warn(ctx, "coder will not check email_verified for OIDC logins") } - oc, err := createOIDCConfig(ctx, vals) + oc, err := createOIDCConfig(ctx, oauthIntrument, vals) if err != nil { return xerrors.Errorf("create oidc config: %w", err) } @@ -1737,7 +1743,7 @@ func configureCAPool(tlsClientCAFile string, tlsConfig *tls.Config) error { } //nolint:revive // Ignore flag-parameter: parameter 'allowEveryone' seems to be a control flag, avoid control coupling (revive) -func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string, allowSignups, allowEveryone bool, allowOrgs []string, rawTeams []string, enterpriseBaseURL string) (*coderd.GithubOAuth2Config, error) { +func configureGithubOAuth2(instrument *promoauth.Factory, accessURL *url.URL, clientID, clientSecret string, allowSignups, allowEveryone bool, allowOrgs []string, rawTeams []string, enterpriseBaseURL string) (*coderd.GithubOAuth2Config, error) { redirectURL, err := accessURL.Parse("/api/v2/users/oauth2/github/callback") if err != nil { return nil, xerrors.Errorf("parse github oauth callback url: %w", err) @@ -1790,7 +1796,7 @@ func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string, al } return &coderd.GithubOAuth2Config{ - OAuth2Config: &oauth2.Config{ + OAuth2Config: instrument.New("github-login", &oauth2.Config{ ClientID: clientID, ClientSecret: clientSecret, Endpoint: endpoint, @@ -1800,7 +1806,7 @@ func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string, al "read:org", "user:email", }, - }, + }), AllowSignups: allowSignups, AllowEveryone: allowEveryone, AllowOrganizations: allowOrgs, diff --git a/coderd/externalauth/externalauth.go b/coderd/externalauth/externalauth.go index 9243aa29e44e4..0c49cbcb62d53 100644 --- a/coderd/externalauth/externalauth.go +++ b/coderd/externalauth/externalauth.go @@ -22,19 +22,14 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/promoauth" "github.com/coder/coder/v2/codersdk" "github.com/coder/retry" ) -type OAuth2Config interface { - AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string - Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) - TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource -} - // Config is used for authentication for Git operations. type Config struct { - OAuth2Config + promoauth.OAuth2Config // ID is a unique identifier for the authenticator. ID string // Type is the type of provider. @@ -401,7 +396,7 @@ func (c *DeviceAuth) formatDeviceCodeURL() (string, error) { // ConvertConfig converts the SDK configuration entry format // to the parsed and ready-to-consume in coderd provider type. -func ConvertConfig(entries []codersdk.ExternalAuthConfig, accessURL *url.URL) ([]*Config, error) { +func ConvertConfig(instrument *promoauth.Factory, entries []codersdk.ExternalAuthConfig, accessURL *url.URL) ([]*Config, error) { ids := map[string]struct{}{} configs := []*Config{} for _, entry := range entries { @@ -453,7 +448,7 @@ func ConvertConfig(entries []codersdk.ExternalAuthConfig, accessURL *url.URL) ([ Scopes: entry.Scopes, } - var oauthConfig OAuth2Config = oc + var oauthConfig promoauth.OAuth2Config = oc // Azure DevOps uses JWT token authentication! if entry.Type == string(codersdk.EnhancedExternalAuthProviderAzureDevops) { oauthConfig = &jwtConfig{oc} @@ -463,7 +458,7 @@ func ConvertConfig(entries []codersdk.ExternalAuthConfig, accessURL *url.URL) ([ } cfg := &Config{ - OAuth2Config: oauthConfig, + OAuth2Config: instrument.New(entry.ID, oauthConfig), ID: entry.ID, Regex: regex, Type: entry.Type, diff --git a/coderd/httpmw/apikey.go b/coderd/httpmw/apikey.go index dfffe9cf092df..46d8c97014bc3 100644 --- a/coderd/httpmw/apikey.go +++ b/coderd/httpmw/apikey.go @@ -22,6 +22,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/promoauth" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/codersdk" ) @@ -74,8 +75,8 @@ func UserAuthorization(r *http.Request) Authorization { // OAuth2Configs is a collection of configurations for OAuth-based authentication. // This should be extended to support other authentication types in the future. type OAuth2Configs struct { - Github OAuth2Config - OIDC OAuth2Config + Github promoauth.OAuth2Config + OIDC promoauth.OAuth2Config } func (c *OAuth2Configs) IsZero() bool { @@ -270,7 +271,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon }) } - var oauthConfig OAuth2Config + var oauthConfig promoauth.OAuth2Config switch key.LoginType { case database.LoginTypeGithub: oauthConfig = cfg.OAuth2Configs.Github diff --git a/coderd/httpmw/oauth2.go b/coderd/httpmw/oauth2.go index c300576aa82c2..dbb763bc9de3e 100644 --- a/coderd/httpmw/oauth2.go +++ b/coderd/httpmw/oauth2.go @@ -10,6 +10,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/promoauth" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/cryptorand" ) @@ -22,14 +23,6 @@ type OAuth2State struct { StateString string } -// OAuth2Config exposes a subset of *oauth2.Config functions for easier testing. -// *oauth2.Config should be used instead of implementing this in production. -type OAuth2Config interface { - AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string - Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) - TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource -} - // OAuth2 returns the state from an oauth request. func OAuth2(r *http.Request) OAuth2State { oauth, ok := r.Context().Value(oauth2StateKey{}).(OAuth2State) @@ -44,7 +37,7 @@ func OAuth2(r *http.Request) OAuth2State { // a "code" URL parameter will be redirected. // AuthURLOpts are passed to the AuthCodeURL function. If this is nil, // the default option oauth2.AccessTypeOffline will be used. -func ExtractOAuth2(config OAuth2Config, client *http.Client, authURLOpts map[string]string) func(http.Handler) http.Handler { +func ExtractOAuth2(config promoauth.OAuth2Config, client *http.Client, authURLOpts map[string]string) func(http.Handler) http.Handler { opts := make([]oauth2.AuthCodeOption, 0, len(authURLOpts)+1) opts = append(opts, oauth2.AccessTypeOffline) for k, v := range authURLOpts { diff --git a/coderd/oauthpki/oidcpki.go b/coderd/oauthpki/oidcpki.go index c44d130e5be9f..dff1240ca7dc1 100644 --- a/coderd/oauthpki/oidcpki.go +++ b/coderd/oauthpki/oidcpki.go @@ -20,7 +20,7 @@ import ( "golang.org/x/oauth2/jws" "golang.org/x/xerrors" - "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/promoauth" ) // Config uses jwt assertions over client_secret for oauth2 authentication of @@ -33,7 +33,7 @@ import ( // // https://datatracker.ietf.org/doc/html/rfc7523 type Config struct { - cfg httpmw.OAuth2Config + cfg promoauth.OAuth2Config // These values should match those provided in the oauth2.Config. // Because the inner config is an interface, we need to duplicate these @@ -57,7 +57,7 @@ type ConfigParams struct { PemEncodedKey []byte PemEncodedCert []byte - Config httpmw.OAuth2Config + Config promoauth.OAuth2Config } // NewOauth2PKIConfig creates the oauth2 config for PKI based auth. It requires the certificate and it's private key. diff --git a/coderd/promoauth/doc.go b/coderd/promoauth/doc.go new file mode 100644 index 0000000000000..72f30b48cff7a --- /dev/null +++ b/coderd/promoauth/doc.go @@ -0,0 +1,4 @@ +// Package promoauth is for instrumenting oauth2 flows with prometheus metrics. +// Specifically, it is intended to count the number of external requests made +// by the underlying oauth2 exchanges. +package promoauth diff --git a/coderd/promoauth/oauth2.go b/coderd/promoauth/oauth2.go new file mode 100644 index 0000000000000..8331d512ecdd7 --- /dev/null +++ b/coderd/promoauth/oauth2.go @@ -0,0 +1,128 @@ +package promoauth + +import ( + "context" + "fmt" + "net/http" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "golang.org/x/oauth2" +) + +// OAuth2Config exposes a subset of *oauth2.Config functions for easier testing. +// *oauth2.Config should be used instead of implementing this in production. +type OAuth2Config interface { + AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string + Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) + TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource +} + +var _ OAuth2Config = (*Config)(nil) + +type Factory struct { + metrics *metrics +} + +// metrics is the reusable metrics for all oauth2 providers. +type metrics struct { + externalRequestCount *prometheus.CounterVec +} + +func NewFactory(registry prometheus.Registerer) *Factory { + factory := promauto.With(registry) + + return &Factory{ + metrics: &metrics{ + externalRequestCount: factory.NewCounterVec(prometheus.CounterOpts{ + Namespace: "coderd", + Subsystem: "oauth2", + Name: "external_requests_total", + Help: "The total number of api calls made to external oauth2 providers. 'status_code' will be 0 if the request failed with no response.", + }, []string{ + "name", + "status_code", + "domain", + }), + }, + } +} + +func (f *Factory) New(name string, under OAuth2Config) *Config { + return &Config{ + name: name, + underlying: under, + metrics: f.metrics, + } +} + +type Config struct { + // Name is a human friendly name to identify the oauth2 provider. This should be + // deterministic from restart to restart, as it is going to be used as a label in + // prometheus metrics. + name string + underlying OAuth2Config + metrics *metrics +} + +func (c *Config) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { + // No external requests are made when constructing the auth code url. + return c.underlying.AuthCodeURL(state, opts...) +} + +func (c *Config) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return c.underlying.Exchange(c.wrapClient(ctx), code, opts...) +} + +func (c *Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.TokenSource { + return c.underlying.TokenSource(c.wrapClient(ctx), token) +} + +// wrapClient is the only way we can accurately instrument the oauth2 client. +// This is because method calls to the 'OAuth2Config' interface are not 1:1 with +// network requests. +// +// For example, the 'TokenSource' method will return a token +// source that will make a network request when the 'Token' method is called on +// it if the token is expired. +func (c *Config) wrapClient(ctx context.Context) context.Context { + cli := http.DefaultClient + + // Check if the context has an http client already. + if hc, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok { + cli = hc + } + + // The new tripper will instrument every request made by the oauth2 client. + cli.Transport = newInstrumentedTripper(c, cli.Transport) + return context.WithValue(ctx, oauth2.HTTPClient, cli) +} + +type instrumentedTripper struct { + c *Config + underlying http.RoundTripper +} + +func newInstrumentedTripper(c *Config, under http.RoundTripper) *instrumentedTripper { + if under == nil { + under = http.DefaultTransport + } + return &instrumentedTripper{ + c: c, + underlying: under, + } +} + +func (i *instrumentedTripper) RoundTrip(r *http.Request) (*http.Response, error) { + resp, err := i.underlying.RoundTrip(r) + var statusCode int + if resp != nil { + statusCode = resp.StatusCode + } + i.c.metrics.externalRequestCount.With(prometheus.Labels{ + "name": i.c.name, + "status_code": fmt.Sprintf("%d", statusCode), + "domain": r.URL.Host, + }).Inc() + return resp, err +} diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index 2715ba6776e8d..1330f370d6191 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -32,7 +32,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/coderd/externalauth" - "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/promoauth" "github.com/coder/coder/v2/coderd/schedule" "github.com/coder/coder/v2/coderd/telemetry" "github.com/coder/coder/v2/coderd/tracing" @@ -55,7 +55,7 @@ const ( ) type Options struct { - OIDCConfig httpmw.OAuth2Config + OIDCConfig promoauth.OAuth2Config ExternalAuthConfigs []*externalauth.Config // TimeNowFn is only used in tests TimeNowFn func() time.Time @@ -96,7 +96,7 @@ type server struct { UserQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore] DeploymentValues *codersdk.DeploymentValues - OIDCConfig httpmw.OAuth2Config + OIDCConfig promoauth.OAuth2Config TimeNowFn func() time.Time @@ -1736,7 +1736,7 @@ func deleteSessionToken(ctx context.Context, db database.Store, workspace databa // obtainOIDCAccessToken returns a valid OpenID Connect access token // for the user if it's able to obtain one, otherwise it returns an empty string. -func obtainOIDCAccessToken(ctx context.Context, db database.Store, oidcConfig httpmw.OAuth2Config, userID uuid.UUID) (string, error) { +func obtainOIDCAccessToken(ctx context.Context, db database.Store, oidcConfig promoauth.OAuth2Config, userID uuid.UUID) (string, error) { link, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{ UserID: userID, LoginType: database.LoginTypeOIDC, diff --git a/coderd/userauth.go b/coderd/userauth.go index 94fe821da7cf2..54f10d7388f79 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -31,6 +31,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/promoauth" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/userpassword" "github.com/coder/coder/v2/codersdk" @@ -438,7 +439,7 @@ type GithubOAuth2Team struct { // GithubOAuth2Provider exposes required functions for the Github authentication flow. type GithubOAuth2Config struct { - httpmw.OAuth2Config + promoauth.OAuth2Config AuthenticatedUser func(ctx context.Context, client *http.Client) (*github.User, error) ListEmails func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) ListOrganizationMemberships func(ctx context.Context, client *http.Client) ([]*github.Membership, error) @@ -662,7 +663,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { } type OIDCConfig struct { - httpmw.OAuth2Config + promoauth.OAuth2Config Provider *oidc.Provider Verifier *oidc.IDTokenVerifier From fd1e012a67e961594f0611bd4607e299f584e89e Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 9 Jan 2024 12:40:23 -0600 Subject: [PATCH 02/20] Add a unit test --- coderd/coderdtest/oidctest/idp.go | 53 +++++++++++++++++++++++++++---- coderd/promoauth/oauth2.go | 25 ++++++++++----- coderd/promoauth/oauth2_test.go | 53 +++++++++++++++++++++++++++++++ 3 files changed, 117 insertions(+), 14 deletions(-) create mode 100644 coderd/promoauth/oauth2_test.go diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index 20702be16ab33..da113b346da74 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -397,6 +397,43 @@ func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...f _ = res.Body.Close() } +// CreateAuthCode emulates a user clicking "allow" on the IDP page. When doing +// unit tests, it's easier to skip this step sometimes. It does make an actual +// request to the IDP, so it should be equivalent to doing this "manually" with +// actual requests. +func (f *FakeIDP) CreateAuthCode(t testing.TB, state string, opts ...func(r *http.Request)) string { + // We need to store some claims, because this is also an OIDC provider, and + // it expects some claims to be present. + f.stateToIDTokenClaims.Store(state, jwt.MapClaims{}) + + u := f.cfg.AuthCodeURL(state) + r, err := http.NewRequestWithContext(context.Background(), http.MethodPost, u, nil) + require.NoError(t, err, "failed to create auth request") + + for _, opt := range opts { + opt(r) + } + + rw := httptest.NewRecorder() + f.handler.ServeHTTP(rw, r) + resp := rw.Result() + + require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode, "expected redirect") + to := resp.Header.Get("Location") + require.NotEmpty(t, to, "expected redirect location") + + toUrl, err := url.Parse(to) + require.NoError(t, err, "failed to parse redirect location") + + code := toUrl.Query().Get("code") + require.NotEmpty(t, code, "expected code in redirect location") + + newState := toUrl.Query().Get("state") + require.Equal(t, state, newState, "expected state to match") + + return code +} + // OIDCCallback will emulate the IDP redirecting back to the Coder callback. // This is helpful if no Coderd exists because the IDP needs to redirect to // something. @@ -917,13 +954,10 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu return cfg } -// OIDCConfig returns the OIDC config to use for Coderd. -func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig { - t.Helper() +func (f *FakeIDP) OAuthConfig(scopes ...string) *oauth2.Config { if len(scopes) == 0 { scopes = []string{"openid", "email", "profile"} } - oauthCfg := &oauth2.Config{ ClientID: f.clientID, ClientSecret: f.clientSecret, @@ -937,6 +971,15 @@ func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *co RedirectURL: "https://redirect.com", Scopes: scopes, } + f.cfg = oauthCfg + return oauthCfg +} + +// OIDCConfig returns the OIDC config to use for Coderd. +func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig { + t.Helper() + + oauthCfg := f.OAuthConfig(scopes...) ctx := oidc.ClientContext(context.Background(), f.HTTPClient(nil)) p, err := oidc.NewProvider(ctx, f.provider.Issuer) @@ -965,8 +1008,6 @@ func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *co opt(cfg) } - f.cfg = oauthCfg - return cfg } diff --git a/coderd/promoauth/oauth2.go b/coderd/promoauth/oauth2.go index 8331d512ecdd7..140078cb51121 100644 --- a/coderd/promoauth/oauth2.go +++ b/coderd/promoauth/oauth2.go @@ -41,8 +41,8 @@ func NewFactory(registry prometheus.Registerer) *Factory { Help: "The total number of api calls made to external oauth2 providers. 'status_code' will be 0 if the request failed with no response.", }, []string{ "name", + "source", "status_code", - "domain", }), }, } @@ -71,11 +71,11 @@ func (c *Config) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string } func (c *Config) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { - return c.underlying.Exchange(c.wrapClient(ctx), code, opts...) + return c.underlying.Exchange(c.wrapClient(ctx, "Exchange"), code, opts...) } func (c *Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.TokenSource { - return c.underlying.TokenSource(c.wrapClient(ctx), token) + return c.underlying.TokenSource(c.wrapClient(ctx, "TokenSource"), token) } // wrapClient is the only way we can accurately instrument the oauth2 client. @@ -85,8 +85,8 @@ func (c *Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.To // For example, the 'TokenSource' method will return a token // source that will make a network request when the 'Token' method is called on // it if the token is expired. -func (c *Config) wrapClient(ctx context.Context) context.Context { - cli := http.DefaultClient +func (c *Config) wrapClient(ctx context.Context, source string) context.Context { + cli := &http.Client{} // Check if the context has an http client already. if hc, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok { @@ -94,21 +94,30 @@ func (c *Config) wrapClient(ctx context.Context) context.Context { } // The new tripper will instrument every request made by the oauth2 client. - cli.Transport = newInstrumentedTripper(c, cli.Transport) + cli.Transport = newInstrumentedTripper(c, source, cli.Transport) return context.WithValue(ctx, oauth2.HTTPClient, cli) } type instrumentedTripper struct { c *Config + source string underlying http.RoundTripper } -func newInstrumentedTripper(c *Config, under http.RoundTripper) *instrumentedTripper { +func newInstrumentedTripper(c *Config, source string, under http.RoundTripper) *instrumentedTripper { if under == nil { under = http.DefaultTransport } + + // If the underlying transport is the default, we need to clone it. + // We should also clone it if it supports cloning. + if tr, ok := under.(*http.Transport); ok { + under = tr.Clone() + } + return &instrumentedTripper{ c: c, + source: source, underlying: under, } } @@ -121,8 +130,8 @@ func (i *instrumentedTripper) RoundTrip(r *http.Request) (*http.Response, error) } i.c.metrics.externalRequestCount.With(prometheus.Labels{ "name": i.c.name, + "source": i.source, "status_code": fmt.Sprintf("%d", statusCode), - "domain": r.URL.Host, }).Inc() return resp, err } diff --git a/coderd/promoauth/oauth2_test.go b/coderd/promoauth/oauth2_test.go new file mode 100644 index 0000000000000..1e1fa612d7c50 --- /dev/null +++ b/coderd/promoauth/oauth2_test.go @@ -0,0 +1,53 @@ +package promoauth_test + +import ( + "net/http" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + ptestutil "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/coderdtest/oidctest" + "github.com/coder/coder/v2/coderd/promoauth" + "github.com/coder/coder/v2/testutil" +) + +func TestMaintainDefault(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + idp := oidctest.NewFakeIDP(t, oidctest.WithServing()) + reg := prometheus.NewRegistry() + count := func() int { + return ptestutil.CollectAndCount(reg, "coderd_oauth2_external_requests_total") + } + + factory := promoauth.NewFactory(reg) + cfg := factory.New("test", idp.OAuthConfig()) + + // 0 Requests before we start + require.Equal(t, count(), 0) + + // Exchange should trigger a request + code := idp.CreateAuthCode(t, "foo") + token, err := cfg.Exchange(ctx, code) + require.NoError(t, err) + require.Equal(t, count(), 1) + + // Force a refresh + token.Expiry = time.Now().Add(time.Hour * -1) + src := cfg.TokenSource(ctx, token) + refreshed, err := src.Token() + require.NoError(t, err) + require.NotEqual(t, token.AccessToken, refreshed.AccessToken, "token refreshed") + require.Equal(t, count(), 2) + + // Verify the default client was not broken. This check is added because we + // extend the http.DefaultTransport. If a `.Clone()` is not done, this can be + // mis-used. It is cheap to run this quick check. + _, err = http.DefaultClient.Get("https://coder.com") + require.NoError(t, err) + require.Equal(t, count(), 2) +} From a6de1e3681e52de80039dc58021401d354052216 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 9 Jan 2024 12:43:51 -0600 Subject: [PATCH 03/20] remove internet based request --- coderd/coderdtest/oidctest/idp.go | 4 ++++ coderd/promoauth/oauth2_test.go | 15 ++++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index da113b346da74..d5bb28e81e5cc 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -223,6 +223,10 @@ func (f *FakeIDP) WellknownConfig() ProviderJSON { return f.provider } +func (f *FakeIDP) IssuerURL() *url.URL { + return f.issuerURL +} + func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { t.Helper() diff --git a/coderd/promoauth/oauth2_test.go b/coderd/promoauth/oauth2_test.go index 1e1fa612d7c50..44ebb82319b66 100644 --- a/coderd/promoauth/oauth2_test.go +++ b/coderd/promoauth/oauth2_test.go @@ -47,7 +47,20 @@ func TestMaintainDefault(t *testing.T) { // Verify the default client was not broken. This check is added because we // extend the http.DefaultTransport. If a `.Clone()` is not done, this can be // mis-used. It is cheap to run this quick check. - _, err = http.DefaultClient.Get("https://coder.com") + req, err := http.NewRequest(http.MethodGet, + must(idp.IssuerURL().Parse("/.well-known/openid-configuration")).String(), nil) require.NoError(t, err) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + _ = resp.Body.Close() + require.Equal(t, count(), 2) } + +func must[V any](v V, err error) V { + if err != nil { + panic(err) + } + return v +} From 005883a439cc9d9ab85185bc0c1a6208775d4ed2 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 9 Jan 2024 12:44:20 -0600 Subject: [PATCH 04/20] use ctx --- coderd/promoauth/oauth2_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/coderd/promoauth/oauth2_test.go b/coderd/promoauth/oauth2_test.go index 44ebb82319b66..1d7b15923cd34 100644 --- a/coderd/promoauth/oauth2_test.go +++ b/coderd/promoauth/oauth2_test.go @@ -50,6 +50,7 @@ func TestMaintainDefault(t *testing.T) { req, err := http.NewRequest(http.MethodGet, must(idp.IssuerURL().Parse("/.well-known/openid-configuration")).String(), nil) require.NoError(t, err) + req = req.WithContext(ctx) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) From 2f77f9996f772e98886257127418b535856569f9 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 9 Jan 2024 13:33:17 -0600 Subject: [PATCH 05/20] typo --- cli/server.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cli/server.go b/cli/server.go index 4f62aefdf98b3..cfac4e9b71c45 100644 --- a/cli/server.go +++ b/cli/server.go @@ -525,10 +525,10 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. } promRegistry := prometheus.NewRegistry() - oauthIntrument := promoauth.NewFactory(promRegistry) + oauthInstrument := promoauth.NewFactory(promRegistry) vals.ExternalAuthConfigs.Value = append(vals.ExternalAuthConfigs.Value, extAuthEnv...) externalAuthConfigs, err := externalauth.ConvertConfig( - oauthIntrument, + oauthInstrument, vals.ExternalAuthConfigs.Value, vals.AccessURL.Value(), ) @@ -622,7 +622,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. if vals.OAuth2.Github.ClientSecret != "" { options.GithubOAuth2Config, err = configureGithubOAuth2( - oauthIntrument, + oauthInstrument, vals.AccessURL.Value(), vals.OAuth2.Github.ClientID.String(), vals.OAuth2.Github.ClientSecret.String(), @@ -642,7 +642,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. logger.Warn(ctx, "coder will not check email_verified for OIDC logins") } - oc, err := createOIDCConfig(ctx, oauthIntrument, vals) + oc, err := createOIDCConfig(ctx, oauthInstrument, vals) if err != nil { return xerrors.Errorf("create oidc config: %w", err) } From 3377a9b20e28c8fb6bb6dd274c5312d763b7a801 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 9 Jan 2024 13:48:39 -0600 Subject: [PATCH 06/20] Work on instrumenting related oauth2 calls --- coderd/coderdtest/oidctest/idp.go | 7 +++--- coderd/externalauth/externalauth.go | 37 ++++++++++++++-------------- coderd/promoauth/oauth2.go | 38 ++++++++++++++++++++++++++++- coderd/workspaceagents.go | 2 +- 4 files changed, 60 insertions(+), 24 deletions(-) diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index d5bb28e81e5cc..93f446c7cefa8 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -421,18 +421,19 @@ func (f *FakeIDP) CreateAuthCode(t testing.TB, state string, opts ...func(r *htt rw := httptest.NewRecorder() f.handler.ServeHTTP(rw, r) resp := rw.Result() + defer resp.Body.Close() require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode, "expected redirect") to := resp.Header.Get("Location") require.NotEmpty(t, to, "expected redirect location") - toUrl, err := url.Parse(to) + toURL, err := url.Parse(to) require.NoError(t, err, "failed to parse redirect location") - code := toUrl.Query().Get("code") + code := toURL.Query().Get("code") require.NotEmpty(t, code, "expected code in redirect location") - newState := toUrl.Query().Get("state") + newState := toURL.Query().Get("state") require.Equal(t, state, newState, "expected state to match") return code diff --git a/coderd/externalauth/externalauth.go b/coderd/externalauth/externalauth.go index 0c49cbcb62d53..d42e3dd49d2b2 100644 --- a/coderd/externalauth/externalauth.go +++ b/coderd/externalauth/externalauth.go @@ -29,7 +29,7 @@ import ( // Config is used for authentication for Git operations. type Config struct { - promoauth.OAuth2Config + promoauth.InstrumentedeOAuth2Config // ID is a unique identifier for the authenticator. ID string // Type is the type of provider. @@ -187,12 +187,8 @@ func (c *Config) ValidateToken(ctx context.Context, token string) (bool, *coders return false, nil, err } - cli := http.DefaultClient - if v, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok { - cli = v - } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - res, err := cli.Do(req) + res, err := c.InstrumentedeOAuth2Config.Do(ctx, "ValidateToken", req) if err != nil { return false, nil, err } @@ -242,7 +238,7 @@ func (c *Config) AppInstallations(ctx context.Context, token string) ([]codersdk return nil, false, err } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - res, err := http.DefaultClient.Do(req) + res, err := c.InstrumentedeOAuth2Config.Do(ctx, "AppInstallations", req) if err != nil { return nil, false, err } @@ -282,6 +278,8 @@ func (c *Config) AppInstallations(ctx context.Context, token string) ([]codersdk } type DeviceAuth struct { + // Cfg is provided for the http client method. + Cfg promoauth.InstrumentedeOAuth2Config ClientID string TokenURL string Scopes []string @@ -302,8 +300,8 @@ func (c *DeviceAuth) AuthorizeDevice(ctx context.Context) (*codersdk.ExternalAut if err != nil { return nil, err } + resp, err := c.Cfg.Do(ctx, "AuthorizeDevice", req) req.Header.Set("Accept", "application/json") - resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } @@ -458,17 +456,17 @@ func ConvertConfig(instrument *promoauth.Factory, entries []codersdk.ExternalAut } cfg := &Config{ - OAuth2Config: instrument.New(entry.ID, oauthConfig), - ID: entry.ID, - Regex: regex, - Type: entry.Type, - NoRefresh: entry.NoRefresh, - ValidateURL: entry.ValidateURL, - AppInstallationsURL: entry.AppInstallationsURL, - AppInstallURL: entry.AppInstallURL, - DisplayName: entry.DisplayName, - DisplayIcon: entry.DisplayIcon, - ExtraTokenKeys: entry.ExtraTokenKeys, + InstrumentedeOAuth2Config: instrument.New(entry.ID, oauthConfig), + ID: entry.ID, + Regex: regex, + Type: entry.Type, + NoRefresh: entry.NoRefresh, + ValidateURL: entry.ValidateURL, + AppInstallationsURL: entry.AppInstallationsURL, + AppInstallURL: entry.AppInstallURL, + DisplayName: entry.DisplayName, + DisplayIcon: entry.DisplayIcon, + ExtraTokenKeys: entry.ExtraTokenKeys, } if entry.DeviceFlow { @@ -476,6 +474,7 @@ func ConvertConfig(instrument *promoauth.Factory, entries []codersdk.ExternalAut return nil, xerrors.Errorf("external auth provider %q: device auth url must be provided", entry.ID) } cfg.DeviceAuth = &DeviceAuth{ + Cfg: cfg, ClientID: entry.ClientID, TokenURL: oc.Endpoint.TokenURL, Scopes: entry.Scopes, diff --git a/coderd/promoauth/oauth2.go b/coderd/promoauth/oauth2.go index 140078cb51121..2780c7ff9ccf6 100644 --- a/coderd/promoauth/oauth2.go +++ b/coderd/promoauth/oauth2.go @@ -18,6 +18,24 @@ type OAuth2Config interface { TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource } +type InstrumentedeOAuth2Config interface { + OAuth2Config + + // Do is provided as a convience method to make a request with the oauth2 client. + // It mirrors `http.Client.Do`. + // We need this because Coder adds some extra functionality to + // oauth clients such as the `ValidateToken()` method. + Do(ctx context.Context, source string, req *http.Request) (*http.Response, error) +} + +type HTTPDo interface { + // Do is provided as a convience method to make a request with the oauth2 client. + // It mirrors `http.Client.Do`. + // We need this because Coder adds some extra functionality to + // oauth clients such as the `ValidateToken()` method. + Do(ctx context.Context, source string, req *http.Request) (*http.Response, error) +} + var _ OAuth2Config = (*Config)(nil) type Factory struct { @@ -65,6 +83,11 @@ type Config struct { metrics *metrics } +func (c *Config) Do(ctx context.Context, source string, req *http.Request) (*http.Response, error) { + cli := c.oauthHTTPClient(ctx, source) + return cli.Do(req) +} + func (c *Config) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { // No external requests are made when constructing the auth code url. return c.underlying.AuthCodeURL(state, opts...) @@ -86,6 +109,10 @@ func (c *Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.To // source that will make a network request when the 'Token' method is called on // it if the token is expired. func (c *Config) wrapClient(ctx context.Context, source string) context.Context { + return context.WithValue(ctx, oauth2.HTTPClient, c.oauthHTTPClient(ctx, source)) +} + +func (c *Config) oauthHTTPClient(ctx context.Context, source string) *http.Client { cli := &http.Client{} // Check if the context has an http client already. @@ -95,7 +122,7 @@ func (c *Config) wrapClient(ctx context.Context, source string) context.Context // The new tripper will instrument every request made by the oauth2 client. cli.Transport = newInstrumentedTripper(c, source, cli.Transport) - return context.WithValue(ctx, oauth2.HTTPClient, cli) + return cli } type instrumentedTripper struct { @@ -133,5 +160,14 @@ func (i *instrumentedTripper) RoundTrip(r *http.Request) (*http.Response, error) "source": i.source, "status_code": fmt.Sprintf("%d", statusCode), }).Inc() + if err == nil { + fmt.Println(map[string]string{ + "limit": resp.Header.Get("x-ratelimit-limit"), + "remain": resp.Header.Get("x-ratelimit-remaining"), + "used": resp.Header.Get("x-ratelimit-used"), + "reset": resp.Header.Get("x-ratelimit-reset"), + "resource": resp.Header.Get("x-ratelimit-resource"), + }) + } return resp, err } diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 917e979e092ee..da90df232d631 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -2100,7 +2100,7 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ }) return } - httpapi.Write(ctx, rw, http.StatusOK, resp) + httpapi.Write(ctx, rw, http.StatusInternalServerError, resp) return } } From 07fd10d25f9b2ae813221fd40712506898478cc2 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 9 Jan 2024 13:50:30 -0600 Subject: [PATCH 07/20] remove print --- coderd/oauthpki/oidcpki.go | 2 ++ coderd/promoauth/oauth2.go | 10 +--------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/coderd/oauthpki/oidcpki.go b/coderd/oauthpki/oidcpki.go index dff1240ca7dc1..d761c43e446ff 100644 --- a/coderd/oauthpki/oidcpki.go +++ b/coderd/oauthpki/oidcpki.go @@ -180,6 +180,8 @@ func (src *jwtTokenSource) Token() (*oauth2.Token, error) { } cli := http.DefaultClient if v, ok := src.ctx.Value(oauth2.HTTPClient).(*http.Client); ok { + // This client should be the instrumented client already. So no need to + // handle this manually. cli = v } diff --git a/coderd/promoauth/oauth2.go b/coderd/promoauth/oauth2.go index 2780c7ff9ccf6..8f0d42ca3ef4c 100644 --- a/coderd/promoauth/oauth2.go +++ b/coderd/promoauth/oauth2.go @@ -160,14 +160,6 @@ func (i *instrumentedTripper) RoundTrip(r *http.Request) (*http.Response, error) "source": i.source, "status_code": fmt.Sprintf("%d", statusCode), }).Inc() - if err == nil { - fmt.Println(map[string]string{ - "limit": resp.Header.Get("x-ratelimit-limit"), - "remain": resp.Header.Get("x-ratelimit-remaining"), - "used": resp.Header.Get("x-ratelimit-used"), - "reset": resp.Header.Get("x-ratelimit-reset"), - "resource": resp.Header.Get("x-ratelimit-resource"), - }) - } + return resp, err } From 117a40542aa43ec270e192303f68ebf28a6e99c9 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 9 Jan 2024 13:56:43 -0600 Subject: [PATCH 08/20] typos --- coderd/externalauth/externalauth.go | 30 ++++++++++++++--------------- coderd/promoauth/oauth2.go | 13 ++++--------- 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/coderd/externalauth/externalauth.go b/coderd/externalauth/externalauth.go index d42e3dd49d2b2..07721c7e63989 100644 --- a/coderd/externalauth/externalauth.go +++ b/coderd/externalauth/externalauth.go @@ -29,7 +29,7 @@ import ( // Config is used for authentication for Git operations. type Config struct { - promoauth.InstrumentedeOAuth2Config + promoauth.InstrumentedOAuth2Config // ID is a unique identifier for the authenticator. ID string // Type is the type of provider. @@ -188,7 +188,7 @@ func (c *Config) ValidateToken(ctx context.Context, token string) (bool, *coders } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - res, err := c.InstrumentedeOAuth2Config.Do(ctx, "ValidateToken", req) + res, err := c.InstrumentedOAuth2Config.Do(ctx, "ValidateToken", req) if err != nil { return false, nil, err } @@ -238,7 +238,7 @@ func (c *Config) AppInstallations(ctx context.Context, token string) ([]codersdk return nil, false, err } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - res, err := c.InstrumentedeOAuth2Config.Do(ctx, "AppInstallations", req) + res, err := c.InstrumentedOAuth2Config.Do(ctx, "AppInstallations", req) if err != nil { return nil, false, err } @@ -279,7 +279,7 @@ func (c *Config) AppInstallations(ctx context.Context, token string) ([]codersdk type DeviceAuth struct { // Cfg is provided for the http client method. - Cfg promoauth.InstrumentedeOAuth2Config + Cfg promoauth.InstrumentedOAuth2Config ClientID string TokenURL string Scopes []string @@ -456,17 +456,17 @@ func ConvertConfig(instrument *promoauth.Factory, entries []codersdk.ExternalAut } cfg := &Config{ - InstrumentedeOAuth2Config: instrument.New(entry.ID, oauthConfig), - ID: entry.ID, - Regex: regex, - Type: entry.Type, - NoRefresh: entry.NoRefresh, - ValidateURL: entry.ValidateURL, - AppInstallationsURL: entry.AppInstallationsURL, - AppInstallURL: entry.AppInstallURL, - DisplayName: entry.DisplayName, - DisplayIcon: entry.DisplayIcon, - ExtraTokenKeys: entry.ExtraTokenKeys, + InstrumentedOAuth2Config: instrument.New(entry.ID, oauthConfig), + ID: entry.ID, + Regex: regex, + Type: entry.Type, + NoRefresh: entry.NoRefresh, + ValidateURL: entry.ValidateURL, + AppInstallationsURL: entry.AppInstallationsURL, + AppInstallURL: entry.AppInstallURL, + DisplayName: entry.DisplayName, + DisplayIcon: entry.DisplayIcon, + ExtraTokenKeys: entry.ExtraTokenKeys, } if entry.DeviceFlow { diff --git a/coderd/promoauth/oauth2.go b/coderd/promoauth/oauth2.go index 8f0d42ca3ef4c..cb79f5971d273 100644 --- a/coderd/promoauth/oauth2.go +++ b/coderd/promoauth/oauth2.go @@ -18,7 +18,10 @@ type OAuth2Config interface { TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource } -type InstrumentedeOAuth2Config interface { +// InstrumentedOAuth2Config extends OAuth2Config with a `Do` method that allows +// external oauth related calls to be instrumented. This is to support +// "ValidateToken" which is not an oauth2 specified method. +type InstrumentedOAuth2Config interface { OAuth2Config // Do is provided as a convience method to make a request with the oauth2 client. @@ -28,14 +31,6 @@ type InstrumentedeOAuth2Config interface { Do(ctx context.Context, source string, req *http.Request) (*http.Response, error) } -type HTTPDo interface { - // Do is provided as a convience method to make a request with the oauth2 client. - // It mirrors `http.Client.Do`. - // We need this because Coder adds some extra functionality to - // oauth clients such as the `ValidateToken()` method. - Do(ctx context.Context, source string, req *http.Request) (*http.Response, error) -} - var _ OAuth2Config = (*Config)(nil) type Factory struct { From 73abae6a7dff47e8f7553b9cfd70d05cba32cdd2 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 9 Jan 2024 14:01:26 -0600 Subject: [PATCH 09/20] Fixup some unit tests --- cli/create_test.go | 10 +++++----- coderd/coderdtest/oidctest/idp.go | 22 ++++++++++------------ coderd/promoauth/oauth2.go | 10 +++++++++- coderd/promoauth/oauth2_test.go | 2 +- testutil/oauth2.go | 5 +++++ 5 files changed, 30 insertions(+), 19 deletions(-) diff --git a/cli/create_test.go b/cli/create_test.go index 42b526d404cfc..903694167fd72 100644 --- a/cli/create_test.go +++ b/cli/create_test.go @@ -767,11 +767,11 @@ func TestCreateWithGitAuth(t *testing.T) { client := coderdtest.New(t, &coderdtest.Options{ ExternalAuthConfigs: []*externalauth.Config{{ - OAuth2Config: &testutil.OAuth2Config{}, - ID: "github", - Regex: regexp.MustCompile(`github\.com`), - Type: codersdk.EnhancedExternalAuthProviderGitHub.String(), - DisplayName: "GitHub", + InstrumentedOAuth2Config: &testutil.OAuth2Config{}, + ID: "github", + Regex: regexp.MustCompile(`github\.com`), + Type: codersdk.EnhancedExternalAuthProviderGitHub.String(), + DisplayName: "GitHub", }}, IncludeProvisionerDaemon: true, }) diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index 93f446c7cefa8..460c687a9751a 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -24,6 +24,7 @@ import ( "github.com/go-jose/go-jose/v3" "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" @@ -33,6 +34,7 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/coderd/externalauth" + "github.com/coder/coder/v2/coderd/promoauth" "github.com/coder/coder/v2/coderd/util/syncmap" "github.com/coder/coder/v2/codersdk" ) @@ -943,9 +945,10 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu handle(email, rw, r) } } + instrumentF := promoauth.NewFactory(prometheus.NewRegistry()) cfg := &externalauth.Config{ - OAuth2Config: f.OIDCConfig(t, nil), - ID: id, + InstrumentedOAuth2Config: instrumentF.New(f.clientID, f.OIDCConfig(t, nil)), + ID: id, // No defaults for these fields by omitting the type Type: "", DisplayIcon: f.WellknownConfig().UserInfoURL, @@ -959,7 +962,10 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu return cfg } -func (f *FakeIDP) OAuthConfig(scopes ...string) *oauth2.Config { +// OIDCConfig returns the OIDC config to use for Coderd. +func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig { + t.Helper() + if len(scopes) == 0 { scopes = []string{"openid", "email", "profile"} } @@ -976,15 +982,6 @@ func (f *FakeIDP) OAuthConfig(scopes ...string) *oauth2.Config { RedirectURL: "https://redirect.com", Scopes: scopes, } - f.cfg = oauthCfg - return oauthCfg -} - -// OIDCConfig returns the OIDC config to use for Coderd. -func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig { - t.Helper() - - oauthCfg := f.OAuthConfig(scopes...) ctx := oidc.ClientContext(context.Background(), f.HTTPClient(nil)) p, err := oidc.NewProvider(ctx, f.provider.Issuer) @@ -1013,6 +1010,7 @@ func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *co opt(cfg) } + f.cfg = oauthCfg return cfg } diff --git a/coderd/promoauth/oauth2.go b/coderd/promoauth/oauth2.go index cb79f5971d273..4d11c0967b4ae 100644 --- a/coderd/promoauth/oauth2.go +++ b/coderd/promoauth/oauth2.go @@ -155,6 +155,14 @@ func (i *instrumentedTripper) RoundTrip(r *http.Request) (*http.Response, error) "source": i.source, "status_code": fmt.Sprintf("%d", statusCode), }).Inc() - + if err == nil { + fmt.Println(map[string]string{ + "limit": resp.Header.Get("x-ratelimit-limit"), + "remain": resp.Header.Get("x-ratelimit-remaining"), + "used": resp.Header.Get("x-ratelimit-used"), + "reset": resp.Header.Get("x-ratelimit-reset"), + "resource": resp.Header.Get("x-ratelimit-resource"), + }) + } return resp, err } diff --git a/coderd/promoauth/oauth2_test.go b/coderd/promoauth/oauth2_test.go index 1d7b15923cd34..777bdd942e162 100644 --- a/coderd/promoauth/oauth2_test.go +++ b/coderd/promoauth/oauth2_test.go @@ -25,7 +25,7 @@ func TestMaintainDefault(t *testing.T) { } factory := promoauth.NewFactory(reg) - cfg := factory.New("test", idp.OAuthConfig()) + cfg := factory.New("test", idp.OIDCConfig(t, []string{})) // 0 Requests before we start require.Equal(t, count(), 0) diff --git a/testutil/oauth2.go b/testutil/oauth2.go index e152caf956db5..3bb22b0a03f5a 100644 --- a/testutil/oauth2.go +++ b/testutil/oauth2.go @@ -2,6 +2,7 @@ package testutil import ( "context" + "net/http" "net/url" "time" @@ -13,6 +14,10 @@ type OAuth2Config struct { TokenSourceFunc OAuth2TokenSource } +func (*OAuth2Config) Do(_ context.Context, _ string, req *http.Request) (*http.Response, error) { + return http.DefaultClient.Do(req) +} + func (*OAuth2Config) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string { return "/?state=" + url.QueryEscape(state) } From e5e190da27629a2795243cab73d076531d8a6348 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 9 Jan 2024 14:03:00 -0600 Subject: [PATCH 10/20] fixup test cases --- coderd/promoauth/oauth2.go | 9 --------- coderd/provisionerdserver/provisionerdserver_test.go | 4 ++-- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/coderd/promoauth/oauth2.go b/coderd/promoauth/oauth2.go index 4d11c0967b4ae..a4dd81e125603 100644 --- a/coderd/promoauth/oauth2.go +++ b/coderd/promoauth/oauth2.go @@ -155,14 +155,5 @@ func (i *instrumentedTripper) RoundTrip(r *http.Request) (*http.Response, error) "source": i.source, "status_code": fmt.Sprintf("%d", statusCode), }).Inc() - if err == nil { - fmt.Println(map[string]string{ - "limit": resp.Header.Get("x-ratelimit-limit"), - "remain": resp.Header.Get("x-ratelimit-remaining"), - "used": resp.Header.Get("x-ratelimit-used"), - "reset": resp.Header.Get("x-ratelimit-reset"), - "resource": resp.Header.Get("x-ratelimit-resource"), - }) - } return resp, err } diff --git a/coderd/provisionerdserver/provisionerdserver_test.go b/coderd/provisionerdserver/provisionerdserver_test.go index 915b50a31dc02..01a0837a4d028 100644 --- a/coderd/provisionerdserver/provisionerdserver_test.go +++ b/coderd/provisionerdserver/provisionerdserver_test.go @@ -187,8 +187,8 @@ func TestAcquireJob(t *testing.T) { srv, db, ps, _ := setup(t, false, &overrides{ deploymentValues: dv, externalAuthConfigs: []*externalauth.Config{{ - ID: gitAuthProvider, - OAuth2Config: &testutil.OAuth2Config{}, + ID: gitAuthProvider, + InstrumentedOAuth2Config: &testutil.OAuth2Config{}, }}, }) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) From d4b36d3cc0265c6f07f23872c11fbe9b8f168c5e Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 9 Jan 2024 14:05:44 -0600 Subject: [PATCH 11/20] remove oidc instrument --- cli/server.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/cli/server.go b/cli/server.go index cfac4e9b71c45..dd45e4dea8aef 100644 --- a/cli/server.go +++ b/cli/server.go @@ -103,7 +103,7 @@ import ( "github.com/coder/wgtunnel/tunnelsdk" ) -func createOIDCConfig(ctx context.Context, instrument *promoauth.Factory, vals *codersdk.DeploymentValues) (*coderd.OIDCConfig, error) { +func createOIDCConfig(ctx context.Context, vals *codersdk.DeploymentValues) (*coderd.OIDCConfig, error) { if vals.OIDC.ClientID == "" { return nil, xerrors.Errorf("OIDC client ID must be set!") } @@ -160,7 +160,7 @@ func createOIDCConfig(ctx context.Context, instrument *promoauth.Factory, vals * } return &coderd.OIDCConfig{ - OAuth2Config: instrument.New("oidc-login", useCfg), + OAuth2Config: useCfg, Provider: oidcProvider, Verifier: oidcProvider.Verifier(&oidc.Config{ ClientID: vals.OIDC.ClientID.String(), @@ -642,7 +642,13 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. logger.Warn(ctx, "coder will not check email_verified for OIDC logins") } - oc, err := createOIDCConfig(ctx, oauthInstrument, vals) + // This OIDC config is **not** being instrumented with the + // oauth2 instrument wrapper. If we implement the missing + // oidc methods, then we can instrument it. + // Missing: + // - Userinfo + // - Verify + oc, err := createOIDCConfig(ctx, vals) if err != nil { return xerrors.Errorf("create oidc config: %w", err) } From 89642973072480a9accaf1067a0d60d77a194c1e Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 9 Jan 2024 14:06:15 -0600 Subject: [PATCH 12/20] typos --- coderd/promoauth/oauth2.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/promoauth/oauth2.go b/coderd/promoauth/oauth2.go index a4dd81e125603..ae35e801180ca 100644 --- a/coderd/promoauth/oauth2.go +++ b/coderd/promoauth/oauth2.go @@ -24,7 +24,7 @@ type OAuth2Config interface { type InstrumentedOAuth2Config interface { OAuth2Config - // Do is provided as a convience method to make a request with the oauth2 client. + // Do is provided as a convenience method to make a request with the oauth2 client. // It mirrors `http.Client.Do`. // We need this because Coder adds some extra functionality to // oauth clients such as the `ValidateToken()` method. From 2ba7a5ce519047f2dc2781aaa18be07a4dc859ef Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 9 Jan 2024 14:21:21 -0600 Subject: [PATCH 13/20] more test coverage --- coderd/externalauth/externalauth_test.go | 21 +++++++---- coderd/externalauth_test.go | 44 ++++++++++++------------ coderd/promoauth/oauth2_test.go | 18 ++++++++-- coderd/templateversions_test.go | 8 ++--- 4 files changed, 55 insertions(+), 36 deletions(-) diff --git a/coderd/externalauth/externalauth_test.go b/coderd/externalauth/externalauth_test.go index 387bdc77382aa..84fbe4ff5de35 100644 --- a/coderd/externalauth/externalauth_test.go +++ b/coderd/externalauth/externalauth_test.go @@ -12,6 +12,7 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" "golang.org/x/oauth2" "golang.org/x/xerrors" @@ -22,6 +23,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbmem" "github.com/coder/coder/v2/coderd/externalauth" + "github.com/coder/coder/v2/coderd/promoauth" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" ) @@ -94,7 +96,7 @@ func TestRefreshToken(t *testing.T) { t.Run("FalseIfTokenSourceFails", func(t *testing.T) { t.Parallel() config := &externalauth.Config{ - OAuth2Config: &testutil.OAuth2Config{ + InstrumentedOAuth2Config: &testutil.OAuth2Config{ TokenSourceFunc: func() (*oauth2.Token, error) { return nil, xerrors.New("failure") }, @@ -301,9 +303,10 @@ func TestRefreshToken(t *testing.T) { func TestExchangeWithClientSecret(t *testing.T) { t.Parallel() + instrument := promoauth.NewFactory(prometheus.NewRegistry()) // This ensures a provider that requires the custom // client secret exchange works. - configs, err := externalauth.ConvertConfig([]codersdk.ExternalAuthConfig{{ + configs, err := externalauth.ConvertConfig(instrument, []codersdk.ExternalAuthConfig{{ // JFrog just happens to require this custom type. Type: codersdk.EnhancedExternalAuthProviderJFrog.String(), @@ -335,6 +338,8 @@ func TestExchangeWithClientSecret(t *testing.T) { func TestConvertYAML(t *testing.T) { t.Parallel() + + instrument := promoauth.NewFactory(prometheus.NewRegistry()) for _, tc := range []struct { Name string Input []codersdk.ExternalAuthConfig @@ -387,7 +392,7 @@ func TestConvertYAML(t *testing.T) { tc := tc t.Run(tc.Name, func(t *testing.T) { t.Parallel() - output, err := externalauth.ConvertConfig(tc.Input, &url.URL{}) + output, err := externalauth.ConvertConfig(instrument, tc.Input, &url.URL{}) if tc.Error != "" { require.Error(t, err) require.Contains(t, err.Error(), tc.Error) @@ -399,7 +404,7 @@ func TestConvertYAML(t *testing.T) { t.Run("CustomScopesAndEndpoint", func(t *testing.T) { t.Parallel() - config, err := externalauth.ConvertConfig([]codersdk.ExternalAuthConfig{{ + config, err := externalauth.ConvertConfig(instrument, []codersdk.ExternalAuthConfig{{ Type: string(codersdk.EnhancedExternalAuthProviderGitLab), ClientID: "id", ClientSecret: "secret", @@ -433,10 +438,12 @@ func setupOauth2Test(t *testing.T, settings testConfig) (*oidctest.FakeIDP, *ext append([]oidctest.FakeIDPOpt{}, settings.FakeIDPOpts...)..., ) + f := promoauth.NewFactory(prometheus.NewRegistry()) config := &externalauth.Config{ - OAuth2Config: fake.OIDCConfig(t, nil, settings.CoderOIDCConfigOpts...), - ID: providerID, - ValidateURL: fake.WellknownConfig().UserInfoURL, + InstrumentedOAuth2Config: f.New("test-oauth2", + fake.OIDCConfig(t, nil, settings.CoderOIDCConfigOpts...)), + ID: providerID, + ValidateURL: fake.WellknownConfig().UserInfoURL, } settings.ExternalAuthOpt(config) diff --git a/coderd/externalauth_test.go b/coderd/externalauth_test.go index 34c1fe7bcdc1e..1d0b06bbc0506 100644 --- a/coderd/externalauth_test.go +++ b/coderd/externalauth_test.go @@ -316,10 +316,10 @@ func TestExternalAuthCallback(t *testing.T) { client := coderdtest.New(t, &coderdtest.Options{ IncludeProvisionerDaemon: true, ExternalAuthConfigs: []*externalauth.Config{{ - OAuth2Config: &testutil.OAuth2Config{}, - ID: "github", - Regex: regexp.MustCompile(`github\.com`), - Type: codersdk.EnhancedExternalAuthProviderGitHub.String(), + InstrumentedOAuth2Config: &testutil.OAuth2Config{}, + ID: "github", + Regex: regexp.MustCompile(`github\.com`), + Type: codersdk.EnhancedExternalAuthProviderGitHub.String(), }}, }) user := coderdtest.CreateFirstUser(t, client) @@ -347,10 +347,10 @@ func TestExternalAuthCallback(t *testing.T) { client := coderdtest.New(t, &coderdtest.Options{ IncludeProvisionerDaemon: true, ExternalAuthConfigs: []*externalauth.Config{{ - OAuth2Config: &testutil.OAuth2Config{}, - ID: "github", - Regex: regexp.MustCompile(`github\.com`), - Type: codersdk.EnhancedExternalAuthProviderGitHub.String(), + InstrumentedOAuth2Config: &testutil.OAuth2Config{}, + ID: "github", + Regex: regexp.MustCompile(`github\.com`), + Type: codersdk.EnhancedExternalAuthProviderGitHub.String(), }}, }) resp := coderdtest.RequestExternalAuthCallback(t, "github", client) @@ -361,10 +361,10 @@ func TestExternalAuthCallback(t *testing.T) { client := coderdtest.New(t, &coderdtest.Options{ IncludeProvisionerDaemon: true, ExternalAuthConfigs: []*externalauth.Config{{ - OAuth2Config: &testutil.OAuth2Config{}, - ID: "github", - Regex: regexp.MustCompile(`github\.com`), - Type: codersdk.EnhancedExternalAuthProviderGitHub.String(), + InstrumentedOAuth2Config: &testutil.OAuth2Config{}, + ID: "github", + Regex: regexp.MustCompile(`github\.com`), + Type: codersdk.EnhancedExternalAuthProviderGitHub.String(), }}, }) _ = coderdtest.CreateFirstUser(t, client) @@ -387,11 +387,11 @@ func TestExternalAuthCallback(t *testing.T) { client := coderdtest.New(t, &coderdtest.Options{ IncludeProvisionerDaemon: true, ExternalAuthConfigs: []*externalauth.Config{{ - ValidateURL: srv.URL, - OAuth2Config: &testutil.OAuth2Config{}, - ID: "github", - Regex: regexp.MustCompile(`github\.com`), - Type: codersdk.EnhancedExternalAuthProviderGitHub.String(), + ValidateURL: srv.URL, + InstrumentedOAuth2Config: &testutil.OAuth2Config{}, + ID: "github", + Regex: regexp.MustCompile(`github\.com`), + Type: codersdk.EnhancedExternalAuthProviderGitHub.String(), }}, }) user := coderdtest.CreateFirstUser(t, client) @@ -443,7 +443,7 @@ func TestExternalAuthCallback(t *testing.T) { client := coderdtest.New(t, &coderdtest.Options{ IncludeProvisionerDaemon: true, ExternalAuthConfigs: []*externalauth.Config{{ - OAuth2Config: &testutil.OAuth2Config{ + InstrumentedOAuth2Config: &testutil.OAuth2Config{ Token: &oauth2.Token{ AccessToken: "token", RefreshToken: "something", @@ -497,10 +497,10 @@ func TestExternalAuthCallback(t *testing.T) { client := coderdtest.New(t, &coderdtest.Options{ IncludeProvisionerDaemon: true, ExternalAuthConfigs: []*externalauth.Config{{ - OAuth2Config: &testutil.OAuth2Config{}, - ID: "github", - Regex: regexp.MustCompile(`github\.com`), - Type: codersdk.EnhancedExternalAuthProviderGitHub.String(), + InstrumentedOAuth2Config: &testutil.OAuth2Config{}, + ID: "github", + Regex: regexp.MustCompile(`github\.com`), + Type: codersdk.EnhancedExternalAuthProviderGitHub.String(), }}, }) user := coderdtest.CreateFirstUser(t, client) diff --git a/coderd/promoauth/oauth2_test.go b/coderd/promoauth/oauth2_test.go index 777bdd942e162..78466ddceff21 100644 --- a/coderd/promoauth/oauth2_test.go +++ b/coderd/promoauth/oauth2_test.go @@ -10,11 +10,12 @@ import ( "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd/coderdtest/oidctest" + "github.com/coder/coder/v2/coderd/externalauth" "github.com/coder/coder/v2/coderd/promoauth" "github.com/coder/coder/v2/testutil" ) -func TestMaintainDefault(t *testing.T) { +func TestInstrument(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) @@ -25,7 +26,12 @@ func TestMaintainDefault(t *testing.T) { } factory := promoauth.NewFactory(reg) - cfg := factory.New("test", idp.OIDCConfig(t, []string{})) + const id = "test" + cfg := externalauth.Config{ + InstrumentedOAuth2Config: factory.New(id, idp.OIDCConfig(t, []string{})), + ID: "test", + ValidateURL: must(idp.IssuerURL().Parse("/oauth2/userinfo")).String(), + } // 0 Requests before we start require.Equal(t, count(), 0) @@ -44,6 +50,12 @@ func TestMaintainDefault(t *testing.T) { require.NotEqual(t, token.AccessToken, refreshed.AccessToken, "token refreshed") require.Equal(t, count(), 2) + // Try a validate + valid, _, err := cfg.ValidateToken(ctx, refreshed.AccessToken) + require.NoError(t, err) + require.True(t, valid) + require.Equal(t, count(), 3) + // Verify the default client was not broken. This check is added because we // extend the http.DefaultTransport. If a `.Clone()` is not done, this can be // mis-used. It is cheap to run this quick check. @@ -56,7 +68,7 @@ func TestMaintainDefault(t *testing.T) { require.NoError(t, err) _ = resp.Body.Close() - require.Equal(t, count(), 2) + require.Equal(t, count(), 3) } func must[V any](v V, err error) V { diff --git a/coderd/templateversions_test.go b/coderd/templateversions_test.go index b7765f076b2f7..4423bbc4e7056 100644 --- a/coderd/templateversions_test.go +++ b/coderd/templateversions_test.go @@ -335,10 +335,10 @@ func TestTemplateVersionsExternalAuth(t *testing.T) { client := coderdtest.New(t, &coderdtest.Options{ IncludeProvisionerDaemon: true, ExternalAuthConfigs: []*externalauth.Config{{ - OAuth2Config: &testutil.OAuth2Config{}, - ID: "github", - Regex: regexp.MustCompile(`github\.com`), - Type: codersdk.EnhancedExternalAuthProviderGitHub.String(), + InstrumentedOAuth2Config: &testutil.OAuth2Config{}, + ID: "github", + Regex: regexp.MustCompile(`github\.com`), + Type: codersdk.EnhancedExternalAuthProviderGitHub.String(), }}, }) user := coderdtest.CreateFirstUser(t, client) From 8963aaa8374a06023255a0612caa4f4a712bcb6a Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 9 Jan 2024 14:27:50 -0600 Subject: [PATCH 14/20] left debug --- coderd/workspaceagents.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index da90df232d631..917e979e092ee 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -2100,7 +2100,7 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ }) return } - httpapi.Write(ctx, rw, http.StatusInternalServerError, resp) + httpapi.Write(ctx, rw, http.StatusOK, resp) return } } From bfa427f3d9414ccc05cac2842b1e7aab48bc6b88 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 9 Jan 2024 14:36:43 -0600 Subject: [PATCH 15/20] Add comments --- coderd/externalauth/externalauth.go | 11 ++++++++++- coderd/promoauth/oauth2.go | 10 +++++++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/coderd/externalauth/externalauth.go b/coderd/externalauth/externalauth.go index 07721c7e63989..904de81ab2abf 100644 --- a/coderd/externalauth/externalauth.go +++ b/coderd/externalauth/externalauth.go @@ -300,7 +300,16 @@ func (c *DeviceAuth) AuthorizeDevice(ctx context.Context) (*codersdk.ExternalAut if err != nil { return nil, err } - resp, err := c.Cfg.Do(ctx, "AuthorizeDevice", req) + + do := http.DefaultClient.Do + if c.Cfg != nil { + // The cfg can be nil in unit tests. + do = func(req *http.Request) (*http.Response, error) { + return c.Cfg.Do(ctx, "AuthorizeDevice", req) + } + } + + resp, err := do(req) req.Header.Set("Accept", "application/json") if err != nil { return nil, err diff --git a/coderd/promoauth/oauth2.go b/coderd/promoauth/oauth2.go index ae35e801180ca..bd74610a5f8c1 100644 --- a/coderd/promoauth/oauth2.go +++ b/coderd/promoauth/oauth2.go @@ -21,18 +21,19 @@ type OAuth2Config interface { // InstrumentedOAuth2Config extends OAuth2Config with a `Do` method that allows // external oauth related calls to be instrumented. This is to support // "ValidateToken" which is not an oauth2 specified method. +// These calls still count against the api rate limit, and should be instrumented. type InstrumentedOAuth2Config interface { OAuth2Config // Do is provided as a convenience method to make a request with the oauth2 client. // It mirrors `http.Client.Do`. - // We need this because Coder adds some extra functionality to - // oauth clients such as the `ValidateToken()` method. Do(ctx context.Context, source string, req *http.Request) (*http.Response, error) } var _ OAuth2Config = (*Config)(nil) +// Factory allows us to have 1 set of metrics for all oauth2 providers. +// Primarily to avoid any prometheus errors registering duplicate metrics. type Factory struct { metrics *metrics } @@ -107,10 +108,11 @@ func (c *Config) wrapClient(ctx context.Context, source string) context.Context return context.WithValue(ctx, oauth2.HTTPClient, c.oauthHTTPClient(ctx, source)) } +// oauthHTTPClient returns an http client that will instrument every request made. func (c *Config) oauthHTTPClient(ctx context.Context, source string) *http.Client { cli := &http.Client{} - // Check if the context has an http client already. + // Check if the context has a http client already. if hc, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok { cli = hc } @@ -126,6 +128,8 @@ type instrumentedTripper struct { underlying http.RoundTripper } +// newInstrumentedTripper intercepts a http request, and increments the +// externalRequestCount metric. func newInstrumentedTripper(c *Config, source string, under http.RoundTripper) *instrumentedTripper { if under == nil { under = http.DefaultTransport From 9d1c76c96c6ff19905efb9b69366c29048dff944 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 10 Jan 2024 08:41:32 -0600 Subject: [PATCH 16/20] use consts for source labels --- coderd/externalauth/externalauth.go | 6 +++--- coderd/promoauth/oauth2.go | 28 +++++++++++++++++++--------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/coderd/externalauth/externalauth.go b/coderd/externalauth/externalauth.go index 904de81ab2abf..d08bac359fcb3 100644 --- a/coderd/externalauth/externalauth.go +++ b/coderd/externalauth/externalauth.go @@ -188,7 +188,7 @@ func (c *Config) ValidateToken(ctx context.Context, token string) (bool, *coders } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - res, err := c.InstrumentedOAuth2Config.Do(ctx, "ValidateToken", req) + res, err := c.InstrumentedOAuth2Config.Do(ctx, promoauth.SourceValidateToken, req) if err != nil { return false, nil, err } @@ -238,7 +238,7 @@ func (c *Config) AppInstallations(ctx context.Context, token string) ([]codersdk return nil, false, err } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - res, err := c.InstrumentedOAuth2Config.Do(ctx, "AppInstallations", req) + res, err := c.InstrumentedOAuth2Config.Do(ctx, promoauth.SourceAppInstallations, req) if err != nil { return nil, false, err } @@ -305,7 +305,7 @@ func (c *DeviceAuth) AuthorizeDevice(ctx context.Context) (*codersdk.ExternalAut if c.Cfg != nil { // The cfg can be nil in unit tests. do = func(req *http.Request) (*http.Response, error) { - return c.Cfg.Do(ctx, "AuthorizeDevice", req) + return c.Cfg.Do(ctx, promoauth.SourceAuthorizeDevice, req) } } diff --git a/coderd/promoauth/oauth2.go b/coderd/promoauth/oauth2.go index bd74610a5f8c1..d3d168b8ec96c 100644 --- a/coderd/promoauth/oauth2.go +++ b/coderd/promoauth/oauth2.go @@ -10,6 +10,16 @@ import ( "golang.org/x/oauth2" ) +type Oauth2Source string + +const ( + SourceValidateToken Oauth2Source = "ValidateToken" + SourceExchange Oauth2Source = "Exchange" + SourceTokenSource Oauth2Source = "TokenSource" + SourceAppInstallations Oauth2Source = "AppInstallations" + SourceAuthorizeDevice Oauth2Source = "AuthorizeDevice" +) + // OAuth2Config exposes a subset of *oauth2.Config functions for easier testing. // *oauth2.Config should be used instead of implementing this in production. type OAuth2Config interface { @@ -27,7 +37,7 @@ type InstrumentedOAuth2Config interface { // Do is provided as a convenience method to make a request with the oauth2 client. // It mirrors `http.Client.Do`. - Do(ctx context.Context, source string, req *http.Request) (*http.Response, error) + Do(ctx context.Context, source Oauth2Source, req *http.Request) (*http.Response, error) } var _ OAuth2Config = (*Config)(nil) @@ -79,7 +89,7 @@ type Config struct { metrics *metrics } -func (c *Config) Do(ctx context.Context, source string, req *http.Request) (*http.Response, error) { +func (c *Config) Do(ctx context.Context, source Oauth2Source, req *http.Request) (*http.Response, error) { cli := c.oauthHTTPClient(ctx, source) return cli.Do(req) } @@ -90,11 +100,11 @@ func (c *Config) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string } func (c *Config) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { - return c.underlying.Exchange(c.wrapClient(ctx, "Exchange"), code, opts...) + return c.underlying.Exchange(c.wrapClient(ctx, SourceExchange), code, opts...) } func (c *Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.TokenSource { - return c.underlying.TokenSource(c.wrapClient(ctx, "TokenSource"), token) + return c.underlying.TokenSource(c.wrapClient(ctx, SourceTokenSource), token) } // wrapClient is the only way we can accurately instrument the oauth2 client. @@ -104,12 +114,12 @@ func (c *Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.To // For example, the 'TokenSource' method will return a token // source that will make a network request when the 'Token' method is called on // it if the token is expired. -func (c *Config) wrapClient(ctx context.Context, source string) context.Context { +func (c *Config) wrapClient(ctx context.Context, source Oauth2Source) context.Context { return context.WithValue(ctx, oauth2.HTTPClient, c.oauthHTTPClient(ctx, source)) } // oauthHTTPClient returns an http client that will instrument every request made. -func (c *Config) oauthHTTPClient(ctx context.Context, source string) *http.Client { +func (c *Config) oauthHTTPClient(ctx context.Context, source Oauth2Source) *http.Client { cli := &http.Client{} // Check if the context has a http client already. @@ -124,13 +134,13 @@ func (c *Config) oauthHTTPClient(ctx context.Context, source string) *http.Clien type instrumentedTripper struct { c *Config - source string + source Oauth2Source underlying http.RoundTripper } // newInstrumentedTripper intercepts a http request, and increments the // externalRequestCount metric. -func newInstrumentedTripper(c *Config, source string, under http.RoundTripper) *instrumentedTripper { +func newInstrumentedTripper(c *Config, source Oauth2Source, under http.RoundTripper) *instrumentedTripper { if under == nil { under = http.DefaultTransport } @@ -156,7 +166,7 @@ func (i *instrumentedTripper) RoundTrip(r *http.Request) (*http.Response, error) } i.c.metrics.externalRequestCount.With(prometheus.Labels{ "name": i.c.name, - "source": i.source, + "source": string(i.source), "status_code": fmt.Sprintf("%d", statusCode), }).Inc() return resp, err From c149f8f1183db7553820112d8e515624c601db89 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 10 Jan 2024 08:43:59 -0600 Subject: [PATCH 17/20] Spell out config --- coderd/externalauth/externalauth.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/coderd/externalauth/externalauth.go b/coderd/externalauth/externalauth.go index d08bac359fcb3..09d7d829ba2ce 100644 --- a/coderd/externalauth/externalauth.go +++ b/coderd/externalauth/externalauth.go @@ -278,8 +278,8 @@ func (c *Config) AppInstallations(ctx context.Context, token string) ([]codersdk } type DeviceAuth struct { - // Cfg is provided for the http client method. - Cfg promoauth.InstrumentedOAuth2Config + // Config is provided for the http client method. + Config promoauth.InstrumentedOAuth2Config ClientID string TokenURL string Scopes []string @@ -302,10 +302,10 @@ func (c *DeviceAuth) AuthorizeDevice(ctx context.Context) (*codersdk.ExternalAut } do := http.DefaultClient.Do - if c.Cfg != nil { + if c.Config != nil { // The cfg can be nil in unit tests. do = func(req *http.Request) (*http.Response, error) { - return c.Cfg.Do(ctx, promoauth.SourceAuthorizeDevice, req) + return c.Config.Do(ctx, promoauth.SourceAuthorizeDevice, req) } } @@ -483,7 +483,7 @@ func ConvertConfig(instrument *promoauth.Factory, entries []codersdk.ExternalAut return nil, xerrors.Errorf("external auth provider %q: device auth url must be provided", entry.ID) } cfg.DeviceAuth = &DeviceAuth{ - Cfg: cfg, + Config: cfg, ClientID: entry.ClientID, TokenURL: oc.Endpoint.TokenURL, Scopes: entry.Scopes, From 30c459f7837726ed786e6efb3b3aece8ac460325 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 10 Jan 2024 08:45:00 -0600 Subject: [PATCH 18/20] fix compile --- testutil/oauth2.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/testutil/oauth2.go b/testutil/oauth2.go index 3bb22b0a03f5a..196e2e7bf712e 100644 --- a/testutil/oauth2.go +++ b/testutil/oauth2.go @@ -7,6 +7,8 @@ import ( "time" "golang.org/x/oauth2" + + "github.com/coder/coder/v2/coderd/promoauth" ) type OAuth2Config struct { @@ -14,7 +16,7 @@ type OAuth2Config struct { TokenSourceFunc OAuth2TokenSource } -func (*OAuth2Config) Do(_ context.Context, _ string, req *http.Request) (*http.Response, error) { +func (*OAuth2Config) Do(_ context.Context, _ promoauth.Oauth2Source, req *http.Request) (*http.Response, error) { return http.DefaultClient.Do(req) } From cd988067bbc78a030fab4553b7e3b513267e5690 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 10 Jan 2024 08:47:29 -0600 Subject: [PATCH 19/20] Use req with context --- coderd/promoauth/oauth2_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/coderd/promoauth/oauth2_test.go b/coderd/promoauth/oauth2_test.go index 78466ddceff21..2a74925f0f7fa 100644 --- a/coderd/promoauth/oauth2_test.go +++ b/coderd/promoauth/oauth2_test.go @@ -59,10 +59,9 @@ func TestInstrument(t *testing.T) { // Verify the default client was not broken. This check is added because we // extend the http.DefaultTransport. If a `.Clone()` is not done, this can be // mis-used. It is cheap to run this quick check. - req, err := http.NewRequest(http.MethodGet, + req, err := http.NewRequestWithContext(ctx, http.MethodGet, must(idp.IssuerURL().Parse("/.well-known/openid-configuration")).String(), nil) require.NoError(t, err) - req = req.WithContext(ctx) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) From 85e2d91ace97cc1177db1c67c7aba6697899824f Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 10 Jan 2024 08:49:42 -0600 Subject: [PATCH 20/20] no panic --- coderd/promoauth/oauth2_test.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/coderd/promoauth/oauth2_test.go b/coderd/promoauth/oauth2_test.go index 2a74925f0f7fa..b9c72f95a3a21 100644 --- a/coderd/promoauth/oauth2_test.go +++ b/coderd/promoauth/oauth2_test.go @@ -2,6 +2,7 @@ package promoauth_test import ( "net/http" + "net/url" "testing" "time" @@ -30,7 +31,7 @@ func TestInstrument(t *testing.T) { cfg := externalauth.Config{ InstrumentedOAuth2Config: factory.New(id, idp.OIDCConfig(t, []string{})), ID: "test", - ValidateURL: must(idp.IssuerURL().Parse("/oauth2/userinfo")).String(), + ValidateURL: must[*url.URL](t)(idp.IssuerURL().Parse("/oauth2/userinfo")).String(), } // 0 Requests before we start @@ -60,7 +61,7 @@ func TestInstrument(t *testing.T) { // extend the http.DefaultTransport. If a `.Clone()` is not done, this can be // mis-used. It is cheap to run this quick check. req, err := http.NewRequestWithContext(ctx, http.MethodGet, - must(idp.IssuerURL().Parse("/.well-known/openid-configuration")).String(), nil) + must[*url.URL](t)(idp.IssuerURL().Parse("/.well-known/openid-configuration")).String(), nil) require.NoError(t, err) resp, err := http.DefaultClient.Do(req) @@ -70,9 +71,10 @@ func TestInstrument(t *testing.T) { require.Equal(t, count(), 3) } -func must[V any](v V, err error) V { - if err != nil { - panic(err) +func must[V any](t *testing.T) func(v V, err error) V { + return func(v V, err error) V { + t.Helper() + require.NoError(t, err) + return v } - return v }