From 66ecf589136cc04c6a9c8d90260b1a310cac5ac6 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 9 Oct 2023 21:22:26 +0000 Subject: [PATCH 1/9] feat: add enhanced support for Slack OAuth --- cli/server.go | 6 +++ coderd/coderdtest/oidctest/idp.go | 12 ++++++ coderd/externalauth/externalauth.go | 47 +++++++++++++++++----- coderd/externalauth/externalauth_test.go | 50 +++++++++++++++++++----- codersdk/deployment.go | 6 +++ codersdk/externalauth.go | 1 + site/static/icon/slack.svg | 6 +++ 7 files changed, 109 insertions(+), 19 deletions(-) create mode 100644 site/static/icon/slack.svg diff --git a/cli/server.go b/cli/server.go index f9ef1aaa65c8c..31a2b63a49660 100644 --- a/cli/server.go +++ b/cli/server.go @@ -2259,6 +2259,12 @@ func parseExternalAuthProvidersFromEnv(prefix string, environ []string) ([]coder provider.DisplayName = v.Value case "DISPLAY_ICON": provider.DisplayIcon = v.Value + case "SLACK_AUTHED_USER_TOKEN": + b, err := strconv.ParseBool(v.Value) + if err != nil { + return nil, xerrors.Errorf("parse bool: %s", v.Value) + } + provider.SlackAuthedUserToken = b } providers[providerNum] = provider } diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index 6f060aea2c6b6..8daf6a63720db 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -68,6 +68,7 @@ type FakeIDP struct { // "Authorized Redirect URLs". This can be used to emulate that. hookValidRedirectURL func(redirectURL string) error hookUserInfo func(email string) (jwt.MapClaims, error) + hookExtra func(email string) map[string]interface{} fakeCoderd func(req *http.Request) (*http.Response, error) hookOnRefresh func(email string) error // Custom authentication for the client. This is useful if you want @@ -112,6 +113,12 @@ func WithRefresh(hook func(email string) error) func(*FakeIDP) { } } +func WithExtra(extra func(email string) map[string]interface{}) func(*FakeIDP) { + return func(f *FakeIDP) { + f.hookExtra = extra + } +} + func WithCustomClientAuth(hook func(t testing.TB, req *http.Request) (url.Values, error)) func(*FakeIDP) { return func(f *FakeIDP) { f.hookAuthenticateClient = hook @@ -621,6 +628,11 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { "expires_in": int64((time.Minute * 5).Seconds()), "id_token": f.encodeClaims(t, claims), } + if f.hookExtra != nil { + for k, v := range f.hookExtra(email) { + token[k] = v + } + } // Store the claims for the next refresh f.refreshIDTokenClaims.Store(refreshToken, claims) diff --git a/coderd/externalauth/externalauth.go b/coderd/externalauth/externalauth.go index 34bd7e9253ee7..92a2aa28cb75c 100644 --- a/coderd/externalauth/externalauth.go +++ b/coderd/externalauth/externalauth.go @@ -67,6 +67,10 @@ type Config struct { // AppInstallationsURL is an API endpoint that returns a list of // installations for the user. This is used for GitHub Apps. AppInstallationsURL string + + // SlackAuthedUserToken is true if the user token should be returned + // instead of the bot token. + SlackAuthedUserToken bool } // RefreshToken automatically refreshes the token if expired and permitted. @@ -101,6 +105,22 @@ func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAu // we aren't trying to surface an error, we're just trying to obtain a valid token. return externalAuthLink, false, nil } + + // Slack's new OAuth2 flow has the user access token in a different field. + // It's weird and unfortunate, but the only way to access the user token. + // See: https://api.slack.com/authentication/oauth-v2#exchanging + if c.Type == string(codersdk.EnhancedExternalAuthProviderSlack) && c.SlackAuthedUserToken { + rawMap, ok := token.Extra("authed_user").(map[string]interface{}) + if !ok { + return externalAuthLink, false, xerrors.Errorf("slack: could not obtain user access token from payload: %+v", token.Extra("authed_user")) + } + accessToken, ok := rawMap["access_token"].(string) + if !ok { + return externalAuthLink, false, xerrors.Errorf("slack: could not obtain user access token from payload: %+v", token.Extra("authed_user")) + } + token.AccessToken = accessToken + } + r := retry.New(50*time.Millisecond, 200*time.Millisecond) // See the comment below why the retry and cancel is required. retryCtx, retryCtxCancel := context.WithTimeout(ctx, time.Second) @@ -424,16 +444,17 @@ func ConvertConfig(entries []codersdk.ExternalAuthConfig, accessURL *url.URL) ([ } cfg := &Config{ - OAuth2Config: oauthConfig, - ID: entry.ID, - Regex: regex, - Type: entry.Type, - NoRefresh: entry.NoRefresh, - ValidateURL: entry.ValidateURL, - AppInstallationsURL: entry.AppInstallationsURL, - AppInstallURL: entry.AppInstallURL, - DisplayName: entry.DisplayName, - DisplayIcon: entry.DisplayIcon, + OAuth2Config: oauthConfig, + ID: entry.ID, + Regex: regex, + Type: entry.Type, + NoRefresh: entry.NoRefresh, + ValidateURL: entry.ValidateURL, + AppInstallationsURL: entry.AppInstallationsURL, + AppInstallURL: entry.AppInstallURL, + DisplayName: entry.DisplayName, + DisplayIcon: entry.DisplayIcon, + SlackAuthedUserToken: entry.SlackAuthedUserToken, } if entry.DeviceFlow { @@ -539,6 +560,12 @@ var defaults = map[codersdk.EnhancedExternalAuthProvider]codersdk.ExternalAuthCo DeviceCodeURL: "https://github.com/login/device/code", AppInstallationsURL: "https://api.github.com/user/installations", }, + codersdk.EnhancedExternalAuthProviderSlack: { + AuthURL: "https://slack.com/oauth/v2/authorize", + TokenURL: "https://slack.com/api/oauth.v2.access", + DisplayName: "Slack", + DisplayIcon: "/icon/slack.svg", + }, } // jwtConfig is a new OAuth2 config that uses a custom diff --git a/coderd/externalauth/externalauth_test.go b/coderd/externalauth/externalauth_test.go index 418d143d16e7e..04d67736489cd 100644 --- a/coderd/externalauth/externalauth_test.go +++ b/coderd/externalauth/externalauth_test.go @@ -43,7 +43,7 @@ func TestRefreshToken(t *testing.T) { return nil, xerrors.New("should not be called") }), }, - GitConfigOpt: func(cfg *externalauth.Config) { + ExternalAuthOpt: func(cfg *externalauth.Config) { cfg.NoRefresh = true }, }) @@ -74,7 +74,7 @@ func TestRefreshToken(t *testing.T) { return jwt.MapClaims{}, nil }), }, - GitConfigOpt: func(cfg *externalauth.Config) { + ExternalAuthOpt: func(cfg *externalauth.Config) { cfg.NoRefresh = true }, }) @@ -117,7 +117,7 @@ func TestRefreshToken(t *testing.T) { return jwt.MapClaims{}, xerrors.New(staticError) }), }, - GitConfigOpt: func(cfg *externalauth.Config) { + ExternalAuthOpt: func(cfg *externalauth.Config) { }, }) @@ -142,7 +142,7 @@ func TestRefreshToken(t *testing.T) { return jwt.MapClaims{}, oidctest.StatusError(http.StatusUnauthorized, xerrors.New(staticError)) }), }, - GitConfigOpt: func(cfg *externalauth.Config) { + ExternalAuthOpt: func(cfg *externalauth.Config) { }, }) @@ -175,7 +175,7 @@ func TestRefreshToken(t *testing.T) { return jwt.MapClaims{}, oidctest.StatusError(http.StatusUnauthorized, xerrors.New(staticError)) }), }, - GitConfigOpt: func(cfg *externalauth.Config) { + ExternalAuthOpt: func(cfg *externalauth.Config) { cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String() }, }) @@ -205,7 +205,7 @@ func TestRefreshToken(t *testing.T) { return jwt.MapClaims{}, nil }), }, - GitConfigOpt: func(cfg *externalauth.Config) { + ExternalAuthOpt: func(cfg *externalauth.Config) { cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String() }, }) @@ -236,7 +236,7 @@ func TestRefreshToken(t *testing.T) { return jwt.MapClaims{}, nil }), }, - GitConfigOpt: func(cfg *externalauth.Config) { + ExternalAuthOpt: func(cfg *externalauth.Config) { cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String() }, DB: db, @@ -260,6 +260,38 @@ func TestRefreshToken(t *testing.T) { require.NoError(t, err) require.Equal(t, updated.OAuthAccessToken, dbLink.OAuthAccessToken, "token is updated in the DB") }) + + t.Run("SlackUserToken", func(t *testing.T) { + t.Parallel() + + db := dbfake.New() + fake, config, link := setupOauth2Test(t, testConfig{ + FakeIDPOpts: []oidctest.FakeIDPOpt{ + oidctest.WithExtra(func(email string) map[string]interface{} { + return map[string]interface{}{ + "authed_user": map[string]interface{}{ + "access_token": "slack-user-token", + }, + } + }), + }, + ExternalAuthOpt: func(cfg *externalauth.Config) { + cfg.Type = codersdk.EnhancedExternalAuthProviderSlack.String() + cfg.SlackAuthedUserToken = true + cfg.ValidateURL = "" + }, + DB: db, + }) + + ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil)) + // Force a refresh + link.OAuthExpiry = expired + + updated, ok, err := config.RefreshToken(ctx, db, link) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, "slack-user-token", updated.OAuthAccessToken) + }) } func TestConvertYAML(t *testing.T) { @@ -344,7 +376,7 @@ func TestConvertYAML(t *testing.T) { type testConfig struct { FakeIDPOpts []oidctest.FakeIDPOpt CoderOIDCConfigOpts []func(cfg *coderd.OIDCConfig) - GitConfigOpt func(cfg *externalauth.Config) + ExternalAuthOpt func(cfg *externalauth.Config) // If DB is passed in, the link will be inserted into the DB. DB database.Store } @@ -367,7 +399,7 @@ func setupOauth2Test(t *testing.T, settings testConfig) (*oidctest.FakeIDP, *ext ID: providerID, ValidateURL: fake.WellknownConfig().UserInfoURL, } - settings.GitConfigOpt(config) + settings.ExternalAuthOpt(config) oauthToken, err := fake.GenerateAuthenticatedToken(jwt.MapClaims{ "email": "test@coder.com", diff --git a/codersdk/deployment.go b/codersdk/deployment.go index db9113e63dc67..56412fbe39033 100644 --- a/codersdk/deployment.go +++ b/codersdk/deployment.go @@ -349,6 +349,12 @@ type ExternalAuthConfig struct { DisplayName string `json:"display_name"` // DisplayIcon is a URL to an icon to display in the UI. DisplayIcon string `json:"display_icon"` + + // SlackAuthedUserToken is a Slack-specific field that controls + // whether the Bot or User token is returned from the OAuth exchange. + // Slack returns multiple OAuth tokens as part of it's flow. + // See: https://api.slack.com/authentication/oauth-v2#exchanging + SlackAuthedUserToken bool `json:"slack_authed_user_token"` } type ProvisionerConfig struct { diff --git a/codersdk/externalauth.go b/codersdk/externalauth.go index 6aff5ad63bf76..0167ca8156259 100644 --- a/codersdk/externalauth.go +++ b/codersdk/externalauth.go @@ -34,6 +34,7 @@ const ( EnhancedExternalAuthProviderGitHub EnhancedExternalAuthProvider = "github" EnhancedExternalAuthProviderGitLab EnhancedExternalAuthProvider = "gitlab" EnhancedExternalAuthProviderBitBucket EnhancedExternalAuthProvider = "bitbucket" + EnhancedExternalAuthProviderSlack EnhancedExternalAuthProvider = "slack" ) type ExternalAuth struct { diff --git a/site/static/icon/slack.svg b/site/static/icon/slack.svg new file mode 100644 index 0000000000000..fb55f7245df5b --- /dev/null +++ b/site/static/icon/slack.svg @@ -0,0 +1,6 @@ + + + + + + From fdc6a5d12fda0b8534475b73325e18f12f152e01 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 9 Oct 2023 21:42:40 +0000 Subject: [PATCH 2/9] feat: allow storing extra oauth token properties in the database --- cli/server.go | 8 +- coderd/database/dbfake/dbfake.go | 2 + coderd/database/dump.sql | 3 +- .../000163_external_auth_extra.down.sql | 1 + .../000163_external_auth_extra.up.sql | 1 + coderd/database/models.go | 3 +- coderd/database/queries.sql.go | 285 +++++++++--------- coderd/database/queries/externalauth.sql | 9 +- coderd/database/sqlc.yaml | 1 + coderd/externalauth.go | 22 ++ coderd/externalauth/externalauth.go | 59 ++-- coderd/externalauth/externalauth_test.go | 6 +- codersdk/deployment.go | 7 +- site/src/api/typesGenerated.ts | 5 +- 14 files changed, 228 insertions(+), 184 deletions(-) create mode 100644 coderd/database/migrations/000163_external_auth_extra.down.sql create mode 100644 coderd/database/migrations/000163_external_auth_extra.up.sql diff --git a/cli/server.go b/cli/server.go index 31a2b63a49660..9f33ced438f84 100644 --- a/cli/server.go +++ b/cli/server.go @@ -2251,6 +2251,8 @@ func parseExternalAuthProvidersFromEnv(prefix string, environ []string) ([]coder provider.NoRefresh = b case "SCOPES": provider.Scopes = strings.Split(v.Value, " ") + case "EXTRA_TOKEN_KEYS": + provider.ExtraTokenKeys = strings.Split(v.Value, " ") case "APP_INSTALL_URL": provider.AppInstallURL = v.Value case "APP_INSTALLATIONS_URL": @@ -2259,12 +2261,6 @@ func parseExternalAuthProvidersFromEnv(prefix string, environ []string) ([]coder provider.DisplayName = v.Value case "DISPLAY_ICON": provider.DisplayIcon = v.Value - case "SLACK_AUTHED_USER_TOKEN": - b, err := strconv.ParseBool(v.Value) - if err != nil { - return nil, xerrors.Errorf("parse bool: %s", v.Value) - } - provider.SlackAuthedUserToken = b } providers[providerNum] = provider } diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index 7b1eb86c135df..edc5572dd551b 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -4246,6 +4246,7 @@ func (q *FakeQuerier) InsertExternalAuthLink(_ context.Context, arg database.Ins OAuthRefreshToken: arg.OAuthRefreshToken, OAuthRefreshTokenKeyID: arg.OAuthRefreshTokenKeyID, OAuthExpiry: arg.OAuthExpiry, + OAuthExtra: arg.OAuthExtra, } q.externalAuthLinks = append(q.externalAuthLinks, gitAuthLink) return gitAuthLink, nil @@ -5301,6 +5302,7 @@ func (q *FakeQuerier) UpdateExternalAuthLink(_ context.Context, arg database.Upd gitAuthLink.OAuthRefreshToken = arg.OAuthRefreshToken gitAuthLink.OAuthRefreshTokenKeyID = arg.OAuthRefreshTokenKeyID gitAuthLink.OAuthExpiry = arg.OAuthExpiry + gitAuthLink.OAuthExtra = arg.OAuthExtra q.externalAuthLinks[index] = gitAuthLink return gitAuthLink, nil diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 4e5f6ed5d62f1..c806248a18c48 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -359,7 +359,8 @@ CREATE TABLE external_auth_links ( oauth_refresh_token text NOT NULL, oauth_expiry timestamp with time zone NOT NULL, oauth_access_token_key_id text, - oauth_refresh_token_key_id text + oauth_refresh_token_key_id text, + oauth_extra jsonb ); COMMENT ON COLUMN external_auth_links.oauth_access_token_key_id IS 'The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted'; diff --git a/coderd/database/migrations/000163_external_auth_extra.down.sql b/coderd/database/migrations/000163_external_auth_extra.down.sql new file mode 100644 index 0000000000000..b926f31d81254 --- /dev/null +++ b/coderd/database/migrations/000163_external_auth_extra.down.sql @@ -0,0 +1 @@ +ALTER TABLE external_auth_links DROP COLUMN "oauth_extra"; diff --git a/coderd/database/migrations/000163_external_auth_extra.up.sql b/coderd/database/migrations/000163_external_auth_extra.up.sql new file mode 100644 index 0000000000000..7c8d9ec6ae163 --- /dev/null +++ b/coderd/database/migrations/000163_external_auth_extra.up.sql @@ -0,0 +1 @@ +ALTER TABLE external_auth_links ADD COLUMN "oauth_extra" jsonb; diff --git a/coderd/database/models.go b/coderd/database/models.go index 218f9020c57b5..5f389f36be4b8 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -1680,7 +1680,8 @@ type ExternalAuthLink struct { // The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"` // The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted - OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"` + OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"` + OAuthExtra pqtype.NullRawMessage `db:"oauth_extra" json:"oauth_extra"` } type File struct { diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 2e165b006778d..3c4ae4160caa4 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -751,7 +751,7 @@ func (q *sqlQuerier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest strin } const getExternalAuthLink = `-- name: GetExternalAuthLink :one -SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id FROM external_auth_links WHERE provider_id = $1 AND user_id = $2 +SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, oauth_extra FROM external_auth_links WHERE provider_id = $1 AND user_id = $2 ` type GetExternalAuthLinkParams struct { @@ -772,12 +772,13 @@ func (q *sqlQuerier) GetExternalAuthLink(ctx context.Context, arg GetExternalAut &i.OAuthExpiry, &i.OAuthAccessTokenKeyID, &i.OAuthRefreshTokenKeyID, + &i.OAuthExtra, ) return i, err } const getExternalAuthLinksByUserID = `-- name: GetExternalAuthLinksByUserID :many -SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id FROM external_auth_links WHERE user_id = $1 +SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, oauth_extra FROM external_auth_links WHERE user_id = $1 ` func (q *sqlQuerier) GetExternalAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]ExternalAuthLink, error) { @@ -799,6 +800,7 @@ func (q *sqlQuerier) GetExternalAuthLinksByUserID(ctx context.Context, userID uu &i.OAuthExpiry, &i.OAuthAccessTokenKeyID, &i.OAuthRefreshTokenKeyID, + &i.OAuthExtra, ); err != nil { return nil, err } @@ -823,7 +825,8 @@ INSERT INTO external_auth_links ( oauth_access_token_key_id, oauth_refresh_token, oauth_refresh_token_key_id, - oauth_expiry + oauth_expiry, + oauth_extra ) VALUES ( $1, $2, @@ -833,20 +836,22 @@ INSERT INTO external_auth_links ( $6, $7, $8, - $9 -) RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id + $9, + $10 +) RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, oauth_extra ` type InsertExternalAuthLinkParams struct { - ProviderID string `db:"provider_id" json:"provider_id"` - UserID uuid.UUID `db:"user_id" json:"user_id"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` - OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"` - OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` - OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"` - OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` + ProviderID string `db:"provider_id" json:"provider_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` + OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"` + OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` + OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"` + OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` + OAuthExtra pqtype.NullRawMessage `db:"oauth_extra" json:"oauth_extra"` } func (q *sqlQuerier) InsertExternalAuthLink(ctx context.Context, arg InsertExternalAuthLinkParams) (ExternalAuthLink, error) { @@ -860,6 +865,7 @@ func (q *sqlQuerier) InsertExternalAuthLink(ctx context.Context, arg InsertExter arg.OAuthRefreshToken, arg.OAuthRefreshTokenKeyID, arg.OAuthExpiry, + arg.OAuthExtra, ) var i ExternalAuthLink err := row.Scan( @@ -872,6 +878,7 @@ func (q *sqlQuerier) InsertExternalAuthLink(ctx context.Context, arg InsertExter &i.OAuthExpiry, &i.OAuthAccessTokenKeyID, &i.OAuthRefreshTokenKeyID, + &i.OAuthExtra, ) return i, err } @@ -883,19 +890,21 @@ UPDATE external_auth_links SET oauth_access_token_key_id = $5, oauth_refresh_token = $6, oauth_refresh_token_key_id = $7, - oauth_expiry = $8 -WHERE provider_id = $1 AND user_id = $2 RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id + oauth_expiry = $8, + oauth_extra = $9 +WHERE provider_id = $1 AND user_id = $2 RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, oauth_extra ` type UpdateExternalAuthLinkParams struct { - ProviderID string `db:"provider_id" json:"provider_id"` - UserID uuid.UUID `db:"user_id" json:"user_id"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` - OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"` - OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` - OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"` - OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` + ProviderID string `db:"provider_id" json:"provider_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` + OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"` + OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` + OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"` + OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` + OAuthExtra pqtype.NullRawMessage `db:"oauth_extra" json:"oauth_extra"` } func (q *sqlQuerier) UpdateExternalAuthLink(ctx context.Context, arg UpdateExternalAuthLinkParams) (ExternalAuthLink, error) { @@ -908,6 +917,7 @@ func (q *sqlQuerier) UpdateExternalAuthLink(ctx context.Context, arg UpdateExter arg.OAuthRefreshToken, arg.OAuthRefreshTokenKeyID, arg.OAuthExpiry, + arg.OAuthExtra, ) var i ExternalAuthLink err := row.Scan( @@ -920,6 +930,7 @@ func (q *sqlQuerier) UpdateExternalAuthLink(ctx context.Context, arg UpdateExter &i.OAuthExpiry, &i.OAuthAccessTokenKeyID, &i.OAuthRefreshTokenKeyID, + &i.OAuthExtra, ) return i, err } @@ -9561,6 +9572,119 @@ func (q *sqlQuerier) InsertWorkspaceResourceMetadata(ctx context.Context, arg In return items, nil } +const getWorkspaceAgentScriptsByAgentIDs = `-- name: GetWorkspaceAgentScriptsByAgentIDs :many +SELECT workspace_agent_id, log_source_id, log_path, created_at, script, cron, start_blocks_login, run_on_start, run_on_stop, timeout_seconds FROM workspace_agent_scripts WHERE workspace_agent_id = ANY($1 :: uuid [ ]) +` + +func (q *sqlQuerier) GetWorkspaceAgentScriptsByAgentIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceAgentScript, error) { + rows, err := q.db.QueryContext(ctx, getWorkspaceAgentScriptsByAgentIDs, pq.Array(ids)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []WorkspaceAgentScript + for rows.Next() { + var i WorkspaceAgentScript + if err := rows.Scan( + &i.WorkspaceAgentID, + &i.LogSourceID, + &i.LogPath, + &i.CreatedAt, + &i.Script, + &i.Cron, + &i.StartBlocksLogin, + &i.RunOnStart, + &i.RunOnStop, + &i.TimeoutSeconds, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const insertWorkspaceAgentScripts = `-- name: InsertWorkspaceAgentScripts :many +INSERT INTO + workspace_agent_scripts (workspace_agent_id, created_at, log_source_id, log_path, script, cron, start_blocks_login, run_on_start, run_on_stop, timeout_seconds) +SELECT + $1 :: uuid AS workspace_agent_id, + $2 :: timestamptz AS created_at, + unnest($3 :: uuid [ ]) AS log_source_id, + unnest($4 :: text [ ]) AS log_path, + unnest($5 :: text [ ]) AS script, + unnest($6 :: text [ ]) AS cron, + unnest($7 :: boolean [ ]) AS start_blocks_login, + unnest($8 :: boolean [ ]) AS run_on_start, + unnest($9 :: boolean [ ]) AS run_on_stop, + unnest($10 :: integer [ ]) AS timeout_seconds +RETURNING workspace_agent_scripts.workspace_agent_id, workspace_agent_scripts.log_source_id, workspace_agent_scripts.log_path, workspace_agent_scripts.created_at, workspace_agent_scripts.script, workspace_agent_scripts.cron, workspace_agent_scripts.start_blocks_login, workspace_agent_scripts.run_on_start, workspace_agent_scripts.run_on_stop, workspace_agent_scripts.timeout_seconds +` + +type InsertWorkspaceAgentScriptsParams struct { + WorkspaceAgentID uuid.UUID `db:"workspace_agent_id" json:"workspace_agent_id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + LogSourceID []uuid.UUID `db:"log_source_id" json:"log_source_id"` + LogPath []string `db:"log_path" json:"log_path"` + Script []string `db:"script" json:"script"` + Cron []string `db:"cron" json:"cron"` + StartBlocksLogin []bool `db:"start_blocks_login" json:"start_blocks_login"` + RunOnStart []bool `db:"run_on_start" json:"run_on_start"` + RunOnStop []bool `db:"run_on_stop" json:"run_on_stop"` + TimeoutSeconds []int32 `db:"timeout_seconds" json:"timeout_seconds"` +} + +func (q *sqlQuerier) InsertWorkspaceAgentScripts(ctx context.Context, arg InsertWorkspaceAgentScriptsParams) ([]WorkspaceAgentScript, error) { + rows, err := q.db.QueryContext(ctx, insertWorkspaceAgentScripts, + arg.WorkspaceAgentID, + arg.CreatedAt, + pq.Array(arg.LogSourceID), + pq.Array(arg.LogPath), + pq.Array(arg.Script), + pq.Array(arg.Cron), + pq.Array(arg.StartBlocksLogin), + pq.Array(arg.RunOnStart), + pq.Array(arg.RunOnStop), + pq.Array(arg.TimeoutSeconds), + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []WorkspaceAgentScript + for rows.Next() { + var i WorkspaceAgentScript + if err := rows.Scan( + &i.WorkspaceAgentID, + &i.LogSourceID, + &i.LogPath, + &i.CreatedAt, + &i.Script, + &i.Cron, + &i.StartBlocksLogin, + &i.RunOnStart, + &i.RunOnStop, + &i.TimeoutSeconds, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getDeploymentWorkspaceStats = `-- name: GetDeploymentWorkspaceStats :one WITH workspaces_with_jobs AS ( SELECT @@ -10521,116 +10645,3 @@ func (q *sqlQuerier) UpdateWorkspacesDormantDeletingAtByTemplateID(ctx context.C _, err := q.db.ExecContext(ctx, updateWorkspacesDormantDeletingAtByTemplateID, arg.TimeTilDormantAutodeleteMs, arg.DormantAt, arg.TemplateID) return err } - -const getWorkspaceAgentScriptsByAgentIDs = `-- name: GetWorkspaceAgentScriptsByAgentIDs :many -SELECT workspace_agent_id, log_source_id, log_path, created_at, script, cron, start_blocks_login, run_on_start, run_on_stop, timeout_seconds FROM workspace_agent_scripts WHERE workspace_agent_id = ANY($1 :: uuid [ ]) -` - -func (q *sqlQuerier) GetWorkspaceAgentScriptsByAgentIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceAgentScript, error) { - rows, err := q.db.QueryContext(ctx, getWorkspaceAgentScriptsByAgentIDs, pq.Array(ids)) - if err != nil { - return nil, err - } - defer rows.Close() - var items []WorkspaceAgentScript - for rows.Next() { - var i WorkspaceAgentScript - if err := rows.Scan( - &i.WorkspaceAgentID, - &i.LogSourceID, - &i.LogPath, - &i.CreatedAt, - &i.Script, - &i.Cron, - &i.StartBlocksLogin, - &i.RunOnStart, - &i.RunOnStop, - &i.TimeoutSeconds, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const insertWorkspaceAgentScripts = `-- name: InsertWorkspaceAgentScripts :many -INSERT INTO - workspace_agent_scripts (workspace_agent_id, created_at, log_source_id, log_path, script, cron, start_blocks_login, run_on_start, run_on_stop, timeout_seconds) -SELECT - $1 :: uuid AS workspace_agent_id, - $2 :: timestamptz AS created_at, - unnest($3 :: uuid [ ]) AS log_source_id, - unnest($4 :: text [ ]) AS log_path, - unnest($5 :: text [ ]) AS script, - unnest($6 :: text [ ]) AS cron, - unnest($7 :: boolean [ ]) AS start_blocks_login, - unnest($8 :: boolean [ ]) AS run_on_start, - unnest($9 :: boolean [ ]) AS run_on_stop, - unnest($10 :: integer [ ]) AS timeout_seconds -RETURNING workspace_agent_scripts.workspace_agent_id, workspace_agent_scripts.log_source_id, workspace_agent_scripts.log_path, workspace_agent_scripts.created_at, workspace_agent_scripts.script, workspace_agent_scripts.cron, workspace_agent_scripts.start_blocks_login, workspace_agent_scripts.run_on_start, workspace_agent_scripts.run_on_stop, workspace_agent_scripts.timeout_seconds -` - -type InsertWorkspaceAgentScriptsParams struct { - WorkspaceAgentID uuid.UUID `db:"workspace_agent_id" json:"workspace_agent_id"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - LogSourceID []uuid.UUID `db:"log_source_id" json:"log_source_id"` - LogPath []string `db:"log_path" json:"log_path"` - Script []string `db:"script" json:"script"` - Cron []string `db:"cron" json:"cron"` - StartBlocksLogin []bool `db:"start_blocks_login" json:"start_blocks_login"` - RunOnStart []bool `db:"run_on_start" json:"run_on_start"` - RunOnStop []bool `db:"run_on_stop" json:"run_on_stop"` - TimeoutSeconds []int32 `db:"timeout_seconds" json:"timeout_seconds"` -} - -func (q *sqlQuerier) InsertWorkspaceAgentScripts(ctx context.Context, arg InsertWorkspaceAgentScriptsParams) ([]WorkspaceAgentScript, error) { - rows, err := q.db.QueryContext(ctx, insertWorkspaceAgentScripts, - arg.WorkspaceAgentID, - arg.CreatedAt, - pq.Array(arg.LogSourceID), - pq.Array(arg.LogPath), - pq.Array(arg.Script), - pq.Array(arg.Cron), - pq.Array(arg.StartBlocksLogin), - pq.Array(arg.RunOnStart), - pq.Array(arg.RunOnStop), - pq.Array(arg.TimeoutSeconds), - ) - if err != nil { - return nil, err - } - defer rows.Close() - var items []WorkspaceAgentScript - for rows.Next() { - var i WorkspaceAgentScript - if err := rows.Scan( - &i.WorkspaceAgentID, - &i.LogSourceID, - &i.LogPath, - &i.CreatedAt, - &i.Script, - &i.Cron, - &i.StartBlocksLogin, - &i.RunOnStart, - &i.RunOnStop, - &i.TimeoutSeconds, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} diff --git a/coderd/database/queries/externalauth.sql b/coderd/database/queries/externalauth.sql index 12a791d8c1cfb..dfc195b9ea886 100644 --- a/coderd/database/queries/externalauth.sql +++ b/coderd/database/queries/externalauth.sql @@ -14,7 +14,8 @@ INSERT INTO external_auth_links ( oauth_access_token_key_id, oauth_refresh_token, oauth_refresh_token_key_id, - oauth_expiry + oauth_expiry, + oauth_extra ) VALUES ( $1, $2, @@ -24,7 +25,8 @@ INSERT INTO external_auth_links ( $6, $7, $8, - $9 + $9, + $10 ) RETURNING *; -- name: UpdateExternalAuthLink :one @@ -34,5 +36,6 @@ UPDATE external_auth_links SET oauth_access_token_key_id = $5, oauth_refresh_token = $6, oauth_refresh_token_key_id = $7, - oauth_expiry = $8 + oauth_expiry = $8, + oauth_extra = $9 WHERE provider_id = $1 AND user_id = $2 RETURNING *; diff --git a/coderd/database/sqlc.yaml b/coderd/database/sqlc.yaml index 1bdc972927f6f..592b2c7b5e32e 100644 --- a/coderd/database/sqlc.yaml +++ b/coderd/database/sqlc.yaml @@ -53,6 +53,7 @@ overrides: oauth_id_token: OAuthIDToken oauth_refresh_token: OAuthRefreshToken oauth_refresh_token_key_id: OAuthRefreshTokenKeyID + oauth_extra: OAuthExtra parameter_type_system_hcl: ParameterTypeSystemHCL userstatus: UserStatus gitsshkey: GitSSHKey diff --git a/coderd/externalauth.go b/coderd/externalauth.go index 577fdfa0b0877..b87ba32852331 100644 --- a/coderd/externalauth.go +++ b/coderd/externalauth.go @@ -2,6 +2,7 @@ package coderd import ( "database/sql" + "encoding/json" "errors" "fmt" "net/http" @@ -14,6 +15,7 @@ import ( "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/codersdk" + "github.com/sqlc-dev/pqtype" ) // @Summary Get external auth by ID @@ -201,6 +203,24 @@ func (api *API) externalAuthCallback(externalAuthConfig *externalauth.Config) ht apiKey = httpmw.APIKey(r) ) + extra := pqtype.NullRawMessage{} + if len(externalAuthConfig.ExtraTokenKeys) > 0 { + extraMap := map[string]interface{}{} + for _, key := range externalAuthConfig.ExtraTokenKeys { + extraMap[key] = state.Token.Extra(key) + } + extraData, err := json.Marshal(extraMap) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to marshal extra token keys.", + Detail: err.Error(), + }) + return + } + extra.RawMessage = extraData + extra.Valid = true + } + _, err := api.Database.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{ ProviderID: externalAuthConfig.ID, UserID: apiKey.UserID, @@ -224,6 +244,7 @@ func (api *API) externalAuthCallback(externalAuthConfig *externalauth.Config) ht OAuthRefreshToken: state.Token.RefreshToken, OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will set as required OAuthExpiry: state.Token.Expiry, + OAuthExtra: extra, }) if err != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ @@ -242,6 +263,7 @@ func (api *API) externalAuthCallback(externalAuthConfig *externalauth.Config) ht OAuthRefreshToken: state.Token.RefreshToken, OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required OAuthExpiry: state.Token.Expiry, + OAuthExtra: extra, }) if err != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ diff --git a/coderd/externalauth/externalauth.go b/coderd/externalauth/externalauth.go index 92a2aa28cb75c..13f92793d9dee 100644 --- a/coderd/externalauth/externalauth.go +++ b/coderd/externalauth/externalauth.go @@ -15,6 +15,7 @@ import ( "golang.org/x/xerrors" "github.com/google/go-github/v43/github" + "github.com/sqlc-dev/pqtype" xgithub "golang.org/x/oauth2/github" "github.com/coder/coder/v2/coderd/database" @@ -44,6 +45,14 @@ type Config struct { // DisplayIcon is the path to an image that will be displayed to the user. DisplayIcon string + // ExtraTokenKeys is a list of extra properties to + // store in the database returned from the token endpoint. + // + // e.g. Slack returns `authed_user` in the token which is + // a payload that contains information about the authenticated + // user. + ExtraTokenKeys []string + // NoRefresh stops Coder from using the refresh token // to renew the access token. // @@ -67,10 +76,6 @@ type Config struct { // AppInstallationsURL is an API endpoint that returns a list of // installations for the user. This is used for GitHub Apps. AppInstallationsURL string - - // SlackAuthedUserToken is true if the user token should be returned - // instead of the bot token. - SlackAuthedUserToken bool } // RefreshToken automatically refreshes the token if expired and permitted. @@ -106,19 +111,16 @@ func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAu return externalAuthLink, false, nil } - // Slack's new OAuth2 flow has the user access token in a different field. - // It's weird and unfortunate, but the only way to access the user token. - // See: https://api.slack.com/authentication/oauth-v2#exchanging - if c.Type == string(codersdk.EnhancedExternalAuthProviderSlack) && c.SlackAuthedUserToken { - rawMap, ok := token.Extra("authed_user").(map[string]interface{}) - if !ok { - return externalAuthLink, false, xerrors.Errorf("slack: could not obtain user access token from payload: %+v", token.Extra("authed_user")) + var extra json.RawMessage + if len(c.ExtraTokenKeys) > 0 { + extraMap := map[string]interface{}{} + for _, key := range c.ExtraTokenKeys { + extraMap[key] = token.Extra(key) } - accessToken, ok := rawMap["access_token"].(string) - if !ok { - return externalAuthLink, false, xerrors.Errorf("slack: could not obtain user access token from payload: %+v", token.Extra("authed_user")) + extra, err = json.Marshal(extraMap) + if err != nil { + return externalAuthLink, false, xerrors.Errorf("marshal extra token keys: %w", err) } - token.AccessToken = accessToken } r := retry.New(50*time.Millisecond, 200*time.Millisecond) @@ -155,6 +157,10 @@ validate: OAuthRefreshToken: token.RefreshToken, OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required OAuthExpiry: token.Expiry, + OAuthExtra: pqtype.NullRawMessage{ + Valid: extra != nil, + RawMessage: extra, + }, }) if err != nil { return updatedAuthLink, false, xerrors.Errorf("update external auth link: %w", err) @@ -444,17 +450,16 @@ func ConvertConfig(entries []codersdk.ExternalAuthConfig, accessURL *url.URL) ([ } cfg := &Config{ - OAuth2Config: oauthConfig, - ID: entry.ID, - Regex: regex, - Type: entry.Type, - NoRefresh: entry.NoRefresh, - ValidateURL: entry.ValidateURL, - AppInstallationsURL: entry.AppInstallationsURL, - AppInstallURL: entry.AppInstallURL, - DisplayName: entry.DisplayName, - DisplayIcon: entry.DisplayIcon, - SlackAuthedUserToken: entry.SlackAuthedUserToken, + OAuth2Config: oauthConfig, + ID: entry.ID, + Regex: regex, + Type: entry.Type, + NoRefresh: entry.NoRefresh, + ValidateURL: entry.ValidateURL, + AppInstallationsURL: entry.AppInstallationsURL, + AppInstallURL: entry.AppInstallURL, + DisplayName: entry.DisplayName, + DisplayIcon: entry.DisplayIcon, } if entry.DeviceFlow { @@ -565,6 +570,8 @@ var defaults = map[codersdk.EnhancedExternalAuthProvider]codersdk.ExternalAuthCo TokenURL: "https://slack.com/api/oauth.v2.access", DisplayName: "Slack", DisplayIcon: "/icon/slack.svg", + // See: https://api.slack.com/authentication/oauth-v2#exchanging + ExtraTokenKeys: []string{"authed_user"}, }, } diff --git a/coderd/externalauth/externalauth_test.go b/coderd/externalauth/externalauth_test.go index 04d67736489cd..1f27c7932f8cf 100644 --- a/coderd/externalauth/externalauth_test.go +++ b/coderd/externalauth/externalauth_test.go @@ -261,7 +261,7 @@ func TestRefreshToken(t *testing.T) { require.Equal(t, updated.OAuthAccessToken, dbLink.OAuthAccessToken, "token is updated in the DB") }) - t.Run("SlackUserToken", func(t *testing.T) { + t.Run("WithExtra", func(t *testing.T) { t.Parallel() db := dbfake.New() @@ -277,7 +277,7 @@ func TestRefreshToken(t *testing.T) { }, ExternalAuthOpt: func(cfg *externalauth.Config) { cfg.Type = codersdk.EnhancedExternalAuthProviderSlack.String() - cfg.SlackAuthedUserToken = true + cfg.ExtraTokenKeys = []string{"authed_user"} cfg.ValidateURL = "" }, DB: db, @@ -290,7 +290,7 @@ func TestRefreshToken(t *testing.T) { updated, ok, err := config.RefreshToken(ctx, db, link) require.NoError(t, err) require.True(t, ok) - require.Equal(t, "slack-user-token", updated.OAuthAccessToken) + require.True(t, updated.OAuthExtra.Valid) }) } diff --git a/codersdk/deployment.go b/codersdk/deployment.go index 56412fbe39033..bd89af4201a10 100644 --- a/codersdk/deployment.go +++ b/codersdk/deployment.go @@ -336,6 +336,7 @@ type ExternalAuthConfig struct { AppInstallationsURL string `json:"app_installations_url"` NoRefresh bool `json:"no_refresh"` Scopes []string `json:"scopes"` + ExtraTokenKeys []string `json:"extra_token_keys"` DeviceFlow bool `json:"device_flow"` DeviceCodeURL string `json:"device_code_url"` // Regex allows API requesters to match an auth config by @@ -349,12 +350,6 @@ type ExternalAuthConfig struct { DisplayName string `json:"display_name"` // DisplayIcon is a URL to an icon to display in the UI. DisplayIcon string `json:"display_icon"` - - // SlackAuthedUserToken is a Slack-specific field that controls - // whether the Bot or User token is returned from the OAuth exchange. - // Slack returns multiple OAuth tokens as part of it's flow. - // See: https://api.slack.com/authentication/oauth-v2#exchanging - SlackAuthedUserToken bool `json:"slack_authed_user_token"` } type ProvisionerConfig struct { diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 55a3fbe6915a2..45b2676666904 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -453,6 +453,7 @@ export interface ExternalAuthConfig { readonly app_installations_url: string; readonly no_refresh: boolean; readonly scopes: string[]; + readonly extra_token_keys: string[]; readonly device_flow: boolean; readonly device_code_url: string; readonly regex: string; @@ -1650,12 +1651,14 @@ export type EnhancedExternalAuthProvider = | "azure-devops" | "bitbucket" | "github" - | "gitlab"; + | "gitlab" + | "slack"; export const EnhancedExternalAuthProviders: EnhancedExternalAuthProvider[] = [ "azure-devops", "bitbucket", "github", "gitlab", + "slack", ]; // From codersdk/deployment.go From 40156194c977940127e32172c7b476b9d3117d19 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 9 Oct 2023 21:48:38 +0000 Subject: [PATCH 3/9] Comments --- coderd/apidoc/docs.go | 6 ++++++ coderd/apidoc/swagger.json | 6 ++++++ coderd/coderdtest/oidctest/idp.go | 2 ++ docs/api/general.md | 1 + docs/api/schemas.md | 5 +++++ site/src/theme/icons.json | 1 + 6 files changed, 21 insertions(+) diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index bb7dd87636cbb..bb925f2a9ced7 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -8321,6 +8321,12 @@ const docTemplate = `{ "description": "DisplayName is shown in the UI to identify the auth config.", "type": "string" }, + "extra_token_keys": { + "type": "array", + "items": { + "type": "string" + } + }, "id": { "description": "ID is a unique identifier for the auth config.\nIt defaults to ` + "`" + `type` + "`" + ` when not provided.", "type": "string" diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index 9e26d94e0023f..ab60a425687df 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -7459,6 +7459,12 @@ "description": "DisplayName is shown in the UI to identify the auth config.", "type": "string" }, + "extra_token_keys": { + "type": "array", + "items": { + "type": "string" + } + }, "id": { "description": "ID is a unique identifier for the auth config.\nIt defaults to `type` when not provided.", "type": "string" diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index 8daf6a63720db..8c6019deee3da 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -113,6 +113,8 @@ func WithRefresh(hook func(email string) error) func(*FakeIDP) { } } +// WithExtra returns extra fields that be accessed on the returned Oauth Token. +// These extra fields can override the default fields (id_token, access_token, etc). func WithExtra(extra func(email string) map[string]interface{}) func(*FakeIDP) { return func(f *FakeIDP) { f.hookExtra = extra diff --git a/docs/api/general.md b/docs/api/general.md index ad2dcf67f05ca..bd6942fd3d271 100644 --- a/docs/api/general.md +++ b/docs/api/general.md @@ -223,6 +223,7 @@ curl -X GET http://coder-server:8080/api/v2/deployment/config \ "device_flow": true, "display_icon": "string", "display_name": "string", + "extra_token_keys": ["string"], "id": "string", "no_refresh": true, "regex": "string", diff --git a/docs/api/schemas.md b/docs/api/schemas.md index 926df93007bdb..8dd0d7b7a4a3e 100644 --- a/docs/api/schemas.md +++ b/docs/api/schemas.md @@ -634,6 +634,7 @@ _None_ "device_flow": true, "display_icon": "string", "display_name": "string", + "extra_token_keys": ["string"], "id": "string", "no_refresh": true, "regex": "string", @@ -2073,6 +2074,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in "device_flow": true, "display_icon": "string", "display_name": "string", + "extra_token_keys": ["string"], "id": "string", "no_refresh": true, "regex": "string", @@ -2440,6 +2442,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in "device_flow": true, "display_icon": "string", "display_name": "string", + "extra_token_keys": ["string"], "id": "string", "no_refresh": true, "regex": "string", @@ -2852,6 +2855,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in "device_flow": true, "display_icon": "string", "display_name": "string", + "extra_token_keys": ["string"], "id": "string", "no_refresh": true, "regex": "string", @@ -2874,6 +2878,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in | `device_flow` | boolean | false | | | | `display_icon` | string | false | | Display icon is a URL to an icon to display in the UI. | | `display_name` | string | false | | Display name is shown in the UI to identify the auth config. | +| `extra_token_keys` | array of string | false | | | | `id` | string | false | | ID is a unique identifier for the auth config. It defaults to `type` when not provided. | | `no_refresh` | boolean | false | | | | `regex` | string | false | | Regex allows API requesters to match an auth config by a string (e.g. coder.com) instead of by it's type. | diff --git a/site/src/theme/icons.json b/site/src/theme/icons.json index 753a0121a4b97..12fbaec3f2805 100644 --- a/site/src/theme/icons.json +++ b/site/src/theme/icons.json @@ -55,6 +55,7 @@ "ruby.png", "rubymine.svg", "rust.svg", + "slack.svg", "swift.svg", "tensorflow.svg", "terminal.svg", From b73fd0b9b50bfebe67f1018961029e6ca6c845d3 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 9 Oct 2023 21:52:16 +0000 Subject: [PATCH 4/9] Fix db gen --- coderd/database/queries.sql.go | 226 ++++++++++++++++----------------- 1 file changed, 113 insertions(+), 113 deletions(-) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 3c4ae4160caa4..f69eb27729672 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -9572,119 +9572,6 @@ func (q *sqlQuerier) InsertWorkspaceResourceMetadata(ctx context.Context, arg In return items, nil } -const getWorkspaceAgentScriptsByAgentIDs = `-- name: GetWorkspaceAgentScriptsByAgentIDs :many -SELECT workspace_agent_id, log_source_id, log_path, created_at, script, cron, start_blocks_login, run_on_start, run_on_stop, timeout_seconds FROM workspace_agent_scripts WHERE workspace_agent_id = ANY($1 :: uuid [ ]) -` - -func (q *sqlQuerier) GetWorkspaceAgentScriptsByAgentIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceAgentScript, error) { - rows, err := q.db.QueryContext(ctx, getWorkspaceAgentScriptsByAgentIDs, pq.Array(ids)) - if err != nil { - return nil, err - } - defer rows.Close() - var items []WorkspaceAgentScript - for rows.Next() { - var i WorkspaceAgentScript - if err := rows.Scan( - &i.WorkspaceAgentID, - &i.LogSourceID, - &i.LogPath, - &i.CreatedAt, - &i.Script, - &i.Cron, - &i.StartBlocksLogin, - &i.RunOnStart, - &i.RunOnStop, - &i.TimeoutSeconds, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const insertWorkspaceAgentScripts = `-- name: InsertWorkspaceAgentScripts :many -INSERT INTO - workspace_agent_scripts (workspace_agent_id, created_at, log_source_id, log_path, script, cron, start_blocks_login, run_on_start, run_on_stop, timeout_seconds) -SELECT - $1 :: uuid AS workspace_agent_id, - $2 :: timestamptz AS created_at, - unnest($3 :: uuid [ ]) AS log_source_id, - unnest($4 :: text [ ]) AS log_path, - unnest($5 :: text [ ]) AS script, - unnest($6 :: text [ ]) AS cron, - unnest($7 :: boolean [ ]) AS start_blocks_login, - unnest($8 :: boolean [ ]) AS run_on_start, - unnest($9 :: boolean [ ]) AS run_on_stop, - unnest($10 :: integer [ ]) AS timeout_seconds -RETURNING workspace_agent_scripts.workspace_agent_id, workspace_agent_scripts.log_source_id, workspace_agent_scripts.log_path, workspace_agent_scripts.created_at, workspace_agent_scripts.script, workspace_agent_scripts.cron, workspace_agent_scripts.start_blocks_login, workspace_agent_scripts.run_on_start, workspace_agent_scripts.run_on_stop, workspace_agent_scripts.timeout_seconds -` - -type InsertWorkspaceAgentScriptsParams struct { - WorkspaceAgentID uuid.UUID `db:"workspace_agent_id" json:"workspace_agent_id"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - LogSourceID []uuid.UUID `db:"log_source_id" json:"log_source_id"` - LogPath []string `db:"log_path" json:"log_path"` - Script []string `db:"script" json:"script"` - Cron []string `db:"cron" json:"cron"` - StartBlocksLogin []bool `db:"start_blocks_login" json:"start_blocks_login"` - RunOnStart []bool `db:"run_on_start" json:"run_on_start"` - RunOnStop []bool `db:"run_on_stop" json:"run_on_stop"` - TimeoutSeconds []int32 `db:"timeout_seconds" json:"timeout_seconds"` -} - -func (q *sqlQuerier) InsertWorkspaceAgentScripts(ctx context.Context, arg InsertWorkspaceAgentScriptsParams) ([]WorkspaceAgentScript, error) { - rows, err := q.db.QueryContext(ctx, insertWorkspaceAgentScripts, - arg.WorkspaceAgentID, - arg.CreatedAt, - pq.Array(arg.LogSourceID), - pq.Array(arg.LogPath), - pq.Array(arg.Script), - pq.Array(arg.Cron), - pq.Array(arg.StartBlocksLogin), - pq.Array(arg.RunOnStart), - pq.Array(arg.RunOnStop), - pq.Array(arg.TimeoutSeconds), - ) - if err != nil { - return nil, err - } - defer rows.Close() - var items []WorkspaceAgentScript - for rows.Next() { - var i WorkspaceAgentScript - if err := rows.Scan( - &i.WorkspaceAgentID, - &i.LogSourceID, - &i.LogPath, - &i.CreatedAt, - &i.Script, - &i.Cron, - &i.StartBlocksLogin, - &i.RunOnStart, - &i.RunOnStop, - &i.TimeoutSeconds, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - const getDeploymentWorkspaceStats = `-- name: GetDeploymentWorkspaceStats :one WITH workspaces_with_jobs AS ( SELECT @@ -10645,3 +10532,116 @@ func (q *sqlQuerier) UpdateWorkspacesDormantDeletingAtByTemplateID(ctx context.C _, err := q.db.ExecContext(ctx, updateWorkspacesDormantDeletingAtByTemplateID, arg.TimeTilDormantAutodeleteMs, arg.DormantAt, arg.TemplateID) return err } + +const getWorkspaceAgentScriptsByAgentIDs = `-- name: GetWorkspaceAgentScriptsByAgentIDs :many +SELECT workspace_agent_id, log_source_id, log_path, created_at, script, cron, start_blocks_login, run_on_start, run_on_stop, timeout_seconds FROM workspace_agent_scripts WHERE workspace_agent_id = ANY($1 :: uuid [ ]) +` + +func (q *sqlQuerier) GetWorkspaceAgentScriptsByAgentIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceAgentScript, error) { + rows, err := q.db.QueryContext(ctx, getWorkspaceAgentScriptsByAgentIDs, pq.Array(ids)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []WorkspaceAgentScript + for rows.Next() { + var i WorkspaceAgentScript + if err := rows.Scan( + &i.WorkspaceAgentID, + &i.LogSourceID, + &i.LogPath, + &i.CreatedAt, + &i.Script, + &i.Cron, + &i.StartBlocksLogin, + &i.RunOnStart, + &i.RunOnStop, + &i.TimeoutSeconds, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const insertWorkspaceAgentScripts = `-- name: InsertWorkspaceAgentScripts :many +INSERT INTO + workspace_agent_scripts (workspace_agent_id, created_at, log_source_id, log_path, script, cron, start_blocks_login, run_on_start, run_on_stop, timeout_seconds) +SELECT + $1 :: uuid AS workspace_agent_id, + $2 :: timestamptz AS created_at, + unnest($3 :: uuid [ ]) AS log_source_id, + unnest($4 :: text [ ]) AS log_path, + unnest($5 :: text [ ]) AS script, + unnest($6 :: text [ ]) AS cron, + unnest($7 :: boolean [ ]) AS start_blocks_login, + unnest($8 :: boolean [ ]) AS run_on_start, + unnest($9 :: boolean [ ]) AS run_on_stop, + unnest($10 :: integer [ ]) AS timeout_seconds +RETURNING workspace_agent_scripts.workspace_agent_id, workspace_agent_scripts.log_source_id, workspace_agent_scripts.log_path, workspace_agent_scripts.created_at, workspace_agent_scripts.script, workspace_agent_scripts.cron, workspace_agent_scripts.start_blocks_login, workspace_agent_scripts.run_on_start, workspace_agent_scripts.run_on_stop, workspace_agent_scripts.timeout_seconds +` + +type InsertWorkspaceAgentScriptsParams struct { + WorkspaceAgentID uuid.UUID `db:"workspace_agent_id" json:"workspace_agent_id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + LogSourceID []uuid.UUID `db:"log_source_id" json:"log_source_id"` + LogPath []string `db:"log_path" json:"log_path"` + Script []string `db:"script" json:"script"` + Cron []string `db:"cron" json:"cron"` + StartBlocksLogin []bool `db:"start_blocks_login" json:"start_blocks_login"` + RunOnStart []bool `db:"run_on_start" json:"run_on_start"` + RunOnStop []bool `db:"run_on_stop" json:"run_on_stop"` + TimeoutSeconds []int32 `db:"timeout_seconds" json:"timeout_seconds"` +} + +func (q *sqlQuerier) InsertWorkspaceAgentScripts(ctx context.Context, arg InsertWorkspaceAgentScriptsParams) ([]WorkspaceAgentScript, error) { + rows, err := q.db.QueryContext(ctx, insertWorkspaceAgentScripts, + arg.WorkspaceAgentID, + arg.CreatedAt, + pq.Array(arg.LogSourceID), + pq.Array(arg.LogPath), + pq.Array(arg.Script), + pq.Array(arg.Cron), + pq.Array(arg.StartBlocksLogin), + pq.Array(arg.RunOnStart), + pq.Array(arg.RunOnStop), + pq.Array(arg.TimeoutSeconds), + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []WorkspaceAgentScript + for rows.Next() { + var i WorkspaceAgentScript + if err := rows.Scan( + &i.WorkspaceAgentID, + &i.LogSourceID, + &i.LogPath, + &i.CreatedAt, + &i.Script, + &i.Cron, + &i.StartBlocksLogin, + &i.RunOnStart, + &i.RunOnStop, + &i.TimeoutSeconds, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} From fb66ac193dd9fab673ba8be31c92f463d2f2851f Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 9 Oct 2023 22:01:47 +0000 Subject: [PATCH 5/9] Comments --- coderd/externalauth.go | 28 +++++++--------------- coderd/externalauth/externalauth.go | 37 ++++++++++++++++++----------- 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/coderd/externalauth.go b/coderd/externalauth.go index b87ba32852331..9dc223dd3f6dd 100644 --- a/coderd/externalauth.go +++ b/coderd/externalauth.go @@ -2,7 +2,6 @@ package coderd import ( "database/sql" - "encoding/json" "errors" "fmt" "net/http" @@ -15,7 +14,6 @@ import ( "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/codersdk" - "github.com/sqlc-dev/pqtype" ) // @Summary Get external auth by ID @@ -203,25 +201,15 @@ func (api *API) externalAuthCallback(externalAuthConfig *externalauth.Config) ht apiKey = httpmw.APIKey(r) ) - extra := pqtype.NullRawMessage{} - if len(externalAuthConfig.ExtraTokenKeys) > 0 { - extraMap := map[string]interface{}{} - for _, key := range externalAuthConfig.ExtraTokenKeys { - extraMap[key] = state.Token.Extra(key) - } - extraData, err := json.Marshal(extraMap) - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to marshal extra token keys.", - Detail: err.Error(), - }) - return - } - extra.RawMessage = extraData - extra.Valid = true + extra, err := externalAuthConfig.GenerateTokenExtra(state.Token) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to generate token extra.", + Detail: err.Error(), + }) + return } - - _, err := api.Database.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{ + _, err = api.Database.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{ ProviderID: externalAuthConfig.ID, UserID: apiKey.UserID, }) diff --git a/coderd/externalauth/externalauth.go b/coderd/externalauth/externalauth.go index 13f92793d9dee..7a84b56e04752 100644 --- a/coderd/externalauth/externalauth.go +++ b/coderd/externalauth/externalauth.go @@ -78,6 +78,25 @@ type Config struct { AppInstallationsURL string } +// GenerateTokenExtra generates the extra token data to store in the database. +func (c *Config) GenerateTokenExtra(token *oauth2.Token) (pqtype.NullRawMessage, error) { + if len(c.ExtraTokenKeys) == 0 { + return pqtype.NullRawMessage{}, nil + } + extraMap := map[string]interface{}{} + for _, key := range c.ExtraTokenKeys { + extraMap[key] = token.Extra(key) + } + data, err := json.Marshal(extraMap) + if err != nil { + return pqtype.NullRawMessage{}, err + } + return pqtype.NullRawMessage{ + RawMessage: data, + Valid: true, + }, nil +} + // RefreshToken automatically refreshes the token if expired and permitted. // It returns the token and a bool indicating if the token is valid. func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAuthLink database.ExternalAuthLink) (database.ExternalAuthLink, bool, error) { @@ -111,16 +130,9 @@ func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAu return externalAuthLink, false, nil } - var extra json.RawMessage - if len(c.ExtraTokenKeys) > 0 { - extraMap := map[string]interface{}{} - for _, key := range c.ExtraTokenKeys { - extraMap[key] = token.Extra(key) - } - extra, err = json.Marshal(extraMap) - if err != nil { - return externalAuthLink, false, xerrors.Errorf("marshal extra token keys: %w", err) - } + extra, err := c.GenerateTokenExtra(token) + if err != nil { + return externalAuthLink, false, xerrors.Errorf("generate token extra: %w", err) } r := retry.New(50*time.Millisecond, 200*time.Millisecond) @@ -157,10 +169,7 @@ validate: OAuthRefreshToken: token.RefreshToken, OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required OAuthExpiry: token.Expiry, - OAuthExtra: pqtype.NullRawMessage{ - Valid: extra != nil, - RawMessage: extra, - }, + OAuthExtra: extra, }) if err != nil { return updatedAuthLink, false, xerrors.Errorf("update external auth link: %w", err) From f18900c67327eabec1c88bb29b05fc959024968d Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 9 Oct 2023 22:08:54 +0000 Subject: [PATCH 6/9] Fix FE --- .../ExternalAuthSettingsPageView.stories.tsx | 1 + 1 file changed, 1 insertion(+) diff --git a/site/src/pages/DeploySettingsPage/ExternalAuthSettingsPage/ExternalAuthSettingsPageView.stories.tsx b/site/src/pages/DeploySettingsPage/ExternalAuthSettingsPage/ExternalAuthSettingsPageView.stories.tsx index bb6ee4eb68e55..311d20109e129 100644 --- a/site/src/pages/DeploySettingsPage/ExternalAuthSettingsPage/ExternalAuthSettingsPageView.stories.tsx +++ b/site/src/pages/DeploySettingsPage/ExternalAuthSettingsPage/ExternalAuthSettingsPageView.stories.tsx @@ -19,6 +19,7 @@ const meta: Meta = { app_installations_url: "", no_refresh: false, scopes: [], + extra_token_keys: [], device_flow: true, device_code_url: "", display_icon: "", From 4d62ebaad5d7f2571d3061703c7cb2d284b8b8f2 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 9 Oct 2023 22:11:49 +0000 Subject: [PATCH 7/9] Fix WithExtra --- coderd/coderdtest/oidctest/idp.go | 12 +++++------- coderd/externalauth/externalauth_test.go | 8 +++----- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index 8c6019deee3da..807257ff18df1 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -68,7 +68,7 @@ type FakeIDP struct { // "Authorized Redirect URLs". This can be used to emulate that. hookValidRedirectURL func(redirectURL string) error hookUserInfo func(email string) (jwt.MapClaims, error) - hookExtra func(email string) map[string]interface{} + hookMutateToken func(token map[string]interface{}) fakeCoderd func(req *http.Request) (*http.Response, error) hookOnRefresh func(email string) error // Custom authentication for the client. This is useful if you want @@ -115,9 +115,9 @@ func WithRefresh(hook func(email string) error) func(*FakeIDP) { // WithExtra returns extra fields that be accessed on the returned Oauth Token. // These extra fields can override the default fields (id_token, access_token, etc). -func WithExtra(extra func(email string) map[string]interface{}) func(*FakeIDP) { +func WithMutateToken(mutateToken func(token map[string]interface{})) func(*FakeIDP) { return func(f *FakeIDP) { - f.hookExtra = extra + f.hookMutateToken = mutateToken } } @@ -630,10 +630,8 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { "expires_in": int64((time.Minute * 5).Seconds()), "id_token": f.encodeClaims(t, claims), } - if f.hookExtra != nil { - for k, v := range f.hookExtra(email) { - token[k] = v - } + if f.hookMutateToken != nil { + f.hookMutateToken(token) } // Store the claims for the next refresh f.refreshIDTokenClaims.Store(refreshToken, claims) diff --git a/coderd/externalauth/externalauth_test.go b/coderd/externalauth/externalauth_test.go index 1f27c7932f8cf..3c92b76b96729 100644 --- a/coderd/externalauth/externalauth_test.go +++ b/coderd/externalauth/externalauth_test.go @@ -267,11 +267,9 @@ func TestRefreshToken(t *testing.T) { db := dbfake.New() fake, config, link := setupOauth2Test(t, testConfig{ FakeIDPOpts: []oidctest.FakeIDPOpt{ - oidctest.WithExtra(func(email string) map[string]interface{} { - return map[string]interface{}{ - "authed_user": map[string]interface{}{ - "access_token": "slack-user-token", - }, + oidctest.WithMutateToken(func(token map[string]interface{}) { + token["authed_user"] = map[string]interface{}{ + "access_token": "slack-user-token", } }), }, From e2ac6c718e4caa64ee7c8378504db6f1e3ca6b8f Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 9 Oct 2023 22:13:59 +0000 Subject: [PATCH 8/9] Fix assertion --- coderd/externalauth/externalauth_test.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/coderd/externalauth/externalauth_test.go b/coderd/externalauth/externalauth_test.go index 3c92b76b96729..d790c32989ea7 100644 --- a/coderd/externalauth/externalauth_test.go +++ b/coderd/externalauth/externalauth_test.go @@ -2,6 +2,7 @@ package externalauth_test import ( "context" + "encoding/json" "net/http" "net/url" "testing" @@ -269,7 +270,7 @@ func TestRefreshToken(t *testing.T) { FakeIDPOpts: []oidctest.FakeIDPOpt{ oidctest.WithMutateToken(func(token map[string]interface{}) { token["authed_user"] = map[string]interface{}{ - "access_token": "slack-user-token", + "access_token": token["access_token"], } }), }, @@ -289,6 +290,11 @@ func TestRefreshToken(t *testing.T) { require.NoError(t, err) require.True(t, ok) require.True(t, updated.OAuthExtra.Valid) + extra := map[string]interface{}{} + require.NoError(t, json.Unmarshal(updated.OAuthExtra.RawMessage, &extra)) + mapping, ok := extra["authed_user"].(map[string]interface{}) + require.True(t, ok) + require.Equal(t, updated.OAuthAccessToken, mapping["access_token"]) }) } From 8acbb587ff0181634d5104fd36db45deab1e9a40 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 9 Oct 2023 22:32:28 +0000 Subject: [PATCH 9/9] Fix lint --- coderd/database/dbgen/dbgen.go | 2 ++ coderd/externalauth.go | 4 ++++ enterprise/dbcrypt/cliutil.go | 42 ++++++++++++++++++---------------- 3 files changed, 28 insertions(+), 20 deletions(-) diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index b2c462cfec79d..b1146b4f49d81 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -514,6 +514,7 @@ func UserLink(t testing.TB, db database.Store, orig database.UserLink) database. } func ExternalAuthLink(t testing.TB, db database.Store, orig database.ExternalAuthLink) database.ExternalAuthLink { + msg := takeFirst(&orig.OAuthExtra, &pqtype.NullRawMessage{}) link, err := db.InsertExternalAuthLink(genCtx, database.InsertExternalAuthLinkParams{ ProviderID: takeFirst(orig.ProviderID, uuid.New().String()), UserID: takeFirst(orig.UserID, uuid.New()), @@ -524,6 +525,7 @@ func ExternalAuthLink(t testing.TB, db database.Store, orig database.ExternalAut OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)), CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()), UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()), + OAuthExtra: *msg, }) require.NoError(t, err, "insert external auth link") diff --git a/coderd/externalauth.go b/coderd/externalauth.go index 9dc223dd3f6dd..775ff5436284f 100644 --- a/coderd/externalauth.go +++ b/coderd/externalauth.go @@ -14,6 +14,7 @@ import ( "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/codersdk" + "github.com/sqlc-dev/pqtype" ) // @Summary Get external auth by ID @@ -132,6 +133,8 @@ func (api *API) postExternalAuthDeviceByID(rw http.ResponseWriter, r *http.Reque OAuthRefreshToken: token.RefreshToken, OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will set as required OAuthExpiry: token.Expiry, + // No extra data from device auth! + OAuthExtra: pqtype.NullRawMessage{}, }) if err != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ @@ -150,6 +153,7 @@ func (api *API) postExternalAuthDeviceByID(rw http.ResponseWriter, r *http.Reque OAuthRefreshToken: token.RefreshToken, OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required OAuthExpiry: token.Expiry, + OAuthExtra: pqtype.NullRawMessage{}, }) if err != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ diff --git a/enterprise/dbcrypt/cliutil.go b/enterprise/dbcrypt/cliutil.go index 2e375a5112fbf..77986b669bb61 100644 --- a/enterprise/dbcrypt/cliutil.go +++ b/enterprise/dbcrypt/cliutil.go @@ -48,26 +48,27 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe } } - gitAuthLinks, err := cryptTx.GetExternalAuthLinksByUserID(ctx, uid) + externalAuthLinks, err := cryptTx.GetExternalAuthLinksByUserID(ctx, uid) if err != nil { return xerrors.Errorf("get git auth links for user: %w", err) } - for _, gitAuthLink := range gitAuthLinks { - if gitAuthLink.OAuthAccessTokenKeyID.String == ciphers[0].HexDigest() && gitAuthLink.OAuthRefreshTokenKeyID.String == ciphers[0].HexDigest() { - log.Debug(ctx, "skipping git auth link", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) + for _, externalAuthLink := range externalAuthLinks { + if externalAuthLink.OAuthAccessTokenKeyID.String == ciphers[0].HexDigest() && externalAuthLink.OAuthRefreshTokenKeyID.String == ciphers[0].HexDigest() { + log.Debug(ctx, "skipping external auth link", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) continue } if _, err := cryptTx.UpdateExternalAuthLink(ctx, database.UpdateExternalAuthLinkParams{ - ProviderID: gitAuthLink.ProviderID, + ProviderID: externalAuthLink.ProviderID, UserID: uid, - UpdatedAt: gitAuthLink.UpdatedAt, - OAuthAccessToken: gitAuthLink.OAuthAccessToken, + UpdatedAt: externalAuthLink.UpdatedAt, + OAuthAccessToken: externalAuthLink.OAuthAccessToken, OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required - OAuthRefreshToken: gitAuthLink.OAuthRefreshToken, + OAuthRefreshToken: externalAuthLink.OAuthRefreshToken, OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required - OAuthExpiry: gitAuthLink.OAuthExpiry, + OAuthExpiry: externalAuthLink.OAuthExpiry, + OAuthExtra: externalAuthLink.OAuthExtra, }); err != nil { - return xerrors.Errorf("update git auth link user_id=%s provider_id=%s: %w", gitAuthLink.UserID, gitAuthLink.ProviderID, err) + return xerrors.Errorf("update external auth link user_id=%s provider_id=%s: %w", externalAuthLink.UserID, externalAuthLink.ProviderID, err) } } return nil @@ -136,26 +137,27 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph } } - gitAuthLinks, err := tx.GetExternalAuthLinksByUserID(ctx, uid) + externalAuthLinks, err := tx.GetExternalAuthLinksByUserID(ctx, uid) if err != nil { return xerrors.Errorf("get git auth links for user: %w", err) } - for _, gitAuthLink := range gitAuthLinks { - if !gitAuthLink.OAuthAccessTokenKeyID.Valid && !gitAuthLink.OAuthRefreshTokenKeyID.Valid { - log.Debug(ctx, "skipping git auth link", slog.F("user_id", uid), slog.F("current", idx+1)) + for _, externalAuthLink := range externalAuthLinks { + if !externalAuthLink.OAuthAccessTokenKeyID.Valid && !externalAuthLink.OAuthRefreshTokenKeyID.Valid { + log.Debug(ctx, "skipping external auth link", slog.F("user_id", uid), slog.F("current", idx+1)) continue } if _, err := tx.UpdateExternalAuthLink(ctx, database.UpdateExternalAuthLinkParams{ - ProviderID: gitAuthLink.ProviderID, + ProviderID: externalAuthLink.ProviderID, UserID: uid, - UpdatedAt: gitAuthLink.UpdatedAt, - OAuthAccessToken: gitAuthLink.OAuthAccessToken, + UpdatedAt: externalAuthLink.UpdatedAt, + OAuthAccessToken: externalAuthLink.OAuthAccessToken, OAuthAccessTokenKeyID: sql.NullString{}, // we explicitly want to clear the key id - OAuthRefreshToken: gitAuthLink.OAuthRefreshToken, + OAuthRefreshToken: externalAuthLink.OAuthRefreshToken, OAuthRefreshTokenKeyID: sql.NullString{}, // we explicitly want to clear the key id - OAuthExpiry: gitAuthLink.OAuthExpiry, + OAuthExpiry: externalAuthLink.OAuthExpiry, + OAuthExtra: externalAuthLink.OAuthExtra, }); err != nil { - return xerrors.Errorf("update git auth link user_id=%s provider_id=%s: %w", gitAuthLink.UserID, gitAuthLink.ProviderID, err) + return xerrors.Errorf("update external auth link user_id=%s provider_id=%s: %w", externalAuthLink.UserID, externalAuthLink.ProviderID, err) } } return nil