Skip to content

chore: wrap audit logs in a mutex to fix data race #6898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions coderd/apikey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestTokenCRUD(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
auditor := audit.NewMock()
numLogs := len(auditor.AuditLogs)
numLogs := len(auditor.AuditLogs())
client := coderdtest.New(t, &coderdtest.Options{Auditor: auditor})
_ = coderdtest.CreateFirstUser(t, client)
numLogs++ // add an audit log for user creation
Expand Down Expand Up @@ -58,9 +58,9 @@ func TestTokenCRUD(t *testing.T) {
require.Empty(t, keys)

// ensure audit log count is correct
require.Len(t, auditor.AuditLogs, numLogs)
require.Equal(t, database.AuditActionCreate, auditor.AuditLogs[numLogs-2].Action)
require.Equal(t, database.AuditActionDelete, auditor.AuditLogs[numLogs-1].Action)
require.Len(t, auditor.AuditLogs(), numLogs)
require.Equal(t, database.AuditActionCreate, auditor.AuditLogs()[numLogs-2].Action)
require.Equal(t, database.AuditActionDelete, auditor.AuditLogs()[numLogs-1].Action)
}

func TestTokenScoped(t *testing.T) {
Expand Down
12 changes: 10 additions & 2 deletions coderd/audit/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,21 @@ func NewMock() *MockAuditor {

type MockAuditor struct {
mutex sync.Mutex
AuditLogs []database.AuditLog
auditLogs []database.AuditLog
}

func (a *MockAuditor) AuditLogs() []database.AuditLog {
a.mutex.Lock()
defer a.mutex.Unlock()
logs := make([]database.AuditLog, len(a.auditLogs))
copy(logs, a.auditLogs)
return logs
}

func (a *MockAuditor) Export(_ context.Context, alog database.AuditLog) error {
a.mutex.Lock()
defer a.mutex.Unlock()
a.AuditLogs = append(a.AuditLogs, alog)
a.auditLogs = append(a.auditLogs, alog)
return nil
}

Expand Down
4 changes: 2 additions & 2 deletions coderd/gitsshkey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ func TestGitSSHKey(t *testing.T) {
require.NotEmpty(t, key2.PublicKey)
require.NotEqual(t, key2.PublicKey, key1.PublicKey)

require.Len(t, auditor.AuditLogs, 2)
assert.Equal(t, database.AuditActionWrite, auditor.AuditLogs[1].Action)
require.Len(t, auditor.AuditLogs(), 2)
assert.Equal(t, database.AuditActionWrite, auditor.AuditLogs()[1].Action)
})
}

Expand Down
18 changes: 9 additions & 9 deletions coderd/templates_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ func TestPostTemplateByOrganization(t *testing.T) {
assert.Equal(t, expected.Name, got.Name)
assert.Equal(t, expected.Description, got.Description)

require.Len(t, auditor.AuditLogs, 4)
assert.Equal(t, database.AuditActionLogin, auditor.AuditLogs[0].Action)
assert.Equal(t, database.AuditActionCreate, auditor.AuditLogs[1].Action)
assert.Equal(t, database.AuditActionWrite, auditor.AuditLogs[2].Action)
assert.Equal(t, database.AuditActionCreate, auditor.AuditLogs[3].Action)
require.Len(t, auditor.AuditLogs(), 4)
assert.Equal(t, database.AuditActionLogin, auditor.AuditLogs()[0].Action)
assert.Equal(t, database.AuditActionCreate, auditor.AuditLogs()[1].Action)
assert.Equal(t, database.AuditActionWrite, auditor.AuditLogs()[2].Action)
assert.Equal(t, database.AuditActionCreate, auditor.AuditLogs()[3].Action)
})

t.Run("AlreadyExists", func(t *testing.T) {
Expand Down Expand Up @@ -376,8 +376,8 @@ func TestPatchTemplateMeta(t *testing.T) {
assert.Equal(t, req.DefaultTTLMillis, updated.DefaultTTLMillis)
assert.False(t, req.AllowUserCancelWorkspaceJobs)

require.Len(t, auditor.AuditLogs, 5)
assert.Equal(t, database.AuditActionWrite, auditor.AuditLogs[4].Action)
require.Len(t, auditor.AuditLogs(), 5)
assert.Equal(t, database.AuditActionWrite, auditor.AuditLogs()[4].Action)
})

t.Run("NoDefaultTTL", func(t *testing.T) {
Expand Down Expand Up @@ -677,8 +677,8 @@ func TestDeleteTemplate(t *testing.T) {
err := client.DeleteTemplate(ctx, template.ID)
require.NoError(t, err)

require.Len(t, auditor.AuditLogs, 5)
assert.Equal(t, database.AuditActionDelete, auditor.AuditLogs[4].Action)
require.Len(t, auditor.AuditLogs(), 5)
assert.Equal(t, database.AuditActionDelete, auditor.AuditLogs()[4].Action)
})

t.Run("Workspaces", func(t *testing.T) {
Expand Down
8 changes: 4 additions & 4 deletions coderd/templateversions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ func TestPostTemplateVersionsByOrganization(t *testing.T) {
require.Equal(t, "bananas", version.Name)
require.Equal(t, provisionerdserver.ScopeOrganization, version.Job.Tags[provisionerdserver.TagScope])

require.Len(t, auditor.AuditLogs, 2)
assert.Equal(t, database.AuditActionCreate, auditor.AuditLogs[1].Action)
require.Len(t, auditor.AuditLogs(), 2)
assert.Equal(t, database.AuditActionCreate, auditor.AuditLogs()[1].Action)
})
t.Run("Example", func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -715,8 +715,8 @@ func TestPatchActiveTemplateVersion(t *testing.T) {
})
require.NoError(t, err)

require.Len(t, auditor.AuditLogs, 5)
assert.Equal(t, database.AuditActionWrite, auditor.AuditLogs[4].Action)
require.Len(t, auditor.AuditLogs(), 5)
assert.Equal(t, database.AuditActionWrite, auditor.AuditLogs()[4].Action)
})
}

Expand Down
50 changes: 25 additions & 25 deletions coderd/userauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ func TestUserOAuth2Github(t *testing.T) {
},
},
})
numLogs := len(auditor.AuditLogs)
numLogs := len(auditor.AuditLogs())

resp := oauth2Callback(t, client)
numLogs++ // add an audit log for login
Expand All @@ -257,9 +257,9 @@ func TestUserOAuth2Github(t *testing.T) {
require.Equal(t, "kyle", user.Username)
require.Equal(t, "/hello-world", user.AvatarURL)

require.Len(t, auditor.AuditLogs, numLogs)
require.NotEqual(t, auditor.AuditLogs[numLogs-1].UserID, uuid.Nil)
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs[numLogs-1].Action)
require.Len(t, auditor.AuditLogs(), numLogs)
require.NotEqual(t, auditor.AuditLogs()[numLogs-1].UserID, uuid.Nil)
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs()[numLogs-1].Action)
})
t.Run("SignupAllowedTeam", func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -296,14 +296,14 @@ func TestUserOAuth2Github(t *testing.T) {
},
},
})
numLogs := len(auditor.AuditLogs)
numLogs := len(auditor.AuditLogs())

resp := oauth2Callback(t, client)
numLogs++ // add an audit log for login

require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
require.Len(t, auditor.AuditLogs, numLogs)
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs[numLogs-1].Action)
require.Len(t, auditor.AuditLogs(), numLogs)
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs()[numLogs-1].Action)
})
t.Run("SignupAllowedTeamInFirstOrganization", func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -348,14 +348,14 @@ func TestUserOAuth2Github(t *testing.T) {
},
},
})
numLogs := len(auditor.AuditLogs)
numLogs := len(auditor.AuditLogs())

resp := oauth2Callback(t, client)
numLogs++ // add an audit log for login

require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
require.Len(t, auditor.AuditLogs, numLogs)
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs[numLogs-1].Action)
require.Len(t, auditor.AuditLogs(), numLogs)
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs()[numLogs-1].Action)
})
t.Run("SignupAllowedTeamInSecondOrganization", func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -400,14 +400,14 @@ func TestUserOAuth2Github(t *testing.T) {
},
},
})
numLogs := len(auditor.AuditLogs)
numLogs := len(auditor.AuditLogs())

resp := oauth2Callback(t, client)
numLogs++ // add an audit log for login

require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
require.Len(t, auditor.AuditLogs, numLogs)
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs[numLogs-1].Action)
require.Len(t, auditor.AuditLogs(), numLogs)
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs()[numLogs-1].Action)
})
t.Run("SignupAllowEveryone", func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -438,14 +438,14 @@ func TestUserOAuth2Github(t *testing.T) {
},
},
})
numLogs := len(auditor.AuditLogs)
numLogs := len(auditor.AuditLogs())

resp := oauth2Callback(t, client)
numLogs++ // add an audit log for login

require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
require.Len(t, auditor.AuditLogs, numLogs)
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs[numLogs-1].Action)
require.Len(t, auditor.AuditLogs(), numLogs)
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs()[numLogs-1].Action)
})
t.Run("SignupFailedInactiveInOrg", func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -659,7 +659,7 @@ func TestUserOIDC(t *testing.T) {
Auditor: auditor,
OIDCConfig: config,
})
numLogs := len(auditor.AuditLogs)
numLogs := len(auditor.AuditLogs())

resp := oidcCallback(t, client, conf.EncodeClaims(t, tc.IDTokenClaims))
numLogs++ // add an audit log for login
Expand All @@ -673,9 +673,9 @@ func TestUserOIDC(t *testing.T) {
require.NoError(t, err)
require.Equal(t, tc.Username, user.Username)

require.Len(t, auditor.AuditLogs, numLogs)
require.NotEqual(t, auditor.AuditLogs[numLogs-1].UserID, uuid.Nil)
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs[numLogs-1].Action)
require.Len(t, auditor.AuditLogs(), numLogs)
require.NotEqual(t, auditor.AuditLogs()[numLogs-1].UserID, uuid.Nil)
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs()[numLogs-1].Action)
}

if tc.AvatarURL != "" {
Expand All @@ -684,8 +684,8 @@ func TestUserOIDC(t *testing.T) {
require.NoError(t, err)
require.Equal(t, tc.AvatarURL, user.AvatarURL)

require.Len(t, auditor.AuditLogs, numLogs)
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs[numLogs-1].Action)
require.Len(t, auditor.AuditLogs(), numLogs)
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs()[numLogs-1].Action)
}
})
}
Expand All @@ -702,7 +702,7 @@ func TestUserOIDC(t *testing.T) {
Auditor: auditor,
OIDCConfig: config,
})
numLogs := len(auditor.AuditLogs)
numLogs := len(auditor.AuditLogs())

code := conf.EncodeClaims(t, jwt.MapClaims{
"email": "jon@coder.com",
Expand Down Expand Up @@ -735,8 +735,8 @@ func TestUserOIDC(t *testing.T) {
require.NoError(t, err)
require.True(t, strings.HasPrefix(user.Username, "jon-"), "username %q should have prefix %q", user.Username, "jon-")

require.Len(t, auditor.AuditLogs, numLogs)
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs[numLogs-1].Action)
require.Len(t, auditor.AuditLogs(), numLogs)
require.Equal(t, database.AuditActionLogin, auditor.AuditLogs()[numLogs-1].Action)
})

t.Run("Disabled", func(t *testing.T) {
Expand Down
Loading