@@ -3,6 +3,7 @@ package idpsync
3
3
import (
4
4
"context"
5
5
"encoding/json"
6
+ "fmt"
6
7
"regexp"
7
8
8
9
"github.com/golang-jwt/jwt/v4"
@@ -92,15 +93,15 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat
92
93
93
94
// collect all diffs to do 1 sql update for all orgs
94
95
groupIDsToAdd := make ([]uuid.UUID , 0 )
95
- groupsToRemove := make ([]ExpectedGroup , 0 )
96
+ groupIDsToRemove := make ([]uuid. UUID , 0 )
96
97
// For each org, determine which groups the user should land in
97
98
for orgID , settings := range orgSettings {
98
99
if settings .GroupField == "" {
99
100
// No group sync enabled for this org, so do nothing.
100
101
continue
101
102
}
102
103
103
- expectedGroups , err := settings .ParseClaims (params .MergedClaims )
104
+ expectedGroups , err := settings .ParseClaims (orgID , params .MergedClaims )
104
105
if err != nil {
105
106
s .Logger .Debug (ctx , "failed to parse claims for groups" ,
106
107
slog .F ("organization_field" , s .GroupField ),
@@ -128,6 +129,10 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat
128
129
}
129
130
})
130
131
add , remove := slice .SymmetricDifferenceFunc (existingGroupsTyped , expectedGroups , func (a , b ExpectedGroup ) bool {
132
+ // Must match
133
+ if a .OrganizationID != b .OrganizationID {
134
+ return false
135
+ }
131
136
// Only the name or the name needs to be checked, priority is given to the ID.
132
137
if a .GroupID != nil && b .GroupID != nil {
133
138
return * a .GroupID == * b .GroupID
@@ -138,6 +143,20 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat
138
143
return false
139
144
})
140
145
146
+ for _ , r := range remove {
147
+ // This should never happen. All group removals come from the
148
+ // existing set, which come from the db. All groups from the
149
+ // database have IDs. This code is purely defensive.
150
+ if r .GroupID == nil {
151
+ detail := "user:" + user .Username
152
+ if r .GroupName != nil {
153
+ detail += fmt .Sprintf (" from group %s" , * r .GroupName )
154
+ }
155
+ return xerrors .Errorf ("removal group has nil ID, which should never happen: %s" , detail )
156
+ }
157
+ groupIDsToRemove = append (groupIDsToRemove , * r .GroupID )
158
+ }
159
+
141
160
// HandleMissingGroups will add the new groups to the org if
142
161
// the settings specify. It will convert all group names into uuids
143
162
// for easier assignment.
@@ -146,11 +165,10 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat
146
165
return xerrors .Errorf ("handle missing groups: %w" , err )
147
166
}
148
167
149
- groupsToRemove = append (groupsToRemove , remove ... )
150
168
groupIDsToAdd = append (groupIDsToAdd , assignGroups ... )
151
169
}
152
170
153
- err = s .applyGroupDifference (ctx , tx , user , groupIDsToAdd , groupsToRemove )
171
+ err = s .applyGroupDifference (ctx , tx , user , groupIDsToAdd , groupIDsToRemove )
154
172
if err != nil {
155
173
return xerrors .Errorf ("apply group difference: %w" , err )
156
174
}
@@ -165,28 +183,13 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat
165
183
return nil
166
184
}
167
185
168
- func (s AGPLIDPSync ) applyGroupDifference (ctx context.Context , tx database.Store , user database.User , add []uuid.UUID , remove []ExpectedGroup ) error {
186
+ func (s AGPLIDPSync ) applyGroupDifference (ctx context.Context , tx database.Store , user database.User , add []uuid.UUID , removeIDs []uuid. UUID ) error {
169
187
// Always do group removal before group add. This way if there is an error,
170
188
// we error on the underprivileged side.
171
- removeIDs := make ([]uuid.UUID , 0 )
172
- removeNames := make ([]database.NameOrganizationPair , 0 )
173
- for _ , r := range remove {
174
- if r .GroupID != nil {
175
- removeIDs = append (removeIDs , * r .GroupID )
176
- } else if r .GroupName != nil {
177
- removeNames = append (removeNames , database.NameOrganizationPair {
178
- Name : * r .GroupName ,
179
- OrganizationID : r .OrganizationID ,
180
- })
181
- }
182
- }
183
-
184
- // If there is something to remove, do it.
185
- if len (removeIDs ) > 0 || len (removeNames ) > 0 {
189
+ if len (removeIDs ) > 0 {
186
190
removedGroupIDs , err := tx .RemoveUserFromGroups (ctx , database.RemoveUserFromGroupsParams {
187
- UserID : user .ID ,
188
- GroupNames : removeNames ,
189
- GroupIds : removeIDs ,
191
+ UserID : user .ID ,
192
+ GroupIds : removeIDs ,
190
193
})
191
194
if err != nil {
192
195
return xerrors .Errorf ("remove user from %d groups: %w" , len (removeIDs ), err )
@@ -264,7 +267,7 @@ type ExpectedGroup struct {
264
267
// the group "UUID 1234" is renamed, we want to maintain the mapping.
265
268
// We have to keep names because group sync supports syncing groups by name if
266
269
// the external IDP group name matches the Coder one.
267
- func (s GroupSyncSettings ) ParseClaims (mergedClaims jwt.MapClaims ) ([]ExpectedGroup , error ) {
270
+ func (s GroupSyncSettings ) ParseClaims (orgID uuid. UUID , mergedClaims jwt.MapClaims ) ([]ExpectedGroup , error ) {
268
271
groupsRaw , ok := mergedClaims [s .GroupField ]
269
272
if ! ok {
270
273
return []ExpectedGroup {}, nil
@@ -294,13 +297,13 @@ func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGr
294
297
if ok {
295
298
for _ , gid := range mappedGroupIDs {
296
299
gid := gid
297
- groups = append (groups , ExpectedGroup {GroupID : & gid })
300
+ groups = append (groups , ExpectedGroup {OrganizationID : orgID , GroupID : & gid })
298
301
}
299
302
continue
300
303
}
301
304
302
305
group := group
303
- groups = append (groups , ExpectedGroup {GroupName : & group })
306
+ groups = append (groups , ExpectedGroup {OrganizationID : orgID , GroupName : & group })
304
307
}
305
308
306
309
return groups , nil
@@ -312,38 +315,6 @@ func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGr
312
315
// Missing groups are created if AutoCreate is enabled.
313
316
// TODO: Batching this would be better, as this is 1 or 2 db calls per organization.
314
317
func (s GroupSyncSettings ) HandleMissingGroups (ctx context.Context , tx database.Store , orgID uuid.UUID , add []ExpectedGroup ) ([]uuid.UUID , error ) {
315
- if ! s .AutoCreateMissingGroups {
316
- // If we are not creating groups, then just construct a db lookup for
317
- // all groups by name.
318
- var lookups []string
319
- filter := make ([]uuid.UUID , 0 )
320
- for _ , expected := range add {
321
- if expected .GroupID != nil {
322
- // Groups with IDs are easy!
323
- filter = append (filter , * expected .GroupID )
324
- } else if expected .GroupName != nil {
325
- lookups = append (lookups , * expected .GroupName )
326
- }
327
- }
328
-
329
- if len (lookups ) > 0 {
330
- // Do name lookups for all groups that are missing IDs.
331
- newGroups , err := tx .GetGroups (ctx , database.GetGroupsParams {
332
- OrganizationID : uuid.UUID {},
333
- HasMemberID : uuid.UUID {},
334
- GroupNames : lookups ,
335
- })
336
- if err != nil {
337
- return nil , xerrors .Errorf ("get groups by names: %w" , err )
338
- }
339
- for _ , g := range newGroups {
340
- filter = append (filter , g .Group .ID )
341
- }
342
- }
343
-
344
- return filter , nil
345
- }
346
-
347
318
// All expected that are missing IDs means the group does not exist
348
319
// in the database. Either remove them, or create them if auto create is
349
320
// turned on.
@@ -359,33 +330,33 @@ func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database.
359
330
}
360
331
}
361
332
362
- createdMissingGroups , err := tx .InsertMissingGroups (ctx , database.InsertMissingGroupsParams {
363
- OrganizationID : orgID ,
364
- Source : database .GroupSourceOidc ,
365
- GroupNames : missingGroups ,
366
- })
367
- if err != nil {
368
- return nil , xerrors .Errorf ("insert missing groups: %w" , err )
333
+ if s .AutoCreateMissingGroups && len (missingGroups ) > 0 {
334
+ // Insert any missing groups. If the groups already exist, this is a noop.
335
+ _ , err := tx .InsertMissingGroups (ctx , database.InsertMissingGroupsParams {
336
+ OrganizationID : orgID ,
337
+ Source : database .GroupSourceOidc ,
338
+ GroupNames : missingGroups ,
339
+ })
340
+ if err != nil {
341
+ return nil , xerrors .Errorf ("insert missing groups: %w" , err )
342
+ }
369
343
}
370
344
371
- if len (missingGroups ) != len (createdMissingGroups ) {
372
- // This is unfortunate, but if legacy params are used, then some existing groups
373
- // can come as params. So we need to fetch them
374
- allGroups , err := tx .GetGroups (ctx , database.GetGroupsParams {
345
+ // Fetch any missing groups by name. If they exist, their IDs will be
346
+ // matched and returned.
347
+ if len (missingGroups ) > 0 {
348
+ // Do name lookups for all groups that are missing IDs.
349
+ newGroups , err := tx .GetGroups (ctx , database.GetGroupsParams {
375
350
OrganizationID : orgID ,
351
+ HasMemberID : uuid.UUID {},
376
352
GroupNames : missingGroups ,
377
353
})
378
354
if err != nil {
379
355
return nil , xerrors .Errorf ("get groups by names: %w" , err )
380
356
}
381
-
382
- createdMissingGroups = db2sdk .List (allGroups , func (g database.GetGroupsRow ) database.Group {
383
- return g .Group
384
- })
385
- }
386
-
387
- for _ , created := range createdMissingGroups {
388
- addIDs = append (addIDs , created .ID )
357
+ for _ , g := range newGroups {
358
+ addIDs = append (addIDs , g .Group .ID )
359
+ }
389
360
}
390
361
391
362
return addIDs , nil
0 commit comments