@@ -130,12 +130,11 @@ func (s *MethodTestSuite) SetupTest() {
130
130
131
131
func (s * MethodTestSuite ) TearDownTest () {
132
132
var (
133
- t = s .T ()
134
- az = s .az
135
- testCase = s .testCase
136
- fakeAuthorizer = s .authz
137
- ctx = s .ctx
138
- rec = s .rec
133
+ t = s .T ()
134
+ az = s .az
135
+ testCase = s .testCase
136
+ ctx = s .ctx
137
+ rec = s .rec
139
138
)
140
139
141
140
require .NotEqualf (t , "" , testCase .MethodName , "Method name must be set" )
@@ -149,43 +148,33 @@ func (s *MethodTestSuite) TearDownTest() {
149
148
MethodLoop:
150
149
for i := 0 ; i < azt .NumMethod (); i ++ {
151
150
method := azt .Method (i )
151
+ callMethod := func () ([]reflect.Value , error ) {
152
+ resp := reflect .ValueOf (az ).Method (i ).Call (append ([]reflect.Value {reflect .ValueOf (ctx )}, testCase .Inputs ... ))
153
+ return splitResp (t , resp )
154
+ }
155
+
152
156
if method .Name == methodName {
153
157
if len (testCase .Assertions ) > 0 {
154
- fakeAuthorizer .AlwaysReturn = xerrors .New ("Always fail authz" )
155
- // If we have assertions, that means the method should FAIL
156
- // if RBAC will disallow the request. The returned error should
157
- // be expected to be a NotAuthorizedError.
158
- erroredResp := reflect .ValueOf (az ).Method (i ).Call (append ([]reflect.Value {reflect .ValueOf (ctx )}, testCase .Inputs ... ))
159
- _ , err := splitResp (t , erroredResp )
160
- // This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
161
- // any case where the error is nil and the response is an empty slice.
162
- if err != nil || ! hasEmptySliceResponse (erroredResp ) {
163
- require .Errorf (t , err , "method %q should an error with disallow authz" , methodName )
164
- require .ErrorIsf (t , err , sql .ErrNoRows , "error should match sql.ErrNoRows" )
165
- require .ErrorAs (t , err , & authzquery.NotAuthorizedError {}, "error should be NotAuthorizedError" )
166
- }
167
- // Set things back to normal.
168
- fakeAuthorizer .AlwaysReturn = nil
169
- rec .Reset ()
158
+ // Run testing on expected errors
159
+ s .TestNotAuthorized (callMethod )
160
+ s .TestNoActor (callMethod )
170
161
}
171
162
172
- resp := reflect .ValueOf (az ).Method (i ).Call (append ([]reflect.Value {reflect .ValueOf (ctx )}, testCase .Inputs ... ))
173
-
174
- outputs , err := splitResp (t , resp )
175
- require .NoError (t , err , "method %q returned an error" , t .Name ())
163
+ outputs , err := callMethod ()
164
+ s .NoError (err , "method %q returned an error" , t .Name ())
176
165
177
166
// Some tests may not care about the outputs, so we only assert if
178
167
// they are provided.
179
168
if testCase .ExpectedOutputs != nil {
180
169
// Assert the required outputs
181
- require .Equal (t , len (testCase .ExpectedOutputs ), len (outputs ), "method %q returned unexpected number of outputs" , methodName )
170
+ s .Equal (len (testCase .ExpectedOutputs ), len (outputs ), "method %q returned unexpected number of outputs" , methodName )
182
171
for i := range outputs {
183
172
a , b := testCase .ExpectedOutputs [i ].Interface (), outputs [i ].Interface ()
184
173
if reflect .TypeOf (a ).Kind () == reflect .Slice || reflect .TypeOf (a ).Kind () == reflect .Array {
185
174
// Order does not matter
186
- require .ElementsMatch (t , a , b , "method %q returned unexpected output %d" , methodName , i )
175
+ s .ElementsMatch (a , b , "method %q returned unexpected output %d" , methodName , i )
187
176
} else {
188
- require .Equal (t , a , b , "method %q returned unexpected output %d" , methodName , i )
177
+ s .Equal (a , b , "method %q returned unexpected output %d" , methodName , i )
189
178
}
190
179
}
191
180
}
@@ -195,7 +184,7 @@ MethodLoop:
195
184
}
196
185
}
197
186
198
- require .True (t , found , "method %q does not exist" , methodName )
187
+ s .True (found , "method %q does not exist" , methodName )
199
188
200
189
var pairs []coderdtest.ActionObjectPair
201
190
for _ , assrt := range testCase .Assertions {
@@ -208,10 +197,40 @@ MethodLoop:
208
197
}
209
198
210
199
rec .AssertActor (t , s .actor , pairs ... )
211
- require .NoError (t , rec .AllAsserted (), "all rbac calls must be asserted" )
200
+ s .NoError (rec .AllAsserted (), "all rbac calls must be asserted" )
212
201
s .clear ()
213
202
}
214
203
204
+ func (s * MethodTestSuite ) TestNoActor (callMethod func () ([]reflect.Value , error )) {
205
+ // TODO:
206
+ }
207
+
208
+ // TestNotAuthorized runs the given method with an authorizer that will fail authz.
209
+ // Asserts that the error returned is a NotAuthorizedError.
210
+ func (s * MethodTestSuite ) TestNotAuthorized (callMethod func () ([]reflect.Value , error )) {
211
+ tmp := s .authz .AlwaysReturn
212
+ defer func () {
213
+ // Set things back to the way they were
214
+ s .rec .Reset ()
215
+ s .authz .AlwaysReturn = tmp
216
+ }()
217
+
218
+ s .authz .AlwaysReturn = xerrors .New ("Always fail authz" )
219
+
220
+ // If we have assertions, that means the method should FAIL
221
+ // if RBAC will disallow the request. The returned error should
222
+ // be expected to be a NotAuthorizedError.
223
+ resp , err := callMethod ()
224
+
225
+ // This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
226
+ // any case where the error is nil and the response is an empty slice.
227
+ if err != nil || ! hasEmptySliceResponse (resp ) {
228
+ s .Errorf (err , "method should an error with disallow authz" )
229
+ s .ErrorIsf (err , sql .ErrNoRows , "error should match sql.ErrNoRows" )
230
+ s .ErrorAs (err , & authzquery.NotAuthorizedError {}, "error should be NotAuthorizedError" )
231
+ }
232
+ }
233
+
215
234
func (s * MethodTestSuite ) Asserts (v ... any ) * MethodTestSuite {
216
235
s .testCase .MethodName = methodName (s .T ())
217
236
s .testCase = s .testCase .Asserts (v ... )
@@ -228,6 +247,10 @@ func (s *MethodTestSuite) Returns(v ...any) *MethodTestSuite {
228
247
return s
229
248
}
230
249
250
+ func (s * MethodTestSuite ) f () {
251
+
252
+ }
253
+
231
254
// RunMethodTest runs a method test case.
232
255
// The method to be tested is inferred from the name of the test case.
233
256
// Deprecated
0 commit comments