Skip to content

Commit 5d0f729

Browse files
committed
Begin unit testing work
1 parent b1ece73 commit 5d0f729

File tree

3 files changed

+139
-85
lines changed

3 files changed

+139
-85
lines changed

coderd/database/dbmem/dbmem.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8630,7 +8630,7 @@ func (q *FakeQuerier) UpdateUserRoles(_ context.Context, arg database.UpdateUser
86308630
}
86318631

86328632
// Set new roles
8633-
user.RBACRoles = arg.GrantedRoles
8633+
user.RBACRoles = slice.Unique(arg.GrantedRoles)
86348634
// Remove duplicates and sort
86358635
uniqueRoles := make([]string, 0, len(user.RBACRoles))
86368636
exist := make(map[string]struct{})

coderd/idpsync/role.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/coder/coder/v2/coderd/rbac"
1616
"github.com/coder/coder/v2/coderd/rbac/rolestore"
1717
"github.com/coder/coder/v2/coderd/runtimeconfig"
18+
"github.com/coder/coder/v2/coderd/util/slice"
1819
)
1920

2021
type RoleParams struct {
@@ -159,8 +160,15 @@ func (s AGPLIDPSync) SyncRoles(ctx context.Context, db database.Store, user data
159160
return s == rbac.RoleOrgMember()
160161
})
161162

163+
// Only care about unique roles. So remove all duplicates
164+
existingFound = slice.Unique(existingFound)
165+
validExpected = slice.Unique(validExpected)
166+
// A sort is required for the equality check
167+
slices.Sort(existingFound)
168+
slices.Sort(validExpected)
162169
// Is there a difference between the expected roles and the existing roles?
163170
if !slices.Equal(existingFound, validExpected) {
171+
// TODO: Write a unit test to verify we do no db call on no diff
164172
_, err = tx.UpdateMemberRoles(ctx, database.UpdateMemberRolesParams{
165173
GrantedRoles: validExpected,
166174
UserID: user.ID,
@@ -189,6 +197,8 @@ func (s AGPLIDPSync) syncSiteWideRoles(ctx context.Context, tx database.Store, u
189197
for _, role := range params.SiteWideRoles {
190198
// Because we are only syncing site wide roles, we intentionally will always
191199
// omit 'OrganizationID' from the RoleIdentifier.
200+
// TODO: If custom site wide roles are introduced, this needs to use the
201+
// database to verify the role exists.
192202
if _, err := rbac.RoleByName(rbac.RoleIdentifier{Name: role}); err == nil {
193203
filtered = append(filtered, role)
194204
} else {

coderd/idpsync/role_test.go

Lines changed: 128 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ import (
66
"github.com/golang-jwt/jwt/v4"
77
"github.com/google/uuid"
88
"github.com/stretchr/testify/require"
9+
"golang.org/x/exp/slices"
910

1011
"cdr.dev/slog/sloggers/slogtest"
1112
"github.com/coder/coder/v2/coderd/database"
13+
"github.com/coder/coder/v2/coderd/database/dbauthz"
1214
"github.com/coder/coder/v2/coderd/database/dbgen"
1315
"github.com/coder/coder/v2/coderd/database/dbtestutil"
1416
"github.com/coder/coder/v2/coderd/idpsync"
@@ -20,7 +22,6 @@ import (
2022
func TestRoleSyncTable(t *testing.T) {
2123
t.Parallel()
2224

23-
// Last checked, takes 30s with postgres on a fast machine.
2425
if dbtestutil.WillUsePostgres() {
2526
t.Skip("Skipping test because it populates a lot of db entries, which is slow on postgres.")
2627
}
@@ -33,9 +34,9 @@ func TestRoleSyncTable(t *testing.T) {
3334
},
3435
// bad-claim is a number, and will fail any role sync
3536
"bad-claim": 100,
37+
"empty": []string{},
3638
}
3739

38-
//ids := coderdtest.NewDeterministicUUIDGenerator()
3940
testCases := []orgSetupDefinition{
4041
{
4142
Name: "NoSync",
@@ -125,6 +126,62 @@ func TestRoleSyncTable(t *testing.T) {
125126
},
126127
},
127128
},
129+
{
130+
Name: "NoChange",
131+
OrganizationRoles: []string{rbac.RoleOrgAdmin(), rbac.RoleOrgTemplateAdmin(), rbac.RoleOrgAuditor()},
132+
RoleSettings: &idpsync.RoleSyncSettings{
133+
Field: "roles",
134+
Mapping: map[string][]string{
135+
"foo": {rbac.RoleOrgAuditor(), rbac.RoleOrgTemplateAdmin()},
136+
"bar": {rbac.RoleOrgAdmin()},
137+
},
138+
},
139+
assertRoles: &orgRoleAssert{
140+
ExpectedOrgRoles: []string{
141+
rbac.RoleOrgAdmin(), rbac.RoleOrgAuditor(), rbac.RoleOrgTemplateAdmin(),
142+
},
143+
},
144+
},
145+
{
146+
// InvalidOriginalRole starts the user with an invalid role.
147+
// In practice, this should not happen, as it means a role was
148+
// inserted into the database that does not exist.
149+
// For the purposes of syncing, it does not matter, and the sync
150+
// should succeed.
151+
Name: "InvalidOriginalRole",
152+
OrganizationRoles: []string{"something-bad"},
153+
RoleSettings: &idpsync.RoleSyncSettings{
154+
Field: "roles",
155+
Mapping: map[string][]string{},
156+
},
157+
assertRoles: &orgRoleAssert{
158+
ExpectedOrgRoles: []string{
159+
rbac.RoleOrgAuditor(),
160+
},
161+
},
162+
},
163+
{
164+
Name: "NonExistentClaim",
165+
OrganizationRoles: []string{rbac.RoleOrgAuditor()},
166+
RoleSettings: &idpsync.RoleSyncSettings{
167+
Field: "not-exists",
168+
Mapping: map[string][]string{},
169+
},
170+
assertRoles: &orgRoleAssert{
171+
ExpectedOrgRoles: []string{},
172+
},
173+
},
174+
{
175+
Name: "EmptyClaim",
176+
OrganizationRoles: []string{rbac.RoleOrgAuditor()},
177+
RoleSettings: &idpsync.RoleSyncSettings{
178+
Field: "empty",
179+
Mapping: map[string][]string{},
180+
},
181+
assertRoles: &orgRoleAssert{
182+
ExpectedOrgRoles: []string{},
183+
},
184+
},
128185
}
129186

130187
for _, tc := range testCases {
@@ -148,7 +205,7 @@ func TestRoleSyncTable(t *testing.T) {
148205
orgID := uuid.New()
149206
SetupOrganization(t, s, db, user, orgID, tc)
150207

151-
// Do the group sync!
208+
// Do the role sync!
152209
err := s.SyncRoles(ctx, db, user, idpsync.RoleParams{
153210
SyncEnabled: true,
154211
SyncSiteWide: false,
@@ -164,85 +221,72 @@ func TestRoleSyncTable(t *testing.T) {
164221
// deployment. This tests all organizations being synced together.
165222
// The reason we do them individually, is that it is much easier to
166223
// debug a single test case.
167-
//t.Run("AllTogether", func(t *testing.T) {
168-
// t.Parallel()
169-
//
170-
// db, _ := dbtestutil.NewDB(t)
171-
// manager := runtimeconfig.NewManager()
172-
// s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
173-
// manager,
174-
// // Also sync the default org!
175-
// idpsync.DeploymentSyncSettings{
176-
// GroupField: "groups",
177-
// Legacy: idpsync.DefaultOrgLegacySettings{
178-
// GroupField: "groups",
179-
// GroupMapping: map[string]string{
180-
// "foo": "legacy-foo",
181-
// "baz": "legacy-baz",
182-
// },
183-
// GroupFilter: regexp.MustCompile("^legacy"),
184-
// CreateMissingGroups: true,
185-
// },
186-
// },
187-
// )
188-
//
189-
// ctx := testutil.Context(t, testutil.WaitSuperLong)
190-
// user := dbgen.User(t, db, database.User{})
191-
//
192-
// var asserts []func(t *testing.T)
193-
// // The default org is also going to do something
194-
// def := orgSetupDefinition{
195-
// Name: "DefaultOrg",
196-
// GroupNames: map[string]bool{
197-
// "legacy-foo": false,
198-
// "legacy-baz": true,
199-
// "random": true,
200-
// },
201-
// // No settings, because they come from the deployment values
202-
// GroupSettings: nil,
203-
// assertGroups: &orgGroupAssert{
204-
// ExpectedGroupNames: []string{"legacy-foo", "legacy-baz", "legacy-bar"},
205-
// },
206-
// }
207-
//
208-
// //nolint:gocritic // testing
209-
// defOrg, err := db.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx))
210-
// require.NoError(t, err)
211-
// SetupOrganization(t, s, db, user, defOrg.ID, def)
212-
// asserts = append(asserts, func(t *testing.T) {
213-
// t.Run(def.Name, func(t *testing.T) {
214-
// t.Parallel()
215-
// def.Assert(t, defOrg.ID, db, user)
216-
// })
217-
// })
218-
//
219-
// for _, tc := range testCases {
220-
// tc := tc
221-
//
222-
// orgID := uuid.New()
223-
// SetupOrganization(t, s, db, user, orgID, tc)
224-
// asserts = append(asserts, func(t *testing.T) {
225-
// t.Run(tc.Name, func(t *testing.T) {
226-
// t.Parallel()
227-
// tc.Assert(t, orgID, db, user)
228-
// })
229-
// })
230-
// }
231-
//
232-
// asserts = append(asserts, func(t *testing.T) {
233-
// t.Helper()
234-
// def.Assert(t, defOrg.ID, db, user)
235-
// })
236-
//
237-
// // Do the group sync!
238-
// err = s.SyncGroups(ctx, db, user, idpsync.GroupParams{
239-
// SyncEnabled: true,
240-
// MergedClaims: userClaims,
241-
// })
242-
// require.NoError(t, err)
243-
//
244-
// for _, assert := range asserts {
245-
// assert(t)
246-
// }
247-
//})
224+
t.Run("AllTogether", func(t *testing.T) {
225+
t.Parallel()
226+
227+
db, _ := dbtestutil.NewDB(t)
228+
manager := runtimeconfig.NewManager()
229+
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{
230+
IgnoreErrors: true,
231+
}),
232+
manager,
233+
// Also sync some site wide roles
234+
idpsync.DeploymentSyncSettings{
235+
GroupField: "groups",
236+
SiteRoleField: "roles",
237+
// Site sync settings do not matter,
238+
// as we are not testing the site parse here.
239+
// Only the sync, assuming the parse is correct.
240+
},
241+
)
242+
243+
ctx := testutil.Context(t, testutil.WaitSuperLong)
244+
user := dbgen.User(t, db, database.User{})
245+
246+
var asserts []func(t *testing.T)
247+
248+
for _, tc := range testCases {
249+
tc := tc
250+
251+
orgID := uuid.New()
252+
SetupOrganization(t, s, db, user, orgID, tc)
253+
asserts = append(asserts, func(t *testing.T) {
254+
t.Run(tc.Name, func(t *testing.T) {
255+
t.Parallel()
256+
tc.Assert(t, orgID, db, user)
257+
})
258+
})
259+
}
260+
261+
err := s.SyncRoles(ctx, db, user, idpsync.RoleParams{
262+
SyncEnabled: true,
263+
SyncSiteWide: true,
264+
SiteWideRoles: []string{
265+
rbac.RoleTemplateAdmin().Name, // Duplicate this value to test deduplication
266+
rbac.RoleTemplateAdmin().Name, rbac.RoleAuditor().Name,
267+
},
268+
MergedClaims: userClaims,
269+
})
270+
require.NoError(t, err)
271+
272+
for _, assert := range asserts {
273+
assert(t)
274+
}
275+
276+
// Also assert site wide roles
277+
//nolint:gocritic // unit testing assertions
278+
allRoles, err := db.GetAuthorizationUserRoles(dbauthz.AsSystemRestricted(ctx), user.ID)
279+
require.NoError(t, err)
280+
281+
allRoleIDs, err := allRoles.RoleNames()
282+
require.NoError(t, err)
283+
284+
siteRoles := slices.DeleteFunc(allRoleIDs, func(r rbac.RoleIdentifier) bool {
285+
return r.IsOrgRole()
286+
})
287+
288+
require.ElementsMatch(t, []rbac.RoleIdentifier{
289+
rbac.RoleTemplateAdmin(), rbac.RoleAuditor(), rbac.RoleMember(),
290+
}, siteRoles)
291+
})
248292
}

0 commit comments

Comments
 (0)