Skip to content

Commit f5dbd3e

Browse files
committed
Convert more tests to new format
1 parent c902715 commit f5dbd3e

File tree

4 files changed

+589
-745
lines changed

4 files changed

+589
-745
lines changed

coderd/authzquery/apikey_test.go

Lines changed: 38 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package authzquery_test
22

33
import (
4-
"testing"
54
"time"
65

76
"github.com/coder/coder/coderd/database"
@@ -11,55 +10,42 @@ import (
1110
)
1211

1312
func (s *MethodTestSuite) TestAPIKey() {
14-
s.Run("DeleteAPIKeyByID", func() {
15-
s.RunMethodTest(func(t *testing.T, db database.Store) MethodCase {
16-
key, _ := dbgen.APIKey(t, db, database.APIKey{})
17-
return methodCase(values(key.ID), asserts(key, rbac.ActionDelete), values())
18-
})
19-
})
20-
s.Run("GetAPIKeyByID", func() {
21-
s.RunMethodTest(func(t *testing.T, db database.Store) MethodCase {
22-
key, _ := dbgen.APIKey(t, db, database.APIKey{})
23-
return methodCase(values(key.ID), asserts(key, rbac.ActionRead), values(key))
24-
})
25-
})
26-
s.Run("GetAPIKeysByLoginType", func() {
27-
s.RunMethodTest(func(t *testing.T, db database.Store) MethodCase {
28-
a, _ := dbgen.APIKey(t, db, database.APIKey{LoginType: database.LoginTypePassword})
29-
b, _ := dbgen.APIKey(t, db, database.APIKey{LoginType: database.LoginTypePassword})
30-
_, _ = dbgen.APIKey(t, db, database.APIKey{LoginType: database.LoginTypeGithub})
31-
return methodCase(values(database.LoginTypePassword),
32-
asserts(a, rbac.ActionRead, b, rbac.ActionRead),
33-
values(slice.New(a, b)))
34-
})
35-
})
36-
s.Run("GetAPIKeysLastUsedAfter", func() {
37-
s.RunMethodTest(func(t *testing.T, db database.Store) MethodCase {
38-
a, _ := dbgen.APIKey(t, db, database.APIKey{LastUsed: time.Now().Add(time.Hour)})
39-
b, _ := dbgen.APIKey(t, db, database.APIKey{LastUsed: time.Now().Add(time.Hour)})
40-
_, _ = dbgen.APIKey(t, db, database.APIKey{LastUsed: time.Now().Add(-time.Hour)})
41-
return methodCase(values(time.Now()),
42-
asserts(a, rbac.ActionRead, b, rbac.ActionRead),
43-
values(slice.New(a, b)))
44-
})
45-
})
46-
s.Run("InsertAPIKey", func() {
47-
s.RunMethodTest(func(t *testing.T, db database.Store) MethodCase {
48-
u := dbgen.User(t, db, database.User{})
49-
return methodCase(values(database.InsertAPIKeyParams{
50-
UserID: u.ID,
51-
LoginType: database.LoginTypePassword,
52-
Scope: database.APIKeyScopeAll,
53-
}), asserts(rbac.ResourceAPIKey.WithOwner(u.ID.String()), rbac.ActionCreate),
54-
nil)
55-
})
56-
})
57-
s.Run("UpdateAPIKeyByID", func() {
58-
s.RunMethodTest(func(t *testing.T, db database.Store) MethodCase {
59-
a, _ := dbgen.APIKey(t, db, database.APIKey{})
60-
return methodCase(values(database.UpdateAPIKeyByIDParams{
61-
ID: a.ID,
62-
}), asserts(a, rbac.ActionUpdate), values())
63-
})
64-
})
13+
s.Run("DeleteAPIKeyByID", s.Subtest(func(db database.Store, check *MethodCase) {
14+
key, _ := dbgen.APIKey(s.T(), db, database.APIKey{})
15+
check.Args(key.ID).Asserts(key, rbac.ActionDelete).Returns()
16+
}))
17+
s.Run("GetAPIKeyByID", s.Subtest(func(db database.Store, check *MethodCase) {
18+
key, _ := dbgen.APIKey(s.T(), db, database.APIKey{})
19+
check.Args(key.ID).Asserts(key, rbac.ActionRead).Returns(key)
20+
}))
21+
s.Run("GetAPIKeysByLoginType", s.Subtest(func(db database.Store, check *MethodCase) {
22+
a, _ := dbgen.APIKey(s.T(), db, database.APIKey{LoginType: database.LoginTypePassword})
23+
b, _ := dbgen.APIKey(s.T(), db, database.APIKey{LoginType: database.LoginTypePassword})
24+
_, _ = dbgen.APIKey(s.T(), db, database.APIKey{LoginType: database.LoginTypeGithub})
25+
check.Args(database.LoginTypePassword).
26+
Asserts(a, rbac.ActionRead, b, rbac.ActionRead).
27+
Returns(slice.New(a, b))
28+
}))
29+
s.Run("GetAPIKeysLastUsedAfter", s.Subtest(func(db database.Store, check *MethodCase) {
30+
a, _ := dbgen.APIKey(s.T(), db, database.APIKey{LastUsed: time.Now().Add(time.Hour)})
31+
b, _ := dbgen.APIKey(s.T(), db, database.APIKey{LastUsed: time.Now().Add(time.Hour)})
32+
_, _ = dbgen.APIKey(s.T(), db, database.APIKey{LastUsed: time.Now().Add(-time.Hour)})
33+
check.Args(time.Now()).
34+
Asserts(a, rbac.ActionRead, b, rbac.ActionRead).
35+
Returns(slice.New(a, b))
36+
}))
37+
s.Run("InsertAPIKey", s.Subtest(func(db database.Store, check *MethodCase) {
38+
u := dbgen.User(s.T(), db, database.User{})
39+
check.Args(database.InsertAPIKeyParams{
40+
UserID: u.ID,
41+
LoginType: database.LoginTypePassword,
42+
Scope: database.APIKeyScopeAll,
43+
}).Asserts(rbac.ResourceAPIKey.WithOwner(u.ID.String()), rbac.ActionCreate)
44+
}))
45+
s.Run("UpdateAPIKeyByID", s.Subtest(func(db database.Store, check *MethodCase) {
46+
a, _ := dbgen.APIKey(s.T(), db, database.APIKey{})
47+
check.Args(database.UpdateAPIKeyByIDParams{
48+
ID: a.ID,
49+
}).Asserts(a, rbac.ActionUpdate).Returns()
50+
}))
6551
}

coderd/authzquery/methods_test.go

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *Metho
122122
if method.Name == methodName {
123123
methodF := reflect.ValueOf(az).Method(i)
124124
callMethod = func(ctx context.Context) ([]reflect.Value, error) {
125-
resp := methodF.Call(append([]reflect.Value{reflect.ValueOf(ctx)}, testCase.Inputs...))
125+
resp := methodF.Call(append([]reflect.Value{reflect.ValueOf(ctx)}, testCase.inputs...))
126126
return splitResp(t, resp)
127127
}
128128
break MethodLoop
@@ -133,7 +133,7 @@ func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *Metho
133133

134134
// Run tests that are only run if the method makes rbac assertions.
135135
// These tests assert the error conditions of the method.
136-
if len(testCase.Assertions) > 0 {
136+
if len(testCase.assertions) > 0 {
137137
// Only run these tests if we know the underlying call makes
138138
// rbac assertions.
139139
s.NotAuthorizedErrorTest(ctx, fakeAuthorizer, callMethod)
@@ -150,11 +150,11 @@ func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *Metho
150150

151151
// Some tests may not care about the outputs, so we only assert if
152152
// they are provided.
153-
if testCase.ExpectedOutputs != nil {
153+
if testCase.expectedOutputs != nil {
154154
// Assert the required outputs
155-
s.Equal(len(testCase.ExpectedOutputs), len(outputs), "method %q returned unexpected number of outputs", methodName)
155+
s.Equal(len(testCase.expectedOutputs), len(outputs), "method %q returned unexpected number of outputs", methodName)
156156
for i := range outputs {
157-
a, b := testCase.ExpectedOutputs[i].Interface(), outputs[i].Interface()
157+
a, b := testCase.expectedOutputs[i].Interface(), outputs[i].Interface()
158158
if reflect.TypeOf(a).Kind() == reflect.Slice || reflect.TypeOf(a).Kind() == reflect.Array {
159159
// Order does not matter
160160
s.ElementsMatch(a, b, "method %q returned unexpected output %d", methodName, i)
@@ -165,7 +165,7 @@ func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *Metho
165165
}
166166

167167
var pairs []coderdtest.ActionObjectPair
168-
for _, assrt := range testCase.Assertions {
168+
for _, assrt := range testCase.assertions {
169169
for _, action := range assrt.Actions {
170170
pairs = append(pairs, coderdtest.ActionObjectPair{
171171
Action: action,
@@ -246,17 +246,17 @@ MethodLoop:
246246
if method.Name == methodName {
247247
methodF := reflect.ValueOf(az).Method(i)
248248
callMethod = func(ctx context.Context) ([]reflect.Value, error) {
249-
resp := methodF.Call(append([]reflect.Value{reflect.ValueOf(ctx)}, testCase.Inputs...))
249+
resp := methodF.Call(append([]reflect.Value{reflect.ValueOf(ctx)}, testCase.inputs...))
250250
return splitResp(t, resp)
251251
}
252252
break MethodLoop
253253

254-
//if len(testCase.Assertions) > 0 {
254+
//if len(testCase.assertions) > 0 {
255255
// fakeAuthorizer.AlwaysReturn = xerrors.New("Always fail authz")
256256
// // If we have assertions, that means the method should FAIL
257257
// // if RBAC will disallow the request. The returned error should
258258
// // be expected to be a NotAuthorizedError.
259-
// erroredResp := reflect.ValueOf(az).Method(i).Call(append([]reflect.Value{reflect.ValueOf(ctx)}, testCase.Inputs...))
259+
// erroredResp := reflect.ValueOf(az).Method(i).Call(append([]reflect.Value{reflect.ValueOf(ctx)}, testCase.inputs...))
260260
// _, err := splitResp(t, erroredResp)
261261
// // This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
262262
// // any case where the error is nil and the response is an empty slice.
@@ -270,7 +270,7 @@ MethodLoop:
270270
// rec.Reset()
271271
//}
272272

273-
//resp := reflect.ValueOf(az).Method(i).Call(append([]reflect.Value{reflect.ValueOf(ctx)}, testCase.Inputs...))
273+
//resp := reflect.ValueOf(az).Method(i).Call(append([]reflect.Value{reflect.ValueOf(ctx)}, testCase.inputs...))
274274
//
275275
//outputs, err := splitResp(t, resp)
276276
//require.NoError(t, err, "method %q returned an error", testName)
@@ -299,7 +299,7 @@ MethodLoop:
299299

300300
// Run tests that are only run if the method makes rbac assertions.
301301
// These tests assert the error conditions of the method.
302-
if len(testCase.Assertions) > 0 {
302+
if len(testCase.assertions) > 0 {
303303
// Only run these tests if we know the underlying call makes
304304
// rbac assertions.
305305
s.NotAuthorizedErrorTest(ctx, fakeAuthorizer, callMethod)
@@ -315,11 +315,11 @@ MethodLoop:
315315

316316
// Some tests may not care about the outputs, so we only assert if
317317
// they are provided.
318-
if testCase.ExpectedOutputs != nil {
318+
if testCase.expectedOutputs != nil {
319319
// Assert the required outputs
320-
s.Equal(len(testCase.ExpectedOutputs), len(outputs), "method %q returned unexpected number of outputs", methodName)
320+
s.Equal(len(testCase.expectedOutputs), len(outputs), "method %q returned unexpected number of outputs", methodName)
321321
for i := range outputs {
322-
a, b := testCase.ExpectedOutputs[i].Interface(), outputs[i].Interface()
322+
a, b := testCase.expectedOutputs[i].Interface(), outputs[i].Interface()
323323
if reflect.TypeOf(a).Kind() == reflect.Slice || reflect.TypeOf(a).Kind() == reflect.Array {
324324
// Order does not matter
325325
s.ElementsMatch(a, b, "method %q returned unexpected output %d", methodName, i)
@@ -330,7 +330,7 @@ MethodLoop:
330330
}
331331

332332
var pairs []coderdtest.ActionObjectPair
333-
for _, assrt := range testCase.Assertions {
333+
for _, assrt := range testCase.assertions {
334334
for _, action := range assrt.Actions {
335335
pairs = append(pairs, coderdtest.ActionObjectPair{
336336
Action: action,
@@ -377,25 +377,25 @@ func splitResp(t *testing.T, values []reflect.Value) ([]reflect.Value, error) {
377377
// A MethodCase contains the inputs to be provided to a single method call,
378378
// and the assertions to be made on the RBAC checks.
379379
type MethodCase struct {
380-
Inputs []reflect.Value
381-
Assertions []AssertRBAC
382-
// Output is optional. Can assert non-error return values.
383-
ExpectedOutputs []reflect.Value
380+
inputs []reflect.Value
381+
assertions []AssertRBAC
382+
// expectedOutputs is optional. Can assert non-error return values.
383+
expectedOutputs []reflect.Value
384384
}
385385

386386
func (m *MethodCase) Asserts(pairs ...any) *MethodCase {
387-
m.Assertions = asserts(pairs...)
387+
m.assertions = asserts(pairs...)
388388
return m
389389
}
390390

391391
func (m *MethodCase) Args(args ...any) *MethodCase {
392-
m.Inputs = values(args...)
392+
m.inputs = values(args...)
393393
return m
394394
}
395395

396396
// Returns is optional. If it is never called, it will not be asserted.
397397
func (m *MethodCase) Returns(rets ...any) *MethodCase {
398-
m.ExpectedOutputs = values(rets...)
398+
m.expectedOutputs = values(rets...)
399399
return m
400400
}
401401

@@ -412,16 +412,16 @@ type AssertRBAC struct {
412412
// is equivalent to
413413
//
414414
// MethodCase{
415-
// Inputs: values(workspace, template, ...),
416-
// Assertions: asserts(workspace, rbac.ActionRead, template, rbac.ActionRead, ...),
415+
// inputs: values(workspace, template, ...),
416+
// assertions: asserts(workspace, rbac.ActionRead, template, rbac.ActionRead, ...),
417417
// }
418418
//
419419
// Deprecated: use MethodCase instead.
420420
func methodCase(ins []reflect.Value, assertions []AssertRBAC, outs []reflect.Value) MethodCase {
421421
return MethodCase{
422-
Inputs: ins,
423-
Assertions: assertions,
424-
ExpectedOutputs: outs,
422+
inputs: ins,
423+
assertions: assertions,
424+
expectedOutputs: outs,
425425
}
426426
}
427427

@@ -500,20 +500,16 @@ func asserts(inputs ...any) []AssertRBAC {
500500
}
501501

502502
func (s *MethodTestSuite) TestExtraMethods() {
503-
s.Run("GetProvisionerDaemons", func() {
504-
s.RunMethodTest(func(t *testing.T, db database.Store) MethodCase {
505-
d, err := db.InsertProvisionerDaemon(context.Background(), database.InsertProvisionerDaemonParams{
506-
ID: uuid.New(),
507-
})
508-
require.NoError(t, err, "insert provisioner daemon")
509-
return methodCase(values(), asserts(d, rbac.ActionRead), nil)
510-
})
511-
})
512-
s.Run("GetDeploymentDAUs", func() {
513-
s.RunMethodTest(func(t *testing.T, db database.Store) MethodCase {
514-
return methodCase(values(), asserts(rbac.ResourceUser.All(), rbac.ActionRead), nil)
503+
s.Run("GetProvisionerDaemons", s.Subtest(func(db database.Store, check *MethodCase) {
504+
d, err := db.InsertProvisionerDaemon(context.Background(), database.InsertProvisionerDaemonParams{
505+
ID: uuid.New(),
515506
})
516-
})
507+
s.NoError(err, "insert provisioner daemon")
508+
check.Args().Asserts(d, rbac.ActionRead)
509+
}))
510+
s.Run("GetDeploymentDAUs", s.Subtest(func(db database.Store, check *MethodCase) {
511+
check.Args().Asserts(rbac.ResourceUser.All(), rbac.ActionRead)
512+
}))
517513
}
518514

519515
type emptyPreparedAuthorized struct{}

0 commit comments

Comments
 (0)