From 99c97c215e634a6ef118f5366dd9499ec68f8d64 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 3 Sep 2024 10:15:52 -0500 Subject: [PATCH 01/38] wip --- coderd/coderd.go | 15 ++--- coderd/idpsync/group.go | 79 ++++++++++++++++++++++++ coderd/idpsync/idpsync.go | 22 ++++--- coderd/idpsync/organization.go | 11 ++++ enterprise/coderd/coderd.go | 16 ++--- enterprise/coderd/enidpsync/enidpsync.go | 1 - enterprise/coderd/enidpsync/groups.go | 28 +++++++++ 7 files changed, 147 insertions(+), 25 deletions(-) create mode 100644 coderd/idpsync/group.go create mode 100644 enterprise/coderd/enidpsync/groups.go diff --git a/coderd/coderd.go b/coderd/coderd.go index 51b6780e4dc47..895aa3e501c27 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -276,13 +276,6 @@ func New(options *Options) *API { if options.Entitlements == nil { options.Entitlements = entitlements.New() } - if options.IDPSync == nil { - options.IDPSync = idpsync.NewAGPLSync(options.Logger, idpsync.SyncSettings{ - OrganizationField: options.DeploymentValues.OIDC.OrganizationField.Value(), - OrganizationMapping: options.DeploymentValues.OIDC.OrganizationMapping.Value, - OrganizationAssignDefault: options.DeploymentValues.OIDC.OrganizationAssignDefault.Value(), - }) - } if options.NewTicker == nil { options.NewTicker = func(duration time.Duration) (tick <-chan time.Time, done func()) { ticker := time.NewTicker(duration) @@ -318,6 +311,14 @@ func New(options *Options) *API { options.AccessControlStore, ) + if options.IDPSync == nil { + options.IDPSync = idpsync.NewAGPLSync(options.Logger, idpsync.SyncSettings{ + OrganizationField: options.DeploymentValues.OIDC.OrganizationField.Value(), + OrganizationMapping: options.DeploymentValues.OIDC.OrganizationMapping.Value, + OrganizationAssignDefault: options.DeploymentValues.OIDC.OrganizationAssignDefault.Value(), + }) + } + experiments := ReadExperiments( options.Logger, options.DeploymentValues.Experiments.Value(), ) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go new file mode 100644 index 0000000000000..1bbc6a09a34d5 --- /dev/null +++ b/coderd/idpsync/group.go @@ -0,0 +1,79 @@ +package idpsync + +import ( + "context" + + "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" +) + +type GroupParams struct { + // SyncEnabled if false will skip syncing the user's groups + SyncEnabled bool + MergedClaims jwt.MapClaims +} + +func (AGPLIDPSync) GroupSyncEnabled() bool { + // AGPL does not support syncing groups. + return false +} + +func (s AGPLIDPSync) ParseGroupClaims(_ context.Context, _ jwt.MapClaims) (GroupParams, *HTTPError) { + return GroupParams{ + SyncEnabled: s.GroupSyncEnabled(), + }, nil +} + +// TODO: Group allowlist behavior should probably happen at this step. +func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user database.User, params GroupParams) error { + // Nothing happens if sync is not enabled + if !params.SyncEnabled { + return nil + } + + // nolint:gocritic // all syncing is done as a system user + ctx = dbauthz.AsSystemRestricted(ctx) + + db.InTx(func(tx database.Store) error { + userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{ + HasMemberID: user.ID, + }) + if err != nil { + return xerrors.Errorf("get user groups: %w", err) + } + + // Figure out which organizations the user is a member of. + userOrgs := make(map[uuid.UUID][]database.GetGroupsRow) + for _, g := range userGroups { + g := g + userOrgs[g.Group.OrganizationID] = append(userOrgs[g.Group.OrganizationID], g) + } + + // Force each organization, we sync the groups. + db.RemoveUserFromAllGroups(ctx, user.ID) + + return nil + }, nil) + + // + //tx.InTx(func(tx database.Store) error { + // // When setting the user's groups, it's easier to just clear their groups and re-add them. + // // This ensures that the user's groups are always in sync with the auth provider. + // err := tx.RemoveUserFromAllGroups(ctx, user.ID) + // if err != nil { + // return err + // } + // + // for _, org := range userOrgs { + // + // } + // + // return nil + //}, nil) + + return nil +} diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index 73a7b9b6f530d..227436cfab998 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -3,6 +3,7 @@ package idpsync import ( "context" "net/http" + "regexp" "strings" "github.com/golang-jwt/jwt/v4" @@ -29,6 +30,11 @@ type IDPSync interface { // SyncOrganizations assigns and removed users from organizations based on the // provided params. SyncOrganizations(ctx context.Context, tx database.Store, user database.User, params OrganizationParams) error + + GroupSyncEnabled() bool + // ParseGroupClaims takes claims from an OIDC provider, and returns the + // group sync params for assigning users into groups. + ParseGroupClaims(ctx context.Context, _ jwt.MapClaims) (GroupParams, *HTTPError) } // AGPLIDPSync is the configuration for syncing user information from an external @@ -50,17 +56,13 @@ type SyncSettings struct { // placed into the default organization. This is mostly a hack to support // legacy deployments. OrganizationAssignDefault bool -} -type OrganizationParams struct { - // SyncEnabled if false will skip syncing the user's organizations. - SyncEnabled bool - // IncludeDefault is primarily for single org deployments. It will ensure - // a user is always inserted into the default org. - IncludeDefault bool - // Organizations is the list of organizations the user should be a member of - // assuming syncing is turned on. - Organizations []uuid.UUID + // Group options here are set by the deployment config and only apply to + // the default organization. + GroupField string + CreateMissingGroups bool + GroupMapping map[string]string + GroupFilter *regexp.Regexp } func NewAGPLSync(logger slog.Logger, settings SyncSettings) *AGPLIDPSync { diff --git a/coderd/idpsync/organization.go b/coderd/idpsync/organization.go index 6d475f28ea0ef..fa091eba441ad 100644 --- a/coderd/idpsync/organization.go +++ b/coderd/idpsync/organization.go @@ -16,6 +16,17 @@ import ( "github.com/coder/coder/v2/coderd/util/slice" ) +type OrganizationParams struct { + // SyncEnabled if false will skip syncing the user's organizations. + SyncEnabled bool + // IncludeDefault is primarily for single org deployments. It will ensure + // a user is always inserted into the default org. + IncludeDefault bool + // Organizations is the list of organizations the user should be a member of + // assuming syncing is turned on. + Organizations []uuid.UUID +} + func (AGPLIDPSync) OrganizationSyncEnabled() bool { // AGPL does not support syncing organizations. return false diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 6cd3e796d1825..bc6491a41198f 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -80,13 +80,6 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { if options.Entitlements == nil { options.Entitlements = entitlements.New() } - if options.IDPSync == nil { - options.IDPSync = enidpsync.NewSync(options.Logger, options.Entitlements, idpsync.SyncSettings{ - OrganizationField: options.DeploymentValues.OIDC.OrganizationField.Value(), - OrganizationMapping: options.DeploymentValues.OIDC.OrganizationMapping.Value, - OrganizationAssignDefault: options.DeploymentValues.OIDC.OrganizationAssignDefault.Value(), - }) - } ctx, cancelFunc := context.WithCancel(ctx) @@ -118,6 +111,15 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { } options.Database = cryptDB + + if options.IDPSync == nil { + options.IDPSync = enidpsync.NewSync(options.Logger, options.Entitlements, idpsync.SyncSettings{ + OrganizationField: options.DeploymentValues.OIDC.OrganizationField.Value(), + OrganizationMapping: options.DeploymentValues.OIDC.OrganizationMapping.Value, + OrganizationAssignDefault: options.DeploymentValues.OIDC.OrganizationAssignDefault.Value(), + }) + } + api := &API{ ctx: ctx, cancel: cancelFunc, diff --git a/enterprise/coderd/enidpsync/enidpsync.go b/enterprise/coderd/enidpsync/enidpsync.go index bb21c68501e1b..918b9f8edb118 100644 --- a/enterprise/coderd/enidpsync/enidpsync.go +++ b/enterprise/coderd/enidpsync/enidpsync.go @@ -2,7 +2,6 @@ package enidpsync import ( "cdr.dev/slog" - "github.com/coder/coder/v2/coderd/entitlements" "github.com/coder/coder/v2/coderd/idpsync" ) diff --git a/enterprise/coderd/enidpsync/groups.go b/enterprise/coderd/enidpsync/groups.go new file mode 100644 index 0000000000000..5c8328f039068 --- /dev/null +++ b/enterprise/coderd/enidpsync/groups.go @@ -0,0 +1,28 @@ +package enidpsync + +import ( + "context" + + "github.com/golang-jwt/jwt/v4" + + "github.com/coder/coder/v2/coderd/idpsync" + "github.com/coder/coder/v2/codersdk" +) + +func (e EnterpriseIDPSync) GroupSyncEnabled() bool { + return e.entitlements.Enabled(codersdk.FeatureTemplateRBAC) + +} + +// ParseGroupClaims returns the groups from the external IDP. +// TODO: Implement group allow_list behavior here since that is deployment wide. +func (e EnterpriseIDPSync) ParseGroupClaims(ctx context.Context, mergedClaims jwt.MapClaims) (idpsync.GroupParams, *idpsync.HTTPError) { + if !e.GroupSyncEnabled() { + return e.AGPLIDPSync.ParseGroupClaims(ctx, mergedClaims) + } + + return idpsync.GroupParams{ + SyncEnabled: e.OrganizationSyncEnabled(), + MergedClaims: mergedClaims, + }, nil +} From bfddeb644f7c10d27ae1abd63ff6585fb61094af Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 3 Sep 2024 16:46:19 -0500 Subject: [PATCH 02/38] begin group sync main work --- coderd/coderd.go | 2 +- coderd/database/dbauthz/dbauthz.go | 8 + coderd/database/dbauthz/dbauthz_test.go | 11 ++ coderd/database/dbmem/dbmem.go | 30 ++++ coderd/database/dbmetrics/dbmetrics.go | 7 + coderd/database/models.go | 2 +- coderd/database/querier.go | 4 +- coderd/database/queries.sql.go | 50 +++++- coderd/database/queries/groupmembers.sql | 19 +++ coderd/idpsync/group.go | 187 ++++++++++++++++++++++- coderd/idpsync/idpsync.go | 19 ++- enterprise/coderd/enidpsync/groups.go | 4 +- 12 files changed, 331 insertions(+), 12 deletions(-) diff --git a/coderd/coderd.go b/coderd/coderd.go index 895aa3e501c27..97c2d9f883713 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -312,7 +312,7 @@ func New(options *Options) *API { ) if options.IDPSync == nil { - options.IDPSync = idpsync.NewAGPLSync(options.Logger, idpsync.SyncSettings{ + options.IDPSync = idpsync.NewAGPLSync(options.Logger, idpsync.DeploymentSyncSettings{ OrganizationField: options.DeploymentValues.OIDC.OrganizationField.Value(), OrganizationMapping: options.DeploymentValues.OIDC.OrganizationMapping.Value, OrganizationAssignDefault: options.DeploymentValues.OIDC.OrganizationAssignDefault.Value(), diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 5782bdc8e7155..3e5e3e39164b6 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2892,6 +2892,14 @@ func (q *querier) InsertUser(ctx context.Context, arg database.InsertUserParams) return insert(q.log, q.auth, obj, q.db.InsertUser)(ctx, arg) } +func (q *querier) InsertUserGroupsByID(ctx context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) { + // This is used by OIDC sync. So only used by a system user. + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.InsertUserGroupsByID(ctx, arg) +} + func (q *querier) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error { // This will add the user to all named groups. This counts as updating a group. // NOTE: instead of checking if the user has permission to update each group, we instead diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index d23bb48184b61..2bd55c4bec499 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -388,6 +388,17 @@ func (s *MethodTestSuite) TestGroup() { GroupNames: slice.New(g1.Name, g2.Name), }).Asserts(rbac.ResourceGroup.InOrg(o.ID), policy.ActionUpdate).Returns() })) + s.Run("InsertUserGroupsByID", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + u1 := dbgen.User(s.T(), db, database.User{}) + g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + _ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g1.ID, UserID: u1.ID}) + check.Args(database.InsertUserGroupsByIDParams{ + UserID: u1.ID, + GroupIds: slice.New(g1.ID, g2.ID), + }).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns(slice.New(g1, g2)) + })) s.Run("RemoveUserFromAllGroups", s.Subtest(func(db database.Store, check *expects) { o := dbgen.Organization(s.T(), db, database.Organization{}) u1 := dbgen.User(s.T(), db, database.User{}) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 04f0d32537f90..c3d04e8f9f201 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -7015,7 +7015,37 @@ func (q *FakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParam return user, nil } +func (q *FakeQuerier) InsertUserGroupsByID(ctx context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) { + err := validateDatabaseType(arg) + if err != nil { + return nil, err + } + + q.mutex.Lock() + defer q.mutex.Unlock() + + var groupIDs []uuid.UUID + for _, group := range q.groups { + for _, groupID := range arg.GroupIds { + if group.ID == groupID { + q.groupMembers = append(q.groupMembers, database.GroupMemberTable{ + UserID: arg.UserID, + GroupID: groupID, + }) + groupIDs = append(groupIDs, group.ID) + } + } + } + + return groupIDs, nil +} + func (q *FakeQuerier) InsertUserGroupsByName(_ context.Context, arg database.InsertUserGroupsByNameParams) error { + err := validateDatabaseType(arg) + if err != nil { + return err + } + q.mutex.Lock() defer q.mutex.Unlock() diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 5aa3a0c8d8cfb..510af865fc1c4 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -1789,6 +1789,13 @@ func (m metricsStore) InsertUser(ctx context.Context, arg database.InsertUserPar return user, err } +func (m metricsStore) InsertUserGroupsByID(ctx context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) { + start := time.Now() + r0 := m.s.InsertUserGroupsByID(ctx, arg) + m.queryLatencies.WithLabelValues("InsertUserGroupsByID").Observe(time.Since(start).Seconds()) + return r0 +} + func (m metricsStore) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error { start := time.Now() err := m.s.InsertUserGroupsByName(ctx, arg) diff --git a/coderd/database/models.go b/coderd/database/models.go index 9e0283ba859c1..950c2674ab310 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.25.0 +// sqlc v1.26.0 package database diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 3432bac7dada1..3499f9cf702b3 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.25.0 +// sqlc v1.26.0 package database @@ -369,6 +369,8 @@ type sqlcQuerier interface { InsertTemplateVersionVariable(ctx context.Context, arg InsertTemplateVersionVariableParams) (TemplateVersionVariable, error) InsertTemplateVersionWorkspaceTag(ctx context.Context, arg InsertTemplateVersionWorkspaceTagParams) (TemplateVersionWorkspaceTag, error) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) + // InsertUserGroupsByID adds a user to all provided groups, if they exist. + InsertUserGroupsByID(ctx context.Context, arg InsertUserGroupsByIDParams) ([]uuid.UUID, error) // InsertUserGroupsByName adds a user to all provided groups, if they exist. InsertUserGroupsByName(ctx context.Context, arg InsertUserGroupsByNameParams) error InsertUserLink(ctx context.Context, arg InsertUserLinkParams) (UserLink, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 89822a72a7855..2816dad13e6ba 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.25.0 +// sqlc v1.26.0 package database @@ -1446,6 +1446,54 @@ func (q *sqlQuerier) InsertGroupMember(ctx context.Context, arg InsertGroupMembe return err } +const insertUserGroupsByID = `-- name: InsertUserGroupsByID :many +WITH groups AS ( + SELECT + id + FROM + groups + WHERE + groups.id = ANY($2 :: uuid []) +) +INSERT INTO + group_members (user_id, group_id) +SELECT + $1, + groups.id +FROM + groups +RETURNING group_id +` + +type InsertUserGroupsByIDParams struct { + UserID uuid.UUID `db:"user_id" json:"user_id"` + GroupIds []uuid.UUID `db:"group_ids" json:"group_ids"` +} + +// InsertUserGroupsByID adds a user to all provided groups, if they exist. +func (q *sqlQuerier) InsertUserGroupsByID(ctx context.Context, arg InsertUserGroupsByIDParams) ([]uuid.UUID, error) { + rows, err := q.db.QueryContext(ctx, insertUserGroupsByID, arg.UserID, pq.Array(arg.GroupIds)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []uuid.UUID + for rows.Next() { + var group_id uuid.UUID + if err := rows.Scan(&group_id); err != nil { + return nil, err + } + items = append(items, group_id) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const insertUserGroupsByName = `-- name: InsertUserGroupsByName :exec WITH groups AS ( SELECT diff --git a/coderd/database/queries/groupmembers.sql b/coderd/database/queries/groupmembers.sql index 0ef2c72323cc9..867b1ba75d0e7 100644 --- a/coderd/database/queries/groupmembers.sql +++ b/coderd/database/queries/groupmembers.sql @@ -29,6 +29,25 @@ SELECT FROM groups; +-- InsertUserGroupsByID adds a user to all provided groups, if they exist. +-- name: InsertUserGroupsByID :many +WITH groups AS ( + SELECT + id + FROM + groups + WHERE + groups.id = ANY(@group_ids :: uuid []) +) +INSERT INTO + group_members (user_id, group_id) +SELECT + @user_id, + groups.id +FROM + groups +RETURNING group_id; + -- name: RemoveUserFromAllGroups :exec DELETE FROM group_members diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 1bbc6a09a34d5..d47a7f69045d5 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -2,13 +2,18 @@ package idpsync import ( "context" + "regexp" "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" "golang.org/x/xerrors" + "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/runtimeconfig" + "github.com/coder/coder/v2/coderd/util/slice" ) type GroupParams struct { @@ -39,7 +44,8 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat ctx = dbauthz.AsSystemRestricted(ctx) db.InTx(func(tx database.Store) error { - userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{ + resolver := runtimeconfig.NewStoreResolver(tx) + userGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ HasMemberID: user.ID, }) if err != nil { @@ -53,9 +59,86 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat userOrgs[g.Group.OrganizationID] = append(userOrgs[g.Group.OrganizationID], g) } - // Force each organization, we sync the groups. - db.RemoveUserFromAllGroups(ctx, user.ID) + // For each org, we need to fetch the sync settings + orgSettings := make(map[uuid.UUID]GroupSyncSettings) + for orgID := range userOrgs { + orgResolver := runtimeconfig.NewOrgResolver(orgID, resolver) + settings, err := s.SyncSettings.Group.Resolve(ctx, orgResolver) + if err != nil { + return xerrors.Errorf("resolve group sync settings: %w", err) + } + orgSettings[orgID] = settings.Value + } + + // collect all diffs to do 1 sql update for all orgs + groupsToAdd := make([]uuid.UUID, 0) + groupsToRemove := make([]uuid.UUID, 0) + // For each org, determine which groups the user should land in + for orgID, settings := range orgSettings { + if settings.GroupField == "" { + // No group sync enabled for this org, so do nothing. + continue + } + + expectedGroups, err := settings.ParseClaims(params.MergedClaims) + if err != nil { + s.Logger.Debug(ctx, "failed to parse claims for groups", + slog.F("organization_field", s.GroupField), + slog.F("organization_id", orgID), + slog.Error(err), + ) + // Unsure where to raise this error on the UI or database. + continue + } + // Everyone group is always implied. + expectedGroups = append(expectedGroups, ExpectedGroup{ + GroupID: &orgID, + }) + + // Now we know what groups the user should be in for a given org, + // determine if we have to do any group updates to sync the user's + // state. + existingGroups := userOrgs[orgID] + existingGroupsTyped := db2sdk.List(existingGroups, func(f database.GetGroupsRow) ExpectedGroup { + return ExpectedGroup{ + GroupID: &f.Group.ID, + GroupName: &f.Group.Name, + } + }) + add, remove := slice.SymmetricDifferenceFunc(existingGroupsTyped, expectedGroups, func(a, b ExpectedGroup) bool { + // Only the name or the name needs to be checked, priority is given to the ID. + if a.GroupID != nil && b.GroupID != nil { + return *a.GroupID == *b.GroupID + } + if a.GroupName != nil && b.GroupName != nil { + return *a.GroupName == *b.GroupName + } + return false + }) + + // HandleMissingGroups will add the new groups to the org if + // the settings specify. It will convert all group names into uuids + // for easier assignment. + assignGroups, err := settings.HandleMissingGroups(ctx, tx, orgID, add) + if err != nil { + return xerrors.Errorf("handle missing groups: %w", err) + } + for _, removeGroup := range remove { + // This should always be the case. + // TODO: make sure this is always the case + if removeGroup.GroupID != nil { + groupsToRemove = append(groupsToRemove, *removeGroup.GroupID) + } + } + + groupsToAdd = append(groupsToAdd, assignGroups...) + } + + tx.InsertUserGroupsByID(ctx, database.InsertUserGroupsByIDParams{ + UserID: user.ID, + GroupIds: groupsToAdd, + }) return nil }, nil) @@ -77,3 +160,101 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat return nil } + +type GroupSyncSettings struct { + GroupField string `json:"field"` + // GroupMapping maps from an OIDC group --> Coder group ID + GroupMapping map[string][]uuid.UUID `json:"mapping"` + RegexFilter *regexp.Regexp `json:"regex_filter"` + AutoCreateMissingGroups bool `json:"auto_create_missing_groups"` +} + +type ExpectedGroup struct { + GroupID *uuid.UUID + GroupName *string +} + +// ParseClaims will take the merged claims from the IDP and return the groups +// the user is expected to be a member of. The expected group can either be a +// name or an ID. +// It is unfortunate we cannot use exclusively names or exclusively IDs. +// When configuring though, if a group is mapped from "A" -> "UUID 1234", and +// the group "UUID 1234" is renamed, we want to maintain the mapping. +// We have to keep names because group sync supports syncing groups by name if +// the external IDP group name matches the Coder one. +func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGroup, error) { + groupsRaw, ok := mergedClaims[s.GroupField] + if !ok { + return []ExpectedGroup{}, nil + } + + parsedGroups, err := ParseStringSliceClaim(groupsRaw) + if err != nil { + return nil, xerrors.Errorf("parse groups field, unexpected type %T: %w", groupsRaw, err) + } + + groups := make([]ExpectedGroup, 0) + for _, group := range parsedGroups { + // Only allow through groups that pass the regex + if s.RegexFilter != nil { + if !s.RegexFilter.MatchString(group) { + continue + } + } + + mappedGroupIDs, ok := s.GroupMapping[group] + if ok { + for _, gid := range mappedGroupIDs { + gid := gid + groups = append(groups, ExpectedGroup{GroupID: &gid}) + } + continue + } + group := group + groups = append(groups, ExpectedGroup{GroupName: &group}) + } + + return groups, nil +} + +func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database.Store, orgID uuid.UUID, add []ExpectedGroup) ([]uuid.UUID, error) { + if !s.AutoCreateMissingGroups { + // Remove all groups that are missing, they will not be created + filter := make([]uuid.UUID, 0) + for _, expected := range add { + if expected.GroupID != nil { + filter = append(filter, *expected.GroupID) + } + } + + return filter, nil + } + // All expected that are missing IDs means the group does not exist + // in the database. Either remove them, or create them if auto create is + // turned on. + var missingGroups []string + addIDs := make([]uuid.UUID, 0) + + for _, expected := range add { + if expected.GroupID == nil && expected.GroupName != nil { + missingGroups = append(missingGroups, *expected.GroupName) + } else if expected.GroupID != nil { + // Keep the IDs to sync the groups. + addIDs = append(addIDs, *expected.GroupID) + } + } + + createdMissingGroups, err := tx.InsertMissingGroups(ctx, database.InsertMissingGroupsParams{ + OrganizationID: orgID, + Source: database.GroupSourceOidc, + GroupNames: missingGroups, + }) + if err != nil { + return nil, xerrors.Errorf("insert missing groups: %w", err) + } + for _, created := range createdMissingGroups { + addIDs = append(addIDs, created.ID) + } + + return addIDs, nil +} diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index 227436cfab998..2d02b941bcc80 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -13,8 +13,10 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/runtimeconfig" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/site" + "github.com/coder/serpent" ) // IDPSync is an interface, so we can implement this as AGPL and as enterprise, @@ -45,7 +47,8 @@ type AGPLIDPSync struct { SyncSettings } -type SyncSettings struct { +// DeploymentSyncSettings are static and are sourced from the deployment config. +type DeploymentSyncSettings struct { // OrganizationField selects the claim field to be used as the created user's // organizations. If the field is the empty string, then no organization updates // will ever come from the OIDC provider. @@ -56,6 +59,12 @@ type SyncSettings struct { // placed into the default organization. This is mostly a hack to support // legacy deployments. OrganizationAssignDefault bool +} + +type SyncSettings struct { + DeploymentSyncSettings + + Group runtimeconfig.Entry[*serpent.Struct[GroupSyncSettings]] // Group options here are set by the deployment config and only apply to // the default organization. @@ -65,10 +74,12 @@ type SyncSettings struct { GroupFilter *regexp.Regexp } -func NewAGPLSync(logger slog.Logger, settings SyncSettings) *AGPLIDPSync { +func NewAGPLSync(logger slog.Logger, settings DeploymentSyncSettings) *AGPLIDPSync { return &AGPLIDPSync{ - Logger: logger.Named("idp-sync"), - SyncSettings: settings, + Logger: logger.Named("idp-sync"), + SyncSettings: SyncSettings{ + DeploymentSyncSettings: settings, + }, } } diff --git a/enterprise/coderd/enidpsync/groups.go b/enterprise/coderd/enidpsync/groups.go index 5c8328f039068..02f012b8e14c3 100644 --- a/enterprise/coderd/enidpsync/groups.go +++ b/enterprise/coderd/enidpsync/groups.go @@ -14,7 +14,9 @@ func (e EnterpriseIDPSync) GroupSyncEnabled() bool { } -// ParseGroupClaims returns the groups from the external IDP. +// ParseGroupClaims parses the user claims and handles deployment wide group behavior. +// Almost all behavior is deferred since each organization configures it's own +// group sync settings. // TODO: Implement group allow_list behavior here since that is deployment wide. func (e EnterpriseIDPSync) ParseGroupClaims(ctx context.Context, mergedClaims jwt.MapClaims) (idpsync.GroupParams, *idpsync.HTTPError) { if !e.GroupSyncEnabled() { From f2857c69a3e7fad35a17a23d72667e376ac78966 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 3 Sep 2024 17:36:24 -0500 Subject: [PATCH 03/38] initial implementation of group sync --- coderd/database/dbauthz/dbauthz.go | 4 ++ coderd/database/dbauthz/dbauthz_test.go | 2 +- coderd/database/dbmem/dbmem.go | 9 +++++ coderd/database/dbmetrics/dbmetrics.go | 11 +++++- coderd/database/dbmock/dbmock.go | 15 ++++++++ coderd/database/querier.go | 1 + coderd/database/queries.sql.go | 37 ++++++++++++++++++ coderd/database/queries/groupmembers.sql | 8 ++++ coderd/idpsync/group.go | 48 ++++++++++++++---------- enterprise/coderd/enidpsync/enidpsync.go | 2 +- 10 files changed, 114 insertions(+), 23 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 3e5e3e39164b6..eaf994e849fc5 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -3108,6 +3108,10 @@ func (q *querier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) return q.db.RemoveUserFromAllGroups(ctx, userID) } +func (q *querier) RemoveUserFromGroups(ctx context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) { + panic("not implemented") +} + func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil { return err diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 2bd55c4bec499..f9b9fb49b71fc 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -397,7 +397,7 @@ func (s *MethodTestSuite) TestGroup() { check.Args(database.InsertUserGroupsByIDParams{ UserID: u1.ID, GroupIds: slice.New(g1.ID, g2.ID), - }).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns(slice.New(g1, g2)) + }).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns(slice.New(g1.ID, g2.ID)) })) s.Run("RemoveUserFromAllGroups", s.Subtest(func(db database.Store, check *expects) { o := dbgen.Organization(s.T(), db, database.Organization{}) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index c3d04e8f9f201..423b13ef4a774 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -7637,6 +7637,15 @@ func (q *FakeQuerier) RemoveUserFromAllGroups(_ context.Context, userID uuid.UUI return nil } +func (q *FakeQuerier) RemoveUserFromGroups(ctx context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) { + err := validateDatabaseType(arg) + if err != nil { + return nil, err + } + + panic("not implemented") +} + func (q *FakeQuerier) RevokeDBCryptKey(_ context.Context, activeKeyDigest string) error { q.mutex.Lock() defer q.mutex.Unlock() diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 510af865fc1c4..0ec70c1736d43 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -1791,9 +1791,9 @@ func (m metricsStore) InsertUser(ctx context.Context, arg database.InsertUserPar func (m metricsStore) InsertUserGroupsByID(ctx context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) { start := time.Now() - r0 := m.s.InsertUserGroupsByID(ctx, arg) + r0, r1 := m.s.InsertUserGroupsByID(ctx, arg) m.queryLatencies.WithLabelValues("InsertUserGroupsByID").Observe(time.Since(start).Seconds()) - return r0 + return r0, r1 } func (m metricsStore) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error { @@ -1950,6 +1950,13 @@ func (m metricsStore) RemoveUserFromAllGroups(ctx context.Context, userID uuid.U return r0 } +func (m metricsStore) RemoveUserFromGroups(ctx context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) { + start := time.Now() + r0, r1 := m.s.RemoveUserFromGroups(ctx, arg) + m.queryLatencies.WithLabelValues("RemoveUserFromGroups").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error { start := time.Now() r0 := m.s.RevokeDBCryptKey(ctx, activeKeyDigest) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 6d881cfe6fc1b..fe2e444ff5c67 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -3766,6 +3766,21 @@ func (mr *MockStoreMockRecorder) InsertUser(arg0, arg1 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertUser", reflect.TypeOf((*MockStore)(nil).InsertUser), arg0, arg1) } +// InsertUserGroupsByID mocks base method. +func (m *MockStore) InsertUserGroupsByID(arg0 context.Context, arg1 database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InsertUserGroupsByID", arg0, arg1) + ret0, _ := ret[0].([]uuid.UUID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// InsertUserGroupsByID indicates an expected call of InsertUserGroupsByID. +func (mr *MockStoreMockRecorder) InsertUserGroupsByID(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertUserGroupsByID", reflect.TypeOf((*MockStore)(nil).InsertUserGroupsByID), arg0, arg1) +} + // InsertUserGroupsByName mocks base method. func (m *MockStore) InsertUserGroupsByName(arg0 context.Context, arg1 database.InsertUserGroupsByNameParams) error { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 3499f9cf702b3..3cedeeade34b7 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -398,6 +398,7 @@ type sqlcQuerier interface { ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error + RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error // Non blocking lock. Returns true if the lock was acquired, false otherwise. // diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 2816dad13e6ba..3e6d6ce61c6fb 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1537,6 +1537,43 @@ func (q *sqlQuerier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UU return err } +const removeUserFromGroups = `-- name: RemoveUserFromGroups :many +DELETE FROM + group_members +WHERE + user_id = $1 AND + group_id = ANY($2 :: uuid []) +RETURNING group_id +` + +type RemoveUserFromGroupsParams struct { + UserID uuid.UUID `db:"user_id" json:"user_id"` + GroupIds []uuid.UUID `db:"group_ids" json:"group_ids"` +} + +func (q *sqlQuerier) RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error) { + rows, err := q.db.QueryContext(ctx, removeUserFromGroups, arg.UserID, pq.Array(arg.GroupIds)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []uuid.UUID + for rows.Next() { + var group_id uuid.UUID + if err := rows.Scan(&group_id); err != nil { + return nil, err + } + items = append(items, group_id) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const deleteGroupByID = `-- name: DeleteGroupByID :exec DELETE FROM groups diff --git a/coderd/database/queries/groupmembers.sql b/coderd/database/queries/groupmembers.sql index 867b1ba75d0e7..814f878cb9232 100644 --- a/coderd/database/queries/groupmembers.sql +++ b/coderd/database/queries/groupmembers.sql @@ -54,6 +54,14 @@ DELETE FROM WHERE user_id = @user_id; +-- name: RemoveUserFromGroups :many +DELETE FROM + group_members +WHERE + user_id = @user_id AND + group_id = ANY(@group_ids :: uuid []) +RETURNING group_id; + -- name: InsertGroupMember :exec INSERT INTO group_members (user_id, group_id) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index d47a7f69045d5..6d5fd11a52e5a 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -135,29 +135,39 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat groupsToAdd = append(groupsToAdd, assignGroups...) } - tx.InsertUserGroupsByID(ctx, database.InsertUserGroupsByIDParams{ - UserID: user.ID, - GroupIds: groupsToAdd, + assignedGroupIDs, err := tx.InsertUserGroupsByID(ctx, database.InsertUserGroupsByIDParams{ + UserID: user.ID, + GroupIds: groupsToAdd, }) + if err != nil { + return xerrors.Errorf("insert user into %d groups: %w", len(groupsToAdd), err) + } + if len(assignedGroupIDs) != len(groupsToAdd) { + s.Logger.Debug(ctx, "failed to assign all groups to user", + slog.F("user_id", user.ID), + slog.F("groups_assigned_count", len(assignedGroupIDs)), + slog.F("expected_count", len(groupsToAdd)), + ) + } + + removedGroupIDs, err := tx.RemoveUserFromGroups(ctx, database.RemoveUserFromGroupsParams{ + UserID: user.ID, + GroupIds: groupsToRemove, + }) + if err != nil { + return xerrors.Errorf("remove user from %d groups: %w", len(groupsToRemove), err) + } + if len(removedGroupIDs) != len(groupsToRemove) { + s.Logger.Debug(ctx, "failed to remove user from all groups", + slog.F("user_id", user.ID), + slog.F("groups_removed_count", len(removedGroupIDs)), + slog.F("expected_count", len(groupsToRemove)), + ) + } + return nil }, nil) - // - //tx.InTx(func(tx database.Store) error { - // // When setting the user's groups, it's easier to just clear their groups and re-add them. - // // This ensures that the user's groups are always in sync with the auth provider. - // err := tx.RemoveUserFromAllGroups(ctx, user.ID) - // if err != nil { - // return err - // } - // - // for _, org := range userOrgs { - // - // } - // - // return nil - //}, nil) - return nil } diff --git a/enterprise/coderd/enidpsync/enidpsync.go b/enterprise/coderd/enidpsync/enidpsync.go index 918b9f8edb118..10988832743da 100644 --- a/enterprise/coderd/enidpsync/enidpsync.go +++ b/enterprise/coderd/enidpsync/enidpsync.go @@ -16,7 +16,7 @@ type EnterpriseIDPSync struct { *idpsync.AGPLIDPSync } -func NewSync(logger slog.Logger, set *entitlements.Set, settings idpsync.SyncSettings) *EnterpriseIDPSync { +func NewSync(logger slog.Logger, set *entitlements.Set, settings idpsync.DeploymentSyncSettings) *EnterpriseIDPSync { return &EnterpriseIDPSync{ entitlements: set, AGPLIDPSync: idpsync.NewAGPLSync(logger.With(slog.F("enterprise_capable", "true")), settings), From 791a05977df0ae118f2e88beec96876ba69e64d4 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 4 Sep 2024 09:24:16 -0500 Subject: [PATCH 04/38] work on moving to the manager --- coderd/idpsync/idpsync.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index 2d02b941bcc80..6400977387536 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -34,15 +34,19 @@ type IDPSync interface { SyncOrganizations(ctx context.Context, tx database.Store, user database.User, params OrganizationParams) error GroupSyncEnabled() bool - // ParseGroupClaims takes claims from an OIDC provider, and returns the - // group sync params for assigning users into groups. + // ParseGroupClaims takes claims from an OIDC provider, and returns the params + // for group syncing. Most of the logic happens in SyncGroups. ParseGroupClaims(ctx context.Context, _ jwt.MapClaims) (GroupParams, *HTTPError) + + // SyncGroups assigns and removes users from groups based on the provided params. + SyncGroups(ctx context.Context, db database.Store, user database.User, params GroupParams) error } // AGPLIDPSync is the configuration for syncing user information from an external // IDP. All related code to syncing user information should be in this package. type AGPLIDPSync struct { - Logger slog.Logger + Logger slog.Logger + Manager runtimeconfig.Manager SyncSettings } @@ -74,9 +78,10 @@ type SyncSettings struct { GroupFilter *regexp.Regexp } -func NewAGPLSync(logger slog.Logger, settings DeploymentSyncSettings) *AGPLIDPSync { +func NewAGPLSync(logger slog.Logger, manager runtimeconfig.Manager, settings DeploymentSyncSettings) *AGPLIDPSync { return &AGPLIDPSync{ - Logger: logger.Named("idp-sync"), + Logger: logger.Named("idp-sync"), + Manager: manager, SyncSettings: SyncSettings{ DeploymentSyncSettings: settings, }, From 4326e9d94af1915ac1d109e72f808ed9cb3acd5c Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 4 Sep 2024 10:22:22 -0500 Subject: [PATCH 05/38] fixup compile issues --- coderd/idpsync/group.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 6d5fd11a52e5a..11e14260a7f3d 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -12,7 +12,6 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" - "github.com/coder/coder/v2/coderd/runtimeconfig" "github.com/coder/coder/v2/coderd/util/slice" ) @@ -44,7 +43,6 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat ctx = dbauthz.AsSystemRestricted(ctx) db.InTx(func(tx database.Store) error { - resolver := runtimeconfig.NewStoreResolver(tx) userGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ HasMemberID: user.ID, }) @@ -62,7 +60,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat // For each org, we need to fetch the sync settings orgSettings := make(map[uuid.UUID]GroupSyncSettings) for orgID := range userOrgs { - orgResolver := runtimeconfig.NewOrgResolver(orgID, resolver) + orgResolver := s.Manager.Scoped(orgID.String()) settings, err := s.SyncSettings.Group.Resolve(ctx, orgResolver) if err != nil { return xerrors.Errorf("resolve group sync settings: %w", err) From 6d3ed2e57043c7eada5f619697c4d2997b3bf790 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 4 Sep 2024 10:25:12 -0500 Subject: [PATCH 06/38] fixup some tests --- coderd/idpsync/organizations_test.go | 29 +++++++++++-------- enterprise/coderd/enidpsync/enidpsync.go | 5 ++-- .../coderd/enidpsync/organizations_test.go | 9 +++--- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/coderd/idpsync/organizations_test.go b/coderd/idpsync/organizations_test.go index 03b1ebfa4b27b..b0e7728b0640a 100644 --- a/coderd/idpsync/organizations_test.go +++ b/coderd/idpsync/organizations_test.go @@ -9,6 +9,7 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd/idpsync" + "github.com/coder/coder/v2/coderd/runtimeconfig" "github.com/coder/coder/v2/testutil" ) @@ -18,11 +19,13 @@ func TestParseOrganizationClaims(t *testing.T) { t.Run("SingleOrgDeployment", func(t *testing.T) { t.Parallel() - s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), idpsync.SyncSettings{ - OrganizationField: "", - OrganizationMapping: nil, - OrganizationAssignDefault: true, - }) + s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), + runtimeconfig.NewNoopManager(), + idpsync.DeploymentSyncSettings{ + OrganizationField: "", + OrganizationMapping: nil, + OrganizationAssignDefault: true, + }) ctx := testutil.Context(t, testutil.WaitMedium) @@ -38,13 +41,15 @@ func TestParseOrganizationClaims(t *testing.T) { t.Parallel() // AGPL has limited behavior - s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), idpsync.SyncSettings{ - OrganizationField: "orgs", - OrganizationMapping: map[string][]uuid.UUID{ - "random": {uuid.New()}, - }, - OrganizationAssignDefault: false, - }) + s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), + runtimeconfig.NewNoopManager(), + idpsync.DeploymentSyncSettings{ + OrganizationField: "orgs", + OrganizationMapping: map[string][]uuid.UUID{ + "random": {uuid.New()}, + }, + OrganizationAssignDefault: false, + }) ctx := testutil.Context(t, testutil.WaitMedium) diff --git a/enterprise/coderd/enidpsync/enidpsync.go b/enterprise/coderd/enidpsync/enidpsync.go index 10988832743da..a7ff1eaa07257 100644 --- a/enterprise/coderd/enidpsync/enidpsync.go +++ b/enterprise/coderd/enidpsync/enidpsync.go @@ -4,6 +4,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/entitlements" "github.com/coder/coder/v2/coderd/idpsync" + "github.com/coder/coder/v2/coderd/runtimeconfig" ) // EnterpriseIDPSync enabled syncing user information from an external IDP. @@ -16,9 +17,9 @@ type EnterpriseIDPSync struct { *idpsync.AGPLIDPSync } -func NewSync(logger slog.Logger, set *entitlements.Set, settings idpsync.DeploymentSyncSettings) *EnterpriseIDPSync { +func NewSync(logger slog.Logger, manager runtimeconfig.Manager, set *entitlements.Set, settings idpsync.DeploymentSyncSettings) *EnterpriseIDPSync { return &EnterpriseIDPSync{ entitlements: set, - AGPLIDPSync: idpsync.NewAGPLSync(logger.With(slog.F("enterprise_capable", "true")), settings), + AGPLIDPSync: idpsync.NewAGPLSync(logger.With(slog.F("enterprise_capable", "true")), manager, settings), } } diff --git a/enterprise/coderd/enidpsync/organizations_test.go b/enterprise/coderd/enidpsync/organizations_test.go index 0b2ed1ef6521f..8978fa6b46ee1 100644 --- a/enterprise/coderd/enidpsync/organizations_test.go +++ b/enterprise/coderd/enidpsync/organizations_test.go @@ -19,6 +19,7 @@ import ( "github.com/coder/coder/v2/coderd/entitlements" "github.com/coder/coder/v2/coderd/idpsync" "github.com/coder/coder/v2/coderd/rbac" + "github.com/coder/coder/v2/coderd/runtimeconfig" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/coderd/enidpsync" "github.com/coder/coder/v2/testutil" @@ -41,7 +42,7 @@ type Expectations struct { } type OrganizationSyncTestCase struct { - Settings idpsync.SyncSettings + Settings idpsync.DeploymentSyncSettings Entitlements *entitlements.Set Exps []Expectations } @@ -89,7 +90,7 @@ func TestOrganizationSync(t *testing.T) { other := dbgen.Organization(t, db, database.Organization{}) return OrganizationSyncTestCase{ Entitlements: entitled, - Settings: idpsync.SyncSettings{ + Settings: idpsync.DeploymentSyncSettings{ OrganizationField: "", OrganizationMapping: nil, OrganizationAssignDefault: true, @@ -142,7 +143,7 @@ func TestOrganizationSync(t *testing.T) { three := dbgen.Organization(t, db, database.Organization{}) return OrganizationSyncTestCase{ Entitlements: entitled, - Settings: idpsync.SyncSettings{ + Settings: idpsync.DeploymentSyncSettings{ OrganizationField: "organizations", OrganizationMapping: map[string][]uuid.UUID{ "first": {one.ID}, @@ -236,7 +237,7 @@ func TestOrganizationSync(t *testing.T) { } // Create a new sync object - sync := enidpsync.NewSync(logger, caseData.Entitlements, caseData.Settings) + sync := enidpsync.NewSync(logger, runtimeconfig.NewStoreManager(rdb), caseData.Entitlements, caseData.Settings) for _, exp := range caseData.Exps { t.Run(exp.Name, func(t *testing.T) { params, httpErr := sync.ParseOrganizationClaims(ctx, exp.Claims) From 0803619e8c65eb1bb584abae3137e403bf07f8fe Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 4 Sep 2024 10:51:49 -0500 Subject: [PATCH 07/38] handle allow list --- coderd/coderd.go | 6 +--- coderd/idpsync/group.go | 8 +++++ coderd/idpsync/idpsync.go | 23 ++++++++++++ coderd/userauth.go | 2 +- enterprise/coderd/coderd.go | 6 +--- enterprise/coderd/enidpsync/groups.go | 42 +++++++++++++++++++++- enterprise/coderd/enidpsync/groups_test.go | 35 ++++++++++++++++++ 7 files changed, 110 insertions(+), 12 deletions(-) create mode 100644 enterprise/coderd/enidpsync/groups_test.go diff --git a/coderd/coderd.go b/coderd/coderd.go index 97c2d9f883713..b829d37a06773 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -312,11 +312,7 @@ func New(options *Options) *API { ) if options.IDPSync == nil { - options.IDPSync = idpsync.NewAGPLSync(options.Logger, idpsync.DeploymentSyncSettings{ - OrganizationField: options.DeploymentValues.OIDC.OrganizationField.Value(), - OrganizationMapping: options.DeploymentValues.OIDC.OrganizationMapping.Value, - OrganizationAssignDefault: options.DeploymentValues.OIDC.OrganizationAssignDefault.Value(), - }) + options.IDPSync = idpsync.NewAGPLSync(options.Logger, options.RuntimeConfig, idpsync.FromDeploymentValues(options.DeploymentValues)) } experiments := ReadExperiments( diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 11e14260a7f3d..0257801ae2a7a 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -266,3 +266,11 @@ func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database. return addIDs, nil } + +func ConvertAllowList(allowList []string) map[string]struct{} { + allowMap := make(map[string]struct{}, len(allowList)) + for _, group := range allowList { + allowMap[group] = struct{}{} + } + return allowMap +} diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index 6400977387536..5ad2ffb52ff12 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -63,6 +63,29 @@ type DeploymentSyncSettings struct { // placed into the default organization. This is mostly a hack to support // legacy deployments. OrganizationAssignDefault bool + + // GroupField at the deployment level is used for deployment level group claim + // settings. + GroupField string + // GroupAllowList (if set) will restrict authentication to only users who + // have at least one group in this list. + // A map representation is used for easier lookup. + GroupAllowList map[string]struct{} +} + +func FromDeploymentValues(dv *codersdk.DeploymentValues) DeploymentSyncSettings { + if dv == nil { + panic("Developer error: DeploymentValues should not be nil") + } + return DeploymentSyncSettings{ + OrganizationField: dv.OIDC.OrganizationField.Value(), + OrganizationMapping: dv.OIDC.OrganizationMapping.Value, + OrganizationAssignDefault: dv.OIDC.OrganizationAssignDefault.Value(), + + GroupField: dv.OIDC.GroupField.Value(), + GroupAllowList: ConvertAllowList(dv.OIDC.GroupAllowList.Value()), + } + } type SyncSettings struct { diff --git a/coderd/userauth.go b/coderd/userauth.go index bb149d9d07379..a1abadc63f31a 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -1142,7 +1142,7 @@ func (api *API) oidcGroups(ctx context.Context, mergedClaims map[string]interfac slog.F("allow_list_count", len(api.OIDCConfig.GroupAllowList)), slog.F("user_group_count", len(groups)), ) - detail := "Ask an administrator to add one of your groups to the whitelist" + detail := "Ask an administrator to add one of your groups to the allow list" if len(groups) == 0 { detail = "You are currently not a member of any groups! Ask an administrator to add you to an authorized group to login." } diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index bc6491a41198f..ce55bae8ec8d0 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -113,11 +113,7 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { options.Database = cryptDB if options.IDPSync == nil { - options.IDPSync = enidpsync.NewSync(options.Logger, options.Entitlements, idpsync.SyncSettings{ - OrganizationField: options.DeploymentValues.OIDC.OrganizationField.Value(), - OrganizationMapping: options.DeploymentValues.OIDC.OrganizationMapping.Value, - OrganizationAssignDefault: options.DeploymentValues.OIDC.OrganizationAssignDefault.Value(), - }) + options.IDPSync = enidpsync.NewSync(options.Logger, options.RuntimeConfig, options.Entitlements, idpsync.FromDeploymentValues(options.DeploymentValues)) } api := &API{ diff --git a/enterprise/coderd/enidpsync/groups.go b/enterprise/coderd/enidpsync/groups.go index 02f012b8e14c3..441f847c6a450 100644 --- a/enterprise/coderd/enidpsync/groups.go +++ b/enterprise/coderd/enidpsync/groups.go @@ -2,6 +2,7 @@ package enidpsync import ( "context" + "net/http" "github.com/golang-jwt/jwt/v4" @@ -11,7 +12,6 @@ import ( func (e EnterpriseIDPSync) GroupSyncEnabled() bool { return e.entitlements.Enabled(codersdk.FeatureTemplateRBAC) - } // ParseGroupClaims parses the user claims and handles deployment wide group behavior. @@ -23,6 +23,46 @@ func (e EnterpriseIDPSync) ParseGroupClaims(ctx context.Context, mergedClaims jw return e.AGPLIDPSync.ParseGroupClaims(ctx, mergedClaims) } + if e.GroupField != "" && len(e.GroupAllowList) > 0 { + groupsRaw, ok := mergedClaims[e.GroupField] + if !ok { + return idpsync.GroupParams{}, &idpsync.HTTPError{ + Code: http.StatusForbidden, + Msg: "Not a member of an allowed group", + Detail: "You have no groups in your claims!", + RenderStaticPage: true, + } + } + parsedGroups, err := idpsync.ParseStringSliceClaim(groupsRaw) + if err != nil { + return idpsync.GroupParams{}, &idpsync.HTTPError{ + Code: http.StatusBadRequest, + Msg: "Failed read groups from claims for allow list check. Ask an administrator for help.", + Detail: err.Error(), + RenderStaticPage: true, + } + } + + inAllowList := false + AllowListCheckLoop: + for _, group := range parsedGroups { + if _, ok := e.GroupAllowList[group]; ok { + inAllowList = true + break AllowListCheckLoop + } + } + + if !inAllowList { + return idpsync.GroupParams{}, &idpsync.HTTPError{ + Code: http.StatusForbidden, + Msg: "Not a member of an allowed group", + Detail: "Ask an administrator to add one of your groups to the allow list.", + RenderStaticPage: true, + } + } + + } + return idpsync.GroupParams{ SyncEnabled: e.OrganizationSyncEnabled(), MergedClaims: mergedClaims, diff --git a/enterprise/coderd/enidpsync/groups_test.go b/enterprise/coderd/enidpsync/groups_test.go new file mode 100644 index 0000000000000..149c57dadd79a --- /dev/null +++ b/enterprise/coderd/enidpsync/groups_test.go @@ -0,0 +1,35 @@ +package enidpsync_test + +import ( + "testing" + + "github.com/golang-jwt/jwt/v4" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/entitlements" + "github.com/coder/coder/v2/coderd/idpsync" + "github.com/coder/coder/v2/coderd/runtimeconfig" + "github.com/coder/coder/v2/enterprise/coderd/enidpsync" + "github.com/coder/coder/v2/testutil" +) + +func TestEnterpriseParseGroupClaims(t *testing.T) { + t.Parallel() + + t.Run("NoEntitlements", func(t *testing.T) { + t.Parallel() + + s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}), + runtimeconfig.NewNoopManager(), + entitlements.New(), + idpsync.DeploymentSyncSettings{}) + + ctx := testutil.Context(t, testutil.WaitMedium) + + params, err := s.ParseGroupClaims(ctx, jwt.MapClaims{}) + require.Nil(t, err) + + require.False(t, params.SyncEnabled) + }) +} From 596e7b467feef913f10ff70768783bb6b5f826e5 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 4 Sep 2024 12:25:51 -0500 Subject: [PATCH 08/38] WIP unit test for group sync --- coderd/coderdtest/uuids.go | 21 ++ coderd/database/dbmem/dbmem.go | 17 +- coderd/idpsync/group.go | 23 +- coderd/idpsync/group_test.go | 289 +++++++++++++++++++++ coderd/idpsync/idpsync.go | 26 +- enterprise/coderd/enidpsync/groups.go | 2 +- enterprise/coderd/enidpsync/groups_test.go | 61 +++++ 7 files changed, 421 insertions(+), 18 deletions(-) create mode 100644 coderd/coderdtest/uuids.go create mode 100644 coderd/idpsync/group_test.go diff --git a/coderd/coderdtest/uuids.go b/coderd/coderdtest/uuids.go new file mode 100644 index 0000000000000..aefa6e83c0b3c --- /dev/null +++ b/coderd/coderdtest/uuids.go @@ -0,0 +1,21 @@ +package coderdtest + +import "github.com/google/uuid" + +type DeterministicUUIDGenerator struct { + Named map[string]uuid.UUID +} + +func NewDeterministicUUIDGenerator() *DeterministicUUIDGenerator { + return &DeterministicUUIDGenerator{ + Named: make(map[string]uuid.UUID), + } +} + +func (d *DeterministicUUIDGenerator) ID(name string) uuid.UUID { + if v, ok := d.Named[name]; ok { + return v + } + d.Named[name] = uuid.New() + return d.Named[name] +} diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 423b13ef4a774..37811063997db 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -7643,7 +7643,22 @@ func (q *FakeQuerier) RemoveUserFromGroups(ctx context.Context, arg database.Rem return nil, err } - panic("not implemented") + q.mutex.Lock() + defer q.mutex.Unlock() + + removed := make([]uuid.UUID, 0) + q.data.groupMembers = slices.DeleteFunc(q.data.groupMembers, func(groupMember database.GroupMemberTable) bool { + if groupMember.UserID != arg.UserID { + return false + } + if !slices.Contains(arg.GroupIds, groupMember.GroupID) { + return false + } + removed = append(removed, groupMember.GroupID) + return true + }) + + return removed, nil } func (q *FakeQuerier) RevokeDBCryptKey(_ context.Context, activeKeyDigest string) error { diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 0257801ae2a7a..d45d79bf04cac 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -2,6 +2,7 @@ package idpsync import ( "context" + "encoding/json" "regexp" "github.com/golang-jwt/jwt/v4" @@ -12,6 +13,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/runtimeconfig" "github.com/coder/coder/v2/coderd/util/slice" ) @@ -32,7 +34,6 @@ func (s AGPLIDPSync) ParseGroupClaims(_ context.Context, _ jwt.MapClaims) (Group }, nil } -// TODO: Group allowlist behavior should probably happen at this step. func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user database.User, params GroupParams) error { // Nothing happens if sync is not enabled if !params.SyncEnabled { @@ -43,6 +44,8 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat ctx = dbauthz.AsSystemRestricted(ctx) db.InTx(func(tx database.Store) error { + manager := runtimeconfig.NewStoreManager(tx) + userGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ HasMemberID: user.ID, }) @@ -60,12 +63,12 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat // For each org, we need to fetch the sync settings orgSettings := make(map[uuid.UUID]GroupSyncSettings) for orgID := range userOrgs { - orgResolver := s.Manager.Scoped(orgID.String()) + orgResolver := manager.Scoped(orgID.String()) settings, err := s.SyncSettings.Group.Resolve(ctx, orgResolver) if err != nil { return xerrors.Errorf("resolve group sync settings: %w", err) } - orgSettings[orgID] = settings.Value + orgSettings[orgID] = *settings } // collect all diffs to do 1 sql update for all orgs @@ -177,6 +180,20 @@ type GroupSyncSettings struct { AutoCreateMissingGroups bool `json:"auto_create_missing_groups"` } +func (s *GroupSyncSettings) Set(v string) error { + return json.Unmarshal([]byte(v), s) +} +func (s *GroupSyncSettings) String() string { + v, err := json.Marshal(s) + if err != nil { + return "decode failed: " + err.Error() + } + return string(v) +} +func (s *GroupSyncSettings) Type() string { + return "GroupSyncSettings" +} + type ExpectedGroup struct { GroupID *uuid.UUID GroupName *string diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go new file mode 100644 index 0000000000000..42465f115488e --- /dev/null +++ b/coderd/idpsync/group_test.go @@ -0,0 +1,289 @@ +package idpsync_test + +import ( + "context" + "regexp" + "testing" + + "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/idpsync" + "github.com/coder/coder/v2/coderd/runtimeconfig" + "github.com/coder/coder/v2/testutil" +) + +func TestParseGroupClaims(t *testing.T) { + t.Parallel() + + t.Run("EmptyConfig", func(t *testing.T) { + t.Parallel() + + s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), + runtimeconfig.NewNoopManager(), + idpsync.DeploymentSyncSettings{}) + + ctx := testutil.Context(t, testutil.WaitMedium) + + params, err := s.ParseGroupClaims(ctx, jwt.MapClaims{}) + require.Nil(t, err) + + require.False(t, params.SyncEnabled) + }) + + // AllowList has no effect in AGPL + t.Run("AllowList", func(t *testing.T) { + t.Parallel() + + s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), + runtimeconfig.NewNoopManager(), + idpsync.DeploymentSyncSettings{ + GroupField: "groups", + GroupAllowList: map[string]struct{}{ + "foo": {}, + }, + }) + + ctx := testutil.Context(t, testutil.WaitMedium) + + params, err := s.ParseGroupClaims(ctx, jwt.MapClaims{}) + require.Nil(t, err) + require.False(t, params.SyncEnabled) + }) +} + +func TestGroupSyncTable(t *testing.T) { + t.Parallel() + + if dbtestutil.WillUsePostgres() { + t.Skip("Skipping test because it populates a lot of db entries, which is slow on postgres.") + } + + userClaims := jwt.MapClaims{ + "groups": []string{ + "foo", "bar", "baz", + "create-bar", "create-baz", + }, + } + + ids := coderdtest.NewDeterministicUUIDGenerator() + testCases := []orgSetupDefinition{ + { + Name: "SwitchGroups", + Settings: &idpsync.GroupSyncSettings{ + GroupField: "groups", + GroupMapping: map[string][]uuid.UUID{ + "foo": {ids.ID("sg-foo"), ids.ID("sg-foo-2")}, + "bar": {ids.ID("sg-bar")}, + "baz": {ids.ID("sg-baz")}, + }, + }, + Groups: map[uuid.UUID]bool{ + uuid.New(): true, + uuid.New(): true, + // Extra groups + ids.ID("sg-foo"): false, + ids.ID("sg-foo-2"): false, + ids.ID("sg-bar"): false, + ids.ID("sg-baz"): false, + }, + ExpectedGroups: []uuid.UUID{ + ids.ID("sg-foo"), + ids.ID("sg-foo-2"), + ids.ID("sg-bar"), + ids.ID("sg-baz"), + }, + }, + { + Name: "StayInGroup", + Settings: &idpsync.GroupSyncSettings{ + GroupField: "groups", + // Only match foo, so bar does not map + RegexFilter: regexp.MustCompile("^foo$"), + GroupMapping: map[string][]uuid.UUID{ + "foo": {ids.ID("gg-foo"), uuid.New()}, + "bar": {ids.ID("gg-bar")}, + "baz": {ids.ID("gg-baz")}, + }, + }, + Groups: map[uuid.UUID]bool{ + ids.ID("gg-foo"): true, + ids.ID("gg-bar"): false, + }, + ExpectedGroups: []uuid.UUID{ + ids.ID("gg-foo"), + }, + }, + { + Name: "UserJoinsGroups", + Settings: &idpsync.GroupSyncSettings{ + GroupField: "groups", + GroupMapping: map[string][]uuid.UUID{ + "foo": {ids.ID("ng-foo"), uuid.New()}, + "bar": {ids.ID("ng-bar"), ids.ID("ng-bar-2")}, + "baz": {ids.ID("ng-baz")}, + }, + }, + Groups: map[uuid.UUID]bool{ + ids.ID("ng-foo"): false, + ids.ID("ng-bar"): false, + ids.ID("ng-bar-2"): false, + ids.ID("ng-baz"): false, + }, + ExpectedGroups: []uuid.UUID{ + ids.ID("ng-foo"), + ids.ID("ng-bar"), + ids.ID("ng-bar-2"), + ids.ID("ng-baz"), + }, + }, + { + Name: "CreateGroups", + Settings: &idpsync.GroupSyncSettings{ + GroupField: "groups", + RegexFilter: regexp.MustCompile("^create"), + AutoCreateMissingGroups: true, + }, + Groups: map[uuid.UUID]bool{}, + ExpectedGroups: []uuid.UUID{ + ids.ID("create-bar"), + ids.ID("create-baz"), + }, + }, + { + Name: "NoUser", + Settings: &idpsync.GroupSyncSettings{ + GroupField: "groups", + GroupMapping: map[string][]uuid.UUID{ + // Extra ID that does not map to a group + "foo": {ids.ID("ow-foo"), uuid.New()}, + }, + RegexFilter: nil, + AutoCreateMissingGroups: false, + }, + NotMember: true, + Groups: map[uuid.UUID]bool{ + ids.ID("ow-foo"): false, + ids.ID("ow-bar"): false, + }, + }, + { + Name: "NoSettingsNoUser", + Settings: nil, + Groups: map[uuid.UUID]bool{}, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.Name, func(t *testing.T) { + t.Parallel() + + if tc.OrgID == uuid.Nil { + tc.OrgID = uuid.New() + } + + db, _ := dbtestutil.NewDB(t) + manager := runtimeconfig.NewStoreManager(db) + s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), + manager, + idpsync.DeploymentSyncSettings{ + GroupField: "groups", + }, + ) + + ctx := testutil.Context(t, testutil.WaitMedium) + user := dbgen.User(t, db, database.User{}) + SetupOrganization(t, s, db, user, tc) + + // Do the group sync! + err := s.SyncGroups(ctx, db, user, idpsync.GroupParams{ + SyncEnabled: true, + MergedClaims: userClaims, + }) + require.NoError(t, err) + + tc.Assert(t, tc.OrgID, db, user) + }) + } +} + +func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, user database.User, def orgSetupDefinition) { + org := dbgen.Organization(t, db, database.Organization{ + ID: def.OrgID, + }) + + manager := runtimeconfig.NewStoreManager(db) + orgResolver := manager.Scoped(org.ID.String()) + err := s.Group.SetRuntimeValue(context.Background(), orgResolver, def.Settings) + require.NoError(t, err) + + if !def.NotMember { + dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + } + for groupID, in := range def.Groups { + dbgen.Group(t, db, database.Group{ + ID: groupID, + OrganizationID: org.ID, + }) + if in { + dbgen.GroupMember(t, db, database.GroupMemberTable{ + UserID: user.ID, + GroupID: groupID, + }) + } + } +} + +type orgSetupDefinition struct { + Name string + OrgID uuid.UUID + // True if the user is a member of the group + Groups map[uuid.UUID]bool + NotMember bool + + Settings *idpsync.GroupSyncSettings + ExpectedGroups []uuid.UUID +} + +func (o orgSetupDefinition) Assert(t *testing.T, orgID uuid.UUID, db database.Store, user database.User) { + t.Helper() + + t.Run(o.Name+"-Assert", func(t *testing.T) { + ctx := context.Background() + + members, err := db.OrganizationMembers(ctx, database.OrganizationMembersParams{ + OrganizationID: orgID, + UserID: user.ID, + }) + require.NoError(t, err) + if o.NotMember { + require.Len(t, members, 0, "should not be a member") + } else { + require.Len(t, members, 1, "should be a member") + } + + userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{ + OrganizationID: orgID, + HasMemberID: user.ID, + }) + require.NoError(t, err) + if o.ExpectedGroups == nil { + o.ExpectedGroups = make([]uuid.UUID, 0) + } + found := db2sdk.List(userGroups, func(g database.GetGroupsRow) uuid.UUID { + return g.Group.ID + }) + require.ElementsMatch(t, o.ExpectedGroups, found, "user groups") + }) +} diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index 5ad2ffb52ff12..3ff8d78fd5174 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -3,7 +3,6 @@ package idpsync import ( "context" "net/http" - "regexp" "strings" "github.com/golang-jwt/jwt/v4" @@ -16,7 +15,6 @@ import ( "github.com/coder/coder/v2/coderd/runtimeconfig" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/site" - "github.com/coder/serpent" ) // IDPSync is an interface, so we can implement this as AGPL and as enterprise, @@ -45,8 +43,7 @@ type IDPSync interface { // AGPLIDPSync is the configuration for syncing user information from an external // IDP. All related code to syncing user information should be in this package. type AGPLIDPSync struct { - Logger slog.Logger - Manager runtimeconfig.Manager + Logger slog.Logger SyncSettings } @@ -91,22 +88,25 @@ func FromDeploymentValues(dv *codersdk.DeploymentValues) DeploymentSyncSettings type SyncSettings struct { DeploymentSyncSettings - Group runtimeconfig.Entry[*serpent.Struct[GroupSyncSettings]] + Group runtimeconfig.Entry[*GroupSyncSettings] - // Group options here are set by the deployment config and only apply to - // the default organization. - GroupField string - CreateMissingGroups bool - GroupMapping map[string]string - GroupFilter *regexp.Regexp + //// Group options here are set by the deployment config and only apply to + //// the default organization. + //GroupField string + //CreateMissingGroups bool + //GroupMapping map[string]string + //GroupFilter *regexp.Regexp } func NewAGPLSync(logger slog.Logger, manager runtimeconfig.Manager, settings DeploymentSyncSettings) *AGPLIDPSync { return &AGPLIDPSync{ - Logger: logger.Named("idp-sync"), - Manager: manager, + Logger: logger.Named("idp-sync"), SyncSettings: SyncSettings{ DeploymentSyncSettings: settings, + // Default to '{}' if the group sync settings are not set. + // TODO: Feels strange to have to define the type as a string. I should be + // able to pass in an empty struct. + Group: runtimeconfig.MustNew[*GroupSyncSettings]("group-sync-settings", "{}"), }, } } diff --git a/enterprise/coderd/enidpsync/groups.go b/enterprise/coderd/enidpsync/groups.go index 441f847c6a450..2ecc8703e29cd 100644 --- a/enterprise/coderd/enidpsync/groups.go +++ b/enterprise/coderd/enidpsync/groups.go @@ -64,7 +64,7 @@ func (e EnterpriseIDPSync) ParseGroupClaims(ctx context.Context, mergedClaims jw } return idpsync.GroupParams{ - SyncEnabled: e.OrganizationSyncEnabled(), + SyncEnabled: true, MergedClaims: mergedClaims, }, nil } diff --git a/enterprise/coderd/enidpsync/groups_test.go b/enterprise/coderd/enidpsync/groups_test.go index 149c57dadd79a..138d2954712de 100644 --- a/enterprise/coderd/enidpsync/groups_test.go +++ b/enterprise/coderd/enidpsync/groups_test.go @@ -10,6 +10,7 @@ import ( "github.com/coder/coder/v2/coderd/entitlements" "github.com/coder/coder/v2/coderd/idpsync" "github.com/coder/coder/v2/coderd/runtimeconfig" + "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/coderd/enidpsync" "github.com/coder/coder/v2/testutil" ) @@ -17,6 +18,14 @@ import ( func TestEnterpriseParseGroupClaims(t *testing.T) { t.Parallel() + entitled := entitlements.New() + entitled.Update(func(entitlements *codersdk.Entitlements) { + entitlements.Features[codersdk.FeatureTemplateRBAC] = codersdk.Feature{ + Entitlement: codersdk.EntitlementEntitled, + Enabled: true, + } + }) + t.Run("NoEntitlements", func(t *testing.T) { t.Parallel() @@ -32,4 +41,56 @@ func TestEnterpriseParseGroupClaims(t *testing.T) { require.False(t, params.SyncEnabled) }) + + t.Run("NotInAllowList", func(t *testing.T) { + t.Parallel() + + s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}), + runtimeconfig.NewNoopManager(), + entitled, + idpsync.DeploymentSyncSettings{ + GroupField: "groups", + GroupAllowList: map[string]struct{}{ + "foo": {}, + }, + }) + + ctx := testutil.Context(t, testutil.WaitMedium) + + // Try with incorrect group + _, err := s.ParseGroupClaims(ctx, jwt.MapClaims{ + "groups": []string{"bar"}, + }) + require.NotNil(t, err) + require.Equal(t, 403, err.Code) + + // Try with no groups + _, err = s.ParseGroupClaims(ctx, jwt.MapClaims{}) + require.NotNil(t, err) + require.Equal(t, 403, err.Code) + }) + + t.Run("InAllowList", func(t *testing.T) { + t.Parallel() + + s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}), + runtimeconfig.NewNoopManager(), + entitled, + idpsync.DeploymentSyncSettings{ + GroupField: "groups", + GroupAllowList: map[string]struct{}{ + "foo": {}, + }, + }) + + ctx := testutil.Context(t, testutil.WaitMedium) + + claims := jwt.MapClaims{ + "groups": []string{"foo", "bar"}, + } + params, err := s.ParseGroupClaims(ctx, claims) + require.Nil(t, err) + require.True(t, params.SyncEnabled) + require.Equal(t, claims, params.MergedClaims) + }) } From b9476ac14070ee3917ac935606bfb6bbc4523a2b Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 4 Sep 2024 14:29:54 -0500 Subject: [PATCH 09/38] fixup tests, account for existing groups --- coderd/database/queries.sql.go | 7 +- coderd/database/queries/groups.sql | 4 ++ coderd/idpsync/group.go | 20 +++++- coderd/idpsync/group_test.go | 105 +++++++++++++++++++++-------- 4 files changed, 106 insertions(+), 30 deletions(-) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 3e6d6ce61c6fb..b87ad6f857bb9 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1677,11 +1677,16 @@ WHERE ) ELSE true END + AND CASE WHEN array_length($3 :: text[], 1) > 0 THEN + name = ANY($3) + ELSE true + END ` type GetGroupsParams struct { OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` HasMemberID uuid.UUID `db:"has_member_id" json:"has_member_id"` + GroupNames []string `db:"group_names" json:"group_names"` } type GetGroupsRow struct { @@ -1691,7 +1696,7 @@ type GetGroupsRow struct { } func (q *sqlQuerier) GetGroups(ctx context.Context, arg GetGroupsParams) ([]GetGroupsRow, error) { - rows, err := q.db.QueryContext(ctx, getGroups, arg.OrganizationID, arg.HasMemberID) + rows, err := q.db.QueryContext(ctx, getGroups, arg.OrganizationID, arg.HasMemberID, pq.Array(arg.GroupNames)) if err != nil { return nil, err } diff --git a/coderd/database/queries/groups.sql b/coderd/database/queries/groups.sql index 1752ccd112ea7..628395b8a81b0 100644 --- a/coderd/database/queries/groups.sql +++ b/coderd/database/queries/groups.sql @@ -52,6 +52,10 @@ WHERE ) ELSE true END + AND CASE WHEN array_length(@group_names :: text[], 1) > 0 THEN + name = ANY(@group_names) + ELSE true + END ; -- name: InsertGroup :one diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index d45d79bf04cac..de1a3eee6597a 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -244,16 +244,34 @@ func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGr func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database.Store, orgID uuid.UUID, add []ExpectedGroup) ([]uuid.UUID, error) { if !s.AutoCreateMissingGroups { - // Remove all groups that are missing, they will not be created + // construct the list of groups to search by name to see if they exist. + var lookups []string filter := make([]uuid.UUID, 0) for _, expected := range add { if expected.GroupID != nil { filter = append(filter, *expected.GroupID) + } else if expected.GroupName != nil { + lookups = append(lookups, *expected.GroupName) + } + } + + if len(lookups) > 0 { + newGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ + OrganizationID: uuid.UUID{}, + HasMemberID: uuid.UUID{}, + GroupNames: lookups, + }) + if err != nil { + return nil, xerrors.Errorf("get groups by names: %w", err) + } + for _, g := range newGroups { + filter = append(filter, g.Group.ID) } } return filter, nil } + // All expected that are missing IDs means the group does not exist // in the database. Either remove them, or create them if auto create is // turned on. diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index 42465f115488e..6b63b13e76ae5 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -8,6 +8,7 @@ import ( "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" "github.com/stretchr/testify/require" + "golang.org/x/exp/slices" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd/coderdtest" @@ -152,9 +153,26 @@ func TestGroupSyncTable(t *testing.T) { AutoCreateMissingGroups: true, }, Groups: map[uuid.UUID]bool{}, - ExpectedGroups: []uuid.UUID{ - ids.ID("create-bar"), - ids.ID("create-baz"), + ExpectedGroupNames: []string{ + "create-bar", + "create-baz", + }, + }, + { + Name: "GroupNamesNoMapping", + Settings: &idpsync.GroupSyncSettings{ + GroupField: "groups", + RegexFilter: regexp.MustCompile(".*"), + AutoCreateMissingGroups: false, + }, + GroupNames: map[string]bool{ + "foo": false, + "bar": false, + "goob": true, + }, + ExpectedGroupNames: []string{ + "foo", + "bar", }, }, { @@ -219,10 +237,12 @@ func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, org := dbgen.Organization(t, db, database.Organization{ ID: def.OrgID, }) + _, err := db.InsertAllUsersGroup(context.Background(), org.ID) + require.NoError(t, err, "Everyone group for an org") manager := runtimeconfig.NewStoreManager(db) orgResolver := manager.Scoped(org.ID.String()) - err := s.Group.SetRuntimeValue(context.Background(), orgResolver, def.Settings) + err = s.Group.SetRuntimeValue(context.Background(), orgResolver, def.Settings) require.NoError(t, err) if !def.NotMember { @@ -243,47 +263,76 @@ func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, }) } } + for groupName, in := range def.GroupNames { + group := dbgen.Group(t, db, database.Group{ + Name: groupName, + OrganizationID: org.ID, + }) + if in { + dbgen.GroupMember(t, db, database.GroupMemberTable{ + UserID: user.ID, + GroupID: group.ID, + }) + } + } } type orgSetupDefinition struct { Name string OrgID uuid.UUID // True if the user is a member of the group - Groups map[uuid.UUID]bool - NotMember bool + Groups map[uuid.UUID]bool + GroupNames map[string]bool + NotMember bool - Settings *idpsync.GroupSyncSettings - ExpectedGroups []uuid.UUID + Settings *idpsync.GroupSyncSettings + ExpectedGroups []uuid.UUID + ExpectedGroupNames []string } func (o orgSetupDefinition) Assert(t *testing.T, orgID uuid.UUID, db database.Store, user database.User) { t.Helper() - t.Run(o.Name+"-Assert", func(t *testing.T) { - ctx := context.Background() + ctx := context.Background() - members, err := db.OrganizationMembers(ctx, database.OrganizationMembersParams{ - OrganizationID: orgID, - UserID: user.ID, - }) - require.NoError(t, err) - if o.NotMember { - require.Len(t, members, 0, "should not be a member") - } else { - require.Len(t, members, 1, "should be a member") - } + members, err := db.OrganizationMembers(ctx, database.OrganizationMembersParams{ + OrganizationID: orgID, + UserID: user.ID, + }) + require.NoError(t, err) + if o.NotMember { + require.Len(t, members, 0, "should not be a member") + } else { + require.Len(t, members, 1, "should be a member") + } + + userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{ + OrganizationID: orgID, + HasMemberID: user.ID, + }) + require.NoError(t, err) + if o.ExpectedGroups == nil { + o.ExpectedGroups = make([]uuid.UUID, 0) + } + if len(o.ExpectedGroupNames) > 0 && len(o.ExpectedGroups) > 0 { + t.Fatal("ExpectedGroups and ExpectedGroupNames are mutually exclusive") + } + + // Everyone groups mess up our asserts + userGroups = slices.DeleteFunc(userGroups, func(row database.GetGroupsRow) bool { + return row.Group.ID == row.Group.OrganizationID + }) - userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{ - OrganizationID: orgID, - HasMemberID: user.ID, + if len(o.ExpectedGroupNames) > 0 { + found := db2sdk.List(userGroups, func(g database.GetGroupsRow) string { + return g.Group.Name }) - require.NoError(t, err) - if o.ExpectedGroups == nil { - o.ExpectedGroups = make([]uuid.UUID, 0) - } + require.ElementsMatch(t, o.ExpectedGroupNames, found, "user groups by name") + } else { + // Check by ID, recommended found := db2sdk.List(userGroups, func(g database.GetGroupsRow) uuid.UUID { return g.Group.ID }) require.ElementsMatch(t, o.ExpectedGroups, found, "user groups") - }) + } } From ee8e4e4b07e54611a4f99c0fec383493c0198bf7 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 5 Sep 2024 12:28:30 -0500 Subject: [PATCH 10/38] fix compile issues --- coderd/idpsync/group.go | 5 +---- coderd/idpsync/group_test.go | 10 +++++----- coderd/idpsync/idpsync.go | 13 ++++++------- coderd/idpsync/organizations_test.go | 4 ++-- 4 files changed, 14 insertions(+), 18 deletions(-) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index de1a3eee6597a..cedcb8ba8eaae 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -13,7 +13,6 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" - "github.com/coder/coder/v2/coderd/runtimeconfig" "github.com/coder/coder/v2/coderd/util/slice" ) @@ -44,8 +43,6 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat ctx = dbauthz.AsSystemRestricted(ctx) db.InTx(func(tx database.Store) error { - manager := runtimeconfig.NewStoreManager(tx) - userGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ HasMemberID: user.ID, }) @@ -63,7 +60,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat // For each org, we need to fetch the sync settings orgSettings := make(map[uuid.UUID]GroupSyncSettings) for orgID := range userOrgs { - orgResolver := manager.Scoped(orgID.String()) + orgResolver := s.Manager.OrganizationResolver(tx, orgID) settings, err := s.SyncSettings.Group.Resolve(ctx, orgResolver) if err != nil { return xerrors.Errorf("resolve group sync settings: %w", err) diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index 6b63b13e76ae5..456e0752ebc1e 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -28,7 +28,7 @@ func TestParseGroupClaims(t *testing.T) { t.Parallel() s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewNoopManager(), + runtimeconfig.NewStoreManager(), idpsync.DeploymentSyncSettings{}) ctx := testutil.Context(t, testutil.WaitMedium) @@ -44,7 +44,7 @@ func TestParseGroupClaims(t *testing.T) { t.Parallel() s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewNoopManager(), + runtimeconfig.NewStoreManager(), idpsync.DeploymentSyncSettings{ GroupField: "groups", GroupAllowList: map[string]struct{}{ @@ -209,7 +209,7 @@ func TestGroupSyncTable(t *testing.T) { } db, _ := dbtestutil.NewDB(t) - manager := runtimeconfig.NewStoreManager(db) + manager := runtimeconfig.NewStoreManager() s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), manager, idpsync.DeploymentSyncSettings{ @@ -240,8 +240,8 @@ func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, _, err := db.InsertAllUsersGroup(context.Background(), org.ID) require.NoError(t, err, "Everyone group for an org") - manager := runtimeconfig.NewStoreManager(db) - orgResolver := manager.Scoped(org.ID.String()) + manager := runtimeconfig.NewStoreManager() + orgResolver := manager.OrganizationResolver(db, org.ID) err = s.Group.SetRuntimeValue(context.Background(), orgResolver, def.Settings) require.NoError(t, err) diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index 3ff8d78fd5174..b462f5da01bdb 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -43,7 +43,8 @@ type IDPSync interface { // AGPLIDPSync is the configuration for syncing user information from an external // IDP. All related code to syncing user information should be in this package. type AGPLIDPSync struct { - Logger slog.Logger + Logger slog.Logger + Manager runtimeconfig.Manager SyncSettings } @@ -88,7 +89,7 @@ func FromDeploymentValues(dv *codersdk.DeploymentValues) DeploymentSyncSettings type SyncSettings struct { DeploymentSyncSettings - Group runtimeconfig.Entry[*GroupSyncSettings] + Group runtimeconfig.RuntimeEntry[*GroupSyncSettings] //// Group options here are set by the deployment config and only apply to //// the default organization. @@ -100,13 +101,11 @@ type SyncSettings struct { func NewAGPLSync(logger slog.Logger, manager runtimeconfig.Manager, settings DeploymentSyncSettings) *AGPLIDPSync { return &AGPLIDPSync{ - Logger: logger.Named("idp-sync"), + Logger: logger.Named("idp-sync"), + Manager: manager, SyncSettings: SyncSettings{ DeploymentSyncSettings: settings, - // Default to '{}' if the group sync settings are not set. - // TODO: Feels strange to have to define the type as a string. I should be - // able to pass in an empty struct. - Group: runtimeconfig.MustNew[*GroupSyncSettings]("group-sync-settings", "{}"), + Group: runtimeconfig.MustNew[*GroupSyncSettings]("group-sync-settings"), }, } } diff --git a/coderd/idpsync/organizations_test.go b/coderd/idpsync/organizations_test.go index b0e7728b0640a..934d7d83816ab 100644 --- a/coderd/idpsync/organizations_test.go +++ b/coderd/idpsync/organizations_test.go @@ -20,7 +20,7 @@ func TestParseOrganizationClaims(t *testing.T) { t.Parallel() s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewNoopManager(), + runtimeconfig.NewStoreManager(), idpsync.DeploymentSyncSettings{ OrganizationField: "", OrganizationMapping: nil, @@ -42,7 +42,7 @@ func TestParseOrganizationClaims(t *testing.T) { // AGPL has limited behavior s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewNoopManager(), + runtimeconfig.NewStoreManager(), idpsync.DeploymentSyncSettings{ OrganizationField: "orgs", OrganizationMapping: map[string][]uuid.UUID{ From d5ff0f7bfa82b6abe2be2f5d7c030c66393da913 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 5 Sep 2024 12:36:56 -0500 Subject: [PATCH 11/38] add comment for test helper --- coderd/coderdtest/uuids.go | 4 ++++ coderd/coderdtest/uuids_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 coderd/coderdtest/uuids_test.go diff --git a/coderd/coderdtest/uuids.go b/coderd/coderdtest/uuids.go index aefa6e83c0b3c..1ff60bf26c572 100644 --- a/coderd/coderdtest/uuids.go +++ b/coderd/coderdtest/uuids.go @@ -2,6 +2,10 @@ package coderdtest import "github.com/google/uuid" +// DeterministicUUIDGenerator allows "naming" uuids for unit tests. +// An example of where this is useful, is when a tabled test references +// a UUID that is not yet known. An alternative to this would be to +// hard code some UUID strings, but these strings are not human friendly. type DeterministicUUIDGenerator struct { Named map[string]uuid.UUID } diff --git a/coderd/coderdtest/uuids_test.go b/coderd/coderdtest/uuids_test.go new file mode 100644 index 0000000000000..bb92d6faffabd --- /dev/null +++ b/coderd/coderdtest/uuids_test.go @@ -0,0 +1,33 @@ +package coderdtest_test + +import ( + "github.com/google/uuid" + + "github.com/coder/coder/v2/coderd/coderdtest" +) + +func ExampleNewDeterministicUUIDGenerator() { + det := coderdtest.NewDeterministicUUIDGenerator() + testCases := []struct { + CreateUsers []uuid.UUID + ExpectedIDs []uuid.UUID + }{ + { + CreateUsers: []uuid.UUID{ + det.ID("player1"), + det.ID("player2"), + }, + ExpectedIDs: []uuid.UUID{ + det.ID("player1"), + det.ID("player2"), + }, + }, + } + + for _, tc := range testCases { + tc := tc + var _ = tc + // Do the test with CreateUsers as the setup, and the expected IDs + // will match. + } +} From 86c0f6f52eb79c142920a9abce15b6d40e4f315c Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 5 Sep 2024 14:41:42 -0500 Subject: [PATCH 12/38] handle legacy params --- coderd/database/queries.sql.go | 2 +- coderd/database/queries/groups.sql | 2 +- coderd/idpsync/group.go | 34 ++++++++++++++++++++++++++++++ coderd/idpsync/idpsync.go | 24 +++++++++++++++------ 4 files changed, 53 insertions(+), 9 deletions(-) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index b87ad6f857bb9..7c7fbbf0f88f0 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1678,7 +1678,7 @@ WHERE ELSE true END AND CASE WHEN array_length($3 :: text[], 1) > 0 THEN - name = ANY($3) + groups.name = ANY($3) ELSE true END ` diff --git a/coderd/database/queries/groups.sql b/coderd/database/queries/groups.sql index 628395b8a81b0..0df848d6a6d05 100644 --- a/coderd/database/queries/groups.sql +++ b/coderd/database/queries/groups.sql @@ -53,7 +53,7 @@ WHERE ELSE true END AND CASE WHEN array_length(@group_names :: text[], 1) > 0 THEN - name = ANY(@group_names) + groups.name = ANY(@group_names) ELSE true END ; diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index cedcb8ba8eaae..5acf9665f80ce 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -39,6 +39,18 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat return nil } + // Only care about the default org for deployment settings if the + // legacy deployment settings exist. + defaultOrgID := uuid.Nil + // Default organization is configured via legacy deployment values + if s.DeploymentSyncSettings.Legacy.GroupField != "" { + defaultOrganization, err := db.GetDefaultOrganization(ctx) + if err != nil { + return xerrors.Errorf("get default organization: %w", err) + } + defaultOrgID = defaultOrganization.ID + } + // nolint:gocritic // all syncing is done as a system user ctx = dbauthz.AsSystemRestricted(ctx) @@ -66,6 +78,16 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat return xerrors.Errorf("resolve group sync settings: %w", err) } orgSettings[orgID] = *settings + + // Legacy deployment settings will override empty settings. + if orgID == defaultOrgID && settings.GroupField == "" { + settings = &GroupSyncSettings{ + GroupField: s.Legacy.GroupField, + LegacyGroupNameMapping: s.Legacy.GroupMapping, + RegexFilter: s.Legacy.GroupFilter, + AutoCreateMissingGroups: s.Legacy.CreateMissingGroups, + } + } } // collect all diffs to do 1 sql update for all orgs @@ -175,6 +197,12 @@ type GroupSyncSettings struct { GroupMapping map[string][]uuid.UUID `json:"mapping"` RegexFilter *regexp.Regexp `json:"regex_filter"` AutoCreateMissingGroups bool `json:"auto_create_missing_groups"` + // LegacyGroupNameMapping is deprecated. It remaps an IDP group name to + // a Coder group name. Since configuration is now done at runtime, + // group IDs are used to account for group renames. + // For legacy configurations, this config option has to remain. + // Deprecated: Use GroupMapping instead. + LegacyGroupNameMapping map[string]string } func (s *GroupSyncSettings) Set(v string) error { @@ -232,6 +260,12 @@ func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGr } continue } + + mappedGroupName, ok := s.LegacyGroupNameMapping[group] + if ok { + groups = append(groups, ExpectedGroup{GroupName: &mappedGroupName}) + continue + } group := group groups = append(groups, ExpectedGroup{GroupName: &group}) } diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index b462f5da01bdb..a01e3bc14f745 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -3,6 +3,7 @@ package idpsync import ( "context" "net/http" + "regexp" "strings" "github.com/golang-jwt/jwt/v4" @@ -69,6 +70,15 @@ type DeploymentSyncSettings struct { // have at least one group in this list. // A map representation is used for easier lookup. GroupAllowList map[string]struct{} + // Legacy deployment settings that only apply to the default org. + Legacy DefaultOrgLegacySettings +} + +type DefaultOrgLegacySettings struct { + GroupField string + GroupMapping map[string]string + GroupFilter *regexp.Regexp + CreateMissingGroups bool } func FromDeploymentValues(dv *codersdk.DeploymentValues) DeploymentSyncSettings { @@ -80,8 +90,15 @@ func FromDeploymentValues(dv *codersdk.DeploymentValues) DeploymentSyncSettings OrganizationMapping: dv.OIDC.OrganizationMapping.Value, OrganizationAssignDefault: dv.OIDC.OrganizationAssignDefault.Value(), + // TODO: Separate group field for allow list from default org GroupField: dv.OIDC.GroupField.Value(), GroupAllowList: ConvertAllowList(dv.OIDC.GroupAllowList.Value()), + Legacy: DefaultOrgLegacySettings{ + GroupField: dv.OIDC.GroupField.Value(), + GroupMapping: dv.OIDC.GroupMapping.Value, + GroupFilter: dv.OIDC.GroupRegexFilter.Value(), + CreateMissingGroups: dv.OIDC.GroupAutoCreate.Value(), + }, } } @@ -90,13 +107,6 @@ type SyncSettings struct { DeploymentSyncSettings Group runtimeconfig.RuntimeEntry[*GroupSyncSettings] - - //// Group options here are set by the deployment config and only apply to - //// the default organization. - //GroupField string - //CreateMissingGroups bool - //GroupMapping map[string]string - //GroupFilter *regexp.Regexp } func NewAGPLSync(logger slog.Logger, manager runtimeconfig.Manager, settings DeploymentSyncSettings) *AGPLIDPSync { From 2f03e182b2554c449e1ee1eb13a16cdc5321270a Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 5 Sep 2024 14:44:00 -0500 Subject: [PATCH 13/38] make gen --- coderd/database/dbmock/dbmock.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index fe2e444ff5c67..c5d579e1c2656 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -4118,6 +4118,21 @@ func (mr *MockStoreMockRecorder) RemoveUserFromAllGroups(arg0, arg1 any) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromAllGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromAllGroups), arg0, arg1) } +// RemoveUserFromGroups mocks base method. +func (m *MockStore) RemoveUserFromGroups(arg0 context.Context, arg1 database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveUserFromGroups", arg0, arg1) + ret0, _ := ret[0].([]uuid.UUID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RemoveUserFromGroups indicates an expected call of RemoveUserFromGroups. +func (mr *MockStoreMockRecorder) RemoveUserFromGroups(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromGroups), arg0, arg1) +} + // RevokeDBCryptKey mocks base method. func (m *MockStore) RevokeDBCryptKey(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() From ec8092d25c3fe8c81edef30afb2af9790c9117a1 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 5 Sep 2024 14:46:36 -0500 Subject: [PATCH 14/38] cleanup --- coderd/coderdtest/uuids_test.go | 2 +- coderd/database/dbauthz/dbauthz.go | 6 +++++- coderd/database/dbmem/dbmem.go | 4 ++-- coderd/idpsync/group.go | 8 +++++++- coderd/idpsync/idpsync.go | 1 - enterprise/coderd/enidpsync/groups_test.go | 6 +++--- 6 files changed, 18 insertions(+), 9 deletions(-) diff --git a/coderd/coderdtest/uuids_test.go b/coderd/coderdtest/uuids_test.go index bb92d6faffabd..5a0e10935bd50 100644 --- a/coderd/coderdtest/uuids_test.go +++ b/coderd/coderdtest/uuids_test.go @@ -26,7 +26,7 @@ func ExampleNewDeterministicUUIDGenerator() { for _, tc := range testCases { tc := tc - var _ = tc + _ = tc // Do the test with CreateUsers as the setup, and the expected IDs // will match. } diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index eaf994e849fc5..077d704be1300 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -3109,7 +3109,11 @@ func (q *querier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) } func (q *querier) RemoveUserFromGroups(ctx context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) { - panic("not implemented") + // This is a system function to clear user groups in group sync. + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.RemoveUserFromGroups(ctx, arg) } func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error { diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 37811063997db..6f0c04eb4e512 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -7015,7 +7015,7 @@ func (q *FakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParam return user, nil } -func (q *FakeQuerier) InsertUserGroupsByID(ctx context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) { +func (q *FakeQuerier) InsertUserGroupsByID(_ context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) { err := validateDatabaseType(arg) if err != nil { return nil, err @@ -7637,7 +7637,7 @@ func (q *FakeQuerier) RemoveUserFromAllGroups(_ context.Context, userID uuid.UUI return nil } -func (q *FakeQuerier) RemoveUserFromGroups(ctx context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) { +func (q *FakeQuerier) RemoveUserFromGroups(_ context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) { err := validateDatabaseType(arg) if err != nil { return nil, err diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 5acf9665f80ce..07ead53cd52c2 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -54,7 +54,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat // nolint:gocritic // all syncing is done as a system user ctx = dbauthz.AsSystemRestricted(ctx) - db.InTx(func(tx database.Store) error { + err := db.InTx(func(tx database.Store) error { userGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ HasMemberID: user.ID, }) @@ -188,6 +188,10 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat return nil }, nil) + if err != nil { + return err + } + return nil } @@ -208,6 +212,7 @@ type GroupSyncSettings struct { func (s *GroupSyncSettings) Set(v string) error { return json.Unmarshal([]byte(v), s) } + func (s *GroupSyncSettings) String() string { v, err := json.Marshal(s) if err != nil { @@ -215,6 +220,7 @@ func (s *GroupSyncSettings) String() string { } return string(v) } + func (s *GroupSyncSettings) Type() string { return "GroupSyncSettings" } diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index a01e3bc14f745..bc3e5cd479064 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -100,7 +100,6 @@ func FromDeploymentValues(dv *codersdk.DeploymentValues) DeploymentSyncSettings CreateMissingGroups: dv.OIDC.GroupAutoCreate.Value(), }, } - } type SyncSettings struct { diff --git a/enterprise/coderd/enidpsync/groups_test.go b/enterprise/coderd/enidpsync/groups_test.go index 138d2954712de..8103f8a002937 100644 --- a/enterprise/coderd/enidpsync/groups_test.go +++ b/enterprise/coderd/enidpsync/groups_test.go @@ -30,7 +30,7 @@ func TestEnterpriseParseGroupClaims(t *testing.T) { t.Parallel() s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewNoopManager(), + runtimeconfig.NewStoreManager(), entitlements.New(), idpsync.DeploymentSyncSettings{}) @@ -46,7 +46,7 @@ func TestEnterpriseParseGroupClaims(t *testing.T) { t.Parallel() s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewNoopManager(), + runtimeconfig.NewStoreManager(), entitled, idpsync.DeploymentSyncSettings{ GroupField: "groups", @@ -74,7 +74,7 @@ func TestEnterpriseParseGroupClaims(t *testing.T) { t.Parallel() s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewNoopManager(), + runtimeconfig.NewStoreManager(), entitled, idpsync.DeploymentSyncSettings{ GroupField: "groups", From d63727d5288de948555f6355643e2d45bc608d21 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 5 Sep 2024 15:08:52 -0500 Subject: [PATCH 15/38] add unit test for legacy behavior --- coderd/idpsync/group.go | 30 ++++++++++++++++++++++++------ coderd/idpsync/group_test.go | 20 ++++++++++++++++++++ 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 07ead53cd52c2..c64b08ee07553 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -206,7 +206,7 @@ type GroupSyncSettings struct { // group IDs are used to account for group renames. // For legacy configurations, this config option has to remain. // Deprecated: Use GroupMapping instead. - LegacyGroupNameMapping map[string]string + LegacyGroupNameMapping map[string]string `json:"legacy_group_name_mapping,omitempty"` } func (s *GroupSyncSettings) Set(v string) error { @@ -251,6 +251,12 @@ func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGr groups := make([]ExpectedGroup, 0) for _, group := range parsedGroups { + // Legacy group mappings happen before the regex filter. + mappedGroupName, ok := s.LegacyGroupNameMapping[group] + if ok { + group = mappedGroupName + } + // Only allow through groups that pass the regex if s.RegexFilter != nil { if !s.RegexFilter.MatchString(group) { @@ -267,11 +273,6 @@ func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGr continue } - mappedGroupName, ok := s.LegacyGroupNameMapping[group] - if ok { - groups = append(groups, ExpectedGroup{GroupName: &mappedGroupName}) - continue - } group := group groups = append(groups, ExpectedGroup{GroupName: &group}) } @@ -332,6 +333,23 @@ func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database. if err != nil { return nil, xerrors.Errorf("insert missing groups: %w", err) } + + if len(missingGroups) != len(createdMissingGroups) { + // This is unfortunate, but if legacy params are used, then some existing groups + // can come as params. So we need to fetch them + allGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ + OrganizationID: orgID, + GroupNames: missingGroups, + }) + if err != nil { + return nil, xerrors.Errorf("get groups by names: %w", err) + } + + createdMissingGroups = db2sdk.List(allGroups, func(g database.GetGroupsRow) database.Group { + return g.Group + }) + } + for _, created := range createdMissingGroups { addIDs = append(addIDs, created.ID) } diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index 456e0752ebc1e..406df099167c3 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -197,6 +197,26 @@ func TestGroupSyncTable(t *testing.T) { Settings: nil, Groups: map[uuid.UUID]bool{}, }, + { + Name: "LegacyMapping", + Settings: &idpsync.GroupSyncSettings{ + GroupField: "groups", + RegexFilter: regexp.MustCompile("^legacy"), + LegacyGroupNameMapping: map[string]string{ + "create-bar": "legacy-bar", + "foo": "legacy-foo", + }, + AutoCreateMissingGroups: true, + }, + Groups: map[uuid.UUID]bool{}, + GroupNames: map[string]bool{ + "legacy-foo": false, + }, + ExpectedGroupNames: []string{ + "legacy-bar", + "legacy-foo", + }, + }, } for _, tc := range testCases { From 2a1769c7fdcd34b5c7a82f335d7f1e4f05b696ce Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 5 Sep 2024 15:41:41 -0500 Subject: [PATCH 16/38] work on batching removal by name or id --- coderd/database/dbmem/dbmem.go | 41 ++++++--- coderd/database/queries.sql.go | 19 +++- coderd/database/queries/groupmembers.sql | 12 ++- coderd/idpsync/group.go | 108 +++++++++++++++-------- coderd/idpsync/group_test.go | 5 +- 5 files changed, 130 insertions(+), 55 deletions(-) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 6f0c04eb4e512..fd97fb0d701bf 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -682,6 +682,17 @@ func (q *FakeQuerier) getWorkspaceResourcesByJobIDNoLock(_ context.Context, jobI return resources, nil } +func (q *FakeQuerier) getGroupByNameNoLock(arg database.NameOrganizationPair) (database.Group, error) { + for _, group := range q.groups { + if group.OrganizationID == arg.OrganizationID && + group.Name == arg.Name { + return group, nil + } + } + + return database.Group{}, sql.ErrNoRows +} + func (q *FakeQuerier) getGroupByIDNoLock(_ context.Context, id uuid.UUID) (database.Group, error) { for _, group := range q.groups { if group.ID == id { @@ -2613,14 +2624,10 @@ func (q *FakeQuerier) GetGroupByOrgAndName(_ context.Context, arg database.GetGr q.mutex.RLock() defer q.mutex.RUnlock() - for _, group := range q.groups { - if group.OrganizationID == arg.OrganizationID && - group.Name == arg.Name { - return group, nil - } - } - - return database.Group{}, sql.ErrNoRows + return q.getGroupByNameNoLock(database.NameOrganizationPair{ + Name: arg.Name, + OrganizationID: arg.OrganizationID, + }) } func (q *FakeQuerier) GetGroupMembers(ctx context.Context) ([]database.GroupMember, error) { @@ -7648,14 +7655,24 @@ func (q *FakeQuerier) RemoveUserFromGroups(_ context.Context, arg database.Remov removed := make([]uuid.UUID, 0) q.data.groupMembers = slices.DeleteFunc(q.data.groupMembers, func(groupMember database.GroupMemberTable) bool { + // Delete all group members that match the arguments. if groupMember.UserID != arg.UserID { + // Not the right user, ignore. return false } - if !slices.Contains(arg.GroupIds, groupMember.GroupID) { - return false + + matchesByID := slices.Contains(arg.GroupIds, groupMember.GroupID) + matchesByName := slices.ContainsFunc(arg.GroupNames, func(name database.NameOrganizationPair) bool { + _, err := q.getGroupByNameNoLock(name) + return err == nil + }) + + if matchesByName || matchesByID { + removed = append(removed, groupMember.GroupID) + return true } - removed = append(removed, groupMember.GroupID) - return true + + return false }) return removed, nil diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 7c7fbbf0f88f0..04c111bcee78f 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1540,19 +1540,30 @@ func (q *sqlQuerier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UU const removeUserFromGroups = `-- name: RemoveUserFromGroups :many DELETE FROM group_members + USING groups WHERE + group_members.group_id = groups.id AND user_id = $1 AND - group_id = ANY($2 :: uuid []) + ( + CASE WHEN array_length($2 :: name_organization_pair[], 1) > 0 THEN + -- Using 'coalesce' to avoid troubles with null literals being an empty string. + (groups.name, coalesce(groups.organization_id, '00000000-0000-0000-0000-000000000000' ::uuid)) = ANY ($2::name_organization_pair[]) + ELSE false + END + OR + group_id = ANY ($3 :: uuid[]) + ) RETURNING group_id ` type RemoveUserFromGroupsParams struct { - UserID uuid.UUID `db:"user_id" json:"user_id"` - GroupIds []uuid.UUID `db:"group_ids" json:"group_ids"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + GroupNames []NameOrganizationPair `db:"group_names" json:"group_names"` + GroupIds []uuid.UUID `db:"group_ids" json:"group_ids"` } func (q *sqlQuerier) RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error) { - rows, err := q.db.QueryContext(ctx, removeUserFromGroups, arg.UserID, pq.Array(arg.GroupIds)) + rows, err := q.db.QueryContext(ctx, removeUserFromGroups, arg.UserID, pq.Array(arg.GroupNames), pq.Array(arg.GroupIds)) if err != nil { return nil, err } diff --git a/coderd/database/queries/groupmembers.sql b/coderd/database/queries/groupmembers.sql index 814f878cb9232..5345d976fcd4e 100644 --- a/coderd/database/queries/groupmembers.sql +++ b/coderd/database/queries/groupmembers.sql @@ -57,9 +57,19 @@ WHERE -- name: RemoveUserFromGroups :many DELETE FROM group_members + USING groups WHERE + group_members.group_id = groups.id AND user_id = @user_id AND - group_id = ANY(@group_ids :: uuid []) + ( + CASE WHEN array_length(@group_names :: name_organization_pair[], 1) > 0 THEN + -- Using 'coalesce' to avoid troubles with null literals being an empty string. + (groups.name, coalesce(groups.organization_id, '00000000-0000-0000-0000-000000000000' ::uuid)) = ANY (@group_names::name_organization_pair[]) + ELSE false + END + OR + group_id = ANY (@group_ids :: uuid[]) + ) RETURNING group_id; -- name: InsertGroupMember :exec diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index c64b08ee07553..3560238e13b4a 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -91,8 +91,8 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat } // collect all diffs to do 1 sql update for all orgs - groupsToAdd := make([]uuid.UUID, 0) - groupsToRemove := make([]uuid.UUID, 0) + groupIDsToAdd := make([]uuid.UUID, 0) + groupsToRemove := make([]ExpectedGroup, 0) // For each org, determine which groups the user should land in for orgID, settings := range orgSettings { if settings.GroupField == "" { @@ -112,7 +112,8 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat } // Everyone group is always implied. expectedGroups = append(expectedGroups, ExpectedGroup{ - GroupID: &orgID, + OrganizationID: orgID, + GroupID: &orgID, }) // Now we know what groups the user should be in for a given org, @@ -121,8 +122,9 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat existingGroups := userOrgs[orgID] existingGroupsTyped := db2sdk.List(existingGroups, func(f database.GetGroupsRow) ExpectedGroup { return ExpectedGroup{ - GroupID: &f.Group.ID, - GroupName: &f.Group.Name, + OrganizationID: orgID, + GroupID: &f.Group.ID, + GroupName: &f.Group.Name, } }) add, remove := slice.SymmetricDifferenceFunc(existingGroupsTyped, expectedGroups, func(a, b ExpectedGroup) bool { @@ -144,52 +146,75 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat return xerrors.Errorf("handle missing groups: %w", err) } - for _, removeGroup := range remove { - // This should always be the case. - // TODO: make sure this is always the case - if removeGroup.GroupID != nil { - groupsToRemove = append(groupsToRemove, *removeGroup.GroupID) - } - } + groupsToRemove = append(groupsToRemove, remove...) + groupIDsToAdd = append(groupIDsToAdd, assignGroups...) + } - groupsToAdd = append(groupsToAdd, assignGroups...) + err = s.applyGroupDifference(ctx, tx, user, groupIDsToAdd, groupsToRemove) + if err != nil { + return xerrors.Errorf("apply group difference: %w", err) } - assignedGroupIDs, err := tx.InsertUserGroupsByID(ctx, database.InsertUserGroupsByIDParams{ - UserID: user.ID, - GroupIds: groupsToAdd, + return nil + }, nil) + + if err != nil { + return err + } + + return nil +} + +func (s AGPLIDPSync) applyGroupDifference(ctx context.Context, tx database.Store, user database.User, add []uuid.UUID, remove []ExpectedGroup) error { + // Always do group removal before group add. This way if there is an error, + // we error on the underprivileged side. + removeIDs := make([]uuid.UUID, 0) + removeNames := make([]database.NameOrganizationPair, 0) + for _, r := range remove { + if r.GroupID != nil { + removeIDs = append(removeIDs, *r.GroupID) + } else if r.GroupName != nil { + removeNames = append(removeNames, database.NameOrganizationPair{ + Name: *r.GroupName, + OrganizationID: r.OrganizationID, + }) + } + } + + // If there is something to remove, do it. + if len(removeIDs) > 0 || len(removeNames) > 0 { + removedGroupIDs, err := tx.RemoveUserFromGroups(ctx, database.RemoveUserFromGroupsParams{ + UserID: user.ID, + GroupNames: removeNames, + GroupIds: removeIDs, }) if err != nil { - return xerrors.Errorf("insert user into %d groups: %w", len(groupsToAdd), err) + return xerrors.Errorf("remove user from %d groups: %w", len(removeIDs), err) } - if len(assignedGroupIDs) != len(groupsToAdd) { - s.Logger.Debug(ctx, "failed to assign all groups to user", + if len(removedGroupIDs) != len(removeIDs) { + s.Logger.Debug(ctx, "failed to remove user from all groups", slog.F("user_id", user.ID), - slog.F("groups_assigned_count", len(assignedGroupIDs)), - slog.F("expected_count", len(groupsToAdd)), + slog.F("groups_removed_count", len(removedGroupIDs)), + slog.F("expected_count", len(removeIDs)), ) } + } - removedGroupIDs, err := tx.RemoveUserFromGroups(ctx, database.RemoveUserFromGroupsParams{ + if len(add) > 0 { + assignedGroupIDs, err := tx.InsertUserGroupsByID(ctx, database.InsertUserGroupsByIDParams{ UserID: user.ID, - GroupIds: groupsToRemove, + GroupIds: add, }) if err != nil { - return xerrors.Errorf("remove user from %d groups: %w", len(groupsToRemove), err) + return xerrors.Errorf("insert user into %d groups: %w", len(add), err) } - if len(removedGroupIDs) != len(groupsToRemove) { - s.Logger.Debug(ctx, "failed to remove user from all groups", + if len(assignedGroupIDs) != len(add) { + s.Logger.Debug(ctx, "failed to assign all groups to user", slog.F("user_id", user.ID), - slog.F("groups_removed_count", len(removedGroupIDs)), - slog.F("expected_count", len(groupsToRemove)), + slog.F("groups_assigned_count", len(assignedGroupIDs)), + slog.F("expected_count", len(add)), ) } - - return nil - }, nil) - - if err != nil { - return err } return nil @@ -226,8 +251,9 @@ func (s *GroupSyncSettings) Type() string { } type ExpectedGroup struct { - GroupID *uuid.UUID - GroupName *string + OrganizationID uuid.UUID + GroupID *uuid.UUID + GroupName *string } // ParseClaims will take the merged claims from the IDP and return the groups @@ -280,13 +306,20 @@ func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGr return groups, nil } +// HandleMissingGroups ensures all ExpectedGroups convert to uuids. +// Groups can be referenced by name via legacy params or IDP group names. +// These group names are converted to IDs for easier assignment. +// Missing groups are created if AutoCreate is enabled. +// TODO: Batching this would be better, as this is 1 or 2 db calls per organization. func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database.Store, orgID uuid.UUID, add []ExpectedGroup) ([]uuid.UUID, error) { if !s.AutoCreateMissingGroups { - // construct the list of groups to search by name to see if they exist. + // If we are not creating groups, then just construct a db lookup for + // all groups by name. var lookups []string filter := make([]uuid.UUID, 0) for _, expected := range add { if expected.GroupID != nil { + // Groups with IDs are easy! filter = append(filter, *expected.GroupID) } else if expected.GroupName != nil { lookups = append(lookups, *expected.GroupName) @@ -294,6 +327,7 @@ func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database. } if len(lookups) > 0 { + // Do name lookups for all groups that are missing IDs. newGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ OrganizationID: uuid.UUID{}, HasMemberID: uuid.UUID{}, diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index 406df099167c3..e1d0ac9d6c095 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -208,9 +208,12 @@ func TestGroupSyncTable(t *testing.T) { }, AutoCreateMissingGroups: true, }, - Groups: map[uuid.UUID]bool{}, + Groups: map[uuid.UUID]bool{ + ids.ID("lg-foo"): true, + }, GroupNames: map[string]bool{ "legacy-foo": false, + "extra": true, }, ExpectedGroupNames: []string{ "legacy-bar", From 640e86e47d633feb767de942b7faa3a437c8bc03 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 5 Sep 2024 16:08:36 -0500 Subject: [PATCH 17/38] group sync adjustments --- coderd/database/dbmem/dbmem.go | 18 ++-- coderd/database/queries.sql.go | 19 +--- coderd/database/queries/groupmembers.sql | 12 +-- coderd/idpsync/group.go | 125 +++++++++-------------- coderd/idpsync/group_test.go | 2 + 5 files changed, 63 insertions(+), 113 deletions(-) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index fd97fb0d701bf..7e761de411615 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -2730,6 +2730,10 @@ func (q *FakeQuerier) GetGroups(_ context.Context, arg database.GetGroupsParams) continue } + if len(arg.GroupNames) > 0 && !slices.Contains(arg.GroupNames, group.Name) { + continue + } + orgDetails, ok := orgDetailsCache[group.ID] if !ok { for _, org := range q.organizations { @@ -7661,18 +7665,12 @@ func (q *FakeQuerier) RemoveUserFromGroups(_ context.Context, arg database.Remov return false } - matchesByID := slices.Contains(arg.GroupIds, groupMember.GroupID) - matchesByName := slices.ContainsFunc(arg.GroupNames, func(name database.NameOrganizationPair) bool { - _, err := q.getGroupByNameNoLock(name) - return err == nil - }) - - if matchesByName || matchesByID { - removed = append(removed, groupMember.GroupID) - return true + if !slices.Contains(arg.GroupIds, groupMember.GroupID) { + return false } - return false + removed = append(removed, groupMember.GroupID) + return true }) return removed, nil diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 04c111bcee78f..7c7fbbf0f88f0 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1540,30 +1540,19 @@ func (q *sqlQuerier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UU const removeUserFromGroups = `-- name: RemoveUserFromGroups :many DELETE FROM group_members - USING groups WHERE - group_members.group_id = groups.id AND user_id = $1 AND - ( - CASE WHEN array_length($2 :: name_organization_pair[], 1) > 0 THEN - -- Using 'coalesce' to avoid troubles with null literals being an empty string. - (groups.name, coalesce(groups.organization_id, '00000000-0000-0000-0000-000000000000' ::uuid)) = ANY ($2::name_organization_pair[]) - ELSE false - END - OR - group_id = ANY ($3 :: uuid[]) - ) + group_id = ANY($2 :: uuid []) RETURNING group_id ` type RemoveUserFromGroupsParams struct { - UserID uuid.UUID `db:"user_id" json:"user_id"` - GroupNames []NameOrganizationPair `db:"group_names" json:"group_names"` - GroupIds []uuid.UUID `db:"group_ids" json:"group_ids"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + GroupIds []uuid.UUID `db:"group_ids" json:"group_ids"` } func (q *sqlQuerier) RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error) { - rows, err := q.db.QueryContext(ctx, removeUserFromGroups, arg.UserID, pq.Array(arg.GroupNames), pq.Array(arg.GroupIds)) + rows, err := q.db.QueryContext(ctx, removeUserFromGroups, arg.UserID, pq.Array(arg.GroupIds)) if err != nil { return nil, err } diff --git a/coderd/database/queries/groupmembers.sql b/coderd/database/queries/groupmembers.sql index 5345d976fcd4e..814f878cb9232 100644 --- a/coderd/database/queries/groupmembers.sql +++ b/coderd/database/queries/groupmembers.sql @@ -57,19 +57,9 @@ WHERE -- name: RemoveUserFromGroups :many DELETE FROM group_members - USING groups WHERE - group_members.group_id = groups.id AND user_id = @user_id AND - ( - CASE WHEN array_length(@group_names :: name_organization_pair[], 1) > 0 THEN - -- Using 'coalesce' to avoid troubles with null literals being an empty string. - (groups.name, coalesce(groups.organization_id, '00000000-0000-0000-0000-000000000000' ::uuid)) = ANY (@group_names::name_organization_pair[]) - ELSE false - END - OR - group_id = ANY (@group_ids :: uuid[]) - ) + group_id = ANY(@group_ids :: uuid []) RETURNING group_id; -- name: InsertGroupMember :exec diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 3560238e13b4a..f076c7c5d5c87 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -3,6 +3,7 @@ package idpsync import ( "context" "encoding/json" + "fmt" "regexp" "github.com/golang-jwt/jwt/v4" @@ -92,7 +93,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat // collect all diffs to do 1 sql update for all orgs groupIDsToAdd := make([]uuid.UUID, 0) - groupsToRemove := make([]ExpectedGroup, 0) + groupIDsToRemove := make([]uuid.UUID, 0) // For each org, determine which groups the user should land in for orgID, settings := range orgSettings { if settings.GroupField == "" { @@ -100,7 +101,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat continue } - expectedGroups, err := settings.ParseClaims(params.MergedClaims) + expectedGroups, err := settings.ParseClaims(orgID, params.MergedClaims) if err != nil { s.Logger.Debug(ctx, "failed to parse claims for groups", slog.F("organization_field", s.GroupField), @@ -128,6 +129,10 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat } }) add, remove := slice.SymmetricDifferenceFunc(existingGroupsTyped, expectedGroups, func(a, b ExpectedGroup) bool { + // Must match + if a.OrganizationID != b.OrganizationID { + return false + } // Only the name or the name needs to be checked, priority is given to the ID. if a.GroupID != nil && b.GroupID != nil { return *a.GroupID == *b.GroupID @@ -138,6 +143,20 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat return false }) + for _, r := range remove { + // This should never happen. All group removals come from the + // existing set, which come from the db. All groups from the + // database have IDs. This code is purely defensive. + if r.GroupID == nil { + detail := "user:" + user.Username + if r.GroupName != nil { + detail += fmt.Sprintf(" from group %s", *r.GroupName) + } + return xerrors.Errorf("removal group has nil ID, which should never happen: %s", detail) + } + groupIDsToRemove = append(groupIDsToRemove, *r.GroupID) + } + // HandleMissingGroups will add the new groups to the org if // the settings specify. It will convert all group names into uuids // for easier assignment. @@ -146,11 +165,10 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat return xerrors.Errorf("handle missing groups: %w", err) } - groupsToRemove = append(groupsToRemove, remove...) groupIDsToAdd = append(groupIDsToAdd, assignGroups...) } - err = s.applyGroupDifference(ctx, tx, user, groupIDsToAdd, groupsToRemove) + err = s.applyGroupDifference(ctx, tx, user, groupIDsToAdd, groupIDsToRemove) if err != nil { return xerrors.Errorf("apply group difference: %w", err) } @@ -165,28 +183,13 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat return nil } -func (s AGPLIDPSync) applyGroupDifference(ctx context.Context, tx database.Store, user database.User, add []uuid.UUID, remove []ExpectedGroup) error { +func (s AGPLIDPSync) applyGroupDifference(ctx context.Context, tx database.Store, user database.User, add []uuid.UUID, removeIDs []uuid.UUID) error { // Always do group removal before group add. This way if there is an error, // we error on the underprivileged side. - removeIDs := make([]uuid.UUID, 0) - removeNames := make([]database.NameOrganizationPair, 0) - for _, r := range remove { - if r.GroupID != nil { - removeIDs = append(removeIDs, *r.GroupID) - } else if r.GroupName != nil { - removeNames = append(removeNames, database.NameOrganizationPair{ - Name: *r.GroupName, - OrganizationID: r.OrganizationID, - }) - } - } - - // If there is something to remove, do it. - if len(removeIDs) > 0 || len(removeNames) > 0 { + if len(removeIDs) > 0 { removedGroupIDs, err := tx.RemoveUserFromGroups(ctx, database.RemoveUserFromGroupsParams{ - UserID: user.ID, - GroupNames: removeNames, - GroupIds: removeIDs, + UserID: user.ID, + GroupIds: removeIDs, }) if err != nil { return xerrors.Errorf("remove user from %d groups: %w", len(removeIDs), err) @@ -264,7 +267,7 @@ type ExpectedGroup struct { // the group "UUID 1234" is renamed, we want to maintain the mapping. // We have to keep names because group sync supports syncing groups by name if // the external IDP group name matches the Coder one. -func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGroup, error) { +func (s GroupSyncSettings) ParseClaims(orgID uuid.UUID, mergedClaims jwt.MapClaims) ([]ExpectedGroup, error) { groupsRaw, ok := mergedClaims[s.GroupField] if !ok { return []ExpectedGroup{}, nil @@ -294,13 +297,13 @@ func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGr if ok { for _, gid := range mappedGroupIDs { gid := gid - groups = append(groups, ExpectedGroup{GroupID: &gid}) + groups = append(groups, ExpectedGroup{OrganizationID: orgID, GroupID: &gid}) } continue } group := group - groups = append(groups, ExpectedGroup{GroupName: &group}) + groups = append(groups, ExpectedGroup{OrganizationID: orgID, GroupName: &group}) } return groups, nil @@ -312,38 +315,6 @@ func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGr // Missing groups are created if AutoCreate is enabled. // TODO: Batching this would be better, as this is 1 or 2 db calls per organization. func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database.Store, orgID uuid.UUID, add []ExpectedGroup) ([]uuid.UUID, error) { - if !s.AutoCreateMissingGroups { - // If we are not creating groups, then just construct a db lookup for - // all groups by name. - var lookups []string - filter := make([]uuid.UUID, 0) - for _, expected := range add { - if expected.GroupID != nil { - // Groups with IDs are easy! - filter = append(filter, *expected.GroupID) - } else if expected.GroupName != nil { - lookups = append(lookups, *expected.GroupName) - } - } - - if len(lookups) > 0 { - // Do name lookups for all groups that are missing IDs. - newGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ - OrganizationID: uuid.UUID{}, - HasMemberID: uuid.UUID{}, - GroupNames: lookups, - }) - if err != nil { - return nil, xerrors.Errorf("get groups by names: %w", err) - } - for _, g := range newGroups { - filter = append(filter, g.Group.ID) - } - } - - return filter, nil - } - // All expected that are missing IDs means the group does not exist // in the database. Either remove them, or create them if auto create is // turned on. @@ -359,33 +330,33 @@ func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database. } } - createdMissingGroups, err := tx.InsertMissingGroups(ctx, database.InsertMissingGroupsParams{ - OrganizationID: orgID, - Source: database.GroupSourceOidc, - GroupNames: missingGroups, - }) - if err != nil { - return nil, xerrors.Errorf("insert missing groups: %w", err) + if s.AutoCreateMissingGroups && len(missingGroups) > 0 { + // Insert any missing groups. If the groups already exist, this is a noop. + _, err := tx.InsertMissingGroups(ctx, database.InsertMissingGroupsParams{ + OrganizationID: orgID, + Source: database.GroupSourceOidc, + GroupNames: missingGroups, + }) + if err != nil { + return nil, xerrors.Errorf("insert missing groups: %w", err) + } } - if len(missingGroups) != len(createdMissingGroups) { - // This is unfortunate, but if legacy params are used, then some existing groups - // can come as params. So we need to fetch them - allGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ + // Fetch any missing groups by name. If they exist, their IDs will be + // matched and returned. + if len(missingGroups) > 0 { + // Do name lookups for all groups that are missing IDs. + newGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ OrganizationID: orgID, + HasMemberID: uuid.UUID{}, GroupNames: missingGroups, }) if err != nil { return nil, xerrors.Errorf("get groups by names: %w", err) } - - createdMissingGroups = db2sdk.List(allGroups, func(g database.GetGroupsRow) database.Group { - return g.Group - }) - } - - for _, created := range createdMissingGroups { - addIDs = append(addIDs, created.ID) + for _, g := range newGroups { + addIDs = append(addIDs, g.Group.ID) + } } return addIDs, nil diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index e1d0ac9d6c095..2207c52fd6830 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -205,6 +205,7 @@ func TestGroupSyncTable(t *testing.T) { LegacyGroupNameMapping: map[string]string{ "create-bar": "legacy-bar", "foo": "legacy-foo", + "bop": "legacy-bop", }, AutoCreateMissingGroups: true, }, @@ -214,6 +215,7 @@ func TestGroupSyncTable(t *testing.T) { GroupNames: map[string]bool{ "legacy-foo": false, "extra": true, + "legacy-bop": true, }, ExpectedGroupNames: []string{ "legacy-bar", From c544a293e30ffe9f087c4bdeb8f5a3b92ced1209 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 5 Sep 2024 16:46:30 -0500 Subject: [PATCH 18/38] test legacy params --- coderd/idpsync/group.go | 5 +- coderd/idpsync/group_test.go | 113 +++++++++++++++++++++++++++++++---- 2 files changed, 104 insertions(+), 14 deletions(-) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index f076c7c5d5c87..a6799d5e50ece 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -78,7 +78,6 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat if err != nil { return xerrors.Errorf("resolve group sync settings: %w", err) } - orgSettings[orgID] = *settings // Legacy deployment settings will override empty settings. if orgID == defaultOrgID && settings.GroupField == "" { @@ -89,6 +88,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat AutoCreateMissingGroups: s.Legacy.CreateMissingGroups, } } + orgSettings[orgID] = *settings } // collect all diffs to do 1 sql update for all orgs @@ -280,6 +280,8 @@ func (s GroupSyncSettings) ParseClaims(orgID uuid.UUID, mergedClaims jwt.MapClai groups := make([]ExpectedGroup, 0) for _, group := range parsedGroups { + group := group + // Legacy group mappings happen before the regex filter. mappedGroupName, ok := s.LegacyGroupNameMapping[group] if ok { @@ -302,7 +304,6 @@ func (s GroupSyncSettings) ParseClaims(orgID uuid.UUID, mergedClaims jwt.MapClai continue } - group := group groups = append(groups, ExpectedGroup{OrganizationID: orgID, GroupName: &group}) } diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index 2207c52fd6830..82b057422a787 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -14,6 +14,7 @@ import ( "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" + "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/idpsync" @@ -71,6 +72,7 @@ func TestGroupSyncTable(t *testing.T) { "groups": []string{ "foo", "bar", "baz", "create-bar", "create-baz", + "legacy-bar", }, } @@ -229,10 +231,6 @@ func TestGroupSyncTable(t *testing.T) { t.Run(tc.Name, func(t *testing.T) { t.Parallel() - if tc.OrgID == uuid.Nil { - tc.OrgID = uuid.New() - } - db, _ := dbtestutil.NewDB(t) manager := runtimeconfig.NewStoreManager() s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), @@ -242,9 +240,10 @@ func TestGroupSyncTable(t *testing.T) { }, ) - ctx := testutil.Context(t, testutil.WaitMedium) + ctx := testutil.Context(t, testutil.WaitSuperLong) user := dbgen.User(t, db, database.User{}) - SetupOrganization(t, s, db, user, tc) + orgID := uuid.New() + SetupOrganization(t, s, db, user, orgID, tc) // Do the group sync! err := s.SyncGroups(ctx, db, user, idpsync.GroupParams{ @@ -253,17 +252,106 @@ func TestGroupSyncTable(t *testing.T) { }) require.NoError(t, err) - tc.Assert(t, tc.OrgID, db, user) + tc.Assert(t, orgID, db, user) }) } + + // AllTogether runs the entire tabled test as a singular user and + // deployment. This tests all organizations being synced together. + // The reason we do them individually, is that it is much easier to + // debug a single test case. + t.Run("AllTogether", func(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + manager := runtimeconfig.NewStoreManager() + s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), + manager, + // Also sync the default org! + idpsync.DeploymentSyncSettings{ + GroupField: "groups", + Legacy: idpsync.DefaultOrgLegacySettings{ + GroupField: "groups", + GroupMapping: map[string]string{ + "foo": "legacy-foo", + "baz": "legacy-baz", + }, + GroupFilter: regexp.MustCompile("^legacy"), + CreateMissingGroups: true, + }, + }, + ) + + ctx := testutil.Context(t, testutil.WaitSuperLong) + user := dbgen.User(t, db, database.User{}) + + var asserts []func(t *testing.T) + // The default org is also going to do something + def := orgSetupDefinition{ + Name: "DefaultOrg", + GroupNames: map[string]bool{ + "legacy-foo": false, + "legacy-baz": true, + "random": true, + }, + // No settings, because they come from the deployment values + Settings: nil, + ExpectedGroups: nil, + ExpectedGroupNames: []string{"legacy-foo", "legacy-baz", "legacy-bar"}, + } + + //nolint:gocritic // testing + defOrg, err := db.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx)) + require.NoError(t, err) + SetupOrganization(t, s, db, user, defOrg.ID, def) + asserts = append(asserts, func(t *testing.T) { + t.Run(def.Name, func(t *testing.T) { + t.Parallel() + def.Assert(t, defOrg.ID, db, user) + }) + }) + + for _, tc := range testCases { + tc := tc + + orgID := uuid.New() + SetupOrganization(t, s, db, user, orgID, tc) + asserts = append(asserts, func(t *testing.T) { + t.Run(tc.Name, func(t *testing.T) { + t.Parallel() + tc.Assert(t, orgID, db, user) + }) + }) + } + + asserts = append(asserts, func(t *testing.T) { + t.Helper() + def.Assert(t, defOrg.ID, db, user) + }) + + // Do the group sync! + err = s.SyncGroups(ctx, db, user, idpsync.GroupParams{ + SyncEnabled: true, + MergedClaims: userClaims, + }) + require.NoError(t, err) + + for _, assert := range asserts { + assert(t) + } + }) } -func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, user database.User, def orgSetupDefinition) { +func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, user database.User, orgID uuid.UUID, def orgSetupDefinition) { + t.Helper() + org := dbgen.Organization(t, db, database.Organization{ - ID: def.OrgID, + ID: orgID, }) _, err := db.InsertAllUsersGroup(context.Background(), org.ID) - require.NoError(t, err, "Everyone group for an org") + if !database.IsUniqueViolation(err) { + require.NoError(t, err, "Everyone group for an org") + } manager := runtimeconfig.NewStoreManager() orgResolver := manager.OrganizationResolver(db, org.ID) @@ -303,8 +391,7 @@ func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, } type orgSetupDefinition struct { - Name string - OrgID uuid.UUID + Name string // True if the user is a member of the group Groups map[uuid.UUID]bool GroupNames map[string]bool @@ -353,11 +440,13 @@ func (o orgSetupDefinition) Assert(t *testing.T, orgID uuid.UUID, db database.St return g.Group.Name }) require.ElementsMatch(t, o.ExpectedGroupNames, found, "user groups by name") + require.Len(t, o.ExpectedGroups, 0, "ExpectedGroups should be empty") } else { // Check by ID, recommended found := db2sdk.List(userGroups, func(g database.GetGroupsRow) uuid.UUID { return g.Group.ID }) require.ElementsMatch(t, o.ExpectedGroups, found, "user groups") + require.Len(t, o.ExpectedGroupNames, 0, "ExpectedGroupNames should be empty") } } From 476be45195870bdc133d77533e137e87925b12da Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 5 Sep 2024 17:07:00 -0500 Subject: [PATCH 19/38] add unit test for ApplyGroupDifference --- coderd/database/dbmem/dbmem.go | 8 +- coderd/idpsync/group.go | 5 +- coderd/idpsync/group_test.go | 152 +++++++++++++++++++++++++++++++++ 3 files changed, 159 insertions(+), 6 deletions(-) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 7e761de411615..2e4e737ed5428 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -2702,18 +2702,18 @@ func (q *FakeQuerier) GetGroups(_ context.Context, arg database.GetGroupsParams) q.mutex.RLock() defer q.mutex.RUnlock() - groupIDs := make(map[uuid.UUID]struct{}) + userGroupIDs := make(map[uuid.UUID]struct{}) if arg.HasMemberID != uuid.Nil { for _, member := range q.groupMembers { if member.UserID == arg.HasMemberID { - groupIDs[member.GroupID] = struct{}{} + userGroupIDs[member.GroupID] = struct{}{} } } // Handle the everyone group for _, orgMember := range q.organizationMembers { if orgMember.UserID == arg.HasMemberID { - groupIDs[orgMember.OrganizationID] = struct{}{} + userGroupIDs[orgMember.OrganizationID] = struct{}{} } } } @@ -2725,7 +2725,7 @@ func (q *FakeQuerier) GetGroups(_ context.Context, arg database.GetGroupsParams) continue } - _, ok := groupIDs[group.ID] + _, ok := userGroupIDs[group.ID] if arg.HasMemberID != uuid.Nil && !ok { continue } diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index a6799d5e50ece..0930ede7cc545 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -168,7 +168,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat groupIDsToAdd = append(groupIDsToAdd, assignGroups...) } - err = s.applyGroupDifference(ctx, tx, user, groupIDsToAdd, groupIDsToRemove) + err = s.ApplyGroupDifference(ctx, tx, user, groupIDsToAdd, groupIDsToRemove) if err != nil { return xerrors.Errorf("apply group difference: %w", err) } @@ -183,7 +183,8 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat return nil } -func (s AGPLIDPSync) applyGroupDifference(ctx context.Context, tx database.Store, user database.User, add []uuid.UUID, removeIDs []uuid.UUID) error { +// ApplyGroupDifference will add and remove the user from the specified groups. +func (s AGPLIDPSync) ApplyGroupDifference(ctx context.Context, tx database.Store, user database.User, add []uuid.UUID, removeIDs []uuid.UUID) error { // Always do group removal before group add. This way if there is an error, // we error on the underprivileged side. if len(removeIDs) > 0 { diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index 82b057422a787..aa9e3e6c68b46 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -342,6 +342,158 @@ func TestGroupSyncTable(t *testing.T) { }) } +// TestApplyGroupDifference is mainly testing the database functions +func TestApplyGroupDifference(t *testing.T) { + t.Parallel() + + ids := coderdtest.NewDeterministicUUIDGenerator() + testCase := []struct { + Name string + Before map[uuid.UUID]bool + Add []uuid.UUID + Remove []uuid.UUID + Expect []uuid.UUID + }{ + { + Name: "Empty", + }, + { + Name: "AddFromNone", + Before: map[uuid.UUID]bool{ + ids.ID("g1"): false, + }, + Add: []uuid.UUID{ + ids.ID("g1"), + }, + Expect: []uuid.UUID{ + ids.ID("g1"), + }, + }, + { + Name: "AddSome", + Before: map[uuid.UUID]bool{ + ids.ID("g1"): true, + ids.ID("g2"): false, + ids.ID("g3"): false, + uuid.New(): false, + }, + Add: []uuid.UUID{ + ids.ID("g2"), + ids.ID("g3"), + }, + Expect: []uuid.UUID{ + ids.ID("g1"), + ids.ID("g2"), + ids.ID("g3"), + }, + }, + { + Name: "RemoveAll", + Before: map[uuid.UUID]bool{ + uuid.New(): false, + ids.ID("g2"): true, + ids.ID("g3"): true, + }, + Remove: []uuid.UUID{ + ids.ID("g2"), + ids.ID("g3"), + }, + Expect: []uuid.UUID{}, + }, + { + Name: "Mixed", + Before: map[uuid.UUID]bool{ + // adds + ids.ID("a1"): true, + ids.ID("a2"): true, + ids.ID("a3"): false, + ids.ID("a4"): false, + // removes + ids.ID("r1"): true, + ids.ID("r2"): true, + ids.ID("r3"): false, + ids.ID("r4"): false, + // stable + ids.ID("s1"): true, + ids.ID("s2"): true, + // noise + uuid.New(): false, + uuid.New(): false, + }, + Add: []uuid.UUID{ + ids.ID("a1"), ids.ID("a2"), + ids.ID("a3"), ids.ID("a4"), + // Double up to try and confuse + ids.ID("a1"), + ids.ID("a4"), + }, + Remove: []uuid.UUID{ + ids.ID("r1"), ids.ID("r2"), + ids.ID("r3"), ids.ID("r4"), + // Double up to try and confuse + ids.ID("r1"), + ids.ID("r4"), + }, + Expect: []uuid.UUID{ + ids.ID("a1"), ids.ID("a2"), ids.ID("a3"), ids.ID("a4"), + ids.ID("s1"), ids.ID("s2"), + }, + }, + } + + for _, tc := range testCase { + tc := tc + t.Run(tc.Name, func(t *testing.T) { + mgr := runtimeconfig.NewStoreManager() + db, _ := dbtestutil.NewDB(t) + + ctx := testutil.Context(t, testutil.WaitMedium) + //nolint:gocritic // testing + ctx = dbauthz.AsSystemRestricted(ctx) + + org := dbgen.Organization(t, db, database.Organization{}) + _, err := db.InsertAllUsersGroup(ctx, org.ID) + require.NoError(t, err) + + user := dbgen.User(t, db, database.User{}) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + + for gid, in := range tc.Before { + group := dbgen.Group(t, db, database.Group{ + ID: gid, + OrganizationID: org.ID, + }) + if in { + _ = dbgen.GroupMember(t, db, database.GroupMemberTable{ + UserID: user.ID, + GroupID: group.ID, + }) + } + } + + s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), mgr, idpsync.FromDeploymentValues(coderdtest.DeploymentValues(t))) + err = s.ApplyGroupDifference(context.Background(), db, user, tc.Add, tc.Remove) + require.NoError(t, err) + + userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{ + HasMemberID: user.ID, + }) + require.NoError(t, err) + + // assert + found := db2sdk.List(userGroups, func(g database.GetGroupsRow) uuid.UUID { + return g.Group.ID + }) + + // Add everyone group + require.ElementsMatch(t, append(tc.Expect, org.ID), found) + }) + } +} + func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, user database.User, orgID uuid.UUID, def orgSetupDefinition) { t.Helper() From 164aeacebac6f544d85c73d4cd83fe028ce5259a Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 5 Sep 2024 17:19:20 -0500 Subject: [PATCH 20/38] chore: remove old group sync code --- coderd/coderd.go | 11 --- coderd/idpsync/group.go | 8 +- coderd/idpsync/idpsync.go | 4 +- coderd/userauth.go | 178 ++++++---------------------------- enterprise/coderd/coderd.go | 1 - enterprise/coderd/userauth.go | 66 ------------- 6 files changed, 36 insertions(+), 232 deletions(-) diff --git a/coderd/coderd.go b/coderd/coderd.go index b829d37a06773..e04f13d367c6e 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -181,7 +181,6 @@ type Options struct { NetworkTelemetryBatchFrequency time.Duration NetworkTelemetryBatchMaxSize int SwaggerEndpoint bool - SetUserGroups func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error SetUserSiteRoles func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, roles []string) error TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore] UserQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore] @@ -374,16 +373,6 @@ func New(options *Options) *API { if options.TracerProvider == nil { options.TracerProvider = trace.NewNoopTracerProvider() } - if options.SetUserGroups == nil { - options.SetUserGroups = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error { - logger.Warn(ctx, "attempted to assign OIDC groups without enterprise license", - slog.F("user_id", userID), - slog.F("groups", orgGroupNames), - slog.F("create_missing_groups", createMissingGroups), - ) - return nil - } - } if options.SetUserSiteRoles == nil { options.SetUserSiteRoles = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, roles []string) error { logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise license", diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 0930ede7cc545..660d0b9b9c23e 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -14,6 +14,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/runtimeconfig" "github.com/coder/coder/v2/coderd/util/slice" ) @@ -76,7 +77,12 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat orgResolver := s.Manager.OrganizationResolver(tx, orgID) settings, err := s.SyncSettings.Group.Resolve(ctx, orgResolver) if err != nil { - return xerrors.Errorf("resolve group sync settings: %w", err) + if xerrors.Is(err, runtimeconfig.EntryNotFound) { + // Default to not being configured + settings = &GroupSyncSettings{} + } else { + return xerrors.Errorf("resolve group sync settings: %w", err) + } } // Legacy deployment settings will override empty settings. diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index bc3e5cd479064..7fac0e7329d3d 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -27,7 +27,7 @@ type IDPSync interface { OrganizationSyncEnabled() bool // ParseOrganizationClaims takes claims from an OIDC provider, and returns the // organization sync params for assigning users into organizations. - ParseOrganizationClaims(ctx context.Context, _ jwt.MapClaims) (OrganizationParams, *HTTPError) + ParseOrganizationClaims(ctx context.Context, mergedClaims jwt.MapClaims) (OrganizationParams, *HTTPError) // SyncOrganizations assigns and removed users from organizations based on the // provided params. SyncOrganizations(ctx context.Context, tx database.Store, user database.User, params OrganizationParams) error @@ -35,7 +35,7 @@ type IDPSync interface { GroupSyncEnabled() bool // ParseGroupClaims takes claims from an OIDC provider, and returns the params // for group syncing. Most of the logic happens in SyncGroups. - ParseGroupClaims(ctx context.Context, _ jwt.MapClaims) (GroupParams, *HTTPError) + ParseGroupClaims(ctx context.Context, mergedClaims jwt.MapClaims) (GroupParams, *HTTPError) // SyncGroups assigns and removes users from groups based on the provided params. SyncGroups(ctx context.Context, db database.Store, user database.User, params GroupParams) error diff --git a/coderd/userauth.go b/coderd/userauth.go index a1abadc63f31a..76d29a7c1a9ec 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -20,7 +20,6 @@ import ( "github.com/google/go-github/v43/github" "github.com/google/uuid" "github.com/moby/moby/pkg/namesgenerator" - "golang.org/x/exp/slices" "golang.org/x/oauth2" "golang.org/x/xerrors" @@ -659,6 +658,9 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { AvatarURL: ghUser.GetAvatarURL(), Name: normName, DebugContext: OauthDebugContext{}, + GroupSync: idpsync.GroupParams{ + SyncEnabled: false, + }, OrganizationSync: idpsync.OrganizationParams{ SyncEnabled: false, IncludeDefault: true, @@ -1004,11 +1006,6 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { } ctx = slog.With(ctx, slog.F("email", email), slog.F("username", username), slog.F("name", name)) - usingGroups, groups, groupErr := api.oidcGroups(ctx, mergedClaims) - if groupErr != nil { - groupErr.Write(rw, r) - return - } roles, roleErr := api.oidcRoles(ctx, mergedClaims) if roleErr != nil { @@ -1032,6 +1029,12 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { return } + groupSync, groupSyncErr := api.IDPSync.ParseGroupClaims(ctx, mergedClaims) + if groupSyncErr != nil { + groupSyncErr.Write(rw, r) + return + } + // If a new user is authenticating for the first time // the audit action is 'register', not 'login' if user.ID == uuid.Nil { @@ -1039,23 +1042,20 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { } params := (&oauthLoginParams{ - User: user, - Link: link, - State: state, - LinkedID: oidcLinkedID(idToken), - LoginType: database.LoginTypeOIDC, - AllowSignups: api.OIDCConfig.AllowSignups, - Email: email, - Username: username, - Name: name, - AvatarURL: picture, - UsingRoles: api.OIDCConfig.RoleSyncEnabled(), - Roles: roles, - UsingGroups: usingGroups, - Groups: groups, - OrganizationSync: orgSync, - CreateMissingGroups: api.OIDCConfig.CreateMissingGroups, - GroupFilter: api.OIDCConfig.GroupFilter, + User: user, + Link: link, + State: state, + LinkedID: oidcLinkedID(idToken), + LoginType: database.LoginTypeOIDC, + AllowSignups: api.OIDCConfig.AllowSignups, + Email: email, + Username: username, + Name: name, + AvatarURL: picture, + UsingRoles: api.OIDCConfig.RoleSyncEnabled(), + Roles: roles, + OrganizationSync: orgSync, + GroupSync: groupSync, DebugContext: OauthDebugContext{ IDTokenClaims: idtokenClaims, UserInfoClaims: userInfoClaims, @@ -1091,79 +1091,6 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect) } -// oidcGroups returns the groups for the user from the OIDC claims. -func (api *API) oidcGroups(ctx context.Context, mergedClaims map[string]interface{}) (bool, []string, *idpsync.HTTPError) { - logger := api.Logger.Named(userAuthLoggerName) - usingGroups := false - var groups []string - - // If the GroupField is the empty string, then groups from OIDC are not used. - // This is so we can support manual group assignment. - if api.OIDCConfig.GroupField != "" { - // If the allow list is empty, then the user is allowed to log in. - // Otherwise, they must belong to at least 1 group in the allow list. - inAllowList := len(api.OIDCConfig.GroupAllowList) == 0 - - usingGroups = true - groupsRaw, ok := mergedClaims[api.OIDCConfig.GroupField] - if ok { - parsedGroups, err := idpsync.ParseStringSliceClaim(groupsRaw) - if err != nil { - api.Logger.Debug(ctx, "groups field was an unknown type in oidc claims", - slog.F("type", fmt.Sprintf("%T", groupsRaw)), - slog.Error(err), - ) - return false, nil, &idpsync.HTTPError{ - Code: http.StatusBadRequest, - Msg: "Failed to sync groups from OIDC claims", - Detail: err.Error(), - RenderStaticPage: false, - } - } - - api.Logger.Debug(ctx, "groups returned in oidc claims", - slog.F("len", len(parsedGroups)), - slog.F("groups", parsedGroups), - ) - - for _, group := range parsedGroups { - if mappedGroup, ok := api.OIDCConfig.GroupMapping[group]; ok { - group = mappedGroup - } - if _, ok := api.OIDCConfig.GroupAllowList[group]; ok { - inAllowList = true - } - groups = append(groups, group) - } - } - - if !inAllowList { - logger.Debug(ctx, "oidc group claim not in allow list, rejecting login", - slog.F("allow_list_count", len(api.OIDCConfig.GroupAllowList)), - slog.F("user_group_count", len(groups)), - ) - detail := "Ask an administrator to add one of your groups to the allow list" - if len(groups) == 0 { - detail = "You are currently not a member of any groups! Ask an administrator to add you to an authorized group to login." - } - return usingGroups, groups, &idpsync.HTTPError{ - Code: http.StatusForbidden, - Msg: "Not a member of an allowed group", - Detail: detail, - RenderStaticPage: true, - } - } - } - - // This conditional is purely to warn the user they might have misconfigured their OIDC - // configuration. - if _, groupClaimExists := mergedClaims["groups"]; !usingGroups && groupClaimExists { - logger.Debug(ctx, "claim 'groups' was returned, but 'oidc-group-field' is not set, check your coder oidc settings") - } - - return usingGroups, groups, nil -} - // oidcRoles returns the roles for the user from the OIDC claims. // If the function returns false, then the caller should return early. // All writes to the response writer are handled by this function. @@ -1278,14 +1205,7 @@ type oauthLoginParams struct { AvatarURL string // OrganizationSync has the organizations that the user will be assigned to. OrganizationSync idpsync.OrganizationParams - // Is UsingGroups is true, then the user will be assigned - // to the Groups provided. - UsingGroups bool - CreateMissingGroups bool - // These are the group names from the IDP. Internally, they will map to - // some organization groups. - Groups []string - GroupFilter *regexp.Regexp + GroupSync idpsync.GroupParams // Is UsingRoles is true, then the user will be assigned // the roles provided. UsingRoles bool @@ -1491,53 +1411,9 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C return xerrors.Errorf("sync organizations: %w", err) } - // Ensure groups are correct. - // This places all groups into the default organization. - // To go multi-org, we need to add a mapping feature here to know which - // groups go to which orgs. - if params.UsingGroups { - filtered := params.Groups - if params.GroupFilter != nil { - filtered = make([]string, 0, len(params.Groups)) - for _, group := range params.Groups { - if params.GroupFilter.MatchString(group) { - filtered = append(filtered, group) - } - } - } - - //nolint:gocritic // No user present in the context. - defaultOrganization, err := tx.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx)) - if err != nil { - // If there is no default org, then we can't assign groups. - // By default, we assume all groups belong to the default org. - return xerrors.Errorf("get default organization: %w", err) - } - - //nolint:gocritic // No user present in the context. - memberships, err := tx.OrganizationMembers(dbauthz.AsSystemRestricted(ctx), database.OrganizationMembersParams{ - UserID: user.ID, - OrganizationID: uuid.Nil, - }) - if err != nil { - return xerrors.Errorf("get organization memberships: %w", err) - } - - // If the user is not in the default organization, then we can't assign groups. - // A user cannot be in groups to an org they are not a member of. - if !slices.ContainsFunc(memberships, func(member database.OrganizationMembersRow) bool { - return member.OrganizationMember.OrganizationID == defaultOrganization.ID - }) { - return xerrors.Errorf("user %s is not a member of the default organization, cannot assign to groups in the org", user.ID) - } - - //nolint:gocritic - err = api.Options.SetUserGroups(dbauthz.AsSystemRestricted(ctx), logger, tx, user.ID, map[uuid.UUID][]string{ - defaultOrganization.ID: filtered, - }, params.CreateMissingGroups) - if err != nil { - return xerrors.Errorf("set user groups: %w", err) - } + err = api.IDPSync.SyncGroups(ctx, tx, user, params.GroupSync) + if err != nil { + return xerrors.Errorf("sync groups: %w", err) } // Ensure roles are correct. diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index ce55bae8ec8d0..f9ab3e452ac04 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -145,7 +145,6 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { } return c.Subject, c.Trial, nil } - api.AGPL.Options.SetUserGroups = api.setUserGroups api.AGPL.Options.SetUserSiteRoles = api.setUserSiteRoles api.AGPL.SiteHandler.RegionsFetcher = func(ctx context.Context) (any, error) { // If the user can read the workspace proxy resource, return that. diff --git a/enterprise/coderd/userauth.go b/enterprise/coderd/userauth.go index 65c4a3473f3f7..60cba28cc37f3 100644 --- a/enterprise/coderd/userauth.go +++ b/enterprise/coderd/userauth.go @@ -8,75 +8,9 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/codersdk" ) -// nolint: revive -func (api *API) setUserGroups(ctx context.Context, logger slog.Logger, db database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error { - if !api.Entitlements.Enabled(codersdk.FeatureTemplateRBAC) { - return nil - } - - return db.InTx(func(tx database.Store) error { - // When setting the user's groups, it's easier to just clear their groups and re-add them. - // This ensures that the user's groups are always in sync with the auth provider. - orgs, err := tx.GetOrganizationsByUserID(ctx, userID) - if err != nil { - return xerrors.Errorf("get user orgs: %w", err) - } - if len(orgs) != 1 { - return xerrors.Errorf("expected 1 org, got %d", len(orgs)) - } - - // Delete all groups the user belongs to. - // nolint:gocritic // Requires system context to remove user from all groups. - err = tx.RemoveUserFromAllGroups(dbauthz.AsSystemRestricted(ctx), userID) - if err != nil { - return xerrors.Errorf("delete user groups: %w", err) - } - - // TODO: This could likely be improved by making these single queries. - // Either by batching or some other means. This for loop could be really - // inefficient if there are a lot of organizations. There was deployments - // on v1 with >100 orgs. - for orgID, groupNames := range orgGroupNames { - // Create the missing groups for each organization. - if createMissingGroups { - // This is the system creating these additional groups, so we use the system restricted context. - // nolint:gocritic - created, err := tx.InsertMissingGroups(dbauthz.AsSystemRestricted(ctx), database.InsertMissingGroupsParams{ - OrganizationID: orgID, - GroupNames: groupNames, - Source: database.GroupSourceOidc, - }) - if err != nil { - return xerrors.Errorf("insert missing groups: %w", err) - } - if len(created) > 0 { - logger.Debug(ctx, "auto created missing groups", - slog.F("org_id", orgID.ID), - slog.F("created", created), - slog.F("num", len(created)), - ) - } - } - - // Re-add the user to all groups returned by the auth provider. - err = tx.InsertUserGroupsByName(ctx, database.InsertUserGroupsByNameParams{ - UserID: userID, - OrganizationID: orgID, - GroupNames: groupNames, - }) - if err != nil { - return xerrors.Errorf("insert user groups: %w", err) - } - } - - return nil - }, nil) -} - func (api *API) setUserSiteRoles(ctx context.Context, logger slog.Logger, db database.Store, userID uuid.UUID, roles []string) error { if !api.Entitlements.Enabled(codersdk.FeatureUserRoleManagement) { logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise entitlement, roles left unchanged", From 986498d5fb0ad4fe8a79b481617a71a233b99db7 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 5 Sep 2024 17:38:33 -0500 Subject: [PATCH 21/38] switch oidc test config to deployment values --- cli/server.go | 5 -- coderd/idpsync/group.go | 50 ++++++++++-------- coderd/idpsync/group_test.go | 38 +++++++------- coderd/userauth.go | 24 +-------- enterprise/coderd/userauth_test.go | 83 ++++++++++++++++++------------ 5 files changed, 100 insertions(+), 100 deletions(-) diff --git a/cli/server.go b/cli/server.go index 4e3b1e16a1482..c2cd476edfaa4 100644 --- a/cli/server.go +++ b/cli/server.go @@ -187,11 +187,6 @@ func createOIDCConfig(ctx context.Context, logger slog.Logger, vals *codersdk.De EmailField: vals.OIDC.EmailField.String(), AuthURLParams: vals.OIDC.AuthURLParams.Value, IgnoreUserInfo: vals.OIDC.IgnoreUserInfo.Value(), - GroupField: vals.OIDC.GroupField.String(), - GroupFilter: vals.OIDC.GroupRegexFilter.Value(), - GroupAllowList: groupAllowList, - CreateMissingGroups: vals.OIDC.GroupAutoCreate.Value(), - GroupMapping: vals.OIDC.GroupMapping.Value, UserRoleField: vals.OIDC.UserRoleField.String(), UserRoleMapping: vals.OIDC.UserRoleMapping.Value, UserRolesDefault: vals.OIDC.UserRolesDefault.GetSlice(), diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 660d0b9b9c23e..69915125acc71 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -41,6 +41,9 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat return nil } + // nolint:gocritic // all syncing is done as a system user + ctx = dbauthz.AsSystemRestricted(ctx) + // Only care about the default org for deployment settings if the // legacy deployment settings exist. defaultOrgID := uuid.Nil @@ -53,9 +56,6 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat defaultOrgID = defaultOrganization.ID } - // nolint:gocritic // all syncing is done as a system user - ctx = dbauthz.AsSystemRestricted(ctx) - err := db.InTx(func(tx database.Store) error { userGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ HasMemberID: user.ID, @@ -86,12 +86,12 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat } // Legacy deployment settings will override empty settings. - if orgID == defaultOrgID && settings.GroupField == "" { + if orgID == defaultOrgID && settings.Field == "" { settings = &GroupSyncSettings{ - GroupField: s.Legacy.GroupField, - LegacyGroupNameMapping: s.Legacy.GroupMapping, - RegexFilter: s.Legacy.GroupFilter, - AutoCreateMissingGroups: s.Legacy.CreateMissingGroups, + Field: s.Legacy.GroupField, + LegacyNameMapping: s.Legacy.GroupMapping, + RegexFilter: s.Legacy.GroupFilter, + AutoCreateMissing: s.Legacy.CreateMissingGroups, } } orgSettings[orgID] = *settings @@ -102,7 +102,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat groupIDsToRemove := make([]uuid.UUID, 0) // For each org, determine which groups the user should land in for orgID, settings := range orgSettings { - if settings.GroupField == "" { + if settings.Field == "" { // No group sync enabled for this org, so do nothing. continue } @@ -231,17 +231,25 @@ func (s AGPLIDPSync) ApplyGroupDifference(ctx context.Context, tx database.Store } type GroupSyncSettings struct { - GroupField string `json:"field"` - // GroupMapping maps from an OIDC group --> Coder group ID - GroupMapping map[string][]uuid.UUID `json:"mapping"` - RegexFilter *regexp.Regexp `json:"regex_filter"` - AutoCreateMissingGroups bool `json:"auto_create_missing_groups"` - // LegacyGroupNameMapping is deprecated. It remaps an IDP group name to + // Field selects the claim field to be used as the created user's + // groups. If the group field is the empty string, then no group updates + // will ever come from the OIDC provider. + Field string `json:"field"` + // Mapping maps from an OIDC group --> Coder group ID + Mapping map[string][]uuid.UUID `json:"mapping"` + // RegexFilter is a regular expression that filters the groups returned by + // the OIDC provider. Any group not matched by this regex will be ignored. + // If the group filter is nil, then no group filtering will occur. + RegexFilter *regexp.Regexp `json:"regex_filter"` + // AutoCreateMissing controls whether groups returned by the OIDC provider + // are automatically created in Coder if they are missing. + AutoCreateMissing bool `json:"auto_create_missing_groups"` + // LegacyNameMapping is deprecated. It remaps an IDP group name to // a Coder group name. Since configuration is now done at runtime, // group IDs are used to account for group renames. // For legacy configurations, this config option has to remain. - // Deprecated: Use GroupMapping instead. - LegacyGroupNameMapping map[string]string `json:"legacy_group_name_mapping,omitempty"` + // Deprecated: Use Mapping instead. + LegacyNameMapping map[string]string `json:"legacy_group_name_mapping,omitempty"` } func (s *GroupSyncSettings) Set(v string) error { @@ -275,7 +283,7 @@ type ExpectedGroup struct { // We have to keep names because group sync supports syncing groups by name if // the external IDP group name matches the Coder one. func (s GroupSyncSettings) ParseClaims(orgID uuid.UUID, mergedClaims jwt.MapClaims) ([]ExpectedGroup, error) { - groupsRaw, ok := mergedClaims[s.GroupField] + groupsRaw, ok := mergedClaims[s.Field] if !ok { return []ExpectedGroup{}, nil } @@ -290,7 +298,7 @@ func (s GroupSyncSettings) ParseClaims(orgID uuid.UUID, mergedClaims jwt.MapClai group := group // Legacy group mappings happen before the regex filter. - mappedGroupName, ok := s.LegacyGroupNameMapping[group] + mappedGroupName, ok := s.LegacyNameMapping[group] if ok { group = mappedGroupName } @@ -302,7 +310,7 @@ func (s GroupSyncSettings) ParseClaims(orgID uuid.UUID, mergedClaims jwt.MapClai } } - mappedGroupIDs, ok := s.GroupMapping[group] + mappedGroupIDs, ok := s.Mapping[group] if ok { for _, gid := range mappedGroupIDs { gid := gid @@ -338,7 +346,7 @@ func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database. } } - if s.AutoCreateMissingGroups && len(missingGroups) > 0 { + if s.AutoCreateMissing && len(missingGroups) > 0 { // Insert any missing groups. If the groups already exist, this is a noop. _, err := tx.InsertMissingGroups(ctx, database.InsertMissingGroupsParams{ OrganizationID: orgID, diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index aa9e3e6c68b46..4e56260400114 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -81,8 +81,8 @@ func TestGroupSyncTable(t *testing.T) { { Name: "SwitchGroups", Settings: &idpsync.GroupSyncSettings{ - GroupField: "groups", - GroupMapping: map[string][]uuid.UUID{ + Field: "groups", + Mapping: map[string][]uuid.UUID{ "foo": {ids.ID("sg-foo"), ids.ID("sg-foo-2")}, "bar": {ids.ID("sg-bar")}, "baz": {ids.ID("sg-baz")}, @@ -107,10 +107,10 @@ func TestGroupSyncTable(t *testing.T) { { Name: "StayInGroup", Settings: &idpsync.GroupSyncSettings{ - GroupField: "groups", + Field: "groups", // Only match foo, so bar does not map RegexFilter: regexp.MustCompile("^foo$"), - GroupMapping: map[string][]uuid.UUID{ + Mapping: map[string][]uuid.UUID{ "foo": {ids.ID("gg-foo"), uuid.New()}, "bar": {ids.ID("gg-bar")}, "baz": {ids.ID("gg-baz")}, @@ -127,8 +127,8 @@ func TestGroupSyncTable(t *testing.T) { { Name: "UserJoinsGroups", Settings: &idpsync.GroupSyncSettings{ - GroupField: "groups", - GroupMapping: map[string][]uuid.UUID{ + Field: "groups", + Mapping: map[string][]uuid.UUID{ "foo": {ids.ID("ng-foo"), uuid.New()}, "bar": {ids.ID("ng-bar"), ids.ID("ng-bar-2")}, "baz": {ids.ID("ng-baz")}, @@ -150,9 +150,9 @@ func TestGroupSyncTable(t *testing.T) { { Name: "CreateGroups", Settings: &idpsync.GroupSyncSettings{ - GroupField: "groups", - RegexFilter: regexp.MustCompile("^create"), - AutoCreateMissingGroups: true, + Field: "groups", + RegexFilter: regexp.MustCompile("^create"), + AutoCreateMissing: true, }, Groups: map[uuid.UUID]bool{}, ExpectedGroupNames: []string{ @@ -163,9 +163,9 @@ func TestGroupSyncTable(t *testing.T) { { Name: "GroupNamesNoMapping", Settings: &idpsync.GroupSyncSettings{ - GroupField: "groups", - RegexFilter: regexp.MustCompile(".*"), - AutoCreateMissingGroups: false, + Field: "groups", + RegexFilter: regexp.MustCompile(".*"), + AutoCreateMissing: false, }, GroupNames: map[string]bool{ "foo": false, @@ -180,13 +180,13 @@ func TestGroupSyncTable(t *testing.T) { { Name: "NoUser", Settings: &idpsync.GroupSyncSettings{ - GroupField: "groups", - GroupMapping: map[string][]uuid.UUID{ + Field: "groups", + Mapping: map[string][]uuid.UUID{ // Extra ID that does not map to a group "foo": {ids.ID("ow-foo"), uuid.New()}, }, - RegexFilter: nil, - AutoCreateMissingGroups: false, + RegexFilter: nil, + AutoCreateMissing: false, }, NotMember: true, Groups: map[uuid.UUID]bool{ @@ -202,14 +202,14 @@ func TestGroupSyncTable(t *testing.T) { { Name: "LegacyMapping", Settings: &idpsync.GroupSyncSettings{ - GroupField: "groups", + Field: "groups", RegexFilter: regexp.MustCompile("^legacy"), - LegacyGroupNameMapping: map[string]string{ + LegacyNameMapping: map[string]string{ "create-bar": "legacy-bar", "foo": "legacy-foo", "bop": "legacy-bop", }, - AutoCreateMissingGroups: true, + AutoCreateMissing: true, }, Groups: map[uuid.UUID]bool{ ids.ID("lg-foo"): true, diff --git a/coderd/userauth.go b/coderd/userauth.go index 76d29a7c1a9ec..a2c8140c65be5 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -8,7 +8,6 @@ import ( "fmt" "net/http" "net/mail" - "regexp" "sort" "strconv" "strings" @@ -659,7 +658,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { Name: normName, DebugContext: OauthDebugContext{}, GroupSync: idpsync.GroupParams{ - SyncEnabled: false, + SyncEnabled: false, }, OrganizationSync: idpsync.OrganizationParams{ SyncEnabled: false, @@ -743,27 +742,6 @@ type OIDCConfig struct { // support the userinfo endpoint, or if the userinfo endpoint causes // undesirable behavior. IgnoreUserInfo bool - - // TODO: Move all idp fields into the IDPSync struct - // GroupField selects the claim field to be used as the created user's - // groups. If the group field is the empty string, then no group updates - // will ever come from the OIDC provider. - GroupField string - // CreateMissingGroups controls whether groups returned by the OIDC provider - // are automatically created in Coder if they are missing. - CreateMissingGroups bool - // GroupFilter is a regular expression that filters the groups returned by - // the OIDC provider. Any group not matched by this regex will be ignored. - // If the group filter is nil, then no group filtering will occur. - GroupFilter *regexp.Regexp - // GroupAllowList is a list of groups that are allowed to log in. - // If the list length is 0, then the allow list will not be applied and - // this feature is disabled. - GroupAllowList map[string]bool - // GroupMapping controls how groups returned by the OIDC provider get mapped - // to groups within Coder. - // map[oidcGroupName]coderGroupName - GroupMapping map[string]string // UserRoleField selects the claim field to be used as the created user's // roles. If the field is the empty string, then no role updates // will ever come from the OIDC provider. diff --git a/enterprise/coderd/userauth_test.go b/enterprise/coderd/userauth_test.go index 3e94a25a1c013..0ab67542cc2c7 100644 --- a/enterprise/coderd/userauth_test.go +++ b/enterprise/coderd/userauth_test.go @@ -402,7 +402,9 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true - cfg.GroupField = groupClaim + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = groupClaim }, }) @@ -433,8 +435,10 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true - cfg.GroupField = groupClaim - cfg.GroupMapping = map[string]string{oidcGroupName: coderGroupName} + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = groupClaim + dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{map[string]string{oidcGroupName: coderGroupName}} }, }) @@ -468,7 +472,9 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true - cfg.GroupField = groupClaim + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = groupClaim }, }) @@ -502,7 +508,9 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true - cfg.GroupField = groupClaim + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = groupClaim }, }) @@ -537,7 +545,9 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true - cfg.GroupField = groupClaim + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = groupClaim }, }) @@ -559,8 +569,10 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true - cfg.GroupField = groupClaim - cfg.CreateMissingGroups = true + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = groupClaim + dv.OIDC.GroupAutoCreate = true }, }) @@ -582,8 +594,10 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true - cfg.GroupField = groupClaim - cfg.CreateMissingGroups = true + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = groupClaim + dv.OIDC.GroupAutoCreate = true }, }) @@ -606,8 +620,10 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true - cfg.GroupField = groupClaim - cfg.GroupAllowList = map[string]bool{allowedGroup: true} + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = groupClaim + dv.OIDC.GroupAllowList = []string{allowedGroup} }, }) @@ -697,6 +713,7 @@ func TestGroupSync(t *testing.T) { testCases := []struct { name string modCfg func(cfg *coderd.OIDCConfig) + modDV func(dv *codersdk.DeploymentValues) // initialOrgGroups is initial groups in the org initialOrgGroups []string // initialUserGroups is initial groups for the user @@ -718,10 +735,10 @@ func TestGroupSync(t *testing.T) { }, { name: "GroupSyncDisabled", - modCfg: func(cfg *coderd.OIDCConfig) { + modDV: func(dv *codersdk.DeploymentValues) { // Disable group sync - cfg.GroupField = "" - cfg.GroupFilter = regexp.MustCompile(".*") + dv.OIDC.GroupField = "" + dv.OIDC.GroupRegexFilter = serpent.Regexp(*regexp.MustCompile(".*")) }, initialOrgGroups: []string{"a", "b", "c", "d"}, initialUserGroups: []string{"b", "c", "d"}, @@ -732,10 +749,8 @@ func TestGroupSync(t *testing.T) { { // From a,c,b -> b,c,d name: "ChangeUserGroups", - modCfg: func(cfg *coderd.OIDCConfig) { - cfg.GroupMapping = map[string]string{ - "D": "d", - } + modDV: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{map[string]string{"D": "d"}} }, initialOrgGroups: []string{"a", "b", "c", "d"}, initialUserGroups: []string{"a", "b", "c"}, @@ -749,8 +764,8 @@ func TestGroupSync(t *testing.T) { { // From a,c,b -> [] name: "RemoveAllGroups", - modCfg: func(cfg *coderd.OIDCConfig) { - cfg.GroupFilter = regexp.MustCompile(".*") + modDV: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupRegexFilter = serpent.Regexp(*regexp.MustCompile(".*")) }, initialOrgGroups: []string{"a", "b", "c", "d"}, initialUserGroups: []string{"a", "b", "c"}, @@ -763,8 +778,8 @@ func TestGroupSync(t *testing.T) { { // From a,c,b -> b,c,d,e,f name: "CreateMissingGroups", - modCfg: func(cfg *coderd.OIDCConfig) { - cfg.CreateMissingGroups = true + modDV: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupAutoCreate = true }, initialOrgGroups: []string{"a", "b", "c", "d"}, initialUserGroups: []string{"a", "b", "c"}, @@ -777,14 +792,11 @@ func TestGroupSync(t *testing.T) { { // From a,c,b -> b,c,d,e,f name: "CreateMissingGroupsFilter", - modCfg: func(cfg *coderd.OIDCConfig) { - cfg.CreateMissingGroups = true + modDV: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupAutoCreate = true // Only single letter groups - cfg.GroupFilter = regexp.MustCompile("^[a-z]$") - cfg.GroupMapping = map[string]string{ - // Does not match the filter, but does after being mapped! - "zebra": "z", - } + dv.OIDC.GroupRegexFilter = serpent.Regexp(*regexp.MustCompile("^[a-z]$")) + dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{map[string]string{"zebra": "z"}} }, initialOrgGroups: []string{"a", "b", "c", "d"}, initialUserGroups: []string{"a", "b", "c"}, @@ -806,8 +818,15 @@ func TestGroupSync(t *testing.T) { t.Parallel() runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { - cfg.GroupField = "groups" - tc.modCfg(cfg) + if tc.modCfg != nil { + tc.modCfg(cfg) + } + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = "groups" + if tc.modDV != nil { + tc.modDV(dv) + } }, }) From 290cfa51aeaaa14d7edc47a841995e096173ed97 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 6 Sep 2024 09:44:36 -0500 Subject: [PATCH 22/38] fix err name --- coderd/idpsync/group.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 69915125acc71..38c7260b80b0a 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -77,7 +77,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat orgResolver := s.Manager.OrganizationResolver(tx, orgID) settings, err := s.SyncSettings.Group.Resolve(ctx, orgResolver) if err != nil { - if xerrors.Is(err, runtimeconfig.EntryNotFound) { + if xerrors.Is(err, runtimeconfig.ErrEntryNotFound) { // Default to not being configured settings = &GroupSyncSettings{} } else { From c563b10717abb7bb1cb509ddf1785cc540eddc6d Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 6 Sep 2024 09:52:32 -0500 Subject: [PATCH 23/38] some linting cleanup --- coderd/database/models.go | 2 +- coderd/database/querier.go | 2 +- coderd/database/queries.sql.go | 10 +++++----- coderd/idpsync/group.go | 1 - enterprise/coderd/enidpsync/organizations_test.go | 2 +- 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/coderd/database/models.go b/coderd/database/models.go index 950c2674ab310..9e0283ba859c1 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.26.0 +// sqlc v1.25.0 package database diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 3cedeeade34b7..315f2d6fa1cfd 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.26.0 +// sqlc v1.25.0 package database diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 7c7fbbf0f88f0..52044e4e7e90d 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.26.0 +// sqlc v1.25.0 package database @@ -3126,7 +3126,7 @@ func (q *sqlQuerier) GetJFrogXrayScanByWorkspaceAndAgentID(ctx context.Context, } const upsertJFrogXrayScanByWorkspaceAndAgentID = `-- name: UpsertJFrogXrayScanByWorkspaceAndAgentID :exec -INSERT INTO +INSERT INTO jfrog_xray_scans ( agent_id, workspace_id, @@ -3135,7 +3135,7 @@ INSERT INTO medium, results_url ) -VALUES +VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (agent_id, workspace_id) DO UPDATE SET critical = $3, high = $4, medium = $5, results_url = $6 @@ -5863,7 +5863,7 @@ FROM provisioner_keys WHERE organization_id = $1 -AND +AND lower(name) = lower($2) ` @@ -7616,7 +7616,7 @@ func (q *sqlQuerier) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUI } const updateTailnetPeerStatusByCoordinator = `-- name: UpdateTailnetPeerStatusByCoordinator :exec -UPDATE +UPDATE tailnet_peers SET status = $2 diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 38c7260b80b0a..d5709b5b9f722 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -181,7 +181,6 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat return nil }, nil) - if err != nil { return err } diff --git a/enterprise/coderd/enidpsync/organizations_test.go b/enterprise/coderd/enidpsync/organizations_test.go index 8978fa6b46ee1..e01ae5a18d98b 100644 --- a/enterprise/coderd/enidpsync/organizations_test.go +++ b/enterprise/coderd/enidpsync/organizations_test.go @@ -237,7 +237,7 @@ func TestOrganizationSync(t *testing.T) { } // Create a new sync object - sync := enidpsync.NewSync(logger, runtimeconfig.NewStoreManager(rdb), caseData.Entitlements, caseData.Settings) + sync := enidpsync.NewSync(logger, runtimeconfig.NewStoreManager(), caseData.Entitlements, caseData.Settings) for _, exp := range caseData.Exps { t.Run(exp.Name, func(t *testing.T) { params, httpErr := sync.ParseOrganizationClaims(ctx, exp.Claims) From d2c247fc8bba073a75ed9598cc28420ce0c7c5b4 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 6 Sep 2024 10:11:05 -0500 Subject: [PATCH 24/38] dbauthz test for new query --- coderd/database/dbauthz/dbauthz_test.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index f9b9fb49b71fc..4b4874f34247c 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -408,6 +408,18 @@ func (s *MethodTestSuite) TestGroup() { _ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g2.ID, UserID: u1.ID}) check.Args(u1.ID).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns() })) + s.Run("RemoveUserFromGroups", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + u1 := dbgen.User(s.T(), db, database.User{}) + g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + _ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g1.ID, UserID: u1.ID}) + _ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g2.ID, UserID: u1.ID}) + check.Args(database.RemoveUserFromGroupsParams{ + UserID: u1.ID, + GroupIds: []uuid.UUID{g1.ID, g2.ID}, + }).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns(slice.New(g1.ID, g2.ID)) + })) s.Run("UpdateGroupByID", s.Subtest(func(db database.Store, check *expects) { g := dbgen.Group(s.T(), db, database.Group{}) check.Args(database.UpdateGroupByIDParams{ From 12685bd985c4d3446edd31f718dbe2cabe8ba6bc Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 6 Sep 2024 10:39:49 -0500 Subject: [PATCH 25/38] fixup comments --- coderd/idpsync/group.go | 31 +++++++++++++++++++++------ enterprise/coderd/enidpsync/groups.go | 3 ++- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index d5709b5b9f722..743b368b094f3 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -65,6 +65,8 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat } // Figure out which organizations the user is a member of. + // The "Everyone" group is always included, so we can infer organization + // membership via the groups the user is in. userOrgs := make(map[uuid.UUID][]database.GetGroupsRow) for _, g := range userGroups { g := g @@ -72,6 +74,8 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat } // For each org, we need to fetch the sync settings + // This loop also handles any legacy settings for the default + // organization. orgSettings := make(map[uuid.UUID]GroupSyncSettings) for orgID := range userOrgs { orgResolver := s.Manager.OrganizationResolver(tx, orgID) @@ -97,16 +101,23 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat orgSettings[orgID] = *settings } - // collect all diffs to do 1 sql update for all orgs + // groupIDsToAdd & groupIDsToRemove are the final group differences + // needed to be applied to user. The loop below will iterate over all + // organizations the user is in, and determine the diffs. + // The diffs are applied as a batch sql query, rather than each + // organization having to execute a query. groupIDsToAdd := make([]uuid.UUID, 0) groupIDsToRemove := make([]uuid.UUID, 0) // For each org, determine which groups the user should land in for orgID, settings := range orgSettings { if settings.Field == "" { // No group sync enabled for this org, so do nothing. + // The user can remain in their groups for this org. continue } + // expectedGroups is the set of groups the IDP expects the + // user to be a member of. expectedGroups, err := settings.ParseClaims(orgID, params.MergedClaims) if err != nil { s.Logger.Debug(ctx, "failed to parse claims for groups", @@ -117,7 +128,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat // Unsure where to raise this error on the UI or database. continue } - // Everyone group is always implied. + // Everyone group is always implied, so include it. expectedGroups = append(expectedGroups, ExpectedGroup{ OrganizationID: orgID, GroupID: &orgID, @@ -134,6 +145,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat GroupName: &f.Group.Name, } }) + add, remove := slice.SymmetricDifferenceFunc(existingGroupsTyped, expectedGroups, func(a, b ExpectedGroup) bool { // Must match if a.OrganizationID != b.OrganizationID { @@ -150,10 +162,10 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat }) for _, r := range remove { - // This should never happen. All group removals come from the - // existing set, which come from the db. All groups from the - // database have IDs. This code is purely defensive. if r.GroupID == nil { + // This should never happen. All group removals come from the + // existing set, which come from the db. All groups from the + // database have IDs. This code is purely defensive. detail := "user:" + user.Username if r.GroupName != nil { detail += fmt.Sprintf(" from group %s", *r.GroupName) @@ -166,6 +178,11 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat // HandleMissingGroups will add the new groups to the org if // the settings specify. It will convert all group names into uuids // for easier assignment. + // TODO: This code should be batched at the end of the for loop. + // Optimizing this is being pushed because if AutoCreate is disabled, + // this code will only add cost on the first login for each user. + // AutoCreate is usually disabled for large deployments. + // For small deployments, this is less of a problem. assignGroups, err := settings.HandleMissingGroups(ctx, tx, orgID, add) if err != nil { return xerrors.Errorf("handle missing groups: %w", err) @@ -174,6 +191,8 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat groupIDsToAdd = append(groupIDsToAdd, assignGroups...) } + // ApplyGroupDifference will take the total adds and removes, and apply + // them. err = s.ApplyGroupDifference(ctx, tx, user, groupIDsToAdd, groupIDsToRemove) if err != nil { return xerrors.Errorf("apply group difference: %w", err) @@ -190,8 +209,6 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat // ApplyGroupDifference will add and remove the user from the specified groups. func (s AGPLIDPSync) ApplyGroupDifference(ctx context.Context, tx database.Store, user database.User, add []uuid.UUID, removeIDs []uuid.UUID) error { - // Always do group removal before group add. This way if there is an error, - // we error on the underprivileged side. if len(removeIDs) > 0 { removedGroupIDs, err := tx.RemoveUserFromGroups(ctx, database.RemoveUserFromGroupsParams{ UserID: user.ID, diff --git a/enterprise/coderd/enidpsync/groups.go b/enterprise/coderd/enidpsync/groups.go index 2ecc8703e29cd..932357e2772fe 100644 --- a/enterprise/coderd/enidpsync/groups.go +++ b/enterprise/coderd/enidpsync/groups.go @@ -17,7 +17,8 @@ func (e EnterpriseIDPSync) GroupSyncEnabled() bool { // ParseGroupClaims parses the user claims and handles deployment wide group behavior. // Almost all behavior is deferred since each organization configures it's own // group sync settings. -// TODO: Implement group allow_list behavior here since that is deployment wide. +// GroupAllowList is implemented here to prevent login by unauthorized users. +// TODO: GroupAllowList overlaps with the default organization group sync settings. func (e EnterpriseIDPSync) ParseGroupClaims(ctx context.Context, mergedClaims jwt.MapClaims) (idpsync.GroupParams, *idpsync.HTTPError) { if !e.GroupSyncEnabled() { return e.AGPLIDPSync.ParseGroupClaims(ctx, mergedClaims) From bf0d4edac865146f84456658bf768775dfd27e77 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 6 Sep 2024 10:42:46 -0500 Subject: [PATCH 26/38] fixup compile issues from rebase --- coderd/idpsync/group_test.go | 12 ++++++------ coderd/idpsync/idpsync.go | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index 4e56260400114..0ef4e18b40bec 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -29,7 +29,7 @@ func TestParseGroupClaims(t *testing.T) { t.Parallel() s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewStoreManager(), + runtimeconfig.NewManager(), idpsync.DeploymentSyncSettings{}) ctx := testutil.Context(t, testutil.WaitMedium) @@ -45,7 +45,7 @@ func TestParseGroupClaims(t *testing.T) { t.Parallel() s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewStoreManager(), + runtimeconfig.NewManager(), idpsync.DeploymentSyncSettings{ GroupField: "groups", GroupAllowList: map[string]struct{}{ @@ -232,7 +232,7 @@ func TestGroupSyncTable(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) - manager := runtimeconfig.NewStoreManager() + manager := runtimeconfig.NewManager() s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), manager, idpsync.DeploymentSyncSettings{ @@ -264,7 +264,7 @@ func TestGroupSyncTable(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) - manager := runtimeconfig.NewStoreManager() + manager := runtimeconfig.NewManager() s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), manager, // Also sync the default org! @@ -444,7 +444,7 @@ func TestApplyGroupDifference(t *testing.T) { for _, tc := range testCase { tc := tc t.Run(tc.Name, func(t *testing.T) { - mgr := runtimeconfig.NewStoreManager() + mgr := runtimeconfig.NewManager() db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitMedium) @@ -505,7 +505,7 @@ func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, require.NoError(t, err, "Everyone group for an org") } - manager := runtimeconfig.NewStoreManager() + manager := runtimeconfig.NewManager() orgResolver := manager.OrganizationResolver(db, org.ID) err = s.Group.SetRuntimeValue(context.Background(), orgResolver, def.Settings) require.NoError(t, err) diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index 7fac0e7329d3d..2c2b185c619c9 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -45,7 +45,7 @@ type IDPSync interface { // IDP. All related code to syncing user information should be in this package. type AGPLIDPSync struct { Logger slog.Logger - Manager runtimeconfig.Manager + Manager *runtimeconfig.Manager SyncSettings } @@ -108,7 +108,7 @@ type SyncSettings struct { Group runtimeconfig.RuntimeEntry[*GroupSyncSettings] } -func NewAGPLSync(logger slog.Logger, manager runtimeconfig.Manager, settings DeploymentSyncSettings) *AGPLIDPSync { +func NewAGPLSync(logger slog.Logger, manager *runtimeconfig.Manager, settings DeploymentSyncSettings) *AGPLIDPSync { return &AGPLIDPSync{ Logger: logger.Named("idp-sync"), Manager: manager, From f95128e14401ba41eca811df86be8e2398c8dbe5 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 6 Sep 2024 11:23:28 -0500 Subject: [PATCH 27/38] add test for disabled sync --- coderd/idpsync/group.go | 17 ++---- coderd/idpsync/group_test.go | 54 +++++++++++++++++++ coderd/idpsync/organizations_test.go | 4 +- coderd/runtimeconfig/entry.go | 9 ++++ enterprise/coderd/enidpsync/enidpsync.go | 2 +- enterprise/coderd/enidpsync/groups_test.go | 6 +-- .../coderd/enidpsync/organizations_test.go | 2 +- 7 files changed, 74 insertions(+), 20 deletions(-) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 743b368b094f3..a54f6fbfa09cf 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -81,12 +81,11 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat orgResolver := s.Manager.OrganizationResolver(tx, orgID) settings, err := s.SyncSettings.Group.Resolve(ctx, orgResolver) if err != nil { - if xerrors.Is(err, runtimeconfig.ErrEntryNotFound) { - // Default to not being configured - settings = &GroupSyncSettings{} - } else { + if !xerrors.Is(err, runtimeconfig.ErrEntryNotFound) { return xerrors.Errorf("resolve group sync settings: %w", err) } + // Default to not being configured + settings = &GroupSyncSettings{} } // Legacy deployment settings will override empty settings. @@ -273,15 +272,7 @@ func (s *GroupSyncSettings) Set(v string) error { } func (s *GroupSyncSettings) String() string { - v, err := json.Marshal(s) - if err != nil { - return "decode failed: " + err.Error() - } - return string(v) -} - -func (s *GroupSyncSettings) Type() string { - return "GroupSyncSettings" + return runtimeconfig.JSONString(s) } type ExpectedGroup struct { diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index 0ef4e18b40bec..07c9052881fad 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -342,6 +342,60 @@ func TestGroupSyncTable(t *testing.T) { }) } +func TestSyncDisabled(t *testing.T) { + t.Parallel() + + if dbtestutil.WillUsePostgres() { + t.Skip("Skipping test because it populates a lot of db entries, which is slow on postgres.") + } + + db, _ := dbtestutil.NewDB(t) + manager := runtimeconfig.NewManager() + s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), + manager, + idpsync.DeploymentSyncSettings{}, + ) + + ids := coderdtest.NewDeterministicUUIDGenerator() + ctx := testutil.Context(t, testutil.WaitSuperLong) + user := dbgen.User(t, db, database.User{}) + orgID := uuid.New() + + def := orgSetupDefinition{ + Name: "SyncDisabled", + Groups: map[uuid.UUID]bool{ + ids.ID("foo"): true, + ids.ID("bar"): true, + ids.ID("baz"): false, + ids.ID("bop"): false, + }, + Settings: &idpsync.GroupSyncSettings{ + Field: "groups", + Mapping: map[string][]uuid.UUID{ + "foo": {ids.ID("foo")}, + "baz": {ids.ID("baz")}, + }, + }, + ExpectedGroups: []uuid.UUID{ + ids.ID("foo"), + ids.ID("bar"), + }, + } + + SetupOrganization(t, s, db, user, orgID, def) + + // Do the group sync! + err := s.SyncGroups(ctx, db, user, idpsync.GroupParams{ + SyncEnabled: false, + MergedClaims: jwt.MapClaims{ + "groups": []string{"baz", "bop"}, + }, + }) + require.NoError(t, err) + + def.Assert(t, orgID, db, user) +} + // TestApplyGroupDifference is mainly testing the database functions func TestApplyGroupDifference(t *testing.T) { t.Parallel() diff --git a/coderd/idpsync/organizations_test.go b/coderd/idpsync/organizations_test.go index 934d7d83816ab..1670beaaedc75 100644 --- a/coderd/idpsync/organizations_test.go +++ b/coderd/idpsync/organizations_test.go @@ -20,7 +20,7 @@ func TestParseOrganizationClaims(t *testing.T) { t.Parallel() s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewStoreManager(), + runtimeconfig.NewManager(), idpsync.DeploymentSyncSettings{ OrganizationField: "", OrganizationMapping: nil, @@ -42,7 +42,7 @@ func TestParseOrganizationClaims(t *testing.T) { // AGPL has limited behavior s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewStoreManager(), + runtimeconfig.NewManager(), idpsync.DeploymentSyncSettings{ OrganizationField: "orgs", OrganizationMapping: map[string][]uuid.UUID{ diff --git a/coderd/runtimeconfig/entry.go b/coderd/runtimeconfig/entry.go index 780138a89d03b..c0260b0268ddb 100644 --- a/coderd/runtimeconfig/entry.go +++ b/coderd/runtimeconfig/entry.go @@ -2,6 +2,7 @@ package runtimeconfig import ( "context" + "encoding/json" "fmt" "golang.org/x/xerrors" @@ -93,3 +94,11 @@ func (e *RuntimeEntry[T]) name() (string, error) { return e.n, nil } + +func JSONString(v any) string { + s, err := json.Marshal(v) + if err != nil { + return "decode failed: " + err.Error() + } + return string(s) +} diff --git a/enterprise/coderd/enidpsync/enidpsync.go b/enterprise/coderd/enidpsync/enidpsync.go index a7ff1eaa07257..c7ba8dd3ecdc6 100644 --- a/enterprise/coderd/enidpsync/enidpsync.go +++ b/enterprise/coderd/enidpsync/enidpsync.go @@ -17,7 +17,7 @@ type EnterpriseIDPSync struct { *idpsync.AGPLIDPSync } -func NewSync(logger slog.Logger, manager runtimeconfig.Manager, set *entitlements.Set, settings idpsync.DeploymentSyncSettings) *EnterpriseIDPSync { +func NewSync(logger slog.Logger, manager *runtimeconfig.Manager, set *entitlements.Set, settings idpsync.DeploymentSyncSettings) *EnterpriseIDPSync { return &EnterpriseIDPSync{ entitlements: set, AGPLIDPSync: idpsync.NewAGPLSync(logger.With(slog.F("enterprise_capable", "true")), manager, settings), diff --git a/enterprise/coderd/enidpsync/groups_test.go b/enterprise/coderd/enidpsync/groups_test.go index 8103f8a002937..77b078cd9e3f0 100644 --- a/enterprise/coderd/enidpsync/groups_test.go +++ b/enterprise/coderd/enidpsync/groups_test.go @@ -30,7 +30,7 @@ func TestEnterpriseParseGroupClaims(t *testing.T) { t.Parallel() s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewStoreManager(), + runtimeconfig.NewManager(), entitlements.New(), idpsync.DeploymentSyncSettings{}) @@ -46,7 +46,7 @@ func TestEnterpriseParseGroupClaims(t *testing.T) { t.Parallel() s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewStoreManager(), + runtimeconfig.NewManager(), entitled, idpsync.DeploymentSyncSettings{ GroupField: "groups", @@ -74,7 +74,7 @@ func TestEnterpriseParseGroupClaims(t *testing.T) { t.Parallel() s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewStoreManager(), + runtimeconfig.NewManager(), entitled, idpsync.DeploymentSyncSettings{ GroupField: "groups", diff --git a/enterprise/coderd/enidpsync/organizations_test.go b/enterprise/coderd/enidpsync/organizations_test.go index e01ae5a18d98b..cb6da2723b2f5 100644 --- a/enterprise/coderd/enidpsync/organizations_test.go +++ b/enterprise/coderd/enidpsync/organizations_test.go @@ -237,7 +237,7 @@ func TestOrganizationSync(t *testing.T) { } // Create a new sync object - sync := enidpsync.NewSync(logger, runtimeconfig.NewStoreManager(), caseData.Entitlements, caseData.Settings) + sync := enidpsync.NewSync(logger, runtimeconfig.NewManager(), caseData.Entitlements, caseData.Settings) for _, exp := range caseData.Exps { t.Run(exp.Name, func(t *testing.T) { params, httpErr := sync.ParseOrganizationClaims(ctx, exp.Claims) From 88b0ad9b86a9be67e83210015e332e43a5dbe10f Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 6 Sep 2024 12:51:52 -0500 Subject: [PATCH 28/38] linting --- coderd/database/queries.sql.go | 8 ++++---- coderd/idpsync/group_test.go | 2 ++ enterprise/coderd/userauth_test.go | 6 +++--- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 52044e4e7e90d..191cf291102ad 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3126,7 +3126,7 @@ func (q *sqlQuerier) GetJFrogXrayScanByWorkspaceAndAgentID(ctx context.Context, } const upsertJFrogXrayScanByWorkspaceAndAgentID = `-- name: UpsertJFrogXrayScanByWorkspaceAndAgentID :exec -INSERT INTO +INSERT INTO jfrog_xray_scans ( agent_id, workspace_id, @@ -3135,7 +3135,7 @@ INSERT INTO medium, results_url ) -VALUES +VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (agent_id, workspace_id) DO UPDATE SET critical = $3, high = $4, medium = $5, results_url = $6 @@ -5863,7 +5863,7 @@ FROM provisioner_keys WHERE organization_id = $1 -AND +AND lower(name) = lower($2) ` @@ -7616,7 +7616,7 @@ func (q *sqlQuerier) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUI } const updateTailnetPeerStatusByCoordinator = `-- name: UpdateTailnetPeerStatusByCoordinator :exec -UPDATE +UPDATE tailnet_peers SET status = $2 diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index 07c9052881fad..a3c9140577b8c 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -498,6 +498,8 @@ func TestApplyGroupDifference(t *testing.T) { for _, tc := range testCase { tc := tc t.Run(tc.Name, func(t *testing.T) { + t.Parallel() + mgr := runtimeconfig.NewManager() db, _ := dbtestutil.NewDB(t) diff --git a/enterprise/coderd/userauth_test.go b/enterprise/coderd/userauth_test.go index 0ab67542cc2c7..3b42dc1aeec5f 100644 --- a/enterprise/coderd/userauth_test.go +++ b/enterprise/coderd/userauth_test.go @@ -438,7 +438,7 @@ func TestUserOIDC(t *testing.T) { }, DeploymentValues: func(dv *codersdk.DeploymentValues) { dv.OIDC.GroupField = groupClaim - dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{map[string]string{oidcGroupName: coderGroupName}} + dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{Value: map[string]string{oidcGroupName: coderGroupName}} }, }) @@ -750,7 +750,7 @@ func TestGroupSync(t *testing.T) { // From a,c,b -> b,c,d name: "ChangeUserGroups", modDV: func(dv *codersdk.DeploymentValues) { - dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{map[string]string{"D": "d"}} + dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{Value: map[string]string{"D": "d"}} }, initialOrgGroups: []string{"a", "b", "c", "d"}, initialUserGroups: []string{"a", "b", "c"}, @@ -796,7 +796,7 @@ func TestGroupSync(t *testing.T) { dv.OIDC.GroupAutoCreate = true // Only single letter groups dv.OIDC.GroupRegexFilter = serpent.Regexp(*regexp.MustCompile("^[a-z]$")) - dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{map[string]string{"zebra": "z"}} + dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{Value: map[string]string{"zebra": "z"}} }, initialOrgGroups: []string{"a", "b", "c", "d"}, initialUserGroups: []string{"a", "b", "c"}, From 6491f6ac295dfc3a765d95ef883d14f61647d4fc Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 6 Sep 2024 13:06:23 -0500 Subject: [PATCH 29/38] chore: handle db conflicts gracefully --- coderd/database/querier.go | 1 + coderd/database/queries.sql.go | 2 ++ coderd/database/queries/groupmembers.sql | 2 ++ coderd/idpsync/group.go | 2 ++ coderd/idpsync/group_test.go | 16 ++++++++++++---- 5 files changed, 19 insertions(+), 4 deletions(-) diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 315f2d6fa1cfd..ee9a64f12076d 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -370,6 +370,7 @@ type sqlcQuerier interface { InsertTemplateVersionWorkspaceTag(ctx context.Context, arg InsertTemplateVersionWorkspaceTagParams) (TemplateVersionWorkspaceTag, error) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) // InsertUserGroupsByID adds a user to all provided groups, if they exist. + // If there is a conflict, the user is already a member InsertUserGroupsByID(ctx context.Context, arg InsertUserGroupsByIDParams) ([]uuid.UUID, error) // InsertUserGroupsByName adds a user to all provided groups, if they exist. InsertUserGroupsByName(ctx context.Context, arg InsertUserGroupsByNameParams) error diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 191cf291102ad..c9f1d1de145d9 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1462,6 +1462,7 @@ SELECT groups.id FROM groups +ON CONFLICT DO NOTHING RETURNING group_id ` @@ -1471,6 +1472,7 @@ type InsertUserGroupsByIDParams struct { } // InsertUserGroupsByID adds a user to all provided groups, if they exist. +// If there is a conflict, the user is already a member func (q *sqlQuerier) InsertUserGroupsByID(ctx context.Context, arg InsertUserGroupsByIDParams) ([]uuid.UUID, error) { rows, err := q.db.QueryContext(ctx, insertUserGroupsByID, arg.UserID, pq.Array(arg.GroupIds)) if err != nil { diff --git a/coderd/database/queries/groupmembers.sql b/coderd/database/queries/groupmembers.sql index 814f878cb9232..4efe9bf488590 100644 --- a/coderd/database/queries/groupmembers.sql +++ b/coderd/database/queries/groupmembers.sql @@ -46,6 +46,8 @@ SELECT groups.id FROM groups +-- If there is a conflict, the user is already a member +ON CONFLICT DO NOTHING RETURNING group_id; -- name: RemoveUserFromAllGroups :exec diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index a54f6fbfa09cf..704fd1b10ea75 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -226,6 +226,8 @@ func (s AGPLIDPSync) ApplyGroupDifference(ctx context.Context, tx database.Store } if len(add) > 0 { + add = slice.Unique(add) + // Defensive programming to only insert uniques. assignedGroupIDs, err := tx.InsertUserGroupsByID(ctx, database.InsertUserGroupsByIDParams{ UserID: user.ID, GroupIds: add, diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index a3c9140577b8c..0f9d0345f1e60 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -2,6 +2,7 @@ package idpsync_test import ( "context" + "database/sql" "regexp" "testing" @@ -9,6 +10,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" "golang.org/x/exp/slices" + "golang.org/x/xerrors" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd/coderdtest" @@ -64,6 +66,7 @@ func TestParseGroupClaims(t *testing.T) { func TestGroupSyncTable(t *testing.T) { t.Parallel() + // Last checked, takes 30s with postgres on a fast machine. if dbtestutil.WillUsePostgres() { t.Skip("Skipping test because it populates a lot of db entries, which is slow on postgres.") } @@ -553,10 +556,15 @@ func TestApplyGroupDifference(t *testing.T) { func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, user database.User, orgID uuid.UUID, def orgSetupDefinition) { t.Helper() - org := dbgen.Organization(t, db, database.Organization{ - ID: orgID, - }) - _, err := db.InsertAllUsersGroup(context.Background(), org.ID) + // Account that the org might be the default organization + org, err := db.GetOrganizationByID(context.Background(), orgID) + if xerrors.Is(err, sql.ErrNoRows) { + org = dbgen.Organization(t, db, database.Organization{ + ID: orgID, + }) + } + + _, err = db.InsertAllUsersGroup(context.Background(), org.ID) if !database.IsUniqueViolation(err) { require.NoError(t, err, "Everyone group for an org") } From bd2328836d951ae48993e0d19cd6e07e0b881e44 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 6 Sep 2024 13:21:54 -0500 Subject: [PATCH 30/38] test expected group equality --- coderd/idpsync/group.go | 38 ++++++--- coderd/idpsync/group_test.go | 146 +++++++++++++++++++++++++++++++++++ 2 files changed, 172 insertions(+), 12 deletions(-) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 704fd1b10ea75..c779b7ed15df3 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -146,18 +146,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat }) add, remove := slice.SymmetricDifferenceFunc(existingGroupsTyped, expectedGroups, func(a, b ExpectedGroup) bool { - // Must match - if a.OrganizationID != b.OrganizationID { - return false - } - // Only the name or the name needs to be checked, priority is given to the ID. - if a.GroupID != nil && b.GroupID != nil { - return *a.GroupID == *b.GroupID - } - if a.GroupName != nil && b.GroupName != nil { - return *a.GroupName == *b.GroupName - } - return false + return a.Equal(b) }) for _, r := range remove { @@ -283,6 +272,31 @@ type ExpectedGroup struct { GroupName *string } +// Equal compares two ExpectedGroups. The org id must be the same. +// If the group ID is set, it will be compared and take priorty, ignoring the +// name value. So 2 groups with the same ID but different names will be +// considered equal. +func (a ExpectedGroup) Equal(b ExpectedGroup) bool { + // Must match + if a.OrganizationID != b.OrganizationID { + return false + } + // Only the name or the name needs to be checked, priority is given to the ID. + if a.GroupID != nil && b.GroupID != nil { + return *a.GroupID == *b.GroupID + } + if a.GroupName != nil && b.GroupName != nil { + return *a.GroupName == *b.GroupName + } + + // If everything is nil, it is equal. Although a bit pointless + if a.GroupID == nil && b.GroupID == nil && + a.GroupName == nil && b.GroupName == nil { + return true + } + return false +} + // ParseClaims will take the merged claims from the IDP and return the groups // the user is expected to be a member of. The expected group can either be a // name or an ID. diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index 0f9d0345f1e60..cf312a576d720 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -21,6 +21,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/idpsync" "github.com/coder/coder/v2/coderd/runtimeconfig" + "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/testutil" ) @@ -553,6 +554,151 @@ func TestApplyGroupDifference(t *testing.T) { } } +func TestExpectedGroupEqual(t *testing.T) { + t.Parallel() + + ids := coderdtest.NewDeterministicUUIDGenerator() + testCases := []struct { + Name string + A idpsync.ExpectedGroup + B idpsync.ExpectedGroup + Equal bool + }{ + { + Name: "Empty", + A: idpsync.ExpectedGroup{}, + B: idpsync.ExpectedGroup{}, + Equal: true, + }, + { + Name: "DifferentOrgs", + A: idpsync.ExpectedGroup{ + OrganizationID: uuid.New(), + GroupID: ptr.Ref(ids.ID("g1")), + GroupName: nil, + }, + B: idpsync.ExpectedGroup{ + OrganizationID: uuid.New(), + GroupID: ptr.Ref(ids.ID("g1")), + GroupName: nil, + }, + Equal: false, + }, + { + Name: "SameID", + A: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: ptr.Ref(ids.ID("g1")), + GroupName: nil, + }, + B: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: ptr.Ref(ids.ID("g1")), + GroupName: nil, + }, + Equal: true, + }, + { + Name: "DifferentIDs", + A: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: ptr.Ref(uuid.New()), + GroupName: nil, + }, + B: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: ptr.Ref(uuid.New()), + GroupName: nil, + }, + Equal: false, + }, + { + Name: "SameName", + A: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: nil, + GroupName: ptr.Ref("foo"), + }, + B: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: nil, + GroupName: ptr.Ref("foo"), + }, + Equal: true, + }, + { + Name: "DifferentName", + A: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: nil, + GroupName: ptr.Ref("foo"), + }, + B: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: nil, + GroupName: ptr.Ref("bar"), + }, + Equal: false, + }, + // Edge cases + { + // A bit strange, but valid as ID takes priority. + // We assume 2 groups with the same ID are equal, even if + // their names are different. Names are mutable, IDs are not, + // so there is 0% chance they are different groups. + Name: "DifferentIDSameName", + A: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: ptr.Ref(ids.ID("g1")), + GroupName: ptr.Ref("foo"), + }, + B: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: ptr.Ref(ids.ID("g1")), + GroupName: ptr.Ref("bar"), + }, + Equal: true, + }, + { + Name: "MixedNils", + A: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: ptr.Ref(ids.ID("g1")), + GroupName: nil, + }, + B: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: nil, + GroupName: ptr.Ref("bar"), + }, + Equal: false, + }, + { + Name: "NoComparable", + A: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: ptr.Ref(ids.ID("g1")), + GroupName: nil, + }, + B: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: nil, + GroupName: nil, + }, + Equal: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.Name, func(t *testing.T) { + t.Parallel() + + require.Equal(t, tc.Equal, tc.A.Equal(tc.B)) + }) + } +} + func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, user database.User, orgID uuid.UUID, def orgSetupDefinition) { t.Helper() From a390ec4cba6db241cbd1ff42115c7ad42907656b Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 6 Sep 2024 13:30:50 -0500 Subject: [PATCH 31/38] cleanup comments --- coderd/idpsync/group.go | 7 +++---- coderd/idpsync/idpsync.go | 4 +++- coderd/userauth.go | 2 ++ 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index c779b7ed15df3..8a097dca37f47 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -206,7 +206,7 @@ func (s AGPLIDPSync) ApplyGroupDifference(ctx context.Context, tx database.Store return xerrors.Errorf("remove user from %d groups: %w", len(removeIDs), err) } if len(removedGroupIDs) != len(removeIDs) { - s.Logger.Debug(ctx, "failed to remove user from all groups", + s.Logger.Debug(ctx, "user not removed from expected number of groups", slog.F("user_id", user.ID), slog.F("groups_removed_count", len(removedGroupIDs)), slog.F("expected_count", len(removeIDs)), @@ -225,7 +225,7 @@ func (s AGPLIDPSync) ApplyGroupDifference(ctx context.Context, tx database.Store return xerrors.Errorf("insert user into %d groups: %w", len(add), err) } if len(assignedGroupIDs) != len(add) { - s.Logger.Debug(ctx, "failed to assign all groups to user", + s.Logger.Debug(ctx, "user not assigned to expected number of groups", slog.F("user_id", user.ID), slog.F("groups_assigned_count", len(assignedGroupIDs)), slog.F("expected_count", len(add)), @@ -355,8 +355,7 @@ func (s GroupSyncSettings) ParseClaims(orgID uuid.UUID, mergedClaims jwt.MapClai // TODO: Batching this would be better, as this is 1 or 2 db calls per organization. func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database.Store, orgID uuid.UUID, add []ExpectedGroup) ([]uuid.UUID, error) { // All expected that are missing IDs means the group does not exist - // in the database. Either remove them, or create them if auto create is - // turned on. + // in the database, or it is a legacy mapping, and we need to do a lookup. var missingGroups []string addIDs := make([]uuid.UUID, 0) diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index 2c2b185c619c9..2c99e780ffee6 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -90,7 +90,9 @@ func FromDeploymentValues(dv *codersdk.DeploymentValues) DeploymentSyncSettings OrganizationMapping: dv.OIDC.OrganizationMapping.Value, OrganizationAssignDefault: dv.OIDC.OrganizationAssignDefault.Value(), - // TODO: Separate group field for allow list from default org + // TODO: Separate group field for allow list from default org. + // Right now you cannot disable group sync from the default org and + // configure an allow list. GroupField: dv.OIDC.GroupField.Value(), GroupAllowList: ConvertAllowList(dv.OIDC.GroupAllowList.Value()), Legacy: DefaultOrgLegacySettings{ diff --git a/coderd/userauth.go b/coderd/userauth.go index a2c8140c65be5..223f697c09bb9 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -1389,6 +1389,8 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C return xerrors.Errorf("sync organizations: %w", err) } + // Group sync needs to occur after org sync, since a user can join an org, + // then have their groups sync to said org. err = api.IDPSync.SyncGroups(ctx, tx, user, params.GroupSync) if err != nil { return xerrors.Errorf("sync groups: %w", err) From a0a1c53bdfcdd7f2259403ce0165ff5c081c98b3 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 6 Sep 2024 13:34:33 -0500 Subject: [PATCH 32/38] spelling mistake --- coderd/idpsync/group.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 8a097dca37f47..91e440c38b668 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -273,7 +273,7 @@ type ExpectedGroup struct { } // Equal compares two ExpectedGroups. The org id must be the same. -// If the group ID is set, it will be compared and take priorty, ignoring the +// If the group ID is set, it will be compared and take priority, ignoring the // name value. So 2 groups with the same ID but different names will be // considered equal. func (a ExpectedGroup) Equal(b ExpectedGroup) bool { From a86ba834180aad44a6ac675783abb1e2a6263dbd Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 6 Sep 2024 14:35:13 -0500 Subject: [PATCH 33/38] linting: --- enterprise/coderd/enidpsync/groups.go | 1 - 1 file changed, 1 deletion(-) diff --git a/enterprise/coderd/enidpsync/groups.go b/enterprise/coderd/enidpsync/groups.go index 932357e2772fe..dc8456fc6b1c9 100644 --- a/enterprise/coderd/enidpsync/groups.go +++ b/enterprise/coderd/enidpsync/groups.go @@ -61,7 +61,6 @@ func (e EnterpriseIDPSync) ParseGroupClaims(ctx context.Context, mergedClaims jw RenderStaticPage: true, } } - } return idpsync.GroupParams{ From 0df7f28209e5b32747f7859871cfa7c954faf6f6 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 9 Sep 2024 14:31:15 -0500 Subject: [PATCH 34/38] add interface method to allow api crud --- coderd/idpsync/group.go | 3 +++ coderd/idpsync/idpsync.go | 5 ++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 91e440c38b668..153f5db91199f 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -28,6 +28,9 @@ func (AGPLIDPSync) GroupSyncEnabled() bool { // AGPL does not support syncing groups. return false } +func (s AGPLIDPSync) GroupSyncSettings() runtimeconfig.RuntimeEntry[*GroupSyncSettings] { + return s.Group +} func (s AGPLIDPSync) ParseGroupClaims(_ context.Context, _ jwt.MapClaims) (GroupParams, *HTTPError) { return GroupParams{ diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index 2c99e780ffee6..2c8ed10ce9bcc 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -36,9 +36,12 @@ type IDPSync interface { // ParseGroupClaims takes claims from an OIDC provider, and returns the params // for group syncing. Most of the logic happens in SyncGroups. ParseGroupClaims(ctx context.Context, mergedClaims jwt.MapClaims) (GroupParams, *HTTPError) - // SyncGroups assigns and removes users from groups based on the provided params. SyncGroups(ctx context.Context, db database.Store, user database.User, params GroupParams) error + // GroupSyncSettings is exposed for the API to implement CRUD operations + // on the settings used by IDPSync. This entry is thread safe and can be + // accessed concurrently. The settings are stored in the database. + GroupSyncSettings() runtimeconfig.RuntimeEntry[*GroupSyncSettings] } // AGPLIDPSync is the configuration for syncing user information from an external From 7a802a9196b1501b77fb1d979bb1c14456d090c7 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 11 Sep 2024 11:49:14 -0500 Subject: [PATCH 35/38] Remove testable example --- coderd/coderdtest/uuids_test.go | 32 ++++++++------------------------ 1 file changed, 8 insertions(+), 24 deletions(-) diff --git a/coderd/coderdtest/uuids_test.go b/coderd/coderdtest/uuids_test.go index 5a0e10935bd50..935be36eb8b15 100644 --- a/coderd/coderdtest/uuids_test.go +++ b/coderd/coderdtest/uuids_test.go @@ -1,33 +1,17 @@ package coderdtest_test import ( - "github.com/google/uuid" + "testing" + + "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd/coderdtest" ) -func ExampleNewDeterministicUUIDGenerator() { - det := coderdtest.NewDeterministicUUIDGenerator() - testCases := []struct { - CreateUsers []uuid.UUID - ExpectedIDs []uuid.UUID - }{ - { - CreateUsers: []uuid.UUID{ - det.ID("player1"), - det.ID("player2"), - }, - ExpectedIDs: []uuid.UUID{ - det.ID("player1"), - det.ID("player2"), - }, - }, - } +func TestDeterministicUUIDGenerator(t *testing.T) { + t.Parallel() - for _, tc := range testCases { - tc := tc - _ = tc - // Do the test with CreateUsers as the setup, and the expected IDs - // will match. - } + ids := coderdtest.NewDeterministicUUIDGenerator() + require.Equal(t, ids.ID("g1"), ids.ID("g1")) + require.NotEqual(t, ids.ID("g1"), ids.ID("g2")) } From 611f1e3a6a7b74ce159d561369352db81bab8143 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 11 Sep 2024 11:51:51 -0500 Subject: [PATCH 36/38] fix formatting of sql, add a comment --- coderd/database/queries.sql.go | 2 +- coderd/database/queries/groups.sql | 2 +- coderd/idpsync/group.go | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index c9f1d1de145d9..3616fcb66d3fb 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1680,7 +1680,7 @@ WHERE ELSE true END AND CASE WHEN array_length($3 :: text[], 1) > 0 THEN - groups.name = ANY($3) + groups.name = ANY($3) ELSE true END ` diff --git a/coderd/database/queries/groups.sql b/coderd/database/queries/groups.sql index 0df848d6a6d05..780c0d0154740 100644 --- a/coderd/database/queries/groups.sql +++ b/coderd/database/queries/groups.sql @@ -53,7 +53,7 @@ WHERE ELSE true END AND CASE WHEN array_length(@group_names :: text[], 1) > 0 THEN - groups.name = ANY(@group_names) + groups.name = ANY(@group_names) ELSE true END ; diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 153f5db91199f..7c61aeb2fe4ef 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -128,6 +128,9 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat slog.Error(err), ) // Unsure where to raise this error on the UI or database. + // TODO: This error prevents group sync, but we have no way + // to raise this to an org admin. Come up with a solution to + // notify the admin and user of this issue. continue } // Everyone group is always implied, so include it. From 7f28a5359be59fb978170f2ddd3444391b200510 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 11 Sep 2024 12:09:27 -0500 Subject: [PATCH 37/38] remove function only used in 1 place --- coderd/database/dbmem/dbmem.go | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 2e4e737ed5428..ed766d48ecd43 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -682,17 +682,6 @@ func (q *FakeQuerier) getWorkspaceResourcesByJobIDNoLock(_ context.Context, jobI return resources, nil } -func (q *FakeQuerier) getGroupByNameNoLock(arg database.NameOrganizationPair) (database.Group, error) { - for _, group := range q.groups { - if group.OrganizationID == arg.OrganizationID && - group.Name == arg.Name { - return group, nil - } - } - - return database.Group{}, sql.ErrNoRows -} - func (q *FakeQuerier) getGroupByIDNoLock(_ context.Context, id uuid.UUID) (database.Group, error) { for _, group := range q.groups { if group.ID == id { @@ -2624,10 +2613,14 @@ func (q *FakeQuerier) GetGroupByOrgAndName(_ context.Context, arg database.GetGr q.mutex.RLock() defer q.mutex.RUnlock() - return q.getGroupByNameNoLock(database.NameOrganizationPair{ - Name: arg.Name, - OrganizationID: arg.OrganizationID, - }) + for _, group := range q.groups { + if group.OrganizationID == arg.OrganizationID && + group.Name == arg.Name { + return group, nil + } + } + + return database.Group{}, sql.ErrNoRows } func (q *FakeQuerier) GetGroupMembers(ctx context.Context) ([]database.GroupMember, error) { From 41994d2195e5980b7a5fea352a77aeb41a61cbfb Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 11 Sep 2024 12:14:14 -0500 Subject: [PATCH 38/38] make fmt --- coderd/idpsync/group.go | 1 + 1 file changed, 1 insertion(+) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 7c61aeb2fe4ef..1b6b8f76dc685 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -28,6 +28,7 @@ func (AGPLIDPSync) GroupSyncEnabled() bool { // AGPL does not support syncing groups. return false } + func (s AGPLIDPSync) GroupSyncSettings() runtimeconfig.RuntimeEntry[*GroupSyncSettings] { return s.Group }