Skip to content

chore: implement user link claims as a typed golang object #15502

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
chore: fixup old references
  • Loading branch information
Emyrk committed Nov 14, 2024
commit 918b1f52a8d3987e8d1f562e6611c67f43c45da6
3 changes: 1 addition & 2 deletions coderd/coderdtest/oidctest/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package oidctest
import (
"context"
"database/sql"
"encoding/json"
"net/http"
"net/url"
"testing"
Expand Down Expand Up @@ -89,7 +88,7 @@ func (*LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *code
OAuthExpiry: time.Now().Add(time.Hour * -1),
UserID: link.UserID,
LoginType: link.LoginType,
DebugContext: json.RawMessage("{}"),
Claims: database.UserLinkClaims{},
})
require.NoError(t, err, "expire user link")

Expand Down
2 changes: 1 addition & 1 deletion coderd/database/dbgen/dbgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ func UserLink(t testing.TB, db database.Store, orig database.UserLink) database.
OAuthRefreshToken: takeFirst(orig.OAuthRefreshToken, uuid.NewString()),
OAuthRefreshTokenKeyID: takeFirst(orig.OAuthRefreshTokenKeyID, sql.NullString{}),
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
DebugContext: takeFirstSlice(orig.DebugContext, json.RawMessage("{}")),
Claims: orig.Claims,
})

require.NoError(t, err, "insert link")
Expand Down
4 changes: 2 additions & 2 deletions coderd/database/dbmem/dbmem.go
Original file line number Diff line number Diff line change
Expand Up @@ -7857,7 +7857,7 @@ func (q *FakeQuerier) InsertUserLink(_ context.Context, args database.InsertUser
OAuthRefreshToken: args.OAuthRefreshToken,
OAuthRefreshTokenKeyID: args.OAuthRefreshTokenKeyID,
OAuthExpiry: args.OAuthExpiry,
DebugContext: args.DebugContext,
Claims: args.Claims,
}

q.userLinks = append(q.userLinks, link)
Expand Down Expand Up @@ -9318,7 +9318,7 @@ func (q *FakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUs
link.OAuthRefreshToken = params.OAuthRefreshToken
link.OAuthRefreshTokenKeyID = params.OAuthRefreshTokenKeyID
link.OAuthExpiry = params.OAuthExpiry
link.DebugContext = params.DebugContext
link.Claims = params.Claims

q.userLinks[i] = link
return link, nil
Expand Down
2 changes: 1 addition & 1 deletion coderd/httpmw/apikey.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
OAuthExpiry: link.OAuthExpiry,
// Refresh should keep the same debug context because we use
// the original claims for the group/role sync.
DebugContext: link.DebugContext,
Claims: link.Claims,
})
if err != nil {
return write(http.StatusInternalServerError, codersdk.Response{
Expand Down
2 changes: 1 addition & 1 deletion coderd/provisionerdserver/provisionerdserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2083,7 +2083,7 @@ func obtainOIDCAccessToken(ctx context.Context, db database.Store, oidcConfig pr
OAuthRefreshToken: link.OAuthRefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // set by dbcrypt if required
OAuthExpiry: link.OAuthExpiry,
DebugContext: link.DebugContext,
Claims: link.Claims,
})
if err != nil {
return "", xerrors.Errorf("update user link: %w", err)
Expand Down
18 changes: 7 additions & 11 deletions coderd/userauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package coderd
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -966,7 +965,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
Username: username,
AvatarURL: ghUser.GetAvatarURL(),
Name: normName,
DebugContext: OauthDebugContext{},
UserClaims: database.UserLinkClaims{},
GroupSync: idpsync.GroupParams{
SyncEntitled: false,
},
Expand Down Expand Up @@ -1324,7 +1323,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
OrganizationSync: orgSync,
GroupSync: groupSync,
RoleSync: roleSync,
DebugContext: OauthDebugContext{
UserClaims: database.UserLinkClaims{
IDTokenClaims: idtokenClaims,
UserInfoClaims: userInfoClaims,
},
Expand Down Expand Up @@ -1421,7 +1420,9 @@ type oauthLoginParams struct {
GroupSync idpsync.GroupParams
RoleSync idpsync.RoleParams

DebugContext OauthDebugContext
// UserClaims should only be populated for OIDC logins.
// It is used to save the user's claims on login.
UserClaims database.UserLinkClaims

commitLock sync.Mutex
initAuditRequest func(params *audit.RequestParams) *audit.Request[database.User]
Expand Down Expand Up @@ -1591,11 +1592,6 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
dormantConvertAudit.New = user
}

debugContext, err := json.Marshal(params.DebugContext)
if err != nil {
return xerrors.Errorf("marshal debug context: %w", err)
}

if link.UserID == uuid.Nil {
//nolint:gocritic // System needs to insert the user link (linked_id, oauth_token, oauth_expiry).
link, err = tx.InsertUserLink(dbauthz.AsSystemRestricted(ctx), database.InsertUserLinkParams{
Expand All @@ -1607,7 +1603,7 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
OAuthRefreshToken: params.State.Token.RefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // set by dbcrypt if required
OAuthExpiry: params.State.Token.Expiry,
DebugContext: debugContext,
Claims: params.UserClaims,
})
if err != nil {
return xerrors.Errorf("insert user link: %w", err)
Expand All @@ -1624,7 +1620,7 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
OAuthRefreshToken: params.State.Token.RefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // set by dbcrypt if required
OAuthExpiry: params.State.Token.Expiry,
DebugContext: debugContext,
Claims: params.UserClaims,
})
if err != nil {
return xerrors.Errorf("update user link: %w", err)
Expand Down
2 changes: 1 addition & 1 deletion coderd/userauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ func TestUserOAuth2Github(t *testing.T) {
OAuthAccessToken: "random",
OAuthRefreshToken: "random",
OAuthExpiry: time.Now(),
DebugContext: []byte(`{}`),
Claims: database.UserLinkClaims{},
})
require.ErrorContains(t, err, "Cannot create user_link for deleted user")

Expand Down
3 changes: 1 addition & 2 deletions coderd/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ func (api *API) userDebugOIDC(rw http.ResponseWriter, r *http.Request) {
return
}

// This will encode properly because it is a json.RawMessage.
httpapi.Write(ctx, rw, http.StatusOK, link.DebugContext)
httpapi.Write(ctx, rw, http.StatusOK, link.Claims)
}

// Returns whether the initial user has been created or not.
Expand Down