Skip to content

Commit 8c731a0

Browse files
authored
chore(coderd/database/dbauthz): refactor TestPing, TestNew, TestInTX to use dbmock (coder#19604)
Part of coder/internal#869
1 parent 43fe44d commit 8c731a0

File tree

1 file changed

+37
-34
lines changed

1 file changed

+37
-34
lines changed

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ func TestAsNoActor(t *testing.T) {
7373
func TestPing(t *testing.T) {
7474
t.Parallel()
7575

76-
db, _ := dbtestutil.NewDB(t)
76+
db := dbmock.NewMockStore(gomock.NewController(t))
77+
db.EXPECT().Wrappers().Times(1).Return([]string{})
78+
db.EXPECT().Ping(gomock.Any()).Times(1).Return(time.Second, nil)
7779
q := dbauthz.New(db, &coderdtest.RecordingAuthorizer{}, slog.Make(), coderdtest.AccessControlStorePointer())
7880
_, err := q.Ping(context.Background())
7981
require.NoError(t, err, "must not error")
@@ -83,34 +85,39 @@ func TestPing(t *testing.T) {
8385
func TestInTX(t *testing.T) {
8486
t.Parallel()
8587

86-
db, _ := dbtestutil.NewDB(t)
88+
var (
89+
ctrl = gomock.NewController(t)
90+
db = dbmock.NewMockStore(ctrl)
91+
mTx = dbmock.NewMockStore(ctrl) // to record the 'in tx' calls
92+
faker = gofakeit.New(0)
93+
w = testutil.Fake(t, faker, database.Workspace{})
94+
actor = rbac.Subject{
95+
ID: uuid.NewString(),
96+
Roles: rbac.RoleIdentifiers{rbac.RoleOwner()},
97+
Groups: []string{},
98+
Scope: rbac.ScopeAll,
99+
}
100+
ctx = dbauthz.As(context.Background(), actor)
101+
)
102+
103+
db.EXPECT().Wrappers().Times(1).Return([]string{}) // called by dbauthz.New
87104
q := dbauthz.New(db, &coderdtest.RecordingAuthorizer{
88105
Wrapped: (&coderdtest.FakeAuthorizer{}).AlwaysReturn(xerrors.New("custom error")),
89106
}, slog.Make(), coderdtest.AccessControlStorePointer())
90-
actor := rbac.Subject{
91-
ID: uuid.NewString(),
92-
Roles: rbac.RoleIdentifiers{rbac.RoleOwner()},
93-
Groups: []string{},
94-
Scope: rbac.ScopeAll,
95-
}
96-
u := dbgen.User(t, db, database.User{})
97-
o := dbgen.Organization(t, db, database.Organization{})
98-
tpl := dbgen.Template(t, db, database.Template{
99-
CreatedBy: u.ID,
100-
OrganizationID: o.ID,
101-
})
102-
w := dbgen.Workspace(t, db, database.WorkspaceTable{
103-
OwnerID: u.ID,
104-
TemplateID: tpl.ID,
105-
OrganizationID: o.ID,
106-
})
107-
ctx := dbauthz.As(context.Background(), actor)
107+
108+
db.EXPECT().InTx(gomock.Any(), gomock.Any()).Times(1).DoAndReturn(
109+
func(f func(database.Store) error, _ *database.TxOptions) error {
110+
return f(mTx)
111+
},
112+
)
113+
mTx.EXPECT().Wrappers().Times(1).Return([]string{})
114+
mTx.EXPECT().GetWorkspaceByID(gomock.Any(), gomock.Any()).Times(1).Return(w, nil)
108115
err := q.InTx(func(tx database.Store) error {
109116
// The inner tx should use the parent's authz
110117
_, err := tx.GetWorkspaceByID(ctx, w.ID)
111118
return err
112119
}, nil)
113-
require.Error(t, err, "must error")
120+
require.ErrorContains(t, err, "custom error", "must be our custom error")
114121
require.ErrorAs(t, err, &dbauthz.NotAuthorizedError{}, "must be an authorized error")
115122
require.True(t, dbauthz.IsNotAuthorizedError(err), "must be an authorized error")
116123
}
@@ -120,32 +127,26 @@ func TestNew(t *testing.T) {
120127
t.Parallel()
121128

122129
var (
123-
db, _ = dbtestutil.NewDB(t)
130+
ctrl = gomock.NewController(t)
131+
db = dbmock.NewMockStore(ctrl)
132+
faker = gofakeit.New(0)
124133
rec = &coderdtest.RecordingAuthorizer{
125134
Wrapped: &coderdtest.FakeAuthorizer{},
126135
}
127136
subj = rbac.Subject{}
128137
ctx = dbauthz.As(context.Background(), rbac.Subject{})
129138
)
130-
u := dbgen.User(t, db, database.User{})
131-
org := dbgen.Organization(t, db, database.Organization{})
132-
tpl := dbgen.Template(t, db, database.Template{
133-
OrganizationID: org.ID,
134-
CreatedBy: u.ID,
135-
})
136-
exp := dbgen.Workspace(t, db, database.WorkspaceTable{
137-
OwnerID: u.ID,
138-
OrganizationID: org.ID,
139-
TemplateID: tpl.ID,
140-
})
139+
db.EXPECT().Wrappers().Times(1).Return([]string{}).Times(2) // two calls to New()
140+
exp := testutil.Fake(t, faker, database.Workspace{})
141+
db.EXPECT().GetWorkspaceByID(gomock.Any(), exp.ID).Times(1).Return(exp, nil)
141142
// Double wrap should not cause an actual double wrap. So only 1 rbac call
142143
// should be made.
143144
az := dbauthz.New(db, rec, slog.Make(), coderdtest.AccessControlStorePointer())
144145
az = dbauthz.New(az, rec, slog.Make(), coderdtest.AccessControlStorePointer())
145146

146147
w, err := az.GetWorkspaceByID(ctx, exp.ID)
147148
require.NoError(t, err, "must not error")
148-
require.Equal(t, exp, w.WorkspaceTable(), "must be equal")
149+
require.Equal(t, exp, w, "must be equal")
149150

150151
rec.AssertActor(t, subj, rec.Pair(policy.ActionRead, exp))
151152
require.NoError(t, rec.AllAsserted(), "should only be 1 rbac call")
@@ -154,6 +155,8 @@ func TestNew(t *testing.T) {
154155
// TestDBAuthzRecursive is a simple test to search for infinite recursion
155156
// bugs. It isn't perfect, and only catches a subset of the possible bugs
156157
// as only the first db call will be made. But it is better than nothing.
158+
// This can be removed when all tests in this package are migrated to
159+
// dbmock as it will immediately detect recursive calls.
157160
func TestDBAuthzRecursive(t *testing.T) {
158161
t.Parallel()
159162
db, _ := dbtestutil.NewDB(t)

0 commit comments

Comments
 (0)