Skip to content

Commit aeb4603

Browse files
committed
test: Increase test coverage on auditable resources
When adding a new audit resource, we also need to add it to the function switch statements. This is a likely mistake, now a unit test will check this for you
1 parent e114999 commit aeb4603

File tree

3 files changed

+101
-3
lines changed

3 files changed

+101
-3
lines changed

coderd/audit/request.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ func ResourceTarget[T Auditable](tgt T) string {
7878
return ""
7979
case database.License:
8080
return strconv.Itoa(int(typed.ID))
81+
case database.WorkspaceProxy:
82+
return typed.Name
8183
default:
8284
panic(fmt.Sprintf("unknown resource %T", tgt))
8385
}
@@ -103,13 +105,15 @@ func ResourceID[T Auditable](tgt T) uuid.UUID {
103105
return typed.UserID
104106
case database.License:
105107
return typed.UUID
108+
case database.WorkspaceProxy:
109+
return typed.ID
106110
default:
107111
panic(fmt.Sprintf("unknown resource %T", tgt))
108112
}
109113
}
110114

111115
func ResourceType[T Auditable](tgt T) database.ResourceType {
112-
switch any(tgt).(type) {
116+
switch typed := any(tgt).(type) {
113117
case database.Template:
114118
return database.ResourceTypeTemplate
115119
case database.TemplateVersion:
@@ -128,8 +132,10 @@ func ResourceType[T Auditable](tgt T) database.ResourceType {
128132
return database.ResourceTypeApiKey
129133
case database.License:
130134
return database.ResourceTypeLicense
135+
case database.WorkspaceProxy:
136+
return database.ResourceTypeWorkspaceProxy
131137
default:
132-
panic(fmt.Sprintf("unknown resource %T", tgt))
138+
panic(fmt.Sprintf("unknown resource %T", typed))
133139
}
134140
}
135141

coderd/database/models.go

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

enterprise/audit/table_internal_test.go

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
package audit
22

33
import (
4+
"fmt"
45
"go/types"
6+
"strings"
57
"testing"
68

9+
"github.com/coder/coder/coderd/util/slice"
10+
11+
"github.com/coder/coder/coderd/audit"
12+
"github.com/coder/coder/coderd/database"
13+
714
"github.com/stretchr/testify/assert"
815
"github.com/stretchr/testify/require"
916
"golang.org/x/tools/go/packages"
@@ -15,7 +22,7 @@ func TestAuditableResources(t *testing.T) {
1522
t.Parallel()
1623

1724
pkgs, err := packages.Load(&packages.Config{
18-
Mode: packages.NeedTypes,
25+
Mode: packages.NeedTypes | packages.NeedDeps,
1926
}, "../../coderd/audit")
2027
require.NoError(t, err)
2128

@@ -37,13 +44,15 @@ func TestAuditableResources(t *testing.T) {
3744
require.True(t, ok, "expected Auditable to be a union")
3845

3946
found := make(map[string]bool)
47+
expectedList := make([]string, 0)
4048
// Now we check we have all the resources in the AuditableResources
4149
for i := 0; i < unionType.Len(); i++ {
4250
// All types come across like 'github.com/coder/coder/coderd/database.<type>'
4351
typeName := unionType.Term(i).Type().String()
4452
_, ok := AuditableResources[typeName]
4553
assert.True(t, ok, "missing resource %q from AuditableResources", typeName)
4654
found[typeName] = true
55+
expectedList = append(expectedList, typeName)
4756
}
4857

4958
// Also check that all resources in the table are in the union. We could
@@ -52,4 +61,86 @@ func TestAuditableResources(t *testing.T) {
5261
_, ok := found[name]
5362
assert.True(t, ok, "extra resource %q found in AuditableResources", name)
5463
}
64+
65+
// Various functions that have switch statements to include all Auditable
66+
// resources. Make sure we have all types supported.
67+
// nolint:paralleltest
68+
t.Run("ResourceID", func(t *testing.T) {
69+
// The function being tested, provided here to make it easier to find
70+
var _ = audit.ResourceID[database.APIKey]
71+
testAuditFunctionWithSwitch(t, auditPkg, "ResourceID", expectedList)
72+
})
73+
74+
// nolint:paralleltest
75+
t.Run("ResourceType", func(t *testing.T) {
76+
// The function being tested, provided here to make it easier to find
77+
var _ = audit.ResourceType[database.APIKey]
78+
testAuditFunctionWithSwitch(t, auditPkg, "ResourceType", expectedList)
79+
})
80+
81+
// nolint:paralleltest
82+
t.Run("ResourceTarget", func(t *testing.T) {
83+
// The function being tested, provided here to make it easier to find
84+
var _ = audit.ResourceTarget[database.APIKey]
85+
testAuditFunctionWithSwitch(t, auditPkg, "ResourceTarget", expectedList)
86+
})
87+
}
88+
89+
func testAuditFunctionWithSwitch(t *testing.T, pkg *packages.Package, funcName string, expectedTypes []string) {
90+
t.Helper()
91+
92+
f, ok := pkg.Types.Scope().Lookup(funcName).(*types.Func)
93+
require.True(t, ok, fmt.Sprintf("expected %s to be a function", funcName))
94+
switchCases := findSwitchTypes(f)
95+
for _, expected := range expectedTypes {
96+
if !slice.Contains(switchCases, expected) {
97+
t.Errorf("%s switch statement is missing type %q. Include it in the switch case block", funcName, expected)
98+
}
99+
}
100+
for _, sc := range switchCases {
101+
if !slice.Contains(expectedTypes, sc) {
102+
t.Errorf("%s switch statement has unexpected type %q. Remove it from the switch case block", funcName, sc)
103+
}
104+
}
105+
}
106+
107+
// findSwitchTypes is a helper function to find all types a switch statement in
108+
// the function body of f has.
109+
func findSwitchTypes(f *types.Func) []string {
110+
caseTypes := make([]string, 0)
111+
switches := returnSwitchBlocks(f.Scope())
112+
for _, sc := range switches {
113+
scTypes := findCaseTypes(sc)
114+
caseTypes = append(caseTypes, scTypes...)
115+
}
116+
return caseTypes
117+
}
118+
119+
func returnSwitchBlocks(sc *types.Scope) []*types.Scope {
120+
switches := make([]*types.Scope, 0)
121+
for i := 0; i < sc.NumChildren(); i++ {
122+
child := sc.Child(i)
123+
cStr := child.String()
124+
// This is the easiest way to tell if it is a switch statement.
125+
if strings.Contains(cStr, "type switch scope") {
126+
switches = append(switches, child)
127+
}
128+
}
129+
return switches
130+
}
131+
132+
func findCaseTypes(sc *types.Scope) []string {
133+
caseTypes := make([]string, 0)
134+
for i := 0; i < sc.NumChildren(); i++ {
135+
child := sc.Child(i)
136+
for _, name := range child.Names() {
137+
obj := child.Lookup(name).Type()
138+
typeName := obj.String()
139+
// Ignore the "Default:" case
140+
if typeName != "any" {
141+
caseTypes = append(caseTypes, typeName)
142+
}
143+
}
144+
}
145+
return caseTypes
55146
}

0 commit comments

Comments
 (0)