diff --git a/coderd/coderdtest/oidctest/helper.go b/coderd/coderdtest/oidctest/helper.go index beb1243e2ce74..c817c8ca47e8e 100644 --- a/coderd/coderdtest/oidctest/helper.go +++ b/coderd/coderdtest/oidctest/helper.go @@ -3,7 +3,6 @@ package oidctest import ( "context" "database/sql" - "encoding/json" "net/http" "net/url" "testing" @@ -89,7 +88,7 @@ func (*LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *code OAuthExpiry: time.Now().Add(time.Hour * -1), UserID: link.UserID, LoginType: link.LoginType, - DebugContext: json.RawMessage("{}"), + Claims: database.UserLinkClaims{}, }) require.NoError(t, err, "expire user link") diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index b610efe0349f5..3f55bc5bdaf21 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -1281,7 +1281,7 @@ func (s *MethodTestSuite) TestUser() { OAuthExpiry: link.OAuthExpiry, UserID: link.UserID, LoginType: link.LoginType, - DebugContext: json.RawMessage("{}"), + Claims: database.UserLinkClaims{}, }).Asserts(rbac.ResourceUserObject(link.UserID), policy.ActionUpdatePersonal).Returns(link) })) s.Run("UpdateUserRoles", s.Subtest(func(db database.Store, check *expects) { diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index 5e83125a93b84..5679e16f40bd3 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -726,7 +726,7 @@ func UserLink(t testing.TB, db database.Store, orig database.UserLink) database. OAuthRefreshToken: takeFirst(orig.OAuthRefreshToken, uuid.NewString()), OAuthRefreshTokenKeyID: takeFirst(orig.OAuthRefreshTokenKeyID, sql.NullString{}), OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)), - DebugContext: takeFirstSlice(orig.DebugContext, json.RawMessage("{}")), + Claims: orig.Claims, }) require.NoError(t, err, "insert link") diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 0cb3789fffd3e..d72c8cc5a7636 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -7857,7 +7857,7 @@ func (q *FakeQuerier) InsertUserLink(_ context.Context, args database.InsertUser OAuthRefreshToken: args.OAuthRefreshToken, OAuthRefreshTokenKeyID: args.OAuthRefreshTokenKeyID, OAuthExpiry: args.OAuthExpiry, - DebugContext: args.DebugContext, + Claims: args.Claims, } q.userLinks = append(q.userLinks, link) @@ -9318,7 +9318,7 @@ func (q *FakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUs link.OAuthRefreshToken = params.OAuthRefreshToken link.OAuthRefreshTokenKeyID = params.OAuthRefreshTokenKeyID link.OAuthExpiry = params.OAuthExpiry - link.DebugContext = params.DebugContext + link.Claims = params.Claims q.userLinks[i] = link return link, nil diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 557b5c2dd9325..3ce183f56e351 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1337,14 +1337,14 @@ CREATE TABLE user_links ( oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL, oauth_access_token_key_id text, oauth_refresh_token_key_id text, - debug_context jsonb DEFAULT '{}'::jsonb NOT NULL + claims jsonb DEFAULT '{}'::jsonb NOT NULL ); COMMENT ON COLUMN user_links.oauth_access_token_key_id IS 'The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted'; COMMENT ON COLUMN user_links.oauth_refresh_token_key_id IS 'The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted'; -COMMENT ON COLUMN user_links.debug_context IS 'Debug information includes information like id_token and userinfo claims.'; +COMMENT ON COLUMN user_links.claims IS 'Claims from the IDP for the linked user. Includes both id_token and userinfo claims. '; CREATE TABLE workspace_agent_log_sources ( workspace_agent_id uuid NOT NULL, diff --git a/coderd/database/migrations/000274_rename_user_link_claims.down.sql b/coderd/database/migrations/000274_rename_user_link_claims.down.sql new file mode 100644 index 0000000000000..39ff8803efa48 --- /dev/null +++ b/coderd/database/migrations/000274_rename_user_link_claims.down.sql @@ -0,0 +1,3 @@ +ALTER TABLE user_links RENAME COLUMN claims TO debug_context; + +COMMENT ON COLUMN user_links.debug_context IS 'Debug information includes information like id_token and userinfo claims.'; diff --git a/coderd/database/migrations/000274_rename_user_link_claims.up.sql b/coderd/database/migrations/000274_rename_user_link_claims.up.sql new file mode 100644 index 0000000000000..2f518c2033024 --- /dev/null +++ b/coderd/database/migrations/000274_rename_user_link_claims.up.sql @@ -0,0 +1,3 @@ +ALTER TABLE user_links RENAME COLUMN debug_context TO claims; + +COMMENT ON COLUMN user_links.claims IS 'Claims from the IDP for the linked user. Includes both id_token and userinfo claims. '; diff --git a/coderd/database/models.go b/coderd/database/models.go index 680450a7826d0..d84030107de7f 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -2892,8 +2892,8 @@ type UserLink struct { OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"` // The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"` - // Debug information includes information like id_token and userinfo claims. - DebugContext json.RawMessage `db:"debug_context" json:"debug_context"` + // Claims from the IDP for the linked user. Includes both id_token and userinfo claims. + Claims UserLinkClaims `db:"claims" json:"claims"` } // Visible fields of users are allowed to be joined with other tables for including context of other resources. diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 14afd75403c89..4c7c25ab69a71 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -9689,7 +9689,7 @@ func (q *sqlQuerier) InsertTemplateVersionWorkspaceTag(ctx context.Context, arg const getUserLinkByLinkedID = `-- name: GetUserLinkByLinkedID :one SELECT - user_links.user_id, user_links.login_type, user_links.linked_id, user_links.oauth_access_token, user_links.oauth_refresh_token, user_links.oauth_expiry, user_links.oauth_access_token_key_id, user_links.oauth_refresh_token_key_id, user_links.debug_context + user_links.user_id, user_links.login_type, user_links.linked_id, user_links.oauth_access_token, user_links.oauth_refresh_token, user_links.oauth_expiry, user_links.oauth_access_token_key_id, user_links.oauth_refresh_token_key_id, user_links.claims FROM user_links INNER JOIN @@ -9712,14 +9712,14 @@ func (q *sqlQuerier) GetUserLinkByLinkedID(ctx context.Context, linkedID string) &i.OAuthExpiry, &i.OAuthAccessTokenKeyID, &i.OAuthRefreshTokenKeyID, - &i.DebugContext, + &i.Claims, ) return i, err } const getUserLinkByUserIDLoginType = `-- name: GetUserLinkByUserIDLoginType :one SELECT - user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, debug_context + user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, claims FROM user_links WHERE @@ -9743,13 +9743,13 @@ func (q *sqlQuerier) GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUs &i.OAuthExpiry, &i.OAuthAccessTokenKeyID, &i.OAuthRefreshTokenKeyID, - &i.DebugContext, + &i.Claims, ) return i, err } const getUserLinksByUserID = `-- name: GetUserLinksByUserID :many -SELECT user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, debug_context FROM user_links WHERE user_id = $1 +SELECT user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, claims FROM user_links WHERE user_id = $1 ` func (q *sqlQuerier) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]UserLink, error) { @@ -9770,7 +9770,7 @@ func (q *sqlQuerier) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) &i.OAuthExpiry, &i.OAuthAccessTokenKeyID, &i.OAuthRefreshTokenKeyID, - &i.DebugContext, + &i.Claims, ); err != nil { return nil, err } @@ -9796,22 +9796,22 @@ INSERT INTO oauth_refresh_token, oauth_refresh_token_key_id, oauth_expiry, - debug_context + claims ) VALUES - ( $1, $2, $3, $4, $5, $6, $7, $8, $9 ) RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, debug_context + ( $1, $2, $3, $4, $5, $6, $7, $8, $9 ) RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, claims ` type InsertUserLinkParams struct { - UserID uuid.UUID `db:"user_id" json:"user_id"` - LoginType LoginType `db:"login_type" json:"login_type"` - LinkedID string `db:"linked_id" json:"linked_id"` - OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` - OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"` - OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` - OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"` - OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` - DebugContext json.RawMessage `db:"debug_context" json:"debug_context"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + LoginType LoginType `db:"login_type" json:"login_type"` + LinkedID string `db:"linked_id" json:"linked_id"` + OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` + OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"` + OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` + OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"` + OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` + Claims UserLinkClaims `db:"claims" json:"claims"` } func (q *sqlQuerier) InsertUserLink(ctx context.Context, arg InsertUserLinkParams) (UserLink, error) { @@ -9824,7 +9824,7 @@ func (q *sqlQuerier) InsertUserLink(ctx context.Context, arg InsertUserLinkParam arg.OAuthRefreshToken, arg.OAuthRefreshTokenKeyID, arg.OAuthExpiry, - arg.DebugContext, + arg.Claims, ) var i UserLink err := row.Scan( @@ -9836,7 +9836,7 @@ func (q *sqlQuerier) InsertUserLink(ctx context.Context, arg InsertUserLinkParam &i.OAuthExpiry, &i.OAuthAccessTokenKeyID, &i.OAuthRefreshTokenKeyID, - &i.DebugContext, + &i.Claims, ) return i, err } @@ -9850,20 +9850,20 @@ SET oauth_refresh_token = $3, oauth_refresh_token_key_id = $4, oauth_expiry = $5, - debug_context = $6 + claims = $6 WHERE - user_id = $7 AND login_type = $8 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, debug_context + user_id = $7 AND login_type = $8 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, claims ` type UpdateUserLinkParams struct { - OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` - OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"` - OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` - OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"` - OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` - DebugContext json.RawMessage `db:"debug_context" json:"debug_context"` - UserID uuid.UUID `db:"user_id" json:"user_id"` - LoginType LoginType `db:"login_type" json:"login_type"` + OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` + OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"` + OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` + OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"` + OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` + Claims UserLinkClaims `db:"claims" json:"claims"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + LoginType LoginType `db:"login_type" json:"login_type"` } func (q *sqlQuerier) UpdateUserLink(ctx context.Context, arg UpdateUserLinkParams) (UserLink, error) { @@ -9873,7 +9873,7 @@ func (q *sqlQuerier) UpdateUserLink(ctx context.Context, arg UpdateUserLinkParam arg.OAuthRefreshToken, arg.OAuthRefreshTokenKeyID, arg.OAuthExpiry, - arg.DebugContext, + arg.Claims, arg.UserID, arg.LoginType, ) @@ -9887,7 +9887,7 @@ func (q *sqlQuerier) UpdateUserLink(ctx context.Context, arg UpdateUserLinkParam &i.OAuthExpiry, &i.OAuthAccessTokenKeyID, &i.OAuthRefreshTokenKeyID, - &i.DebugContext, + &i.Claims, ) return i, err } @@ -9898,7 +9898,7 @@ UPDATE SET linked_id = $1 WHERE - user_id = $2 AND login_type = $3 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, debug_context + user_id = $2 AND login_type = $3 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, claims ` type UpdateUserLinkedIDParams struct { @@ -9919,7 +9919,7 @@ func (q *sqlQuerier) UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinke &i.OAuthExpiry, &i.OAuthAccessTokenKeyID, &i.OAuthRefreshTokenKeyID, - &i.DebugContext, + &i.Claims, ) return i, err } diff --git a/coderd/database/queries/user_links.sql b/coderd/database/queries/user_links.sql index 9fc0e6f9d7598..d0d52c3eac054 100644 --- a/coderd/database/queries/user_links.sql +++ b/coderd/database/queries/user_links.sql @@ -32,7 +32,7 @@ INSERT INTO oauth_refresh_token, oauth_refresh_token_key_id, oauth_expiry, - debug_context + claims ) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9 ) RETURNING *; @@ -54,6 +54,6 @@ SET oauth_refresh_token = $3, oauth_refresh_token_key_id = $4, oauth_expiry = $5, - debug_context = $6 + claims = $6 WHERE user_id = $7 AND login_type = $8 RETURNING *; diff --git a/coderd/database/sqlc.yaml b/coderd/database/sqlc.yaml index 2161feb47e1c3..8f73570077636 100644 --- a/coderd/database/sqlc.yaml +++ b/coderd/database/sqlc.yaml @@ -79,6 +79,9 @@ sql: - column: "provisioner_job_stats.*_secs" go_type: type: "float64" + - column: "user_links.claims" + go_type: + type: "UserLinkClaims" rename: group_member: GroupMemberTable group_members_expanded: GroupMember diff --git a/coderd/database/types.go b/coderd/database/types.go index 8e22258382abb..188825cea6eb7 100644 --- a/coderd/database/types.go +++ b/coderd/database/types.go @@ -207,3 +207,25 @@ func (p *AgentIDNamePair) Scan(src interface{}) error { func (p AgentIDNamePair) Value() (driver.Value, error) { return fmt.Sprintf(`(%s,%s)`, p.ID.String(), p.Name), nil } + +// UserLinkClaims is the returned IDP claims for a given user link. +// These claims are fetched at login time. These are the claims that were +// used for IDP sync. +type UserLinkClaims struct { + IDTokenClaims map[string]interface{} `json:"id_token_claims"` + UserInfoClaims map[string]interface{} `json:"user_info_claims"` +} + +func (a *UserLinkClaims) Scan(src interface{}) error { + switch v := src.(type) { + case string: + return json.Unmarshal([]byte(v), &a) + case []byte: + return json.Unmarshal(v, &a) + } + return xerrors.Errorf("unexpected type %T", src) +} + +func (a UserLinkClaims) Value() (driver.Value, error) { + return json.Marshal(a) +} diff --git a/coderd/httpmw/apikey.go b/coderd/httpmw/apikey.go index f6746b95eb20e..38ba74031ba46 100644 --- a/coderd/httpmw/apikey.go +++ b/coderd/httpmw/apikey.go @@ -377,7 +377,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon OAuthExpiry: link.OAuthExpiry, // Refresh should keep the same debug context because we use // the original claims for the group/role sync. - DebugContext: link.DebugContext, + Claims: link.Claims, }) if err != nil { return write(http.StatusInternalServerError, codersdk.Response{ diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index fdbf0bcd98212..5fc0f198102c9 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -2083,7 +2083,7 @@ func obtainOIDCAccessToken(ctx context.Context, db database.Store, oidcConfig pr OAuthRefreshToken: link.OAuthRefreshToken, OAuthRefreshTokenKeyID: sql.NullString{}, // set by dbcrypt if required OAuthExpiry: link.OAuthExpiry, - DebugContext: link.DebugContext, + Claims: link.Claims, }) if err != nil { return "", xerrors.Errorf("update user link: %w", err) diff --git a/coderd/userauth.go b/coderd/userauth.go index e7db9e9719c35..7d8e6c2bbc44f 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -3,7 +3,6 @@ package coderd import ( "context" "database/sql" - "encoding/json" "errors" "fmt" "net/http" @@ -966,7 +965,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { Username: username, AvatarURL: ghUser.GetAvatarURL(), Name: normName, - DebugContext: OauthDebugContext{}, + UserClaims: database.UserLinkClaims{}, GroupSync: idpsync.GroupParams{ SyncEntitled: false, }, @@ -1324,7 +1323,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { OrganizationSync: orgSync, GroupSync: groupSync, RoleSync: roleSync, - DebugContext: OauthDebugContext{ + UserClaims: database.UserLinkClaims{ IDTokenClaims: idtokenClaims, UserInfoClaims: userInfoClaims, }, @@ -1421,7 +1420,9 @@ type oauthLoginParams struct { GroupSync idpsync.GroupParams RoleSync idpsync.RoleParams - DebugContext OauthDebugContext + // UserClaims should only be populated for OIDC logins. + // It is used to save the user's claims on login. + UserClaims database.UserLinkClaims commitLock sync.Mutex initAuditRequest func(params *audit.RequestParams) *audit.Request[database.User] @@ -1591,11 +1592,6 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C dormantConvertAudit.New = user } - debugContext, err := json.Marshal(params.DebugContext) - if err != nil { - return xerrors.Errorf("marshal debug context: %w", err) - } - if link.UserID == uuid.Nil { //nolint:gocritic // System needs to insert the user link (linked_id, oauth_token, oauth_expiry). link, err = tx.InsertUserLink(dbauthz.AsSystemRestricted(ctx), database.InsertUserLinkParams{ @@ -1607,7 +1603,7 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C OAuthRefreshToken: params.State.Token.RefreshToken, OAuthRefreshTokenKeyID: sql.NullString{}, // set by dbcrypt if required OAuthExpiry: params.State.Token.Expiry, - DebugContext: debugContext, + Claims: params.UserClaims, }) if err != nil { return xerrors.Errorf("insert user link: %w", err) @@ -1624,7 +1620,7 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C OAuthRefreshToken: params.State.Token.RefreshToken, OAuthRefreshTokenKeyID: sql.NullString{}, // set by dbcrypt if required OAuthExpiry: params.State.Token.Expiry, - DebugContext: debugContext, + Claims: params.UserClaims, }) if err != nil { return xerrors.Errorf("update user link: %w", err) diff --git a/coderd/userauth_test.go b/coderd/userauth_test.go index 36713f6a6ae40..a7032364c1807 100644 --- a/coderd/userauth_test.go +++ b/coderd/userauth_test.go @@ -843,7 +843,7 @@ func TestUserOAuth2Github(t *testing.T) { OAuthAccessToken: "random", OAuthRefreshToken: "random", OAuthExpiry: time.Now(), - DebugContext: []byte(`{}`), + Claims: database.UserLinkClaims{}, }) require.ErrorContains(t, err, "Cannot create user_link for deleted user") diff --git a/coderd/users.go b/coderd/users.go index 4978e12a788b9..2fccef83f2013 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -70,8 +70,7 @@ func (api *API) userDebugOIDC(rw http.ResponseWriter, r *http.Request) { return } - // This will encode properly because it is a json.RawMessage. - httpapi.Write(ctx, rw, http.StatusOK, link.DebugContext) + httpapi.Write(ctx, rw, http.StatusOK, link.Claims) } // Returns whether the initial user has been created or not. diff --git a/enterprise/dbcrypt/cliutil.go b/enterprise/dbcrypt/cliutil.go index 47045f9bfefab..120b41972de05 100644 --- a/enterprise/dbcrypt/cliutil.go +++ b/enterprise/dbcrypt/cliutil.go @@ -43,7 +43,7 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe OAuthExpiry: userLink.OAuthExpiry, UserID: uid, LoginType: userLink.LoginType, - DebugContext: userLink.DebugContext, + Claims: userLink.Claims, }); err != nil { return xerrors.Errorf("update user link user_id=%s linked_id=%s: %w", userLink.UserID, userLink.LinkedID, err) } @@ -133,7 +133,7 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph OAuthExpiry: userLink.OAuthExpiry, UserID: uid, LoginType: userLink.LoginType, - DebugContext: userLink.DebugContext, + Claims: userLink.Claims, }); err != nil { return xerrors.Errorf("update user link user_id=%s linked_id=%s: %w", userLink.UserID, userLink.LinkedID, err) } diff --git a/enterprise/dbcrypt/dbcrypt_internal_test.go b/enterprise/dbcrypt/dbcrypt_internal_test.go index 8800180493d12..10d56b50a074c 100644 --- a/enterprise/dbcrypt/dbcrypt_internal_test.go +++ b/enterprise/dbcrypt/dbcrypt_internal_test.go @@ -5,7 +5,6 @@ import ( "crypto/rand" "database/sql" "encoding/base64" - "encoding/json" "io" "testing" "time" @@ -52,12 +51,27 @@ func TestUserLinks(t *testing.T) { UserID: user.ID, }) + expectedClaims := database.UserLinkClaims{ + IDTokenClaims: map[string]interface{}{ + "sub": "123", + "groups": []interface{}{ + "foo", "bar", + }, + }, + UserInfoClaims: map[string]interface{}{ + "number": float64(2), + "struct": map[string]interface{}{ + "number": float64(2), + }, + }, + } + updated, err := crypt.UpdateUserLink(ctx, database.UpdateUserLinkParams{ OAuthAccessToken: "access", OAuthRefreshToken: "refresh", UserID: link.UserID, LoginType: link.LoginType, - DebugContext: json.RawMessage("{}"), + Claims: expectedClaims, }) require.NoError(t, err) require.Equal(t, "access", updated.OAuthAccessToken) @@ -69,6 +83,7 @@ func TestUserLinks(t *testing.T) { require.NoError(t, err) requireEncryptedEquals(t, ciphers[0], rawLink.OAuthAccessToken, "access") requireEncryptedEquals(t, ciphers[0], rawLink.OAuthRefreshToken, "refresh") + require.EqualValues(t, expectedClaims, rawLink.Claims) }) t.Run("GetUserLinkByLinkedID", func(t *testing.T) {