Skip to content

Commit da6c21c

Browse files
committed
feat: set groupsync to use default org
1 parent bf35196 commit da6c21c

File tree

11 files changed

+139
-123
lines changed

11 files changed

+139
-123
lines changed

coderd/coderd.go

+5-3
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ type Options struct {
134134
BaseDERPMap *tailcfg.DERPMap
135135
DERPMapUpdateFrequency time.Duration
136136
SwaggerEndpoint bool
137-
SetUserGroups func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, groupNames []string, createMissingGroups bool) error
137+
SetUserGroups func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error
138138
SetUserSiteRoles func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, roles []string) error
139139
TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
140140
UserQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore]
@@ -301,9 +301,11 @@ func New(options *Options) *API {
301301
options.TracerProvider = trace.NewNoopTracerProvider()
302302
}
303303
if options.SetUserGroups == nil {
304-
options.SetUserGroups = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, groups []string, createMissingGroups bool) error {
304+
options.SetUserGroups = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error {
305305
logger.Warn(ctx, "attempted to assign OIDC groups without enterprise license",
306-
slog.F("user_id", userID), slog.F("groups", groups), slog.F("create_missing_groups", createMissingGroups),
306+
slog.F("user_id", userID),
307+
slog.F("groups", orgGroupNames),
308+
slog.F("create_missing_groups", createMissingGroups),
307309
)
308310
return nil
309311
}

coderd/database/dbauthz/dbauthz.go

+8-10
Original file line numberDiff line numberDiff line change
@@ -793,16 +793,6 @@ func (q *querier) DeleteGroupMemberFromGroup(ctx context.Context, arg database.D
793793
return update(q.log, q.auth, fetch, q.db.DeleteGroupMemberFromGroup)(ctx, arg)
794794
}
795795

796-
func (q *querier) DeleteGroupMembersByOrgAndUser(ctx context.Context, arg database.DeleteGroupMembersByOrgAndUserParams) error {
797-
// This will remove the user from all groups in the org. This counts as updating a group.
798-
// NOTE: instead of fetching all groups in the org with arg.UserID as a member, we instead
799-
// check if the caller has permission to update any group in the org.
800-
fetch := func(ctx context.Context, arg database.DeleteGroupMembersByOrgAndUserParams) (rbac.Objecter, error) {
801-
return rbac.ResourceGroup.InOrg(arg.OrganizationID), nil
802-
}
803-
return update(q.log, q.auth, fetch, q.db.DeleteGroupMembersByOrgAndUser)(ctx, arg)
804-
}
805-
806796
func (q *querier) DeleteLicense(ctx context.Context, id int32) (int32, error) {
807797
err := deleteQ(q.log, q.auth, q.db.GetLicenseByID, func(ctx context.Context, id int32) error {
808798
_, err := q.db.DeleteLicense(ctx, id)
@@ -2549,6 +2539,14 @@ func (q *querier) RegisterWorkspaceProxy(ctx context.Context, arg database.Regis
25492539
return updateWithReturn(q.log, q.auth, fetch, q.db.RegisterWorkspaceProxy)(ctx, arg)
25502540
}
25512541

2542+
func (q *querier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error {
2543+
// This is a system function to clear user groups in group sync.
2544+
if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil {
2545+
return err
2546+
}
2547+
return q.db.RemoveUserFromAllGroups(ctx, userID)
2548+
}
2549+
25522550
func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
25532551
if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil {
25542552
return err

coderd/database/dbauthz/dbauthz_test.go

+2-5
Original file line numberDiff line numberDiff line change
@@ -344,17 +344,14 @@ func (s *MethodTestSuite) TestGroup() {
344344
GroupNames: slice.New(g1.Name, g2.Name),
345345
}).Asserts(rbac.ResourceGroup.InOrg(o.ID), rbac.ActionUpdate).Returns()
346346
}))
347-
s.Run("DeleteGroupMembersByOrgAndUser", s.Subtest(func(db database.Store, check *expects) {
347+
s.Run("RemoveUserFromAllGroups", s.Subtest(func(db database.Store, check *expects) {
348348
o := dbgen.Organization(s.T(), db, database.Organization{})
349349
u1 := dbgen.User(s.T(), db, database.User{})
350350
g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
351351
g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
352352
_ = dbgen.GroupMember(s.T(), db, database.GroupMember{GroupID: g1.ID, UserID: u1.ID})
353353
_ = dbgen.GroupMember(s.T(), db, database.GroupMember{GroupID: g2.ID, UserID: u1.ID})
354-
check.Args(database.DeleteGroupMembersByOrgAndUserParams{
355-
OrganizationID: o.ID,
356-
UserID: u1.ID,
357-
}).Asserts(rbac.ResourceGroup.InOrg(o.ID), rbac.ActionUpdate).Returns()
354+
check.Args(u1.ID).Asserts(rbac.ResourceSystem, rbac.ActionUpdate).Returns()
358355
}))
359356
s.Run("UpdateGroupByID", s.Subtest(func(db database.Store, check *expects) {
360357
g := dbgen.Group(s.T(), db, database.Group{})

coderd/database/dbmem/dbmem.go

+16-30
Original file line numberDiff line numberDiff line change
@@ -1135,36 +1135,6 @@ func (q *FakeQuerier) DeleteGroupMemberFromGroup(_ context.Context, arg database
11351135
return nil
11361136
}
11371137

1138-
func (q *FakeQuerier) DeleteGroupMembersByOrgAndUser(_ context.Context, arg database.DeleteGroupMembersByOrgAndUserParams) error {
1139-
q.mutex.Lock()
1140-
defer q.mutex.Unlock()
1141-
1142-
newMembers := q.groupMembers[:0]
1143-
for _, member := range q.groupMembers {
1144-
if member.UserID != arg.UserID {
1145-
// Do not delete the other members
1146-
newMembers = append(newMembers, member)
1147-
} else if member.UserID == arg.UserID {
1148-
// We only want to delete from groups in the organization in the args.
1149-
for _, group := range q.groups {
1150-
// Find the group that the member is apartof.
1151-
if group.ID == member.GroupID {
1152-
// Only add back the member if the organization ID does not match
1153-
// the arg organization ID. Since the arg is saying which
1154-
// org to delete.
1155-
if group.OrganizationID != arg.OrganizationID {
1156-
newMembers = append(newMembers, member)
1157-
}
1158-
break
1159-
}
1160-
}
1161-
}
1162-
}
1163-
q.groupMembers = newMembers
1164-
1165-
return nil
1166-
}
1167-
11681138
func (q *FakeQuerier) DeleteLicense(_ context.Context, id int32) (int32, error) {
11691139
q.mutex.Lock()
11701140
defer q.mutex.Unlock()
@@ -6083,6 +6053,22 @@ func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.Reg
60836053
return database.WorkspaceProxy{}, sql.ErrNoRows
60846054
}
60856055

6056+
func (q *FakeQuerier) RemoveUserFromAllGroups(_ context.Context, userID uuid.UUID) error {
6057+
q.mutex.Lock()
6058+
defer q.mutex.Unlock()
6059+
6060+
newMembers := q.groupMembers[:0]
6061+
for _, member := range q.groupMembers {
6062+
if member.UserID == userID {
6063+
continue
6064+
}
6065+
newMembers = append(newMembers, member)
6066+
}
6067+
q.groupMembers = newMembers
6068+
6069+
return nil
6070+
}
6071+
60866072
func (q *FakeQuerier) RevokeDBCryptKey(_ context.Context, activeKeyDigest string) error {
60876073
q.mutex.Lock()
60886074
defer q.mutex.Unlock()

coderd/database/dbmetrics/dbmetrics.go

+7-7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

+14-14
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/querier.go

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries.sql.go

+12-18
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/groupmembers.sql

+2-3
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,11 @@ SELECT
4242
FROM
4343
groups;
4444

45-
-- name: DeleteGroupMembersByOrgAndUser :exec
45+
-- name: RemoveUserFromAllGroups :exec
4646
DELETE FROM
4747
group_members
4848
WHERE
49-
group_members.user_id = @user_id
50-
AND group_id = ANY(SELECT id FROM groups WHERE organization_id = @organization_id);
49+
user_id = @user_id;
5150

5251
-- name: InsertGroupMember :exec
5352
INSERT INTO

coderd/userauth.go

+36-4
Original file line numberDiff line numberDiff line change
@@ -1217,8 +1217,10 @@ type oauthLoginParams struct {
12171217
// to the Groups provided.
12181218
UsingGroups bool
12191219
CreateMissingGroups bool
1220-
Groups []string
1221-
GroupFilter *regexp.Regexp
1220+
// These are the group names from the IDP. Internally, they will map to
1221+
// some organization groups.
1222+
Groups []string
1223+
GroupFilter *regexp.Regexp
12221224
// Is UsingRoles is true, then the user will be assigned
12231225
// the roles provided.
12241226
UsingRoles bool
@@ -1301,7 +1303,6 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
13011303
link database.UserLink
13021304
err error
13031305
)
1304-
13051306
user = params.User
13061307
link = params.Link
13071308

@@ -1457,6 +1458,9 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
14571458
}
14581459

14591460
// Ensure groups are correct.
1461+
// This places all groups into the default organization.
1462+
// To go multi-org, we need to add a mapping feature here to know which
1463+
// groups go to which orgs.
14601464
if params.UsingGroups {
14611465
filtered := params.Groups
14621466
if params.GroupFilter != nil {
@@ -1468,8 +1472,36 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
14681472
}
14691473
}
14701474

1475+
//nolint:gocritic // No user present in the context.
1476+
defaultOrganization, err := tx.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx))
1477+
if err != nil {
1478+
// If there is no default org, then we can't assign groups.
1479+
// By default, we assume all groups belong to the default org.
1480+
return xerrors.Errorf("get default organization: %w", err)
1481+
}
1482+
1483+
//nolint:gocritic // No user present in the context.
1484+
memberships, err := tx.GetOrganizationMembershipsByUserID(dbauthz.AsSystemRestricted(ctx), user.ID)
1485+
if err != nil {
1486+
return xerrors.Errorf("get organization memberships: %w", err)
1487+
}
1488+
1489+
inDefault := false
1490+
for _, membership := range memberships {
1491+
if membership.OrganizationID == defaultOrganization.ID {
1492+
inDefault = true
1493+
break
1494+
}
1495+
}
1496+
1497+
if !inDefault {
1498+
return xerrors.Errorf("user %s is not a member of the default organization, cannot assign to groups in the org", user.ID)
1499+
}
1500+
14711501
//nolint:gocritic
1472-
err := api.Options.SetUserGroups(dbauthz.AsSystemRestricted(ctx), logger, tx, user.ID, filtered, params.CreateMissingGroups)
1502+
err = api.Options.SetUserGroups(dbauthz.AsSystemRestricted(ctx), logger, tx, user.ID, map[uuid.UUID][]string{
1503+
defaultOrganization.ID: filtered,
1504+
}, params.CreateMissingGroups)
14731505
if err != nil {
14741506
return xerrors.Errorf("set user groups: %w", err)
14751507
}

0 commit comments

Comments
 (0)