Skip to content

Commit 528a068

Browse files
authored
chore: fix deadlock in dbfake and incorrect lock types (#7218)
I manually went through every single dbfake function and ensured it has the correct lock type depending on whether it writes or only reads. There were a surprising amount of methods that had the wrong lock type (Lock when only reading, or RLock when writing (!!!)). This also manually fixes every method that acquires a RLock and then calls a method that also acquires it's own RLock to use noLock methods instead. You cannot rely on acquiring a RLock twice in the same goroutine as RWMutex prioritizes any waiting Lock calls. I tried writing a ruleguard rule for this but because of limitations in ruleguard it doesn't seem possible.
1 parent 5f5edb1 commit 528a068

File tree

2 files changed

+77
-47
lines changed

2 files changed

+77
-47
lines changed

coderd/database/dbfake/databasefake.go

Lines changed: 76 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ func (q *fakeQuerier) GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx context.C
332332
}
333333

334334
// Get resources for build.
335-
resources, err := q.GetWorkspaceResourcesByJobID(ctx, workspaceBuild.JobID)
335+
resources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, workspaceBuild.JobID)
336336
if err != nil {
337337
return nil, xerrors.Errorf("get workspace resources: %w", err)
338338
}
@@ -345,7 +345,7 @@ func (q *fakeQuerier) GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx context.C
345345
resourceIDs[i] = resource.ID
346346
}
347347

348-
agents, err := q.GetWorkspaceAgentsByResourceIDs(ctx, resourceIDs)
348+
agents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, resourceIDs)
349349
if err != nil {
350350
return nil, xerrors.Errorf("get workspace agents: %w", err)
351351
}
@@ -435,8 +435,8 @@ func (q *fakeQuerier) InsertWorkspaceAgentStat(_ context.Context, p database.Ins
435435
}
436436

437437
func (q *fakeQuerier) GetTemplateDAUs(_ context.Context, templateID uuid.UUID) ([]database.GetTemplateDAUsRow, error) {
438-
q.mutex.Lock()
439-
defer q.mutex.Unlock()
438+
q.mutex.RLock()
439+
defer q.mutex.RUnlock()
440440

441441
seens := make(map[time.Time]map[uuid.UUID]struct{})
442442

@@ -478,8 +478,8 @@ func (q *fakeQuerier) GetTemplateDAUs(_ context.Context, templateID uuid.UUID) (
478478
}
479479

480480
func (q *fakeQuerier) GetDeploymentDAUs(_ context.Context) ([]database.GetDeploymentDAUsRow, error) {
481-
q.mutex.Lock()
482-
defer q.mutex.Unlock()
481+
q.mutex.RLock()
482+
defer q.mutex.RUnlock()
483483

484484
seens := make(map[time.Time]map[uuid.UUID]struct{})
485485

@@ -571,8 +571,8 @@ func (q *fakeQuerier) GetTemplateAverageBuildTime(ctx context.Context, arg datab
571571
}
572572

573573
func (q *fakeQuerier) ParameterValue(_ context.Context, id uuid.UUID) (database.ParameterValue, error) {
574-
q.mutex.Lock()
575-
defer q.mutex.Unlock()
574+
q.mutex.RLock()
575+
defer q.mutex.RUnlock()
576576

577577
for _, parameterValue := range q.parameterValues {
578578
if parameterValue.ID != id {
@@ -1181,7 +1181,7 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.
11811181
return nil, xerrors.Errorf("get latest build: %w", err)
11821182
}
11831183

1184-
job, err := q.GetProvisionerJobByID(ctx, build.JobID)
1184+
job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID)
11851185
if err != nil {
11861186
return nil, xerrors.Errorf("get provisioner job: %w", err)
11871187
}
@@ -1270,12 +1270,12 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.
12701270
return nil, xerrors.Errorf("get latest build: %w", err)
12711271
}
12721272

1273-
job, err := q.GetProvisionerJobByID(ctx, build.JobID)
1273+
job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID)
12741274
if err != nil {
12751275
return nil, xerrors.Errorf("get provisioner job: %w", err)
12761276
}
12771277

1278-
workspaceResources, err := q.GetWorkspaceResourcesByJobID(ctx, job.ID)
1278+
workspaceResources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, job.ID)
12791279
if err != nil {
12801280
return nil, xerrors.Errorf("get workspace resources: %w", err)
12811281
}
@@ -1285,7 +1285,7 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.
12851285
workspaceResourceIDs = append(workspaceResourceIDs, wr.ID)
12861286
}
12871287

1288-
workspaceAgents, err := q.GetWorkspaceAgentsByResourceIDs(ctx, workspaceResourceIDs)
1288+
workspaceAgents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, workspaceResourceIDs)
12891289
if err != nil {
12901290
return nil, xerrors.Errorf("get workspace agents: %w", err)
12911291
}
@@ -1395,10 +1395,14 @@ func convertToWorkspaceRows(workspaces []database.Workspace, count int64) []data
13951395
return rows
13961396
}
13971397

1398-
func (q *fakeQuerier) GetWorkspaceByID(_ context.Context, id uuid.UUID) (database.Workspace, error) {
1398+
func (q *fakeQuerier) GetWorkspaceByID(ctx context.Context, id uuid.UUID) (database.Workspace, error) {
13991399
q.mutex.RLock()
14001400
defer q.mutex.RUnlock()
14011401

1402+
return q.getWorkspaceByIDNoLock(ctx, id)
1403+
}
1404+
1405+
func (q *fakeQuerier) getWorkspaceByIDNoLock(_ context.Context, id uuid.UUID) (database.Workspace, error) {
14021406
for _, workspace := range q.workspaces {
14031407
if workspace.ID == id {
14041408
return workspace, nil
@@ -1407,10 +1411,14 @@ func (q *fakeQuerier) GetWorkspaceByID(_ context.Context, id uuid.UUID) (databas
14071411
return database.Workspace{}, sql.ErrNoRows
14081412
}
14091413

1410-
func (q *fakeQuerier) GetWorkspaceByAgentID(_ context.Context, agentID uuid.UUID) (database.Workspace, error) {
1414+
func (q *fakeQuerier) GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (database.Workspace, error) {
14111415
q.mutex.RLock()
14121416
defer q.mutex.RUnlock()
14131417

1418+
return q.getWorkspaceByAgentIDNoLock(ctx, agentID)
1419+
}
1420+
1421+
func (q *fakeQuerier) getWorkspaceByAgentIDNoLock(_ context.Context, agentID uuid.UUID) (database.Workspace, error) {
14141422
var agent database.WorkspaceAgent
14151423
for _, _agent := range q.workspaceAgents {
14161424
if _agent.ID == agentID {
@@ -1496,7 +1504,7 @@ func (q *fakeQuerier) GetWorkspaceByWorkspaceAppID(_ context.Context, workspaceA
14961504
for _, workspaceApp := range q.workspaceApps {
14971505
workspaceApp := workspaceApp
14981506
if workspaceApp.ID == workspaceAppID {
1499-
return q.GetWorkspaceByAgentID(context.Background(), workspaceApp.AgentID)
1507+
return q.getWorkspaceByAgentIDNoLock(context.Background(), workspaceApp.AgentID)
15001508
}
15011509
}
15021510
return database.Workspace{}, sql.ErrNoRows
@@ -1547,10 +1555,14 @@ func (q *fakeQuerier) GetWorkspaceAppsByAgentIDs(_ context.Context, ids []uuid.U
15471555
return apps, nil
15481556
}
15491557

1550-
func (q *fakeQuerier) GetWorkspaceBuildByID(_ context.Context, id uuid.UUID) (database.WorkspaceBuild, error) {
1558+
func (q *fakeQuerier) GetWorkspaceBuildByID(ctx context.Context, id uuid.UUID) (database.WorkspaceBuild, error) {
15511559
q.mutex.RLock()
15521560
defer q.mutex.RUnlock()
15531561

1562+
return q.getWorkspaceBuildByIDNoLock(ctx, id)
1563+
}
1564+
1565+
func (q *fakeQuerier) getWorkspaceBuildByIDNoLock(_ context.Context, id uuid.UUID) (database.WorkspaceBuild, error) {
15541566
for _, history := range q.workspaceBuilds {
15551567
if history.ID == id {
15561568
return history, nil
@@ -2359,7 +2371,7 @@ func (q *fakeQuerier) GetTemplateGroupRoles(_ context.Context, id uuid.UUID) ([]
23592371

23602372
groups := make([]database.TemplateGroup, 0, len(template.GroupACL))
23612373
for k, v := range template.GroupACL {
2362-
group, err := q.GetGroupByID(context.Background(), uuid.MustParse(k))
2374+
group, err := q.getGroupByIDNoLock(context.Background(), uuid.MustParse(k))
23632375
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
23642376
return nil, xerrors.Errorf("get group by ID: %w", err)
23652377
}
@@ -2490,10 +2502,14 @@ func (q *fakeQuerier) GetWorkspaceAgentByAuthToken(_ context.Context, authToken
24902502
return database.WorkspaceAgent{}, sql.ErrNoRows
24912503
}
24922504

2493-
func (q *fakeQuerier) GetWorkspaceAgentByID(_ context.Context, id uuid.UUID) (database.WorkspaceAgent, error) {
2505+
func (q *fakeQuerier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (database.WorkspaceAgent, error) {
24942506
q.mutex.RLock()
24952507
defer q.mutex.RUnlock()
24962508

2509+
return q.getWorkspaceAgentByIDNoLock(ctx, id)
2510+
}
2511+
2512+
func (q *fakeQuerier) getWorkspaceAgentByIDNoLock(_ context.Context, id uuid.UUID) (database.WorkspaceAgent, error) {
24972513
// The schema sorts this by created at, so we iterate the array backwards.
24982514
for i := len(q.workspaceAgents) - 1; i >= 0; i-- {
24992515
agent := q.workspaceAgents[i]
@@ -2518,10 +2534,14 @@ func (q *fakeQuerier) GetWorkspaceAgentByInstanceID(_ context.Context, instanceI
25182534
return database.WorkspaceAgent{}, sql.ErrNoRows
25192535
}
25202536

2521-
func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(_ context.Context, resourceIDs []uuid.UUID) ([]database.WorkspaceAgent, error) {
2537+
func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, resourceIDs []uuid.UUID) ([]database.WorkspaceAgent, error) {
25222538
q.mutex.RLock()
25232539
defer q.mutex.RUnlock()
25242540

2541+
return q.getWorkspaceAgentsByResourceIDsNoLock(ctx, resourceIDs)
2542+
}
2543+
2544+
func (q *fakeQuerier) getWorkspaceAgentsByResourceIDsNoLock(_ context.Context, resourceIDs []uuid.UUID) ([]database.WorkspaceAgent, error) {
25252545
workspaceAgents := make([]database.WorkspaceAgent, 0)
25262546
for _, agent := range q.workspaceAgents {
25272547
for _, resourceID := range resourceIDs {
@@ -2596,10 +2616,14 @@ func (q *fakeQuerier) GetWorkspaceResourceByID(_ context.Context, id uuid.UUID)
25962616
return database.WorkspaceResource{}, sql.ErrNoRows
25972617
}
25982618

2599-
func (q *fakeQuerier) GetWorkspaceResourcesByJobID(_ context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) {
2619+
func (q *fakeQuerier) GetWorkspaceResourcesByJobID(ctx context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) {
26002620
q.mutex.RLock()
26012621
defer q.mutex.RUnlock()
26022622

2623+
return q.getWorkspaceResourcesByJobIDNoLock(ctx, jobID)
2624+
}
2625+
2626+
func (q *fakeQuerier) getWorkspaceResourcesByJobIDNoLock(_ context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) {
26032627
resources := make([]database.WorkspaceResource, 0)
26042628
for _, resource := range q.workspaceResources {
26052629
if resource.JobID != jobID {
@@ -3674,8 +3698,8 @@ func (q *fakeQuerier) GetWorkspaceAgentStartupLogsAfter(_ context.Context, arg d
36743698
return nil, err
36753699
}
36763700

3677-
q.mutex.Lock()
3678-
defer q.mutex.Unlock()
3701+
q.mutex.RLock()
3702+
defer q.mutex.RUnlock()
36793703

36803704
logs := []database.WorkspaceAgentStartupLog{}
36813705
for _, log := range q.workspaceAgentLogs {
@@ -4051,13 +4075,13 @@ func (q *fakeQuerier) GetWorkspaceAgentStatsAndLabels(ctx context.Context, creat
40514075

40524076
stat.Username = user.Username
40534077

4054-
workspace, err := q.GetWorkspaceByID(ctx, agentStat.WorkspaceID)
4078+
workspace, err := q.getWorkspaceByIDNoLock(ctx, agentStat.WorkspaceID)
40554079
if err != nil {
40564080
return nil, err
40574081
}
40584082
stat.WorkspaceName = workspace.Name
40594083

4060-
agent, err := q.GetWorkspaceAgentByID(ctx, agentStat.AgentID)
4084+
agent, err := q.getWorkspaceAgentByIDNoLock(ctx, agentStat.AgentID)
40614085
if err != nil {
40624086
return nil, err
40634087
}
@@ -4403,7 +4427,7 @@ func (q *fakeQuerier) GetAuditLogsOffset(_ context.Context, arg database.GetAudi
44034427
}
44044428
}
44054429
if arg.BuildReason != "" {
4406-
workspaceBuild, err := q.GetWorkspaceBuildByID(context.Background(), alog.ResourceID)
4430+
workspaceBuild, err := q.getWorkspaceBuildByIDNoLock(context.Background(), alog.ResourceID)
44074431
if err == nil && !strings.EqualFold(arg.BuildReason, string(workspaceBuild.Reason)) {
44084432
continue
44094433
}
@@ -4497,8 +4521,8 @@ func (q *fakeQuerier) GetDERPMeshKey(_ context.Context) (string, error) {
44974521
}
44984522

44994523
func (q *fakeQuerier) UpsertLastUpdateCheck(_ context.Context, data string) error {
4500-
q.mutex.RLock()
4501-
defer q.mutex.RUnlock()
4524+
q.mutex.Lock()
4525+
defer q.mutex.Unlock()
45024526

45034527
q.lastUpdateCheck = []byte(data)
45044528
return nil
@@ -4672,8 +4696,8 @@ func (q *fakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params dat
46724696
}
46734697

46744698
func (q *fakeQuerier) InsertUserLink(_ context.Context, args database.InsertUserLinkParams) (database.UserLink, error) {
4675-
q.mutex.RLock()
4676-
defer q.mutex.RUnlock()
4699+
q.mutex.Lock()
4700+
defer q.mutex.Unlock()
46774701

46784702
//nolint:gosimple
46794703
link := database.UserLink{
@@ -4695,8 +4719,8 @@ func (q *fakeQuerier) UpdateUserLinkedID(_ context.Context, params database.Upda
46954719
return database.UserLink{}, err
46964720
}
46974721

4698-
q.mutex.RLock()
4699-
defer q.mutex.RUnlock()
4722+
q.mutex.Lock()
4723+
defer q.mutex.Unlock()
47004724

47014725
for i, link := range q.userLinks {
47024726
if link.UserID == params.UserID && link.LoginType == params.LoginType {
@@ -4715,8 +4739,8 @@ func (q *fakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUs
47154739
return database.UserLink{}, err
47164740
}
47174741

4718-
q.mutex.RLock()
4719-
defer q.mutex.RUnlock()
4742+
q.mutex.Lock()
4743+
defer q.mutex.Unlock()
47204744

47214745
for i, link := range q.userLinks {
47224746
if link.UserID == params.UserID && link.LoginType == params.LoginType {
@@ -4732,10 +4756,14 @@ func (q *fakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUs
47324756
return database.UserLink{}, sql.ErrNoRows
47334757
}
47344758

4735-
func (q *fakeQuerier) GetGroupByID(_ context.Context, id uuid.UUID) (database.Group, error) {
4759+
func (q *fakeQuerier) GetGroupByID(ctx context.Context, id uuid.UUID) (database.Group, error) {
47364760
q.mutex.RLock()
47374761
defer q.mutex.RUnlock()
47384762

4763+
return q.getGroupByIDNoLock(ctx, id)
4764+
}
4765+
4766+
func (q *fakeQuerier) getGroupByIDNoLock(_ context.Context, id uuid.UUID) (database.Group, error) {
47394767
for _, group := range q.groups {
47404768
if group.ID == id {
47414769
return group, nil
@@ -4776,8 +4804,8 @@ func (q *fakeQuerier) InsertGroup(_ context.Context, arg database.InsertGroupPar
47764804
return database.Group{}, err
47774805
}
47784806

4779-
q.mutex.RLock()
4780-
defer q.mutex.RUnlock()
4807+
q.mutex.Lock()
4808+
defer q.mutex.Unlock()
47814809

47824810
for _, group := range q.groups {
47834811
if group.OrganizationID == arg.OrganizationID &&
@@ -4995,8 +5023,9 @@ func (q *fakeQuerier) UpdateGitAuthLink(_ context.Context, arg database.UpdateGi
49955023
}
49965024

49975025
func (q *fakeQuerier) GetQuotaAllowanceForUser(_ context.Context, userID uuid.UUID) (int64, error) {
4998-
q.mutex.Lock()
4999-
defer q.mutex.Unlock()
5026+
q.mutex.RLock()
5027+
defer q.mutex.RUnlock()
5028+
50005029
var sum int64
50015030
for _, member := range q.groupMembers {
50025031
if member.UserID != userID {
@@ -5012,8 +5041,9 @@ func (q *fakeQuerier) GetQuotaAllowanceForUser(_ context.Context, userID uuid.UU
50125041
}
50135042

50145043
func (q *fakeQuerier) GetQuotaConsumedForUser(_ context.Context, userID uuid.UUID) (int64, error) {
5015-
q.mutex.Lock()
5016-
defer q.mutex.Unlock()
5044+
q.mutex.RLock()
5045+
defer q.mutex.RUnlock()
5046+
50175047
var sum int64
50185048
for _, workspace := range q.workspaces {
50195049
if workspace.OwnerID != userID {
@@ -5072,8 +5102,8 @@ func (q *fakeQuerier) UpdateWorkspaceAgentStartupLogOverflowByID(_ context.Conte
50725102
}
50735103

50745104
func (q *fakeQuerier) GetWorkspaceProxies(_ context.Context) ([]database.WorkspaceProxy, error) {
5075-
q.mutex.Lock()
5076-
defer q.mutex.Unlock()
5105+
q.mutex.RLock()
5106+
defer q.mutex.RUnlock()
50775107

50785108
cpy := make([]database.WorkspaceProxy, 0, len(q.workspaceProxies))
50795109

@@ -5086,8 +5116,8 @@ func (q *fakeQuerier) GetWorkspaceProxies(_ context.Context) ([]database.Workspa
50865116
}
50875117

50885118
func (q *fakeQuerier) GetWorkspaceProxyByID(_ context.Context, id uuid.UUID) (database.WorkspaceProxy, error) {
5089-
q.mutex.Lock()
5090-
defer q.mutex.Unlock()
5119+
q.mutex.RLock()
5120+
defer q.mutex.RUnlock()
50915121

50925122
for _, proxy := range q.workspaceProxies {
50935123
if proxy.ID == id {
@@ -5098,8 +5128,8 @@ func (q *fakeQuerier) GetWorkspaceProxyByID(_ context.Context, id uuid.UUID) (da
50985128
}
50995129

51005130
func (q *fakeQuerier) GetWorkspaceProxyByHostname(_ context.Context, hostname string) (database.WorkspaceProxy, error) {
5101-
q.mutex.Lock()
5102-
defer q.mutex.Unlock()
5131+
q.mutex.RLock()
5132+
defer q.mutex.RUnlock()
51035133

51045134
// Return zero rows if this is called with a non-sanitized hostname. The SQL
51055135
// version of this query does the same thing.

coderd/rbac/authz_internal_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ func TestFilterError(t *testing.T) {
6868

6969
auth := &MockAuthorizer{
7070
AuthorizeFunc: func(ctx context.Context, subject Subject, action Action, object Object) error {
71-
// Authorize func always returns nil, unless the context is cancelled.
71+
// Authorize func always returns nil, unless the context is canceled.
7272
return ctx.Err()
7373
},
7474
}

0 commit comments

Comments
 (0)