Skip to content

feat: add sourcing secondary claims from access_token #16517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion cli/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,17 @@ func createOIDCConfig(ctx context.Context, logger slog.Logger, vals *codersdk.De
groupAllowList[group] = true
}

secondaryClaimsSrc := coderd.MergedClaimsSourceUserInfo
if !vals.OIDC.IgnoreUserInfo && vals.OIDC.UserInfoFromAccessToken {
return nil, xerrors.Errorf("to use 'oidc-access-token-claims', 'oidc-ignore-userinfo' must be set to 'false'")
}
if vals.OIDC.IgnoreUserInfo {
secondaryClaimsSrc = coderd.MergedClaimsSourceNone
}
if vals.OIDC.UserInfoFromAccessToken {
secondaryClaimsSrc = coderd.MergedClaimsSourceAccessToken
}

return &coderd.OIDCConfig{
OAuth2Config: useCfg,
Provider: oidcProvider,
Expand All @@ -187,7 +198,7 @@ func createOIDCConfig(ctx context.Context, logger slog.Logger, vals *codersdk.De
NameField: vals.OIDC.NameField.String(),
EmailField: vals.OIDC.EmailField.String(),
AuthURLParams: vals.OIDC.AuthURLParams.Value,
IgnoreUserInfo: vals.OIDC.IgnoreUserInfo.Value(),
SecondaryClaims: secondaryClaimsSrc,
SignInText: vals.OIDC.SignInText.String(),
SignupsDisabledText: vals.OIDC.SignupsDisabledText.String(),
IconURL: vals.OIDC.IconURL.String(),
Expand Down
6 changes: 6 additions & 0 deletions cli/testdata/server-config.yaml.golden
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,12 @@ oidc:
# Ignore the userinfo endpoint and only use the ID token for user information.
# (default: false, type: bool)
ignoreUserInfo: false
# Source supplemental user claims from the 'access_token'. This assumes the token
# is a jwt signed by the same issuer as the id_token. Using this requires setting
# 'oidc-ignore-userinfo' to true. This setting is not compliant with the OIDC
# specification and is not recommended. Use at your own risk.
# (default: false, type: bool)
accessTokenClaims: false
# This field must be set if using the organization sync feature. Set to the claim
# to be used for organizations.
# (default: <unset>, type: string)
Expand Down
5 changes: 5 additions & 0 deletions coderd/apidoc/docs.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions coderd/apidoc/swagger.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 24 additions & 7 deletions coderd/coderdtest/oidctest/idp.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ type FakeIDP struct {
// "Authorized Redirect URLs". This can be used to emulate that.
hookValidRedirectURL func(redirectURL string) error
hookUserInfo func(email string) (jwt.MapClaims, error)
hookAccessTokenJWT func(email string, exp time.Time) jwt.MapClaims
// defaultIDClaims is if a new client connects and we didn't preset
// some claims.
defaultIDClaims jwt.MapClaims
Expand Down Expand Up @@ -154,6 +155,12 @@ func WithMiddlewares(mws ...func(http.Handler) http.Handler) func(*FakeIDP) {
}
}

func WithAccessTokenJWTHook(hook func(email string, exp time.Time) jwt.MapClaims) func(*FakeIDP) {
return func(f *FakeIDP) {
f.hookAccessTokenJWT = hook
}
}

func WithHookWellKnown(hook func(r *http.Request, j *ProviderJSON) error) func(*FakeIDP) {
return func(f *FakeIDP) {
f.hookWellKnown = hook
Expand Down Expand Up @@ -316,8 +323,7 @@ const (
func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
t.Helper()

block, _ := pem.Decode([]byte(testRSAPrivateKey))
pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
pkey, err := FakeIDPKey()
require.NoError(t, err)

idp := &FakeIDP{
Expand Down Expand Up @@ -676,8 +682,13 @@ func (f *FakeIDP) newCode(state string) string {

// newToken enforces the access token exchanged is actually a valid access token
// created by the IDP.
func (f *FakeIDP) newToken(email string, expires time.Time) string {
func (f *FakeIDP) newToken(t testing.TB, email string, expires time.Time) string {
accessToken := uuid.NewString()
if f.hookAccessTokenJWT != nil {
claims := f.hookAccessTokenJWT(email, expires)
accessToken = f.encodeClaims(t, claims)
}

f.accessTokens.Store(accessToken, token{
issued: time.Now(),
email: email,
Expand Down Expand Up @@ -963,7 +974,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
email := getEmail(claims)
refreshToken := f.newRefreshTokens(email)
token := map[string]interface{}{
"access_token": f.newToken(email, exp),
"access_token": f.newToken(t, email, exp),
"refresh_token": refreshToken,
"token_type": "Bearer",
"expires_in": int64((f.defaultExpire).Seconds()),
Expand Down Expand Up @@ -1465,9 +1476,10 @@ func (f *FakeIDP) internalOIDCConfig(ctx context.Context, t testing.TB, scopes [
Verifier: oidc.NewVerifier(f.provider.Issuer, &oidc.StaticKeySet{
PublicKeys: []crypto.PublicKey{f.key.Public()},
}, verifierConfig),
UsernameField: "preferred_username",
EmailField: "email",
AuthURLParams: map[string]string{"access_type": "offline"},
UsernameField: "preferred_username",
EmailField: "email",
AuthURLParams: map[string]string{"access_type": "offline"},
SecondaryClaims: coderd.MergedClaimsSourceUserInfo,
}

for _, opt := range opts {
Expand Down Expand Up @@ -1552,3 +1564,8 @@ d8h4Ht09E+f3nhTEc87mODkl7WJZpHL6V2sORfeq/eIkds+H6CJ4hy5w/bSw8tjf
sz9Di8sGIaUbLZI2rd0CQQCzlVwEtRtoNCyMJTTrkgUuNufLP19RZ5FpyXxBO5/u
QastnN77KfUwdj3SJt44U/uh1jAIv4oSLBr8HYUkbnI8
-----END RSA PRIVATE KEY-----`

func FakeIDPKey() (*rsa.PrivateKey, error) {
block, _ := pem.Decode([]byte(testRSAPrivateKey))
return x509.ParsePKCS1PrivateKey(block.Bytes)
}
151 changes: 106 additions & 45 deletions coderd/userauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ import (
"github.com/coder/coder/v2/cryptorand"
)

type MergedClaimsSource string

var (
MergedClaimsSourceNone MergedClaimsSource = "none"
MergedClaimsSourceUserInfo MergedClaimsSource = "user_info"
MergedClaimsSourceAccessToken MergedClaimsSource = "access_token"
)

const (
userAuthLoggerName = "userauth"
OAuthConvertCookieValue = "coder_oauth_convert_jwt"
Expand Down Expand Up @@ -1042,11 +1050,13 @@ type OIDCConfig struct {
// AuthURLParams are additional parameters to be passed to the OIDC provider
// when requesting an access token.
AuthURLParams map[string]string
// IgnoreUserInfo causes Coder to only use claims from the ID token to
// process OIDC logins. This is useful if the OIDC provider does not
// support the userinfo endpoint, or if the userinfo endpoint causes
// undesirable behavior.
IgnoreUserInfo bool
// SecondaryClaims indicates where to source additional claim information from.
// The standard is either 'MergedClaimsSourceNone' or 'MergedClaimsSourceUserInfo'.
//
// The OIDC compliant way is to use the userinfo endpoint. This option
// is useful when the userinfo endpoint does not exist or causes undesirable
// behavior.
SecondaryClaims MergedClaimsSource
// SignInText is the text to display on the OIDC login button
SignInText string
// IconURL points to the URL of an icon to display on the OIDC login button
Expand Down Expand Up @@ -1142,50 +1152,39 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
// Some providers (e.g. ADFS) do not support custom OIDC claims in the
// UserInfo endpoint, so we allow users to disable it and only rely on the
// ID token.
userInfoClaims := make(map[string]interface{})
//
// If user info is skipped, the idtokenClaims are the claims.
mergedClaims := idtokenClaims
if !api.OIDCConfig.IgnoreUserInfo {
userInfo, err := api.OIDCConfig.Provider.UserInfo(ctx, oauth2.StaticTokenSource(state.Token))
if err == nil {
err = userInfo.Claims(&userInfoClaims)
if err != nil {
logger.Error(ctx, "oauth2: unable to unmarshal user info claims", slog.Error(err))
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to unmarshal user info claims.",
Detail: err.Error(),
})
return
}
logger.Debug(ctx, "got oidc claims",
slog.F("source", "userinfo"),
slog.F("claim_fields", claimFields(userInfoClaims)),
slog.F("blank", blankFields(userInfoClaims)),
)

// Merge the claims from the ID token and the UserInfo endpoint.
// Information from UserInfo takes precedence.
mergedClaims = mergeClaims(idtokenClaims, userInfoClaims)
supplementaryClaims := make(map[string]interface{})
switch api.OIDCConfig.SecondaryClaims {
case MergedClaimsSourceUserInfo:
supplementaryClaims, ok = api.userInfoClaims(ctx, rw, state, logger)
if !ok {
return
}

// Log all of the field names after merging.
logger.Debug(ctx, "got oidc claims",
slog.F("source", "merged"),
slog.F("claim_fields", claimFields(mergedClaims)),
slog.F("blank", blankFields(mergedClaims)),
)
} else if !strings.Contains(err.Error(), "user info endpoint is not supported by this provider") {
logger.Error(ctx, "oauth2: unable to obtain user information claims", slog.Error(err))
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to obtain user information claims.",
Detail: "The attempt to fetch claims via the UserInfo endpoint failed: " + err.Error(),
})
// The precedence ordering is userInfoClaims > idTokenClaims.
// Note: Unsure why exactly this is the case. idTokenClaims feels more
// important?
mergedClaims = mergeClaims(idtokenClaims, supplementaryClaims)
case MergedClaimsSourceAccessToken:
supplementaryClaims, ok = api.accessTokenClaims(ctx, rw, state, logger)
if !ok {
return
} else {
// The OIDC provider does not support the UserInfo endpoint.
// This is not an error, but we should log it as it may mean
// that some claims are missing.
logger.Warn(ctx, "OIDC provider does not support the user info endpoint, ensure that all required claims are present in the id_token")
}
// idTokenClaims take priority over accessTokenClaims. The order should
// not matter. It is just safer to assume idTokenClaims is the truth,
// and accessTokenClaims are supplemental.
mergedClaims = mergeClaims(supplementaryClaims, idtokenClaims)
case MergedClaimsSourceNone:
// noop, keep the userInfoClaims empty
default:
// This should never happen and is a developer error
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Invalid source for secondary user claims.",
Detail: fmt.Sprintf("invalid source: %q", api.OIDCConfig.SecondaryClaims),
})
return // Invalid MergedClaimsSource
}

usernameRaw, ok := mergedClaims[api.OIDCConfig.UsernameField]
Expand Down Expand Up @@ -1339,7 +1338,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
RoleSync: roleSync,
UserClaims: database.UserLinkClaims{
IDTokenClaims: idtokenClaims,
UserInfoClaims: userInfoClaims,
UserInfoClaims: supplementaryClaims,
MergedClaims: mergedClaims,
},
}).SetInitAuditRequest(func(params *audit.RequestParams) (*audit.Request[database.User], func()) {
Expand Down Expand Up @@ -1373,6 +1372,68 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect)
}

func (api *API) accessTokenClaims(ctx context.Context, rw http.ResponseWriter, state httpmw.OAuth2State, logger slog.Logger) (accessTokenClaims map[string]interface{}, ok bool) {
// Assume the access token is a jwt, and signed by the provider.
accessToken, err := api.OIDCConfig.Verifier.Verify(ctx, state.Token.AccessToken)
if err != nil {
logger.Error(ctx, "oauth2: unable to verify access token as secondary claims source", slog.Error(err))
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to verify access token.",
Detail: fmt.Sprintf("sourcing secondary claims from access token: %s", err.Error()),
})
return nil, false
}

rawClaims := make(map[string]any)
err = accessToken.Claims(&rawClaims)
if err != nil {
logger.Error(ctx, "oauth2: unable to unmarshal access token claims", slog.Error(err))
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to unmarshal access token claims.",
Detail: err.Error(),
})
return nil, false
}

return rawClaims, true
}

func (api *API) userInfoClaims(ctx context.Context, rw http.ResponseWriter, state httpmw.OAuth2State, logger slog.Logger) (userInfoClaims map[string]interface{}, ok bool) {
userInfoClaims = make(map[string]interface{})
userInfo, err := api.OIDCConfig.Provider.UserInfo(ctx, oauth2.StaticTokenSource(state.Token))
if err == nil {
err = userInfo.Claims(&userInfoClaims)
if err != nil {
logger.Error(ctx, "oauth2: unable to unmarshal user info claims", slog.Error(err))
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to unmarshal user info claims.",
Detail: err.Error(),
})
return nil, false
}
logger.Debug(ctx, "got oidc claims",
slog.F("source", "userinfo"),
slog.F("claim_fields", claimFields(userInfoClaims)),
slog.F("blank", blankFields(userInfoClaims)),
)
} else if !strings.Contains(err.Error(), "user info endpoint is not supported by this provider") {
logger.Error(ctx, "oauth2: unable to obtain user information claims", slog.Error(err))
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to obtain user information claims.",
Detail: "The attempt to fetch claims via the UserInfo endpoint failed: " + err.Error(),
})
return nil, false
} else {
// The OIDC provider does not support the UserInfo endpoint.
// This is not an error, but we should log it as it may mean
// that some claims are missing.
logger.Warn(ctx, "OIDC provider does not support the user info endpoint, ensure that all required claims are present in the id_token",
slog.Error(err),
)
}
return userInfoClaims, true
}

// claimFields returns the sorted list of fields in the claims map.
func claimFields(claims map[string]interface{}) []string {
fields := []string{}
Expand Down
Loading
Loading