diff --git a/cli/server.go b/cli/server.go index 4e3b1e16a1482..c2cd476edfaa4 100644 --- a/cli/server.go +++ b/cli/server.go @@ -187,11 +187,6 @@ func createOIDCConfig(ctx context.Context, logger slog.Logger, vals *codersdk.De EmailField: vals.OIDC.EmailField.String(), AuthURLParams: vals.OIDC.AuthURLParams.Value, IgnoreUserInfo: vals.OIDC.IgnoreUserInfo.Value(), - GroupField: vals.OIDC.GroupField.String(), - GroupFilter: vals.OIDC.GroupRegexFilter.Value(), - GroupAllowList: groupAllowList, - CreateMissingGroups: vals.OIDC.GroupAutoCreate.Value(), - GroupMapping: vals.OIDC.GroupMapping.Value, UserRoleField: vals.OIDC.UserRoleField.String(), UserRoleMapping: vals.OIDC.UserRoleMapping.Value, UserRolesDefault: vals.OIDC.UserRolesDefault.GetSlice(), diff --git a/coderd/coderd.go b/coderd/coderd.go index 51b6780e4dc47..e04f13d367c6e 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -181,7 +181,6 @@ type Options struct { NetworkTelemetryBatchFrequency time.Duration NetworkTelemetryBatchMaxSize int SwaggerEndpoint bool - SetUserGroups func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error SetUserSiteRoles func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, roles []string) error TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore] UserQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore] @@ -276,13 +275,6 @@ func New(options *Options) *API { if options.Entitlements == nil { options.Entitlements = entitlements.New() } - if options.IDPSync == nil { - options.IDPSync = idpsync.NewAGPLSync(options.Logger, idpsync.SyncSettings{ - OrganizationField: options.DeploymentValues.OIDC.OrganizationField.Value(), - OrganizationMapping: options.DeploymentValues.OIDC.OrganizationMapping.Value, - OrganizationAssignDefault: options.DeploymentValues.OIDC.OrganizationAssignDefault.Value(), - }) - } if options.NewTicker == nil { options.NewTicker = func(duration time.Duration) (tick <-chan time.Time, done func()) { ticker := time.NewTicker(duration) @@ -318,6 +310,10 @@ func New(options *Options) *API { options.AccessControlStore, ) + if options.IDPSync == nil { + options.IDPSync = idpsync.NewAGPLSync(options.Logger, options.RuntimeConfig, idpsync.FromDeploymentValues(options.DeploymentValues)) + } + experiments := ReadExperiments( options.Logger, options.DeploymentValues.Experiments.Value(), ) @@ -377,16 +373,6 @@ func New(options *Options) *API { if options.TracerProvider == nil { options.TracerProvider = trace.NewNoopTracerProvider() } - if options.SetUserGroups == nil { - options.SetUserGroups = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error { - logger.Warn(ctx, "attempted to assign OIDC groups without enterprise license", - slog.F("user_id", userID), - slog.F("groups", orgGroupNames), - slog.F("create_missing_groups", createMissingGroups), - ) - return nil - } - } if options.SetUserSiteRoles == nil { options.SetUserSiteRoles = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, roles []string) error { logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise license", diff --git a/coderd/coderdtest/uuids.go b/coderd/coderdtest/uuids.go new file mode 100644 index 0000000000000..1ff60bf26c572 --- /dev/null +++ b/coderd/coderdtest/uuids.go @@ -0,0 +1,25 @@ +package coderdtest + +import "github.com/google/uuid" + +// DeterministicUUIDGenerator allows "naming" uuids for unit tests. +// An example of where this is useful, is when a tabled test references +// a UUID that is not yet known. An alternative to this would be to +// hard code some UUID strings, but these strings are not human friendly. +type DeterministicUUIDGenerator struct { + Named map[string]uuid.UUID +} + +func NewDeterministicUUIDGenerator() *DeterministicUUIDGenerator { + return &DeterministicUUIDGenerator{ + Named: make(map[string]uuid.UUID), + } +} + +func (d *DeterministicUUIDGenerator) ID(name string) uuid.UUID { + if v, ok := d.Named[name]; ok { + return v + } + d.Named[name] = uuid.New() + return d.Named[name] +} diff --git a/coderd/coderdtest/uuids_test.go b/coderd/coderdtest/uuids_test.go new file mode 100644 index 0000000000000..935be36eb8b15 --- /dev/null +++ b/coderd/coderdtest/uuids_test.go @@ -0,0 +1,17 @@ +package coderdtest_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/coderdtest" +) + +func TestDeterministicUUIDGenerator(t *testing.T) { + t.Parallel() + + ids := coderdtest.NewDeterministicUUIDGenerator() + require.Equal(t, ids.ID("g1"), ids.ID("g1")) + require.NotEqual(t, ids.ID("g1"), ids.ID("g2")) +} diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 5782bdc8e7155..077d704be1300 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2892,6 +2892,14 @@ func (q *querier) InsertUser(ctx context.Context, arg database.InsertUserParams) return insert(q.log, q.auth, obj, q.db.InsertUser)(ctx, arg) } +func (q *querier) InsertUserGroupsByID(ctx context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) { + // This is used by OIDC sync. So only used by a system user. + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.InsertUserGroupsByID(ctx, arg) +} + func (q *querier) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error { // This will add the user to all named groups. This counts as updating a group. // NOTE: instead of checking if the user has permission to update each group, we instead @@ -3100,6 +3108,14 @@ func (q *querier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) return q.db.RemoveUserFromAllGroups(ctx, userID) } +func (q *querier) RemoveUserFromGroups(ctx context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) { + // This is a system function to clear user groups in group sync. + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.RemoveUserFromGroups(ctx, arg) +} + func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil { return err diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index d23bb48184b61..4b4874f34247c 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -388,6 +388,17 @@ func (s *MethodTestSuite) TestGroup() { GroupNames: slice.New(g1.Name, g2.Name), }).Asserts(rbac.ResourceGroup.InOrg(o.ID), policy.ActionUpdate).Returns() })) + s.Run("InsertUserGroupsByID", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + u1 := dbgen.User(s.T(), db, database.User{}) + g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + _ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g1.ID, UserID: u1.ID}) + check.Args(database.InsertUserGroupsByIDParams{ + UserID: u1.ID, + GroupIds: slice.New(g1.ID, g2.ID), + }).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns(slice.New(g1.ID, g2.ID)) + })) s.Run("RemoveUserFromAllGroups", s.Subtest(func(db database.Store, check *expects) { o := dbgen.Organization(s.T(), db, database.Organization{}) u1 := dbgen.User(s.T(), db, database.User{}) @@ -397,6 +408,18 @@ func (s *MethodTestSuite) TestGroup() { _ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g2.ID, UserID: u1.ID}) check.Args(u1.ID).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns() })) + s.Run("RemoveUserFromGroups", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + u1 := dbgen.User(s.T(), db, database.User{}) + g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + _ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g1.ID, UserID: u1.ID}) + _ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g2.ID, UserID: u1.ID}) + check.Args(database.RemoveUserFromGroupsParams{ + UserID: u1.ID, + GroupIds: []uuid.UUID{g1.ID, g2.ID}, + }).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns(slice.New(g1.ID, g2.ID)) + })) s.Run("UpdateGroupByID", s.Subtest(func(db database.Store, check *expects) { g := dbgen.Group(s.T(), db, database.Group{}) check.Args(database.UpdateGroupByIDParams{ diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 04f0d32537f90..ed766d48ecd43 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -2695,18 +2695,18 @@ func (q *FakeQuerier) GetGroups(_ context.Context, arg database.GetGroupsParams) q.mutex.RLock() defer q.mutex.RUnlock() - groupIDs := make(map[uuid.UUID]struct{}) + userGroupIDs := make(map[uuid.UUID]struct{}) if arg.HasMemberID != uuid.Nil { for _, member := range q.groupMembers { if member.UserID == arg.HasMemberID { - groupIDs[member.GroupID] = struct{}{} + userGroupIDs[member.GroupID] = struct{}{} } } // Handle the everyone group for _, orgMember := range q.organizationMembers { if orgMember.UserID == arg.HasMemberID { - groupIDs[orgMember.OrganizationID] = struct{}{} + userGroupIDs[orgMember.OrganizationID] = struct{}{} } } } @@ -2718,11 +2718,15 @@ func (q *FakeQuerier) GetGroups(_ context.Context, arg database.GetGroupsParams) continue } - _, ok := groupIDs[group.ID] + _, ok := userGroupIDs[group.ID] if arg.HasMemberID != uuid.Nil && !ok { continue } + if len(arg.GroupNames) > 0 && !slices.Contains(arg.GroupNames, group.Name) { + continue + } + orgDetails, ok := orgDetailsCache[group.ID] if !ok { for _, org := range q.organizations { @@ -7015,7 +7019,37 @@ func (q *FakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParam return user, nil } +func (q *FakeQuerier) InsertUserGroupsByID(_ context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) { + err := validateDatabaseType(arg) + if err != nil { + return nil, err + } + + q.mutex.Lock() + defer q.mutex.Unlock() + + var groupIDs []uuid.UUID + for _, group := range q.groups { + for _, groupID := range arg.GroupIds { + if group.ID == groupID { + q.groupMembers = append(q.groupMembers, database.GroupMemberTable{ + UserID: arg.UserID, + GroupID: groupID, + }) + groupIDs = append(groupIDs, group.ID) + } + } + } + + return groupIDs, nil +} + func (q *FakeQuerier) InsertUserGroupsByName(_ context.Context, arg database.InsertUserGroupsByNameParams) error { + err := validateDatabaseType(arg) + if err != nil { + return err + } + q.mutex.Lock() defer q.mutex.Unlock() @@ -7607,6 +7641,34 @@ func (q *FakeQuerier) RemoveUserFromAllGroups(_ context.Context, userID uuid.UUI return nil } +func (q *FakeQuerier) RemoveUserFromGroups(_ context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) { + err := validateDatabaseType(arg) + if err != nil { + return nil, err + } + + q.mutex.Lock() + defer q.mutex.Unlock() + + removed := make([]uuid.UUID, 0) + q.data.groupMembers = slices.DeleteFunc(q.data.groupMembers, func(groupMember database.GroupMemberTable) bool { + // Delete all group members that match the arguments. + if groupMember.UserID != arg.UserID { + // Not the right user, ignore. + return false + } + + if !slices.Contains(arg.GroupIds, groupMember.GroupID) { + return false + } + + removed = append(removed, groupMember.GroupID) + return true + }) + + return removed, nil +} + func (q *FakeQuerier) RevokeDBCryptKey(_ context.Context, activeKeyDigest string) error { q.mutex.Lock() defer q.mutex.Unlock() diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 5aa3a0c8d8cfb..0ec70c1736d43 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -1789,6 +1789,13 @@ func (m metricsStore) InsertUser(ctx context.Context, arg database.InsertUserPar return user, err } +func (m metricsStore) InsertUserGroupsByID(ctx context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) { + start := time.Now() + r0, r1 := m.s.InsertUserGroupsByID(ctx, arg) + m.queryLatencies.WithLabelValues("InsertUserGroupsByID").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error { start := time.Now() err := m.s.InsertUserGroupsByName(ctx, arg) @@ -1943,6 +1950,13 @@ func (m metricsStore) RemoveUserFromAllGroups(ctx context.Context, userID uuid.U return r0 } +func (m metricsStore) RemoveUserFromGroups(ctx context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) { + start := time.Now() + r0, r1 := m.s.RemoveUserFromGroups(ctx, arg) + m.queryLatencies.WithLabelValues("RemoveUserFromGroups").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error { start := time.Now() r0 := m.s.RevokeDBCryptKey(ctx, activeKeyDigest) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 6d881cfe6fc1b..c5d579e1c2656 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -3766,6 +3766,21 @@ func (mr *MockStoreMockRecorder) InsertUser(arg0, arg1 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertUser", reflect.TypeOf((*MockStore)(nil).InsertUser), arg0, arg1) } +// InsertUserGroupsByID mocks base method. +func (m *MockStore) InsertUserGroupsByID(arg0 context.Context, arg1 database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InsertUserGroupsByID", arg0, arg1) + ret0, _ := ret[0].([]uuid.UUID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// InsertUserGroupsByID indicates an expected call of InsertUserGroupsByID. +func (mr *MockStoreMockRecorder) InsertUserGroupsByID(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertUserGroupsByID", reflect.TypeOf((*MockStore)(nil).InsertUserGroupsByID), arg0, arg1) +} + // InsertUserGroupsByName mocks base method. func (m *MockStore) InsertUserGroupsByName(arg0 context.Context, arg1 database.InsertUserGroupsByNameParams) error { m.ctrl.T.Helper() @@ -4103,6 +4118,21 @@ func (mr *MockStoreMockRecorder) RemoveUserFromAllGroups(arg0, arg1 any) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromAllGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromAllGroups), arg0, arg1) } +// RemoveUserFromGroups mocks base method. +func (m *MockStore) RemoveUserFromGroups(arg0 context.Context, arg1 database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveUserFromGroups", arg0, arg1) + ret0, _ := ret[0].([]uuid.UUID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RemoveUserFromGroups indicates an expected call of RemoveUserFromGroups. +func (mr *MockStoreMockRecorder) RemoveUserFromGroups(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromGroups), arg0, arg1) +} + // RevokeDBCryptKey mocks base method. func (m *MockStore) RevokeDBCryptKey(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 3432bac7dada1..ee9a64f12076d 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -369,6 +369,9 @@ type sqlcQuerier interface { InsertTemplateVersionVariable(ctx context.Context, arg InsertTemplateVersionVariableParams) (TemplateVersionVariable, error) InsertTemplateVersionWorkspaceTag(ctx context.Context, arg InsertTemplateVersionWorkspaceTagParams) (TemplateVersionWorkspaceTag, error) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) + // InsertUserGroupsByID adds a user to all provided groups, if they exist. + // If there is a conflict, the user is already a member + InsertUserGroupsByID(ctx context.Context, arg InsertUserGroupsByIDParams) ([]uuid.UUID, error) // InsertUserGroupsByName adds a user to all provided groups, if they exist. InsertUserGroupsByName(ctx context.Context, arg InsertUserGroupsByNameParams) error InsertUserLink(ctx context.Context, arg InsertUserLinkParams) (UserLink, error) @@ -396,6 +399,7 @@ type sqlcQuerier interface { ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error + RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error // Non blocking lock. Returns true if the lock was acquired, false otherwise. // diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 89822a72a7855..3616fcb66d3fb 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1446,6 +1446,56 @@ func (q *sqlQuerier) InsertGroupMember(ctx context.Context, arg InsertGroupMembe return err } +const insertUserGroupsByID = `-- name: InsertUserGroupsByID :many +WITH groups AS ( + SELECT + id + FROM + groups + WHERE + groups.id = ANY($2 :: uuid []) +) +INSERT INTO + group_members (user_id, group_id) +SELECT + $1, + groups.id +FROM + groups +ON CONFLICT DO NOTHING +RETURNING group_id +` + +type InsertUserGroupsByIDParams struct { + UserID uuid.UUID `db:"user_id" json:"user_id"` + GroupIds []uuid.UUID `db:"group_ids" json:"group_ids"` +} + +// InsertUserGroupsByID adds a user to all provided groups, if they exist. +// If there is a conflict, the user is already a member +func (q *sqlQuerier) InsertUserGroupsByID(ctx context.Context, arg InsertUserGroupsByIDParams) ([]uuid.UUID, error) { + rows, err := q.db.QueryContext(ctx, insertUserGroupsByID, arg.UserID, pq.Array(arg.GroupIds)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []uuid.UUID + for rows.Next() { + var group_id uuid.UUID + if err := rows.Scan(&group_id); err != nil { + return nil, err + } + items = append(items, group_id) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const insertUserGroupsByName = `-- name: InsertUserGroupsByName :exec WITH groups AS ( SELECT @@ -1489,6 +1539,43 @@ func (q *sqlQuerier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UU return err } +const removeUserFromGroups = `-- name: RemoveUserFromGroups :many +DELETE FROM + group_members +WHERE + user_id = $1 AND + group_id = ANY($2 :: uuid []) +RETURNING group_id +` + +type RemoveUserFromGroupsParams struct { + UserID uuid.UUID `db:"user_id" json:"user_id"` + GroupIds []uuid.UUID `db:"group_ids" json:"group_ids"` +} + +func (q *sqlQuerier) RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error) { + rows, err := q.db.QueryContext(ctx, removeUserFromGroups, arg.UserID, pq.Array(arg.GroupIds)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []uuid.UUID + for rows.Next() { + var group_id uuid.UUID + if err := rows.Scan(&group_id); err != nil { + return nil, err + } + items = append(items, group_id) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const deleteGroupByID = `-- name: DeleteGroupByID :exec DELETE FROM groups @@ -1592,11 +1679,16 @@ WHERE ) ELSE true END + AND CASE WHEN array_length($3 :: text[], 1) > 0 THEN + groups.name = ANY($3) + ELSE true + END ` type GetGroupsParams struct { OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` HasMemberID uuid.UUID `db:"has_member_id" json:"has_member_id"` + GroupNames []string `db:"group_names" json:"group_names"` } type GetGroupsRow struct { @@ -1606,7 +1698,7 @@ type GetGroupsRow struct { } func (q *sqlQuerier) GetGroups(ctx context.Context, arg GetGroupsParams) ([]GetGroupsRow, error) { - rows, err := q.db.QueryContext(ctx, getGroups, arg.OrganizationID, arg.HasMemberID) + rows, err := q.db.QueryContext(ctx, getGroups, arg.OrganizationID, arg.HasMemberID, pq.Array(arg.GroupNames)) if err != nil { return nil, err } diff --git a/coderd/database/queries/groupmembers.sql b/coderd/database/queries/groupmembers.sql index 0ef2c72323cc9..4efe9bf488590 100644 --- a/coderd/database/queries/groupmembers.sql +++ b/coderd/database/queries/groupmembers.sql @@ -29,12 +29,41 @@ SELECT FROM groups; +-- InsertUserGroupsByID adds a user to all provided groups, if they exist. +-- name: InsertUserGroupsByID :many +WITH groups AS ( + SELECT + id + FROM + groups + WHERE + groups.id = ANY(@group_ids :: uuid []) +) +INSERT INTO + group_members (user_id, group_id) +SELECT + @user_id, + groups.id +FROM + groups +-- If there is a conflict, the user is already a member +ON CONFLICT DO NOTHING +RETURNING group_id; + -- name: RemoveUserFromAllGroups :exec DELETE FROM group_members WHERE user_id = @user_id; +-- name: RemoveUserFromGroups :many +DELETE FROM + group_members +WHERE + user_id = @user_id AND + group_id = ANY(@group_ids :: uuid []) +RETURNING group_id; + -- name: InsertGroupMember :exec INSERT INTO group_members (user_id, group_id) diff --git a/coderd/database/queries/groups.sql b/coderd/database/queries/groups.sql index 1752ccd112ea7..780c0d0154740 100644 --- a/coderd/database/queries/groups.sql +++ b/coderd/database/queries/groups.sql @@ -52,6 +52,10 @@ WHERE ) ELSE true END + AND CASE WHEN array_length(@group_names :: text[], 1) > 0 THEN + groups.name = ANY(@group_names) + ELSE true + END ; -- name: InsertGroup :one diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go new file mode 100644 index 0000000000000..1b6b8f76dc685 --- /dev/null +++ b/coderd/idpsync/group.go @@ -0,0 +1,416 @@ +package idpsync + +import ( + "context" + "encoding/json" + "fmt" + "regexp" + + "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" + "golang.org/x/xerrors" + + "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/runtimeconfig" + "github.com/coder/coder/v2/coderd/util/slice" +) + +type GroupParams struct { + // SyncEnabled if false will skip syncing the user's groups + SyncEnabled bool + MergedClaims jwt.MapClaims +} + +func (AGPLIDPSync) GroupSyncEnabled() bool { + // AGPL does not support syncing groups. + return false +} + +func (s AGPLIDPSync) GroupSyncSettings() runtimeconfig.RuntimeEntry[*GroupSyncSettings] { + return s.Group +} + +func (s AGPLIDPSync) ParseGroupClaims(_ context.Context, _ jwt.MapClaims) (GroupParams, *HTTPError) { + return GroupParams{ + SyncEnabled: s.GroupSyncEnabled(), + }, nil +} + +func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user database.User, params GroupParams) error { + // Nothing happens if sync is not enabled + if !params.SyncEnabled { + return nil + } + + // nolint:gocritic // all syncing is done as a system user + ctx = dbauthz.AsSystemRestricted(ctx) + + // Only care about the default org for deployment settings if the + // legacy deployment settings exist. + defaultOrgID := uuid.Nil + // Default organization is configured via legacy deployment values + if s.DeploymentSyncSettings.Legacy.GroupField != "" { + defaultOrganization, err := db.GetDefaultOrganization(ctx) + if err != nil { + return xerrors.Errorf("get default organization: %w", err) + } + defaultOrgID = defaultOrganization.ID + } + + err := db.InTx(func(tx database.Store) error { + userGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ + HasMemberID: user.ID, + }) + if err != nil { + return xerrors.Errorf("get user groups: %w", err) + } + + // Figure out which organizations the user is a member of. + // The "Everyone" group is always included, so we can infer organization + // membership via the groups the user is in. + userOrgs := make(map[uuid.UUID][]database.GetGroupsRow) + for _, g := range userGroups { + g := g + userOrgs[g.Group.OrganizationID] = append(userOrgs[g.Group.OrganizationID], g) + } + + // For each org, we need to fetch the sync settings + // This loop also handles any legacy settings for the default + // organization. + orgSettings := make(map[uuid.UUID]GroupSyncSettings) + for orgID := range userOrgs { + orgResolver := s.Manager.OrganizationResolver(tx, orgID) + settings, err := s.SyncSettings.Group.Resolve(ctx, orgResolver) + if err != nil { + if !xerrors.Is(err, runtimeconfig.ErrEntryNotFound) { + return xerrors.Errorf("resolve group sync settings: %w", err) + } + // Default to not being configured + settings = &GroupSyncSettings{} + } + + // Legacy deployment settings will override empty settings. + if orgID == defaultOrgID && settings.Field == "" { + settings = &GroupSyncSettings{ + Field: s.Legacy.GroupField, + LegacyNameMapping: s.Legacy.GroupMapping, + RegexFilter: s.Legacy.GroupFilter, + AutoCreateMissing: s.Legacy.CreateMissingGroups, + } + } + orgSettings[orgID] = *settings + } + + // groupIDsToAdd & groupIDsToRemove are the final group differences + // needed to be applied to user. The loop below will iterate over all + // organizations the user is in, and determine the diffs. + // The diffs are applied as a batch sql query, rather than each + // organization having to execute a query. + groupIDsToAdd := make([]uuid.UUID, 0) + groupIDsToRemove := make([]uuid.UUID, 0) + // For each org, determine which groups the user should land in + for orgID, settings := range orgSettings { + if settings.Field == "" { + // No group sync enabled for this org, so do nothing. + // The user can remain in their groups for this org. + continue + } + + // expectedGroups is the set of groups the IDP expects the + // user to be a member of. + expectedGroups, err := settings.ParseClaims(orgID, params.MergedClaims) + if err != nil { + s.Logger.Debug(ctx, "failed to parse claims for groups", + slog.F("organization_field", s.GroupField), + slog.F("organization_id", orgID), + slog.Error(err), + ) + // Unsure where to raise this error on the UI or database. + // TODO: This error prevents group sync, but we have no way + // to raise this to an org admin. Come up with a solution to + // notify the admin and user of this issue. + continue + } + // Everyone group is always implied, so include it. + expectedGroups = append(expectedGroups, ExpectedGroup{ + OrganizationID: orgID, + GroupID: &orgID, + }) + + // Now we know what groups the user should be in for a given org, + // determine if we have to do any group updates to sync the user's + // state. + existingGroups := userOrgs[orgID] + existingGroupsTyped := db2sdk.List(existingGroups, func(f database.GetGroupsRow) ExpectedGroup { + return ExpectedGroup{ + OrganizationID: orgID, + GroupID: &f.Group.ID, + GroupName: &f.Group.Name, + } + }) + + add, remove := slice.SymmetricDifferenceFunc(existingGroupsTyped, expectedGroups, func(a, b ExpectedGroup) bool { + return a.Equal(b) + }) + + for _, r := range remove { + if r.GroupID == nil { + // This should never happen. All group removals come from the + // existing set, which come from the db. All groups from the + // database have IDs. This code is purely defensive. + detail := "user:" + user.Username + if r.GroupName != nil { + detail += fmt.Sprintf(" from group %s", *r.GroupName) + } + return xerrors.Errorf("removal group has nil ID, which should never happen: %s", detail) + } + groupIDsToRemove = append(groupIDsToRemove, *r.GroupID) + } + + // HandleMissingGroups will add the new groups to the org if + // the settings specify. It will convert all group names into uuids + // for easier assignment. + // TODO: This code should be batched at the end of the for loop. + // Optimizing this is being pushed because if AutoCreate is disabled, + // this code will only add cost on the first login for each user. + // AutoCreate is usually disabled for large deployments. + // For small deployments, this is less of a problem. + assignGroups, err := settings.HandleMissingGroups(ctx, tx, orgID, add) + if err != nil { + return xerrors.Errorf("handle missing groups: %w", err) + } + + groupIDsToAdd = append(groupIDsToAdd, assignGroups...) + } + + // ApplyGroupDifference will take the total adds and removes, and apply + // them. + err = s.ApplyGroupDifference(ctx, tx, user, groupIDsToAdd, groupIDsToRemove) + if err != nil { + return xerrors.Errorf("apply group difference: %w", err) + } + + return nil + }, nil) + if err != nil { + return err + } + + return nil +} + +// ApplyGroupDifference will add and remove the user from the specified groups. +func (s AGPLIDPSync) ApplyGroupDifference(ctx context.Context, tx database.Store, user database.User, add []uuid.UUID, removeIDs []uuid.UUID) error { + if len(removeIDs) > 0 { + removedGroupIDs, err := tx.RemoveUserFromGroups(ctx, database.RemoveUserFromGroupsParams{ + UserID: user.ID, + GroupIds: removeIDs, + }) + if err != nil { + return xerrors.Errorf("remove user from %d groups: %w", len(removeIDs), err) + } + if len(removedGroupIDs) != len(removeIDs) { + s.Logger.Debug(ctx, "user not removed from expected number of groups", + slog.F("user_id", user.ID), + slog.F("groups_removed_count", len(removedGroupIDs)), + slog.F("expected_count", len(removeIDs)), + ) + } + } + + if len(add) > 0 { + add = slice.Unique(add) + // Defensive programming to only insert uniques. + assignedGroupIDs, err := tx.InsertUserGroupsByID(ctx, database.InsertUserGroupsByIDParams{ + UserID: user.ID, + GroupIds: add, + }) + if err != nil { + return xerrors.Errorf("insert user into %d groups: %w", len(add), err) + } + if len(assignedGroupIDs) != len(add) { + s.Logger.Debug(ctx, "user not assigned to expected number of groups", + slog.F("user_id", user.ID), + slog.F("groups_assigned_count", len(assignedGroupIDs)), + slog.F("expected_count", len(add)), + ) + } + } + + return nil +} + +type GroupSyncSettings struct { + // Field selects the claim field to be used as the created user's + // groups. If the group field is the empty string, then no group updates + // will ever come from the OIDC provider. + Field string `json:"field"` + // Mapping maps from an OIDC group --> Coder group ID + Mapping map[string][]uuid.UUID `json:"mapping"` + // RegexFilter is a regular expression that filters the groups returned by + // the OIDC provider. Any group not matched by this regex will be ignored. + // If the group filter is nil, then no group filtering will occur. + RegexFilter *regexp.Regexp `json:"regex_filter"` + // AutoCreateMissing controls whether groups returned by the OIDC provider + // are automatically created in Coder if they are missing. + AutoCreateMissing bool `json:"auto_create_missing_groups"` + // LegacyNameMapping is deprecated. It remaps an IDP group name to + // a Coder group name. Since configuration is now done at runtime, + // group IDs are used to account for group renames. + // For legacy configurations, this config option has to remain. + // Deprecated: Use Mapping instead. + LegacyNameMapping map[string]string `json:"legacy_group_name_mapping,omitempty"` +} + +func (s *GroupSyncSettings) Set(v string) error { + return json.Unmarshal([]byte(v), s) +} + +func (s *GroupSyncSettings) String() string { + return runtimeconfig.JSONString(s) +} + +type ExpectedGroup struct { + OrganizationID uuid.UUID + GroupID *uuid.UUID + GroupName *string +} + +// Equal compares two ExpectedGroups. The org id must be the same. +// If the group ID is set, it will be compared and take priority, ignoring the +// name value. So 2 groups with the same ID but different names will be +// considered equal. +func (a ExpectedGroup) Equal(b ExpectedGroup) bool { + // Must match + if a.OrganizationID != b.OrganizationID { + return false + } + // Only the name or the name needs to be checked, priority is given to the ID. + if a.GroupID != nil && b.GroupID != nil { + return *a.GroupID == *b.GroupID + } + if a.GroupName != nil && b.GroupName != nil { + return *a.GroupName == *b.GroupName + } + + // If everything is nil, it is equal. Although a bit pointless + if a.GroupID == nil && b.GroupID == nil && + a.GroupName == nil && b.GroupName == nil { + return true + } + return false +} + +// ParseClaims will take the merged claims from the IDP and return the groups +// the user is expected to be a member of. The expected group can either be a +// name or an ID. +// It is unfortunate we cannot use exclusively names or exclusively IDs. +// When configuring though, if a group is mapped from "A" -> "UUID 1234", and +// the group "UUID 1234" is renamed, we want to maintain the mapping. +// We have to keep names because group sync supports syncing groups by name if +// the external IDP group name matches the Coder one. +func (s GroupSyncSettings) ParseClaims(orgID uuid.UUID, mergedClaims jwt.MapClaims) ([]ExpectedGroup, error) { + groupsRaw, ok := mergedClaims[s.Field] + if !ok { + return []ExpectedGroup{}, nil + } + + parsedGroups, err := ParseStringSliceClaim(groupsRaw) + if err != nil { + return nil, xerrors.Errorf("parse groups field, unexpected type %T: %w", groupsRaw, err) + } + + groups := make([]ExpectedGroup, 0) + for _, group := range parsedGroups { + group := group + + // Legacy group mappings happen before the regex filter. + mappedGroupName, ok := s.LegacyNameMapping[group] + if ok { + group = mappedGroupName + } + + // Only allow through groups that pass the regex + if s.RegexFilter != nil { + if !s.RegexFilter.MatchString(group) { + continue + } + } + + mappedGroupIDs, ok := s.Mapping[group] + if ok { + for _, gid := range mappedGroupIDs { + gid := gid + groups = append(groups, ExpectedGroup{OrganizationID: orgID, GroupID: &gid}) + } + continue + } + + groups = append(groups, ExpectedGroup{OrganizationID: orgID, GroupName: &group}) + } + + return groups, nil +} + +// HandleMissingGroups ensures all ExpectedGroups convert to uuids. +// Groups can be referenced by name via legacy params or IDP group names. +// These group names are converted to IDs for easier assignment. +// Missing groups are created if AutoCreate is enabled. +// TODO: Batching this would be better, as this is 1 or 2 db calls per organization. +func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database.Store, orgID uuid.UUID, add []ExpectedGroup) ([]uuid.UUID, error) { + // All expected that are missing IDs means the group does not exist + // in the database, or it is a legacy mapping, and we need to do a lookup. + var missingGroups []string + addIDs := make([]uuid.UUID, 0) + + for _, expected := range add { + if expected.GroupID == nil && expected.GroupName != nil { + missingGroups = append(missingGroups, *expected.GroupName) + } else if expected.GroupID != nil { + // Keep the IDs to sync the groups. + addIDs = append(addIDs, *expected.GroupID) + } + } + + if s.AutoCreateMissing && len(missingGroups) > 0 { + // Insert any missing groups. If the groups already exist, this is a noop. + _, err := tx.InsertMissingGroups(ctx, database.InsertMissingGroupsParams{ + OrganizationID: orgID, + Source: database.GroupSourceOidc, + GroupNames: missingGroups, + }) + if err != nil { + return nil, xerrors.Errorf("insert missing groups: %w", err) + } + } + + // Fetch any missing groups by name. If they exist, their IDs will be + // matched and returned. + if len(missingGroups) > 0 { + // Do name lookups for all groups that are missing IDs. + newGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ + OrganizationID: orgID, + HasMemberID: uuid.UUID{}, + GroupNames: missingGroups, + }) + if err != nil { + return nil, xerrors.Errorf("get groups by names: %w", err) + } + for _, g := range newGroups { + addIDs = append(addIDs, g.Group.ID) + } + } + + return addIDs, nil +} + +func ConvertAllowList(allowList []string) map[string]struct{} { + allowMap := make(map[string]struct{}, len(allowList)) + for _, group := range allowList { + allowMap[group] = struct{}{} + } + return allowMap +} diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go new file mode 100644 index 0000000000000..cf312a576d720 --- /dev/null +++ b/coderd/idpsync/group_test.go @@ -0,0 +1,814 @@ +package idpsync_test + +import ( + "context" + "database/sql" + "regexp" + "testing" + + "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "golang.org/x/exp/slices" + "golang.org/x/xerrors" + + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/idpsync" + "github.com/coder/coder/v2/coderd/runtimeconfig" + "github.com/coder/coder/v2/coderd/util/ptr" + "github.com/coder/coder/v2/testutil" +) + +func TestParseGroupClaims(t *testing.T) { + t.Parallel() + + t.Run("EmptyConfig", func(t *testing.T) { + t.Parallel() + + s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), + runtimeconfig.NewManager(), + idpsync.DeploymentSyncSettings{}) + + ctx := testutil.Context(t, testutil.WaitMedium) + + params, err := s.ParseGroupClaims(ctx, jwt.MapClaims{}) + require.Nil(t, err) + + require.False(t, params.SyncEnabled) + }) + + // AllowList has no effect in AGPL + t.Run("AllowList", func(t *testing.T) { + t.Parallel() + + s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), + runtimeconfig.NewManager(), + idpsync.DeploymentSyncSettings{ + GroupField: "groups", + GroupAllowList: map[string]struct{}{ + "foo": {}, + }, + }) + + ctx := testutil.Context(t, testutil.WaitMedium) + + params, err := s.ParseGroupClaims(ctx, jwt.MapClaims{}) + require.Nil(t, err) + require.False(t, params.SyncEnabled) + }) +} + +func TestGroupSyncTable(t *testing.T) { + t.Parallel() + + // Last checked, takes 30s with postgres on a fast machine. + if dbtestutil.WillUsePostgres() { + t.Skip("Skipping test because it populates a lot of db entries, which is slow on postgres.") + } + + userClaims := jwt.MapClaims{ + "groups": []string{ + "foo", "bar", "baz", + "create-bar", "create-baz", + "legacy-bar", + }, + } + + ids := coderdtest.NewDeterministicUUIDGenerator() + testCases := []orgSetupDefinition{ + { + Name: "SwitchGroups", + Settings: &idpsync.GroupSyncSettings{ + Field: "groups", + Mapping: map[string][]uuid.UUID{ + "foo": {ids.ID("sg-foo"), ids.ID("sg-foo-2")}, + "bar": {ids.ID("sg-bar")}, + "baz": {ids.ID("sg-baz")}, + }, + }, + Groups: map[uuid.UUID]bool{ + uuid.New(): true, + uuid.New(): true, + // Extra groups + ids.ID("sg-foo"): false, + ids.ID("sg-foo-2"): false, + ids.ID("sg-bar"): false, + ids.ID("sg-baz"): false, + }, + ExpectedGroups: []uuid.UUID{ + ids.ID("sg-foo"), + ids.ID("sg-foo-2"), + ids.ID("sg-bar"), + ids.ID("sg-baz"), + }, + }, + { + Name: "StayInGroup", + Settings: &idpsync.GroupSyncSettings{ + Field: "groups", + // Only match foo, so bar does not map + RegexFilter: regexp.MustCompile("^foo$"), + Mapping: map[string][]uuid.UUID{ + "foo": {ids.ID("gg-foo"), uuid.New()}, + "bar": {ids.ID("gg-bar")}, + "baz": {ids.ID("gg-baz")}, + }, + }, + Groups: map[uuid.UUID]bool{ + ids.ID("gg-foo"): true, + ids.ID("gg-bar"): false, + }, + ExpectedGroups: []uuid.UUID{ + ids.ID("gg-foo"), + }, + }, + { + Name: "UserJoinsGroups", + Settings: &idpsync.GroupSyncSettings{ + Field: "groups", + Mapping: map[string][]uuid.UUID{ + "foo": {ids.ID("ng-foo"), uuid.New()}, + "bar": {ids.ID("ng-bar"), ids.ID("ng-bar-2")}, + "baz": {ids.ID("ng-baz")}, + }, + }, + Groups: map[uuid.UUID]bool{ + ids.ID("ng-foo"): false, + ids.ID("ng-bar"): false, + ids.ID("ng-bar-2"): false, + ids.ID("ng-baz"): false, + }, + ExpectedGroups: []uuid.UUID{ + ids.ID("ng-foo"), + ids.ID("ng-bar"), + ids.ID("ng-bar-2"), + ids.ID("ng-baz"), + }, + }, + { + Name: "CreateGroups", + Settings: &idpsync.GroupSyncSettings{ + Field: "groups", + RegexFilter: regexp.MustCompile("^create"), + AutoCreateMissing: true, + }, + Groups: map[uuid.UUID]bool{}, + ExpectedGroupNames: []string{ + "create-bar", + "create-baz", + }, + }, + { + Name: "GroupNamesNoMapping", + Settings: &idpsync.GroupSyncSettings{ + Field: "groups", + RegexFilter: regexp.MustCompile(".*"), + AutoCreateMissing: false, + }, + GroupNames: map[string]bool{ + "foo": false, + "bar": false, + "goob": true, + }, + ExpectedGroupNames: []string{ + "foo", + "bar", + }, + }, + { + Name: "NoUser", + Settings: &idpsync.GroupSyncSettings{ + Field: "groups", + Mapping: map[string][]uuid.UUID{ + // Extra ID that does not map to a group + "foo": {ids.ID("ow-foo"), uuid.New()}, + }, + RegexFilter: nil, + AutoCreateMissing: false, + }, + NotMember: true, + Groups: map[uuid.UUID]bool{ + ids.ID("ow-foo"): false, + ids.ID("ow-bar"): false, + }, + }, + { + Name: "NoSettingsNoUser", + Settings: nil, + Groups: map[uuid.UUID]bool{}, + }, + { + Name: "LegacyMapping", + Settings: &idpsync.GroupSyncSettings{ + Field: "groups", + RegexFilter: regexp.MustCompile("^legacy"), + LegacyNameMapping: map[string]string{ + "create-bar": "legacy-bar", + "foo": "legacy-foo", + "bop": "legacy-bop", + }, + AutoCreateMissing: true, + }, + Groups: map[uuid.UUID]bool{ + ids.ID("lg-foo"): true, + }, + GroupNames: map[string]bool{ + "legacy-foo": false, + "extra": true, + "legacy-bop": true, + }, + ExpectedGroupNames: []string{ + "legacy-bar", + "legacy-foo", + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.Name, func(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + manager := runtimeconfig.NewManager() + s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), + manager, + idpsync.DeploymentSyncSettings{ + GroupField: "groups", + }, + ) + + ctx := testutil.Context(t, testutil.WaitSuperLong) + user := dbgen.User(t, db, database.User{}) + orgID := uuid.New() + SetupOrganization(t, s, db, user, orgID, tc) + + // Do the group sync! + err := s.SyncGroups(ctx, db, user, idpsync.GroupParams{ + SyncEnabled: true, + MergedClaims: userClaims, + }) + require.NoError(t, err) + + tc.Assert(t, orgID, db, user) + }) + } + + // AllTogether runs the entire tabled test as a singular user and + // deployment. This tests all organizations being synced together. + // The reason we do them individually, is that it is much easier to + // debug a single test case. + t.Run("AllTogether", func(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + manager := runtimeconfig.NewManager() + s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), + manager, + // Also sync the default org! + idpsync.DeploymentSyncSettings{ + GroupField: "groups", + Legacy: idpsync.DefaultOrgLegacySettings{ + GroupField: "groups", + GroupMapping: map[string]string{ + "foo": "legacy-foo", + "baz": "legacy-baz", + }, + GroupFilter: regexp.MustCompile("^legacy"), + CreateMissingGroups: true, + }, + }, + ) + + ctx := testutil.Context(t, testutil.WaitSuperLong) + user := dbgen.User(t, db, database.User{}) + + var asserts []func(t *testing.T) + // The default org is also going to do something + def := orgSetupDefinition{ + Name: "DefaultOrg", + GroupNames: map[string]bool{ + "legacy-foo": false, + "legacy-baz": true, + "random": true, + }, + // No settings, because they come from the deployment values + Settings: nil, + ExpectedGroups: nil, + ExpectedGroupNames: []string{"legacy-foo", "legacy-baz", "legacy-bar"}, + } + + //nolint:gocritic // testing + defOrg, err := db.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx)) + require.NoError(t, err) + SetupOrganization(t, s, db, user, defOrg.ID, def) + asserts = append(asserts, func(t *testing.T) { + t.Run(def.Name, func(t *testing.T) { + t.Parallel() + def.Assert(t, defOrg.ID, db, user) + }) + }) + + for _, tc := range testCases { + tc := tc + + orgID := uuid.New() + SetupOrganization(t, s, db, user, orgID, tc) + asserts = append(asserts, func(t *testing.T) { + t.Run(tc.Name, func(t *testing.T) { + t.Parallel() + tc.Assert(t, orgID, db, user) + }) + }) + } + + asserts = append(asserts, func(t *testing.T) { + t.Helper() + def.Assert(t, defOrg.ID, db, user) + }) + + // Do the group sync! + err = s.SyncGroups(ctx, db, user, idpsync.GroupParams{ + SyncEnabled: true, + MergedClaims: userClaims, + }) + require.NoError(t, err) + + for _, assert := range asserts { + assert(t) + } + }) +} + +func TestSyncDisabled(t *testing.T) { + t.Parallel() + + if dbtestutil.WillUsePostgres() { + t.Skip("Skipping test because it populates a lot of db entries, which is slow on postgres.") + } + + db, _ := dbtestutil.NewDB(t) + manager := runtimeconfig.NewManager() + s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), + manager, + idpsync.DeploymentSyncSettings{}, + ) + + ids := coderdtest.NewDeterministicUUIDGenerator() + ctx := testutil.Context(t, testutil.WaitSuperLong) + user := dbgen.User(t, db, database.User{}) + orgID := uuid.New() + + def := orgSetupDefinition{ + Name: "SyncDisabled", + Groups: map[uuid.UUID]bool{ + ids.ID("foo"): true, + ids.ID("bar"): true, + ids.ID("baz"): false, + ids.ID("bop"): false, + }, + Settings: &idpsync.GroupSyncSettings{ + Field: "groups", + Mapping: map[string][]uuid.UUID{ + "foo": {ids.ID("foo")}, + "baz": {ids.ID("baz")}, + }, + }, + ExpectedGroups: []uuid.UUID{ + ids.ID("foo"), + ids.ID("bar"), + }, + } + + SetupOrganization(t, s, db, user, orgID, def) + + // Do the group sync! + err := s.SyncGroups(ctx, db, user, idpsync.GroupParams{ + SyncEnabled: false, + MergedClaims: jwt.MapClaims{ + "groups": []string{"baz", "bop"}, + }, + }) + require.NoError(t, err) + + def.Assert(t, orgID, db, user) +} + +// TestApplyGroupDifference is mainly testing the database functions +func TestApplyGroupDifference(t *testing.T) { + t.Parallel() + + ids := coderdtest.NewDeterministicUUIDGenerator() + testCase := []struct { + Name string + Before map[uuid.UUID]bool + Add []uuid.UUID + Remove []uuid.UUID + Expect []uuid.UUID + }{ + { + Name: "Empty", + }, + { + Name: "AddFromNone", + Before: map[uuid.UUID]bool{ + ids.ID("g1"): false, + }, + Add: []uuid.UUID{ + ids.ID("g1"), + }, + Expect: []uuid.UUID{ + ids.ID("g1"), + }, + }, + { + Name: "AddSome", + Before: map[uuid.UUID]bool{ + ids.ID("g1"): true, + ids.ID("g2"): false, + ids.ID("g3"): false, + uuid.New(): false, + }, + Add: []uuid.UUID{ + ids.ID("g2"), + ids.ID("g3"), + }, + Expect: []uuid.UUID{ + ids.ID("g1"), + ids.ID("g2"), + ids.ID("g3"), + }, + }, + { + Name: "RemoveAll", + Before: map[uuid.UUID]bool{ + uuid.New(): false, + ids.ID("g2"): true, + ids.ID("g3"): true, + }, + Remove: []uuid.UUID{ + ids.ID("g2"), + ids.ID("g3"), + }, + Expect: []uuid.UUID{}, + }, + { + Name: "Mixed", + Before: map[uuid.UUID]bool{ + // adds + ids.ID("a1"): true, + ids.ID("a2"): true, + ids.ID("a3"): false, + ids.ID("a4"): false, + // removes + ids.ID("r1"): true, + ids.ID("r2"): true, + ids.ID("r3"): false, + ids.ID("r4"): false, + // stable + ids.ID("s1"): true, + ids.ID("s2"): true, + // noise + uuid.New(): false, + uuid.New(): false, + }, + Add: []uuid.UUID{ + ids.ID("a1"), ids.ID("a2"), + ids.ID("a3"), ids.ID("a4"), + // Double up to try and confuse + ids.ID("a1"), + ids.ID("a4"), + }, + Remove: []uuid.UUID{ + ids.ID("r1"), ids.ID("r2"), + ids.ID("r3"), ids.ID("r4"), + // Double up to try and confuse + ids.ID("r1"), + ids.ID("r4"), + }, + Expect: []uuid.UUID{ + ids.ID("a1"), ids.ID("a2"), ids.ID("a3"), ids.ID("a4"), + ids.ID("s1"), ids.ID("s2"), + }, + }, + } + + for _, tc := range testCase { + tc := tc + t.Run(tc.Name, func(t *testing.T) { + t.Parallel() + + mgr := runtimeconfig.NewManager() + db, _ := dbtestutil.NewDB(t) + + ctx := testutil.Context(t, testutil.WaitMedium) + //nolint:gocritic // testing + ctx = dbauthz.AsSystemRestricted(ctx) + + org := dbgen.Organization(t, db, database.Organization{}) + _, err := db.InsertAllUsersGroup(ctx, org.ID) + require.NoError(t, err) + + user := dbgen.User(t, db, database.User{}) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + + for gid, in := range tc.Before { + group := dbgen.Group(t, db, database.Group{ + ID: gid, + OrganizationID: org.ID, + }) + if in { + _ = dbgen.GroupMember(t, db, database.GroupMemberTable{ + UserID: user.ID, + GroupID: group.ID, + }) + } + } + + s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), mgr, idpsync.FromDeploymentValues(coderdtest.DeploymentValues(t))) + err = s.ApplyGroupDifference(context.Background(), db, user, tc.Add, tc.Remove) + require.NoError(t, err) + + userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{ + HasMemberID: user.ID, + }) + require.NoError(t, err) + + // assert + found := db2sdk.List(userGroups, func(g database.GetGroupsRow) uuid.UUID { + return g.Group.ID + }) + + // Add everyone group + require.ElementsMatch(t, append(tc.Expect, org.ID), found) + }) + } +} + +func TestExpectedGroupEqual(t *testing.T) { + t.Parallel() + + ids := coderdtest.NewDeterministicUUIDGenerator() + testCases := []struct { + Name string + A idpsync.ExpectedGroup + B idpsync.ExpectedGroup + Equal bool + }{ + { + Name: "Empty", + A: idpsync.ExpectedGroup{}, + B: idpsync.ExpectedGroup{}, + Equal: true, + }, + { + Name: "DifferentOrgs", + A: idpsync.ExpectedGroup{ + OrganizationID: uuid.New(), + GroupID: ptr.Ref(ids.ID("g1")), + GroupName: nil, + }, + B: idpsync.ExpectedGroup{ + OrganizationID: uuid.New(), + GroupID: ptr.Ref(ids.ID("g1")), + GroupName: nil, + }, + Equal: false, + }, + { + Name: "SameID", + A: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: ptr.Ref(ids.ID("g1")), + GroupName: nil, + }, + B: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: ptr.Ref(ids.ID("g1")), + GroupName: nil, + }, + Equal: true, + }, + { + Name: "DifferentIDs", + A: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: ptr.Ref(uuid.New()), + GroupName: nil, + }, + B: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: ptr.Ref(uuid.New()), + GroupName: nil, + }, + Equal: false, + }, + { + Name: "SameName", + A: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: nil, + GroupName: ptr.Ref("foo"), + }, + B: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: nil, + GroupName: ptr.Ref("foo"), + }, + Equal: true, + }, + { + Name: "DifferentName", + A: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: nil, + GroupName: ptr.Ref("foo"), + }, + B: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: nil, + GroupName: ptr.Ref("bar"), + }, + Equal: false, + }, + // Edge cases + { + // A bit strange, but valid as ID takes priority. + // We assume 2 groups with the same ID are equal, even if + // their names are different. Names are mutable, IDs are not, + // so there is 0% chance they are different groups. + Name: "DifferentIDSameName", + A: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: ptr.Ref(ids.ID("g1")), + GroupName: ptr.Ref("foo"), + }, + B: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: ptr.Ref(ids.ID("g1")), + GroupName: ptr.Ref("bar"), + }, + Equal: true, + }, + { + Name: "MixedNils", + A: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: ptr.Ref(ids.ID("g1")), + GroupName: nil, + }, + B: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: nil, + GroupName: ptr.Ref("bar"), + }, + Equal: false, + }, + { + Name: "NoComparable", + A: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: ptr.Ref(ids.ID("g1")), + GroupName: nil, + }, + B: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: nil, + GroupName: nil, + }, + Equal: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.Name, func(t *testing.T) { + t.Parallel() + + require.Equal(t, tc.Equal, tc.A.Equal(tc.B)) + }) + } +} + +func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, user database.User, orgID uuid.UUID, def orgSetupDefinition) { + t.Helper() + + // Account that the org might be the default organization + org, err := db.GetOrganizationByID(context.Background(), orgID) + if xerrors.Is(err, sql.ErrNoRows) { + org = dbgen.Organization(t, db, database.Organization{ + ID: orgID, + }) + } + + _, err = db.InsertAllUsersGroup(context.Background(), org.ID) + if !database.IsUniqueViolation(err) { + require.NoError(t, err, "Everyone group for an org") + } + + manager := runtimeconfig.NewManager() + orgResolver := manager.OrganizationResolver(db, org.ID) + err = s.Group.SetRuntimeValue(context.Background(), orgResolver, def.Settings) + require.NoError(t, err) + + if !def.NotMember { + dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + } + for groupID, in := range def.Groups { + dbgen.Group(t, db, database.Group{ + ID: groupID, + OrganizationID: org.ID, + }) + if in { + dbgen.GroupMember(t, db, database.GroupMemberTable{ + UserID: user.ID, + GroupID: groupID, + }) + } + } + for groupName, in := range def.GroupNames { + group := dbgen.Group(t, db, database.Group{ + Name: groupName, + OrganizationID: org.ID, + }) + if in { + dbgen.GroupMember(t, db, database.GroupMemberTable{ + UserID: user.ID, + GroupID: group.ID, + }) + } + } +} + +type orgSetupDefinition struct { + Name string + // True if the user is a member of the group + Groups map[uuid.UUID]bool + GroupNames map[string]bool + NotMember bool + + Settings *idpsync.GroupSyncSettings + ExpectedGroups []uuid.UUID + ExpectedGroupNames []string +} + +func (o orgSetupDefinition) Assert(t *testing.T, orgID uuid.UUID, db database.Store, user database.User) { + t.Helper() + + ctx := context.Background() + + members, err := db.OrganizationMembers(ctx, database.OrganizationMembersParams{ + OrganizationID: orgID, + UserID: user.ID, + }) + require.NoError(t, err) + if o.NotMember { + require.Len(t, members, 0, "should not be a member") + } else { + require.Len(t, members, 1, "should be a member") + } + + userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{ + OrganizationID: orgID, + HasMemberID: user.ID, + }) + require.NoError(t, err) + if o.ExpectedGroups == nil { + o.ExpectedGroups = make([]uuid.UUID, 0) + } + if len(o.ExpectedGroupNames) > 0 && len(o.ExpectedGroups) > 0 { + t.Fatal("ExpectedGroups and ExpectedGroupNames are mutually exclusive") + } + + // Everyone groups mess up our asserts + userGroups = slices.DeleteFunc(userGroups, func(row database.GetGroupsRow) bool { + return row.Group.ID == row.Group.OrganizationID + }) + + if len(o.ExpectedGroupNames) > 0 { + found := db2sdk.List(userGroups, func(g database.GetGroupsRow) string { + return g.Group.Name + }) + require.ElementsMatch(t, o.ExpectedGroupNames, found, "user groups by name") + require.Len(t, o.ExpectedGroups, 0, "ExpectedGroups should be empty") + } else { + // Check by ID, recommended + found := db2sdk.List(userGroups, func(g database.GetGroupsRow) uuid.UUID { + return g.Group.ID + }) + require.ElementsMatch(t, o.ExpectedGroups, found, "user groups") + require.Len(t, o.ExpectedGroupNames, 0, "ExpectedGroupNames should be empty") + } +} diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index 73a7b9b6f530d..2c8ed10ce9bcc 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -3,6 +3,7 @@ package idpsync import ( "context" "net/http" + "regexp" "strings" "github.com/golang-jwt/jwt/v4" @@ -12,6 +13,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/runtimeconfig" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/site" ) @@ -25,21 +27,34 @@ type IDPSync interface { OrganizationSyncEnabled() bool // ParseOrganizationClaims takes claims from an OIDC provider, and returns the // organization sync params for assigning users into organizations. - ParseOrganizationClaims(ctx context.Context, _ jwt.MapClaims) (OrganizationParams, *HTTPError) + ParseOrganizationClaims(ctx context.Context, mergedClaims jwt.MapClaims) (OrganizationParams, *HTTPError) // SyncOrganizations assigns and removed users from organizations based on the // provided params. SyncOrganizations(ctx context.Context, tx database.Store, user database.User, params OrganizationParams) error + + GroupSyncEnabled() bool + // ParseGroupClaims takes claims from an OIDC provider, and returns the params + // for group syncing. Most of the logic happens in SyncGroups. + ParseGroupClaims(ctx context.Context, mergedClaims jwt.MapClaims) (GroupParams, *HTTPError) + // SyncGroups assigns and removes users from groups based on the provided params. + SyncGroups(ctx context.Context, db database.Store, user database.User, params GroupParams) error + // GroupSyncSettings is exposed for the API to implement CRUD operations + // on the settings used by IDPSync. This entry is thread safe and can be + // accessed concurrently. The settings are stored in the database. + GroupSyncSettings() runtimeconfig.RuntimeEntry[*GroupSyncSettings] } // AGPLIDPSync is the configuration for syncing user information from an external // IDP. All related code to syncing user information should be in this package. type AGPLIDPSync struct { - Logger slog.Logger + Logger slog.Logger + Manager *runtimeconfig.Manager SyncSettings } -type SyncSettings struct { +// DeploymentSyncSettings are static and are sourced from the deployment config. +type DeploymentSyncSettings struct { // OrganizationField selects the claim field to be used as the created user's // organizations. If the field is the empty string, then no organization updates // will ever come from the OIDC provider. @@ -50,23 +65,62 @@ type SyncSettings struct { // placed into the default organization. This is mostly a hack to support // legacy deployments. OrganizationAssignDefault bool + + // GroupField at the deployment level is used for deployment level group claim + // settings. + GroupField string + // GroupAllowList (if set) will restrict authentication to only users who + // have at least one group in this list. + // A map representation is used for easier lookup. + GroupAllowList map[string]struct{} + // Legacy deployment settings that only apply to the default org. + Legacy DefaultOrgLegacySettings } -type OrganizationParams struct { - // SyncEnabled if false will skip syncing the user's organizations. - SyncEnabled bool - // IncludeDefault is primarily for single org deployments. It will ensure - // a user is always inserted into the default org. - IncludeDefault bool - // Organizations is the list of organizations the user should be a member of - // assuming syncing is turned on. - Organizations []uuid.UUID +type DefaultOrgLegacySettings struct { + GroupField string + GroupMapping map[string]string + GroupFilter *regexp.Regexp + CreateMissingGroups bool +} + +func FromDeploymentValues(dv *codersdk.DeploymentValues) DeploymentSyncSettings { + if dv == nil { + panic("Developer error: DeploymentValues should not be nil") + } + return DeploymentSyncSettings{ + OrganizationField: dv.OIDC.OrganizationField.Value(), + OrganizationMapping: dv.OIDC.OrganizationMapping.Value, + OrganizationAssignDefault: dv.OIDC.OrganizationAssignDefault.Value(), + + // TODO: Separate group field for allow list from default org. + // Right now you cannot disable group sync from the default org and + // configure an allow list. + GroupField: dv.OIDC.GroupField.Value(), + GroupAllowList: ConvertAllowList(dv.OIDC.GroupAllowList.Value()), + Legacy: DefaultOrgLegacySettings{ + GroupField: dv.OIDC.GroupField.Value(), + GroupMapping: dv.OIDC.GroupMapping.Value, + GroupFilter: dv.OIDC.GroupRegexFilter.Value(), + CreateMissingGroups: dv.OIDC.GroupAutoCreate.Value(), + }, + } +} + +type SyncSettings struct { + DeploymentSyncSettings + + Group runtimeconfig.RuntimeEntry[*GroupSyncSettings] } -func NewAGPLSync(logger slog.Logger, settings SyncSettings) *AGPLIDPSync { +func NewAGPLSync(logger slog.Logger, manager *runtimeconfig.Manager, settings DeploymentSyncSettings) *AGPLIDPSync { return &AGPLIDPSync{ - Logger: logger.Named("idp-sync"), - SyncSettings: settings, + Logger: logger.Named("idp-sync"), + Manager: manager, + SyncSettings: SyncSettings{ + DeploymentSyncSettings: settings, + Group: runtimeconfig.MustNew[*GroupSyncSettings]("group-sync-settings"), + }, } } diff --git a/coderd/idpsync/organization.go b/coderd/idpsync/organization.go index 6d475f28ea0ef..fa091eba441ad 100644 --- a/coderd/idpsync/organization.go +++ b/coderd/idpsync/organization.go @@ -16,6 +16,17 @@ import ( "github.com/coder/coder/v2/coderd/util/slice" ) +type OrganizationParams struct { + // SyncEnabled if false will skip syncing the user's organizations. + SyncEnabled bool + // IncludeDefault is primarily for single org deployments. It will ensure + // a user is always inserted into the default org. + IncludeDefault bool + // Organizations is the list of organizations the user should be a member of + // assuming syncing is turned on. + Organizations []uuid.UUID +} + func (AGPLIDPSync) OrganizationSyncEnabled() bool { // AGPL does not support syncing organizations. return false diff --git a/coderd/idpsync/organizations_test.go b/coderd/idpsync/organizations_test.go index 03b1ebfa4b27b..1670beaaedc75 100644 --- a/coderd/idpsync/organizations_test.go +++ b/coderd/idpsync/organizations_test.go @@ -9,6 +9,7 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd/idpsync" + "github.com/coder/coder/v2/coderd/runtimeconfig" "github.com/coder/coder/v2/testutil" ) @@ -18,11 +19,13 @@ func TestParseOrganizationClaims(t *testing.T) { t.Run("SingleOrgDeployment", func(t *testing.T) { t.Parallel() - s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), idpsync.SyncSettings{ - OrganizationField: "", - OrganizationMapping: nil, - OrganizationAssignDefault: true, - }) + s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), + runtimeconfig.NewManager(), + idpsync.DeploymentSyncSettings{ + OrganizationField: "", + OrganizationMapping: nil, + OrganizationAssignDefault: true, + }) ctx := testutil.Context(t, testutil.WaitMedium) @@ -38,13 +41,15 @@ func TestParseOrganizationClaims(t *testing.T) { t.Parallel() // AGPL has limited behavior - s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), idpsync.SyncSettings{ - OrganizationField: "orgs", - OrganizationMapping: map[string][]uuid.UUID{ - "random": {uuid.New()}, - }, - OrganizationAssignDefault: false, - }) + s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), + runtimeconfig.NewManager(), + idpsync.DeploymentSyncSettings{ + OrganizationField: "orgs", + OrganizationMapping: map[string][]uuid.UUID{ + "random": {uuid.New()}, + }, + OrganizationAssignDefault: false, + }) ctx := testutil.Context(t, testutil.WaitMedium) diff --git a/coderd/runtimeconfig/entry.go b/coderd/runtimeconfig/entry.go index 780138a89d03b..c0260b0268ddb 100644 --- a/coderd/runtimeconfig/entry.go +++ b/coderd/runtimeconfig/entry.go @@ -2,6 +2,7 @@ package runtimeconfig import ( "context" + "encoding/json" "fmt" "golang.org/x/xerrors" @@ -93,3 +94,11 @@ func (e *RuntimeEntry[T]) name() (string, error) { return e.n, nil } + +func JSONString(v any) string { + s, err := json.Marshal(v) + if err != nil { + return "decode failed: " + err.Error() + } + return string(s) +} diff --git a/coderd/userauth.go b/coderd/userauth.go index bb149d9d07379..223f697c09bb9 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -8,7 +8,6 @@ import ( "fmt" "net/http" "net/mail" - "regexp" "sort" "strconv" "strings" @@ -20,7 +19,6 @@ import ( "github.com/google/go-github/v43/github" "github.com/google/uuid" "github.com/moby/moby/pkg/namesgenerator" - "golang.org/x/exp/slices" "golang.org/x/oauth2" "golang.org/x/xerrors" @@ -659,6 +657,9 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { AvatarURL: ghUser.GetAvatarURL(), Name: normName, DebugContext: OauthDebugContext{}, + GroupSync: idpsync.GroupParams{ + SyncEnabled: false, + }, OrganizationSync: idpsync.OrganizationParams{ SyncEnabled: false, IncludeDefault: true, @@ -741,27 +742,6 @@ type OIDCConfig struct { // support the userinfo endpoint, or if the userinfo endpoint causes // undesirable behavior. IgnoreUserInfo bool - - // TODO: Move all idp fields into the IDPSync struct - // GroupField selects the claim field to be used as the created user's - // groups. If the group field is the empty string, then no group updates - // will ever come from the OIDC provider. - GroupField string - // CreateMissingGroups controls whether groups returned by the OIDC provider - // are automatically created in Coder if they are missing. - CreateMissingGroups bool - // GroupFilter is a regular expression that filters the groups returned by - // the OIDC provider. Any group not matched by this regex will be ignored. - // If the group filter is nil, then no group filtering will occur. - GroupFilter *regexp.Regexp - // GroupAllowList is a list of groups that are allowed to log in. - // If the list length is 0, then the allow list will not be applied and - // this feature is disabled. - GroupAllowList map[string]bool - // GroupMapping controls how groups returned by the OIDC provider get mapped - // to groups within Coder. - // map[oidcGroupName]coderGroupName - GroupMapping map[string]string // UserRoleField selects the claim field to be used as the created user's // roles. If the field is the empty string, then no role updates // will ever come from the OIDC provider. @@ -1004,11 +984,6 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { } ctx = slog.With(ctx, slog.F("email", email), slog.F("username", username), slog.F("name", name)) - usingGroups, groups, groupErr := api.oidcGroups(ctx, mergedClaims) - if groupErr != nil { - groupErr.Write(rw, r) - return - } roles, roleErr := api.oidcRoles(ctx, mergedClaims) if roleErr != nil { @@ -1032,6 +1007,12 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { return } + groupSync, groupSyncErr := api.IDPSync.ParseGroupClaims(ctx, mergedClaims) + if groupSyncErr != nil { + groupSyncErr.Write(rw, r) + return + } + // If a new user is authenticating for the first time // the audit action is 'register', not 'login' if user.ID == uuid.Nil { @@ -1039,23 +1020,20 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { } params := (&oauthLoginParams{ - User: user, - Link: link, - State: state, - LinkedID: oidcLinkedID(idToken), - LoginType: database.LoginTypeOIDC, - AllowSignups: api.OIDCConfig.AllowSignups, - Email: email, - Username: username, - Name: name, - AvatarURL: picture, - UsingRoles: api.OIDCConfig.RoleSyncEnabled(), - Roles: roles, - UsingGroups: usingGroups, - Groups: groups, - OrganizationSync: orgSync, - CreateMissingGroups: api.OIDCConfig.CreateMissingGroups, - GroupFilter: api.OIDCConfig.GroupFilter, + User: user, + Link: link, + State: state, + LinkedID: oidcLinkedID(idToken), + LoginType: database.LoginTypeOIDC, + AllowSignups: api.OIDCConfig.AllowSignups, + Email: email, + Username: username, + Name: name, + AvatarURL: picture, + UsingRoles: api.OIDCConfig.RoleSyncEnabled(), + Roles: roles, + OrganizationSync: orgSync, + GroupSync: groupSync, DebugContext: OauthDebugContext{ IDTokenClaims: idtokenClaims, UserInfoClaims: userInfoClaims, @@ -1091,79 +1069,6 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect) } -// oidcGroups returns the groups for the user from the OIDC claims. -func (api *API) oidcGroups(ctx context.Context, mergedClaims map[string]interface{}) (bool, []string, *idpsync.HTTPError) { - logger := api.Logger.Named(userAuthLoggerName) - usingGroups := false - var groups []string - - // If the GroupField is the empty string, then groups from OIDC are not used. - // This is so we can support manual group assignment. - if api.OIDCConfig.GroupField != "" { - // If the allow list is empty, then the user is allowed to log in. - // Otherwise, they must belong to at least 1 group in the allow list. - inAllowList := len(api.OIDCConfig.GroupAllowList) == 0 - - usingGroups = true - groupsRaw, ok := mergedClaims[api.OIDCConfig.GroupField] - if ok { - parsedGroups, err := idpsync.ParseStringSliceClaim(groupsRaw) - if err != nil { - api.Logger.Debug(ctx, "groups field was an unknown type in oidc claims", - slog.F("type", fmt.Sprintf("%T", groupsRaw)), - slog.Error(err), - ) - return false, nil, &idpsync.HTTPError{ - Code: http.StatusBadRequest, - Msg: "Failed to sync groups from OIDC claims", - Detail: err.Error(), - RenderStaticPage: false, - } - } - - api.Logger.Debug(ctx, "groups returned in oidc claims", - slog.F("len", len(parsedGroups)), - slog.F("groups", parsedGroups), - ) - - for _, group := range parsedGroups { - if mappedGroup, ok := api.OIDCConfig.GroupMapping[group]; ok { - group = mappedGroup - } - if _, ok := api.OIDCConfig.GroupAllowList[group]; ok { - inAllowList = true - } - groups = append(groups, group) - } - } - - if !inAllowList { - logger.Debug(ctx, "oidc group claim not in allow list, rejecting login", - slog.F("allow_list_count", len(api.OIDCConfig.GroupAllowList)), - slog.F("user_group_count", len(groups)), - ) - detail := "Ask an administrator to add one of your groups to the whitelist" - if len(groups) == 0 { - detail = "You are currently not a member of any groups! Ask an administrator to add you to an authorized group to login." - } - return usingGroups, groups, &idpsync.HTTPError{ - Code: http.StatusForbidden, - Msg: "Not a member of an allowed group", - Detail: detail, - RenderStaticPage: true, - } - } - } - - // This conditional is purely to warn the user they might have misconfigured their OIDC - // configuration. - if _, groupClaimExists := mergedClaims["groups"]; !usingGroups && groupClaimExists { - logger.Debug(ctx, "claim 'groups' was returned, but 'oidc-group-field' is not set, check your coder oidc settings") - } - - return usingGroups, groups, nil -} - // oidcRoles returns the roles for the user from the OIDC claims. // If the function returns false, then the caller should return early. // All writes to the response writer are handled by this function. @@ -1278,14 +1183,7 @@ type oauthLoginParams struct { AvatarURL string // OrganizationSync has the organizations that the user will be assigned to. OrganizationSync idpsync.OrganizationParams - // Is UsingGroups is true, then the user will be assigned - // to the Groups provided. - UsingGroups bool - CreateMissingGroups bool - // These are the group names from the IDP. Internally, they will map to - // some organization groups. - Groups []string - GroupFilter *regexp.Regexp + GroupSync idpsync.GroupParams // Is UsingRoles is true, then the user will be assigned // the roles provided. UsingRoles bool @@ -1491,53 +1389,11 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C return xerrors.Errorf("sync organizations: %w", err) } - // Ensure groups are correct. - // This places all groups into the default organization. - // To go multi-org, we need to add a mapping feature here to know which - // groups go to which orgs. - if params.UsingGroups { - filtered := params.Groups - if params.GroupFilter != nil { - filtered = make([]string, 0, len(params.Groups)) - for _, group := range params.Groups { - if params.GroupFilter.MatchString(group) { - filtered = append(filtered, group) - } - } - } - - //nolint:gocritic // No user present in the context. - defaultOrganization, err := tx.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx)) - if err != nil { - // If there is no default org, then we can't assign groups. - // By default, we assume all groups belong to the default org. - return xerrors.Errorf("get default organization: %w", err) - } - - //nolint:gocritic // No user present in the context. - memberships, err := tx.OrganizationMembers(dbauthz.AsSystemRestricted(ctx), database.OrganizationMembersParams{ - UserID: user.ID, - OrganizationID: uuid.Nil, - }) - if err != nil { - return xerrors.Errorf("get organization memberships: %w", err) - } - - // If the user is not in the default organization, then we can't assign groups. - // A user cannot be in groups to an org they are not a member of. - if !slices.ContainsFunc(memberships, func(member database.OrganizationMembersRow) bool { - return member.OrganizationMember.OrganizationID == defaultOrganization.ID - }) { - return xerrors.Errorf("user %s is not a member of the default organization, cannot assign to groups in the org", user.ID) - } - - //nolint:gocritic - err = api.Options.SetUserGroups(dbauthz.AsSystemRestricted(ctx), logger, tx, user.ID, map[uuid.UUID][]string{ - defaultOrganization.ID: filtered, - }, params.CreateMissingGroups) - if err != nil { - return xerrors.Errorf("set user groups: %w", err) - } + // Group sync needs to occur after org sync, since a user can join an org, + // then have their groups sync to said org. + err = api.IDPSync.SyncGroups(ctx, tx, user, params.GroupSync) + if err != nil { + return xerrors.Errorf("sync groups: %w", err) } // Ensure roles are correct. diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 6cd3e796d1825..f9ab3e452ac04 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -80,13 +80,6 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { if options.Entitlements == nil { options.Entitlements = entitlements.New() } - if options.IDPSync == nil { - options.IDPSync = enidpsync.NewSync(options.Logger, options.Entitlements, idpsync.SyncSettings{ - OrganizationField: options.DeploymentValues.OIDC.OrganizationField.Value(), - OrganizationMapping: options.DeploymentValues.OIDC.OrganizationMapping.Value, - OrganizationAssignDefault: options.DeploymentValues.OIDC.OrganizationAssignDefault.Value(), - }) - } ctx, cancelFunc := context.WithCancel(ctx) @@ -118,6 +111,11 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { } options.Database = cryptDB + + if options.IDPSync == nil { + options.IDPSync = enidpsync.NewSync(options.Logger, options.RuntimeConfig, options.Entitlements, idpsync.FromDeploymentValues(options.DeploymentValues)) + } + api := &API{ ctx: ctx, cancel: cancelFunc, @@ -147,7 +145,6 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { } return c.Subject, c.Trial, nil } - api.AGPL.Options.SetUserGroups = api.setUserGroups api.AGPL.Options.SetUserSiteRoles = api.setUserSiteRoles api.AGPL.SiteHandler.RegionsFetcher = func(ctx context.Context) (any, error) { // If the user can read the workspace proxy resource, return that. diff --git a/enterprise/coderd/enidpsync/enidpsync.go b/enterprise/coderd/enidpsync/enidpsync.go index bb21c68501e1b..c7ba8dd3ecdc6 100644 --- a/enterprise/coderd/enidpsync/enidpsync.go +++ b/enterprise/coderd/enidpsync/enidpsync.go @@ -2,9 +2,9 @@ package enidpsync import ( "cdr.dev/slog" - "github.com/coder/coder/v2/coderd/entitlements" "github.com/coder/coder/v2/coderd/idpsync" + "github.com/coder/coder/v2/coderd/runtimeconfig" ) // EnterpriseIDPSync enabled syncing user information from an external IDP. @@ -17,9 +17,9 @@ type EnterpriseIDPSync struct { *idpsync.AGPLIDPSync } -func NewSync(logger slog.Logger, set *entitlements.Set, settings idpsync.SyncSettings) *EnterpriseIDPSync { +func NewSync(logger slog.Logger, manager *runtimeconfig.Manager, set *entitlements.Set, settings idpsync.DeploymentSyncSettings) *EnterpriseIDPSync { return &EnterpriseIDPSync{ entitlements: set, - AGPLIDPSync: idpsync.NewAGPLSync(logger.With(slog.F("enterprise_capable", "true")), settings), + AGPLIDPSync: idpsync.NewAGPLSync(logger.With(slog.F("enterprise_capable", "true")), manager, settings), } } diff --git a/enterprise/coderd/enidpsync/groups.go b/enterprise/coderd/enidpsync/groups.go new file mode 100644 index 0000000000000..dc8456fc6b1c9 --- /dev/null +++ b/enterprise/coderd/enidpsync/groups.go @@ -0,0 +1,70 @@ +package enidpsync + +import ( + "context" + "net/http" + + "github.com/golang-jwt/jwt/v4" + + "github.com/coder/coder/v2/coderd/idpsync" + "github.com/coder/coder/v2/codersdk" +) + +func (e EnterpriseIDPSync) GroupSyncEnabled() bool { + return e.entitlements.Enabled(codersdk.FeatureTemplateRBAC) +} + +// ParseGroupClaims parses the user claims and handles deployment wide group behavior. +// Almost all behavior is deferred since each organization configures it's own +// group sync settings. +// GroupAllowList is implemented here to prevent login by unauthorized users. +// TODO: GroupAllowList overlaps with the default organization group sync settings. +func (e EnterpriseIDPSync) ParseGroupClaims(ctx context.Context, mergedClaims jwt.MapClaims) (idpsync.GroupParams, *idpsync.HTTPError) { + if !e.GroupSyncEnabled() { + return e.AGPLIDPSync.ParseGroupClaims(ctx, mergedClaims) + } + + if e.GroupField != "" && len(e.GroupAllowList) > 0 { + groupsRaw, ok := mergedClaims[e.GroupField] + if !ok { + return idpsync.GroupParams{}, &idpsync.HTTPError{ + Code: http.StatusForbidden, + Msg: "Not a member of an allowed group", + Detail: "You have no groups in your claims!", + RenderStaticPage: true, + } + } + parsedGroups, err := idpsync.ParseStringSliceClaim(groupsRaw) + if err != nil { + return idpsync.GroupParams{}, &idpsync.HTTPError{ + Code: http.StatusBadRequest, + Msg: "Failed read groups from claims for allow list check. Ask an administrator for help.", + Detail: err.Error(), + RenderStaticPage: true, + } + } + + inAllowList := false + AllowListCheckLoop: + for _, group := range parsedGroups { + if _, ok := e.GroupAllowList[group]; ok { + inAllowList = true + break AllowListCheckLoop + } + } + + if !inAllowList { + return idpsync.GroupParams{}, &idpsync.HTTPError{ + Code: http.StatusForbidden, + Msg: "Not a member of an allowed group", + Detail: "Ask an administrator to add one of your groups to the allow list.", + RenderStaticPage: true, + } + } + } + + return idpsync.GroupParams{ + SyncEnabled: true, + MergedClaims: mergedClaims, + }, nil +} diff --git a/enterprise/coderd/enidpsync/groups_test.go b/enterprise/coderd/enidpsync/groups_test.go new file mode 100644 index 0000000000000..77b078cd9e3f0 --- /dev/null +++ b/enterprise/coderd/enidpsync/groups_test.go @@ -0,0 +1,96 @@ +package enidpsync_test + +import ( + "testing" + + "github.com/golang-jwt/jwt/v4" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/entitlements" + "github.com/coder/coder/v2/coderd/idpsync" + "github.com/coder/coder/v2/coderd/runtimeconfig" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/enterprise/coderd/enidpsync" + "github.com/coder/coder/v2/testutil" +) + +func TestEnterpriseParseGroupClaims(t *testing.T) { + t.Parallel() + + entitled := entitlements.New() + entitled.Update(func(entitlements *codersdk.Entitlements) { + entitlements.Features[codersdk.FeatureTemplateRBAC] = codersdk.Feature{ + Entitlement: codersdk.EntitlementEntitled, + Enabled: true, + } + }) + + t.Run("NoEntitlements", func(t *testing.T) { + t.Parallel() + + s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}), + runtimeconfig.NewManager(), + entitlements.New(), + idpsync.DeploymentSyncSettings{}) + + ctx := testutil.Context(t, testutil.WaitMedium) + + params, err := s.ParseGroupClaims(ctx, jwt.MapClaims{}) + require.Nil(t, err) + + require.False(t, params.SyncEnabled) + }) + + t.Run("NotInAllowList", func(t *testing.T) { + t.Parallel() + + s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}), + runtimeconfig.NewManager(), + entitled, + idpsync.DeploymentSyncSettings{ + GroupField: "groups", + GroupAllowList: map[string]struct{}{ + "foo": {}, + }, + }) + + ctx := testutil.Context(t, testutil.WaitMedium) + + // Try with incorrect group + _, err := s.ParseGroupClaims(ctx, jwt.MapClaims{ + "groups": []string{"bar"}, + }) + require.NotNil(t, err) + require.Equal(t, 403, err.Code) + + // Try with no groups + _, err = s.ParseGroupClaims(ctx, jwt.MapClaims{}) + require.NotNil(t, err) + require.Equal(t, 403, err.Code) + }) + + t.Run("InAllowList", func(t *testing.T) { + t.Parallel() + + s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}), + runtimeconfig.NewManager(), + entitled, + idpsync.DeploymentSyncSettings{ + GroupField: "groups", + GroupAllowList: map[string]struct{}{ + "foo": {}, + }, + }) + + ctx := testutil.Context(t, testutil.WaitMedium) + + claims := jwt.MapClaims{ + "groups": []string{"foo", "bar"}, + } + params, err := s.ParseGroupClaims(ctx, claims) + require.Nil(t, err) + require.True(t, params.SyncEnabled) + require.Equal(t, claims, params.MergedClaims) + }) +} diff --git a/enterprise/coderd/enidpsync/organizations_test.go b/enterprise/coderd/enidpsync/organizations_test.go index 0b2ed1ef6521f..cb6da2723b2f5 100644 --- a/enterprise/coderd/enidpsync/organizations_test.go +++ b/enterprise/coderd/enidpsync/organizations_test.go @@ -19,6 +19,7 @@ import ( "github.com/coder/coder/v2/coderd/entitlements" "github.com/coder/coder/v2/coderd/idpsync" "github.com/coder/coder/v2/coderd/rbac" + "github.com/coder/coder/v2/coderd/runtimeconfig" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/coderd/enidpsync" "github.com/coder/coder/v2/testutil" @@ -41,7 +42,7 @@ type Expectations struct { } type OrganizationSyncTestCase struct { - Settings idpsync.SyncSettings + Settings idpsync.DeploymentSyncSettings Entitlements *entitlements.Set Exps []Expectations } @@ -89,7 +90,7 @@ func TestOrganizationSync(t *testing.T) { other := dbgen.Organization(t, db, database.Organization{}) return OrganizationSyncTestCase{ Entitlements: entitled, - Settings: idpsync.SyncSettings{ + Settings: idpsync.DeploymentSyncSettings{ OrganizationField: "", OrganizationMapping: nil, OrganizationAssignDefault: true, @@ -142,7 +143,7 @@ func TestOrganizationSync(t *testing.T) { three := dbgen.Organization(t, db, database.Organization{}) return OrganizationSyncTestCase{ Entitlements: entitled, - Settings: idpsync.SyncSettings{ + Settings: idpsync.DeploymentSyncSettings{ OrganizationField: "organizations", OrganizationMapping: map[string][]uuid.UUID{ "first": {one.ID}, @@ -236,7 +237,7 @@ func TestOrganizationSync(t *testing.T) { } // Create a new sync object - sync := enidpsync.NewSync(logger, caseData.Entitlements, caseData.Settings) + sync := enidpsync.NewSync(logger, runtimeconfig.NewManager(), caseData.Entitlements, caseData.Settings) for _, exp := range caseData.Exps { t.Run(exp.Name, func(t *testing.T) { params, httpErr := sync.ParseOrganizationClaims(ctx, exp.Claims) diff --git a/enterprise/coderd/userauth.go b/enterprise/coderd/userauth.go index 65c4a3473f3f7..60cba28cc37f3 100644 --- a/enterprise/coderd/userauth.go +++ b/enterprise/coderd/userauth.go @@ -8,75 +8,9 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/codersdk" ) -// nolint: revive -func (api *API) setUserGroups(ctx context.Context, logger slog.Logger, db database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error { - if !api.Entitlements.Enabled(codersdk.FeatureTemplateRBAC) { - return nil - } - - return db.InTx(func(tx database.Store) error { - // When setting the user's groups, it's easier to just clear their groups and re-add them. - // This ensures that the user's groups are always in sync with the auth provider. - orgs, err := tx.GetOrganizationsByUserID(ctx, userID) - if err != nil { - return xerrors.Errorf("get user orgs: %w", err) - } - if len(orgs) != 1 { - return xerrors.Errorf("expected 1 org, got %d", len(orgs)) - } - - // Delete all groups the user belongs to. - // nolint:gocritic // Requires system context to remove user from all groups. - err = tx.RemoveUserFromAllGroups(dbauthz.AsSystemRestricted(ctx), userID) - if err != nil { - return xerrors.Errorf("delete user groups: %w", err) - } - - // TODO: This could likely be improved by making these single queries. - // Either by batching or some other means. This for loop could be really - // inefficient if there are a lot of organizations. There was deployments - // on v1 with >100 orgs. - for orgID, groupNames := range orgGroupNames { - // Create the missing groups for each organization. - if createMissingGroups { - // This is the system creating these additional groups, so we use the system restricted context. - // nolint:gocritic - created, err := tx.InsertMissingGroups(dbauthz.AsSystemRestricted(ctx), database.InsertMissingGroupsParams{ - OrganizationID: orgID, - GroupNames: groupNames, - Source: database.GroupSourceOidc, - }) - if err != nil { - return xerrors.Errorf("insert missing groups: %w", err) - } - if len(created) > 0 { - logger.Debug(ctx, "auto created missing groups", - slog.F("org_id", orgID.ID), - slog.F("created", created), - slog.F("num", len(created)), - ) - } - } - - // Re-add the user to all groups returned by the auth provider. - err = tx.InsertUserGroupsByName(ctx, database.InsertUserGroupsByNameParams{ - UserID: userID, - OrganizationID: orgID, - GroupNames: groupNames, - }) - if err != nil { - return xerrors.Errorf("insert user groups: %w", err) - } - } - - return nil - }, nil) -} - func (api *API) setUserSiteRoles(ctx context.Context, logger slog.Logger, db database.Store, userID uuid.UUID, roles []string) error { if !api.Entitlements.Enabled(codersdk.FeatureUserRoleManagement) { logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise entitlement, roles left unchanged", diff --git a/enterprise/coderd/userauth_test.go b/enterprise/coderd/userauth_test.go index 3e94a25a1c013..3b42dc1aeec5f 100644 --- a/enterprise/coderd/userauth_test.go +++ b/enterprise/coderd/userauth_test.go @@ -402,7 +402,9 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true - cfg.GroupField = groupClaim + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = groupClaim }, }) @@ -433,8 +435,10 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true - cfg.GroupField = groupClaim - cfg.GroupMapping = map[string]string{oidcGroupName: coderGroupName} + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = groupClaim + dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{Value: map[string]string{oidcGroupName: coderGroupName}} }, }) @@ -468,7 +472,9 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true - cfg.GroupField = groupClaim + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = groupClaim }, }) @@ -502,7 +508,9 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true - cfg.GroupField = groupClaim + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = groupClaim }, }) @@ -537,7 +545,9 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true - cfg.GroupField = groupClaim + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = groupClaim }, }) @@ -559,8 +569,10 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true - cfg.GroupField = groupClaim - cfg.CreateMissingGroups = true + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = groupClaim + dv.OIDC.GroupAutoCreate = true }, }) @@ -582,8 +594,10 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true - cfg.GroupField = groupClaim - cfg.CreateMissingGroups = true + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = groupClaim + dv.OIDC.GroupAutoCreate = true }, }) @@ -606,8 +620,10 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true - cfg.GroupField = groupClaim - cfg.GroupAllowList = map[string]bool{allowedGroup: true} + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = groupClaim + dv.OIDC.GroupAllowList = []string{allowedGroup} }, }) @@ -697,6 +713,7 @@ func TestGroupSync(t *testing.T) { testCases := []struct { name string modCfg func(cfg *coderd.OIDCConfig) + modDV func(dv *codersdk.DeploymentValues) // initialOrgGroups is initial groups in the org initialOrgGroups []string // initialUserGroups is initial groups for the user @@ -718,10 +735,10 @@ func TestGroupSync(t *testing.T) { }, { name: "GroupSyncDisabled", - modCfg: func(cfg *coderd.OIDCConfig) { + modDV: func(dv *codersdk.DeploymentValues) { // Disable group sync - cfg.GroupField = "" - cfg.GroupFilter = regexp.MustCompile(".*") + dv.OIDC.GroupField = "" + dv.OIDC.GroupRegexFilter = serpent.Regexp(*regexp.MustCompile(".*")) }, initialOrgGroups: []string{"a", "b", "c", "d"}, initialUserGroups: []string{"b", "c", "d"}, @@ -732,10 +749,8 @@ func TestGroupSync(t *testing.T) { { // From a,c,b -> b,c,d name: "ChangeUserGroups", - modCfg: func(cfg *coderd.OIDCConfig) { - cfg.GroupMapping = map[string]string{ - "D": "d", - } + modDV: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{Value: map[string]string{"D": "d"}} }, initialOrgGroups: []string{"a", "b", "c", "d"}, initialUserGroups: []string{"a", "b", "c"}, @@ -749,8 +764,8 @@ func TestGroupSync(t *testing.T) { { // From a,c,b -> [] name: "RemoveAllGroups", - modCfg: func(cfg *coderd.OIDCConfig) { - cfg.GroupFilter = regexp.MustCompile(".*") + modDV: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupRegexFilter = serpent.Regexp(*regexp.MustCompile(".*")) }, initialOrgGroups: []string{"a", "b", "c", "d"}, initialUserGroups: []string{"a", "b", "c"}, @@ -763,8 +778,8 @@ func TestGroupSync(t *testing.T) { { // From a,c,b -> b,c,d,e,f name: "CreateMissingGroups", - modCfg: func(cfg *coderd.OIDCConfig) { - cfg.CreateMissingGroups = true + modDV: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupAutoCreate = true }, initialOrgGroups: []string{"a", "b", "c", "d"}, initialUserGroups: []string{"a", "b", "c"}, @@ -777,14 +792,11 @@ func TestGroupSync(t *testing.T) { { // From a,c,b -> b,c,d,e,f name: "CreateMissingGroupsFilter", - modCfg: func(cfg *coderd.OIDCConfig) { - cfg.CreateMissingGroups = true + modDV: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupAutoCreate = true // Only single letter groups - cfg.GroupFilter = regexp.MustCompile("^[a-z]$") - cfg.GroupMapping = map[string]string{ - // Does not match the filter, but does after being mapped! - "zebra": "z", - } + dv.OIDC.GroupRegexFilter = serpent.Regexp(*regexp.MustCompile("^[a-z]$")) + dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{Value: map[string]string{"zebra": "z"}} }, initialOrgGroups: []string{"a", "b", "c", "d"}, initialUserGroups: []string{"a", "b", "c"}, @@ -806,8 +818,15 @@ func TestGroupSync(t *testing.T) { t.Parallel() runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { - cfg.GroupField = "groups" - tc.modCfg(cfg) + if tc.modCfg != nil { + tc.modCfg(cfg) + } + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = "groups" + if tc.modDV != nil { + tc.modDV(dv) + } }, })