Skip to content

Commit 26c3c12

Browse files
authored
chore(coderd): add MockAuditor.Contains test helper (#10421)
* Adds a Contains() method on MockAuditor to help with asserting the presence of an audit log with specific fields. * Updates existing usages of verifyAuditWorkspaceCreated to use the new helper * Updates test referenced in PR#10396.
1 parent e36b606 commit 26c3c12

File tree

2 files changed

+92
-25
lines changed

2 files changed

+92
-25
lines changed

coderd/audit/audit.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ package audit
33
import (
44
"context"
55
"sync"
6+
"testing"
7+
8+
"github.com/google/uuid"
9+
"golang.org/x/exp/slices"
610

711
"github.com/coder/coder/v2/coderd/database"
812
)
@@ -68,3 +72,76 @@ func (a *MockAuditor) Export(_ context.Context, alog database.AuditLog) error {
6872
func (*MockAuditor) diff(any, any) Map {
6973
return Map{}
7074
}
75+
76+
// Contains returns true if, for each non-zero-valued field in expected,
77+
// there exists a corresponding audit log in the mock auditor that matches
78+
// the expected values. Returns false otherwise.
79+
func (a *MockAuditor) Contains(t testing.TB, expected database.AuditLog) bool {
80+
a.mutex.Lock()
81+
defer a.mutex.Unlock()
82+
for idx, al := range a.auditLogs {
83+
if expected.ID != uuid.Nil && al.ID != expected.ID {
84+
t.Logf("audit log %d: expected ID %s, got %s", idx+1, expected.ID, al.ID)
85+
continue
86+
}
87+
if !expected.Time.IsZero() && expected.Time != al.Time {
88+
t.Logf("audit log %d: expected Time %s, got %s", idx+1, expected.Time, al.Time)
89+
continue
90+
}
91+
if expected.UserID != uuid.Nil && al.UserID != expected.UserID {
92+
t.Logf("audit log %d: expected UserID %s, got %s", idx+1, expected.UserID, al.UserID)
93+
continue
94+
}
95+
if expected.OrganizationID != uuid.Nil && al.UserID != expected.UserID {
96+
t.Logf("audit log %d: expected OrganizationID %s, got %s", idx+1, expected.OrganizationID, al.OrganizationID)
97+
continue
98+
}
99+
if expected.Ip.Valid && al.Ip.IPNet.String() != expected.Ip.IPNet.String() {
100+
t.Logf("audit log %d: expected Ip %s, got %s", idx+1, expected.Ip.IPNet, al.Ip.IPNet)
101+
continue
102+
}
103+
if expected.UserAgent.Valid && al.UserAgent.String != expected.UserAgent.String {
104+
t.Logf("audit log %d: expected UserAgent %s, got %s", idx+1, expected.UserAgent.String, al.UserAgent.String)
105+
continue
106+
}
107+
if expected.ResourceType != "" && expected.ResourceType != al.ResourceType {
108+
t.Logf("audit log %d: expected ResourceType %s, got %s", idx+1, expected.ResourceType, al.ResourceType)
109+
continue
110+
}
111+
if expected.ResourceID != uuid.Nil && expected.ResourceID != al.ResourceID {
112+
t.Logf("audit log %d: expected ResourceID %s, got %s", idx+1, expected.ResourceID, al.ResourceID)
113+
continue
114+
}
115+
if expected.ResourceTarget != "" && expected.ResourceTarget != al.ResourceTarget {
116+
t.Logf("audit log %d: expected ResourceTarget %s, got %s", idx+1, expected.ResourceTarget, al.ResourceTarget)
117+
continue
118+
}
119+
if expected.Action != "" && expected.Action != al.Action {
120+
t.Logf("audit log %d: expected Action %s, got %s", idx+1, expected.Action, al.Action)
121+
continue
122+
}
123+
if len(expected.Diff) > 0 && slices.Compare(expected.Diff, al.Diff) != 0 {
124+
t.Logf("audit log %d: expected Diff %s, got %s", idx+1, string(expected.Diff), string(al.Diff))
125+
continue
126+
}
127+
if expected.StatusCode != 0 && expected.StatusCode != al.StatusCode {
128+
t.Logf("audit log %d: expected StatusCode %d, got %d", idx+1, expected.StatusCode, al.StatusCode)
129+
continue
130+
}
131+
if len(expected.AdditionalFields) > 0 && slices.Compare(expected.AdditionalFields, al.AdditionalFields) != 0 {
132+
t.Logf("audit log %d: expected AdditionalFields %s, got %s", idx+1, string(expected.AdditionalFields), string(al.AdditionalFields))
133+
continue
134+
}
135+
if expected.RequestID != uuid.Nil && expected.RequestID != al.RequestID {
136+
t.Logf("audit log %d: expected RequestID %s, got %s", idx+1, expected.RequestID, al.RequestID)
137+
continue
138+
}
139+
if expected.ResourceIcon != "" && expected.ResourceIcon != al.ResourceIcon {
140+
t.Logf("audit log %d: expected ResourceIcon %s, got %s", idx+1, expected.ResourceIcon, al.ResourceIcon)
141+
continue
142+
}
143+
return true
144+
}
145+
146+
return false
147+
}

coderd/workspaces_test.go

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,11 @@ func TestPostWorkspacesByOrganization(t *testing.T) {
511511
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
512512
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
513513
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
514-
verifyAuditWorkspaceCreated(t, auditor, workspace.Name)
514+
assert.True(t, auditor.Contains(t, database.AuditLog{
515+
ResourceType: database.ResourceTypeWorkspace,
516+
Action: database.AuditActionCreate,
517+
ResourceTarget: workspace.Name,
518+
}))
515519
})
516520

517521
t.Run("CreateFromVersionWithAuditLogs", func(t *testing.T) {
@@ -535,7 +539,11 @@ func TestPostWorkspacesByOrganization(t *testing.T) {
535539

536540
require.Equal(t, testWorkspaceBuild.TemplateVersionID, versionTest.ID)
537541
require.Equal(t, defaultWorkspaceBuild.TemplateVersionID, versionDefault.ID)
538-
verifyAuditWorkspaceCreated(t, auditor, defaultWorkspace.Name)
542+
assert.True(t, auditor.Contains(t, database.AuditLog{
543+
ResourceType: database.ResourceTypeWorkspace,
544+
Action: database.AuditActionCreate,
545+
ResourceTarget: defaultWorkspace.Name,
546+
}))
539547
})
540548

541549
t.Run("InvalidCombinationOfTemplateAndTemplateVersion", func(t *testing.T) {
@@ -2741,7 +2749,11 @@ func TestWorkspaceDormant(t *testing.T) {
27412749
Dormant: true,
27422750
})
27432751
require.NoError(t, err)
2744-
require.Len(t, auditRecorder.AuditLogs(), 1)
2752+
require.True(t, auditRecorder.Contains(t, database.AuditLog{
2753+
Action: database.AuditActionWrite,
2754+
ResourceType: database.ResourceTypeWorkspace,
2755+
ResourceTarget: workspace.Name,
2756+
}))
27452757

27462758
workspace = coderdtest.MustWorkspace(t, client, workspace.ID)
27472759
require.NoError(t, err, "fetch provisioned workspace")
@@ -2804,25 +2816,3 @@ func TestWorkspaceDormant(t *testing.T) {
28042816
coderdtest.MustTransitionWorkspace(t, client, workspace.ID, database.WorkspaceTransitionStop, database.WorkspaceTransitionStart)
28052817
})
28062818
}
2807-
2808-
func verifyAuditWorkspaceCreated(t *testing.T, auditor *audit.MockAuditor, workspaceName string) {
2809-
var auditLogs []database.AuditLog
2810-
ok := assert.Eventually(t, func() bool {
2811-
auditLogs = auditor.AuditLogs()
2812-
2813-
for _, auditLog := range auditLogs {
2814-
if auditLog.Action == database.AuditActionCreate &&
2815-
auditLog.ResourceType == database.ResourceTypeWorkspace &&
2816-
auditLog.ResourceTarget == workspaceName {
2817-
return true
2818-
}
2819-
}
2820-
return false
2821-
}, testutil.WaitMedium, testutil.IntervalFast)
2822-
2823-
if !ok {
2824-
for i, auditLog := range auditLogs {
2825-
t.Logf("%d. Audit: ID=%s action=%s resourceID=%s resourceType=%s resourceTarget=%s", i+1, auditLog.ID, auditLog.Action, auditLog.ResourceID, auditLog.ResourceType, auditLog.ResourceTarget)
2826-
}
2827-
}
2828-
}

0 commit comments

Comments
 (0)