From 0b20020a6373aeb237de28a74a270c46d5758453 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 18 Nov 2024 12:28:02 -0600 Subject: [PATCH 1/9] chore: implement OIDCClaimFieldValues for idp sync mappings help --- coderd/database/queries/user_links.sql | 33 +++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/coderd/database/queries/user_links.sql b/coderd/database/queries/user_links.sql index 274193b0c8bf6..adc6d648d06d2 100644 --- a/coderd/database/queries/user_links.sql +++ b/coderd/database/queries/user_links.sql @@ -58,7 +58,6 @@ SET WHERE user_id = $7 AND login_type = $8 RETURNING *; - -- name: OIDCClaimFields :many -- OIDCClaimFields returns a list of distinct keys in the the merged_claims fields. -- This query is used to generate the list of available sync fields for idp sync settings. @@ -78,3 +77,35 @@ WHERE ELSE true END ; + +-- name: OIDCClaimFieldValues :many +SELECT + -- DISTINCT to remove duplicates + DISTINCT jsonb_array_elements_text(CASE + -- When the type is an array, filter out any non-string elements. + -- This is to keep the return type consistent. + WHEN jsonb_typeof(claims->'merged_claims'->'groups') = 'array' THEN + ( + SELECT + jsonb_agg(element) + FROM + jsonb_array_elements(claims->'merged_claims'->@claim_field) AS element + WHERE + -- Filtering out non-string elements + jsonb_typeof(element) = 'string' + ) + -- Some IDPs return a single string instead of an array of strings. + WHEN jsonb_typeof(claims->'merged_claims'->'groups') = 'string' THEN + jsonb_build_array(claims->'merged_claims'->@claim_field) + END)::text +FROM + user_links +WHERE + -- IDP sync only supports string and array (of string) types + jsonb_typeof(claims->'merged_claims'->@claim_field) = ANY(ARRAY['string', 'array']) + AND login_type = 'oidc' + AND CASE WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + user_links.user_id = ANY(SELECT organization_members.user_id FROM organization_members WHERE organization_id = @organization_id) + ELSE true + END +; From 917ce361fd6d9091ed93385bf4d44f7c00a7bfc3 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 18 Nov 2024 14:32:03 -0600 Subject: [PATCH 2/9] remove text cast --- coderd/database/queries/user_links.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/database/queries/user_links.sql b/coderd/database/queries/user_links.sql index adc6d648d06d2..83b321656edb7 100644 --- a/coderd/database/queries/user_links.sql +++ b/coderd/database/queries/user_links.sql @@ -97,7 +97,7 @@ SELECT -- Some IDPs return a single string instead of an array of strings. WHEN jsonb_typeof(claims->'merged_claims'->'groups') = 'string' THEN jsonb_build_array(claims->'merged_claims'->@claim_field) - END)::text + END) FROM user_links WHERE From 1a5dc303a23d39f4bb728f55aa9c344090181111 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 18 Nov 2024 14:40:58 -0600 Subject: [PATCH 3/9] make gen --- coderd/database/dbauthz/dbauthz.go | 4 ++ coderd/database/dbmem/dbmem.go | 4 ++ coderd/database/dbmetrics/querymetrics.go | 7 +++ coderd/database/dbmock/dbmock.go | 15 +++++++ coderd/database/querier.go | 1 + coderd/database/queries.sql.go | 55 +++++++++++++++++++++++ 6 files changed, 86 insertions(+) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 4845ff22288fe..144839b58ab5f 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -3283,6 +3283,10 @@ func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID return q.db.ListWorkspaceAgentPortShares(ctx, workspaceID) } +func (q *querier) OIDCClaimFieldValues(ctx context.Context, organizationID uuid.UUID) ([]string, error) { + panic("not implemented") +} + func (q *querier) OIDCClaimFields(ctx context.Context, organizationID uuid.UUID) ([]string, error) { resource := rbac.ResourceIdpsyncSettings if organizationID != uuid.Nil { diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index aed57e9284b3a..227f2934d6a56 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -8409,6 +8409,10 @@ func (q *FakeQuerier) ListWorkspaceAgentPortShares(_ context.Context, workspaceI return shares, nil } +func (q *FakeQuerier) OIDCClaimFieldValues(ctx context.Context, organizationID uuid.UUID) ([]string, error) { + panic("not implemented") +} + func (q *FakeQuerier) OIDCClaimFields(_ context.Context, organizationID uuid.UUID) ([]string, error) { orgMembers := q.getOrganizationMemberNoLock(organizationID) diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 32d3cce658525..65894f99fbc8b 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -2058,6 +2058,13 @@ func (m queryMetricsStore) ListWorkspaceAgentPortShares(ctx context.Context, wor return r0, r1 } +func (m queryMetricsStore) OIDCClaimFieldValues(ctx context.Context, organizationID uuid.UUID) ([]string, error) { + start := time.Now() + r0, r1 := m.s.OIDCClaimFieldValues(ctx, organizationID) + m.queryLatencies.WithLabelValues("OIDCClaimFieldValues").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m queryMetricsStore) OIDCClaimFields(ctx context.Context, organizationID uuid.UUID) ([]string, error) { start := time.Now() r0, r1 := m.s.OIDCClaimFields(ctx, organizationID) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index d6c34411f8208..53e3f094b9418 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -4359,6 +4359,21 @@ func (mr *MockStoreMockRecorder) ListWorkspaceAgentPortShares(arg0, arg1 any) *g return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListWorkspaceAgentPortShares", reflect.TypeOf((*MockStore)(nil).ListWorkspaceAgentPortShares), arg0, arg1) } +// OIDCClaimFieldValues mocks base method. +func (m *MockStore) OIDCClaimFieldValues(arg0 context.Context, arg1 uuid.UUID) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OIDCClaimFieldValues", arg0, arg1) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OIDCClaimFieldValues indicates an expected call of OIDCClaimFieldValues. +func (mr *MockStoreMockRecorder) OIDCClaimFieldValues(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OIDCClaimFieldValues", reflect.TypeOf((*MockStore)(nil).OIDCClaimFieldValues), arg0, arg1) +} + // OIDCClaimFields mocks base method. func (m *MockStore) OIDCClaimFields(arg0 context.Context, arg1 uuid.UUID) ([]string, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 49ba6fbf8496a..ff0625f2c08f1 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -413,6 +413,7 @@ type sqlcQuerier interface { ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error) ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgentPortShare, error) + OIDCClaimFieldValues(ctx context.Context, organizationID uuid.UUID) ([]string, error) // OIDCClaimFields returns a list of distinct keys in the the merged_claims fields. // This query is used to generate the list of available sync fields for idp sync settings. OIDCClaimFields(ctx context.Context, organizationID uuid.UUID) ([]string, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 09dd4c1fbc488..b05277c6af2c1 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -9846,6 +9846,61 @@ func (q *sqlQuerier) InsertUserLink(ctx context.Context, arg InsertUserLinkParam return i, err } +const oIDCClaimFieldValues = `-- name: OIDCClaimFieldValues :many +SELECT + -- DISTINCT to remove duplicates + DISTINCT jsonb_array_elements_text(CASE + -- When the type is an array, filter out any non-string elements. + -- This is to keep the return type consistent. + WHEN jsonb_typeof(claims->'merged_claims'->'groups') = 'array' THEN + ( + SELECT + jsonb_agg(element) + FROM + jsonb_array_elements(claims->'merged_claims'->@claim_field) AS element + WHERE + -- Filtering out non-string elements + jsonb_typeof(element) = 'string' + ) + -- Some IDPs return a single string instead of an array of strings. + WHEN jsonb_typeof(claims->'merged_claims'->'groups') = 'string' THEN + jsonb_build_array(claims->'merged_claims'->@claim_field) + END) +FROM + user_links +WHERE + -- IDP sync only supports string and array (of string) types + jsonb_typeof(claims->'merged_claims'->@claim_field) = ANY(ARRAY['string', 'array']) + AND login_type = 'oidc' + AND CASE WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + user_links.user_id = ANY(SELECT organization_members.user_id FROM organization_members WHERE organization_id = $1) + ELSE true + END +` + +func (q *sqlQuerier) OIDCClaimFieldValues(ctx context.Context, organizationID uuid.UUID) ([]string, error) { + rows, err := q.db.QueryContext(ctx, oIDCClaimFieldValues, organizationID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var jsonb_array_elements_text string + if err := rows.Scan(&jsonb_array_elements_text); err != nil { + return nil, err + } + items = append(items, jsonb_array_elements_text) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const oIDCClaimFields = `-- name: OIDCClaimFields :many SELECT DISTINCT jsonb_object_keys(claims->'merged_claims') From 2ee347f9452522f8153015c1b41c979f4a89e7b4 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 18 Nov 2024 15:29:11 -0600 Subject: [PATCH 4/9] chore: unit tests for OIDCClaimFieldValues and fixup sql arg types --- coderd/database/dbauthz/dbauthz.go | 12 +++++- coderd/database/dbmem/dbmem.go | 50 ++++++++++++++++++++--- coderd/database/dbmetrics/querymetrics.go | 2 +- coderd/database/oidcclaims_test.go | 38 +++++++++++++++-- coderd/database/querier.go | 2 +- coderd/database/queries.sql.go | 24 +++++++---- coderd/database/queries/user_links.sql | 15 +++---- 7 files changed, 113 insertions(+), 30 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 144839b58ab5f..4ca5c7249b59a 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -3283,8 +3283,16 @@ func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID return q.db.ListWorkspaceAgentPortShares(ctx, workspaceID) } -func (q *querier) OIDCClaimFieldValues(ctx context.Context, organizationID uuid.UUID) ([]string, error) { - panic("not implemented") +func (q *querier) OIDCClaimFieldValues(ctx context.Context, args database.OIDCClaimFieldValuesParams) ([]string, error) { + resource := rbac.ResourceIdpsyncSettings + if args.OrganizationID != uuid.Nil { + resource = resource.InOrg(args.OrganizationID) + } + if err := q.authorizeContext(ctx, policy.ActionRead, resource); err != nil { + return nil, err + } + return q.db.OIDCClaimFieldValues(ctx, args) + } func (q *querier) OIDCClaimFields(ctx context.Context, organizationID uuid.UUID) ([]string, error) { diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 227f2934d6a56..1e4c818f34085 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -8409,8 +8409,49 @@ func (q *FakeQuerier) ListWorkspaceAgentPortShares(_ context.Context, workspaceI return shares, nil } -func (q *FakeQuerier) OIDCClaimFieldValues(ctx context.Context, organizationID uuid.UUID) ([]string, error) { - panic("not implemented") +func (q *FakeQuerier) OIDCClaimFieldValues(ctx context.Context, args database.OIDCClaimFieldValuesParams) ([]string, error) { + orgMembers := q.getOrganizationMemberNoLock(args.OrganizationID) + + var values []string + for _, link := range q.userLinks { + if args.OrganizationID != uuid.Nil { + inOrg := slices.ContainsFunc(orgMembers, func(organizationMember database.OrganizationMember) bool { + return organizationMember.UserID == link.UserID + }) + if !inOrg { + continue + } + } + + if link.LoginType != database.LoginTypeOIDC { + continue + } + + if len(link.Claims.MergedClaims) == 0 { + continue + } + + value, ok := link.Claims.MergedClaims[args.ClaimField] + if !ok { + continue + } + switch value.(type) { + case string: + values = append(values, value.(string)) + case []string: + values = append(values, value.([]string)...) + case []any: + for _, v := range value.([]any) { + if sv, ok := v.(string); ok { + values = append(values, sv) + } + } + default: + continue + } + } + + return slice.Unique(values), nil } func (q *FakeQuerier) OIDCClaimFields(_ context.Context, organizationID uuid.UUID) ([]string, error) { @@ -8431,10 +8472,7 @@ func (q *FakeQuerier) OIDCClaimFields(_ context.Context, organizationID uuid.UUI continue } - for k := range link.Claims.IDTokenClaims { - fields = append(fields, k) - } - for k := range link.Claims.UserInfoClaims { + for k := range link.Claims.MergedClaims { fields = append(fields, k) } } diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 65894f99fbc8b..ba869ec42f27a 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -2058,7 +2058,7 @@ func (m queryMetricsStore) ListWorkspaceAgentPortShares(ctx context.Context, wor return r0, r1 } -func (m queryMetricsStore) OIDCClaimFieldValues(ctx context.Context, organizationID uuid.UUID) ([]string, error) { +func (m queryMetricsStore) OIDCClaimFieldValues(ctx context.Context, organizationID database.OIDCClaimFieldValuesParams) ([]string, error) { start := time.Now() r0, r1 := m.s.OIDCClaimFieldValues(ctx, organizationID) m.queryLatencies.WithLabelValues("OIDCClaimFieldValues").Observe(time.Since(start).Seconds()) diff --git a/coderd/database/oidcclaims_test.go b/coderd/database/oidcclaims_test.go index 85fd5b3df3812..f9fe1711b19b8 100644 --- a/coderd/database/oidcclaims_test.go +++ b/coderd/database/oidcclaims_test.go @@ -32,6 +32,8 @@ func TestOIDCClaims(t *testing.T) { db, _ := dbtestutil.NewDB(t) g := userGenerator{t: t, db: db} + const claimField = "claim-list" + // https://en.wikipedia.org/wiki/Alice_and_Bob#Cast_of_characters alice := g.withLink(database.LoginTypeOIDC, toJSON(extraKeys{ UserLinkClaims: database.UserLinkClaims{ @@ -43,6 +45,9 @@ func TestOIDCClaims(t *testing.T) { MergedClaims: map[string]interface{}{ "sub": "alice", "alice-id": "from-bob", + claimField: []string{ + "one", "two", "three", + }, }, }, // Always should be a no-op @@ -79,6 +84,9 @@ func TestOIDCClaims(t *testing.T) { "foo": "bar", }, "nil": nil, + claimField: []any{ + "three", 5, []string{"test"}, "four", + }, }, })) charlie := g.withLink(database.LoginTypeOIDC, toJSON(database.UserLinkClaims{ @@ -94,6 +102,7 @@ func TestOIDCClaims(t *testing.T) { "sub": "charlie", "charlie-id": "charlie", "charlie-info": "charlie", + claimField: "charlie", }, })) @@ -113,8 +122,9 @@ func TestOIDCClaims(t *testing.T) { "do-not": "look", }, MergedClaims: map[string]interface{}{ - "not": "allowed", - "do-not": "look", + "not": "allowed", + "do-not": "look", + claimField: 42, }, })), // github should be omitted @@ -140,12 +150,32 @@ func TestOIDCClaims(t *testing.T) { // Verify the OIDC claim fields always := []string{"array", "map", "nil", "number"} - expectA := append([]string{"sub", "alice-id", "bob-id", "bob-info"}, always...) - expectB := append([]string{"sub", "bob-id", "bob-info", "charlie-id", "charlie-info"}, always...) + expectA := append([]string{"sub", "alice-id", "bob-id", "bob-info", "claim-list"}, always...) + expectB := append([]string{"sub", "bob-id", "bob-info", "charlie-id", "charlie-info", "claim-list"}, always...) requireClaims(t, db, orgA.Org.ID, expectA) requireClaims(t, db, orgB.Org.ID, expectB) requireClaims(t, db, orgC.Org.ID, []string{}) requireClaims(t, db, uuid.Nil, slice.Unique(append(expectA, expectB...))) + + // Verify the claim field values + expectAValues := []string{"one", "two", "three", "four"} + expectBValues := []string{"three", "four", "charlie"} + requireClaimValues(t, db, orgA.Org.ID, claimField, expectAValues) + requireClaimValues(t, db, orgB.Org.ID, claimField, expectBValues) + requireClaimValues(t, db, orgC.Org.ID, claimField, []string{}) +} + +func requireClaimValues(t *testing.T, db database.Store, orgID uuid.UUID, field string, want []string) { + t.Helper() + + ctx := testutil.Context(t, testutil.WaitMedium) + got, err := db.OIDCClaimFieldValues(ctx, database.OIDCClaimFieldValuesParams{ + ClaimField: field, + OrganizationID: orgID, + }) + require.NoError(t, err) + + require.ElementsMatch(t, want, got) } func requireClaims(t *testing.T, db database.Store, orgID uuid.UUID, want []string) { diff --git a/coderd/database/querier.go b/coderd/database/querier.go index ff0625f2c08f1..acd284b58f4cf 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -413,7 +413,7 @@ type sqlcQuerier interface { ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error) ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgentPortShare, error) - OIDCClaimFieldValues(ctx context.Context, organizationID uuid.UUID) ([]string, error) + OIDCClaimFieldValues(ctx context.Context, arg OIDCClaimFieldValuesParams) ([]string, error) // OIDCClaimFields returns a list of distinct keys in the the merged_claims fields. // This query is used to generate the list of available sync fields for idp sync settings. OIDCClaimFields(ctx context.Context, organizationID uuid.UUID) ([]string, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index b05277c6af2c1..0ce09dffa1d40 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -9852,34 +9852,40 @@ SELECT DISTINCT jsonb_array_elements_text(CASE -- When the type is an array, filter out any non-string elements. -- This is to keep the return type consistent. - WHEN jsonb_typeof(claims->'merged_claims'->'groups') = 'array' THEN + WHEN jsonb_typeof(claims->'merged_claims'->$1::text) = 'array' THEN ( SELECT jsonb_agg(element) FROM - jsonb_array_elements(claims->'merged_claims'->@claim_field) AS element + jsonb_array_elements(claims->'merged_claims'->$1::text) AS element WHERE -- Filtering out non-string elements jsonb_typeof(element) = 'string' ) -- Some IDPs return a single string instead of an array of strings. - WHEN jsonb_typeof(claims->'merged_claims'->'groups') = 'string' THEN - jsonb_build_array(claims->'merged_claims'->@claim_field) + WHEN jsonb_typeof(claims->'merged_claims'->$1::text) = 'string' THEN + jsonb_build_array(claims->'merged_claims'->$1::text) END) FROM user_links WHERE -- IDP sync only supports string and array (of string) types - jsonb_typeof(claims->'merged_claims'->@claim_field) = ANY(ARRAY['string', 'array']) + jsonb_typeof(claims->'merged_claims'->$1::text) = ANY(ARRAY['string', 'array']) AND login_type = 'oidc' - AND CASE WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN - user_links.user_id = ANY(SELECT organization_members.user_id FROM organization_members WHERE organization_id = $1) + AND CASE + WHEN $2 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + user_links.user_id = ANY(SELECT organization_members.user_id FROM organization_members WHERE organization_id = $2) ELSE true END ` -func (q *sqlQuerier) OIDCClaimFieldValues(ctx context.Context, organizationID uuid.UUID) ([]string, error) { - rows, err := q.db.QueryContext(ctx, oIDCClaimFieldValues, organizationID) +type OIDCClaimFieldValuesParams struct { + ClaimField string `db:"claim_field" json:"claim_field"` + OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` +} + +func (q *sqlQuerier) OIDCClaimFieldValues(ctx context.Context, arg OIDCClaimFieldValuesParams) ([]string, error) { + rows, err := q.db.QueryContext(ctx, oIDCClaimFieldValues, arg.ClaimField, arg.OrganizationID) if err != nil { return nil, err } diff --git a/coderd/database/queries/user_links.sql b/coderd/database/queries/user_links.sql index 83b321656edb7..43e7fad64e7bd 100644 --- a/coderd/database/queries/user_links.sql +++ b/coderd/database/queries/user_links.sql @@ -84,28 +84,29 @@ SELECT DISTINCT jsonb_array_elements_text(CASE -- When the type is an array, filter out any non-string elements. -- This is to keep the return type consistent. - WHEN jsonb_typeof(claims->'merged_claims'->'groups') = 'array' THEN + WHEN jsonb_typeof(claims->'merged_claims'->sqlc.arg('claim_field')::text) = 'array' THEN ( SELECT jsonb_agg(element) FROM - jsonb_array_elements(claims->'merged_claims'->@claim_field) AS element + jsonb_array_elements(claims->'merged_claims'->sqlc.arg('claim_field')::text) AS element WHERE -- Filtering out non-string elements jsonb_typeof(element) = 'string' ) -- Some IDPs return a single string instead of an array of strings. - WHEN jsonb_typeof(claims->'merged_claims'->'groups') = 'string' THEN - jsonb_build_array(claims->'merged_claims'->@claim_field) + WHEN jsonb_typeof(claims->'merged_claims'->sqlc.arg('claim_field')::text) = 'string' THEN + jsonb_build_array(claims->'merged_claims'->sqlc.arg('claim_field')::text) END) FROM user_links WHERE -- IDP sync only supports string and array (of string) types - jsonb_typeof(claims->'merged_claims'->@claim_field) = ANY(ARRAY['string', 'array']) + jsonb_typeof(claims->'merged_claims'->sqlc.arg('claim_field')::text) = ANY(ARRAY['string', 'array']) AND login_type = 'oidc' - AND CASE WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN - user_links.user_id = ANY(SELECT organization_members.user_id FROM organization_members WHERE organization_id = @organization_id) + AND CASE + WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + user_links.user_id = ANY(SELECT organization_members.user_id FROM organization_members WHERE organization_id = @organization_id) ELSE true END ; From 0ef9f1076476fa68ea475bf274c4288999dd98f6 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 18 Nov 2024 15:34:40 -0600 Subject: [PATCH 5/9] fix mock --- coderd/database/dbmock/dbmock.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 53e3f094b9418..c4bef6acb75c4 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -4360,7 +4360,7 @@ func (mr *MockStoreMockRecorder) ListWorkspaceAgentPortShares(arg0, arg1 any) *g } // OIDCClaimFieldValues mocks base method. -func (m *MockStore) OIDCClaimFieldValues(arg0 context.Context, arg1 uuid.UUID) ([]string, error) { +func (m *MockStore) OIDCClaimFieldValues(arg0 context.Context, arg1 database.OIDCClaimFieldValuesParams) ([]string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OIDCClaimFieldValues", arg0, arg1) ret0, _ := ret[0].([]string) From 614acc77677ec0c1e68f7d5a97ac316659f4f729 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 18 Nov 2024 15:57:40 -0600 Subject: [PATCH 6/9] linting --- coderd/database/dbmem/dbmem.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 1e4c818f34085..25b657a106ab8 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -8409,7 +8409,8 @@ func (q *FakeQuerier) ListWorkspaceAgentPortShares(_ context.Context, workspaceI return shares, nil } -func (q *FakeQuerier) OIDCClaimFieldValues(ctx context.Context, args database.OIDCClaimFieldValuesParams) ([]string, error) { +// nolint:forcetypeassert +func (q *FakeQuerier) OIDCClaimFieldValues(_ context.Context, args database.OIDCClaimFieldValuesParams) ([]string, error) { orgMembers := q.getOrganizationMemberNoLock(args.OrganizationID) var values []string From e04102fd168073d98f3ff1476ba19ca6cec7371b Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 18 Nov 2024 16:00:57 -0600 Subject: [PATCH 7/9] fmt --- coderd/database/dbauthz/dbauthz.go | 1 - 1 file changed, 1 deletion(-) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 4ca5c7249b59a..2e9a85f8ba578 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -3292,7 +3292,6 @@ func (q *querier) OIDCClaimFieldValues(ctx context.Context, args database.OIDCCl return nil, err } return q.db.OIDCClaimFieldValues(ctx, args) - } func (q *querier) OIDCClaimFields(ctx context.Context, organizationID uuid.UUID) ([]string, error) { From dae9216d7253fde4f567ba7f1e393f99271d0081 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 18 Nov 2024 16:09:17 -0600 Subject: [PATCH 8/9] dbauthz tests --- coderd/database/dbauthz/dbauthz_test.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 2eb75f8b738c4..ef572f0a11247 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -633,6 +633,19 @@ func (s *MethodTestSuite) TestOrganization() { id := uuid.New() check.Args(id).Asserts(rbac.ResourceIdpsyncSettings.InOrg(id), policy.ActionRead).Returns([]string{}) })) + s.Run("Deployment/OIDCClaimFieldValues", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.OIDCClaimFieldValuesParams{ + ClaimField: "claim-field", + OrganizationID: uuid.Nil, + }).Asserts(rbac.ResourceIdpsyncSettings, policy.ActionRead).Returns([]string{}) + })) + s.Run("Organization/OIDCClaimFieldValues", s.Subtest(func(db database.Store, check *expects) { + id := uuid.New() + check.Args(database.OIDCClaimFieldValuesParams{ + ClaimField: "claim-field", + OrganizationID: id, + }).Asserts(rbac.ResourceIdpsyncSettings.InOrg(id), policy.ActionRead).Returns([]string{}) + })) s.Run("ByOrganization/GetGroups", s.Subtest(func(db database.Store, check *expects) { o := dbgen.Organization(s.T(), db, database.Organization{}) a := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) From 785243411f463162a54f4600566af14f952e3912 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 19 Nov 2024 10:41:55 -0600 Subject: [PATCH 9/9] linting --- coderd/database/dbmem/dbmem.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 25b657a106ab8..2be0a8f583bb7 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -8436,13 +8436,13 @@ func (q *FakeQuerier) OIDCClaimFieldValues(_ context.Context, args database.OIDC if !ok { continue } - switch value.(type) { + switch value := value.(type) { case string: - values = append(values, value.(string)) + values = append(values, value) case []string: - values = append(values, value.([]string)...) + values = append(values, value...) case []any: - for _, v := range value.([]any) { + for _, v := range value { if sv, ok := v.(string); ok { values = append(values, sv) }