Skip to content

Commit a63c97b

Browse files
authored
chore: wrap audit logs in a mutex to fix data race (#6898)
This was seen in `main`!
1 parent 5780006 commit a63c97b

11 files changed

+109
-101
lines changed

coderd/apikey_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func TestTokenCRUD(t *testing.T) {
2525
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
2626
defer cancel()
2727
auditor := audit.NewMock()
28-
numLogs := len(auditor.AuditLogs)
28+
numLogs := len(auditor.AuditLogs())
2929
client := coderdtest.New(t, &coderdtest.Options{Auditor: auditor})
3030
_ = coderdtest.CreateFirstUser(t, client)
3131
numLogs++ // add an audit log for user creation
@@ -58,9 +58,9 @@ func TestTokenCRUD(t *testing.T) {
5858
require.Empty(t, keys)
5959

6060
// ensure audit log count is correct
61-
require.Len(t, auditor.AuditLogs, numLogs)
62-
require.Equal(t, database.AuditActionCreate, auditor.AuditLogs[numLogs-2].Action)
63-
require.Equal(t, database.AuditActionDelete, auditor.AuditLogs[numLogs-1].Action)
61+
require.Len(t, auditor.AuditLogs(), numLogs)
62+
require.Equal(t, database.AuditActionCreate, auditor.AuditLogs()[numLogs-2].Action)
63+
require.Equal(t, database.AuditActionDelete, auditor.AuditLogs()[numLogs-1].Action)
6464
}
6565

6666
func TestTokenScoped(t *testing.T) {

coderd/audit/audit.go

+10-2
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,21 @@ func NewMock() *MockAuditor {
3939

4040
type MockAuditor struct {
4141
mutex sync.Mutex
42-
AuditLogs []database.AuditLog
42+
auditLogs []database.AuditLog
43+
}
44+
45+
func (a *MockAuditor) AuditLogs() []database.AuditLog {
46+
a.mutex.Lock()
47+
defer a.mutex.Unlock()
48+
logs := make([]database.AuditLog, len(a.auditLogs))
49+
copy(logs, a.auditLogs)
50+
return logs
4351
}
4452

4553
func (a *MockAuditor) Export(_ context.Context, alog database.AuditLog) error {
4654
a.mutex.Lock()
4755
defer a.mutex.Unlock()
48-
a.AuditLogs = append(a.AuditLogs, alog)
56+
a.auditLogs = append(a.auditLogs, alog)
4957
return nil
5058
}
5159

coderd/gitsshkey_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ func TestGitSSHKey(t *testing.T) {
9494
require.NotEmpty(t, key2.PublicKey)
9595
require.NotEqual(t, key2.PublicKey, key1.PublicKey)
9696

97-
require.Len(t, auditor.AuditLogs, 2)
98-
assert.Equal(t, database.AuditActionWrite, auditor.AuditLogs[1].Action)
97+
require.Len(t, auditor.AuditLogs(), 2)
98+
assert.Equal(t, database.AuditActionWrite, auditor.AuditLogs()[1].Action)
9999
})
100100
}
101101

coderd/templates_test.go

+9-9
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,11 @@ func TestPostTemplateByOrganization(t *testing.T) {
6262
assert.Equal(t, expected.Name, got.Name)
6363
assert.Equal(t, expected.Description, got.Description)
6464

65-
require.Len(t, auditor.AuditLogs, 4)
66-
assert.Equal(t, database.AuditActionLogin, auditor.AuditLogs[0].Action)
67-
assert.Equal(t, database.AuditActionCreate, auditor.AuditLogs[1].Action)
68-
assert.Equal(t, database.AuditActionWrite, auditor.AuditLogs[2].Action)
69-
assert.Equal(t, database.AuditActionCreate, auditor.AuditLogs[3].Action)
65+
require.Len(t, auditor.AuditLogs(), 4)
66+
assert.Equal(t, database.AuditActionLogin, auditor.AuditLogs()[0].Action)
67+
assert.Equal(t, database.AuditActionCreate, auditor.AuditLogs()[1].Action)
68+
assert.Equal(t, database.AuditActionWrite, auditor.AuditLogs()[2].Action)
69+
assert.Equal(t, database.AuditActionCreate, auditor.AuditLogs()[3].Action)
7070
})
7171

7272
t.Run("AlreadyExists", func(t *testing.T) {
@@ -376,8 +376,8 @@ func TestPatchTemplateMeta(t *testing.T) {
376376
assert.Equal(t, req.DefaultTTLMillis, updated.DefaultTTLMillis)
377377
assert.False(t, req.AllowUserCancelWorkspaceJobs)
378378

379-
require.Len(t, auditor.AuditLogs, 5)
380-
assert.Equal(t, database.AuditActionWrite, auditor.AuditLogs[4].Action)
379+
require.Len(t, auditor.AuditLogs(), 5)
380+
assert.Equal(t, database.AuditActionWrite, auditor.AuditLogs()[4].Action)
381381
})
382382

383383
t.Run("NoDefaultTTL", func(t *testing.T) {
@@ -677,8 +677,8 @@ func TestDeleteTemplate(t *testing.T) {
677677
err := client.DeleteTemplate(ctx, template.ID)
678678
require.NoError(t, err)
679679

680-
require.Len(t, auditor.AuditLogs, 5)
681-
assert.Equal(t, database.AuditActionDelete, auditor.AuditLogs[4].Action)
680+
require.Len(t, auditor.AuditLogs(), 5)
681+
assert.Equal(t, database.AuditActionDelete, auditor.AuditLogs()[4].Action)
682682
})
683683

684684
t.Run("Workspaces", func(t *testing.T) {

coderd/templateversions_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ func TestPostTemplateVersionsByOrganization(t *testing.T) {
135135
require.Equal(t, "bananas", version.Name)
136136
require.Equal(t, provisionerdserver.ScopeOrganization, version.Job.Tags[provisionerdserver.TagScope])
137137

138-
require.Len(t, auditor.AuditLogs, 2)
139-
assert.Equal(t, database.AuditActionCreate, auditor.AuditLogs[1].Action)
138+
require.Len(t, auditor.AuditLogs(), 2)
139+
assert.Equal(t, database.AuditActionCreate, auditor.AuditLogs()[1].Action)
140140
})
141141
t.Run("Example", func(t *testing.T) {
142142
t.Parallel()
@@ -715,8 +715,8 @@ func TestPatchActiveTemplateVersion(t *testing.T) {
715715
})
716716
require.NoError(t, err)
717717

718-
require.Len(t, auditor.AuditLogs, 5)
719-
assert.Equal(t, database.AuditActionWrite, auditor.AuditLogs[4].Action)
718+
require.Len(t, auditor.AuditLogs(), 5)
719+
assert.Equal(t, database.AuditActionWrite, auditor.AuditLogs()[4].Action)
720720
})
721721
}
722722

coderd/userauth_test.go

+25-25
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ func TestUserOAuth2Github(t *testing.T) {
243243
},
244244
},
245245
})
246-
numLogs := len(auditor.AuditLogs)
246+
numLogs := len(auditor.AuditLogs())
247247

248248
resp := oauth2Callback(t, client)
249249
numLogs++ // add an audit log for login
@@ -257,9 +257,9 @@ func TestUserOAuth2Github(t *testing.T) {
257257
require.Equal(t, "kyle", user.Username)
258258
require.Equal(t, "/hello-world", user.AvatarURL)
259259

260-
require.Len(t, auditor.AuditLogs, numLogs)
261-
require.NotEqual(t, auditor.AuditLogs[numLogs-1].UserID, uuid.Nil)
262-
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs[numLogs-1].Action)
260+
require.Len(t, auditor.AuditLogs(), numLogs)
261+
require.NotEqual(t, auditor.AuditLogs()[numLogs-1].UserID, uuid.Nil)
262+
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs()[numLogs-1].Action)
263263
})
264264
t.Run("SignupAllowedTeam", func(t *testing.T) {
265265
t.Parallel()
@@ -296,14 +296,14 @@ func TestUserOAuth2Github(t *testing.T) {
296296
},
297297
},
298298
})
299-
numLogs := len(auditor.AuditLogs)
299+
numLogs := len(auditor.AuditLogs())
300300

301301
resp := oauth2Callback(t, client)
302302
numLogs++ // add an audit log for login
303303

304304
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
305-
require.Len(t, auditor.AuditLogs, numLogs)
306-
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs[numLogs-1].Action)
305+
require.Len(t, auditor.AuditLogs(), numLogs)
306+
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs()[numLogs-1].Action)
307307
})
308308
t.Run("SignupAllowedTeamInFirstOrganization", func(t *testing.T) {
309309
t.Parallel()
@@ -348,14 +348,14 @@ func TestUserOAuth2Github(t *testing.T) {
348348
},
349349
},
350350
})
351-
numLogs := len(auditor.AuditLogs)
351+
numLogs := len(auditor.AuditLogs())
352352

353353
resp := oauth2Callback(t, client)
354354
numLogs++ // add an audit log for login
355355

356356
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
357-
require.Len(t, auditor.AuditLogs, numLogs)
358-
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs[numLogs-1].Action)
357+
require.Len(t, auditor.AuditLogs(), numLogs)
358+
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs()[numLogs-1].Action)
359359
})
360360
t.Run("SignupAllowedTeamInSecondOrganization", func(t *testing.T) {
361361
t.Parallel()
@@ -400,14 +400,14 @@ func TestUserOAuth2Github(t *testing.T) {
400400
},
401401
},
402402
})
403-
numLogs := len(auditor.AuditLogs)
403+
numLogs := len(auditor.AuditLogs())
404404

405405
resp := oauth2Callback(t, client)
406406
numLogs++ // add an audit log for login
407407

408408
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
409-
require.Len(t, auditor.AuditLogs, numLogs)
410-
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs[numLogs-1].Action)
409+
require.Len(t, auditor.AuditLogs(), numLogs)
410+
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs()[numLogs-1].Action)
411411
})
412412
t.Run("SignupAllowEveryone", func(t *testing.T) {
413413
t.Parallel()
@@ -438,14 +438,14 @@ func TestUserOAuth2Github(t *testing.T) {
438438
},
439439
},
440440
})
441-
numLogs := len(auditor.AuditLogs)
441+
numLogs := len(auditor.AuditLogs())
442442

443443
resp := oauth2Callback(t, client)
444444
numLogs++ // add an audit log for login
445445

446446
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
447-
require.Len(t, auditor.AuditLogs, numLogs)
448-
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs[numLogs-1].Action)
447+
require.Len(t, auditor.AuditLogs(), numLogs)
448+
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs()[numLogs-1].Action)
449449
})
450450
t.Run("SignupFailedInactiveInOrg", func(t *testing.T) {
451451
t.Parallel()
@@ -659,7 +659,7 @@ func TestUserOIDC(t *testing.T) {
659659
Auditor: auditor,
660660
OIDCConfig: config,
661661
})
662-
numLogs := len(auditor.AuditLogs)
662+
numLogs := len(auditor.AuditLogs())
663663

664664
resp := oidcCallback(t, client, conf.EncodeClaims(t, tc.IDTokenClaims))
665665
numLogs++ // add an audit log for login
@@ -673,9 +673,9 @@ func TestUserOIDC(t *testing.T) {
673673
require.NoError(t, err)
674674
require.Equal(t, tc.Username, user.Username)
675675

676-
require.Len(t, auditor.AuditLogs, numLogs)
677-
require.NotEqual(t, auditor.AuditLogs[numLogs-1].UserID, uuid.Nil)
678-
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs[numLogs-1].Action)
676+
require.Len(t, auditor.AuditLogs(), numLogs)
677+
require.NotEqual(t, auditor.AuditLogs()[numLogs-1].UserID, uuid.Nil)
678+
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs()[numLogs-1].Action)
679679
}
680680

681681
if tc.AvatarURL != "" {
@@ -684,8 +684,8 @@ func TestUserOIDC(t *testing.T) {
684684
require.NoError(t, err)
685685
require.Equal(t, tc.AvatarURL, user.AvatarURL)
686686

687-
require.Len(t, auditor.AuditLogs, numLogs)
688-
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs[numLogs-1].Action)
687+
require.Len(t, auditor.AuditLogs(), numLogs)
688+
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs()[numLogs-1].Action)
689689
}
690690
})
691691
}
@@ -702,7 +702,7 @@ func TestUserOIDC(t *testing.T) {
702702
Auditor: auditor,
703703
OIDCConfig: config,
704704
})
705-
numLogs := len(auditor.AuditLogs)
705+
numLogs := len(auditor.AuditLogs())
706706

707707
code := conf.EncodeClaims(t, jwt.MapClaims{
708708
"email": "jon@coder.com",
@@ -735,8 +735,8 @@ func TestUserOIDC(t *testing.T) {
735735
require.NoError(t, err)
736736
require.True(t, strings.HasPrefix(user.Username, "jon-"), "username %q should have prefix %q", user.Username, "jon-")
737737

738-
require.Len(t, auditor.AuditLogs, numLogs)
739-
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs[numLogs-1].Action)
738+
require.Len(t, auditor.AuditLogs(), numLogs)
739+
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs()[numLogs-1].Action)
740740
})
741741

742742
t.Run("Disabled", func(t *testing.T) {

0 commit comments

Comments
 (0)