@@ -4,19 +4,18 @@ import (
4
4
"context"
5
5
"fmt"
6
6
"reflect"
7
+ "sort"
7
8
"strings"
8
9
"testing"
9
10
10
11
"github.com/google/uuid"
11
-
12
12
"github.com/stretchr/testify/require"
13
+ "github.com/stretchr/testify/suite"
13
14
14
15
"github.com/coder/coder/coderd/authzquery"
15
16
"github.com/coder/coder/coderd/coderdtest"
16
- "github.com/coder/coder/coderd/database/databasefake"
17
- "github.com/stretchr/testify/suite"
18
-
19
17
"github.com/coder/coder/coderd/database"
18
+ "github.com/coder/coder/coderd/database/databasefake"
20
19
"github.com/coder/coder/coderd/rbac"
21
20
)
22
21
27
26
}
28
27
)
29
28
30
- // MethodTestSuite runs all methods tests for AuthzQuerier. The reason we use
31
- // a test suite, is so we can account for all functions tested on the AuthzQuerier.
29
+ // TestMethodTestSuite runs MethodTestSuite.
30
+ // In order for 'go test' to run this suite, we need to create
31
+ // a normal test function and pass our suite to suite.Run
32
+ // nolint: paralleltest
33
+ func TestMethodTestSuite (t * testing.T ) {
34
+ suite .Run (t , new (MethodTestSuite ))
35
+ }
36
+
37
+ // MethodTestSuite runs all methods tests for AuthzQuerier. We use
38
+ // a test suite so we can account for all functions tested on the AuthzQuerier.
32
39
// We can then assert all methods were tested and asserted for proper RBAC
33
40
// checks. This forces RBAC checks to be written for all methods.
34
41
// Additionally, the way unit tests are written allows for easily executing
@@ -39,52 +46,46 @@ type MethodTestSuite struct {
39
46
methodAccounting map [string ]int
40
47
}
41
48
42
- func (suite * MethodTestSuite ) SetupSuite () {
49
+ // SetupSuite sets up the suite by creating a map of all methods on AuthzQuerier
50
+ // and setting their count to 0.
51
+ func (s * MethodTestSuite ) SetupSuite () {
43
52
az := & authzquery.AuthzQuerier {}
44
53
azt := reflect .TypeOf (az )
45
- suite .methodAccounting = make (map [string ]int )
54
+ s .methodAccounting = make (map [string ]int )
46
55
for i := 0 ; i < azt .NumMethod (); i ++ {
47
56
method := azt .Method (i )
48
57
if _ , ok := skipMethods [method .Name ]; ok {
49
58
continue
50
59
}
51
- suite .methodAccounting [method .Name ] = 0
60
+ s .methodAccounting [method .Name ] = 0
52
61
}
53
62
}
54
63
55
- func (suite * MethodTestSuite ) TearDownSuite () {
56
- suite .Run ("Accounting" , func () {
57
- t := suite .T ()
58
- for m , c := range suite .methodAccounting {
64
+ // TearDownSuite asserts that all methods were called at least once.
65
+ func (s * MethodTestSuite ) TearDownSuite () {
66
+ s .Run ("Accounting" , func () {
67
+ t := s .T ()
68
+ notCalled := []string {}
69
+ for m , c := range s .methodAccounting {
59
70
if c <= 0 {
60
- t . Errorf ( "Method %q never called" , m )
71
+ notCalled = append ( notCalled , m )
61
72
}
62
73
}
74
+ sort .Strings (notCalled )
75
+ for _ , m := range notCalled {
76
+ t .Errorf ("Method never called: %q" , m )
77
+ }
63
78
})
64
79
}
65
80
66
- // In order for 'go test' to run this suite, we need to create
67
- // a normal test function and pass our suite to suite.Run
68
- func TestMethodTestSuite (t * testing.T ) {
69
- suite .Run (t , new (MethodTestSuite ))
70
- }
71
-
72
- type MethodCase struct {
73
- Inputs []reflect.Value
74
- Assertions []AssertRBAC
75
- }
76
-
77
- type AssertRBAC struct {
78
- Object rbac.Object
79
- Actions []rbac.Action
80
- }
81
-
82
- func (suite * MethodTestSuite ) RunMethodTest (testCaseF func (t * testing.T , db database.Store ) MethodCase ) {
83
- t := suite .T ()
84
- testName := suite .T ().Name ()
81
+ // RunMethodTest runs a method test case.
82
+ // The method to be tested is inferred from the name of the test case.
83
+ func (s * MethodTestSuite ) RunMethodTest (testCaseF func (t * testing.T , db database.Store ) MethodCase ) {
84
+ t := s .T ()
85
+ testName := s .T ().Name ()
85
86
names := strings .Split (testName , "/" )
86
87
methodName := names [len (names )- 1 ]
87
- suite .methodAccounting [methodName ]++
88
+ s .methodAccounting [methodName ]++
88
89
89
90
db := databasefake .New ()
90
91
rec := & coderdtest.RecordingAuthorizer {
@@ -131,7 +132,48 @@ MethodLoop:
131
132
require .NoError (t , rec .AllAsserted (), "all rbac calls must be asserted" )
132
133
}
133
134
134
- func methodInputs (inputs ... any ) []reflect.Value {
135
+ // A MethodCase contains the inputs to be provided to a single method call,
136
+ // and the assertions to be made on the RBAC checks.
137
+ type MethodCase struct {
138
+ Inputs []reflect.Value
139
+ Assertions []AssertRBAC
140
+ }
141
+
142
+ // AssertRBAC contains the object and actions to be asserted.
143
+ type AssertRBAC struct {
144
+ Object rbac.Object
145
+ Actions []rbac.Action
146
+ }
147
+
148
+ // methodCase is a convenience method for creating MethodCases.
149
+ //
150
+ // methodCase(inputs(workspace, template, ...), asserts(workspace, rbac.ActionRead, template, rbac.ActionRead, ...))
151
+ //
152
+ // is equivalent to
153
+ //
154
+ // MethodCase{
155
+ // Inputs: inputs(workspace, template, ...),
156
+ // Assertions: asserts(workspace, rbac.ActionRead, template, rbac.ActionRead, ...),
157
+ // }
158
+ func methodCase (inputs []reflect.Value , assertions []AssertRBAC ) MethodCase {
159
+ return MethodCase {
160
+ Inputs : inputs ,
161
+ Assertions : assertions ,
162
+ }
163
+ }
164
+
165
+ // inputs is a convenience method for creating []reflect.Value.
166
+ //
167
+ // inputs(workspace, template, ...)
168
+ //
169
+ // is equivalent to
170
+ //
171
+ // []reflect.Value{
172
+ // reflect.ValueOf(workspace),
173
+ // reflect.ValueOf(template),
174
+ // ...
175
+ // }
176
+ func inputs (inputs ... any ) []reflect.Value {
135
177
out := make ([]reflect.Value , 0 )
136
178
for _ , input := range inputs {
137
179
input := input
@@ -140,6 +182,24 @@ func methodInputs(inputs ...any) []reflect.Value {
140
182
return out
141
183
}
142
184
185
+ // asserts is a convenience method for creating AssertRBACs.
186
+ //
187
+ // The number of inputs must be an even number.
188
+ // asserts() will panic if this is not the case.
189
+ //
190
+ // Even-numbered inputs are the objects, and odd-numbered inputs are the actions.
191
+ // Objects must implement rbac.Objecter.
192
+ // Inputs can be a single rbac.Action, or a slice of rbac.Action.
193
+ //
194
+ // asserts(workspace, rbac.ActionRead, template, slice(rbac.ActionRead, rbac.ActionWrite), ...)
195
+ //
196
+ // is equivalent to
197
+ //
198
+ // []AssertRBAC{
199
+ // {Object: workspace, Actions: []rbac.Action{rbac.ActionRead}},
200
+ // {Object: template, Actions: []rbac.Action{rbac.ActionRead, rbac.ActionWrite)}},
201
+ // ...
202
+ // }
143
203
func asserts (inputs ... any ) []AssertRBAC {
144
204
if len (inputs )% 2 != 0 {
145
205
panic (fmt .Sprintf ("Must be an even length number of args, found %d" , len (inputs )))
@@ -149,7 +209,7 @@ func asserts(inputs ...any) []AssertRBAC {
149
209
for i := 0 ; i < len (inputs ); i += 2 {
150
210
obj , ok := inputs [i ].(rbac.Objecter )
151
211
if ! ok {
152
- panic (fmt .Sprintf ("object type '%T' not a supported key " , obj ))
212
+ panic (fmt .Sprintf ("object type '%T' does not implement rbac.Objecter " , obj ))
153
213
}
154
214
rbacObj := obj .RBACObject ()
155
215
0 commit comments