1
1
package audit
2
2
3
3
import (
4
+ "fmt"
4
5
"go/types"
6
+ "strings"
5
7
"testing"
6
8
9
+ "github.com/coder/coder/coderd/util/slice"
10
+
11
+ "github.com/coder/coder/coderd/audit"
12
+ "github.com/coder/coder/coderd/database"
13
+
7
14
"github.com/stretchr/testify/assert"
8
15
"github.com/stretchr/testify/require"
9
16
"golang.org/x/tools/go/packages"
@@ -15,7 +22,7 @@ func TestAuditableResources(t *testing.T) {
15
22
t .Parallel ()
16
23
17
24
pkgs , err := packages .Load (& packages.Config {
18
- Mode : packages .NeedTypes ,
25
+ Mode : packages .NeedTypes | packages . NeedDeps ,
19
26
}, "../../coderd/audit" )
20
27
require .NoError (t , err )
21
28
@@ -37,13 +44,15 @@ func TestAuditableResources(t *testing.T) {
37
44
require .True (t , ok , "expected Auditable to be a union" )
38
45
39
46
found := make (map [string ]bool )
47
+ expectedList := make ([]string , 0 )
40
48
// Now we check we have all the resources in the AuditableResources
41
49
for i := 0 ; i < unionType .Len (); i ++ {
42
50
// All types come across like 'github.com/coder/coder/coderd/database.<type>'
43
51
typeName := unionType .Term (i ).Type ().String ()
44
52
_ , ok := AuditableResources [typeName ]
45
53
assert .True (t , ok , "missing resource %q from AuditableResources" , typeName )
46
54
found [typeName ] = true
55
+ expectedList = append (expectedList , typeName )
47
56
}
48
57
49
58
// Also check that all resources in the table are in the union. We could
@@ -52,4 +61,86 @@ func TestAuditableResources(t *testing.T) {
52
61
_ , ok := found [name ]
53
62
assert .True (t , ok , "extra resource %q found in AuditableResources" , name )
54
63
}
64
+
65
+ // Various functions that have switch statements to include all Auditable
66
+ // resources. Make sure we have all types supported.
67
+ // nolint:paralleltest
68
+ t .Run ("ResourceID" , func (t * testing.T ) {
69
+ // The function being tested, provided here to make it easier to find
70
+ var _ = audit .ResourceID [database .APIKey ]
71
+ testAuditFunctionWithSwitch (t , auditPkg , "ResourceID" , expectedList )
72
+ })
73
+
74
+ // nolint:paralleltest
75
+ t .Run ("ResourceType" , func (t * testing.T ) {
76
+ // The function being tested, provided here to make it easier to find
77
+ var _ = audit .ResourceType [database .APIKey ]
78
+ testAuditFunctionWithSwitch (t , auditPkg , "ResourceType" , expectedList )
79
+ })
80
+
81
+ // nolint:paralleltest
82
+ t .Run ("ResourceTarget" , func (t * testing.T ) {
83
+ // The function being tested, provided here to make it easier to find
84
+ var _ = audit .ResourceTarget [database .APIKey ]
85
+ testAuditFunctionWithSwitch (t , auditPkg , "ResourceTarget" , expectedList )
86
+ })
87
+ }
88
+
89
+ func testAuditFunctionWithSwitch (t * testing.T , pkg * packages.Package , funcName string , expectedTypes []string ) {
90
+ t .Helper ()
91
+
92
+ f , ok := pkg .Types .Scope ().Lookup (funcName ).(* types.Func )
93
+ require .True (t , ok , fmt .Sprintf ("expected %s to be a function" , funcName ))
94
+ switchCases := findSwitchTypes (f )
95
+ for _ , expected := range expectedTypes {
96
+ if ! slice .Contains (switchCases , expected ) {
97
+ t .Errorf ("%s switch statement is missing type %q. Include it in the switch case block" , funcName , expected )
98
+ }
99
+ }
100
+ for _ , sc := range switchCases {
101
+ if ! slice .Contains (expectedTypes , sc ) {
102
+ t .Errorf ("%s switch statement has unexpected type %q. Remove it from the switch case block" , funcName , sc )
103
+ }
104
+ }
105
+ }
106
+
107
+ // findSwitchTypes is a helper function to find all types a switch statement in
108
+ // the function body of f has.
109
+ func findSwitchTypes (f * types.Func ) []string {
110
+ caseTypes := make ([]string , 0 )
111
+ switches := returnSwitchBlocks (f .Scope ())
112
+ for _ , sc := range switches {
113
+ scTypes := findCaseTypes (sc )
114
+ caseTypes = append (caseTypes , scTypes ... )
115
+ }
116
+ return caseTypes
117
+ }
118
+
119
+ func returnSwitchBlocks (sc * types.Scope ) []* types.Scope {
120
+ switches := make ([]* types.Scope , 0 )
121
+ for i := 0 ; i < sc .NumChildren (); i ++ {
122
+ child := sc .Child (i )
123
+ cStr := child .String ()
124
+ // This is the easiest way to tell if it is a switch statement.
125
+ if strings .Contains (cStr , "type switch scope" ) {
126
+ switches = append (switches , child )
127
+ }
128
+ }
129
+ return switches
130
+ }
131
+
132
+ func findCaseTypes (sc * types.Scope ) []string {
133
+ caseTypes := make ([]string , 0 )
134
+ for i := 0 ; i < sc .NumChildren (); i ++ {
135
+ child := sc .Child (i )
136
+ for _ , name := range child .Names () {
137
+ obj := child .Lookup (name ).Type ()
138
+ typeName := obj .String ()
139
+ // Ignore the "Default:" case
140
+ if typeName != "any" {
141
+ caseTypes = append (caseTypes , typeName )
142
+ }
143
+ }
144
+ }
145
+ return caseTypes
55
146
}
0 commit comments