diff --git a/coderd/coderd.go b/coderd/coderd.go index 6110733edecc3..d6ec155dc42f2 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -134,7 +134,7 @@ type Options struct { BaseDERPMap *tailcfg.DERPMap DERPMapUpdateFrequency time.Duration SwaggerEndpoint bool - SetUserGroups func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, groupNames []string, createMissingGroups bool) error + SetUserGroups func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error SetUserSiteRoles func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, roles []string) error TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore] UserQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore] @@ -301,9 +301,11 @@ func New(options *Options) *API { options.TracerProvider = trace.NewNoopTracerProvider() } if options.SetUserGroups == nil { - options.SetUserGroups = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, groups []string, createMissingGroups bool) error { + options.SetUserGroups = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error { logger.Warn(ctx, "attempted to assign OIDC groups without enterprise license", - slog.F("user_id", userID), slog.F("groups", groups), slog.F("create_missing_groups", createMissingGroups), + slog.F("user_id", userID), + slog.F("groups", orgGroupNames), + slog.F("create_missing_groups", createMissingGroups), ) return nil } diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index a6c6b34f2dafa..28d9a4fafb1e4 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -793,16 +793,6 @@ func (q *querier) DeleteGroupMemberFromGroup(ctx context.Context, arg database.D return update(q.log, q.auth, fetch, q.db.DeleteGroupMemberFromGroup)(ctx, arg) } -func (q *querier) DeleteGroupMembersByOrgAndUser(ctx context.Context, arg database.DeleteGroupMembersByOrgAndUserParams) error { - // This will remove the user from all groups in the org. This counts as updating a group. - // NOTE: instead of fetching all groups in the org with arg.UserID as a member, we instead - // check if the caller has permission to update any group in the org. - fetch := func(ctx context.Context, arg database.DeleteGroupMembersByOrgAndUserParams) (rbac.Objecter, error) { - return rbac.ResourceGroup.InOrg(arg.OrganizationID), nil - } - return update(q.log, q.auth, fetch, q.db.DeleteGroupMembersByOrgAndUser)(ctx, arg) -} - func (q *querier) DeleteLicense(ctx context.Context, id int32) (int32, error) { err := deleteQ(q.log, q.auth, q.db.GetLicenseByID, func(ctx context.Context, id int32) error { _, err := q.db.DeleteLicense(ctx, id) @@ -2555,6 +2545,14 @@ func (q *querier) RegisterWorkspaceProxy(ctx context.Context, arg database.Regis return updateWithReturn(q.log, q.auth, fetch, q.db.RegisterWorkspaceProxy)(ctx, arg) } +func (q *querier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error { + // This is a system function to clear user groups in group sync. + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { + return err + } + return q.db.RemoveUserFromAllGroups(ctx, userID) +} + func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error { if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { return err diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index c55b55a3d164d..207f4a64a9b78 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -344,17 +344,14 @@ func (s *MethodTestSuite) TestGroup() { GroupNames: slice.New(g1.Name, g2.Name), }).Asserts(rbac.ResourceGroup.InOrg(o.ID), rbac.ActionUpdate).Returns() })) - s.Run("DeleteGroupMembersByOrgAndUser", s.Subtest(func(db database.Store, check *expects) { + s.Run("RemoveUserFromAllGroups", s.Subtest(func(db database.Store, check *expects) { o := dbgen.Organization(s.T(), db, database.Organization{}) u1 := dbgen.User(s.T(), db, database.User{}) g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) _ = dbgen.GroupMember(s.T(), db, database.GroupMember{GroupID: g1.ID, UserID: u1.ID}) _ = dbgen.GroupMember(s.T(), db, database.GroupMember{GroupID: g2.ID, UserID: u1.ID}) - check.Args(database.DeleteGroupMembersByOrgAndUserParams{ - OrganizationID: o.ID, - UserID: u1.ID, - }).Asserts(rbac.ResourceGroup.InOrg(o.ID), rbac.ActionUpdate).Returns() + check.Args(u1.ID).Asserts(rbac.ResourceSystem, rbac.ActionUpdate).Returns() })) s.Run("UpdateGroupByID", s.Subtest(func(db database.Store, check *expects) { g := dbgen.Group(s.T(), db, database.Group{}) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index ae0a0d7e48d33..5c837130b6a52 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -1135,36 +1135,6 @@ func (q *FakeQuerier) DeleteGroupMemberFromGroup(_ context.Context, arg database return nil } -func (q *FakeQuerier) DeleteGroupMembersByOrgAndUser(_ context.Context, arg database.DeleteGroupMembersByOrgAndUserParams) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - newMembers := q.groupMembers[:0] - for _, member := range q.groupMembers { - if member.UserID != arg.UserID { - // Do not delete the other members - newMembers = append(newMembers, member) - } else if member.UserID == arg.UserID { - // We only want to delete from groups in the organization in the args. - for _, group := range q.groups { - // Find the group that the member is apartof. - if group.ID == member.GroupID { - // Only add back the member if the organization ID does not match - // the arg organization ID. Since the arg is saying which - // org to delete. - if group.OrganizationID != arg.OrganizationID { - newMembers = append(newMembers, member) - } - break - } - } - } - } - q.groupMembers = newMembers - - return nil -} - func (q *FakeQuerier) DeleteLicense(_ context.Context, id int32) (int32, error) { q.mutex.Lock() defer q.mutex.Unlock() @@ -6096,6 +6066,22 @@ func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.Reg return database.WorkspaceProxy{}, sql.ErrNoRows } +func (q *FakeQuerier) RemoveUserFromAllGroups(_ context.Context, userID uuid.UUID) error { + q.mutex.Lock() + defer q.mutex.Unlock() + + newMembers := q.groupMembers[:0] + for _, member := range q.groupMembers { + if member.UserID == userID { + continue + } + newMembers = append(newMembers, member) + } + q.groupMembers = newMembers + + return nil +} + func (q *FakeQuerier) RevokeDBCryptKey(_ context.Context, activeKeyDigest string) error { q.mutex.Lock() defer q.mutex.Unlock() diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index b07b7b0305d9c..11d0d275920ac 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -211,13 +211,6 @@ func (m metricsStore) DeleteGroupMemberFromGroup(ctx context.Context, arg databa return err } -func (m metricsStore) DeleteGroupMembersByOrgAndUser(ctx context.Context, arg database.DeleteGroupMembersByOrgAndUserParams) error { - start := time.Now() - err := m.s.DeleteGroupMembersByOrgAndUser(ctx, arg) - m.queryLatencies.WithLabelValues("DeleteGroupMembersByOrgAndUser").Observe(time.Since(start).Seconds()) - return err -} - func (m metricsStore) DeleteLicense(ctx context.Context, id int32) (int32, error) { start := time.Now() licenseID, err := m.s.DeleteLicense(ctx, id) @@ -1642,6 +1635,13 @@ func (m metricsStore) RegisterWorkspaceProxy(ctx context.Context, arg database.R return proxy, err } +func (m metricsStore) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error { + start := time.Now() + r0 := m.s.RemoveUserFromAllGroups(ctx, userID) + m.queryLatencies.WithLabelValues("RemoveUserFromAllGroups").Observe(time.Since(start).Seconds()) + return r0 +} + func (m metricsStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error { start := time.Now() r0 := m.s.RevokeDBCryptKey(ctx, activeKeyDigest) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index cbe91468c2a6d..1ec23fbc970f6 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -313,20 +313,6 @@ func (mr *MockStoreMockRecorder) DeleteGroupMemberFromGroup(arg0, arg1 any) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteGroupMemberFromGroup", reflect.TypeOf((*MockStore)(nil).DeleteGroupMemberFromGroup), arg0, arg1) } -// DeleteGroupMembersByOrgAndUser mocks base method. -func (m *MockStore) DeleteGroupMembersByOrgAndUser(arg0 context.Context, arg1 database.DeleteGroupMembersByOrgAndUserParams) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteGroupMembersByOrgAndUser", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteGroupMembersByOrgAndUser indicates an expected call of DeleteGroupMembersByOrgAndUser. -func (mr *MockStoreMockRecorder) DeleteGroupMembersByOrgAndUser(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteGroupMembersByOrgAndUser", reflect.TypeOf((*MockStore)(nil).DeleteGroupMembersByOrgAndUser), arg0, arg1) -} - // DeleteLicense mocks base method. func (m *MockStore) DeleteLicense(arg0 context.Context, arg1 int32) (int32, error) { m.ctrl.T.Helper() @@ -3470,6 +3456,20 @@ func (mr *MockStoreMockRecorder) RegisterWorkspaceProxy(arg0, arg1 any) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterWorkspaceProxy", reflect.TypeOf((*MockStore)(nil).RegisterWorkspaceProxy), arg0, arg1) } +// RemoveUserFromAllGroups mocks base method. +func (m *MockStore) RemoveUserFromAllGroups(arg0 context.Context, arg1 uuid.UUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveUserFromAllGroups", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemoveUserFromAllGroups indicates an expected call of RemoveUserFromAllGroups. +func (mr *MockStoreMockRecorder) RemoveUserFromAllGroups(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromAllGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromAllGroups), arg0, arg1) +} + // RevokeDBCryptKey mocks base method. func (m *MockStore) RevokeDBCryptKey(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 4b459e3141216..00353daaef876 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -58,7 +58,6 @@ type sqlcQuerier interface { DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error DeleteGroupByID(ctx context.Context, id uuid.UUID) error DeleteGroupMemberFromGroup(ctx context.Context, arg DeleteGroupMemberFromGroupParams) error - DeleteGroupMembersByOrgAndUser(ctx context.Context, arg DeleteGroupMembersByOrgAndUserParams) error DeleteLicense(ctx context.Context, id int32) (int32, error) DeleteOAuth2ProviderAppByID(ctx context.Context, id uuid.UUID) error DeleteOAuth2ProviderAppSecretByID(ctx context.Context, id uuid.UUID) error @@ -322,6 +321,7 @@ type sqlcQuerier interface { InsertWorkspaceResourceMetadata(ctx context.Context, arg InsertWorkspaceResourceMetadataParams) ([]WorkspaceResourceMetadatum, error) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgentPortShare, error) RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error) + RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error // Non blocking lock. Returns true if the lock was acquired, false otherwise. // diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 002da316cbccf..5e9577f264f5e 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1288,24 +1288,6 @@ func (q *sqlQuerier) DeleteGroupMemberFromGroup(ctx context.Context, arg DeleteG return err } -const deleteGroupMembersByOrgAndUser = `-- name: DeleteGroupMembersByOrgAndUser :exec -DELETE FROM - group_members -WHERE - group_members.user_id = $1 - AND group_id = ANY(SELECT id FROM groups WHERE organization_id = $2) -` - -type DeleteGroupMembersByOrgAndUserParams struct { - UserID uuid.UUID `db:"user_id" json:"user_id"` - OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` -} - -func (q *sqlQuerier) DeleteGroupMembersByOrgAndUser(ctx context.Context, arg DeleteGroupMembersByOrgAndUserParams) error { - _, err := q.db.ExecContext(ctx, deleteGroupMembersByOrgAndUser, arg.UserID, arg.OrganizationID) - return err -} - const getGroupMembers = `-- name: GetGroupMembers :many SELECT users.id, users.email, users.username, users.hashed_password, users.created_at, users.updated_at, users.status, users.rbac_roles, users.login_type, users.avatar_url, users.deleted, users.last_seen_at, users.quiet_hours_schedule, users.theme_preference, users.name @@ -1419,6 +1401,18 @@ func (q *sqlQuerier) InsertUserGroupsByName(ctx context.Context, arg InsertUserG return err } +const removeUserFromAllGroups = `-- name: RemoveUserFromAllGroups :exec +DELETE FROM + group_members +WHERE + user_id = $1 +` + +func (q *sqlQuerier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error { + _, err := q.db.ExecContext(ctx, removeUserFromAllGroups, userID) + return err +} + const deleteGroupByID = `-- name: DeleteGroupByID :exec DELETE FROM groups diff --git a/coderd/database/queries/groupmembers.sql b/coderd/database/queries/groupmembers.sql index 4999df7930044..d755212132383 100644 --- a/coderd/database/queries/groupmembers.sql +++ b/coderd/database/queries/groupmembers.sql @@ -42,12 +42,11 @@ SELECT FROM groups; --- name: DeleteGroupMembersByOrgAndUser :exec +-- name: RemoveUserFromAllGroups :exec DELETE FROM group_members WHERE - group_members.user_id = @user_id - AND group_id = ANY(SELECT id FROM groups WHERE organization_id = @organization_id); + user_id = @user_id; -- name: InsertGroupMember :exec INSERT INTO diff --git a/coderd/userauth.go b/coderd/userauth.go index 3b83d1ed696e1..188a877e51055 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -20,6 +20,7 @@ import ( "github.com/google/go-github/v43/github" "github.com/google/uuid" "github.com/moby/moby/pkg/namesgenerator" + "golang.org/x/exp/slices" "golang.org/x/oauth2" "golang.org/x/xerrors" @@ -1217,8 +1218,10 @@ type oauthLoginParams struct { // to the Groups provided. UsingGroups bool CreateMissingGroups bool - Groups []string - GroupFilter *regexp.Regexp + // These are the group names from the IDP. Internally, they will map to + // some organization groups. + Groups []string + GroupFilter *regexp.Regexp // Is UsingRoles is true, then the user will be assigned // the roles provided. UsingRoles bool @@ -1301,7 +1304,6 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C link database.UserLink err error ) - user = params.User link = params.Link @@ -1460,6 +1462,9 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C } // Ensure groups are correct. + // This places all groups into the default organization. + // To go multi-org, we need to add a mapping feature here to know which + // groups go to which orgs. if params.UsingGroups { filtered := params.Groups if params.GroupFilter != nil { @@ -1471,8 +1476,32 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C } } + //nolint:gocritic // No user present in the context. + defaultOrganization, err := tx.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx)) + if err != nil { + // If there is no default org, then we can't assign groups. + // By default, we assume all groups belong to the default org. + return xerrors.Errorf("get default organization: %w", err) + } + + //nolint:gocritic // No user present in the context. + memberships, err := tx.GetOrganizationMembershipsByUserID(dbauthz.AsSystemRestricted(ctx), user.ID) + if err != nil { + return xerrors.Errorf("get organization memberships: %w", err) + } + + // If the user is not in the default organization, then we can't assign groups. + // A user cannot be in groups to an org they are not a member of. + if !slices.ContainsFunc(memberships, func(member database.OrganizationMember) bool { + return member.OrganizationID == defaultOrganization.ID + }) { + return xerrors.Errorf("user %s is not a member of the default organization, cannot assign to groups in the org", user.ID) + } + //nolint:gocritic - err := api.Options.SetUserGroups(dbauthz.AsSystemRestricted(ctx), logger, tx, user.ID, filtered, params.CreateMissingGroups) + err = api.Options.SetUserGroups(dbauthz.AsSystemRestricted(ctx), logger, tx, user.ID, map[uuid.UUID][]string{ + defaultOrganization.ID: filtered, + }, params.CreateMissingGroups) if err != nil { return xerrors.Errorf("set user groups: %w", err) } diff --git a/enterprise/coderd/userauth.go b/enterprise/coderd/userauth.go index f504a6c0325c4..f35d38ca448d9 100644 --- a/enterprise/coderd/userauth.go +++ b/enterprise/coderd/userauth.go @@ -14,7 +14,7 @@ import ( ) // nolint: revive -func (api *API) setUserGroups(ctx context.Context, logger slog.Logger, db database.Store, userID uuid.UUID, groupNames []string, createMissingGroups bool) error { +func (api *API) setUserGroups(ctx context.Context, logger slog.Logger, db database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error { api.entitlementsMu.RLock() enabled := api.entitlements.Features[codersdk.FeatureTemplateRBAC].Enabled api.entitlementsMu.RUnlock() @@ -24,6 +24,8 @@ func (api *API) setUserGroups(ctx context.Context, logger slog.Logger, db databa } return db.InTx(func(tx database.Store) error { + // When setting the user's groups, it's easier to just clear their groups and re-add them. + // This ensures that the user's groups are always in sync with the auth provider. orgs, err := tx.GetOrganizationsByUserID(ctx, userID) if err != nil { return xerrors.Errorf("get user orgs: %w", err) @@ -33,43 +35,49 @@ func (api *API) setUserGroups(ctx context.Context, logger slog.Logger, db databa } // Delete all groups the user belongs to. - err = tx.DeleteGroupMembersByOrgAndUser(ctx, database.DeleteGroupMembersByOrgAndUserParams{ - UserID: userID, - OrganizationID: orgs[0].ID, - }) + // nolint:gocritic // Requires system context to remove user from all groups. + err = tx.RemoveUserFromAllGroups(dbauthz.AsSystemRestricted(ctx), userID) if err != nil { return xerrors.Errorf("delete user groups: %w", err) } - if createMissingGroups { - // This is the system creating these additional groups, so we use the system restricted context. - // nolint:gocritic - created, err := tx.InsertMissingGroups(dbauthz.AsSystemRestricted(ctx), database.InsertMissingGroupsParams{ - OrganizationID: orgs[0].ID, + // TODO: This could likely be improved by making these single queries. + // Either by batching or some other means. This for loop could be really + // inefficient if there are a lot of organizations. There was deployments + // on v1 with >100 orgs. + for orgID, groupNames := range orgGroupNames { + // Create the missing groups for each organization. + if createMissingGroups { + // This is the system creating these additional groups, so we use the system restricted context. + // nolint:gocritic + created, err := tx.InsertMissingGroups(dbauthz.AsSystemRestricted(ctx), database.InsertMissingGroupsParams{ + OrganizationID: orgID, + GroupNames: groupNames, + Source: database.GroupSourceOidc, + }) + if err != nil { + return xerrors.Errorf("insert missing groups: %w", err) + } + if len(created) > 0 { + logger.Debug(ctx, "auto created missing groups", + slog.F("org_id", orgID.ID), + slog.F("created", created), + slog.F("num", len(created)), + ) + } + } + + // Re-add the user to all groups returned by the auth provider. + err = tx.InsertUserGroupsByName(ctx, database.InsertUserGroupsByNameParams{ + UserID: userID, + OrganizationID: orgID, GroupNames: groupNames, - Source: database.GroupSourceOidc, }) if err != nil { - return xerrors.Errorf("insert missing groups: %w", err) - } - if len(created) > 0 { - logger.Debug(ctx, "auto created missing groups", - slog.F("org_id", orgs[0].ID), - slog.F("created", created), - ) + return xerrors.Errorf("insert user groups: %w", err) } } - // Re-add the user to all groups returned by the auth provider. - err = tx.InsertUserGroupsByName(ctx, database.InsertUserGroupsByNameParams{ - UserID: userID, - OrganizationID: orgs[0].ID, - GroupNames: groupNames, - }) - if err != nil { - return xerrors.Errorf("insert user groups: %w", err) - } - return nil }, nil) }