From 98801b9987936c968b6bf29dcbb12aa67eb75a66 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Sun, 29 Jan 2023 21:52:32 +0000 Subject: [PATCH] fix: access GetUserByID in database fake without lock to resolve race See: https://github.com/coder/coder/actions/runs/4038615993/jobs/6942750837 --- coderd/database/databasefake/databasefake.go | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 40aba2a8eeed1..b4ca053b0cc73 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -556,6 +556,11 @@ func (q *fakeQuerier) GetUserByID(_ context.Context, id uuid.UUID) (database.Use q.mutex.RLock() defer q.mutex.RUnlock() + return q.getUserByIDNoLock(id) +} + +// getUserByIDNoLock is used by other functions in the database fake. +func (q *fakeQuerier) getUserByIDNoLock(id uuid.UUID) (database.User, error) { for _, user := range q.users { if user.ID == id { return user, nil @@ -891,7 +896,7 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database. } if arg.OwnerUsername != "" { - owner, err := q.GetUserByID(ctx, workspace.OwnerID) + owner, err := q.getUserByIDNoLock(workspace.OwnerID) if err == nil && !strings.EqualFold(arg.OwnerUsername, owner.Username) { continue } @@ -2033,7 +2038,7 @@ func (q *fakeQuerier) GetTemplateUserRoles(_ context.Context, id uuid.UUID) ([]d users := make([]database.TemplateUser, 0, len(template.UserACL)) for k, v := range template.UserACL { - user, err := q.GetUserByID(context.Background(), uuid.MustParse(k)) + user, err := q.getUserByIDNoLock(uuid.MustParse(k)) if err != nil && xerrors.Is(err, sql.ErrNoRows) { return nil, xerrors.Errorf("get user by ID: %w", err) } @@ -3593,7 +3598,7 @@ func (q *fakeQuerier) DeleteGitSSHKey(_ context.Context, userID uuid.UUID) error return sql.ErrNoRows } -func (q *fakeQuerier) GetAuditLogsOffset(ctx context.Context, arg database.GetAuditLogsOffsetParams) ([]database.GetAuditLogsOffsetRow, error) { +func (q *fakeQuerier) GetAuditLogsOffset(_ context.Context, arg database.GetAuditLogsOffsetParams) ([]database.GetAuditLogsOffsetRow, error) { if err := validateDatabaseType(arg); err != nil { return nil, err } @@ -3619,13 +3624,13 @@ func (q *fakeQuerier) GetAuditLogsOffset(ctx context.Context, arg database.GetAu continue } if arg.Username != "" { - user, err := q.GetUserByID(context.Background(), alog.UserID) + user, err := q.getUserByIDNoLock(alog.UserID) if err == nil && !strings.EqualFold(arg.Username, user.Username) { continue } } if arg.Email != "" { - user, err := q.GetUserByID(context.Background(), alog.UserID) + user, err := q.getUserByIDNoLock(alog.UserID) if err == nil && !strings.EqualFold(arg.Email, user.Email) { continue } @@ -3647,7 +3652,7 @@ func (q *fakeQuerier) GetAuditLogsOffset(ctx context.Context, arg database.GetAu } } - user, err := q.GetUserByID(ctx, alog.UserID) + user, err := q.getUserByIDNoLock(alog.UserID) userValid := err == nil logs = append(logs, database.GetAuditLogsOffsetRow{