Skip to content

Commit a911dda

Browse files
authored
fix: access GetUserByID in database fake without lock to resolve race (#5909)
See: https://github.com/coder/coder/actions/runs/4038615993/jobs/6942750837
1 parent 7ad8750 commit a911dda

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

coderd/database/databasefake/databasefake.go

+11-6
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,11 @@ func (q *fakeQuerier) GetUserByID(_ context.Context, id uuid.UUID) (database.Use
556556
q.mutex.RLock()
557557
defer q.mutex.RUnlock()
558558

559+
return q.getUserByIDNoLock(id)
560+
}
561+
562+
// getUserByIDNoLock is used by other functions in the database fake.
563+
func (q *fakeQuerier) getUserByIDNoLock(id uuid.UUID) (database.User, error) {
559564
for _, user := range q.users {
560565
if user.ID == id {
561566
return user, nil
@@ -891,7 +896,7 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.
891896
}
892897

893898
if arg.OwnerUsername != "" {
894-
owner, err := q.GetUserByID(ctx, workspace.OwnerID)
899+
owner, err := q.getUserByIDNoLock(workspace.OwnerID)
895900
if err == nil && !strings.EqualFold(arg.OwnerUsername, owner.Username) {
896901
continue
897902
}
@@ -2033,7 +2038,7 @@ func (q *fakeQuerier) GetTemplateUserRoles(_ context.Context, id uuid.UUID) ([]d
20332038

20342039
users := make([]database.TemplateUser, 0, len(template.UserACL))
20352040
for k, v := range template.UserACL {
2036-
user, err := q.GetUserByID(context.Background(), uuid.MustParse(k))
2041+
user, err := q.getUserByIDNoLock(uuid.MustParse(k))
20372042
if err != nil && xerrors.Is(err, sql.ErrNoRows) {
20382043
return nil, xerrors.Errorf("get user by ID: %w", err)
20392044
}
@@ -3593,7 +3598,7 @@ func (q *fakeQuerier) DeleteGitSSHKey(_ context.Context, userID uuid.UUID) error
35933598
return sql.ErrNoRows
35943599
}
35953600

3596-
func (q *fakeQuerier) GetAuditLogsOffset(ctx context.Context, arg database.GetAuditLogsOffsetParams) ([]database.GetAuditLogsOffsetRow, error) {
3601+
func (q *fakeQuerier) GetAuditLogsOffset(_ context.Context, arg database.GetAuditLogsOffsetParams) ([]database.GetAuditLogsOffsetRow, error) {
35973602
if err := validateDatabaseType(arg); err != nil {
35983603
return nil, err
35993604
}
@@ -3619,13 +3624,13 @@ func (q *fakeQuerier) GetAuditLogsOffset(ctx context.Context, arg database.GetAu
36193624
continue
36203625
}
36213626
if arg.Username != "" {
3622-
user, err := q.GetUserByID(context.Background(), alog.UserID)
3627+
user, err := q.getUserByIDNoLock(alog.UserID)
36233628
if err == nil && !strings.EqualFold(arg.Username, user.Username) {
36243629
continue
36253630
}
36263631
}
36273632
if arg.Email != "" {
3628-
user, err := q.GetUserByID(context.Background(), alog.UserID)
3633+
user, err := q.getUserByIDNoLock(alog.UserID)
36293634
if err == nil && !strings.EqualFold(arg.Email, user.Email) {
36303635
continue
36313636
}
@@ -3647,7 +3652,7 @@ func (q *fakeQuerier) GetAuditLogsOffset(ctx context.Context, arg database.GetAu
36473652
}
36483653
}
36493654

3650-
user, err := q.GetUserByID(ctx, alog.UserID)
3655+
user, err := q.getUserByIDNoLock(alog.UserID)
36513656
userValid := err == nil
36523657

36533658
logs = append(logs, database.GetAuditLogsOffsetRow{

0 commit comments

Comments
 (0)