Skip to content

Commit 6ffdea8

Browse files
committed
work on batching removal by name or id
1 parent 4798911 commit 6ffdea8

File tree

5 files changed

+130
-55
lines changed

5 files changed

+130
-55
lines changed

coderd/database/dbmem/dbmem.go

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,17 @@ func (q *FakeQuerier) getWorkspaceResourcesByJobIDNoLock(_ context.Context, jobI
683683
return resources, nil
684684
}
685685

686+
func (q *FakeQuerier) getGroupByNameNoLock(arg database.NameOrganizationPair) (database.Group, error) {
687+
for _, group := range q.groups {
688+
if group.OrganizationID == arg.OrganizationID &&
689+
group.Name == arg.Name {
690+
return group, nil
691+
}
692+
}
693+
694+
return database.Group{}, sql.ErrNoRows
695+
}
696+
686697
func (q *FakeQuerier) getGroupByIDNoLock(_ context.Context, id uuid.UUID) (database.Group, error) {
687698
for _, group := range q.groups {
688699
if group.ID == id {
@@ -2614,14 +2625,10 @@ func (q *FakeQuerier) GetGroupByOrgAndName(_ context.Context, arg database.GetGr
26142625
q.mutex.RLock()
26152626
defer q.mutex.RUnlock()
26162627

2617-
for _, group := range q.groups {
2618-
if group.OrganizationID == arg.OrganizationID &&
2619-
group.Name == arg.Name {
2620-
return group, nil
2621-
}
2622-
}
2623-
2624-
return database.Group{}, sql.ErrNoRows
2628+
return q.getGroupByNameNoLock(database.NameOrganizationPair{
2629+
Name: arg.Name,
2630+
OrganizationID: arg.OrganizationID,
2631+
})
26252632
}
26262633

26272634
func (q *FakeQuerier) GetGroupMembers(ctx context.Context) ([]database.GroupMember, error) {
@@ -7648,14 +7655,24 @@ func (q *FakeQuerier) RemoveUserFromGroups(_ context.Context, arg database.Remov
76487655

76497656
removed := make([]uuid.UUID, 0)
76507657
q.data.groupMembers = slices.DeleteFunc(q.data.groupMembers, func(groupMember database.GroupMemberTable) bool {
7658+
// Delete all group members that match the arguments.
76517659
if groupMember.UserID != arg.UserID {
7660+
// Not the right user, ignore.
76527661
return false
76537662
}
7654-
if !slices.Contains(arg.GroupIds, groupMember.GroupID) {
7655-
return false
7663+
7664+
matchesByID := slices.Contains(arg.GroupIds, groupMember.GroupID)
7665+
matchesByName := slices.ContainsFunc(arg.GroupNames, func(name database.NameOrganizationPair) bool {
7666+
_, err := q.getGroupByNameNoLock(name)
7667+
return err == nil
7668+
})
7669+
7670+
if matchesByName || matchesByID {
7671+
removed = append(removed, groupMember.GroupID)
7672+
return true
76567673
}
7657-
removed = append(removed, groupMember.GroupID)
7658-
return true
7674+
7675+
return false
76597676
})
76607677

76617678
return removed, nil

coderd/database/queries.sql.go

Lines changed: 15 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/groupmembers.sql

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,19 @@ WHERE
5757
-- name: RemoveUserFromGroups :many
5858
DELETE FROM
5959
group_members
60+
USING groups
6061
WHERE
62+
group_members.group_id = groups.id AND
6163
user_id = @user_id AND
62-
group_id = ANY(@group_ids :: uuid [])
64+
(
65+
CASE WHEN array_length(@group_names :: name_organization_pair[], 1) > 0 THEN
66+
-- Using 'coalesce' to avoid troubles with null literals being an empty string.
67+
(groups.name, coalesce(groups.organization_id, '00000000-0000-0000-0000-000000000000' ::uuid)) = ANY (@group_names::name_organization_pair[])
68+
ELSE false
69+
END
70+
OR
71+
group_id = ANY (@group_ids :: uuid[])
72+
)
6373
RETURNING group_id;
6474

6575
-- name: InsertGroupMember :exec

coderd/idpsync/group.go

Lines changed: 71 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat
9191
}
9292

9393
// collect all diffs to do 1 sql update for all orgs
94-
groupsToAdd := make([]uuid.UUID, 0)
95-
groupsToRemove := make([]uuid.UUID, 0)
94+
groupIDsToAdd := make([]uuid.UUID, 0)
95+
groupsToRemove := make([]ExpectedGroup, 0)
9696
// For each org, determine which groups the user should land in
9797
for orgID, settings := range orgSettings {
9898
if settings.GroupField == "" {
@@ -112,7 +112,8 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat
112112
}
113113
// Everyone group is always implied.
114114
expectedGroups = append(expectedGroups, ExpectedGroup{
115-
GroupID: &orgID,
115+
OrganizationID: orgID,
116+
GroupID: &orgID,
116117
})
117118

118119
// Now we know what groups the user should be in for a given org,
@@ -121,8 +122,9 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat
121122
existingGroups := userOrgs[orgID]
122123
existingGroupsTyped := db2sdk.List(existingGroups, func(f database.GetGroupsRow) ExpectedGroup {
123124
return ExpectedGroup{
124-
GroupID: &f.Group.ID,
125-
GroupName: &f.Group.Name,
125+
OrganizationID: orgID,
126+
GroupID: &f.Group.ID,
127+
GroupName: &f.Group.Name,
126128
}
127129
})
128130
add, remove := slice.SymmetricDifferenceFunc(existingGroupsTyped, expectedGroups, func(a, b ExpectedGroup) bool {
@@ -144,52 +146,75 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat
144146
return xerrors.Errorf("handle missing groups: %w", err)
145147
}
146148

147-
for _, removeGroup := range remove {
148-
// This should always be the case.
149-
// TODO: make sure this is always the case
150-
if removeGroup.GroupID != nil {
151-
groupsToRemove = append(groupsToRemove, *removeGroup.GroupID)
152-
}
153-
}
149+
groupsToRemove = append(groupsToRemove, remove...)
150+
groupIDsToAdd = append(groupIDsToAdd, assignGroups...)
151+
}
154152

155-
groupsToAdd = append(groupsToAdd, assignGroups...)
153+
err = s.applyGroupDifference(ctx, tx, user, groupIDsToAdd, groupsToRemove)
154+
if err != nil {
155+
return xerrors.Errorf("apply group difference: %w", err)
156156
}
157157

158-
assignedGroupIDs, err := tx.InsertUserGroupsByID(ctx, database.InsertUserGroupsByIDParams{
159-
UserID: user.ID,
160-
GroupIds: groupsToAdd,
158+
return nil
159+
}, nil)
160+
161+
if err != nil {
162+
return err
163+
}
164+
165+
return nil
166+
}
167+
168+
func (s AGPLIDPSync) applyGroupDifference(ctx context.Context, tx database.Store, user database.User, add []uuid.UUID, remove []ExpectedGroup) error {
169+
// Always do group removal before group add. This way if there is an error,
170+
// we error on the underprivileged side.
171+
removeIDs := make([]uuid.UUID, 0)
172+
removeNames := make([]database.NameOrganizationPair, 0)
173+
for _, r := range remove {
174+
if r.GroupID != nil {
175+
removeIDs = append(removeIDs, *r.GroupID)
176+
} else if r.GroupName != nil {
177+
removeNames = append(removeNames, database.NameOrganizationPair{
178+
Name: *r.GroupName,
179+
OrganizationID: r.OrganizationID,
180+
})
181+
}
182+
}
183+
184+
// If there is something to remove, do it.
185+
if len(removeIDs) > 0 || len(removeNames) > 0 {
186+
removedGroupIDs, err := tx.RemoveUserFromGroups(ctx, database.RemoveUserFromGroupsParams{
187+
UserID: user.ID,
188+
GroupNames: removeNames,
189+
GroupIds: removeIDs,
161190
})
162191
if err != nil {
163-
return xerrors.Errorf("insert user into %d groups: %w", len(groupsToAdd), err)
192+
return xerrors.Errorf("remove user from %d groups: %w", len(removeIDs), err)
164193
}
165-
if len(assignedGroupIDs) != len(groupsToAdd) {
166-
s.Logger.Debug(ctx, "failed to assign all groups to user",
194+
if len(removedGroupIDs) != len(removeIDs) {
195+
s.Logger.Debug(ctx, "failed to remove user from all groups",
167196
slog.F("user_id", user.ID),
168-
slog.F("groups_assigned_count", len(assignedGroupIDs)),
169-
slog.F("expected_count", len(groupsToAdd)),
197+
slog.F("groups_removed_count", len(removedGroupIDs)),
198+
slog.F("expected_count", len(removeIDs)),
170199
)
171200
}
201+
}
172202

173-
removedGroupIDs, err := tx.RemoveUserFromGroups(ctx, database.RemoveUserFromGroupsParams{
203+
if len(add) > 0 {
204+
assignedGroupIDs, err := tx.InsertUserGroupsByID(ctx, database.InsertUserGroupsByIDParams{
174205
UserID: user.ID,
175-
GroupIds: groupsToRemove,
206+
GroupIds: add,
176207
})
177208
if err != nil {
178-
return xerrors.Errorf("remove user from %d groups: %w", len(groupsToRemove), err)
209+
return xerrors.Errorf("insert user into %d groups: %w", len(add), err)
179210
}
180-
if len(removedGroupIDs) != len(groupsToRemove) {
181-
s.Logger.Debug(ctx, "failed to remove user from all groups",
211+
if len(assignedGroupIDs) != len(add) {
212+
s.Logger.Debug(ctx, "failed to assign all groups to user",
182213
slog.F("user_id", user.ID),
183-
slog.F("groups_removed_count", len(removedGroupIDs)),
184-
slog.F("expected_count", len(groupsToRemove)),
214+
slog.F("groups_assigned_count", len(assignedGroupIDs)),
215+
slog.F("expected_count", len(add)),
185216
)
186217
}
187-
188-
return nil
189-
}, nil)
190-
191-
if err != nil {
192-
return err
193218
}
194219

195220
return nil
@@ -226,8 +251,9 @@ func (s *GroupSyncSettings) Type() string {
226251
}
227252

228253
type ExpectedGroup struct {
229-
GroupID *uuid.UUID
230-
GroupName *string
254+
OrganizationID uuid.UUID
255+
GroupID *uuid.UUID
256+
GroupName *string
231257
}
232258

233259
// ParseClaims will take the merged claims from the IDP and return the groups
@@ -280,20 +306,28 @@ func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGr
280306
return groups, nil
281307
}
282308

309+
// HandleMissingGroups ensures all ExpectedGroups convert to uuids.
310+
// Groups can be referenced by name via legacy params or IDP group names.
311+
// These group names are converted to IDs for easier assignment.
312+
// Missing groups are created if AutoCreate is enabled.
313+
// TODO: Batching this would be better, as this is 1 or 2 db calls per organization.
283314
func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database.Store, orgID uuid.UUID, add []ExpectedGroup) ([]uuid.UUID, error) {
284315
if !s.AutoCreateMissingGroups {
285-
// construct the list of groups to search by name to see if they exist.
316+
// If we are not creating groups, then just construct a db lookup for
317+
// all groups by name.
286318
var lookups []string
287319
filter := make([]uuid.UUID, 0)
288320
for _, expected := range add {
289321
if expected.GroupID != nil {
322+
// Groups with IDs are easy!
290323
filter = append(filter, *expected.GroupID)
291324
} else if expected.GroupName != nil {
292325
lookups = append(lookups, *expected.GroupName)
293326
}
294327
}
295328

296329
if len(lookups) > 0 {
330+
// Do name lookups for all groups that are missing IDs.
297331
newGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{
298332
OrganizationID: uuid.UUID{},
299333
HasMemberID: uuid.UUID{},

coderd/idpsync/group_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,12 @@ func TestGroupSyncTable(t *testing.T) {
208208
},
209209
AutoCreateMissingGroups: true,
210210
},
211-
Groups: map[uuid.UUID]bool{},
211+
Groups: map[uuid.UUID]bool{
212+
ids.ID("lg-foo"): true,
213+
},
212214
GroupNames: map[string]bool{
213215
"legacy-foo": false,
216+
"extra": true,
214217
},
215218
ExpectedGroupNames: []string{
216219
"legacy-bar",

0 commit comments

Comments
 (0)