Skip to content

Commit d3a56ae

Browse files
authored
feat: enable GitHub OAuth2 login by default on new deployments (coder#16662)
Third and final PR to address coder#16230. This PR enables GitHub OAuth2 login by default on new deployments. Combined with coder#16629, this will allow the first admin user to sign up with GitHub rather than email and password. We take care not to enable the default on deployments that would upgrade to a Coder version with this change. To disable the default provider an admin can set the `CODER_OAUTH2_GITHUB_DEFAULT_PROVIDER` env variable to false.
1 parent 67d89bb commit d3a56ae

25 files changed

+544
-83
lines changed

cli/server.go

Lines changed: 114 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -688,24 +688,6 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
688688
}
689689
}
690690

691-
if vals.OAuth2.Github.ClientSecret != "" || vals.OAuth2.Github.DeviceFlow.Value() {
692-
options.GithubOAuth2Config, err = configureGithubOAuth2(
693-
oauthInstrument,
694-
vals.AccessURL.Value(),
695-
vals.OAuth2.Github.ClientID.String(),
696-
vals.OAuth2.Github.ClientSecret.String(),
697-
vals.OAuth2.Github.DeviceFlow.Value(),
698-
vals.OAuth2.Github.AllowSignups.Value(),
699-
vals.OAuth2.Github.AllowEveryone.Value(),
700-
vals.OAuth2.Github.AllowedOrgs,
701-
vals.OAuth2.Github.AllowedTeams,
702-
vals.OAuth2.Github.EnterpriseBaseURL.String(),
703-
)
704-
if err != nil {
705-
return xerrors.Errorf("configure github oauth2: %w", err)
706-
}
707-
}
708-
709691
// As OIDC clients can be confidential or public,
710692
// we should only check for a client id being set.
711693
// The underlying library handles the case of no
@@ -793,6 +775,20 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
793775
return xerrors.Errorf("set deployment id: %w", err)
794776
}
795777

778+
githubOAuth2ConfigParams, err := getGithubOAuth2ConfigParams(ctx, options.Database, vals)
779+
if err != nil {
780+
return xerrors.Errorf("get github oauth2 config params: %w", err)
781+
}
782+
if githubOAuth2ConfigParams != nil {
783+
options.GithubOAuth2Config, err = configureGithubOAuth2(
784+
oauthInstrument,
785+
githubOAuth2ConfigParams,
786+
)
787+
if err != nil {
788+
return xerrors.Errorf("configure github oauth2: %w", err)
789+
}
790+
}
791+
796792
options.RuntimeConfig = runtimeconfig.NewManager()
797793

798794
// This should be output before the logs start streaming.
@@ -1843,25 +1839,101 @@ func configureCAPool(tlsClientCAFile string, tlsConfig *tls.Config) error {
18431839
return nil
18441840
}
18451841

1846-
// TODO: convert the argument list to a struct, it's easy to mix up the order of the arguments
1847-
//
1842+
const (
1843+
// Client ID for https://github.com/apps/coder
1844+
GithubOAuth2DefaultProviderClientID = "Iv1.6a2b4b4aec4f4fe7"
1845+
GithubOAuth2DefaultProviderAllowEveryone = true
1846+
GithubOAuth2DefaultProviderDeviceFlow = true
1847+
)
1848+
1849+
type githubOAuth2ConfigParams struct {
1850+
accessURL *url.URL
1851+
clientID string
1852+
clientSecret string
1853+
deviceFlow bool
1854+
allowSignups bool
1855+
allowEveryone bool
1856+
allowOrgs []string
1857+
rawTeams []string
1858+
enterpriseBaseURL string
1859+
}
1860+
1861+
func getGithubOAuth2ConfigParams(ctx context.Context, db database.Store, vals *codersdk.DeploymentValues) (*githubOAuth2ConfigParams, error) {
1862+
params := githubOAuth2ConfigParams{
1863+
accessURL: vals.AccessURL.Value(),
1864+
clientID: vals.OAuth2.Github.ClientID.String(),
1865+
clientSecret: vals.OAuth2.Github.ClientSecret.String(),
1866+
deviceFlow: vals.OAuth2.Github.DeviceFlow.Value(),
1867+
allowSignups: vals.OAuth2.Github.AllowSignups.Value(),
1868+
allowEveryone: vals.OAuth2.Github.AllowEveryone.Value(),
1869+
allowOrgs: vals.OAuth2.Github.AllowedOrgs.Value(),
1870+
rawTeams: vals.OAuth2.Github.AllowedTeams.Value(),
1871+
enterpriseBaseURL: vals.OAuth2.Github.EnterpriseBaseURL.String(),
1872+
}
1873+
1874+
// If the user manually configured the GitHub OAuth2 provider,
1875+
// we won't add the default configuration.
1876+
if params.clientID != "" || params.clientSecret != "" || params.enterpriseBaseURL != "" {
1877+
return &params, nil
1878+
}
1879+
1880+
// Check if the user manually disabled the default GitHub OAuth2 provider.
1881+
if !vals.OAuth2.Github.DefaultProviderEnable.Value() {
1882+
return nil, nil //nolint:nilnil
1883+
}
1884+
1885+
// Check if the deployment is eligible for the default GitHub OAuth2 provider.
1886+
// We want to enable it only for new deployments, and avoid enabling it
1887+
// if a deployment was upgraded from an older version.
1888+
// nolint:gocritic // Requires system privileges
1889+
defaultEligible, err := db.GetOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx))
1890+
if err != nil && !errors.Is(err, sql.ErrNoRows) {
1891+
return nil, xerrors.Errorf("get github default eligible: %w", err)
1892+
}
1893+
defaultEligibleNotSet := errors.Is(err, sql.ErrNoRows)
1894+
1895+
if defaultEligibleNotSet {
1896+
// nolint:gocritic // User count requires system privileges
1897+
userCount, err := db.GetUserCount(dbauthz.AsSystemRestricted(ctx))
1898+
if err != nil {
1899+
return nil, xerrors.Errorf("get user count: %w", err)
1900+
}
1901+
// We check if a deployment is new by checking if it has any users.
1902+
defaultEligible = userCount == 0
1903+
// nolint:gocritic // Requires system privileges
1904+
if err := db.UpsertOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx), defaultEligible); err != nil {
1905+
return nil, xerrors.Errorf("upsert github default eligible: %w", err)
1906+
}
1907+
}
1908+
1909+
if !defaultEligible {
1910+
return nil, nil //nolint:nilnil
1911+
}
1912+
1913+
params.clientID = GithubOAuth2DefaultProviderClientID
1914+
params.allowEveryone = GithubOAuth2DefaultProviderAllowEveryone
1915+
params.deviceFlow = GithubOAuth2DefaultProviderDeviceFlow
1916+
1917+
return &params, nil
1918+
}
1919+
18481920
//nolint:revive // Ignore flag-parameter: parameter 'allowEveryone' seems to be a control flag, avoid control coupling (revive)
1849-
func configureGithubOAuth2(instrument *promoauth.Factory, accessURL *url.URL, clientID, clientSecret string, deviceFlow, allowSignups, allowEveryone bool, allowOrgs []string, rawTeams []string, enterpriseBaseURL string) (*coderd.GithubOAuth2Config, error) {
1850-
redirectURL, err := accessURL.Parse("/api/v2/users/oauth2/github/callback")
1921+
func configureGithubOAuth2(instrument *promoauth.Factory, params *githubOAuth2ConfigParams) (*coderd.GithubOAuth2Config, error) {
1922+
redirectURL, err := params.accessURL.Parse("/api/v2/users/oauth2/github/callback")
18511923
if err != nil {
18521924
return nil, xerrors.Errorf("parse github oauth callback url: %w", err)
18531925
}
1854-
if allowEveryone && len(allowOrgs) > 0 {
1926+
if params.allowEveryone && len(params.allowOrgs) > 0 {
18551927
return nil, xerrors.New("allow everyone and allowed orgs cannot be used together")
18561928
}
1857-
if allowEveryone && len(rawTeams) > 0 {
1929+
if params.allowEveryone && len(params.rawTeams) > 0 {
18581930
return nil, xerrors.New("allow everyone and allowed teams cannot be used together")
18591931
}
1860-
if !allowEveryone && len(allowOrgs) == 0 {
1932+
if !params.allowEveryone && len(params.allowOrgs) == 0 {
18611933
return nil, xerrors.New("allowed orgs is empty: must specify at least one org or allow everyone")
18621934
}
1863-
allowTeams := make([]coderd.GithubOAuth2Team, 0, len(rawTeams))
1864-
for _, rawTeam := range rawTeams {
1935+
allowTeams := make([]coderd.GithubOAuth2Team, 0, len(params.rawTeams))
1936+
for _, rawTeam := range params.rawTeams {
18651937
parts := strings.SplitN(rawTeam, "/", 2)
18661938
if len(parts) != 2 {
18671939
return nil, xerrors.Errorf("github team allowlist is formatted incorrectly. got %s; wanted <organization>/<team>", rawTeam)
@@ -1873,8 +1945,8 @@ func configureGithubOAuth2(instrument *promoauth.Factory, accessURL *url.URL, cl
18731945
}
18741946

18751947
endpoint := xgithub.Endpoint
1876-
if enterpriseBaseURL != "" {
1877-
enterpriseURL, err := url.Parse(enterpriseBaseURL)
1948+
if params.enterpriseBaseURL != "" {
1949+
enterpriseURL, err := url.Parse(params.enterpriseBaseURL)
18781950
if err != nil {
18791951
return nil, xerrors.Errorf("parse enterprise base url: %w", err)
18801952
}
@@ -1893,8 +1965,8 @@ func configureGithubOAuth2(instrument *promoauth.Factory, accessURL *url.URL, cl
18931965
}
18941966

18951967
instrumentedOauth := instrument.NewGithub("github-login", &oauth2.Config{
1896-
ClientID: clientID,
1897-
ClientSecret: clientSecret,
1968+
ClientID: params.clientID,
1969+
ClientSecret: params.clientSecret,
18981970
Endpoint: endpoint,
18991971
RedirectURL: redirectURL.String(),
19001972
Scopes: []string{
@@ -1906,17 +1978,17 @@ func configureGithubOAuth2(instrument *promoauth.Factory, accessURL *url.URL, cl
19061978

19071979
createClient := func(client *http.Client, source promoauth.Oauth2Source) (*github.Client, error) {
19081980
client = instrumentedOauth.InstrumentHTTPClient(client, source)
1909-
if enterpriseBaseURL != "" {
1910-
return github.NewEnterpriseClient(enterpriseBaseURL, "", client)
1981+
if params.enterpriseBaseURL != "" {
1982+
return github.NewEnterpriseClient(params.enterpriseBaseURL, "", client)
19111983
}
19121984
return github.NewClient(client), nil
19131985
}
19141986

19151987
var deviceAuth *externalauth.DeviceAuth
1916-
if deviceFlow {
1988+
if params.deviceFlow {
19171989
deviceAuth = &externalauth.DeviceAuth{
19181990
Config: instrumentedOauth,
1919-
ClientID: clientID,
1991+
ClientID: params.clientID,
19201992
TokenURL: endpoint.TokenURL,
19211993
Scopes: []string{"read:user", "read:org", "user:email"},
19221994
CodeURL: endpoint.DeviceAuthURL,
@@ -1925,9 +1997,9 @@ func configureGithubOAuth2(instrument *promoauth.Factory, accessURL *url.URL, cl
19251997

19261998
return &coderd.GithubOAuth2Config{
19271999
OAuth2Config: instrumentedOauth,
1928-
AllowSignups: allowSignups,
1929-
AllowEveryone: allowEveryone,
1930-
AllowOrganizations: allowOrgs,
2000+
AllowSignups: params.allowSignups,
2001+
AllowEveryone: params.allowEveryone,
2002+
AllowOrganizations: params.allowOrgs,
19312003
AllowTeams: allowTeams,
19322004
AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) {
19332005
api, err := createClient(client, promoauth.SourceGitAPIAuthUser)
@@ -1966,19 +2038,20 @@ func configureGithubOAuth2(instrument *promoauth.Factory, accessURL *url.URL, cl
19662038
team, _, err := api.Teams.GetTeamMembershipBySlug(ctx, org, teamSlug, username)
19672039
return team, err
19682040
},
1969-
DeviceFlowEnabled: deviceFlow,
2041+
DeviceFlowEnabled: params.deviceFlow,
19702042
ExchangeDeviceCode: func(ctx context.Context, deviceCode string) (*oauth2.Token, error) {
1971-
if !deviceFlow {
2043+
if !params.deviceFlow {
19722044
return nil, xerrors.New("device flow is not enabled")
19732045
}
19742046
return deviceAuth.ExchangeDeviceCode(ctx, deviceCode)
19752047
},
19762048
AuthorizeDevice: func(ctx context.Context) (*codersdk.ExternalAuthDevice, error) {
1977-
if !deviceFlow {
2049+
if !params.deviceFlow {
19782050
return nil, xerrors.New("device flow is not enabled")
19792051
}
19802052
return deviceAuth.AuthorizeDevice(ctx)
19812053
},
2054+
DefaultProviderConfigured: params.clientID == GithubOAuth2DefaultProviderClientID,
19822055
}, nil
19832056
}
19842057

cli/server_test.go

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ import (
4545
"github.com/coder/coder/v2/cli/clitest"
4646
"github.com/coder/coder/v2/cli/config"
4747
"github.com/coder/coder/v2/coderd/coderdtest"
48+
"github.com/coder/coder/v2/coderd/database"
49+
"github.com/coder/coder/v2/coderd/database/dbgen"
4850
"github.com/coder/coder/v2/coderd/database/dbtestutil"
4951
"github.com/coder/coder/v2/coderd/database/migrations"
5052
"github.com/coder/coder/v2/coderd/httpapi"
@@ -306,6 +308,145 @@ func TestServer(t *testing.T) {
306308
require.Less(t, numLines, 20)
307309
})
308310

311+
t.Run("OAuth2GitHubDefaultProvider", func(t *testing.T) {
312+
type testCase struct {
313+
name string
314+
githubDefaultProviderEnabled string
315+
githubClientID string
316+
githubClientSecret string
317+
expectGithubEnabled bool
318+
expectGithubDefaultProviderConfigured bool
319+
createUserPreStart bool
320+
createUserPostRestart bool
321+
}
322+
323+
runGitHubProviderTest := func(t *testing.T, tc testCase) {
324+
t.Parallel()
325+
if !dbtestutil.WillUsePostgres() {
326+
t.Skip("test requires postgres")
327+
}
328+
329+
ctx, cancelFunc := context.WithCancel(testutil.Context(t, testutil.WaitLong))
330+
defer cancelFunc()
331+
332+
dbURL, err := dbtestutil.Open(t)
333+
require.NoError(t, err)
334+
db, _ := dbtestutil.NewDB(t, dbtestutil.WithURL(dbURL))
335+
336+
if tc.createUserPreStart {
337+
_ = dbgen.User(t, db, database.User{})
338+
}
339+
340+
args := []string{
341+
"server",
342+
"--postgres-url", dbURL,
343+
"--http-address", ":0",
344+
"--access-url", "https://example.com",
345+
}
346+
if tc.githubClientID != "" {
347+
args = append(args, fmt.Sprintf("--oauth2-github-client-id=%s", tc.githubClientID))
348+
}
349+
if tc.githubClientSecret != "" {
350+
args = append(args, fmt.Sprintf("--oauth2-github-client-secret=%s", tc.githubClientSecret))
351+
}
352+
if tc.githubClientID != "" || tc.githubClientSecret != "" {
353+
args = append(args, "--oauth2-github-allow-everyone")
354+
}
355+
if tc.githubDefaultProviderEnabled != "" {
356+
args = append(args, fmt.Sprintf("--oauth2-github-default-provider-enable=%s", tc.githubDefaultProviderEnabled))
357+
}
358+
359+
inv, cfg := clitest.New(t, args...)
360+
errChan := make(chan error, 1)
361+
go func() {
362+
errChan <- inv.WithContext(ctx).Run()
363+
}()
364+
accessURLChan := make(chan *url.URL, 1)
365+
go func() {
366+
accessURLChan <- waitAccessURL(t, cfg)
367+
}()
368+
369+
var accessURL *url.URL
370+
select {
371+
case err := <-errChan:
372+
require.NoError(t, err)
373+
case accessURL = <-accessURLChan:
374+
require.NotNil(t, accessURL)
375+
}
376+
377+
client := codersdk.New(accessURL)
378+
379+
authMethods, err := client.AuthMethods(ctx)
380+
require.NoError(t, err)
381+
require.Equal(t, tc.expectGithubEnabled, authMethods.Github.Enabled)
382+
require.Equal(t, tc.expectGithubDefaultProviderConfigured, authMethods.Github.DefaultProviderConfigured)
383+
384+
cancelFunc()
385+
select {
386+
case err := <-errChan:
387+
require.NoError(t, err)
388+
case <-time.After(testutil.WaitLong):
389+
t.Fatal("server did not exit")
390+
}
391+
392+
if tc.createUserPostRestart {
393+
_ = dbgen.User(t, db, database.User{})
394+
}
395+
396+
// Ensure that it stays at that setting after the server restarts.
397+
inv, cfg = clitest.New(t, args...)
398+
clitest.Start(t, inv)
399+
accessURL = waitAccessURL(t, cfg)
400+
client = codersdk.New(accessURL)
401+
402+
ctx = testutil.Context(t, testutil.WaitLong)
403+
authMethods, err = client.AuthMethods(ctx)
404+
require.NoError(t, err)
405+
require.Equal(t, tc.expectGithubEnabled, authMethods.Github.Enabled)
406+
require.Equal(t, tc.expectGithubDefaultProviderConfigured, authMethods.Github.DefaultProviderConfigured)
407+
}
408+
409+
for _, tc := range []testCase{
410+
{
411+
name: "NewDeployment",
412+
expectGithubEnabled: true,
413+
expectGithubDefaultProviderConfigured: true,
414+
createUserPreStart: false,
415+
createUserPostRestart: true,
416+
},
417+
{
418+
name: "ExistingDeployment",
419+
expectGithubEnabled: false,
420+
expectGithubDefaultProviderConfigured: false,
421+
createUserPreStart: true,
422+
createUserPostRestart: false,
423+
},
424+
{
425+
name: "ManuallyDisabled",
426+
githubDefaultProviderEnabled: "false",
427+
expectGithubEnabled: false,
428+
expectGithubDefaultProviderConfigured: false,
429+
},
430+
{
431+
name: "ConfiguredClientID",
432+
githubClientID: "123",
433+
expectGithubEnabled: true,
434+
expectGithubDefaultProviderConfigured: false,
435+
},
436+
{
437+
name: "ConfiguredClientSecret",
438+
githubClientSecret: "456",
439+
expectGithubEnabled: true,
440+
expectGithubDefaultProviderConfigured: false,
441+
},
442+
} {
443+
tc := tc
444+
t.Run(tc.name, func(t *testing.T) {
445+
runGitHubProviderTest(t, tc)
446+
})
447+
}
448+
})
449+
309450
// Validate that a warning is printed that it may not be externally
310451
// reachable.
311452
t.Run("LocalAccessURL", func(t *testing.T) {

0 commit comments

Comments
 (0)