Skip to content

Commit fdfdd73

Browse files
committed
Fix user tests to use new subtest strategy
1 parent 4c68562 commit fdfdd73

File tree

6 files changed

+495
-584
lines changed

6 files changed

+495
-584
lines changed

coderd/authzquery/authz.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ func fetchWithPostFilter[
241241
// Fetch the rbac subject
242242
act, ok := ActorFromContext(ctx)
243243
if !ok {
244-
return empty, xerrors.Errorf("no authorization actor in context")
244+
return empty, NoActorError
245245
}
246246

247247
// Fetch the database object

coderd/authzquery/methods_test.go

Lines changed: 67 additions & 247 deletions
Original file line numberDiff line numberDiff line change
@@ -50,30 +50,6 @@ type MethodTestSuite struct {
5050
suite.Suite
5151
// methodAccounting counts all methods called by a 'RunMethodTest'
5252
methodAccounting map[string]int
53-
54-
// Individual state for each unit test.
55-
// State used by developer
56-
DB database.Store
57-
// State set by setup
58-
ctx context.Context
59-
az *authzquery.AuthzQuerier
60-
rec *coderdtest.RecordingAuthorizer
61-
authz *coderdtest.FakeAuthorizer
62-
actor rbac.Subject
63-
// State set by developer
64-
testCase MethodCase
65-
}
66-
67-
type testCaseState struct {
68-
DB database.Store
69-
// State set by setup
70-
ctx context.Context
71-
az *authzquery.AuthzQuerier
72-
rec *coderdtest.RecordingAuthorizer
73-
authz *coderdtest.FakeAuthorizer
74-
actor rbac.Subject
75-
// State set by developer
76-
testCase MethodCase
7753
}
7854

7955
// SetupSuite sets up the suite by creating a map of all methods on AuthzQuerier
@@ -110,150 +86,7 @@ func (s *MethodTestSuite) TearDownSuite() {
11086
})
11187
}
11288

113-
func (s *MethodTestSuite) clear() {
114-
s.DB = nil
115-
s.ctx = nil
116-
s.az = nil
117-
s.rec = nil
118-
s.actor = rbac.Subject{}
119-
s.testCase = MethodCase{}
120-
s.authz = nil
121-
}
122-
123-
//func (s *MethodTestSuite) BeforeSubTest(_ string) {
124-
// s.clear()
125-
//
126-
// s.DB = dbfake.New()
127-
// s.authz = &coderdtest.FakeAuthorizer{
128-
// AlwaysReturn: nil,
129-
// }
130-
// s.rec = &coderdtest.RecordingAuthorizer{
131-
// Wrapped: s.authz,
132-
// }
133-
// s.az = authzquery.New(s.DB, s.rec, slog.Make())
134-
// s.actor = rbac.Subject{
135-
// ID: uuid.NewString(),
136-
// Roles: rbac.RoleNames{rbac.RoleOwner()},
137-
// Groups: []string{},
138-
// Scope: rbac.ScopeAll,
139-
// }
140-
// s.ctx = authzquery.WithAuthorizeContext(context.Background(), s.actor)
141-
//}
142-
143-
//func (s *MethodTestSuite) AfterSubTest(testName string) {
144-
// var (
145-
// t = s.T()
146-
// az = s.az
147-
// testCase = s.testCase
148-
// methodName = parseMethodName(testName)
149-
// )
150-
//
151-
// // This ensures the test case has assertion data. If it is missing this,
152-
// // the test is incomplete
153-
// s.NotEqualf("", methodName, "Method name not")
154-
//
155-
// s.methodAccounting[methodName]++
156-
//
157-
// // Find the method with the name of the test.
158-
// var callMethod func(ctx context.Context) ([]reflect.Value, error)
159-
// azt := reflect.TypeOf(az)
160-
//MethodLoop:
161-
// for i := 0; i < azt.NumMethod(); i++ {
162-
// method := azt.Method(i)
163-
// if method.Name == methodName {
164-
// methodF := reflect.ValueOf(az).Method(i)
165-
// callMethod = func(ctx context.Context) ([]reflect.Value, error) {
166-
// resp := methodF.Call(append([]reflect.Value{reflect.ValueOf(ctx)}, testCase.Inputs...))
167-
// return splitResp(t, resp)
168-
// }
169-
// break MethodLoop
170-
// }
171-
// }
172-
//
173-
// s.NotNil(callMethod, "method %q does not exist", methodName)
174-
//
175-
// // Run tests that are only run if the method makes rbac assertions.
176-
// // These tests assert the error conditions of the method.
177-
// if len(testCase.Assertions) > 0 {
178-
// // Only run these tests if we know the underlying call makes
179-
// // rbac assertions.
180-
// s.TestNotAuthorized(callMethod)
181-
// s.TestNoActor(callMethod)
182-
// }
183-
//
184-
// // Always run
185-
// s.TestMethodCall(methodName, callMethod)
186-
//}
187-
188-
// TestMethodCall runs the given method and asserts:
189-
// - The method does not return an error
190-
// - The method makes the expected number of rbac calls
191-
// - The method returns the expected outputs
192-
func (s *MethodTestSuite) TestMethodCall(ctx context.Context, methodName string, rec *coderdtest.RecordingAuthorizer, callMethod func(ctx context.Context) ([]reflect.Value, error)) {
193-
// Reset any recordings and set the authz to always succeed in authorizing.
194-
s.rec.Reset()
195-
s.authz.AlwaysReturn = nil
196-
testCase := s.testCase
197-
198-
outputs, err := callMethod(ctx)
199-
s.NoError(err, "method %q returned an error", methodName)
200-
201-
// Some tests may not care about the outputs, so we only assert if
202-
// they are provided.
203-
if testCase.ExpectedOutputs != nil {
204-
// Assert the required outputs
205-
s.Equal(len(testCase.ExpectedOutputs), len(outputs), "method %q returned unexpected number of outputs", methodName)
206-
for i := range outputs {
207-
a, b := testCase.ExpectedOutputs[i].Interface(), outputs[i].Interface()
208-
if reflect.TypeOf(a).Kind() == reflect.Slice || reflect.TypeOf(a).Kind() == reflect.Array {
209-
// Order does not matter
210-
s.ElementsMatch(a, b, "method %q returned unexpected output %d", methodName, i)
211-
} else {
212-
s.Equal(a, b, "method %q returned unexpected output %d", methodName, i)
213-
}
214-
}
215-
}
216-
217-
var pairs []coderdtest.ActionObjectPair
218-
for _, assrt := range testCase.Assertions {
219-
for _, action := range assrt.Actions {
220-
pairs = append(pairs, coderdtest.ActionObjectPair{
221-
Action: action,
222-
Object: assrt.Object,
223-
})
224-
}
225-
}
226-
227-
s.rec.AssertActor(s.T(), s.actor, pairs...)
228-
s.NoError(s.rec.AllAsserted(), "all rbac calls must be asserted")
229-
}
230-
231-
func (s *MethodTestSuite) TestNoActor(callMethod func(ctx context.Context) ([]reflect.Value, error)) {
232-
// Call without any actor
233-
_, err := callMethod(context.Background())
234-
s.ErrorIs(err, authzquery.NoActorError, "method should return NoActorError error when no actor is provided")
235-
}
236-
237-
// TestNotAuthorized runs the given method with an authorizer that will fail authz.
238-
// Asserts that the error returned is a NotAuthorizedError.
239-
func (s *MethodTestSuite) TestNotAuthorized(ctx context.Context, az *coderdtest.FakeAuthorizer, callMethod func(ctx context.Context) ([]reflect.Value, error)) {
240-
az.AlwaysReturn = xerrors.New("Always fail authz")
241-
242-
// If we have assertions, that means the method should FAIL
243-
// if RBAC will disallow the request. The returned error should
244-
// be expected to be a NotAuthorizedError.
245-
resp, err := callMethod(ctx)
246-
247-
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
248-
// any case where the error is nil and the response is an empty slice.
249-
if err != nil || !hasEmptySliceResponse(resp) {
250-
s.Errorf(err, "method should an error with disallow authz")
251-
s.ErrorIsf(err, sql.ErrNoRows, "error should match sql.ErrNoRows")
252-
s.ErrorAs(err, &authzquery.NotAuthorizedError{}, "error should be NotAuthorizedError")
253-
}
254-
}
255-
256-
func (s *MethodTestSuite) Subtest(testCaseF func(t *testing.T, db database.Store, check *MethodCase)) func() {
89+
func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *MethodCase)) func() {
25790
return func() {
25891
t := s.T()
25992
testName := s.T().Name()
@@ -278,7 +111,7 @@ func (s *MethodTestSuite) Subtest(testCaseF func(t *testing.T, db database.Store
278111
ctx := authzquery.WithAuthorizeContext(context.Background(), actor)
279112

280113
var testCase MethodCase
281-
testCaseF(t, db, &testCase)
114+
testCaseF(db, &testCase)
282115

283116
// Find the method with the name of the test.
284117
var callMethod func(ctx context.Context) ([]reflect.Value, error)
@@ -293,48 +126,6 @@ func (s *MethodTestSuite) Subtest(testCaseF func(t *testing.T, db database.Store
293126
return splitResp(t, resp)
294127
}
295128
break MethodLoop
296-
297-
//if len(testCase.Assertions) > 0 {
298-
// fakeAuthorizer.AlwaysReturn = xerrors.New("Always fail authz")
299-
// // If we have assertions, that means the method should FAIL
300-
// // if RBAC will disallow the request. The returned error should
301-
// // be expected to be a NotAuthorizedError.
302-
// erroredResp := reflect.ValueOf(az).Method(i).Call(append([]reflect.Value{reflect.ValueOf(ctx)}, testCase.Inputs...))
303-
// _, err := splitResp(t, erroredResp)
304-
// // This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
305-
// // any case where the error is nil and the response is an empty slice.
306-
// if err != nil || !hasEmptySliceResponse(erroredResp) {
307-
// require.Errorf(t, err, "method %q should an error with disallow authz", testName)
308-
// require.ErrorIsf(t, err, sql.ErrNoRows, "error should match sql.ErrNoRows")
309-
// require.ErrorAs(t, err, &authzquery.NotAuthorizedError{}, "error should be NotAuthorizedError")
310-
// }
311-
// // Set things back to normal.
312-
// fakeAuthorizer.AlwaysReturn = nil
313-
// rec.Reset()
314-
//}
315-
316-
//resp := reflect.ValueOf(az).Method(i).Call(append([]reflect.Value{reflect.ValueOf(ctx)}, testCase.Inputs...))
317-
//
318-
//outputs, err := splitResp(t, resp)
319-
//require.NoError(t, err, "method %q returned an error", testName)
320-
//
321-
//// Some tests may not care about the outputs, so we only assert if
322-
//// they are provided.
323-
//if testCase.ExpectedOutputs != nil {
324-
// // Assert the required outputs
325-
// require.Equal(t, len(testCase.ExpectedOutputs), len(outputs), "method %q returned unexpected number of outputs", testName)
326-
// for i := range outputs {
327-
// a, b := testCase.ExpectedOutputs[i].Interface(), outputs[i].Interface()
328-
// if reflect.TypeOf(a).Kind() == reflect.Slice || reflect.TypeOf(a).Kind() == reflect.Array {
329-
// // Order does not matter
330-
// require.ElementsMatch(t, a, b, "method %q returned unexpected output %d", testName, i)
331-
// } else {
332-
// require.Equal(t, a, b, "method %q returned unexpected output %d", testName, i)
333-
// }
334-
// }
335-
//}
336-
//
337-
//break MethodLoop
338129
}
339130
}
340131

@@ -350,45 +141,77 @@ func (s *MethodTestSuite) Subtest(testCaseF func(t *testing.T, db database.Store
350141
}
351142

352143
// Always run
353-
rec.Reset()
354-
fakeAuthorizer.AlwaysReturn = nil
355-
356-
outputs, err := callMethod(ctx)
357-
s.NoError(err, "method %q returned an error", methodName)
358-
359-
// Some tests may not care about the outputs, so we only assert if
360-
// they are provided.
361-
if testCase.ExpectedOutputs != nil {
362-
// Assert the required outputs
363-
s.Equal(len(testCase.ExpectedOutputs), len(outputs), "method %q returned unexpected number of outputs", methodName)
364-
for i := range outputs {
365-
a, b := testCase.ExpectedOutputs[i].Interface(), outputs[i].Interface()
366-
if reflect.TypeOf(a).Kind() == reflect.Slice || reflect.TypeOf(a).Kind() == reflect.Array {
367-
// Order does not matter
368-
s.ElementsMatch(a, b, "method %q returned unexpected output %d", methodName, i)
369-
} else {
370-
s.Equal(a, b, "method %q returned unexpected output %d", methodName, i)
144+
s.Run("Success", func() {
145+
rec.Reset()
146+
fakeAuthorizer.AlwaysReturn = nil
147+
148+
outputs, err := callMethod(ctx)
149+
s.NoError(err, "method %q returned an error", methodName)
150+
151+
// Some tests may not care about the outputs, so we only assert if
152+
// they are provided.
153+
if testCase.ExpectedOutputs != nil {
154+
// Assert the required outputs
155+
s.Equal(len(testCase.ExpectedOutputs), len(outputs), "method %q returned unexpected number of outputs", methodName)
156+
for i := range outputs {
157+
a, b := testCase.ExpectedOutputs[i].Interface(), outputs[i].Interface()
158+
if reflect.TypeOf(a).Kind() == reflect.Slice || reflect.TypeOf(a).Kind() == reflect.Array {
159+
// Order does not matter
160+
s.ElementsMatch(a, b, "method %q returned unexpected output %d", methodName, i)
161+
} else {
162+
s.Equal(a, b, "method %q returned unexpected output %d", methodName, i)
163+
}
371164
}
372165
}
373-
}
374166

375-
var pairs []coderdtest.ActionObjectPair
376-
for _, assrt := range testCase.Assertions {
377-
for _, action := range assrt.Actions {
378-
pairs = append(pairs, coderdtest.ActionObjectPair{
379-
Action: action,
380-
Object: assrt.Object,
381-
})
167+
var pairs []coderdtest.ActionObjectPair
168+
for _, assrt := range testCase.Assertions {
169+
for _, action := range assrt.Actions {
170+
pairs = append(pairs, coderdtest.ActionObjectPair{
171+
Action: action,
172+
Object: assrt.Object,
173+
})
174+
}
382175
}
383-
}
384176

385-
rec.AssertActor(s.T(), s.actor, pairs...)
386-
s.NoError(rec.AllAsserted(), "all rbac calls must be asserted")
177+
rec.AssertActor(s.T(), actor, pairs...)
178+
s.NoError(rec.AllAsserted(), "all rbac calls must be asserted")
179+
})
387180
}
388181
}
389182

183+
func (s *MethodTestSuite) TestNoActor(callMethod func(ctx context.Context) ([]reflect.Value, error)) {
184+
s.Run("NoActor", func() {
185+
// Call without any actor
186+
_, err := callMethod(context.Background())
187+
s.ErrorIs(err, authzquery.NoActorError, "method should return NoActorError error when no actor is provided")
188+
})
189+
}
190+
191+
// TestNotAuthorized runs the given method with an authorizer that will fail authz.
192+
// Asserts that the error returned is a NotAuthorizedError.
193+
func (s *MethodTestSuite) TestNotAuthorized(ctx context.Context, az *coderdtest.FakeAuthorizer, callMethod func(ctx context.Context) ([]reflect.Value, error)) {
194+
s.Run("NotAuthorized", func() {
195+
az.AlwaysReturn = xerrors.New("Always fail authz")
196+
197+
// If we have assertions, that means the method should FAIL
198+
// if RBAC will disallow the request. The returned error should
199+
// be expected to be a NotAuthorizedError.
200+
resp, err := callMethod(ctx)
201+
202+
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
203+
// any case where the error is nil and the response is an empty slice.
204+
if err != nil || !hasEmptySliceResponse(resp) {
205+
s.Errorf(err, "method should an error with disallow authz")
206+
s.ErrorIsf(err, sql.ErrNoRows, "error should match sql.ErrNoRows")
207+
s.ErrorAs(err, &authzquery.NotAuthorizedError{}, "error should be NotAuthorizedError")
208+
}
209+
})
210+
}
211+
390212
// RunMethodTest runs a method test case.
391213
// The method to be tested is inferred from the name of the test case.
214+
// Deprecated: Use Subtest instead. Remove this function!
392215
func (s *MethodTestSuite) RunMethodTest(testCaseF func(t *testing.T, db database.Store) MethodCase) {
393216
t := s.T()
394217
testName := s.T().Name()
@@ -516,7 +339,7 @@ MethodLoop:
516339
}
517340
}
518341

519-
rec.AssertActor(s.T(), s.actor, pairs...)
342+
rec.AssertActor(s.T(), actor, pairs...)
520343
s.NoError(rec.AllAsserted(), "all rbac calls must be asserted")
521344
}
522345

@@ -570,6 +393,7 @@ func (m *MethodCase) Args(args ...any) *MethodCase {
570393
return m
571394
}
572395

396+
// Returns is optional. If it is never called, it will not be asserted.
573397
func (m *MethodCase) Returns(rets ...any) *MethodCase {
574398
m.ExpectedOutputs = values(rets...)
575399
return m
@@ -591,6 +415,8 @@ type AssertRBAC struct {
591415
// Inputs: values(workspace, template, ...),
592416
// Assertions: asserts(workspace, rbac.ActionRead, template, rbac.ActionRead, ...),
593417
// }
418+
//
419+
// Deprecated: use MethodCase instead.
594420
func methodCase(ins []reflect.Value, assertions []AssertRBAC, outs []reflect.Value) MethodCase {
595421
return MethodCase{
596422
Inputs: ins,
@@ -673,12 +499,6 @@ func asserts(inputs ...any) []AssertRBAC {
673499
return out
674500
}
675501

676-
func parseMethodName(testName string) string {
677-
names := strings.Split(testName, "/")
678-
methodName := names[len(names)-1]
679-
return methodName
680-
}
681-
682502
func (s *MethodTestSuite) TestExtraMethods() {
683503
s.Run("GetProvisionerDaemons", func() {
684504
s.RunMethodTest(func(t *testing.T, db database.Store) MethodCase {

0 commit comments

Comments
 (0)