Skip to content

Commit 6aa55ac

Browse files
committed
refactor out error test
1 parent 052c531 commit 6aa55ac

File tree

1 file changed

+54
-31
lines changed

1 file changed

+54
-31
lines changed

coderd/authzquery/methods_test.go

Lines changed: 54 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,11 @@ func (s *MethodTestSuite) SetupTest() {
130130

131131
func (s *MethodTestSuite) TearDownTest() {
132132
var (
133-
t = s.T()
134-
az = s.az
135-
testCase = s.testCase
136-
fakeAuthorizer = s.authz
137-
ctx = s.ctx
138-
rec = s.rec
133+
t = s.T()
134+
az = s.az
135+
testCase = s.testCase
136+
ctx = s.ctx
137+
rec = s.rec
139138
)
140139

141140
require.NotEqualf(t, "", testCase.MethodName, "Method name must be set")
@@ -149,43 +148,33 @@ func (s *MethodTestSuite) TearDownTest() {
149148
MethodLoop:
150149
for i := 0; i < azt.NumMethod(); i++ {
151150
method := azt.Method(i)
151+
callMethod := func() ([]reflect.Value, error) {
152+
resp := reflect.ValueOf(az).Method(i).Call(append([]reflect.Value{reflect.ValueOf(ctx)}, testCase.Inputs...))
153+
return splitResp(t, resp)
154+
}
155+
152156
if method.Name == methodName {
153157
if len(testCase.Assertions) > 0 {
154-
fakeAuthorizer.AlwaysReturn = xerrors.New("Always fail authz")
155-
// If we have assertions, that means the method should FAIL
156-
// if RBAC will disallow the request. The returned error should
157-
// be expected to be a NotAuthorizedError.
158-
erroredResp := reflect.ValueOf(az).Method(i).Call(append([]reflect.Value{reflect.ValueOf(ctx)}, testCase.Inputs...))
159-
_, err := splitResp(t, erroredResp)
160-
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
161-
// any case where the error is nil and the response is an empty slice.
162-
if err != nil || !hasEmptySliceResponse(erroredResp) {
163-
require.Errorf(t, err, "method %q should an error with disallow authz", methodName)
164-
require.ErrorIsf(t, err, sql.ErrNoRows, "error should match sql.ErrNoRows")
165-
require.ErrorAs(t, err, &authzquery.NotAuthorizedError{}, "error should be NotAuthorizedError")
166-
}
167-
// Set things back to normal.
168-
fakeAuthorizer.AlwaysReturn = nil
169-
rec.Reset()
158+
// Run testing on expected errors
159+
s.TestNotAuthorized(callMethod)
160+
s.TestNoActor(callMethod)
170161
}
171162

172-
resp := reflect.ValueOf(az).Method(i).Call(append([]reflect.Value{reflect.ValueOf(ctx)}, testCase.Inputs...))
173-
174-
outputs, err := splitResp(t, resp)
175-
require.NoError(t, err, "method %q returned an error", t.Name())
163+
outputs, err := callMethod()
164+
s.NoError(err, "method %q returned an error", t.Name())
176165

177166
// Some tests may not care about the outputs, so we only assert if
178167
// they are provided.
179168
if testCase.ExpectedOutputs != nil {
180169
// Assert the required outputs
181-
require.Equal(t, len(testCase.ExpectedOutputs), len(outputs), "method %q returned unexpected number of outputs", methodName)
170+
s.Equal(len(testCase.ExpectedOutputs), len(outputs), "method %q returned unexpected number of outputs", methodName)
182171
for i := range outputs {
183172
a, b := testCase.ExpectedOutputs[i].Interface(), outputs[i].Interface()
184173
if reflect.TypeOf(a).Kind() == reflect.Slice || reflect.TypeOf(a).Kind() == reflect.Array {
185174
// Order does not matter
186-
require.ElementsMatch(t, a, b, "method %q returned unexpected output %d", methodName, i)
175+
s.ElementsMatch(a, b, "method %q returned unexpected output %d", methodName, i)
187176
} else {
188-
require.Equal(t, a, b, "method %q returned unexpected output %d", methodName, i)
177+
s.Equal(a, b, "method %q returned unexpected output %d", methodName, i)
189178
}
190179
}
191180
}
@@ -195,7 +184,7 @@ MethodLoop:
195184
}
196185
}
197186

198-
require.True(t, found, "method %q does not exist", methodName)
187+
s.True(found, "method %q does not exist", methodName)
199188

200189
var pairs []coderdtest.ActionObjectPair
201190
for _, assrt := range testCase.Assertions {
@@ -208,10 +197,40 @@ MethodLoop:
208197
}
209198

210199
rec.AssertActor(t, s.actor, pairs...)
211-
require.NoError(t, rec.AllAsserted(), "all rbac calls must be asserted")
200+
s.NoError(rec.AllAsserted(), "all rbac calls must be asserted")
212201
s.clear()
213202
}
214203

204+
func (s *MethodTestSuite) TestNoActor(callMethod func() ([]reflect.Value, error)) {
205+
// TODO:
206+
}
207+
208+
// TestNotAuthorized runs the given method with an authorizer that will fail authz.
209+
// Asserts that the error returned is a NotAuthorizedError.
210+
func (s *MethodTestSuite) TestNotAuthorized(callMethod func() ([]reflect.Value, error)) {
211+
tmp := s.authz.AlwaysReturn
212+
defer func() {
213+
// Set things back to the way they were
214+
s.rec.Reset()
215+
s.authz.AlwaysReturn = tmp
216+
}()
217+
218+
s.authz.AlwaysReturn = xerrors.New("Always fail authz")
219+
220+
// If we have assertions, that means the method should FAIL
221+
// if RBAC will disallow the request. The returned error should
222+
// be expected to be a NotAuthorizedError.
223+
resp, err := callMethod()
224+
225+
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
226+
// any case where the error is nil and the response is an empty slice.
227+
if err != nil || !hasEmptySliceResponse(resp) {
228+
s.Errorf(err, "method should an error with disallow authz")
229+
s.ErrorIsf(err, sql.ErrNoRows, "error should match sql.ErrNoRows")
230+
s.ErrorAs(err, &authzquery.NotAuthorizedError{}, "error should be NotAuthorizedError")
231+
}
232+
}
233+
215234
func (s *MethodTestSuite) Asserts(v ...any) *MethodTestSuite {
216235
s.testCase.MethodName = methodName(s.T())
217236
s.testCase = s.testCase.Asserts(v...)
@@ -228,6 +247,10 @@ func (s *MethodTestSuite) Returns(v ...any) *MethodTestSuite {
228247
return s
229248
}
230249

250+
func (s *MethodTestSuite) f() {
251+
252+
}
253+
231254
// RunMethodTest runs a method test case.
232255
// The method to be tested is inferred from the name of the test case.
233256
// Deprecated

0 commit comments

Comments
 (0)