diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 4845ff22288fe..2e9a85f8ba578 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -3283,6 +3283,17 @@ func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID return q.db.ListWorkspaceAgentPortShares(ctx, workspaceID) } +func (q *querier) OIDCClaimFieldValues(ctx context.Context, args database.OIDCClaimFieldValuesParams) ([]string, error) { + resource := rbac.ResourceIdpsyncSettings + if args.OrganizationID != uuid.Nil { + resource = resource.InOrg(args.OrganizationID) + } + if err := q.authorizeContext(ctx, policy.ActionRead, resource); err != nil { + return nil, err + } + return q.db.OIDCClaimFieldValues(ctx, args) +} + func (q *querier) OIDCClaimFields(ctx context.Context, organizationID uuid.UUID) ([]string, error) { resource := rbac.ResourceIdpsyncSettings if organizationID != uuid.Nil { diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 2eb75f8b738c4..ef572f0a11247 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -633,6 +633,19 @@ func (s *MethodTestSuite) TestOrganization() { id := uuid.New() check.Args(id).Asserts(rbac.ResourceIdpsyncSettings.InOrg(id), policy.ActionRead).Returns([]string{}) })) + s.Run("Deployment/OIDCClaimFieldValues", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.OIDCClaimFieldValuesParams{ + ClaimField: "claim-field", + OrganizationID: uuid.Nil, + }).Asserts(rbac.ResourceIdpsyncSettings, policy.ActionRead).Returns([]string{}) + })) + s.Run("Organization/OIDCClaimFieldValues", s.Subtest(func(db database.Store, check *expects) { + id := uuid.New() + check.Args(database.OIDCClaimFieldValuesParams{ + ClaimField: "claim-field", + OrganizationID: id, + }).Asserts(rbac.ResourceIdpsyncSettings.InOrg(id), policy.ActionRead).Returns([]string{}) + })) s.Run("ByOrganization/GetGroups", s.Subtest(func(db database.Store, check *expects) { o := dbgen.Organization(s.T(), db, database.Organization{}) a := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index aed57e9284b3a..2be0a8f583bb7 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -8409,6 +8409,52 @@ func (q *FakeQuerier) ListWorkspaceAgentPortShares(_ context.Context, workspaceI return shares, nil } +// nolint:forcetypeassert +func (q *FakeQuerier) OIDCClaimFieldValues(_ context.Context, args database.OIDCClaimFieldValuesParams) ([]string, error) { + orgMembers := q.getOrganizationMemberNoLock(args.OrganizationID) + + var values []string + for _, link := range q.userLinks { + if args.OrganizationID != uuid.Nil { + inOrg := slices.ContainsFunc(orgMembers, func(organizationMember database.OrganizationMember) bool { + return organizationMember.UserID == link.UserID + }) + if !inOrg { + continue + } + } + + if link.LoginType != database.LoginTypeOIDC { + continue + } + + if len(link.Claims.MergedClaims) == 0 { + continue + } + + value, ok := link.Claims.MergedClaims[args.ClaimField] + if !ok { + continue + } + switch value := value.(type) { + case string: + values = append(values, value) + case []string: + values = append(values, value...) + case []any: + for _, v := range value { + if sv, ok := v.(string); ok { + values = append(values, sv) + } + } + default: + continue + } + } + + return slice.Unique(values), nil +} + func (q *FakeQuerier) OIDCClaimFields(_ context.Context, organizationID uuid.UUID) ([]string, error) { orgMembers := q.getOrganizationMemberNoLock(organizationID) @@ -8427,10 +8473,7 @@ func (q *FakeQuerier) OIDCClaimFields(_ context.Context, organizationID uuid.UUI continue } - for k := range link.Claims.IDTokenClaims { - fields = append(fields, k) - } - for k := range link.Claims.UserInfoClaims { + for k := range link.Claims.MergedClaims { fields = append(fields, k) } } diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 32d3cce658525..ba869ec42f27a 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -2058,6 +2058,13 @@ func (m queryMetricsStore) ListWorkspaceAgentPortShares(ctx context.Context, wor return r0, r1 } +func (m queryMetricsStore) OIDCClaimFieldValues(ctx context.Context, organizationID database.OIDCClaimFieldValuesParams) ([]string, error) { + start := time.Now() + r0, r1 := m.s.OIDCClaimFieldValues(ctx, organizationID) + m.queryLatencies.WithLabelValues("OIDCClaimFieldValues").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m queryMetricsStore) OIDCClaimFields(ctx context.Context, organizationID uuid.UUID) ([]string, error) { start := time.Now() r0, r1 := m.s.OIDCClaimFields(ctx, organizationID) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index d6c34411f8208..c4bef6acb75c4 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -4359,6 +4359,21 @@ func (mr *MockStoreMockRecorder) ListWorkspaceAgentPortShares(arg0, arg1 any) *g return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListWorkspaceAgentPortShares", reflect.TypeOf((*MockStore)(nil).ListWorkspaceAgentPortShares), arg0, arg1) } +// OIDCClaimFieldValues mocks base method. +func (m *MockStore) OIDCClaimFieldValues(arg0 context.Context, arg1 database.OIDCClaimFieldValuesParams) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OIDCClaimFieldValues", arg0, arg1) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OIDCClaimFieldValues indicates an expected call of OIDCClaimFieldValues. +func (mr *MockStoreMockRecorder) OIDCClaimFieldValues(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OIDCClaimFieldValues", reflect.TypeOf((*MockStore)(nil).OIDCClaimFieldValues), arg0, arg1) +} + // OIDCClaimFields mocks base method. func (m *MockStore) OIDCClaimFields(arg0 context.Context, arg1 uuid.UUID) ([]string, error) { m.ctrl.T.Helper() diff --git a/coderd/database/oidcclaims_test.go b/coderd/database/oidcclaims_test.go index 85fd5b3df3812..f9fe1711b19b8 100644 --- a/coderd/database/oidcclaims_test.go +++ b/coderd/database/oidcclaims_test.go @@ -32,6 +32,8 @@ func TestOIDCClaims(t *testing.T) { db, _ := dbtestutil.NewDB(t) g := userGenerator{t: t, db: db} + const claimField = "claim-list" + // https://en.wikipedia.org/wiki/Alice_and_Bob#Cast_of_characters alice := g.withLink(database.LoginTypeOIDC, toJSON(extraKeys{ UserLinkClaims: database.UserLinkClaims{ @@ -43,6 +45,9 @@ func TestOIDCClaims(t *testing.T) { MergedClaims: map[string]interface{}{ "sub": "alice", "alice-id": "from-bob", + claimField: []string{ + "one", "two", "three", + }, }, }, // Always should be a no-op @@ -79,6 +84,9 @@ func TestOIDCClaims(t *testing.T) { "foo": "bar", }, "nil": nil, + claimField: []any{ + "three", 5, []string{"test"}, "four", + }, }, })) charlie := g.withLink(database.LoginTypeOIDC, toJSON(database.UserLinkClaims{ @@ -94,6 +102,7 @@ func TestOIDCClaims(t *testing.T) { "sub": "charlie", "charlie-id": "charlie", "charlie-info": "charlie", + claimField: "charlie", }, })) @@ -113,8 +122,9 @@ func TestOIDCClaims(t *testing.T) { "do-not": "look", }, MergedClaims: map[string]interface{}{ - "not": "allowed", - "do-not": "look", + "not": "allowed", + "do-not": "look", + claimField: 42, }, })), // github should be omitted @@ -140,12 +150,32 @@ func TestOIDCClaims(t *testing.T) { // Verify the OIDC claim fields always := []string{"array", "map", "nil", "number"} - expectA := append([]string{"sub", "alice-id", "bob-id", "bob-info"}, always...) - expectB := append([]string{"sub", "bob-id", "bob-info", "charlie-id", "charlie-info"}, always...) + expectA := append([]string{"sub", "alice-id", "bob-id", "bob-info", "claim-list"}, always...) + expectB := append([]string{"sub", "bob-id", "bob-info", "charlie-id", "charlie-info", "claim-list"}, always...) requireClaims(t, db, orgA.Org.ID, expectA) requireClaims(t, db, orgB.Org.ID, expectB) requireClaims(t, db, orgC.Org.ID, []string{}) requireClaims(t, db, uuid.Nil, slice.Unique(append(expectA, expectB...))) + + // Verify the claim field values + expectAValues := []string{"one", "two", "three", "four"} + expectBValues := []string{"three", "four", "charlie"} + requireClaimValues(t, db, orgA.Org.ID, claimField, expectAValues) + requireClaimValues(t, db, orgB.Org.ID, claimField, expectBValues) + requireClaimValues(t, db, orgC.Org.ID, claimField, []string{}) +} + +func requireClaimValues(t *testing.T, db database.Store, orgID uuid.UUID, field string, want []string) { + t.Helper() + + ctx := testutil.Context(t, testutil.WaitMedium) + got, err := db.OIDCClaimFieldValues(ctx, database.OIDCClaimFieldValuesParams{ + ClaimField: field, + OrganizationID: orgID, + }) + require.NoError(t, err) + + require.ElementsMatch(t, want, got) } func requireClaims(t *testing.T, db database.Store, orgID uuid.UUID, want []string) { diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 49ba6fbf8496a..acd284b58f4cf 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -413,6 +413,7 @@ type sqlcQuerier interface { ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error) ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgentPortShare, error) + OIDCClaimFieldValues(ctx context.Context, arg OIDCClaimFieldValuesParams) ([]string, error) // OIDCClaimFields returns a list of distinct keys in the the merged_claims fields. // This query is used to generate the list of available sync fields for idp sync settings. OIDCClaimFields(ctx context.Context, organizationID uuid.UUID) ([]string, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 09dd4c1fbc488..0ce09dffa1d40 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -9846,6 +9846,67 @@ func (q *sqlQuerier) InsertUserLink(ctx context.Context, arg InsertUserLinkParam return i, err } +const oIDCClaimFieldValues = `-- name: OIDCClaimFieldValues :many +SELECT + -- DISTINCT to remove duplicates + DISTINCT jsonb_array_elements_text(CASE + -- When the type is an array, filter out any non-string elements. + -- This is to keep the return type consistent. + WHEN jsonb_typeof(claims->'merged_claims'->$1::text) = 'array' THEN + ( + SELECT + jsonb_agg(element) + FROM + jsonb_array_elements(claims->'merged_claims'->$1::text) AS element + WHERE + -- Filtering out non-string elements + jsonb_typeof(element) = 'string' + ) + -- Some IDPs return a single string instead of an array of strings. + WHEN jsonb_typeof(claims->'merged_claims'->$1::text) = 'string' THEN + jsonb_build_array(claims->'merged_claims'->$1::text) + END) +FROM + user_links +WHERE + -- IDP sync only supports string and array (of string) types + jsonb_typeof(claims->'merged_claims'->$1::text) = ANY(ARRAY['string', 'array']) + AND login_type = 'oidc' + AND CASE + WHEN $2 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + user_links.user_id = ANY(SELECT organization_members.user_id FROM organization_members WHERE organization_id = $2) + ELSE true + END +` + +type OIDCClaimFieldValuesParams struct { + ClaimField string `db:"claim_field" json:"claim_field"` + OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` +} + +func (q *sqlQuerier) OIDCClaimFieldValues(ctx context.Context, arg OIDCClaimFieldValuesParams) ([]string, error) { + rows, err := q.db.QueryContext(ctx, oIDCClaimFieldValues, arg.ClaimField, arg.OrganizationID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var jsonb_array_elements_text string + if err := rows.Scan(&jsonb_array_elements_text); err != nil { + return nil, err + } + items = append(items, jsonb_array_elements_text) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const oIDCClaimFields = `-- name: OIDCClaimFields :many SELECT DISTINCT jsonb_object_keys(claims->'merged_claims') diff --git a/coderd/database/queries/user_links.sql b/coderd/database/queries/user_links.sql index 274193b0c8bf6..43e7fad64e7bd 100644 --- a/coderd/database/queries/user_links.sql +++ b/coderd/database/queries/user_links.sql @@ -58,7 +58,6 @@ SET WHERE user_id = $7 AND login_type = $8 RETURNING *; - -- name: OIDCClaimFields :many -- OIDCClaimFields returns a list of distinct keys in the the merged_claims fields. -- This query is used to generate the list of available sync fields for idp sync settings. @@ -78,3 +77,36 @@ WHERE ELSE true END ; + +-- name: OIDCClaimFieldValues :many +SELECT + -- DISTINCT to remove duplicates + DISTINCT jsonb_array_elements_text(CASE + -- When the type is an array, filter out any non-string elements. + -- This is to keep the return type consistent. + WHEN jsonb_typeof(claims->'merged_claims'->sqlc.arg('claim_field')::text) = 'array' THEN + ( + SELECT + jsonb_agg(element) + FROM + jsonb_array_elements(claims->'merged_claims'->sqlc.arg('claim_field')::text) AS element + WHERE + -- Filtering out non-string elements + jsonb_typeof(element) = 'string' + ) + -- Some IDPs return a single string instead of an array of strings. + WHEN jsonb_typeof(claims->'merged_claims'->sqlc.arg('claim_field')::text) = 'string' THEN + jsonb_build_array(claims->'merged_claims'->sqlc.arg('claim_field')::text) + END) +FROM + user_links +WHERE + -- IDP sync only supports string and array (of string) types + jsonb_typeof(claims->'merged_claims'->sqlc.arg('claim_field')::text) = ANY(ARRAY['string', 'array']) + AND login_type = 'oidc' + AND CASE + WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + user_links.user_id = ANY(SELECT organization_members.user_id FROM organization_members WHERE organization_id = @organization_id) + ELSE true + END +;