@@ -50,6 +50,18 @@ type MethodTestSuite struct {
50
50
suite.Suite
51
51
// methodAccounting counts all methods called by a 'RunMethodTest'
52
52
methodAccounting map [string ]int
53
+
54
+ // Individual state for each unit test.
55
+ // State used by developer
56
+ DB database.Store
57
+ // State set by setup
58
+ ctx context.Context
59
+ az * authzquery.AuthzQuerier
60
+ rec * coderdtest.RecordingAuthorizer
61
+ authz * coderdtest.FakeAuthorizer
62
+ actor rbac.Subject
63
+ // State set by developer
64
+ testCase MethodCase
53
65
}
54
66
55
67
// SetupSuite sets up the suite by creating a map of all methods on AuthzQuerier
@@ -86,8 +98,139 @@ func (s *MethodTestSuite) TearDownSuite() {
86
98
})
87
99
}
88
100
101
+ func (s * MethodTestSuite ) clear () {
102
+ s .DB = nil
103
+ s .ctx = nil
104
+ s .az = nil
105
+ s .rec = nil
106
+ s .actor = rbac.Subject {}
107
+ s .testCase = MethodCase {}
108
+ s .authz = nil
109
+ }
110
+
111
+ func (s * MethodTestSuite ) SetupTest () {
112
+ s .clear ()
113
+
114
+ s .DB = dbfake .New ()
115
+ s .authz = & coderdtest.FakeAuthorizer {
116
+ AlwaysReturn : nil ,
117
+ }
118
+ s .rec = & coderdtest.RecordingAuthorizer {
119
+ Wrapped : s .authz ,
120
+ }
121
+ s .az = authzquery .New (s .DB , s .rec , slog .Make ())
122
+ s .actor = rbac.Subject {
123
+ ID : uuid .NewString (),
124
+ Roles : rbac.RoleNames {rbac .RoleOwner ()},
125
+ Groups : []string {},
126
+ Scope : rbac .ScopeAll ,
127
+ }
128
+ s .ctx = authzquery .WithAuthorizeContext (context .Background (), s .actor )
129
+ }
130
+
131
+ func (s * MethodTestSuite ) TearDownTest () {
132
+ var (
133
+ t = s .T ()
134
+ az = s .az
135
+ testCase = s .testCase
136
+ fakeAuthorizer = s .authz
137
+ ctx = s .ctx
138
+ rec = s .rec
139
+ )
140
+
141
+ require .NotEqualf (t , "" , testCase .MethodName , "Method name must be set" )
142
+
143
+ methodName := testCase .MethodName
144
+ s .methodAccounting [methodName ]++
145
+
146
+ // Find the method with the name of the test.
147
+ found := false
148
+ azt := reflect .TypeOf (az )
149
+ MethodLoop:
150
+ for i := 0 ; i < azt .NumMethod (); i ++ {
151
+ method := azt .Method (i )
152
+ if method .Name == methodName {
153
+ if len (testCase .Assertions ) > 0 {
154
+ fakeAuthorizer .AlwaysReturn = xerrors .New ("Always fail authz" )
155
+ // If we have assertions, that means the method should FAIL
156
+ // if RBAC will disallow the request. The returned error should
157
+ // be expected to be a NotAuthorizedError.
158
+ erroredResp := reflect .ValueOf (az ).Method (i ).Call (append ([]reflect.Value {reflect .ValueOf (ctx )}, testCase .Inputs ... ))
159
+ _ , err := splitResp (t , erroredResp )
160
+ // This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
161
+ // any case where the error is nil and the response is an empty slice.
162
+ if err != nil || ! hasEmptySliceResponse (erroredResp ) {
163
+ require .Errorf (t , err , "method %q should an error with disallow authz" , methodName )
164
+ require .ErrorIsf (t , err , sql .ErrNoRows , "error should match sql.ErrNoRows" )
165
+ require .ErrorAs (t , err , & authzquery.NotAuthorizedError {}, "error should be NotAuthorizedError" )
166
+ }
167
+ // Set things back to normal.
168
+ fakeAuthorizer .AlwaysReturn = nil
169
+ rec .Reset ()
170
+ }
171
+
172
+ resp := reflect .ValueOf (az ).Method (i ).Call (append ([]reflect.Value {reflect .ValueOf (ctx )}, testCase .Inputs ... ))
173
+
174
+ outputs , err := splitResp (t , resp )
175
+ require .NoError (t , err , "method %q returned an error" , t .Name ())
176
+
177
+ // Some tests may not care about the outputs, so we only assert if
178
+ // they are provided.
179
+ if testCase .ExpectedOutputs != nil {
180
+ // Assert the required outputs
181
+ require .Equal (t , len (testCase .ExpectedOutputs ), len (outputs ), "method %q returned unexpected number of outputs" , methodName )
182
+ for i := range outputs {
183
+ a , b := testCase .ExpectedOutputs [i ].Interface (), outputs [i ].Interface ()
184
+ if reflect .TypeOf (a ).Kind () == reflect .Slice || reflect .TypeOf (a ).Kind () == reflect .Array {
185
+ // Order does not matter
186
+ require .ElementsMatch (t , a , b , "method %q returned unexpected output %d" , methodName , i )
187
+ } else {
188
+ require .Equal (t , a , b , "method %q returned unexpected output %d" , methodName , i )
189
+ }
190
+ }
191
+ }
192
+
193
+ found = true
194
+ break MethodLoop
195
+ }
196
+ }
197
+
198
+ require .True (t , found , "method %q does not exist" , methodName )
199
+
200
+ var pairs []coderdtest.ActionObjectPair
201
+ for _ , assrt := range testCase .Assertions {
202
+ for _ , action := range assrt .Actions {
203
+ pairs = append (pairs , coderdtest.ActionObjectPair {
204
+ Action : action ,
205
+ Object : assrt .Object ,
206
+ })
207
+ }
208
+ }
209
+
210
+ rec .AssertActor (t , s .actor , pairs ... )
211
+ require .NoError (t , rec .AllAsserted (), "all rbac calls must be asserted" )
212
+ s .clear ()
213
+ }
214
+
215
+ func (s * MethodTestSuite ) Asserts (v ... any ) * MethodTestSuite {
216
+ s .testCase .MethodName = methodName (s .T ())
217
+ s .testCase = s .testCase .Asserts (v ... )
218
+ return s
219
+ }
220
+
221
+ func (s * MethodTestSuite ) Args (v ... any ) * MethodTestSuite {
222
+ s .testCase = s .testCase .Args (v ... )
223
+ return s
224
+ }
225
+
226
+ func (s * MethodTestSuite ) Returns (v ... any ) * MethodTestSuite {
227
+ s .testCase = s .testCase .Returns (v ... )
228
+ return s
229
+ }
230
+
89
231
// RunMethodTest runs a method test case.
90
232
// The method to be tested is inferred from the name of the test case.
233
+ // Deprecated
91
234
func (s * MethodTestSuite ) RunMethodTest (testCaseF func (t * testing.T , db database.Store ) MethodCase ) {
92
235
t := s .T ()
93
236
testName := s .T ().Name ()
@@ -215,12 +358,29 @@ func splitResp(t *testing.T, values []reflect.Value) ([]reflect.Value, error) {
215
358
// A MethodCase contains the inputs to be provided to a single method call,
216
359
// and the assertions to be made on the RBAC checks.
217
360
type MethodCase struct {
361
+ // MethodName is the name of the method to be called on the AuthzQuerier.
362
+ MethodName string
218
363
Inputs []reflect.Value
219
364
Assertions []AssertRBAC
220
365
// Output is optional. Can assert non-error return values.
221
366
ExpectedOutputs []reflect.Value
222
367
}
223
368
369
+ func (m MethodCase ) Asserts (pairs ... any ) MethodCase {
370
+ m .Assertions = asserts (pairs ... )
371
+ return m
372
+ }
373
+
374
+ func (m MethodCase ) Args (args ... any ) MethodCase {
375
+ m .Inputs = values (args ... )
376
+ return m
377
+ }
378
+
379
+ func (m MethodCase ) Returns (rets ... any ) MethodCase {
380
+ m .ExpectedOutputs = values (rets ... )
381
+ return m
382
+ }
383
+
224
384
// AssertRBAC contains the object and actions to be asserted.
225
385
type AssertRBAC struct {
226
386
Object rbac.Object
@@ -319,6 +479,13 @@ func asserts(inputs ...any) []AssertRBAC {
319
479
return out
320
480
}
321
481
482
+ func methodName (t * testing.T ) string {
483
+ testName := t .Name ()
484
+ names := strings .Split (testName , "/" )
485
+ methodName := names [len (names )- 1 ]
486
+ return methodName
487
+ }
488
+
322
489
func (s * MethodTestSuite ) TestExtraMethods () {
323
490
s .Run ("GetProvisionerDaemons" , func () {
324
491
s .RunMethodTest (func (t * testing.T , db database.Store ) MethodCase {
0 commit comments