Skip to content

Commit 7627933

Browse files
committed
fix tests
1 parent 9ef0e0d commit 7627933

File tree

7 files changed

+203
-65
lines changed

7 files changed

+203
-65
lines changed

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ func (s *MethodTestSuite) TestGroup() {
305305
}))
306306
s.Run("DeleteGroupMemberFromGroup", s.Subtest(func(db database.Store, check *expects) {
307307
g := dbgen.Group(s.T(), db, database.Group{})
308-
m := dbgen.GroupMember(s.T(), db, database.GroupMember{
308+
m := dbgen.GroupMember(s.T(), db, database.GroupMemberTable{
309309
GroupID: g.ID,
310310
})
311311
check.Args(database.DeleteGroupMemberFromGroupParams{
@@ -326,11 +326,15 @@ func (s *MethodTestSuite) TestGroup() {
326326
}))
327327
s.Run("GetGroupMembersByGroupID", s.Subtest(func(db database.Store, check *expects) {
328328
g := dbgen.Group(s.T(), db, database.Group{})
329-
_ = dbgen.GroupMember(s.T(), db, database.GroupMember{})
329+
gm := dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g.ID})
330+
check.Args(g.ID).Asserts(gm, policy.ActionRead)
331+
}))
332+
s.Run("GetGroupMembersCountByGroupID", s.Subtest(func(db database.Store, check *expects) {
333+
g := dbgen.Group(s.T(), db, database.Group{})
330334
check.Args(g.ID).Asserts(g, policy.ActionRead)
331335
}))
332336
s.Run("GetGroupMembers", s.Subtest(func(db database.Store, check *expects) {
333-
_ = dbgen.GroupMember(s.T(), db, database.GroupMember{})
337+
dbgen.GroupMember(s.T(), db, database.GroupMemberTable{})
334338
check.Asserts(rbac.ResourceSystem, policy.ActionRead)
335339
}))
336340
s.Run("GetGroups", s.Subtest(func(db database.Store, check *expects) {
@@ -339,7 +343,7 @@ func (s *MethodTestSuite) TestGroup() {
339343
}))
340344
s.Run("GetGroupsByOrganizationAndUserID", s.Subtest(func(db database.Store, check *expects) {
341345
g := dbgen.Group(s.T(), db, database.Group{})
342-
gm := dbgen.GroupMember(s.T(), db, database.GroupMember{GroupID: g.ID})
346+
gm := dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g.ID})
343347
check.Args(database.GetGroupsByOrganizationAndUserIDParams{
344348
OrganizationID: g.OrganizationID,
345349
UserID: gm.UserID,
@@ -368,7 +372,7 @@ func (s *MethodTestSuite) TestGroup() {
368372
u1 := dbgen.User(s.T(), db, database.User{})
369373
g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
370374
g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
371-
_ = dbgen.GroupMember(s.T(), db, database.GroupMember{GroupID: g1.ID, UserID: u1.ID})
375+
_ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g1.ID, UserID: u1.ID})
372376
check.Args(database.InsertUserGroupsByNameParams{
373377
OrganizationID: o.ID,
374378
UserID: u1.ID,
@@ -380,8 +384,8 @@ func (s *MethodTestSuite) TestGroup() {
380384
u1 := dbgen.User(s.T(), db, database.User{})
381385
g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
382386
g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
383-
_ = dbgen.GroupMember(s.T(), db, database.GroupMember{GroupID: g1.ID, UserID: u1.ID})
384-
_ = dbgen.GroupMember(s.T(), db, database.GroupMember{GroupID: g2.ID, UserID: u1.ID})
387+
_ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g1.ID, UserID: u1.ID})
388+
_ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g2.ID, UserID: u1.ID})
385389
check.Args(u1.ID).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns()
386390
}))
387391
s.Run("UpdateGroupByID", s.Subtest(func(db database.Store, check *expects) {

coderd/database/dbgen/dbgen.go

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"database/sql"
77
"encoding/hex"
88
"encoding/json"
9+
"errors"
910
"fmt"
1011
"net"
1112
"strings"
@@ -374,8 +375,8 @@ func Group(t testing.TB, db database.Store, orig database.Group) database.Group
374375
return group
375376
}
376377

377-
func GroupMember(t testing.TB, db database.Store, orig database.GroupMember) database.GroupMember {
378-
member := database.GroupMember{
378+
func GroupMember(t testing.TB, db database.Store, orig database.GroupMemberTable) database.GroupMember {
379+
member := database.GroupMemberTable{
379380
UserID: takeFirst(orig.UserID, uuid.New()),
380381
GroupID: takeFirst(orig.GroupID, uuid.New()),
381382
}
@@ -385,7 +386,44 @@ func GroupMember(t testing.TB, db database.Store, orig database.GroupMember) dat
385386
GroupID: member.GroupID,
386387
})
387388
require.NoError(t, err, "insert group member")
388-
return member
389+
390+
user, err := db.GetUserByID(genCtx, member.UserID)
391+
if errors.Is(err, sql.ErrNoRows) {
392+
user = User(t, db, database.User{ID: member.UserID})
393+
} else {
394+
require.NoError(t, err, "get user by id")
395+
}
396+
397+
group, err := db.GetGroupByID(genCtx, member.GroupID)
398+
if errors.Is(err, sql.ErrNoRows) {
399+
group = Group(t, db, database.Group{ID: member.GroupID})
400+
} else {
401+
require.NoError(t, err, "get group by id")
402+
}
403+
404+
groupMember := database.GroupMember{
405+
UserID: user.ID,
406+
UserEmail: user.Email,
407+
UserUsername: user.Username,
408+
UserHashedPassword: user.HashedPassword,
409+
UserCreatedAt: user.CreatedAt,
410+
UserUpdatedAt: user.UpdatedAt,
411+
UserStatus: user.Status,
412+
UserRbacRoles: user.RBACRoles,
413+
UserLoginType: user.LoginType,
414+
UserAvatarUrl: user.AvatarURL,
415+
UserDeleted: user.Deleted,
416+
UserLastSeenAt: user.LastSeenAt,
417+
UserQuietHoursSchedule: user.QuietHoursSchedule,
418+
UserThemePreference: user.ThemePreference,
419+
UserName: user.Name,
420+
UserGithubComUserID: user.GithubComUserID,
421+
OrganizationID: group.OrganizationID,
422+
GroupName: group.Name,
423+
GroupID: group.ID,
424+
}
425+
426+
return groupMember
389427
}
390428

391429
// ProvisionerJob is a bit more involved to get the values such as "completedAt", "startedAt", "cancelledAt" set. ps

coderd/database/dbgen/dbgen_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ func TestGenerator(t *testing.T) {
102102
db := dbmem.New()
103103
g := dbgen.Group(t, db, database.Group{})
104104
u := dbgen.User(t, db, database.User{})
105-
exp := []database.User{u}
106-
dbgen.GroupMember(t, db, database.GroupMember{GroupID: g.ID, UserID: u.ID})
105+
gm := dbgen.GroupMember(t, db, database.GroupMemberTable{GroupID: g.ID, UserID: u.ID})
106+
exp := []database.GroupMember{gm}
107107

108108
require.Equal(t, exp, must(db.GetGroupMembersByGroupID(context.Background(), g.ID)))
109109
})

coderd/database/dbmem/dbmem.go

Lines changed: 97 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func New() database.Store {
6060
dbcryptKeys: make([]database.DBCryptKey, 0),
6161
externalAuthLinks: make([]database.ExternalAuthLink, 0),
6262
groups: make([]database.Group, 0),
63-
groupMembers: make([]database.GroupMember, 0),
63+
groupMembers: make([]database.GroupMemberTable, 0),
6464
auditLogs: make([]database.AuditLog, 0),
6565
files: make([]database.File, 0),
6666
gitSSHKey: make([]database.GitSSHKey, 0),
@@ -156,7 +156,7 @@ type data struct {
156156
files []database.File
157157
externalAuthLinks []database.ExternalAuthLink
158158
gitSSHKey []database.GitSSHKey
159-
groupMembers []database.GroupMember
159+
groupMembers []database.GroupMemberTable
160160
groups []database.Group
161161
jfrogXRayScans []database.JfrogXrayScan
162162
licenses []database.License
@@ -723,41 +723,68 @@ func (q *FakeQuerier) getOrganizationMemberNoLock(orgID uuid.UUID) []database.Or
723723
return members
724724
}
725725

726+
var ErrUserDeleted = xerrors.New("user deleted")
727+
728+
// getGroupMemberNoLock fetches a group member by user ID and group ID.
729+
func (q *FakeQuerier) getGroupMemberNoLock(ctx context.Context, userID, groupID uuid.UUID) (database.GroupMember, error) {
730+
groupName := "Everyone"
731+
orgID := groupID
732+
groupIsEveryone := q.isEveryoneGroup(groupID)
733+
if !groupIsEveryone {
734+
group, err := q.getGroupByIDNoLock(ctx, groupID)
735+
if err != nil {
736+
return database.GroupMember{}, err
737+
}
738+
groupName = group.Name
739+
orgID = group.OrganizationID
740+
}
741+
742+
user, err := q.getUserByIDNoLock(userID)
743+
if err != nil {
744+
return database.GroupMember{}, err
745+
}
746+
if user.Deleted {
747+
return database.GroupMember{}, ErrUserDeleted
748+
}
749+
750+
return database.GroupMember{
751+
UserID: user.ID,
752+
UserEmail: user.Email,
753+
UserUsername: user.Username,
754+
UserHashedPassword: user.HashedPassword,
755+
UserCreatedAt: user.CreatedAt,
756+
UserUpdatedAt: user.UpdatedAt,
757+
UserStatus: user.Status,
758+
UserRbacRoles: user.RBACRoles,
759+
UserLoginType: user.LoginType,
760+
UserAvatarUrl: user.AvatarURL,
761+
UserDeleted: user.Deleted,
762+
UserLastSeenAt: user.LastSeenAt,
763+
UserQuietHoursSchedule: user.QuietHoursSchedule,
764+
UserThemePreference: user.ThemePreference,
765+
UserName: user.Name,
766+
UserGithubComUserID: user.GithubComUserID,
767+
OrganizationID: orgID,
768+
GroupName: groupName,
769+
GroupID: groupID,
770+
}, nil
771+
}
772+
726773
// getEveryoneGroupMembersNoLock fetches all the users in an organization.
727-
func (q *FakeQuerier) getEveryoneGroupMembersNoLock(orgID uuid.UUID) []database.GroupMember {
774+
func (q *FakeQuerier) getEveryoneGroupMembersNoLock(ctx context.Context, orgID uuid.UUID) []database.GroupMember {
728775
var (
729776
everyone []database.GroupMember
730777
orgMembers = q.getOrganizationMemberNoLock(orgID)
731778
)
732779
for _, member := range orgMembers {
733-
user, err := q.getUserByIDNoLock(member.UserID)
780+
groupMember, err := q.getGroupMemberNoLock(ctx, member.UserID, orgID)
781+
if errors.Is(err, ErrUserDeleted) {
782+
continue
783+
}
734784
if err != nil {
735785
return nil
736786
}
737-
if user.Deleted {
738-
continue
739-
}
740-
everyone = append(everyone, database.GroupMember{
741-
UserID: user.ID,
742-
UserEmail: user.Email,
743-
UserUsername: user.Username,
744-
UserHashedPassword: user.HashedPassword,
745-
UserCreatedAt: user.CreatedAt,
746-
UserUpdatedAt: user.UpdatedAt,
747-
UserStatus: user.Status,
748-
UserRbacRoles: user.RBACRoles,
749-
UserLoginType: user.LoginType,
750-
UserAvatarUrl: user.AvatarURL,
751-
UserDeleted: user.Deleted,
752-
UserLastSeenAt: user.LastSeenAt,
753-
UserQuietHoursSchedule: user.QuietHoursSchedule,
754-
UserThemePreference: user.ThemePreference,
755-
UserName: user.Name,
756-
UserGithubComUserID: user.GithubComUserID,
757-
OrganizationID: orgID,
758-
GroupName: "Everyone",
759-
GroupID: orgID,
760-
})
787+
everyone = append(everyone, groupMember)
761788
}
762789
return everyone
763790
}
@@ -2509,31 +2536,59 @@ func (q *FakeQuerier) GetGroupByOrgAndName(_ context.Context, arg database.GetGr
25092536
return database.Group{}, sql.ErrNoRows
25102537
}
25112538

2512-
func (q *FakeQuerier) GetGroupMembers(_ context.Context) ([]database.GroupMember, error) {
2539+
func (q *FakeQuerier) GetGroupMembers(ctx context.Context) ([]database.GroupMember, error) {
25132540
q.mutex.RLock()
25142541
defer q.mutex.RUnlock()
25152542

2516-
out := make([]database.GroupMember, len(q.groupMembers))
2517-
copy(out, q.groupMembers)
2518-
return out, nil
2543+
members := make([]database.GroupMemberTable, 0, len(q.groupMembers))
2544+
members = append(members, q.groupMembers...)
2545+
for _, org := range q.organizations {
2546+
for _, user := range q.users {
2547+
members = append(members, database.GroupMemberTable{
2548+
UserID: user.ID,
2549+
GroupID: org.ID,
2550+
})
2551+
}
2552+
}
2553+
2554+
var groupMembers []database.GroupMember
2555+
for _, member := range members {
2556+
groupMember, err := q.getGroupMemberNoLock(ctx, member.UserID, member.GroupID)
2557+
if errors.Is(err, ErrUserDeleted) {
2558+
continue
2559+
}
2560+
if err != nil {
2561+
return nil, err
2562+
}
2563+
groupMembers = append(groupMembers, groupMember)
2564+
}
2565+
2566+
return groupMembers, nil
25192567
}
25202568

2521-
func (q *FakeQuerier) GetGroupMembersByGroupID(_ context.Context, id uuid.UUID) ([]database.GroupMember, error) {
2569+
func (q *FakeQuerier) GetGroupMembersByGroupID(ctx context.Context, id uuid.UUID) ([]database.GroupMember, error) {
25222570
q.mutex.RLock()
25232571
defer q.mutex.RUnlock()
25242572

25252573
if q.isEveryoneGroup(id) {
2526-
return q.getEveryoneGroupMembersNoLock(id), nil
2574+
return q.getEveryoneGroupMembersNoLock(ctx, id), nil
25272575
}
25282576

2529-
var members []database.GroupMember
2577+
var groupMembers []database.GroupMember
25302578
for _, member := range q.groupMembers {
2579+
groupMember, err := q.getGroupMemberNoLock(ctx, member.UserID, member.GroupID)
2580+
if errors.Is(err, ErrUserDeleted) {
2581+
continue
2582+
}
2583+
if err != nil {
2584+
return nil, err
2585+
}
25312586
if member.GroupID == id {
2532-
members = append(members, member)
2587+
groupMembers = append(groupMembers, groupMember)
25332588
}
25342589
}
25352590

2536-
return members, nil
2591+
return groupMembers, nil
25372592
}
25382593

25392594
func (q *FakeQuerier) GetGroupMembersCountByGroupID(ctx context.Context, groupID uuid.UUID) (int64, error) {
@@ -2561,15 +2616,15 @@ func (q *FakeQuerier) GetGroupsByOrganizationAndUserID(_ context.Context, arg da
25612616

25622617
q.mutex.RLock()
25632618
defer q.mutex.RUnlock()
2564-
var groupIds []uuid.UUID
2619+
var groupIDs []uuid.UUID
25652620
for _, member := range q.groupMembers {
25662621
if member.UserID == arg.UserID {
2567-
groupIds = append(groupIds, member.GroupID)
2622+
groupIDs = append(groupIDs, member.GroupID)
25682623
}
25692624
}
25702625
groups := []database.Group{}
25712626
for _, group := range q.groups {
2572-
if slices.Contains(groupIds, group.ID) && group.OrganizationID == arg.OrganizationID {
2627+
if slices.Contains(groupIDs, group.ID) && group.OrganizationID == arg.OrganizationID {
25732628
groups = append(groups, group)
25742629
}
25752630
}
@@ -6254,7 +6309,7 @@ func (q *FakeQuerier) InsertGroupMember(_ context.Context, arg database.InsertGr
62546309
}
62556310

62566311
//nolint:gosimple
6257-
q.groupMembers = append(q.groupMembers, database.GroupMember{
6312+
q.groupMembers = append(q.groupMembers, database.GroupMemberTable{
62586313
GroupID: arg.GroupID,
62596314
UserID: arg.UserID,
62606315
})
@@ -6794,7 +6849,7 @@ func (q *FakeQuerier) InsertUserGroupsByName(_ context.Context, arg database.Ins
67946849
}
67956850

67966851
for _, groupID := range groupIDs {
6797-
q.groupMembers = append(q.groupMembers, database.GroupMember{
6852+
q.groupMembers = append(q.groupMembers, database.GroupMemberTable{
67986853
UserID: arg.UserID,
67996854
GroupID: groupID,
68006855
})

0 commit comments

Comments
 (0)