diff --git a/cli/server.go b/cli/server.go index c862769e58b67..4631df82dfc44 100644 --- a/cli/server.go +++ b/cli/server.go @@ -1773,12 +1773,6 @@ func configureGithubOAuth2(instrument *promoauth.Factory, accessURL *url.URL, cl Slug: parts[1], }) } - createClient := func(client *http.Client) (*github.Client, error) { - if enterpriseBaseURL != "" { - return github.NewEnterpriseClient(enterpriseBaseURL, "", client) - } - return github.NewClient(client), nil - } endpoint := xgithub.Endpoint if enterpriseBaseURL != "" { @@ -1800,24 +1794,34 @@ func configureGithubOAuth2(instrument *promoauth.Factory, accessURL *url.URL, cl } } + instrumentedOauth := instrument.NewGithub("github-login", &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + Endpoint: endpoint, + RedirectURL: redirectURL.String(), + Scopes: []string{ + "read:user", + "read:org", + "user:email", + }, + }) + + createClient := func(client *http.Client, source promoauth.Oauth2Source) (*github.Client, error) { + client = instrumentedOauth.InstrumentHTTPClient(client, source) + if enterpriseBaseURL != "" { + return github.NewEnterpriseClient(enterpriseBaseURL, "", client) + } + return github.NewClient(client), nil + } + return &coderd.GithubOAuth2Config{ - OAuth2Config: instrument.NewGithub("github-login", &oauth2.Config{ - ClientID: clientID, - ClientSecret: clientSecret, - Endpoint: endpoint, - RedirectURL: redirectURL.String(), - Scopes: []string{ - "read:user", - "read:org", - "user:email", - }, - }), + OAuth2Config: instrumentedOauth, AllowSignups: allowSignups, AllowEveryone: allowEveryone, AllowOrganizations: allowOrgs, AllowTeams: allowTeams, AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) { - api, err := createClient(client) + api, err := createClient(client, promoauth.SourceGitAPIAuthUser) if err != nil { return nil, err } @@ -1825,7 +1829,7 @@ func configureGithubOAuth2(instrument *promoauth.Factory, accessURL *url.URL, cl return user, err }, ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) { - api, err := createClient(client) + api, err := createClient(client, promoauth.SourceGitAPIListEmails) if err != nil { return nil, err } @@ -1833,7 +1837,7 @@ func configureGithubOAuth2(instrument *promoauth.Factory, accessURL *url.URL, cl return emails, err }, ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) { - api, err := createClient(client) + api, err := createClient(client, promoauth.SourceGitAPIOrgMemberships) if err != nil { return nil, err } @@ -1846,7 +1850,7 @@ func configureGithubOAuth2(instrument *promoauth.Factory, accessURL *url.URL, cl return memberships, err }, TeamMembership: func(ctx context.Context, client *http.Client, org, teamSlug, username string) (*github.Membership, error) { - api, err := createClient(client) + api, err := createClient(client, promoauth.SourceGitAPITeamMemberships) if err != nil { return nil, err } diff --git a/coderd/promoauth/oauth2.go b/coderd/promoauth/oauth2.go index 258694563581c..30e5269cd319e 100644 --- a/coderd/promoauth/oauth2.go +++ b/coderd/promoauth/oauth2.go @@ -19,6 +19,11 @@ const ( SourceTokenSource Oauth2Source = "TokenSource" SourceAppInstallations Oauth2Source = "AppInstallations" SourceAuthorizeDevice Oauth2Source = "AuthorizeDevice" + + SourceGitAPIAuthUser Oauth2Source = "GitAPIAuthUser" + SourceGitAPIListEmails Oauth2Source = "GitAPIListEmails" + SourceGitAPIOrgMemberships Oauth2Source = "GitAPIOrgMemberships" + SourceGitAPITeamMemberships Oauth2Source = "GitAPITeamMemberships" ) // OAuth2Config exposes a subset of *oauth2.Config functions for easier testing. @@ -209,6 +214,12 @@ func (c *Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.To return c.underlying.TokenSource(c.wrapClient(ctx, SourceTokenSource), token) } +func (c *Config) InstrumentHTTPClient(hc *http.Client, source Oauth2Source) *http.Client { + // The new tripper will instrument every request made by the oauth2 client. + hc.Transport = newInstrumentedTripper(c, source, hc.Transport) + return hc +} + // wrapClient is the only way we can accurately instrument the oauth2 client. // This is because method calls to the 'OAuth2Config' interface are not 1:1 with // network requests. @@ -229,8 +240,7 @@ func (c *Config) oauthHTTPClient(ctx context.Context, source Oauth2Source) *http cli = hc } - // The new tripper will instrument every request made by the oauth2 client. - cli.Transport = newInstrumentedTripper(c, source, cli.Transport) + cli = c.InstrumentHTTPClient(cli, source) return cli }