Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
chore: unit tests for OIDCClaimFieldValues and fixup sql arg types
  • Loading branch information
Emyrk committed Nov 18, 2024
commit 2ee347f9452522f8153015c1b41c979f4a89e7b4
12 changes: 10 additions & 2 deletions coderd/database/dbauthz/dbauthz.go
Original file line number Diff line number Diff line change
Expand Up @@ -3283,8 +3283,16 @@ func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID
return q.db.ListWorkspaceAgentPortShares(ctx, workspaceID)
}

func (q *querier) OIDCClaimFieldValues(ctx context.Context, organizationID uuid.UUID) ([]string, error) {
panic("not implemented")
func (q *querier) OIDCClaimFieldValues(ctx context.Context, args database.OIDCClaimFieldValuesParams) ([]string, error) {
resource := rbac.ResourceIdpsyncSettings
if args.OrganizationID != uuid.Nil {
resource = resource.InOrg(args.OrganizationID)
}
if err := q.authorizeContext(ctx, policy.ActionRead, resource); err != nil {
return nil, err
}
return q.db.OIDCClaimFieldValues(ctx, args)

}

func (q *querier) OIDCClaimFields(ctx context.Context, organizationID uuid.UUID) ([]string, error) {
Expand Down
50 changes: 44 additions & 6 deletions coderd/database/dbmem/dbmem.go
Original file line number Diff line number Diff line change
Expand Up @@ -8409,8 +8409,49 @@ func (q *FakeQuerier) ListWorkspaceAgentPortShares(_ context.Context, workspaceI
return shares, nil
}

func (q *FakeQuerier) OIDCClaimFieldValues(ctx context.Context, organizationID uuid.UUID) ([]string, error) {
panic("not implemented")
func (q *FakeQuerier) OIDCClaimFieldValues(ctx context.Context, args database.OIDCClaimFieldValuesParams) ([]string, error) {
orgMembers := q.getOrganizationMemberNoLock(args.OrganizationID)

var values []string
for _, link := range q.userLinks {
if args.OrganizationID != uuid.Nil {
inOrg := slices.ContainsFunc(orgMembers, func(organizationMember database.OrganizationMember) bool {
return organizationMember.UserID == link.UserID
})
if !inOrg {
continue
}
}

if link.LoginType != database.LoginTypeOIDC {
continue
}

if len(link.Claims.MergedClaims) == 0 {
continue
}

value, ok := link.Claims.MergedClaims[args.ClaimField]
if !ok {
continue
}
switch value.(type) {
case string:
values = append(values, value.(string))
case []string:
values = append(values, value.([]string)...)
case []any:
for _, v := range value.([]any) {
if sv, ok := v.(string); ok {
values = append(values, sv)
}
}
default:
continue
}
}

return slice.Unique(values), nil
}

func (q *FakeQuerier) OIDCClaimFields(_ context.Context, organizationID uuid.UUID) ([]string, error) {
Expand All @@ -8431,10 +8472,7 @@ func (q *FakeQuerier) OIDCClaimFields(_ context.Context, organizationID uuid.UUI
continue
}

for k := range link.Claims.IDTokenClaims {
fields = append(fields, k)
}
for k := range link.Claims.UserInfoClaims {
for k := range link.Claims.MergedClaims {
fields = append(fields, k)
}
}
Expand Down
2 changes: 1 addition & 1 deletion coderd/database/dbmetrics/querymetrics.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 34 additions & 4 deletions coderd/database/oidcclaims_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ func TestOIDCClaims(t *testing.T) {
db, _ := dbtestutil.NewDB(t)
g := userGenerator{t: t, db: db}

const claimField = "claim-list"

// https://en.wikipedia.org/wiki/Alice_and_Bob#Cast_of_characters
alice := g.withLink(database.LoginTypeOIDC, toJSON(extraKeys{
UserLinkClaims: database.UserLinkClaims{
Expand All @@ -43,6 +45,9 @@ func TestOIDCClaims(t *testing.T) {
MergedClaims: map[string]interface{}{
"sub": "alice",
"alice-id": "from-bob",
claimField: []string{
"one", "two", "three",
},
},
},
// Always should be a no-op
Expand Down Expand Up @@ -79,6 +84,9 @@ func TestOIDCClaims(t *testing.T) {
"foo": "bar",
},
"nil": nil,
claimField: []any{
"three", 5, []string{"test"}, "four",
},
},
}))
charlie := g.withLink(database.LoginTypeOIDC, toJSON(database.UserLinkClaims{
Expand All @@ -94,6 +102,7 @@ func TestOIDCClaims(t *testing.T) {
"sub": "charlie",
"charlie-id": "charlie",
"charlie-info": "charlie",
claimField: "charlie",
},
}))

Expand All @@ -113,8 +122,9 @@ func TestOIDCClaims(t *testing.T) {
"do-not": "look",
},
MergedClaims: map[string]interface{}{
"not": "allowed",
"do-not": "look",
"not": "allowed",
"do-not": "look",
claimField: 42,
},
})), // github should be omitted

Expand All @@ -140,12 +150,32 @@ func TestOIDCClaims(t *testing.T) {

// Verify the OIDC claim fields
always := []string{"array", "map", "nil", "number"}
expectA := append([]string{"sub", "alice-id", "bob-id", "bob-info"}, always...)
expectB := append([]string{"sub", "bob-id", "bob-info", "charlie-id", "charlie-info"}, always...)
expectA := append([]string{"sub", "alice-id", "bob-id", "bob-info", "claim-list"}, always...)
expectB := append([]string{"sub", "bob-id", "bob-info", "charlie-id", "charlie-info", "claim-list"}, always...)
requireClaims(t, db, orgA.Org.ID, expectA)
requireClaims(t, db, orgB.Org.ID, expectB)
requireClaims(t, db, orgC.Org.ID, []string{})
requireClaims(t, db, uuid.Nil, slice.Unique(append(expectA, expectB...)))

// Verify the claim field values
expectAValues := []string{"one", "two", "three", "four"}
expectBValues := []string{"three", "four", "charlie"}
requireClaimValues(t, db, orgA.Org.ID, claimField, expectAValues)
requireClaimValues(t, db, orgB.Org.ID, claimField, expectBValues)
requireClaimValues(t, db, orgC.Org.ID, claimField, []string{})
}

func requireClaimValues(t *testing.T, db database.Store, orgID uuid.UUID, field string, want []string) {
t.Helper()

ctx := testutil.Context(t, testutil.WaitMedium)
got, err := db.OIDCClaimFieldValues(ctx, database.OIDCClaimFieldValuesParams{
ClaimField: field,
OrganizationID: orgID,
})
require.NoError(t, err)

require.ElementsMatch(t, want, got)
}

func requireClaims(t *testing.T, db database.Store, orgID uuid.UUID, want []string) {
Expand Down
2 changes: 1 addition & 1 deletion coderd/database/querier.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 15 additions & 9 deletions coderd/database/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 8 additions & 7 deletions coderd/database/queries/user_links.sql
Original file line number Diff line number Diff line change
Expand Up @@ -84,28 +84,29 @@ SELECT
DISTINCT jsonb_array_elements_text(CASE
-- When the type is an array, filter out any non-string elements.
-- This is to keep the return type consistent.
WHEN jsonb_typeof(claims->'merged_claims'->'groups') = 'array' THEN
WHEN jsonb_typeof(claims->'merged_claims'->sqlc.arg('claim_field')::text) = 'array' THEN
(
SELECT
jsonb_agg(element)
FROM
jsonb_array_elements(claims->'merged_claims'->@claim_field) AS element
jsonb_array_elements(claims->'merged_claims'->sqlc.arg('claim_field')::text) AS element
WHERE
-- Filtering out non-string elements
jsonb_typeof(element) = 'string'
)
-- Some IDPs return a single string instead of an array of strings.
WHEN jsonb_typeof(claims->'merged_claims'->'groups') = 'string' THEN
jsonb_build_array(claims->'merged_claims'->@claim_field)
WHEN jsonb_typeof(claims->'merged_claims'->sqlc.arg('claim_field')::text) = 'string' THEN
jsonb_build_array(claims->'merged_claims'->sqlc.arg('claim_field')::text)
END)
FROM
user_links
WHERE
-- IDP sync only supports string and array (of string) types
jsonb_typeof(claims->'merged_claims'->@claim_field) = ANY(ARRAY['string', 'array'])
jsonb_typeof(claims->'merged_claims'->sqlc.arg('claim_field')::text) = ANY(ARRAY['string', 'array'])
AND login_type = 'oidc'
AND CASE WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
user_links.user_id = ANY(SELECT organization_members.user_id FROM organization_members WHERE organization_id = @organization_id)
AND CASE
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
user_links.user_id = ANY(SELECT organization_members.user_id FROM organization_members WHERE organization_id = @organization_id)
ELSE true
END
;
Loading