Skip to content

Commit a37fead

Browse files
committed
Support asserting outputs in authzquery test
1 parent 29e7c46 commit a37fead

File tree

2 files changed

+27
-11
lines changed

2 files changed

+27
-11
lines changed

coderd/authzquery/methods_test.go

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ MethodLoop:
128128
// if RBAC will disallow the request. The returned error should
129129
// be expected to be a NotAuthorizedError.
130130
erroredResp := reflect.ValueOf(az).Method(i).Call(append([]reflect.Value{reflect.ValueOf(ctx)}, testCase.Inputs...))
131-
err := findError(t, erroredResp)
131+
_, err := splitResp(t, erroredResp)
132132
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
133133
// any case where the error is nil and the response is an empty slice.
134134
if err != nil || !hasEmptySliceResponse(erroredResp) {
@@ -143,8 +143,14 @@ MethodLoop:
143143

144144
resp := reflect.ValueOf(az).Method(i).Call(append([]reflect.Value{reflect.ValueOf(ctx)}, testCase.Inputs...))
145145
// TODO: Should we assert the object returned is the correct one?
146-
err := findError(t, resp)
146+
outputs, err := splitResp(t, resp)
147147
require.NoError(t, err, "method %q returned an error", testName)
148+
if testCase.ExpectedOutputs != nil {
149+
require.Equal(t, len(testCase.ExpectedOutputs), len(outputs), "method %q returned unexpected number of outputs", testName)
150+
for i := range outputs {
151+
require.Equal(t, testCase.ExpectedOutputs[i].Interface(), outputs[i].Interface(), "method %q returned unexpected output %d", testName, i)
152+
}
153+
}
148154
found = true
149155
break MethodLoop
150156
}
@@ -177,19 +183,21 @@ func hasEmptySliceResponse(values []reflect.Value) bool {
177183
return false
178184
}
179185

180-
func findError(t *testing.T, values []reflect.Value) error {
186+
func splitResp(t *testing.T, values []reflect.Value) ([]reflect.Value, error) {
187+
outputs := []reflect.Value{}
181188
for _, r := range values {
182189
if r.Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) {
183190
if r.IsNil() {
184191
// Error is found, but it's nil!
185-
return nil
192+
return outputs, nil
186193
}
187194
err, ok := r.Interface().(error)
188195
if !ok {
189196
t.Fatal("error is not an error?!")
190197
}
191-
return err
198+
return outputs, err
192199
}
200+
outputs = append(outputs, r)
193201
}
194202
t.Fatal("no expected error value found in responses (error can be nil)")
195203
panic("unreachable") // For compile reasons
@@ -200,6 +208,8 @@ func findError(t *testing.T, values []reflect.Value) error {
200208
type MethodCase struct {
201209
Inputs []reflect.Value
202210
Assertions []AssertRBAC
211+
// Output is optional. Can assert non-error return values.
212+
ExpectedOutputs []reflect.Value
203213
}
204214

205215
// AssertRBAC contains the object and actions to be asserted.
@@ -218,13 +228,19 @@ type AssertRBAC struct {
218228
// Inputs: inputs(workspace, template, ...),
219229
// Assertions: asserts(workspace, rbac.ActionRead, template, rbac.ActionRead, ...),
220230
// }
221-
func methodCase(inputs []reflect.Value, assertions []AssertRBAC) MethodCase {
231+
func methodCase(ins []reflect.Value, assertions []AssertRBAC) MethodCase {
222232
return MethodCase{
223-
Inputs: inputs,
224-
Assertions: assertions,
233+
Inputs: ins,
234+
Assertions: assertions,
235+
ExpectedOutputs: nil,
225236
}
226237
}
227238

239+
func (m MethodCase) Outputs(outs ...any) MethodCase {
240+
m.ExpectedOutputs = inputs(outs...)
241+
return m
242+
}
243+
228244
// inputs is a convenience method for creating []reflect.Value.
229245
//
230246
// inputs(workspace, template, ...)
@@ -236,9 +252,9 @@ func methodCase(inputs []reflect.Value, assertions []AssertRBAC) MethodCase {
236252
// reflect.ValueOf(template),
237253
// ...
238254
// }
239-
func inputs(inputs ...any) []reflect.Value {
255+
func inputs(ins ...any) []reflect.Value {
240256
out := make([]reflect.Value, 0)
241-
for _, input := range inputs {
257+
for _, input := range ins {
242258
input := input
243259
out = append(out, reflect.ValueOf(input))
244260
}

coderd/authzquery/user_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func (s *MethodTestSuite) TestUser() {
4141
s.Run("GetUserByID", func() {
4242
s.RunMethodTest(func(t *testing.T, db database.Store) MethodCase {
4343
u := dbgen.User(t, db, database.User{})
44-
return methodCase(inputs(u.ID), asserts(u, rbac.ActionRead))
44+
return methodCase(inputs(u.ID), asserts(u, rbac.ActionRead)).Outputs(u)
4545
})
4646
})
4747
s.Run("GetAuthorizedUserCount", func() {

0 commit comments

Comments
 (0)