diff --git a/coderd/audit/audit.go b/coderd/audit/audit.go index 4d256541d05f6..83675ab477309 100644 --- a/coderd/audit/audit.go +++ b/coderd/audit/audit.go @@ -3,6 +3,10 @@ package audit import ( "context" "sync" + "testing" + + "github.com/google/uuid" + "golang.org/x/exp/slices" "github.com/coder/coder/v2/coderd/database" ) @@ -68,3 +72,76 @@ func (a *MockAuditor) Export(_ context.Context, alog database.AuditLog) error { func (*MockAuditor) diff(any, any) Map { return Map{} } + +// Contains returns true if, for each non-zero-valued field in expected, +// there exists a corresponding audit log in the mock auditor that matches +// the expected values. Returns false otherwise. +func (a *MockAuditor) Contains(t testing.TB, expected database.AuditLog) bool { + a.mutex.Lock() + defer a.mutex.Unlock() + for idx, al := range a.auditLogs { + if expected.ID != uuid.Nil && al.ID != expected.ID { + t.Logf("audit log %d: expected ID %s, got %s", idx+1, expected.ID, al.ID) + continue + } + if !expected.Time.IsZero() && expected.Time != al.Time { + t.Logf("audit log %d: expected Time %s, got %s", idx+1, expected.Time, al.Time) + continue + } + if expected.UserID != uuid.Nil && al.UserID != expected.UserID { + t.Logf("audit log %d: expected UserID %s, got %s", idx+1, expected.UserID, al.UserID) + continue + } + if expected.OrganizationID != uuid.Nil && al.UserID != expected.UserID { + t.Logf("audit log %d: expected OrganizationID %s, got %s", idx+1, expected.OrganizationID, al.OrganizationID) + continue + } + if expected.Ip.Valid && al.Ip.IPNet.String() != expected.Ip.IPNet.String() { + t.Logf("audit log %d: expected Ip %s, got %s", idx+1, expected.Ip.IPNet, al.Ip.IPNet) + continue + } + if expected.UserAgent.Valid && al.UserAgent.String != expected.UserAgent.String { + t.Logf("audit log %d: expected UserAgent %s, got %s", idx+1, expected.UserAgent.String, al.UserAgent.String) + continue + } + if expected.ResourceType != "" && expected.ResourceType != al.ResourceType { + t.Logf("audit log %d: expected ResourceType %s, got %s", idx+1, expected.ResourceType, al.ResourceType) + continue + } + if expected.ResourceID != uuid.Nil && expected.ResourceID != al.ResourceID { + t.Logf("audit log %d: expected ResourceID %s, got %s", idx+1, expected.ResourceID, al.ResourceID) + continue + } + if expected.ResourceTarget != "" && expected.ResourceTarget != al.ResourceTarget { + t.Logf("audit log %d: expected ResourceTarget %s, got %s", idx+1, expected.ResourceTarget, al.ResourceTarget) + continue + } + if expected.Action != "" && expected.Action != al.Action { + t.Logf("audit log %d: expected Action %s, got %s", idx+1, expected.Action, al.Action) + continue + } + if len(expected.Diff) > 0 && slices.Compare(expected.Diff, al.Diff) != 0 { + t.Logf("audit log %d: expected Diff %s, got %s", idx+1, string(expected.Diff), string(al.Diff)) + continue + } + if expected.StatusCode != 0 && expected.StatusCode != al.StatusCode { + t.Logf("audit log %d: expected StatusCode %d, got %d", idx+1, expected.StatusCode, al.StatusCode) + continue + } + if len(expected.AdditionalFields) > 0 && slices.Compare(expected.AdditionalFields, al.AdditionalFields) != 0 { + t.Logf("audit log %d: expected AdditionalFields %s, got %s", idx+1, string(expected.AdditionalFields), string(al.AdditionalFields)) + continue + } + if expected.RequestID != uuid.Nil && expected.RequestID != al.RequestID { + t.Logf("audit log %d: expected RequestID %s, got %s", idx+1, expected.RequestID, al.RequestID) + continue + } + if expected.ResourceIcon != "" && expected.ResourceIcon != al.ResourceIcon { + t.Logf("audit log %d: expected ResourceIcon %s, got %s", idx+1, expected.ResourceIcon, al.ResourceIcon) + continue + } + return true + } + + return false +} diff --git a/coderd/workspaces_test.go b/coderd/workspaces_test.go index a12c262cb14f5..d3f5bfa00e276 100644 --- a/coderd/workspaces_test.go +++ b/coderd/workspaces_test.go @@ -511,7 +511,11 @@ func TestPostWorkspacesByOrganization(t *testing.T) { coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - verifyAuditWorkspaceCreated(t, auditor, workspace.Name) + assert.True(t, auditor.Contains(t, database.AuditLog{ + ResourceType: database.ResourceTypeWorkspace, + Action: database.AuditActionCreate, + ResourceTarget: workspace.Name, + })) }) t.Run("CreateFromVersionWithAuditLogs", func(t *testing.T) { @@ -535,7 +539,11 @@ func TestPostWorkspacesByOrganization(t *testing.T) { require.Equal(t, testWorkspaceBuild.TemplateVersionID, versionTest.ID) require.Equal(t, defaultWorkspaceBuild.TemplateVersionID, versionDefault.ID) - verifyAuditWorkspaceCreated(t, auditor, defaultWorkspace.Name) + assert.True(t, auditor.Contains(t, database.AuditLog{ + ResourceType: database.ResourceTypeWorkspace, + Action: database.AuditActionCreate, + ResourceTarget: defaultWorkspace.Name, + })) }) t.Run("InvalidCombinationOfTemplateAndTemplateVersion", func(t *testing.T) { @@ -2741,7 +2749,11 @@ func TestWorkspaceDormant(t *testing.T) { Dormant: true, }) require.NoError(t, err) - require.Len(t, auditRecorder.AuditLogs(), 1) + require.True(t, auditRecorder.Contains(t, database.AuditLog{ + Action: database.AuditActionWrite, + ResourceType: database.ResourceTypeWorkspace, + ResourceTarget: workspace.Name, + })) workspace = coderdtest.MustWorkspace(t, client, workspace.ID) require.NoError(t, err, "fetch provisioned workspace") @@ -2804,25 +2816,3 @@ func TestWorkspaceDormant(t *testing.T) { coderdtest.MustTransitionWorkspace(t, client, workspace.ID, database.WorkspaceTransitionStop, database.WorkspaceTransitionStart) }) } - -func verifyAuditWorkspaceCreated(t *testing.T, auditor *audit.MockAuditor, workspaceName string) { - var auditLogs []database.AuditLog - ok := assert.Eventually(t, func() bool { - auditLogs = auditor.AuditLogs() - - for _, auditLog := range auditLogs { - if auditLog.Action == database.AuditActionCreate && - auditLog.ResourceType == database.ResourceTypeWorkspace && - auditLog.ResourceTarget == workspaceName { - return true - } - } - return false - }, testutil.WaitMedium, testutil.IntervalFast) - - if !ok { - for i, auditLog := range auditLogs { - t.Logf("%d. Audit: ID=%s action=%s resourceID=%s resourceType=%s resourceTarget=%s", i+1, auditLog.ID, auditLog.Action, auditLog.ResourceID, auditLog.ResourceType, auditLog.ResourceTarget) - } - } -}