Skip to content

Commit 9887057

Browse files
committed
add unit test for ApplyGroupDifference
1 parent 8fefc9f commit 9887057

File tree

3 files changed

+159
-6
lines changed

3 files changed

+159
-6
lines changed

coderd/database/dbmem/dbmem.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -2703,18 +2703,18 @@ func (q *FakeQuerier) GetGroups(_ context.Context, arg database.GetGroupsParams)
27032703
q.mutex.RLock()
27042704
defer q.mutex.RUnlock()
27052705

2706-
groupIDs := make(map[uuid.UUID]struct{})
2706+
userGroupIDs := make(map[uuid.UUID]struct{})
27072707
if arg.HasMemberID != uuid.Nil {
27082708
for _, member := range q.groupMembers {
27092709
if member.UserID == arg.HasMemberID {
2710-
groupIDs[member.GroupID] = struct{}{}
2710+
userGroupIDs[member.GroupID] = struct{}{}
27112711
}
27122712
}
27132713

27142714
// Handle the everyone group
27152715
for _, orgMember := range q.organizationMembers {
27162716
if orgMember.UserID == arg.HasMemberID {
2717-
groupIDs[orgMember.OrganizationID] = struct{}{}
2717+
userGroupIDs[orgMember.OrganizationID] = struct{}{}
27182718
}
27192719
}
27202720
}
@@ -2726,7 +2726,7 @@ func (q *FakeQuerier) GetGroups(_ context.Context, arg database.GetGroupsParams)
27262726
continue
27272727
}
27282728

2729-
_, ok := groupIDs[group.ID]
2729+
_, ok := userGroupIDs[group.ID]
27302730
if arg.HasMemberID != uuid.Nil && !ok {
27312731
continue
27322732
}

coderd/idpsync/group.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat
168168
groupIDsToAdd = append(groupIDsToAdd, assignGroups...)
169169
}
170170

171-
err = s.applyGroupDifference(ctx, tx, user, groupIDsToAdd, groupIDsToRemove)
171+
err = s.ApplyGroupDifference(ctx, tx, user, groupIDsToAdd, groupIDsToRemove)
172172
if err != nil {
173173
return xerrors.Errorf("apply group difference: %w", err)
174174
}
@@ -183,7 +183,8 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat
183183
return nil
184184
}
185185

186-
func (s AGPLIDPSync) applyGroupDifference(ctx context.Context, tx database.Store, user database.User, add []uuid.UUID, removeIDs []uuid.UUID) error {
186+
// ApplyGroupDifference will add and remove the user from the specified groups.
187+
func (s AGPLIDPSync) ApplyGroupDifference(ctx context.Context, tx database.Store, user database.User, add []uuid.UUID, removeIDs []uuid.UUID) error {
187188
// Always do group removal before group add. This way if there is an error,
188189
// we error on the underprivileged side.
189190
if len(removeIDs) > 0 {

coderd/idpsync/group_test.go

+152
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,158 @@ func TestGroupSyncTable(t *testing.T) {
342342
})
343343
}
344344

345+
// TestApplyGroupDifference is mainly testing the database functions
346+
func TestApplyGroupDifference(t *testing.T) {
347+
t.Parallel()
348+
349+
ids := coderdtest.NewDeterministicUUIDGenerator()
350+
testCase := []struct {
351+
Name string
352+
Before map[uuid.UUID]bool
353+
Add []uuid.UUID
354+
Remove []uuid.UUID
355+
Expect []uuid.UUID
356+
}{
357+
{
358+
Name: "Empty",
359+
},
360+
{
361+
Name: "AddFromNone",
362+
Before: map[uuid.UUID]bool{
363+
ids.ID("g1"): false,
364+
},
365+
Add: []uuid.UUID{
366+
ids.ID("g1"),
367+
},
368+
Expect: []uuid.UUID{
369+
ids.ID("g1"),
370+
},
371+
},
372+
{
373+
Name: "AddSome",
374+
Before: map[uuid.UUID]bool{
375+
ids.ID("g1"): true,
376+
ids.ID("g2"): false,
377+
ids.ID("g3"): false,
378+
uuid.New(): false,
379+
},
380+
Add: []uuid.UUID{
381+
ids.ID("g2"),
382+
ids.ID("g3"),
383+
},
384+
Expect: []uuid.UUID{
385+
ids.ID("g1"),
386+
ids.ID("g2"),
387+
ids.ID("g3"),
388+
},
389+
},
390+
{
391+
Name: "RemoveAll",
392+
Before: map[uuid.UUID]bool{
393+
uuid.New(): false,
394+
ids.ID("g2"): true,
395+
ids.ID("g3"): true,
396+
},
397+
Remove: []uuid.UUID{
398+
ids.ID("g2"),
399+
ids.ID("g3"),
400+
},
401+
Expect: []uuid.UUID{},
402+
},
403+
{
404+
Name: "Mixed",
405+
Before: map[uuid.UUID]bool{
406+
// adds
407+
ids.ID("a1"): true,
408+
ids.ID("a2"): true,
409+
ids.ID("a3"): false,
410+
ids.ID("a4"): false,
411+
// removes
412+
ids.ID("r1"): true,
413+
ids.ID("r2"): true,
414+
ids.ID("r3"): false,
415+
ids.ID("r4"): false,
416+
// stable
417+
ids.ID("s1"): true,
418+
ids.ID("s2"): true,
419+
// noise
420+
uuid.New(): false,
421+
uuid.New(): false,
422+
},
423+
Add: []uuid.UUID{
424+
ids.ID("a1"), ids.ID("a2"),
425+
ids.ID("a3"), ids.ID("a4"),
426+
// Double up to try and confuse
427+
ids.ID("a1"),
428+
ids.ID("a4"),
429+
},
430+
Remove: []uuid.UUID{
431+
ids.ID("r1"), ids.ID("r2"),
432+
ids.ID("r3"), ids.ID("r4"),
433+
// Double up to try and confuse
434+
ids.ID("r1"),
435+
ids.ID("r4"),
436+
},
437+
Expect: []uuid.UUID{
438+
ids.ID("a1"), ids.ID("a2"), ids.ID("a3"), ids.ID("a4"),
439+
ids.ID("s1"), ids.ID("s2"),
440+
},
441+
},
442+
}
443+
444+
for _, tc := range testCase {
445+
tc := tc
446+
t.Run(tc.Name, func(t *testing.T) {
447+
mgr := runtimeconfig.NewStoreManager()
448+
db, _ := dbtestutil.NewDB(t)
449+
450+
ctx := testutil.Context(t, testutil.WaitMedium)
451+
//nolint:gocritic // testing
452+
ctx = dbauthz.AsSystemRestricted(ctx)
453+
454+
org := dbgen.Organization(t, db, database.Organization{})
455+
_, err := db.InsertAllUsersGroup(ctx, org.ID)
456+
require.NoError(t, err)
457+
458+
user := dbgen.User(t, db, database.User{})
459+
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
460+
UserID: user.ID,
461+
OrganizationID: org.ID,
462+
})
463+
464+
for gid, in := range tc.Before {
465+
group := dbgen.Group(t, db, database.Group{
466+
ID: gid,
467+
OrganizationID: org.ID,
468+
})
469+
if in {
470+
_ = dbgen.GroupMember(t, db, database.GroupMemberTable{
471+
UserID: user.ID,
472+
GroupID: group.ID,
473+
})
474+
}
475+
}
476+
477+
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), mgr, idpsync.FromDeploymentValues(coderdtest.DeploymentValues(t)))
478+
err = s.ApplyGroupDifference(context.Background(), db, user, tc.Add, tc.Remove)
479+
require.NoError(t, err)
480+
481+
userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{
482+
HasMemberID: user.ID,
483+
})
484+
require.NoError(t, err)
485+
486+
// assert
487+
found := db2sdk.List(userGroups, func(g database.GetGroupsRow) uuid.UUID {
488+
return g.Group.ID
489+
})
490+
491+
// Add everyone group
492+
require.ElementsMatch(t, append(tc.Expect, org.ID), found)
493+
})
494+
}
495+
}
496+
345497
func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, user database.User, orgID uuid.UUID, def orgSetupDefinition) {
346498
t.Helper()
347499

0 commit comments

Comments
 (0)