@@ -133,59 +133,75 @@ func (s *MethodTestSuite) TearDownTest() {
133
133
t = s .T ()
134
134
az = s .az
135
135
testCase = s .testCase
136
- ctx = s .ctx
137
- rec = s .rec
138
136
)
139
137
140
- require .NotEqualf (t , "" , testCase .MethodName , "Method name must be set" )
138
+ // This ensures the test case has assertion data. If it is missing this,
139
+ // the test is incomplete
140
+ s .NotEqualf ("" , testCase .MethodName , "Method name must be set" )
141
141
142
142
methodName := testCase .MethodName
143
143
s .methodAccounting [methodName ]++
144
144
145
145
// Find the method with the name of the test.
146
- found := false
146
+ var callMethod func ( ctx context. Context ) ([]reflect. Value , error )
147
147
azt := reflect .TypeOf (az )
148
148
MethodLoop:
149
149
for i := 0 ; i < azt .NumMethod (); i ++ {
150
150
method := azt .Method (i )
151
- callMethod := func () ([]reflect.Value , error ) {
152
- resp := reflect .ValueOf (az ).Method (i ).Call (append ([]reflect.Value {reflect .ValueOf (ctx )}, testCase .Inputs ... ))
153
- return splitResp (t , resp )
154
- }
155
-
156
151
if method .Name == methodName {
157
- if len ( testCase . Assertions ) > 0 {
158
- // Run testing on expected errors
159
- s . TestNotAuthorized ( callMethod )
160
- s . TestNoActor ( callMethod )
152
+ methodF := reflect . ValueOf ( az ). Method ( i )
153
+ callMethod = func ( ctx context. Context ) ([]reflect. Value , error ) {
154
+ resp := methodF . Call ( append ([]reflect. Value { reflect . ValueOf ( ctx )}, testCase . Inputs ... ) )
155
+ return splitResp ( t , resp )
161
156
}
157
+ break MethodLoop
158
+ }
159
+ }
162
160
163
- outputs , err := callMethod ()
164
- s .NoError (err , "method %q returned an error" , t .Name ())
161
+ s .NotNil (callMethod , "method %q does not exist" , methodName )
165
162
166
- // Some tests may not care about the outputs, so we only assert if
167
- // they are provided.
168
- if testCase .ExpectedOutputs != nil {
169
- // Assert the required outputs
170
- s .Equal (len (testCase .ExpectedOutputs ), len (outputs ), "method %q returned unexpected number of outputs" , methodName )
171
- for i := range outputs {
172
- a , b := testCase .ExpectedOutputs [i ].Interface (), outputs [i ].Interface ()
173
- if reflect .TypeOf (a ).Kind () == reflect .Slice || reflect .TypeOf (a ).Kind () == reflect .Array {
174
- // Order does not matter
175
- s .ElementsMatch (a , b , "method %q returned unexpected output %d" , methodName , i )
176
- } else {
177
- s .Equal (a , b , "method %q returned unexpected output %d" , methodName , i )
178
- }
179
- }
180
- }
163
+ // Run tests that are only run if the method makes rbac assertions.
164
+ // These tests assert the error conditions of the method.
165
+ if len (testCase .Assertions ) > 0 {
166
+ // Only run these tests if we know the underlying call makes
167
+ // rbac assertions.
168
+ s .TestNotAuthorized (callMethod )
169
+ s .TestNoActor (callMethod )
170
+ }
181
171
182
- found = true
183
- break MethodLoop
172
+ // Always run
173
+ s .TestMethodCall (callMethod )
174
+ }
175
+
176
+ // TestMethodCall runs the given method and asserts:
177
+ // - The method does not return an error
178
+ // - The method makes the expected number of rbac calls
179
+ // - The method returns the expected outputs
180
+ func (s * MethodTestSuite ) TestMethodCall (callMethod func (ctx context.Context ) ([]reflect.Value , error )) {
181
+ // Reset any recordings and set the authz to always succeed in authorizing.
182
+ s .rec .Reset ()
183
+ s .authz .AlwaysReturn = nil
184
+ testCase := s .testCase
185
+
186
+ outputs , err := callMethod (s .ctx )
187
+ s .NoError (err , "method %q returned an error" , testCase .MethodName )
188
+
189
+ // Some tests may not care about the outputs, so we only assert if
190
+ // they are provided.
191
+ if testCase .ExpectedOutputs != nil {
192
+ // Assert the required outputs
193
+ s .Equal (len (testCase .ExpectedOutputs ), len (outputs ), "method %q returned unexpected number of outputs" , testCase .MethodName )
194
+ for i := range outputs {
195
+ a , b := testCase .ExpectedOutputs [i ].Interface (), outputs [i ].Interface ()
196
+ if reflect .TypeOf (a ).Kind () == reflect .Slice || reflect .TypeOf (a ).Kind () == reflect .Array {
197
+ // Order does not matter
198
+ s .ElementsMatch (a , b , "method %q returned unexpected output %d" , testCase .MethodName , i )
199
+ } else {
200
+ s .Equal (a , b , "method %q returned unexpected output %d" , testCase .MethodName , i )
201
+ }
184
202
}
185
203
}
186
204
187
- s .True (found , "method %q does not exist" , methodName )
188
-
189
205
var pairs []coderdtest.ActionObjectPair
190
206
for _ , assrt := range testCase .Assertions {
191
207
for _ , action := range assrt .Actions {
@@ -196,31 +212,25 @@ MethodLoop:
196
212
}
197
213
}
198
214
199
- rec .AssertActor (t , s .actor , pairs ... )
200
- s .NoError (rec .AllAsserted (), "all rbac calls must be asserted" )
201
- s .clear ()
215
+ s .rec .AssertActor (s .T (), s .actor , pairs ... )
216
+ s .NoError (s .rec .AllAsserted (), "all rbac calls must be asserted" )
202
217
}
203
218
204
- func (s * MethodTestSuite ) TestNoActor (callMethod func () ([]reflect.Value , error )) {
205
- // TODO:
219
+ func (s * MethodTestSuite ) TestNoActor (callMethod func (ctx context.Context ) ([]reflect.Value , error )) {
220
+ // Call without any actor
221
+ _ , err := callMethod (context .Background ())
222
+ s .ErrorIs (err , authzquery .NoActorError , "method should return NoActorError error when no actor is provided" )
206
223
}
207
224
208
225
// TestNotAuthorized runs the given method with an authorizer that will fail authz.
209
226
// Asserts that the error returned is a NotAuthorizedError.
210
- func (s * MethodTestSuite ) TestNotAuthorized (callMethod func () ([]reflect.Value , error )) {
211
- tmp := s .authz .AlwaysReturn
212
- defer func () {
213
- // Set things back to the way they were
214
- s .rec .Reset ()
215
- s .authz .AlwaysReturn = tmp
216
- }()
217
-
227
+ func (s * MethodTestSuite ) TestNotAuthorized (callMethod func (ctx context.Context ) ([]reflect.Value , error )) {
218
228
s .authz .AlwaysReturn = xerrors .New ("Always fail authz" )
219
229
220
230
// If we have assertions, that means the method should FAIL
221
231
// if RBAC will disallow the request. The returned error should
222
232
// be expected to be a NotAuthorizedError.
223
- resp , err := callMethod ()
233
+ resp , err := callMethod (s . ctx )
224
234
225
235
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
226
236
// any case where the error is nil and the response is an empty slice.
0 commit comments