@@ -73,7 +73,9 @@ func TestAsNoActor(t *testing.T) {
73
73
func TestPing (t * testing.T ) {
74
74
t .Parallel ()
75
75
76
- db , _ := dbtestutil .NewDB (t )
76
+ db := dbmock .NewMockStore (gomock .NewController (t ))
77
+ db .EXPECT ().Wrappers ().Times (1 ).Return ([]string {})
78
+ db .EXPECT ().Ping (gomock .Any ()).Times (1 ).Return (time .Second , nil )
77
79
q := dbauthz .New (db , & coderdtest.RecordingAuthorizer {}, slog .Make (), coderdtest .AccessControlStorePointer ())
78
80
_ , err := q .Ping (context .Background ())
79
81
require .NoError (t , err , "must not error" )
@@ -83,34 +85,39 @@ func TestPing(t *testing.T) {
83
85
func TestInTX (t * testing.T ) {
84
86
t .Parallel ()
85
87
86
- db , _ := dbtestutil .NewDB (t )
88
+ var (
89
+ ctrl = gomock .NewController (t )
90
+ db = dbmock .NewMockStore (ctrl )
91
+ mTx = dbmock .NewMockStore (ctrl ) // to record the 'in tx' calls
92
+ faker = gofakeit .New (0 )
93
+ w = testutil .Fake (t , faker , database.Workspace {})
94
+ actor = rbac.Subject {
95
+ ID : uuid .NewString (),
96
+ Roles : rbac.RoleIdentifiers {rbac .RoleOwner ()},
97
+ Groups : []string {},
98
+ Scope : rbac .ScopeAll ,
99
+ }
100
+ ctx = dbauthz .As (context .Background (), actor )
101
+ )
102
+
103
+ db .EXPECT ().Wrappers ().Times (1 ).Return ([]string {}) // called by dbauthz.New
87
104
q := dbauthz .New (db , & coderdtest.RecordingAuthorizer {
88
105
Wrapped : (& coderdtest.FakeAuthorizer {}).AlwaysReturn (xerrors .New ("custom error" )),
89
106
}, slog .Make (), coderdtest .AccessControlStorePointer ())
90
- actor := rbac.Subject {
91
- ID : uuid .NewString (),
92
- Roles : rbac.RoleIdentifiers {rbac .RoleOwner ()},
93
- Groups : []string {},
94
- Scope : rbac .ScopeAll ,
95
- }
96
- u := dbgen .User (t , db , database.User {})
97
- o := dbgen .Organization (t , db , database.Organization {})
98
- tpl := dbgen .Template (t , db , database.Template {
99
- CreatedBy : u .ID ,
100
- OrganizationID : o .ID ,
101
- })
102
- w := dbgen .Workspace (t , db , database.WorkspaceTable {
103
- OwnerID : u .ID ,
104
- TemplateID : tpl .ID ,
105
- OrganizationID : o .ID ,
106
- })
107
- ctx := dbauthz .As (context .Background (), actor )
107
+
108
+ db .EXPECT ().InTx (gomock .Any (), gomock .Any ()).Times (1 ).DoAndReturn (
109
+ func (f func (database.Store ) error , _ * database.TxOptions ) error {
110
+ return f (mTx )
111
+ },
112
+ )
113
+ mTx .EXPECT ().Wrappers ().Times (1 ).Return ([]string {})
114
+ mTx .EXPECT ().GetWorkspaceByID (gomock .Any (), gomock .Any ()).Times (1 ).Return (w , nil )
108
115
err := q .InTx (func (tx database.Store ) error {
109
116
// The inner tx should use the parent's authz
110
117
_ , err := tx .GetWorkspaceByID (ctx , w .ID )
111
118
return err
112
119
}, nil )
113
- require .Error (t , err , "must error" )
120
+ require .ErrorContains (t , err , "custom error" , " must be our custom error" )
114
121
require .ErrorAs (t , err , & dbauthz.NotAuthorizedError {}, "must be an authorized error" )
115
122
require .True (t , dbauthz .IsNotAuthorizedError (err ), "must be an authorized error" )
116
123
}
@@ -120,32 +127,26 @@ func TestNew(t *testing.T) {
120
127
t .Parallel ()
121
128
122
129
var (
123
- db , _ = dbtestutil .NewDB (t )
130
+ ctrl = gomock .NewController (t )
131
+ db = dbmock .NewMockStore (ctrl )
132
+ faker = gofakeit .New (0 )
124
133
rec = & coderdtest.RecordingAuthorizer {
125
134
Wrapped : & coderdtest.FakeAuthorizer {},
126
135
}
127
136
subj = rbac.Subject {}
128
137
ctx = dbauthz .As (context .Background (), rbac.Subject {})
129
138
)
130
- u := dbgen .User (t , db , database.User {})
131
- org := dbgen .Organization (t , db , database.Organization {})
132
- tpl := dbgen .Template (t , db , database.Template {
133
- OrganizationID : org .ID ,
134
- CreatedBy : u .ID ,
135
- })
136
- exp := dbgen .Workspace (t , db , database.WorkspaceTable {
137
- OwnerID : u .ID ,
138
- OrganizationID : org .ID ,
139
- TemplateID : tpl .ID ,
140
- })
139
+ db .EXPECT ().Wrappers ().Times (1 ).Return ([]string {}).Times (2 ) // two calls to New()
140
+ exp := testutil .Fake (t , faker , database.Workspace {})
141
+ db .EXPECT ().GetWorkspaceByID (gomock .Any (), exp .ID ).Times (1 ).Return (exp , nil )
141
142
// Double wrap should not cause an actual double wrap. So only 1 rbac call
142
143
// should be made.
143
144
az := dbauthz .New (db , rec , slog .Make (), coderdtest .AccessControlStorePointer ())
144
145
az = dbauthz .New (az , rec , slog .Make (), coderdtest .AccessControlStorePointer ())
145
146
146
147
w , err := az .GetWorkspaceByID (ctx , exp .ID )
147
148
require .NoError (t , err , "must not error" )
148
- require .Equal (t , exp , w . WorkspaceTable () , "must be equal" )
149
+ require .Equal (t , exp , w , "must be equal" )
149
150
150
151
rec .AssertActor (t , subj , rec .Pair (policy .ActionRead , exp ))
151
152
require .NoError (t , rec .AllAsserted (), "should only be 1 rbac call" )
@@ -154,6 +155,8 @@ func TestNew(t *testing.T) {
154
155
// TestDBAuthzRecursive is a simple test to search for infinite recursion
155
156
// bugs. It isn't perfect, and only catches a subset of the possible bugs
156
157
// as only the first db call will be made. But it is better than nothing.
158
+ // This can be removed when all tests in this package are migrated to
159
+ // dbmock as it will immediately detect recursive calls.
157
160
func TestDBAuthzRecursive (t * testing.T ) {
158
161
t .Parallel ()
159
162
db , _ := dbtestutil .NewDB (t )
0 commit comments