Skip to content

Commit c2bc20e

Browse files
committed
Reduce LoC by using setup and teardown test
1 parent 8295eb3 commit c2bc20e

File tree

1 file changed

+167
-0
lines changed

1 file changed

+167
-0
lines changed

coderd/authzquery/methods_test.go

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,18 @@ type MethodTestSuite struct {
5050
suite.Suite
5151
// methodAccounting counts all methods called by a 'RunMethodTest'
5252
methodAccounting map[string]int
53+
54+
// Individual state for each unit test.
55+
// State used by developer
56+
DB database.Store
57+
// State set by setup
58+
ctx context.Context
59+
az *authzquery.AuthzQuerier
60+
rec *coderdtest.RecordingAuthorizer
61+
authz *coderdtest.FakeAuthorizer
62+
actor rbac.Subject
63+
// State set by developer
64+
testCase MethodCase
5365
}
5466

5567
// SetupSuite sets up the suite by creating a map of all methods on AuthzQuerier
@@ -86,8 +98,139 @@ func (s *MethodTestSuite) TearDownSuite() {
8698
})
8799
}
88100

101+
func (s *MethodTestSuite) clear() {
102+
s.DB = nil
103+
s.ctx = nil
104+
s.az = nil
105+
s.rec = nil
106+
s.actor = rbac.Subject{}
107+
s.testCase = MethodCase{}
108+
s.authz = nil
109+
}
110+
111+
func (s *MethodTestSuite) SetupTest() {
112+
s.clear()
113+
114+
s.DB = dbfake.New()
115+
s.authz = &coderdtest.FakeAuthorizer{
116+
AlwaysReturn: nil,
117+
}
118+
s.rec = &coderdtest.RecordingAuthorizer{
119+
Wrapped: s.authz,
120+
}
121+
s.az = authzquery.New(s.DB, s.rec, slog.Make())
122+
s.actor = rbac.Subject{
123+
ID: uuid.NewString(),
124+
Roles: rbac.RoleNames{rbac.RoleOwner()},
125+
Groups: []string{},
126+
Scope: rbac.ScopeAll,
127+
}
128+
s.ctx = authzquery.WithAuthorizeContext(context.Background(), s.actor)
129+
}
130+
131+
func (s *MethodTestSuite) TearDownTest() {
132+
var (
133+
t = s.T()
134+
az = s.az
135+
testCase = s.testCase
136+
fakeAuthorizer = s.authz
137+
ctx = s.ctx
138+
rec = s.rec
139+
)
140+
141+
require.NotEqualf(t, "", testCase.MethodName, "Method name must be set")
142+
143+
methodName := testCase.MethodName
144+
s.methodAccounting[methodName]++
145+
146+
// Find the method with the name of the test.
147+
found := false
148+
azt := reflect.TypeOf(az)
149+
MethodLoop:
150+
for i := 0; i < azt.NumMethod(); i++ {
151+
method := azt.Method(i)
152+
if method.Name == methodName {
153+
if len(testCase.Assertions) > 0 {
154+
fakeAuthorizer.AlwaysReturn = xerrors.New("Always fail authz")
155+
// If we have assertions, that means the method should FAIL
156+
// if RBAC will disallow the request. The returned error should
157+
// be expected to be a NotAuthorizedError.
158+
erroredResp := reflect.ValueOf(az).Method(i).Call(append([]reflect.Value{reflect.ValueOf(ctx)}, testCase.Inputs...))
159+
_, err := splitResp(t, erroredResp)
160+
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
161+
// any case where the error is nil and the response is an empty slice.
162+
if err != nil || !hasEmptySliceResponse(erroredResp) {
163+
require.Errorf(t, err, "method %q should an error with disallow authz", methodName)
164+
require.ErrorIsf(t, err, sql.ErrNoRows, "error should match sql.ErrNoRows")
165+
require.ErrorAs(t, err, &authzquery.NotAuthorizedError{}, "error should be NotAuthorizedError")
166+
}
167+
// Set things back to normal.
168+
fakeAuthorizer.AlwaysReturn = nil
169+
rec.Reset()
170+
}
171+
172+
resp := reflect.ValueOf(az).Method(i).Call(append([]reflect.Value{reflect.ValueOf(ctx)}, testCase.Inputs...))
173+
174+
outputs, err := splitResp(t, resp)
175+
require.NoError(t, err, "method %q returned an error", t.Name())
176+
177+
// Some tests may not care about the outputs, so we only assert if
178+
// they are provided.
179+
if testCase.ExpectedOutputs != nil {
180+
// Assert the required outputs
181+
require.Equal(t, len(testCase.ExpectedOutputs), len(outputs), "method %q returned unexpected number of outputs", methodName)
182+
for i := range outputs {
183+
a, b := testCase.ExpectedOutputs[i].Interface(), outputs[i].Interface()
184+
if reflect.TypeOf(a).Kind() == reflect.Slice || reflect.TypeOf(a).Kind() == reflect.Array {
185+
// Order does not matter
186+
require.ElementsMatch(t, a, b, "method %q returned unexpected output %d", methodName, i)
187+
} else {
188+
require.Equal(t, a, b, "method %q returned unexpected output %d", methodName, i)
189+
}
190+
}
191+
}
192+
193+
found = true
194+
break MethodLoop
195+
}
196+
}
197+
198+
require.True(t, found, "method %q does not exist", methodName)
199+
200+
var pairs []coderdtest.ActionObjectPair
201+
for _, assrt := range testCase.Assertions {
202+
for _, action := range assrt.Actions {
203+
pairs = append(pairs, coderdtest.ActionObjectPair{
204+
Action: action,
205+
Object: assrt.Object,
206+
})
207+
}
208+
}
209+
210+
rec.AssertActor(t, s.actor, pairs...)
211+
require.NoError(t, rec.AllAsserted(), "all rbac calls must be asserted")
212+
s.clear()
213+
}
214+
215+
func (s *MethodTestSuite) Asserts(v ...any) *MethodTestSuite {
216+
s.testCase.MethodName = methodName(s.T())
217+
s.testCase = s.testCase.Asserts(v...)
218+
return s
219+
}
220+
221+
func (s *MethodTestSuite) Args(v ...any) *MethodTestSuite {
222+
s.testCase = s.testCase.Args(v...)
223+
return s
224+
}
225+
226+
func (s *MethodTestSuite) Returns(v ...any) *MethodTestSuite {
227+
s.testCase = s.testCase.Returns(v...)
228+
return s
229+
}
230+
89231
// RunMethodTest runs a method test case.
90232
// The method to be tested is inferred from the name of the test case.
233+
// Deprecated
91234
func (s *MethodTestSuite) RunMethodTest(testCaseF func(t *testing.T, db database.Store) MethodCase) {
92235
t := s.T()
93236
testName := s.T().Name()
@@ -215,12 +358,29 @@ func splitResp(t *testing.T, values []reflect.Value) ([]reflect.Value, error) {
215358
// A MethodCase contains the inputs to be provided to a single method call,
216359
// and the assertions to be made on the RBAC checks.
217360
type MethodCase struct {
361+
// MethodName is the name of the method to be called on the AuthzQuerier.
362+
MethodName string
218363
Inputs []reflect.Value
219364
Assertions []AssertRBAC
220365
// Output is optional. Can assert non-error return values.
221366
ExpectedOutputs []reflect.Value
222367
}
223368

369+
func (m MethodCase) Asserts(pairs ...any) MethodCase {
370+
m.Assertions = asserts(pairs...)
371+
return m
372+
}
373+
374+
func (m MethodCase) Args(args ...any) MethodCase {
375+
m.Inputs = values(args...)
376+
return m
377+
}
378+
379+
func (m MethodCase) Returns(rets ...any) MethodCase {
380+
m.ExpectedOutputs = values(rets...)
381+
return m
382+
}
383+
224384
// AssertRBAC contains the object and actions to be asserted.
225385
type AssertRBAC struct {
226386
Object rbac.Object
@@ -319,6 +479,13 @@ func asserts(inputs ...any) []AssertRBAC {
319479
return out
320480
}
321481

482+
func methodName(t *testing.T) string {
483+
testName := t.Name()
484+
names := strings.Split(testName, "/")
485+
methodName := names[len(names)-1]
486+
return methodName
487+
}
488+
322489
func (s *MethodTestSuite) TestExtraMethods() {
323490
s.Run("GetProvisionerDaemons", func() {
324491
s.RunMethodTest(func(t *testing.T, db database.Store) MethodCase {

0 commit comments

Comments
 (0)