@@ -91,8 +91,8 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat
91
91
}
92
92
93
93
// collect all diffs to do 1 sql update for all orgs
94
- groupsToAdd := make ([]uuid.UUID , 0 )
95
- groupsToRemove := make ([]uuid. UUID , 0 )
94
+ groupIDsToAdd := make ([]uuid.UUID , 0 )
95
+ groupsToRemove := make ([]ExpectedGroup , 0 )
96
96
// For each org, determine which groups the user should land in
97
97
for orgID , settings := range orgSettings {
98
98
if settings .GroupField == "" {
@@ -112,7 +112,8 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat
112
112
}
113
113
// Everyone group is always implied.
114
114
expectedGroups = append (expectedGroups , ExpectedGroup {
115
- GroupID : & orgID ,
115
+ OrganizationID : orgID ,
116
+ GroupID : & orgID ,
116
117
})
117
118
118
119
// Now we know what groups the user should be in for a given org,
@@ -121,8 +122,9 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat
121
122
existingGroups := userOrgs [orgID ]
122
123
existingGroupsTyped := db2sdk .List (existingGroups , func (f database.GetGroupsRow ) ExpectedGroup {
123
124
return ExpectedGroup {
124
- GroupID : & f .Group .ID ,
125
- GroupName : & f .Group .Name ,
125
+ OrganizationID : orgID ,
126
+ GroupID : & f .Group .ID ,
127
+ GroupName : & f .Group .Name ,
126
128
}
127
129
})
128
130
add , remove := slice .SymmetricDifferenceFunc (existingGroupsTyped , expectedGroups , func (a , b ExpectedGroup ) bool {
@@ -144,52 +146,75 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat
144
146
return xerrors .Errorf ("handle missing groups: %w" , err )
145
147
}
146
148
147
- for _ , removeGroup := range remove {
148
- // This should always be the case.
149
- // TODO: make sure this is always the case
150
- if removeGroup .GroupID != nil {
151
- groupsToRemove = append (groupsToRemove , * removeGroup .GroupID )
152
- }
153
- }
149
+ groupsToRemove = append (groupsToRemove , remove ... )
150
+ groupIDsToAdd = append (groupIDsToAdd , assignGroups ... )
151
+ }
154
152
155
- groupsToAdd = append (groupsToAdd , assignGroups ... )
153
+ err = s .applyGroupDifference (ctx , tx , user , groupIDsToAdd , groupsToRemove )
154
+ if err != nil {
155
+ return xerrors .Errorf ("apply group difference: %w" , err )
156
156
}
157
157
158
- assignedGroupIDs , err := tx .InsertUserGroupsByID (ctx , database.InsertUserGroupsByIDParams {
159
- UserID : user .ID ,
160
- GroupIds : groupsToAdd ,
158
+ return nil
159
+ }, nil )
160
+
161
+ if err != nil {
162
+ return err
163
+ }
164
+
165
+ return nil
166
+ }
167
+
168
+ func (s AGPLIDPSync ) applyGroupDifference (ctx context.Context , tx database.Store , user database.User , add []uuid.UUID , remove []ExpectedGroup ) error {
169
+ // Always do group removal before group add. This way if there is an error,
170
+ // we error on the underprivileged side.
171
+ removeIDs := make ([]uuid.UUID , 0 )
172
+ removeNames := make ([]database.NameOrganizationPair , 0 )
173
+ for _ , r := range remove {
174
+ if r .GroupID != nil {
175
+ removeIDs = append (removeIDs , * r .GroupID )
176
+ } else if r .GroupName != nil {
177
+ removeNames = append (removeNames , database.NameOrganizationPair {
178
+ Name : * r .GroupName ,
179
+ OrganizationID : r .OrganizationID ,
180
+ })
181
+ }
182
+ }
183
+
184
+ // If there is something to remove, do it.
185
+ if len (removeIDs ) > 0 || len (removeNames ) > 0 {
186
+ removedGroupIDs , err := tx .RemoveUserFromGroups (ctx , database.RemoveUserFromGroupsParams {
187
+ UserID : user .ID ,
188
+ GroupNames : removeNames ,
189
+ GroupIds : removeIDs ,
161
190
})
162
191
if err != nil {
163
- return xerrors .Errorf ("insert user into %d groups: %w" , len (groupsToAdd ), err )
192
+ return xerrors .Errorf ("remove user from %d groups: %w" , len (removeIDs ), err )
164
193
}
165
- if len (assignedGroupIDs ) != len (groupsToAdd ) {
166
- s .Logger .Debug (ctx , "failed to assign all groups to user " ,
194
+ if len (removedGroupIDs ) != len (removeIDs ) {
195
+ s .Logger .Debug (ctx , "failed to remove user from all groups " ,
167
196
slog .F ("user_id" , user .ID ),
168
- slog .F ("groups_assigned_count " , len (assignedGroupIDs )),
169
- slog .F ("expected_count" , len (groupsToAdd )),
197
+ slog .F ("groups_removed_count " , len (removedGroupIDs )),
198
+ slog .F ("expected_count" , len (removeIDs )),
170
199
)
171
200
}
201
+ }
172
202
173
- removedGroupIDs , err := tx .RemoveUserFromGroups (ctx , database.RemoveUserFromGroupsParams {
203
+ if len (add ) > 0 {
204
+ assignedGroupIDs , err := tx .InsertUserGroupsByID (ctx , database.InsertUserGroupsByIDParams {
174
205
UserID : user .ID ,
175
- GroupIds : groupsToRemove ,
206
+ GroupIds : add ,
176
207
})
177
208
if err != nil {
178
- return xerrors .Errorf ("remove user from %d groups: %w" , len (groupsToRemove ), err )
209
+ return xerrors .Errorf ("insert user into %d groups: %w" , len (add ), err )
179
210
}
180
- if len (removedGroupIDs ) != len (groupsToRemove ) {
181
- s .Logger .Debug (ctx , "failed to remove user from all groups" ,
211
+ if len (assignedGroupIDs ) != len (add ) {
212
+ s .Logger .Debug (ctx , "failed to assign all groups to user " ,
182
213
slog .F ("user_id" , user .ID ),
183
- slog .F ("groups_removed_count " , len (removedGroupIDs )),
184
- slog .F ("expected_count" , len (groupsToRemove )),
214
+ slog .F ("groups_assigned_count " , len (assignedGroupIDs )),
215
+ slog .F ("expected_count" , len (add )),
185
216
)
186
217
}
187
-
188
- return nil
189
- }, nil )
190
-
191
- if err != nil {
192
- return err
193
218
}
194
219
195
220
return nil
@@ -226,8 +251,9 @@ func (s *GroupSyncSettings) Type() string {
226
251
}
227
252
228
253
type ExpectedGroup struct {
229
- GroupID * uuid.UUID
230
- GroupName * string
254
+ OrganizationID uuid.UUID
255
+ GroupID * uuid.UUID
256
+ GroupName * string
231
257
}
232
258
233
259
// ParseClaims will take the merged claims from the IDP and return the groups
@@ -280,20 +306,28 @@ func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGr
280
306
return groups , nil
281
307
}
282
308
309
+ // HandleMissingGroups ensures all ExpectedGroups convert to uuids.
310
+ // Groups can be referenced by name via legacy params or IDP group names.
311
+ // These group names are converted to IDs for easier assignment.
312
+ // Missing groups are created if AutoCreate is enabled.
313
+ // TODO: Batching this would be better, as this is 1 or 2 db calls per organization.
283
314
func (s GroupSyncSettings ) HandleMissingGroups (ctx context.Context , tx database.Store , orgID uuid.UUID , add []ExpectedGroup ) ([]uuid.UUID , error ) {
284
315
if ! s .AutoCreateMissingGroups {
285
- // construct the list of groups to search by name to see if they exist.
316
+ // If we are not creating groups, then just construct a db lookup for
317
+ // all groups by name.
286
318
var lookups []string
287
319
filter := make ([]uuid.UUID , 0 )
288
320
for _ , expected := range add {
289
321
if expected .GroupID != nil {
322
+ // Groups with IDs are easy!
290
323
filter = append (filter , * expected .GroupID )
291
324
} else if expected .GroupName != nil {
292
325
lookups = append (lookups , * expected .GroupName )
293
326
}
294
327
}
295
328
296
329
if len (lookups ) > 0 {
330
+ // Do name lookups for all groups that are missing IDs.
297
331
newGroups , err := tx .GetGroups (ctx , database.GetGroupsParams {
298
332
OrganizationID : uuid.UUID {},
299
333
HasMemberID : uuid.UUID {},
0 commit comments