Skip to content

Commit 26438aa

Browse files
authored
chore: implement OIDCClaimFieldValues for idp sync mappings auto complete (coder#15576)
When creating IDP sync mappings, these are the values that can be selected from. These are the values that can be mapped from in org/group/role sync.
1 parent 5b7fa78 commit 26438aa

File tree

9 files changed

+222
-9
lines changed

9 files changed

+222
-9
lines changed

coderd/database/dbauthz/dbauthz.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3283,6 +3283,17 @@ func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID
32833283
return q.db.ListWorkspaceAgentPortShares(ctx, workspaceID)
32843284
}
32853285

3286+
func (q *querier) OIDCClaimFieldValues(ctx context.Context, args database.OIDCClaimFieldValuesParams) ([]string, error) {
3287+
resource := rbac.ResourceIdpsyncSettings
3288+
if args.OrganizationID != uuid.Nil {
3289+
resource = resource.InOrg(args.OrganizationID)
3290+
}
3291+
if err := q.authorizeContext(ctx, policy.ActionRead, resource); err != nil {
3292+
return nil, err
3293+
}
3294+
return q.db.OIDCClaimFieldValues(ctx, args)
3295+
}
3296+
32863297
func (q *querier) OIDCClaimFields(ctx context.Context, organizationID uuid.UUID) ([]string, error) {
32873298
resource := rbac.ResourceIdpsyncSettings
32883299
if organizationID != uuid.Nil {

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,19 @@ func (s *MethodTestSuite) TestOrganization() {
633633
id := uuid.New()
634634
check.Args(id).Asserts(rbac.ResourceIdpsyncSettings.InOrg(id), policy.ActionRead).Returns([]string{})
635635
}))
636+
s.Run("Deployment/OIDCClaimFieldValues", s.Subtest(func(db database.Store, check *expects) {
637+
check.Args(database.OIDCClaimFieldValuesParams{
638+
ClaimField: "claim-field",
639+
OrganizationID: uuid.Nil,
640+
}).Asserts(rbac.ResourceIdpsyncSettings, policy.ActionRead).Returns([]string{})
641+
}))
642+
s.Run("Organization/OIDCClaimFieldValues", s.Subtest(func(db database.Store, check *expects) {
643+
id := uuid.New()
644+
check.Args(database.OIDCClaimFieldValuesParams{
645+
ClaimField: "claim-field",
646+
OrganizationID: id,
647+
}).Asserts(rbac.ResourceIdpsyncSettings.InOrg(id), policy.ActionRead).Returns([]string{})
648+
}))
636649
s.Run("ByOrganization/GetGroups", s.Subtest(func(db database.Store, check *expects) {
637650
o := dbgen.Organization(s.T(), db, database.Organization{})
638651
a := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})

coderd/database/dbmem/dbmem.go

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8409,6 +8409,52 @@ func (q *FakeQuerier) ListWorkspaceAgentPortShares(_ context.Context, workspaceI
84098409
return shares, nil
84108410
}
84118411

8412+
// nolint:forcetypeassert
8413+
func (q *FakeQuerier) OIDCClaimFieldValues(_ context.Context, args database.OIDCClaimFieldValuesParams) ([]string, error) {
8414+
orgMembers := q.getOrganizationMemberNoLock(args.OrganizationID)
8415+
8416+
var values []string
8417+
for _, link := range q.userLinks {
8418+
if args.OrganizationID != uuid.Nil {
8419+
inOrg := slices.ContainsFunc(orgMembers, func(organizationMember database.OrganizationMember) bool {
8420+
return organizationMember.UserID == link.UserID
8421+
})
8422+
if !inOrg {
8423+
continue
8424+
}
8425+
}
8426+
8427+
if link.LoginType != database.LoginTypeOIDC {
8428+
continue
8429+
}
8430+
8431+
if len(link.Claims.MergedClaims) == 0 {
8432+
continue
8433+
}
8434+
8435+
value, ok := link.Claims.MergedClaims[args.ClaimField]
8436+
if !ok {
8437+
continue
8438+
}
8439+
switch value := value.(type) {
8440+
case string:
8441+
values = append(values, value)
8442+
case []string:
8443+
values = append(values, value...)
8444+
case []any:
8445+
for _, v := range value {
8446+
if sv, ok := v.(string); ok {
8447+
values = append(values, sv)
8448+
}
8449+
}
8450+
default:
8451+
continue
8452+
}
8453+
}
8454+
8455+
return slice.Unique(values), nil
8456+
}
8457+
84128458
func (q *FakeQuerier) OIDCClaimFields(_ context.Context, organizationID uuid.UUID) ([]string, error) {
84138459
orgMembers := q.getOrganizationMemberNoLock(organizationID)
84148460

@@ -8427,10 +8473,7 @@ func (q *FakeQuerier) OIDCClaimFields(_ context.Context, organizationID uuid.UUI
84278473
continue
84288474
}
84298475

8430-
for k := range link.Claims.IDTokenClaims {
8431-
fields = append(fields, k)
8432-
}
8433-
for k := range link.Claims.UserInfoClaims {
8476+
for k := range link.Claims.MergedClaims {
84348477
fields = append(fields, k)
84358478
}
84368479
}

coderd/database/dbmetrics/querymetrics.go

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/oidcclaims_test.go

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ func TestOIDCClaims(t *testing.T) {
3232
db, _ := dbtestutil.NewDB(t)
3333
g := userGenerator{t: t, db: db}
3434

35+
const claimField = "claim-list"
36+
3537
// https://en.wikipedia.org/wiki/Alice_and_Bob#Cast_of_characters
3638
alice := g.withLink(database.LoginTypeOIDC, toJSON(extraKeys{
3739
UserLinkClaims: database.UserLinkClaims{
@@ -43,6 +45,9 @@ func TestOIDCClaims(t *testing.T) {
4345
MergedClaims: map[string]interface{}{
4446
"sub": "alice",
4547
"alice-id": "from-bob",
48+
claimField: []string{
49+
"one", "two", "three",
50+
},
4651
},
4752
},
4853
// Always should be a no-op
@@ -79,6 +84,9 @@ func TestOIDCClaims(t *testing.T) {
7984
"foo": "bar",
8085
},
8186
"nil": nil,
87+
claimField: []any{
88+
"three", 5, []string{"test"}, "four",
89+
},
8290
},
8391
}))
8492
charlie := g.withLink(database.LoginTypeOIDC, toJSON(database.UserLinkClaims{
@@ -94,6 +102,7 @@ func TestOIDCClaims(t *testing.T) {
94102
"sub": "charlie",
95103
"charlie-id": "charlie",
96104
"charlie-info": "charlie",
105+
claimField: "charlie",
97106
},
98107
}))
99108

@@ -113,8 +122,9 @@ func TestOIDCClaims(t *testing.T) {
113122
"do-not": "look",
114123
},
115124
MergedClaims: map[string]interface{}{
116-
"not": "allowed",
117-
"do-not": "look",
125+
"not": "allowed",
126+
"do-not": "look",
127+
claimField: 42,
118128
},
119129
})), // github should be omitted
120130

@@ -140,12 +150,32 @@ func TestOIDCClaims(t *testing.T) {
140150

141151
// Verify the OIDC claim fields
142152
always := []string{"array", "map", "nil", "number"}
143-
expectA := append([]string{"sub", "alice-id", "bob-id", "bob-info"}, always...)
144-
expectB := append([]string{"sub", "bob-id", "bob-info", "charlie-id", "charlie-info"}, always...)
153+
expectA := append([]string{"sub", "alice-id", "bob-id", "bob-info", "claim-list"}, always...)
154+
expectB := append([]string{"sub", "bob-id", "bob-info", "charlie-id", "charlie-info", "claim-list"}, always...)
145155
requireClaims(t, db, orgA.Org.ID, expectA)
146156
requireClaims(t, db, orgB.Org.ID, expectB)
147157
requireClaims(t, db, orgC.Org.ID, []string{})
148158
requireClaims(t, db, uuid.Nil, slice.Unique(append(expectA, expectB...)))
159+
160+
// Verify the claim field values
161+
expectAValues := []string{"one", "two", "three", "four"}
162+
expectBValues := []string{"three", "four", "charlie"}
163+
requireClaimValues(t, db, orgA.Org.ID, claimField, expectAValues)
164+
requireClaimValues(t, db, orgB.Org.ID, claimField, expectBValues)
165+
requireClaimValues(t, db, orgC.Org.ID, claimField, []string{})
166+
}
167+
168+
func requireClaimValues(t *testing.T, db database.Store, orgID uuid.UUID, field string, want []string) {
169+
t.Helper()
170+
171+
ctx := testutil.Context(t, testutil.WaitMedium)
172+
got, err := db.OIDCClaimFieldValues(ctx, database.OIDCClaimFieldValuesParams{
173+
ClaimField: field,
174+
OrganizationID: orgID,
175+
})
176+
require.NoError(t, err)
177+
178+
require.ElementsMatch(t, want, got)
149179
}
150180

151181
func requireClaims(t *testing.T, db database.Store, orgID uuid.UUID, want []string) {

coderd/database/querier.go

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries.sql.go

Lines changed: 61 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/user_links.sql

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ SET
5858
WHERE
5959
user_id = $7 AND login_type = $8 RETURNING *;
6060

61-
6261
-- name: OIDCClaimFields :many
6362
-- OIDCClaimFields returns a list of distinct keys in the the merged_claims fields.
6463
-- This query is used to generate the list of available sync fields for idp sync settings.
@@ -78,3 +77,36 @@ WHERE
7877
ELSE true
7978
END
8079
;
80+
81+
-- name: OIDCClaimFieldValues :many
82+
SELECT
83+
-- DISTINCT to remove duplicates
84+
DISTINCT jsonb_array_elements_text(CASE
85+
-- When the type is an array, filter out any non-string elements.
86+
-- This is to keep the return type consistent.
87+
WHEN jsonb_typeof(claims->'merged_claims'->sqlc.arg('claim_field')::text) = 'array' THEN
88+
(
89+
SELECT
90+
jsonb_agg(element)
91+
FROM
92+
jsonb_array_elements(claims->'merged_claims'->sqlc.arg('claim_field')::text) AS element
93+
WHERE
94+
-- Filtering out non-string elements
95+
jsonb_typeof(element) = 'string'
96+
)
97+
-- Some IDPs return a single string instead of an array of strings.
98+
WHEN jsonb_typeof(claims->'merged_claims'->sqlc.arg('claim_field')::text) = 'string' THEN
99+
jsonb_build_array(claims->'merged_claims'->sqlc.arg('claim_field')::text)
100+
END)
101+
FROM
102+
user_links
103+
WHERE
104+
-- IDP sync only supports string and array (of string) types
105+
jsonb_typeof(claims->'merged_claims'->sqlc.arg('claim_field')::text) = ANY(ARRAY['string', 'array'])
106+
AND login_type = 'oidc'
107+
AND CASE
108+
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
109+
user_links.user_id = ANY(SELECT organization_members.user_id FROM organization_members WHERE organization_id = @organization_id)
110+
ELSE true
111+
END
112+
;

0 commit comments

Comments
 (0)