Skip to content

chore: implement user link claims as a typed golang object #15502

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions coderd/coderdtest/oidctest/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package oidctest
import (
"context"
"database/sql"
"encoding/json"
"net/http"
"net/url"
"testing"
Expand Down Expand Up @@ -89,7 +88,7 @@ func (*LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *code
OAuthExpiry: time.Now().Add(time.Hour * -1),
UserID: link.UserID,
LoginType: link.LoginType,
DebugContext: json.RawMessage("{}"),
Claims: database.UserLinkClaims{},
})
require.NoError(t, err, "expire user link")

Expand Down
2 changes: 1 addition & 1 deletion coderd/database/dbauthz/dbauthz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1281,7 +1281,7 @@ func (s *MethodTestSuite) TestUser() {
OAuthExpiry: link.OAuthExpiry,
UserID: link.UserID,
LoginType: link.LoginType,
DebugContext: json.RawMessage("{}"),
Claims: database.UserLinkClaims{},
}).Asserts(rbac.ResourceUserObject(link.UserID), policy.ActionUpdatePersonal).Returns(link)
}))
s.Run("UpdateUserRoles", s.Subtest(func(db database.Store, check *expects) {
Expand Down
2 changes: 1 addition & 1 deletion coderd/database/dbgen/dbgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ func UserLink(t testing.TB, db database.Store, orig database.UserLink) database.
OAuthRefreshToken: takeFirst(orig.OAuthRefreshToken, uuid.NewString()),
OAuthRefreshTokenKeyID: takeFirst(orig.OAuthRefreshTokenKeyID, sql.NullString{}),
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
DebugContext: takeFirstSlice(orig.DebugContext, json.RawMessage("{}")),
Claims: orig.Claims,
})

require.NoError(t, err, "insert link")
Expand Down
4 changes: 2 additions & 2 deletions coderd/database/dbmem/dbmem.go
Original file line number Diff line number Diff line change
Expand Up @@ -7857,7 +7857,7 @@ func (q *FakeQuerier) InsertUserLink(_ context.Context, args database.InsertUser
OAuthRefreshToken: args.OAuthRefreshToken,
OAuthRefreshTokenKeyID: args.OAuthRefreshTokenKeyID,
OAuthExpiry: args.OAuthExpiry,
DebugContext: args.DebugContext,
Claims: args.Claims,
}

q.userLinks = append(q.userLinks, link)
Expand Down Expand Up @@ -9318,7 +9318,7 @@ func (q *FakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUs
link.OAuthRefreshToken = params.OAuthRefreshToken
link.OAuthRefreshTokenKeyID = params.OAuthRefreshTokenKeyID
link.OAuthExpiry = params.OAuthExpiry
link.DebugContext = params.DebugContext
link.Claims = params.Claims

q.userLinks[i] = link
return link, nil
Expand Down
4 changes: 2 additions & 2 deletions coderd/database/dump.sql

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ALTER TABLE user_links RENAME COLUMN claims TO debug_context;

COMMENT ON COLUMN user_links.debug_context IS 'Debug information includes information like id_token and userinfo claims.';
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ALTER TABLE user_links RENAME COLUMN debug_context TO claims;

COMMENT ON COLUMN user_links.claims IS 'Claims from the IDP for the linked user. Includes both id_token and userinfo claims. ';
4 changes: 2 additions & 2 deletions coderd/database/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

66 changes: 33 additions & 33 deletions coderd/database/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions coderd/database/queries/user_links.sql
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ INSERT INTO
oauth_refresh_token,
oauth_refresh_token_key_id,
oauth_expiry,
debug_context
claims
)
VALUES
( $1, $2, $3, $4, $5, $6, $7, $8, $9 ) RETURNING *;
Expand All @@ -54,6 +54,6 @@ SET
oauth_refresh_token = $3,
oauth_refresh_token_key_id = $4,
oauth_expiry = $5,
debug_context = $6
claims = $6
WHERE
user_id = $7 AND login_type = $8 RETURNING *;
3 changes: 3 additions & 0 deletions coderd/database/sqlc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ sql:
- column: "provisioner_job_stats.*_secs"
go_type:
type: "float64"
- column: "user_links.claims"
go_type:
type: "UserLinkClaims"
rename:
group_member: GroupMemberTable
group_members_expanded: GroupMember
Expand Down
22 changes: 22 additions & 0 deletions coderd/database/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,25 @@ func (p *AgentIDNamePair) Scan(src interface{}) error {
func (p AgentIDNamePair) Value() (driver.Value, error) {
return fmt.Sprintf(`(%s,%s)`, p.ID.String(), p.Name), nil
}

// UserLinkClaims is the returned IDP claims for a given user link.
// These claims are fetched at login time. These are the claims that were
// used for IDP sync.
type UserLinkClaims struct {
IDTokenClaims map[string]interface{} `json:"id_token_claims"`
UserInfoClaims map[string]interface{} `json:"user_info_claims"`
}

func (a *UserLinkClaims) Scan(src interface{}) error {
switch v := src.(type) {
case string:
return json.Unmarshal([]byte(v), &a)
case []byte:
return json.Unmarshal(v, &a)
}
return xerrors.Errorf("unexpected type %T", src)
}

func (a UserLinkClaims) Value() (driver.Value, error) {
return json.Marshal(a)
}
2 changes: 1 addition & 1 deletion coderd/httpmw/apikey.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
OAuthExpiry: link.OAuthExpiry,
// Refresh should keep the same debug context because we use
// the original claims for the group/role sync.
DebugContext: link.DebugContext,
Claims: link.Claims,
})
if err != nil {
return write(http.StatusInternalServerError, codersdk.Response{
Expand Down
2 changes: 1 addition & 1 deletion coderd/provisionerdserver/provisionerdserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2083,7 +2083,7 @@ func obtainOIDCAccessToken(ctx context.Context, db database.Store, oidcConfig pr
OAuthRefreshToken: link.OAuthRefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // set by dbcrypt if required
OAuthExpiry: link.OAuthExpiry,
DebugContext: link.DebugContext,
Claims: link.Claims,
})
if err != nil {
return "", xerrors.Errorf("update user link: %w", err)
Expand Down
18 changes: 7 additions & 11 deletions coderd/userauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package coderd
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -966,7 +965,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
Username: username,
AvatarURL: ghUser.GetAvatarURL(),
Name: normName,
DebugContext: OauthDebugContext{},
UserClaims: database.UserLinkClaims{},
GroupSync: idpsync.GroupParams{
SyncEntitled: false,
},
Expand Down Expand Up @@ -1324,7 +1323,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
OrganizationSync: orgSync,
GroupSync: groupSync,
RoleSync: roleSync,
DebugContext: OauthDebugContext{
UserClaims: database.UserLinkClaims{
IDTokenClaims: idtokenClaims,
UserInfoClaims: userInfoClaims,
},
Expand Down Expand Up @@ -1421,7 +1420,9 @@ type oauthLoginParams struct {
GroupSync idpsync.GroupParams
RoleSync idpsync.RoleParams

DebugContext OauthDebugContext
// UserClaims should only be populated for OIDC logins.
// It is used to save the user's claims on login.
UserClaims database.UserLinkClaims

commitLock sync.Mutex
initAuditRequest func(params *audit.RequestParams) *audit.Request[database.User]
Expand Down Expand Up @@ -1591,11 +1592,6 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
dormantConvertAudit.New = user
}

debugContext, err := json.Marshal(params.DebugContext)
if err != nil {
return xerrors.Errorf("marshal debug context: %w", err)
}

if link.UserID == uuid.Nil {
//nolint:gocritic // System needs to insert the user link (linked_id, oauth_token, oauth_expiry).
link, err = tx.InsertUserLink(dbauthz.AsSystemRestricted(ctx), database.InsertUserLinkParams{
Expand All @@ -1607,7 +1603,7 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
OAuthRefreshToken: params.State.Token.RefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // set by dbcrypt if required
OAuthExpiry: params.State.Token.Expiry,
DebugContext: debugContext,
Claims: params.UserClaims,
})
if err != nil {
return xerrors.Errorf("insert user link: %w", err)
Expand All @@ -1624,7 +1620,7 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
OAuthRefreshToken: params.State.Token.RefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // set by dbcrypt if required
OAuthExpiry: params.State.Token.Expiry,
DebugContext: debugContext,
Claims: params.UserClaims,
})
if err != nil {
return xerrors.Errorf("update user link: %w", err)
Expand Down
Loading
Loading