From b4aced90c423269c575d16b44614d05c4bf0801f Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Thu, 20 Apr 2023 11:31:41 +0000 Subject: [PATCH] chore: fix deadlock in dbfake and incorrect lock types I manually went through every single dbfake function and ensured it has the correct lock type depending on whether it writes or only reads. There were a surprising amount of methods that had the wrong lock type (Lock when only reading, or RLock when writing (!!!)). This also manually fixes every method that acquires a RLock and then calls a method that also acquires it's own RLock to use noLock methods instead. You cannot rely on acquiring a RLock twice in the same goroutine as RWMutex prioritizes any waiting Lock calls. I tried writing a ruleguard rule for this but because of limitations in ruleguard it doesn't seem possible. --- coderd/database/dbfake/databasefake.go | 122 +++++++++++++++---------- coderd/rbac/authz_internal_test.go | 2 +- 2 files changed, 77 insertions(+), 47 deletions(-) diff --git a/coderd/database/dbfake/databasefake.go b/coderd/database/dbfake/databasefake.go index 9d37f195dd01b..59f42ce864e2d 100644 --- a/coderd/database/dbfake/databasefake.go +++ b/coderd/database/dbfake/databasefake.go @@ -332,7 +332,7 @@ func (q *fakeQuerier) GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx context.C } // Get resources for build. - resources, err := q.GetWorkspaceResourcesByJobID(ctx, workspaceBuild.JobID) + resources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, workspaceBuild.JobID) if err != nil { return nil, xerrors.Errorf("get workspace resources: %w", err) } @@ -345,7 +345,7 @@ func (q *fakeQuerier) GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx context.C resourceIDs[i] = resource.ID } - agents, err := q.GetWorkspaceAgentsByResourceIDs(ctx, resourceIDs) + agents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, resourceIDs) if err != nil { return nil, xerrors.Errorf("get workspace agents: %w", err) } @@ -435,8 +435,8 @@ func (q *fakeQuerier) InsertWorkspaceAgentStat(_ context.Context, p database.Ins } func (q *fakeQuerier) GetTemplateDAUs(_ context.Context, templateID uuid.UUID) ([]database.GetTemplateDAUsRow, error) { - q.mutex.Lock() - defer q.mutex.Unlock() + q.mutex.RLock() + defer q.mutex.RUnlock() seens := make(map[time.Time]map[uuid.UUID]struct{}) @@ -478,8 +478,8 @@ func (q *fakeQuerier) GetTemplateDAUs(_ context.Context, templateID uuid.UUID) ( } func (q *fakeQuerier) GetDeploymentDAUs(_ context.Context) ([]database.GetDeploymentDAUsRow, error) { - q.mutex.Lock() - defer q.mutex.Unlock() + q.mutex.RLock() + defer q.mutex.RUnlock() seens := make(map[time.Time]map[uuid.UUID]struct{}) @@ -571,8 +571,8 @@ func (q *fakeQuerier) GetTemplateAverageBuildTime(ctx context.Context, arg datab } func (q *fakeQuerier) ParameterValue(_ context.Context, id uuid.UUID) (database.ParameterValue, error) { - q.mutex.Lock() - defer q.mutex.Unlock() + q.mutex.RLock() + defer q.mutex.RUnlock() for _, parameterValue := range q.parameterValues { if parameterValue.ID != id { @@ -1181,7 +1181,7 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database. return nil, xerrors.Errorf("get latest build: %w", err) } - job, err := q.GetProvisionerJobByID(ctx, build.JobID) + job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) if err != nil { return nil, xerrors.Errorf("get provisioner job: %w", err) } @@ -1270,12 +1270,12 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database. return nil, xerrors.Errorf("get latest build: %w", err) } - job, err := q.GetProvisionerJobByID(ctx, build.JobID) + job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) if err != nil { return nil, xerrors.Errorf("get provisioner job: %w", err) } - workspaceResources, err := q.GetWorkspaceResourcesByJobID(ctx, job.ID) + workspaceResources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, job.ID) if err != nil { return nil, xerrors.Errorf("get workspace resources: %w", err) } @@ -1285,7 +1285,7 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database. workspaceResourceIDs = append(workspaceResourceIDs, wr.ID) } - workspaceAgents, err := q.GetWorkspaceAgentsByResourceIDs(ctx, workspaceResourceIDs) + workspaceAgents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, workspaceResourceIDs) if err != nil { return nil, xerrors.Errorf("get workspace agents: %w", err) } @@ -1395,10 +1395,14 @@ func convertToWorkspaceRows(workspaces []database.Workspace, count int64) []data return rows } -func (q *fakeQuerier) GetWorkspaceByID(_ context.Context, id uuid.UUID) (database.Workspace, error) { +func (q *fakeQuerier) GetWorkspaceByID(ctx context.Context, id uuid.UUID) (database.Workspace, error) { q.mutex.RLock() defer q.mutex.RUnlock() + return q.getWorkspaceByIDNoLock(ctx, id) +} + +func (q *fakeQuerier) getWorkspaceByIDNoLock(_ context.Context, id uuid.UUID) (database.Workspace, error) { for _, workspace := range q.workspaces { if workspace.ID == id { return workspace, nil @@ -1407,10 +1411,14 @@ func (q *fakeQuerier) GetWorkspaceByID(_ context.Context, id uuid.UUID) (databas return database.Workspace{}, sql.ErrNoRows } -func (q *fakeQuerier) GetWorkspaceByAgentID(_ context.Context, agentID uuid.UUID) (database.Workspace, error) { +func (q *fakeQuerier) GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (database.Workspace, error) { q.mutex.RLock() defer q.mutex.RUnlock() + return q.getWorkspaceByAgentIDNoLock(ctx, agentID) +} + +func (q *fakeQuerier) getWorkspaceByAgentIDNoLock(_ context.Context, agentID uuid.UUID) (database.Workspace, error) { var agent database.WorkspaceAgent for _, _agent := range q.workspaceAgents { if _agent.ID == agentID { @@ -1496,7 +1504,7 @@ func (q *fakeQuerier) GetWorkspaceByWorkspaceAppID(_ context.Context, workspaceA for _, workspaceApp := range q.workspaceApps { workspaceApp := workspaceApp if workspaceApp.ID == workspaceAppID { - return q.GetWorkspaceByAgentID(context.Background(), workspaceApp.AgentID) + return q.getWorkspaceByAgentIDNoLock(context.Background(), workspaceApp.AgentID) } } return database.Workspace{}, sql.ErrNoRows @@ -1547,10 +1555,14 @@ func (q *fakeQuerier) GetWorkspaceAppsByAgentIDs(_ context.Context, ids []uuid.U return apps, nil } -func (q *fakeQuerier) GetWorkspaceBuildByID(_ context.Context, id uuid.UUID) (database.WorkspaceBuild, error) { +func (q *fakeQuerier) GetWorkspaceBuildByID(ctx context.Context, id uuid.UUID) (database.WorkspaceBuild, error) { q.mutex.RLock() defer q.mutex.RUnlock() + return q.getWorkspaceBuildByIDNoLock(ctx, id) +} + +func (q *fakeQuerier) getWorkspaceBuildByIDNoLock(_ context.Context, id uuid.UUID) (database.WorkspaceBuild, error) { for _, history := range q.workspaceBuilds { if history.ID == id { return history, nil @@ -2359,7 +2371,7 @@ func (q *fakeQuerier) GetTemplateGroupRoles(_ context.Context, id uuid.UUID) ([] groups := make([]database.TemplateGroup, 0, len(template.GroupACL)) for k, v := range template.GroupACL { - group, err := q.GetGroupByID(context.Background(), uuid.MustParse(k)) + group, err := q.getGroupByIDNoLock(context.Background(), uuid.MustParse(k)) if err != nil && !xerrors.Is(err, sql.ErrNoRows) { return nil, xerrors.Errorf("get group by ID: %w", err) } @@ -2490,10 +2502,14 @@ func (q *fakeQuerier) GetWorkspaceAgentByAuthToken(_ context.Context, authToken return database.WorkspaceAgent{}, sql.ErrNoRows } -func (q *fakeQuerier) GetWorkspaceAgentByID(_ context.Context, id uuid.UUID) (database.WorkspaceAgent, error) { +func (q *fakeQuerier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (database.WorkspaceAgent, error) { q.mutex.RLock() defer q.mutex.RUnlock() + return q.getWorkspaceAgentByIDNoLock(ctx, id) +} + +func (q *fakeQuerier) getWorkspaceAgentByIDNoLock(_ context.Context, id uuid.UUID) (database.WorkspaceAgent, error) { // The schema sorts this by created at, so we iterate the array backwards. for i := len(q.workspaceAgents) - 1; i >= 0; i-- { agent := q.workspaceAgents[i] @@ -2518,10 +2534,14 @@ func (q *fakeQuerier) GetWorkspaceAgentByInstanceID(_ context.Context, instanceI return database.WorkspaceAgent{}, sql.ErrNoRows } -func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(_ context.Context, resourceIDs []uuid.UUID) ([]database.WorkspaceAgent, error) { +func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, resourceIDs []uuid.UUID) ([]database.WorkspaceAgent, error) { q.mutex.RLock() defer q.mutex.RUnlock() + return q.getWorkspaceAgentsByResourceIDsNoLock(ctx, resourceIDs) +} + +func (q *fakeQuerier) getWorkspaceAgentsByResourceIDsNoLock(_ context.Context, resourceIDs []uuid.UUID) ([]database.WorkspaceAgent, error) { workspaceAgents := make([]database.WorkspaceAgent, 0) for _, agent := range q.workspaceAgents { for _, resourceID := range resourceIDs { @@ -2596,10 +2616,14 @@ func (q *fakeQuerier) GetWorkspaceResourceByID(_ context.Context, id uuid.UUID) return database.WorkspaceResource{}, sql.ErrNoRows } -func (q *fakeQuerier) GetWorkspaceResourcesByJobID(_ context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) { +func (q *fakeQuerier) GetWorkspaceResourcesByJobID(ctx context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) { q.mutex.RLock() defer q.mutex.RUnlock() + return q.getWorkspaceResourcesByJobIDNoLock(ctx, jobID) +} + +func (q *fakeQuerier) getWorkspaceResourcesByJobIDNoLock(_ context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) { resources := make([]database.WorkspaceResource, 0) for _, resource := range q.workspaceResources { if resource.JobID != jobID { @@ -3674,8 +3698,8 @@ func (q *fakeQuerier) GetWorkspaceAgentStartupLogsAfter(_ context.Context, arg d return nil, err } - q.mutex.Lock() - defer q.mutex.Unlock() + q.mutex.RLock() + defer q.mutex.RUnlock() logs := []database.WorkspaceAgentStartupLog{} for _, log := range q.workspaceAgentLogs { @@ -4051,13 +4075,13 @@ func (q *fakeQuerier) GetWorkspaceAgentStatsAndLabels(ctx context.Context, creat stat.Username = user.Username - workspace, err := q.GetWorkspaceByID(ctx, agentStat.WorkspaceID) + workspace, err := q.getWorkspaceByIDNoLock(ctx, agentStat.WorkspaceID) if err != nil { return nil, err } stat.WorkspaceName = workspace.Name - agent, err := q.GetWorkspaceAgentByID(ctx, agentStat.AgentID) + agent, err := q.getWorkspaceAgentByIDNoLock(ctx, agentStat.AgentID) if err != nil { return nil, err } @@ -4403,7 +4427,7 @@ func (q *fakeQuerier) GetAuditLogsOffset(_ context.Context, arg database.GetAudi } } if arg.BuildReason != "" { - workspaceBuild, err := q.GetWorkspaceBuildByID(context.Background(), alog.ResourceID) + workspaceBuild, err := q.getWorkspaceBuildByIDNoLock(context.Background(), alog.ResourceID) if err == nil && !strings.EqualFold(arg.BuildReason, string(workspaceBuild.Reason)) { continue } @@ -4497,8 +4521,8 @@ func (q *fakeQuerier) GetDERPMeshKey(_ context.Context) (string, error) { } func (q *fakeQuerier) UpsertLastUpdateCheck(_ context.Context, data string) error { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() q.lastUpdateCheck = []byte(data) return nil @@ -4672,8 +4696,8 @@ func (q *fakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params dat } func (q *fakeQuerier) InsertUserLink(_ context.Context, args database.InsertUserLinkParams) (database.UserLink, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() //nolint:gosimple link := database.UserLink{ @@ -4695,8 +4719,8 @@ func (q *fakeQuerier) UpdateUserLinkedID(_ context.Context, params database.Upda return database.UserLink{}, err } - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for i, link := range q.userLinks { if link.UserID == params.UserID && link.LoginType == params.LoginType { @@ -4715,8 +4739,8 @@ func (q *fakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUs return database.UserLink{}, err } - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for i, link := range q.userLinks { if link.UserID == params.UserID && link.LoginType == params.LoginType { @@ -4732,10 +4756,14 @@ func (q *fakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUs return database.UserLink{}, sql.ErrNoRows } -func (q *fakeQuerier) GetGroupByID(_ context.Context, id uuid.UUID) (database.Group, error) { +func (q *fakeQuerier) GetGroupByID(ctx context.Context, id uuid.UUID) (database.Group, error) { q.mutex.RLock() defer q.mutex.RUnlock() + return q.getGroupByIDNoLock(ctx, id) +} + +func (q *fakeQuerier) getGroupByIDNoLock(_ context.Context, id uuid.UUID) (database.Group, error) { for _, group := range q.groups { if group.ID == id { return group, nil @@ -4776,8 +4804,8 @@ func (q *fakeQuerier) InsertGroup(_ context.Context, arg database.InsertGroupPar return database.Group{}, err } - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() for _, group := range q.groups { if group.OrganizationID == arg.OrganizationID && @@ -4995,8 +5023,9 @@ func (q *fakeQuerier) UpdateGitAuthLink(_ context.Context, arg database.UpdateGi } func (q *fakeQuerier) GetQuotaAllowanceForUser(_ context.Context, userID uuid.UUID) (int64, error) { - q.mutex.Lock() - defer q.mutex.Unlock() + q.mutex.RLock() + defer q.mutex.RUnlock() + var sum int64 for _, member := range q.groupMembers { if member.UserID != userID { @@ -5012,8 +5041,9 @@ func (q *fakeQuerier) GetQuotaAllowanceForUser(_ context.Context, userID uuid.UU } func (q *fakeQuerier) GetQuotaConsumedForUser(_ context.Context, userID uuid.UUID) (int64, error) { - q.mutex.Lock() - defer q.mutex.Unlock() + q.mutex.RLock() + defer q.mutex.RUnlock() + var sum int64 for _, workspace := range q.workspaces { if workspace.OwnerID != userID { @@ -5072,8 +5102,8 @@ func (q *fakeQuerier) UpdateWorkspaceAgentStartupLogOverflowByID(_ context.Conte } func (q *fakeQuerier) GetWorkspaceProxies(_ context.Context) ([]database.WorkspaceProxy, error) { - q.mutex.Lock() - defer q.mutex.Unlock() + q.mutex.RLock() + defer q.mutex.RUnlock() cpy := make([]database.WorkspaceProxy, 0, len(q.workspaceProxies)) @@ -5086,8 +5116,8 @@ func (q *fakeQuerier) GetWorkspaceProxies(_ context.Context) ([]database.Workspa } func (q *fakeQuerier) GetWorkspaceProxyByID(_ context.Context, id uuid.UUID) (database.WorkspaceProxy, error) { - q.mutex.Lock() - defer q.mutex.Unlock() + q.mutex.RLock() + defer q.mutex.RUnlock() for _, proxy := range q.workspaceProxies { if proxy.ID == id { @@ -5098,8 +5128,8 @@ func (q *fakeQuerier) GetWorkspaceProxyByID(_ context.Context, id uuid.UUID) (da } func (q *fakeQuerier) GetWorkspaceProxyByHostname(_ context.Context, hostname string) (database.WorkspaceProxy, error) { - q.mutex.Lock() - defer q.mutex.Unlock() + q.mutex.RLock() + defer q.mutex.RUnlock() // Return zero rows if this is called with a non-sanitized hostname. The SQL // version of this query does the same thing. diff --git a/coderd/rbac/authz_internal_test.go b/coderd/rbac/authz_internal_test.go index 6a52b12e9b25d..ad4e180775f91 100644 --- a/coderd/rbac/authz_internal_test.go +++ b/coderd/rbac/authz_internal_test.go @@ -68,7 +68,7 @@ func TestFilterError(t *testing.T) { auth := &MockAuthorizer{ AuthorizeFunc: func(ctx context.Context, subject Subject, action Action, object Object) error { - // Authorize func always returns nil, unless the context is cancelled. + // Authorize func always returns nil, unless the context is canceled. return ctx.Err() }, }