Skip to content

Commit e152d5f

Browse files
committed
authzquery: add some more convenience methods, comments etc.
1 parent 44ca906 commit e152d5f

File tree

5 files changed

+105
-51
lines changed

5 files changed

+105
-51
lines changed

coderd/authzquery/authz_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func TestAuthzQueryRecursive(t *testing.T) {
4141
}
4242
// Log the name of the last method, so if there is a panic, it is
4343
// easy to know which method failed.
44-
//t.Log(method.Name)
44+
// t.Log(method.Name)
4545
// Call the function. Any infinite recursion will stack overflow.
4646
reflect.ValueOf(q).Method(i).Call(ins)
4747
}

coderd/authzquery/methods_test.go

Lines changed: 96 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,18 @@ import (
44
"context"
55
"fmt"
66
"reflect"
7+
"sort"
78
"strings"
89
"testing"
910

1011
"github.com/google/uuid"
11-
1212
"github.com/stretchr/testify/require"
13+
"github.com/stretchr/testify/suite"
1314

1415
"github.com/coder/coder/coderd/authzquery"
1516
"github.com/coder/coder/coderd/coderdtest"
16-
"github.com/coder/coder/coderd/database/databasefake"
17-
"github.com/stretchr/testify/suite"
18-
1917
"github.com/coder/coder/coderd/database"
18+
"github.com/coder/coder/coderd/database/databasefake"
2019
"github.com/coder/coder/coderd/rbac"
2120
)
2221

@@ -27,8 +26,16 @@ var (
2726
}
2827
)
2928

30-
// MethodTestSuite runs all methods tests for AuthzQuerier. The reason we use
31-
// a test suite, is so we can account for all functions tested on the AuthzQuerier.
29+
// TestMethodTestSuite runs MethodTestSuite.
30+
// In order for 'go test' to run this suite, we need to create
31+
// a normal test function and pass our suite to suite.Run
32+
// nolint: paralleltest
33+
func TestMethodTestSuite(t *testing.T) {
34+
suite.Run(t, new(MethodTestSuite))
35+
}
36+
37+
// MethodTestSuite runs all methods tests for AuthzQuerier. We use
38+
// a test suite so we can account for all functions tested on the AuthzQuerier.
3239
// We can then assert all methods were tested and asserted for proper RBAC
3340
// checks. This forces RBAC checks to be written for all methods.
3441
// Additionally, the way unit tests are written allows for easily executing
@@ -39,52 +46,46 @@ type MethodTestSuite struct {
3946
methodAccounting map[string]int
4047
}
4148

42-
func (suite *MethodTestSuite) SetupSuite() {
49+
// SetupSuite sets up the suite by creating a map of all methods on AuthzQuerier
50+
// and setting their count to 0.
51+
func (s *MethodTestSuite) SetupSuite() {
4352
az := &authzquery.AuthzQuerier{}
4453
azt := reflect.TypeOf(az)
45-
suite.methodAccounting = make(map[string]int)
54+
s.methodAccounting = make(map[string]int)
4655
for i := 0; i < azt.NumMethod(); i++ {
4756
method := azt.Method(i)
4857
if _, ok := skipMethods[method.Name]; ok {
4958
continue
5059
}
51-
suite.methodAccounting[method.Name] = 0
60+
s.methodAccounting[method.Name] = 0
5261
}
5362
}
5463

55-
func (suite *MethodTestSuite) TearDownSuite() {
56-
suite.Run("Accounting", func() {
57-
t := suite.T()
58-
for m, c := range suite.methodAccounting {
64+
// TearDownSuite asserts that all methods were called at least once.
65+
func (s *MethodTestSuite) TearDownSuite() {
66+
s.Run("Accounting", func() {
67+
t := s.T()
68+
notCalled := []string{}
69+
for m, c := range s.methodAccounting {
5970
if c <= 0 {
60-
t.Errorf("Method %q never called", m)
71+
notCalled = append(notCalled, m)
6172
}
6273
}
74+
sort.Strings(notCalled)
75+
for _, m := range notCalled {
76+
t.Errorf("Method never called: %q", m)
77+
}
6378
})
6479
}
6580

66-
// In order for 'go test' to run this suite, we need to create
67-
// a normal test function and pass our suite to suite.Run
68-
func TestMethodTestSuite(t *testing.T) {
69-
suite.Run(t, new(MethodTestSuite))
70-
}
71-
72-
type MethodCase struct {
73-
Inputs []reflect.Value
74-
Assertions []AssertRBAC
75-
}
76-
77-
type AssertRBAC struct {
78-
Object rbac.Object
79-
Actions []rbac.Action
80-
}
81-
82-
func (suite *MethodTestSuite) RunMethodTest(testCaseF func(t *testing.T, db database.Store) MethodCase) {
83-
t := suite.T()
84-
testName := suite.T().Name()
81+
// RunMethodTest runs a method test case.
82+
// The method to be tested is inferred from the name of the test case.
83+
func (s *MethodTestSuite) RunMethodTest(testCaseF func(t *testing.T, db database.Store) MethodCase) {
84+
t := s.T()
85+
testName := s.T().Name()
8586
names := strings.Split(testName, "/")
8687
methodName := names[len(names)-1]
87-
suite.methodAccounting[methodName]++
88+
s.methodAccounting[methodName]++
8889

8990
db := databasefake.New()
9091
rec := &coderdtest.RecordingAuthorizer{
@@ -131,7 +132,48 @@ MethodLoop:
131132
require.NoError(t, rec.AllAsserted(), "all rbac calls must be asserted")
132133
}
133134

134-
func methodInputs(inputs ...any) []reflect.Value {
135+
// A MethodCase contains the inputs to be provided to a single method call,
136+
// and the assertions to be made on the RBAC checks.
137+
type MethodCase struct {
138+
Inputs []reflect.Value
139+
Assertions []AssertRBAC
140+
}
141+
142+
// AssertRBAC contains the object and actions to be asserted.
143+
type AssertRBAC struct {
144+
Object rbac.Object
145+
Actions []rbac.Action
146+
}
147+
148+
// methodCase is a convenience method for creating MethodCases.
149+
//
150+
// methodCase(inputs(workspace, template, ...), asserts(workspace, rbac.ActionRead, template, rbac.ActionRead, ...))
151+
//
152+
// is equivalent to
153+
//
154+
// MethodCase{
155+
// Inputs: inputs(workspace, template, ...),
156+
// Assertions: asserts(workspace, rbac.ActionRead, template, rbac.ActionRead, ...),
157+
// }
158+
func methodCase(inputs []reflect.Value, assertions []AssertRBAC) MethodCase {
159+
return MethodCase{
160+
Inputs: inputs,
161+
Assertions: assertions,
162+
}
163+
}
164+
165+
// inputs is a convenience method for creating []reflect.Value.
166+
//
167+
// inputs(workspace, template, ...)
168+
//
169+
// is equivalent to
170+
//
171+
// []reflect.Value{
172+
// reflect.ValueOf(workspace),
173+
// reflect.ValueOf(template),
174+
// ...
175+
// }
176+
func inputs(inputs ...any) []reflect.Value {
135177
out := make([]reflect.Value, 0)
136178
for _, input := range inputs {
137179
input := input
@@ -140,6 +182,24 @@ func methodInputs(inputs ...any) []reflect.Value {
140182
return out
141183
}
142184

185+
// asserts is a convenience method for creating AssertRBACs.
186+
//
187+
// The number of inputs must be an even number.
188+
// asserts() will panic if this is not the case.
189+
//
190+
// Even-numbered inputs are the objects, and odd-numbered inputs are the actions.
191+
// Objects must implement rbac.Objecter.
192+
// Inputs can be a single rbac.Action, or a slice of rbac.Action.
193+
//
194+
// asserts(workspace, rbac.ActionRead, template, slice(rbac.ActionRead, rbac.ActionWrite), ...)
195+
//
196+
// is equivalent to
197+
//
198+
// []AssertRBAC{
199+
// {Object: workspace, Actions: []rbac.Action{rbac.ActionRead}},
200+
// {Object: template, Actions: []rbac.Action{rbac.ActionRead, rbac.ActionWrite)}},
201+
// ...
202+
// }
143203
func asserts(inputs ...any) []AssertRBAC {
144204
if len(inputs)%2 != 0 {
145205
panic(fmt.Sprintf("Must be an even length number of args, found %d", len(inputs)))
@@ -149,7 +209,7 @@ func asserts(inputs ...any) []AssertRBAC {
149209
for i := 0; i < len(inputs); i += 2 {
150210
obj, ok := inputs[i].(rbac.Objecter)
151211
if !ok {
152-
panic(fmt.Sprintf("object type '%T' not a supported key", obj))
212+
panic(fmt.Sprintf("object type '%T' does not implement rbac.Objecter", obj))
153213
}
154214
rbacObj := obj.RBACObject()
155215

coderd/authzquery/template_test.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@ func (suite *MethodTestSuite) TestTemplate() {
1212
suite.Run("GetTemplateByID", func() {
1313
suite.RunMethodTest(func(t *testing.T, db database.Store) MethodCase {
1414
obj := dbgen.Template(t, db, database.Template{})
15-
return MethodCase{
16-
Inputs: methodInputs(obj.ID),
17-
Assertions: asserts(obj, rbac.ActionRead),
18-
}
15+
return methodCase(inputs(obj.ID), asserts(obj, rbac.ActionRead))
1916
})
2017
})
2118
}

coderd/authzquery/workspace.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,9 @@ func (q *AuthzQuerier) GetWorkspaceByID(ctx context.Context, id uuid.UUID) (data
209209
return authorizedFetch(q.authorizer, q.database.GetWorkspaceByID)(ctx, id)
210210
}
211211

212-
//OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
213-
//Deleted bool `db:"deleted" json:"deleted"`
214-
//Name string `db:"name" json:"name"`
212+
// OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
213+
// Deleted bool `db:"deleted" json:"deleted"`
214+
// Name string `db:"name" json:"name"`
215215

216216
// GetWorkspaceByOwnerIDAndName
217217
// Gen: Workspace

coderd/authzquery/workspace_test.go

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,11 @@ import (
88
"github.com/coder/coder/coderd/rbac"
99
)
1010

11-
func (suite *MethodTestSuite) TestWorkspace() {
12-
suite.Run("GetWorkspaceByID", func() {
13-
suite.RunMethodTest(func(t *testing.T, db database.Store) MethodCase {
11+
func (s *MethodTestSuite) TestWorkspace() {
12+
s.Run("GetWorkspaceByID", func() {
13+
s.RunMethodTest(func(t *testing.T, db database.Store) MethodCase {
1414
workspace := dbgen.Workspace(t, db, database.Workspace{})
15-
return MethodCase{
16-
Inputs: methodInputs(workspace.ID),
17-
Assertions: asserts(workspace, rbac.ActionRead),
18-
}
15+
return methodCase(inputs(workspace.ID), asserts(workspace, rbac.ActionRead))
1916
})
2017
})
2118
}

0 commit comments

Comments
 (0)