Skip to content

Commit 2ee347f

Browse files
committed
chore: unit tests for OIDCClaimFieldValues and fixup sql arg types
1 parent 1a5dc30 commit 2ee347f

File tree

7 files changed

+113
-30
lines changed

7 files changed

+113
-30
lines changed

coderd/database/dbauthz/dbauthz.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3283,8 +3283,16 @@ func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID
32833283
return q.db.ListWorkspaceAgentPortShares(ctx, workspaceID)
32843284
}
32853285

3286-
func (q *querier) OIDCClaimFieldValues(ctx context.Context, organizationID uuid.UUID) ([]string, error) {
3287-
panic("not implemented")
3286+
func (q *querier) OIDCClaimFieldValues(ctx context.Context, args database.OIDCClaimFieldValuesParams) ([]string, error) {
3287+
resource := rbac.ResourceIdpsyncSettings
3288+
if args.OrganizationID != uuid.Nil {
3289+
resource = resource.InOrg(args.OrganizationID)
3290+
}
3291+
if err := q.authorizeContext(ctx, policy.ActionRead, resource); err != nil {
3292+
return nil, err
3293+
}
3294+
return q.db.OIDCClaimFieldValues(ctx, args)
3295+
32883296
}
32893297

32903298
func (q *querier) OIDCClaimFields(ctx context.Context, organizationID uuid.UUID) ([]string, error) {

coderd/database/dbmem/dbmem.go

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8409,8 +8409,49 @@ func (q *FakeQuerier) ListWorkspaceAgentPortShares(_ context.Context, workspaceI
84098409
return shares, nil
84108410
}
84118411

8412-
func (q *FakeQuerier) OIDCClaimFieldValues(ctx context.Context, organizationID uuid.UUID) ([]string, error) {
8413-
panic("not implemented")
8412+
func (q *FakeQuerier) OIDCClaimFieldValues(ctx context.Context, args database.OIDCClaimFieldValuesParams) ([]string, error) {
8413+
orgMembers := q.getOrganizationMemberNoLock(args.OrganizationID)
8414+
8415+
var values []string
8416+
for _, link := range q.userLinks {
8417+
if args.OrganizationID != uuid.Nil {
8418+
inOrg := slices.ContainsFunc(orgMembers, func(organizationMember database.OrganizationMember) bool {
8419+
return organizationMember.UserID == link.UserID
8420+
})
8421+
if !inOrg {
8422+
continue
8423+
}
8424+
}
8425+
8426+
if link.LoginType != database.LoginTypeOIDC {
8427+
continue
8428+
}
8429+
8430+
if len(link.Claims.MergedClaims) == 0 {
8431+
continue
8432+
}
8433+
8434+
value, ok := link.Claims.MergedClaims[args.ClaimField]
8435+
if !ok {
8436+
continue
8437+
}
8438+
switch value.(type) {
8439+
case string:
8440+
values = append(values, value.(string))
8441+
case []string:
8442+
values = append(values, value.([]string)...)
8443+
case []any:
8444+
for _, v := range value.([]any) {
8445+
if sv, ok := v.(string); ok {
8446+
values = append(values, sv)
8447+
}
8448+
}
8449+
default:
8450+
continue
8451+
}
8452+
}
8453+
8454+
return slice.Unique(values), nil
84148455
}
84158456

84168457
func (q *FakeQuerier) OIDCClaimFields(_ context.Context, organizationID uuid.UUID) ([]string, error) {
@@ -8431,10 +8472,7 @@ func (q *FakeQuerier) OIDCClaimFields(_ context.Context, organizationID uuid.UUI
84318472
continue
84328473
}
84338474

8434-
for k := range link.Claims.IDTokenClaims {
8435-
fields = append(fields, k)
8436-
}
8437-
for k := range link.Claims.UserInfoClaims {
8475+
for k := range link.Claims.MergedClaims {
84388476
fields = append(fields, k)
84398477
}
84408478
}

coderd/database/dbmetrics/querymetrics.go

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/oidcclaims_test.go

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ func TestOIDCClaims(t *testing.T) {
3232
db, _ := dbtestutil.NewDB(t)
3333
g := userGenerator{t: t, db: db}
3434

35+
const claimField = "claim-list"
36+
3537
// https://en.wikipedia.org/wiki/Alice_and_Bob#Cast_of_characters
3638
alice := g.withLink(database.LoginTypeOIDC, toJSON(extraKeys{
3739
UserLinkClaims: database.UserLinkClaims{
@@ -43,6 +45,9 @@ func TestOIDCClaims(t *testing.T) {
4345
MergedClaims: map[string]interface{}{
4446
"sub": "alice",
4547
"alice-id": "from-bob",
48+
claimField: []string{
49+
"one", "two", "three",
50+
},
4651
},
4752
},
4853
// Always should be a no-op
@@ -79,6 +84,9 @@ func TestOIDCClaims(t *testing.T) {
7984
"foo": "bar",
8085
},
8186
"nil": nil,
87+
claimField: []any{
88+
"three", 5, []string{"test"}, "four",
89+
},
8290
},
8391
}))
8492
charlie := g.withLink(database.LoginTypeOIDC, toJSON(database.UserLinkClaims{
@@ -94,6 +102,7 @@ func TestOIDCClaims(t *testing.T) {
94102
"sub": "charlie",
95103
"charlie-id": "charlie",
96104
"charlie-info": "charlie",
105+
claimField: "charlie",
97106
},
98107
}))
99108

@@ -113,8 +122,9 @@ func TestOIDCClaims(t *testing.T) {
113122
"do-not": "look",
114123
},
115124
MergedClaims: map[string]interface{}{
116-
"not": "allowed",
117-
"do-not": "look",
125+
"not": "allowed",
126+
"do-not": "look",
127+
claimField: 42,
118128
},
119129
})), // github should be omitted
120130

@@ -140,12 +150,32 @@ func TestOIDCClaims(t *testing.T) {
140150

141151
// Verify the OIDC claim fields
142152
always := []string{"array", "map", "nil", "number"}
143-
expectA := append([]string{"sub", "alice-id", "bob-id", "bob-info"}, always...)
144-
expectB := append([]string{"sub", "bob-id", "bob-info", "charlie-id", "charlie-info"}, always...)
153+
expectA := append([]string{"sub", "alice-id", "bob-id", "bob-info", "claim-list"}, always...)
154+
expectB := append([]string{"sub", "bob-id", "bob-info", "charlie-id", "charlie-info", "claim-list"}, always...)
145155
requireClaims(t, db, orgA.Org.ID, expectA)
146156
requireClaims(t, db, orgB.Org.ID, expectB)
147157
requireClaims(t, db, orgC.Org.ID, []string{})
148158
requireClaims(t, db, uuid.Nil, slice.Unique(append(expectA, expectB...)))
159+
160+
// Verify the claim field values
161+
expectAValues := []string{"one", "two", "three", "four"}
162+
expectBValues := []string{"three", "four", "charlie"}
163+
requireClaimValues(t, db, orgA.Org.ID, claimField, expectAValues)
164+
requireClaimValues(t, db, orgB.Org.ID, claimField, expectBValues)
165+
requireClaimValues(t, db, orgC.Org.ID, claimField, []string{})
166+
}
167+
168+
func requireClaimValues(t *testing.T, db database.Store, orgID uuid.UUID, field string, want []string) {
169+
t.Helper()
170+
171+
ctx := testutil.Context(t, testutil.WaitMedium)
172+
got, err := db.OIDCClaimFieldValues(ctx, database.OIDCClaimFieldValuesParams{
173+
ClaimField: field,
174+
OrganizationID: orgID,
175+
})
176+
require.NoError(t, err)
177+
178+
require.ElementsMatch(t, want, got)
149179
}
150180

151181
func requireClaims(t *testing.T, db database.Store, orgID uuid.UUID, want []string) {

coderd/database/querier.go

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries.sql.go

Lines changed: 15 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/user_links.sql

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,28 +84,29 @@ SELECT
8484
DISTINCT jsonb_array_elements_text(CASE
8585
-- When the type is an array, filter out any non-string elements.
8686
-- This is to keep the return type consistent.
87-
WHEN jsonb_typeof(claims->'merged_claims'->'groups') = 'array' THEN
87+
WHEN jsonb_typeof(claims->'merged_claims'->sqlc.arg('claim_field')::text) = 'array' THEN
8888
(
8989
SELECT
9090
jsonb_agg(element)
9191
FROM
92-
jsonb_array_elements(claims->'merged_claims'->@claim_field) AS element
92+
jsonb_array_elements(claims->'merged_claims'->sqlc.arg('claim_field')::text) AS element
9393
WHERE
9494
-- Filtering out non-string elements
9595
jsonb_typeof(element) = 'string'
9696
)
9797
-- Some IDPs return a single string instead of an array of strings.
98-
WHEN jsonb_typeof(claims->'merged_claims'->'groups') = 'string' THEN
99-
jsonb_build_array(claims->'merged_claims'->@claim_field)
98+
WHEN jsonb_typeof(claims->'merged_claims'->sqlc.arg('claim_field')::text) = 'string' THEN
99+
jsonb_build_array(claims->'merged_claims'->sqlc.arg('claim_field')::text)
100100
END)
101101
FROM
102102
user_links
103103
WHERE
104104
-- IDP sync only supports string and array (of string) types
105-
jsonb_typeof(claims->'merged_claims'->@claim_field) = ANY(ARRAY['string', 'array'])
105+
jsonb_typeof(claims->'merged_claims'->sqlc.arg('claim_field')::text) = ANY(ARRAY['string', 'array'])
106106
AND login_type = 'oidc'
107-
AND CASE WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
108-
user_links.user_id = ANY(SELECT organization_members.user_id FROM organization_members WHERE organization_id = @organization_id)
107+
AND CASE
108+
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
109+
user_links.user_id = ANY(SELECT organization_members.user_id FROM organization_members WHERE organization_id = @organization_id)
109110
ELSE true
110111
END
111112
;

0 commit comments

Comments
 (0)