diff --git a/cli/server.go b/cli/server.go index 170afea8f984a..4fa36c42a211e 100644 --- a/cli/server.go +++ b/cli/server.go @@ -87,6 +87,7 @@ func server() *cobra.Command { oauth2GithubAllowedOrganizations []string oauth2GithubAllowedTeams []string oauth2GithubAllowSignups bool + oauth2GithubEnterpriseBaseURL string oidcAllowSignups bool oidcClientID string oidcClientSecret string @@ -286,7 +287,7 @@ func server() *cobra.Command { } if oauth2GithubClientSecret != "" { - options.GithubOAuth2Config, err = configureGithubOAuth2(accessURLParsed, oauth2GithubClientID, oauth2GithubClientSecret, oauth2GithubAllowSignups, oauth2GithubAllowedOrganizations, oauth2GithubAllowedTeams) + options.GithubOAuth2Config, err = configureGithubOAuth2(accessURLParsed, oauth2GithubClientID, oauth2GithubClientSecret, oauth2GithubAllowSignups, oauth2GithubAllowedOrganizations, oauth2GithubAllowedTeams, oauth2GithubEnterpriseBaseURL) if err != nil { return xerrors.Errorf("configure github oauth2: %w", err) } @@ -689,6 +690,8 @@ func server() *cobra.Command { "Specifies teams inside organizations the user must be a member of to authenticate with GitHub. Formatted as: /.") cliflag.BoolVarP(root.Flags(), &oauth2GithubAllowSignups, "oauth2-github-allow-signups", "", "CODER_OAUTH2_GITHUB_ALLOW_SIGNUPS", false, "Specifies whether new users can sign up with GitHub.") + cliflag.StringVarP(root.Flags(), &oauth2GithubEnterpriseBaseURL, "oauth2-github-enterprise-base-url", "", "CODER_OAUTH2_GITHUB_ENTERPRISE_BASE_URL", "", + "Specifies the base URL of a GitHub Enterprise instance to use for oauth2.") cliflag.BoolVarP(root.Flags(), &oidcAllowSignups, "oidc-allow-signups", "", "CODER_OIDC_ALLOW_SIGNUPS", true, "Specifies whether new users can sign up with OIDC.") cliflag.StringVarP(root.Flags(), &oidcClientID, "oidc-client-id", "", "CODER_OIDC_CLIENT_ID", "", @@ -966,7 +969,7 @@ func configureTLS(listener net.Listener, tlsMinVersion, tlsClientAuth, tlsCertFi return tls.NewListener(listener, tlsConfig), nil } -func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string, allowSignups bool, allowOrgs []string, rawTeams []string) (*coderd.GithubOAuth2Config, error) { +func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string, allowSignups bool, allowOrgs []string, rawTeams []string, enterpriseBaseURL string) (*coderd.GithubOAuth2Config, error) { redirectURL, err := accessURL.Parse("/api/v2/users/oauth2/github/callback") if err != nil { return nil, xerrors.Errorf("parse github oauth callback url: %w", err) @@ -982,11 +985,38 @@ func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string, al Slug: parts[1], }) } + createClient := func(client *http.Client) (*github.Client, error) { + if enterpriseBaseURL != "" { + return github.NewEnterpriseClient(enterpriseBaseURL, "", client) + } + return github.NewClient(client), nil + } + + endpoint := xgithub.Endpoint + if enterpriseBaseURL != "" { + enterpriseURL, err := url.Parse(enterpriseBaseURL) + if err != nil { + return nil, xerrors.Errorf("parse enterprise base url: %w", err) + } + authURL, err := enterpriseURL.Parse("/login/oauth/authorize") + if err != nil { + return nil, xerrors.Errorf("parse enterprise auth url: %w", err) + } + tokenURL, err := enterpriseURL.Parse("/login/oauth/access_token") + if err != nil { + return nil, xerrors.Errorf("parse enterprise token url: %w", err) + } + endpoint = oauth2.Endpoint{ + AuthURL: authURL.String(), + TokenURL: tokenURL.String(), + } + } + return &coderd.GithubOAuth2Config{ OAuth2Config: &oauth2.Config{ ClientID: clientID, ClientSecret: clientSecret, - Endpoint: xgithub.Endpoint, + Endpoint: endpoint, RedirectURL: redirectURL.String(), Scopes: []string{ "read:user", @@ -998,15 +1028,27 @@ func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string, al AllowOrganizations: allowOrgs, AllowTeams: allowTeams, AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) { - user, _, err := github.NewClient(client).Users.Get(ctx, "") + api, err := createClient(client) + if err != nil { + return nil, err + } + user, _, err := api.Users.Get(ctx, "") return user, err }, ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) { - emails, _, err := github.NewClient(client).Users.ListEmails(ctx, &github.ListOptions{}) + api, err := createClient(client) + if err != nil { + return nil, err + } + emails, _, err := api.Users.ListEmails(ctx, &github.ListOptions{}) return emails, err }, ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) { - memberships, _, err := github.NewClient(client).Organizations.ListOrgMemberships(ctx, &github.ListOrgMembershipsOptions{ + api, err := createClient(client) + if err != nil { + return nil, err + } + memberships, _, err := api.Organizations.ListOrgMemberships(ctx, &github.ListOrgMembershipsOptions{ State: "active", ListOptions: github.ListOptions{ PerPage: 100, @@ -1015,7 +1057,11 @@ func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string, al return memberships, err }, TeamMembership: func(ctx context.Context, client *http.Client, org, teamSlug, username string) (*github.Membership, error) { - team, _, err := github.NewClient(client).Teams.GetTeamMembershipBySlug(ctx, org, teamSlug, username) + api, err := createClient(client) + if err != nil { + return nil, err + } + team, _, err := api.Teams.GetTeamMembershipBySlug(ctx, org, teamSlug, username) return team, err }, }, nil diff --git a/cli/server_test.go b/cli/server_test.go index 226f4a9f2eabb..97027c83f35ee 100644 --- a/cli/server_test.go +++ b/cli/server_test.go @@ -429,6 +429,39 @@ func TestServer(t *testing.T) { cancelFunc() <-serverErr }) + t.Run("GitHubOAuth", func(t *testing.T) { + t.Parallel() + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + fakeRedirect := "https://fake-url.com" + root, cfg := clitest.New(t, + "server", + "--in-memory", + "--address", ":0", + "--oauth2-github-client-id", "fake", + "--oauth2-github-client-secret", "fake", + "--oauth2-github-enterprise-base-url", fakeRedirect, + ) + serverErr := make(chan error, 1) + go func() { + serverErr <- root.ExecuteContext(ctx) + }() + accessURL := waitAccessURL(t, cfg) + client := codersdk.New(accessURL) + client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + githubURL, err := accessURL.Parse("/api/v2/users/oauth2/github") + require.NoError(t, err) + res, err := client.HTTPClient.Get(githubURL.String()) + require.NoError(t, err) + fakeURL, err := res.Location() + require.NoError(t, err) + require.True(t, strings.HasPrefix(fakeURL.String(), fakeRedirect), fakeURL.String()) + cancelFunc() + <-serverErr + }) } func generateTLSCertificate(t testing.TB) (certPath, keyPath string) { diff --git a/docs/install/auth.md b/docs/install/auth.md index 76ed35b830b4a..bae7f1ed5f41d 100644 --- a/docs/install/auth.md +++ b/docs/install/auth.md @@ -25,6 +25,8 @@ server: coder server --oauth2-github-allow-signups=true --oauth2-github-allowed-orgs="your-org" --oauth2-github-client-id="8d1...e05" --oauth2-github-client-secret="57ebc9...02c24c" ``` +> For GitHub Enterprise support, specify the `--oauth2-github-enterprise-base-url` flag. + Alternatively, if you are running Coder as a system service, you can achieve the same result as the command above by adding the following environment variables to the `/etc/coder.d/coder.env` file: