Skip to content

chore: implement OIDCClaimFieldValues for idp sync mappings auto complete #15576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions coderd/database/dbauthz/dbauthz.go
Original file line number Diff line number Diff line change
Expand Up @@ -3283,6 +3283,17 @@ func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID
return q.db.ListWorkspaceAgentPortShares(ctx, workspaceID)
}

func (q *querier) OIDCClaimFieldValues(ctx context.Context, args database.OIDCClaimFieldValuesParams) ([]string, error) {
resource := rbac.ResourceIdpsyncSettings
if args.OrganizationID != uuid.Nil {
resource = resource.InOrg(args.OrganizationID)
}
if err := q.authorizeContext(ctx, policy.ActionRead, resource); err != nil {
return nil, err
}
return q.db.OIDCClaimFieldValues(ctx, args)
}

func (q *querier) OIDCClaimFields(ctx context.Context, organizationID uuid.UUID) ([]string, error) {
resource := rbac.ResourceIdpsyncSettings
if organizationID != uuid.Nil {
Expand Down
13 changes: 13 additions & 0 deletions coderd/database/dbauthz/dbauthz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,19 @@ func (s *MethodTestSuite) TestOrganization() {
id := uuid.New()
check.Args(id).Asserts(rbac.ResourceIdpsyncSettings.InOrg(id), policy.ActionRead).Returns([]string{})
}))
s.Run("Deployment/OIDCClaimFieldValues", s.Subtest(func(db database.Store, check *expects) {
check.Args(database.OIDCClaimFieldValuesParams{
ClaimField: "claim-field",
OrganizationID: uuid.Nil,
}).Asserts(rbac.ResourceIdpsyncSettings, policy.ActionRead).Returns([]string{})
}))
s.Run("Organization/OIDCClaimFieldValues", s.Subtest(func(db database.Store, check *expects) {
id := uuid.New()
check.Args(database.OIDCClaimFieldValuesParams{
ClaimField: "claim-field",
OrganizationID: id,
}).Asserts(rbac.ResourceIdpsyncSettings.InOrg(id), policy.ActionRead).Returns([]string{})
}))
s.Run("ByOrganization/GetGroups", s.Subtest(func(db database.Store, check *expects) {
o := dbgen.Organization(s.T(), db, database.Organization{})
a := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
Expand Down
51 changes: 47 additions & 4 deletions coderd/database/dbmem/dbmem.go
Original file line number Diff line number Diff line change
Expand Up @@ -8409,6 +8409,52 @@ func (q *FakeQuerier) ListWorkspaceAgentPortShares(_ context.Context, workspaceI
return shares, nil
}

// nolint:forcetypeassert
func (q *FakeQuerier) OIDCClaimFieldValues(_ context.Context, args database.OIDCClaimFieldValuesParams) ([]string, error) {
orgMembers := q.getOrganizationMemberNoLock(args.OrganizationID)

var values []string
for _, link := range q.userLinks {
if args.OrganizationID != uuid.Nil {
inOrg := slices.ContainsFunc(orgMembers, func(organizationMember database.OrganizationMember) bool {
return organizationMember.UserID == link.UserID
})
if !inOrg {
continue
}
}

if link.LoginType != database.LoginTypeOIDC {
continue
}

if len(link.Claims.MergedClaims) == 0 {
continue
}

value, ok := link.Claims.MergedClaims[args.ClaimField]
if !ok {
continue
}
switch value := value.(type) {
case string:
values = append(values, value)
case []string:
values = append(values, value...)
case []any:
for _, v := range value {
if sv, ok := v.(string); ok {
values = append(values, sv)
}
}
default:
continue
}
}

return slice.Unique(values), nil
}

func (q *FakeQuerier) OIDCClaimFields(_ context.Context, organizationID uuid.UUID) ([]string, error) {
orgMembers := q.getOrganizationMemberNoLock(organizationID)

Expand All @@ -8427,10 +8473,7 @@ func (q *FakeQuerier) OIDCClaimFields(_ context.Context, organizationID uuid.UUI
continue
}

for k := range link.Claims.IDTokenClaims {
fields = append(fields, k)
}
for k := range link.Claims.UserInfoClaims {
for k := range link.Claims.MergedClaims {
fields = append(fields, k)
}
}
Expand Down
7 changes: 7 additions & 0 deletions coderd/database/dbmetrics/querymetrics.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions coderd/database/dbmock/dbmock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 34 additions & 4 deletions coderd/database/oidcclaims_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ func TestOIDCClaims(t *testing.T) {
db, _ := dbtestutil.NewDB(t)
g := userGenerator{t: t, db: db}

const claimField = "claim-list"

// https://en.wikipedia.org/wiki/Alice_and_Bob#Cast_of_characters
alice := g.withLink(database.LoginTypeOIDC, toJSON(extraKeys{
UserLinkClaims: database.UserLinkClaims{
Expand All @@ -43,6 +45,9 @@ func TestOIDCClaims(t *testing.T) {
MergedClaims: map[string]interface{}{
"sub": "alice",
"alice-id": "from-bob",
claimField: []string{
"one", "two", "three",
},
},
},
// Always should be a no-op
Expand Down Expand Up @@ -79,6 +84,9 @@ func TestOIDCClaims(t *testing.T) {
"foo": "bar",
},
"nil": nil,
claimField: []any{
"three", 5, []string{"test"}, "four",
},
},
}))
charlie := g.withLink(database.LoginTypeOIDC, toJSON(database.UserLinkClaims{
Expand All @@ -94,6 +102,7 @@ func TestOIDCClaims(t *testing.T) {
"sub": "charlie",
"charlie-id": "charlie",
"charlie-info": "charlie",
claimField: "charlie",
},
}))

Expand All @@ -113,8 +122,9 @@ func TestOIDCClaims(t *testing.T) {
"do-not": "look",
},
MergedClaims: map[string]interface{}{
"not": "allowed",
"do-not": "look",
"not": "allowed",
"do-not": "look",
claimField: 42,
},
})), // github should be omitted

Expand All @@ -140,12 +150,32 @@ func TestOIDCClaims(t *testing.T) {

// Verify the OIDC claim fields
always := []string{"array", "map", "nil", "number"}
expectA := append([]string{"sub", "alice-id", "bob-id", "bob-info"}, always...)
expectB := append([]string{"sub", "bob-id", "bob-info", "charlie-id", "charlie-info"}, always...)
expectA := append([]string{"sub", "alice-id", "bob-id", "bob-info", "claim-list"}, always...)
expectB := append([]string{"sub", "bob-id", "bob-info", "charlie-id", "charlie-info", "claim-list"}, always...)
requireClaims(t, db, orgA.Org.ID, expectA)
requireClaims(t, db, orgB.Org.ID, expectB)
requireClaims(t, db, orgC.Org.ID, []string{})
requireClaims(t, db, uuid.Nil, slice.Unique(append(expectA, expectB...)))

// Verify the claim field values
expectAValues := []string{"one", "two", "three", "four"}
expectBValues := []string{"three", "four", "charlie"}
requireClaimValues(t, db, orgA.Org.ID, claimField, expectAValues)
requireClaimValues(t, db, orgB.Org.ID, claimField, expectBValues)
requireClaimValues(t, db, orgC.Org.ID, claimField, []string{})
}

func requireClaimValues(t *testing.T, db database.Store, orgID uuid.UUID, field string, want []string) {
t.Helper()

ctx := testutil.Context(t, testutil.WaitMedium)
got, err := db.OIDCClaimFieldValues(ctx, database.OIDCClaimFieldValuesParams{
ClaimField: field,
OrganizationID: orgID,
})
require.NoError(t, err)

require.ElementsMatch(t, want, got)
}

func requireClaims(t *testing.T, db database.Store, orgID uuid.UUID, want []string) {
Expand Down
1 change: 1 addition & 0 deletions coderd/database/querier.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

61 changes: 61 additions & 0 deletions coderd/database/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 33 additions & 1 deletion coderd/database/queries/user_links.sql
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ SET
WHERE
user_id = $7 AND login_type = $8 RETURNING *;


-- name: OIDCClaimFields :many
-- OIDCClaimFields returns a list of distinct keys in the the merged_claims fields.
-- This query is used to generate the list of available sync fields for idp sync settings.
Expand All @@ -78,3 +77,36 @@ WHERE
ELSE true
END
;

-- name: OIDCClaimFieldValues :many
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this being / going to be called?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the Mappings section, this generates the left hand side options.

So if doing Group sync, this shows the user which IDP groups exist.

SELECT
-- DISTINCT to remove duplicates
DISTINCT jsonb_array_elements_text(CASE
-- When the type is an array, filter out any non-string elements.
-- This is to keep the return type consistent.
WHEN jsonb_typeof(claims->'merged_claims'->sqlc.arg('claim_field')::text) = 'array' THEN
(
SELECT
jsonb_agg(element)
FROM
jsonb_array_elements(claims->'merged_claims'->sqlc.arg('claim_field')::text) AS element
WHERE
-- Filtering out non-string elements
jsonb_typeof(element) = 'string'
)
-- Some IDPs return a single string instead of an array of strings.
WHEN jsonb_typeof(claims->'merged_claims'->sqlc.arg('claim_field')::text) = 'string' THEN
jsonb_build_array(claims->'merged_claims'->sqlc.arg('claim_field')::text)
END)
FROM
user_links
WHERE
-- IDP sync only supports string and array (of string) types
jsonb_typeof(claims->'merged_claims'->sqlc.arg('claim_field')::text) = ANY(ARRAY['string', 'array'])
AND login_type = 'oidc'
AND CASE
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
user_links.user_id = ANY(SELECT organization_members.user_id FROM organization_members WHERE organization_id = @organization_id)
ELSE true
END
;
Loading