Skip to content

Commit 290cabc

Browse files
committed
fix dbauthz edge cases
1 parent 820a033 commit 290cabc

File tree

2 files changed

+42
-25
lines changed

2 files changed

+42
-25
lines changed

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -594,19 +594,6 @@ func (s *MethodTestSuite) TestOrganization() {
594594
check.Args([]uuid.UUID{ma.UserID, mb.UserID}).
595595
Asserts(rbac.ResourceUserObject(ma.UserID), policy.ActionRead, rbac.ResourceUserObject(mb.UserID), policy.ActionRead)
596596
}))
597-
s.Run("GetOrganizationMemberByUserID", s.Subtest(func(db database.Store, check *expects) {
598-
mem := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{})
599-
check.Args(database.GetOrganizationMemberByUserIDParams{
600-
OrganizationID: mem.OrganizationID,
601-
UserID: mem.UserID,
602-
}).Asserts(mem, policy.ActionRead).Returns(mem)
603-
}))
604-
s.Run("GetOrganizationMembershipsByUserID", s.Subtest(func(db database.Store, check *expects) {
605-
u := dbgen.User(s.T(), db, database.User{})
606-
a := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID})
607-
b := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID})
608-
check.Args(u.ID).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(a, b))
609-
}))
610597
s.Run("GetOrganizations", s.Subtest(func(db database.Store, check *expects) {
611598
def, _ := db.GetDefaultOrganization(context.Background())
612599
a := dbgen.Organization(s.T(), db, database.Organization{})
@@ -671,11 +658,14 @@ func (s *MethodTestSuite) TestOrganization() {
671658
GrantedRoles: []string{},
672659
UserID: u.ID,
673660
OrgID: o.ID,
674-
}).Asserts(
675-
mem, policy.ActionRead,
676-
rbac.ResourceAssignRole.InOrg(o.ID), policy.ActionAssign, // org-mem
677-
rbac.ResourceAssignRole.InOrg(o.ID), policy.ActionDelete, // org-admin
678-
).Returns(out)
661+
}).
662+
WithNotAuthorized(sql.ErrNoRows.Error()).
663+
WithCancelled(sql.ErrNoRows.Error()).
664+
Asserts(
665+
mem, policy.ActionRead,
666+
rbac.ResourceAssignRole.InOrg(o.ID), policy.ActionAssign, // org-mem
667+
rbac.ResourceAssignRole.InOrg(o.ID), policy.ActionDelete, // org-admin
668+
).Returns(out)
679669
}))
680670
}
681671

coderd/database/dbauthz/setup_test.go

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *expec
157157
if len(testCase.assertions) > 0 {
158158
// Only run these tests if we know the underlying call makes
159159
// rbac assertions.
160-
s.NotAuthorizedErrorTest(ctx, fakeAuthorizer, callMethod)
160+
s.NotAuthorizedErrorTest(ctx, fakeAuthorizer, testCase, callMethod)
161161
}
162162

163163
if len(testCase.assertions) > 0 ||
@@ -230,7 +230,7 @@ func (s *MethodTestSuite) NoActorErrorTest(callMethod func(ctx context.Context)
230230

231231
// NotAuthorizedErrorTest runs the given method with an authorizer that will fail authz.
232232
// Asserts that the error returned is a NotAuthorizedError.
233-
func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderdtest.FakeAuthorizer, callMethod func(ctx context.Context) ([]reflect.Value, error)) {
233+
func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderdtest.FakeAuthorizer, testCase expects, callMethod func(ctx context.Context) ([]reflect.Value, error)) {
234234
s.Run("NotAuthorized", func() {
235235
az.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("Always fail authz"), rbac.Subject{}, "", rbac.Object{}, nil)
236236

@@ -242,9 +242,14 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
242242
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
243243
// any case where the error is nil and the response is an empty slice.
244244
if err != nil || !hasEmptySliceResponse(resp) {
245-
s.ErrorContainsf(err, "unauthorized", "error string should have a good message")
246-
s.Errorf(err, "method should an error with disallow authz")
247-
s.ErrorAs(err, &dbauthz.NotAuthorizedError{}, "error should be NotAuthorizedError")
245+
// Expect the default error
246+
if testCase.notAuthorizedExpect == "" {
247+
s.ErrorContainsf(err, "unauthorized", "error string should have a good message")
248+
s.Errorf(err, "method should an error with disallow authz")
249+
s.ErrorAs(err, &dbauthz.NotAuthorizedError{}, "error should be NotAuthorizedError")
250+
} else {
251+
s.ErrorContains(err, testCase.notAuthorizedExpect)
252+
}
248253
}
249254
})
250255

@@ -263,8 +268,13 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
263268
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
264269
// any case where the error is nil and the response is an empty slice.
265270
if err != nil || !hasEmptySliceResponse(resp) {
266-
s.Errorf(err, "method should an error with cancellation")
267-
s.ErrorIsf(err, context.Canceled, "error should match context.Canceled")
271+
if testCase.cancelledCtxExpect == "" {
272+
s.Errorf(err, "method should an error with cancellation")
273+
s.ErrorIsf(err, context.Canceled, "error should match context.Canceled")
274+
} else {
275+
s.ErrorContains(err, testCase.cancelledCtxExpect)
276+
}
277+
268278
}
269279
})
270280
}
@@ -308,6 +318,23 @@ type expects struct {
308318
// outputs is optional. Can assert non-error return values.
309319
outputs []reflect.Value
310320
err error
321+
322+
// Optional override of the default error checks.
323+
// By default, we search for the expected error strings.
324+
// If these strings are present, these strings will be searched
325+
// instead.
326+
notAuthorizedExpect string
327+
cancelledCtxExpect string
328+
}
329+
330+
func (m *expects) WithNotAuthorized(contains string) *expects {
331+
m.notAuthorizedExpect = contains
332+
return m
333+
}
334+
335+
func (m *expects) WithCancelled(contains string) *expects {
336+
m.cancelledCtxExpect = contains
337+
return m
311338
}
312339

313340
// Asserts is required. Asserts the RBAC authorize calls that should be made.

0 commit comments

Comments
 (0)