Skip to content

Commit b379725

Browse files
committed
group sync adjustments
1 parent 6ffdea8 commit b379725

File tree

5 files changed

+63
-113
lines changed

5 files changed

+63
-113
lines changed

coderd/database/dbmem/dbmem.go

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2731,6 +2731,10 @@ func (q *FakeQuerier) GetGroups(_ context.Context, arg database.GetGroupsParams)
27312731
continue
27322732
}
27332733

2734+
if len(arg.GroupNames) > 0 && !slices.Contains(arg.GroupNames, group.Name) {
2735+
continue
2736+
}
2737+
27342738
orgDetails, ok := orgDetailsCache[group.ID]
27352739
if !ok {
27362740
for _, org := range q.organizations {
@@ -7661,18 +7665,12 @@ func (q *FakeQuerier) RemoveUserFromGroups(_ context.Context, arg database.Remov
76617665
return false
76627666
}
76637667

7664-
matchesByID := slices.Contains(arg.GroupIds, groupMember.GroupID)
7665-
matchesByName := slices.ContainsFunc(arg.GroupNames, func(name database.NameOrganizationPair) bool {
7666-
_, err := q.getGroupByNameNoLock(name)
7667-
return err == nil
7668-
})
7669-
7670-
if matchesByName || matchesByID {
7671-
removed = append(removed, groupMember.GroupID)
7672-
return true
7668+
if !slices.Contains(arg.GroupIds, groupMember.GroupID) {
7669+
return false
76737670
}
76747671

7675-
return false
7672+
removed = append(removed, groupMember.GroupID)
7673+
return true
76767674
})
76777675

76787676
return removed, nil

coderd/database/queries.sql.go

Lines changed: 4 additions & 15 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/groupmembers.sql

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,9 @@ WHERE
5757
-- name: RemoveUserFromGroups :many
5858
DELETE FROM
5959
group_members
60-
USING groups
6160
WHERE
62-
group_members.group_id = groups.id AND
6361
user_id = @user_id AND
64-
(
65-
CASE WHEN array_length(@group_names :: name_organization_pair[], 1) > 0 THEN
66-
-- Using 'coalesce' to avoid troubles with null literals being an empty string.
67-
(groups.name, coalesce(groups.organization_id, '00000000-0000-0000-0000-000000000000' ::uuid)) = ANY (@group_names::name_organization_pair[])
68-
ELSE false
69-
END
70-
OR
71-
group_id = ANY (@group_ids :: uuid[])
72-
)
62+
group_id = ANY(@group_ids :: uuid [])
7363
RETURNING group_id;
7464

7565
-- name: InsertGroupMember :exec

coderd/idpsync/group.go

Lines changed: 48 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package idpsync
33
import (
44
"context"
55
"encoding/json"
6+
"fmt"
67
"regexp"
78

89
"github.com/golang-jwt/jwt/v4"
@@ -92,15 +93,15 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat
9293

9394
// collect all diffs to do 1 sql update for all orgs
9495
groupIDsToAdd := make([]uuid.UUID, 0)
95-
groupsToRemove := make([]ExpectedGroup, 0)
96+
groupIDsToRemove := make([]uuid.UUID, 0)
9697
// For each org, determine which groups the user should land in
9798
for orgID, settings := range orgSettings {
9899
if settings.GroupField == "" {
99100
// No group sync enabled for this org, so do nothing.
100101
continue
101102
}
102103

103-
expectedGroups, err := settings.ParseClaims(params.MergedClaims)
104+
expectedGroups, err := settings.ParseClaims(orgID, params.MergedClaims)
104105
if err != nil {
105106
s.Logger.Debug(ctx, "failed to parse claims for groups",
106107
slog.F("organization_field", s.GroupField),
@@ -128,6 +129,10 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat
128129
}
129130
})
130131
add, remove := slice.SymmetricDifferenceFunc(existingGroupsTyped, expectedGroups, func(a, b ExpectedGroup) bool {
132+
// Must match
133+
if a.OrganizationID != b.OrganizationID {
134+
return false
135+
}
131136
// Only the name or the name needs to be checked, priority is given to the ID.
132137
if a.GroupID != nil && b.GroupID != nil {
133138
return *a.GroupID == *b.GroupID
@@ -138,6 +143,20 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat
138143
return false
139144
})
140145

146+
for _, r := range remove {
147+
// This should never happen. All group removals come from the
148+
// existing set, which come from the db. All groups from the
149+
// database have IDs. This code is purely defensive.
150+
if r.GroupID == nil {
151+
detail := "user:" + user.Username
152+
if r.GroupName != nil {
153+
detail += fmt.Sprintf(" from group %s", *r.GroupName)
154+
}
155+
return xerrors.Errorf("removal group has nil ID, which should never happen: %s", detail)
156+
}
157+
groupIDsToRemove = append(groupIDsToRemove, *r.GroupID)
158+
}
159+
141160
// HandleMissingGroups will add the new groups to the org if
142161
// the settings specify. It will convert all group names into uuids
143162
// for easier assignment.
@@ -146,11 +165,10 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat
146165
return xerrors.Errorf("handle missing groups: %w", err)
147166
}
148167

149-
groupsToRemove = append(groupsToRemove, remove...)
150168
groupIDsToAdd = append(groupIDsToAdd, assignGroups...)
151169
}
152170

153-
err = s.applyGroupDifference(ctx, tx, user, groupIDsToAdd, groupsToRemove)
171+
err = s.applyGroupDifference(ctx, tx, user, groupIDsToAdd, groupIDsToRemove)
154172
if err != nil {
155173
return xerrors.Errorf("apply group difference: %w", err)
156174
}
@@ -165,28 +183,13 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat
165183
return nil
166184
}
167185

168-
func (s AGPLIDPSync) applyGroupDifference(ctx context.Context, tx database.Store, user database.User, add []uuid.UUID, remove []ExpectedGroup) error {
186+
func (s AGPLIDPSync) applyGroupDifference(ctx context.Context, tx database.Store, user database.User, add []uuid.UUID, removeIDs []uuid.UUID) error {
169187
// Always do group removal before group add. This way if there is an error,
170188
// we error on the underprivileged side.
171-
removeIDs := make([]uuid.UUID, 0)
172-
removeNames := make([]database.NameOrganizationPair, 0)
173-
for _, r := range remove {
174-
if r.GroupID != nil {
175-
removeIDs = append(removeIDs, *r.GroupID)
176-
} else if r.GroupName != nil {
177-
removeNames = append(removeNames, database.NameOrganizationPair{
178-
Name: *r.GroupName,
179-
OrganizationID: r.OrganizationID,
180-
})
181-
}
182-
}
183-
184-
// If there is something to remove, do it.
185-
if len(removeIDs) > 0 || len(removeNames) > 0 {
189+
if len(removeIDs) > 0 {
186190
removedGroupIDs, err := tx.RemoveUserFromGroups(ctx, database.RemoveUserFromGroupsParams{
187-
UserID: user.ID,
188-
GroupNames: removeNames,
189-
GroupIds: removeIDs,
191+
UserID: user.ID,
192+
GroupIds: removeIDs,
190193
})
191194
if err != nil {
192195
return xerrors.Errorf("remove user from %d groups: %w", len(removeIDs), err)
@@ -264,7 +267,7 @@ type ExpectedGroup struct {
264267
// the group "UUID 1234" is renamed, we want to maintain the mapping.
265268
// We have to keep names because group sync supports syncing groups by name if
266269
// the external IDP group name matches the Coder one.
267-
func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGroup, error) {
270+
func (s GroupSyncSettings) ParseClaims(orgID uuid.UUID, mergedClaims jwt.MapClaims) ([]ExpectedGroup, error) {
268271
groupsRaw, ok := mergedClaims[s.GroupField]
269272
if !ok {
270273
return []ExpectedGroup{}, nil
@@ -294,13 +297,13 @@ func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGr
294297
if ok {
295298
for _, gid := range mappedGroupIDs {
296299
gid := gid
297-
groups = append(groups, ExpectedGroup{GroupID: &gid})
300+
groups = append(groups, ExpectedGroup{OrganizationID: orgID, GroupID: &gid})
298301
}
299302
continue
300303
}
301304

302305
group := group
303-
groups = append(groups, ExpectedGroup{GroupName: &group})
306+
groups = append(groups, ExpectedGroup{OrganizationID: orgID, GroupName: &group})
304307
}
305308

306309
return groups, nil
@@ -312,38 +315,6 @@ func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGr
312315
// Missing groups are created if AutoCreate is enabled.
313316
// TODO: Batching this would be better, as this is 1 or 2 db calls per organization.
314317
func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database.Store, orgID uuid.UUID, add []ExpectedGroup) ([]uuid.UUID, error) {
315-
if !s.AutoCreateMissingGroups {
316-
// If we are not creating groups, then just construct a db lookup for
317-
// all groups by name.
318-
var lookups []string
319-
filter := make([]uuid.UUID, 0)
320-
for _, expected := range add {
321-
if expected.GroupID != nil {
322-
// Groups with IDs are easy!
323-
filter = append(filter, *expected.GroupID)
324-
} else if expected.GroupName != nil {
325-
lookups = append(lookups, *expected.GroupName)
326-
}
327-
}
328-
329-
if len(lookups) > 0 {
330-
// Do name lookups for all groups that are missing IDs.
331-
newGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{
332-
OrganizationID: uuid.UUID{},
333-
HasMemberID: uuid.UUID{},
334-
GroupNames: lookups,
335-
})
336-
if err != nil {
337-
return nil, xerrors.Errorf("get groups by names: %w", err)
338-
}
339-
for _, g := range newGroups {
340-
filter = append(filter, g.Group.ID)
341-
}
342-
}
343-
344-
return filter, nil
345-
}
346-
347318
// All expected that are missing IDs means the group does not exist
348319
// in the database. Either remove them, or create them if auto create is
349320
// turned on.
@@ -359,33 +330,33 @@ func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database.
359330
}
360331
}
361332

362-
createdMissingGroups, err := tx.InsertMissingGroups(ctx, database.InsertMissingGroupsParams{
363-
OrganizationID: orgID,
364-
Source: database.GroupSourceOidc,
365-
GroupNames: missingGroups,
366-
})
367-
if err != nil {
368-
return nil, xerrors.Errorf("insert missing groups: %w", err)
333+
if s.AutoCreateMissingGroups && len(missingGroups) > 0 {
334+
// Insert any missing groups. If the groups already exist, this is a noop.
335+
_, err := tx.InsertMissingGroups(ctx, database.InsertMissingGroupsParams{
336+
OrganizationID: orgID,
337+
Source: database.GroupSourceOidc,
338+
GroupNames: missingGroups,
339+
})
340+
if err != nil {
341+
return nil, xerrors.Errorf("insert missing groups: %w", err)
342+
}
369343
}
370344

371-
if len(missingGroups) != len(createdMissingGroups) {
372-
// This is unfortunate, but if legacy params are used, then some existing groups
373-
// can come as params. So we need to fetch them
374-
allGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{
345+
// Fetch any missing groups by name. If they exist, their IDs will be
346+
// matched and returned.
347+
if len(missingGroups) > 0 {
348+
// Do name lookups for all groups that are missing IDs.
349+
newGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{
375350
OrganizationID: orgID,
351+
HasMemberID: uuid.UUID{},
376352
GroupNames: missingGroups,
377353
})
378354
if err != nil {
379355
return nil, xerrors.Errorf("get groups by names: %w", err)
380356
}
381-
382-
createdMissingGroups = db2sdk.List(allGroups, func(g database.GetGroupsRow) database.Group {
383-
return g.Group
384-
})
385-
}
386-
387-
for _, created := range createdMissingGroups {
388-
addIDs = append(addIDs, created.ID)
357+
for _, g := range newGroups {
358+
addIDs = append(addIDs, g.Group.ID)
359+
}
389360
}
390361

391362
return addIDs, nil

coderd/idpsync/group_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ func TestGroupSyncTable(t *testing.T) {
205205
LegacyGroupNameMapping: map[string]string{
206206
"create-bar": "legacy-bar",
207207
"foo": "legacy-foo",
208+
"bop": "legacy-bop",
208209
},
209210
AutoCreateMissingGroups: true,
210211
},
@@ -214,6 +215,7 @@ func TestGroupSyncTable(t *testing.T) {
214215
GroupNames: map[string]bool{
215216
"legacy-foo": false,
216217
"extra": true,
218+
"legacy-bop": true,
217219
},
218220
ExpectedGroupNames: []string{
219221
"legacy-bar",

0 commit comments

Comments
 (0)