Skip to content

Commit b359dbb

Browse files
authored
chore: Allow RecordingAuthorizer to record multiple rbac authz calls (#6024)
* chore: Allow RecordingAuthorizer to record multiple rbac authz calls Prior iteration only recorded the last call. This is required for more comprehensive testing
1 parent 571f5d0 commit b359dbb

File tree

9 files changed

+719
-52
lines changed

9 files changed

+719
-52
lines changed

coderd/coderdtest/authorize.go

+185-50
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@ import (
77
"net/http"
88
"strconv"
99
"strings"
10+
"sync"
1011
"testing"
1112
"time"
1213

13-
"github.com/coder/coder/coderd/database/dbfake"
14-
1514
"github.com/go-chi/chi/v5"
1615
"github.com/stretchr/testify/assert"
1716
"github.com/stretchr/testify/require"
1817
"golang.org/x/xerrors"
1918

2019
"github.com/coder/coder/coderd"
20+
"github.com/coder/coder/coderd/database/dbfake"
2121
"github.com/coder/coder/coderd/rbac"
2222
"github.com/coder/coder/coderd/rbac/regosql"
2323
"github.com/coder/coder/codersdk"
@@ -443,7 +443,9 @@ func NewAuthTester(ctx context.Context, t *testing.T, client *codersdk.Client, a
443443

444444
func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck, skipRoutes map[string]string) {
445445
// Always fail auth from this point forward
446-
a.authorizer.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("fake implementation"), nil, nil)
446+
a.authorizer.Wrapped = &FakeAuthorizer{
447+
AlwaysReturn: rbac.ForbiddenWithInternal(xerrors.New("fake implementation"), nil, nil),
448+
}
447449

448450
routeMissing := make(map[string]bool)
449451
for k, v := range assertRoute {
@@ -483,7 +485,7 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
483485
return nil
484486
}
485487
a.t.Run(name, func(t *testing.T) {
486-
a.authorizer.reset()
488+
a.authorizer.Reset()
487489
routeKey := strings.TrimRight(name, "/")
488490

489491
routeAssertions, ok := assertRoute[routeKey]
@@ -514,18 +516,19 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
514516
assert.Equal(t, http.StatusForbidden, resp.StatusCode, "expect unauthorized")
515517
}
516518
}
517-
if a.authorizer.Called != nil {
519+
if a.authorizer.lastCall() != nil {
520+
last := a.authorizer.lastCall()
518521
if routeAssertions.AssertAction != "" {
519-
assert.Equal(t, routeAssertions.AssertAction, a.authorizer.Called.Action, "resource action")
522+
assert.Equal(t, routeAssertions.AssertAction, last.Action, "resource action")
520523
}
521524
if routeAssertions.AssertObject.Type != "" {
522-
assert.Equal(t, routeAssertions.AssertObject.Type, a.authorizer.Called.Object.Type, "resource type")
525+
assert.Equal(t, routeAssertions.AssertObject.Type, last.Object.Type, "resource type")
523526
}
524527
if routeAssertions.AssertObject.Owner != "" {
525-
assert.Equal(t, routeAssertions.AssertObject.Owner, a.authorizer.Called.Object.Owner, "resource owner")
528+
assert.Equal(t, routeAssertions.AssertObject.Owner, last.Object.Owner, "resource owner")
526529
}
527530
if routeAssertions.AssertObject.OrgID != "" {
528-
assert.Equal(t, routeAssertions.AssertObject.OrgID, a.authorizer.Called.Object.OrgID, "resource org")
531+
assert.Equal(t, routeAssertions.AssertObject.OrgID, last.Object.OrgID, "resource org")
529532
}
530533
}
531534
} else {
@@ -539,52 +542,195 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
539542
}
540543

541544
type authCall struct {
542-
Subject rbac.Subject
543-
Action rbac.Action
544-
Object rbac.Object
545+
Actor rbac.Subject
546+
Action rbac.Action
547+
Object rbac.Object
548+
549+
asserted bool
545550
}
546551

552+
var _ rbac.Authorizer = (*RecordingAuthorizer)(nil)
553+
554+
// RecordingAuthorizer wraps any rbac.Authorizer and records all Authorize()
555+
// calls made. This is useful for testing as these calls can later be asserted.
547556
type RecordingAuthorizer struct {
548-
Called *authCall
549-
AlwaysReturn error
557+
sync.RWMutex
558+
Called []authCall
559+
Wrapped rbac.Authorizer
550560
}
551561

552-
var _ rbac.Authorizer = (*RecordingAuthorizer)(nil)
562+
type ActionObjectPair struct {
563+
Action rbac.Action
564+
Object rbac.Object
565+
}
553566

554-
// AuthorizeSQL does not record the call. This matches the postgres behavior
555-
// of not calling Authorize()
556-
func (r *RecordingAuthorizer) AuthorizeSQL(_ context.Context, _ rbac.Subject, _ rbac.Action, _ rbac.Object) error {
557-
return r.AlwaysReturn
567+
// Pair is on the RecordingAuthorizer to be easy to find and keep the pkg
568+
// interface smaller.
569+
func (*RecordingAuthorizer) Pair(action rbac.Action, object rbac.Objecter) ActionObjectPair {
570+
return ActionObjectPair{
571+
Action: action,
572+
Object: object.RBACObject(),
573+
}
558574
}
559575

560-
func (r *RecordingAuthorizer) Authorize(_ context.Context, subject rbac.Subject, action rbac.Action, object rbac.Object) error {
561-
r.Called = &authCall{
562-
Subject: subject,
563-
Action: action,
564-
Object: object,
576+
// AllAsserted returns an error if all calls to Authorize() have not been
577+
// asserted and checked. This is useful for testing to ensure that all
578+
// Authorize() calls are checked in the unit test.
579+
func (r *RecordingAuthorizer) AllAsserted() error {
580+
r.RLock()
581+
defer r.RUnlock()
582+
missed := []authCall{}
583+
for _, c := range r.Called {
584+
if !c.asserted {
585+
missed = append(missed, c)
586+
}
565587
}
566-
return r.AlwaysReturn
588+
589+
if len(missed) > 0 {
590+
return xerrors.Errorf("missed calls: %+v", missed)
591+
}
592+
return nil
567593
}
568594

569-
func (r *RecordingAuthorizer) Prepare(_ context.Context, subject rbac.Subject, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) {
570-
return &fakePreparedAuthorizer{
571-
Original: r,
572-
Subject: subject,
573-
Action: action,
574-
HardCodedSQLString: "true",
595+
// AssertActor asserts in order. If the order of authz calls does not match,
596+
// this will fail.
597+
func (r *RecordingAuthorizer) AssertActor(t *testing.T, actor rbac.Subject, did ...ActionObjectPair) {
598+
r.RLock()
599+
defer r.RUnlock()
600+
ptr := 0
601+
for i, call := range r.Called {
602+
if ptr == len(did) {
603+
// Finished all assertions
604+
return
605+
}
606+
if call.Actor.ID == actor.ID {
607+
action, object := did[ptr].Action, did[ptr].Object
608+
assert.Equalf(t, action, call.Action, "assert action %d", ptr)
609+
assert.Equalf(t, object, call.Object, "assert object %d", ptr)
610+
r.Called[i].asserted = true
611+
ptr++
612+
}
613+
}
614+
615+
assert.Equalf(t, len(did), ptr, "assert actor: didn't find all actions, %d missing actions", len(did)-ptr)
616+
}
617+
618+
// recordAuthorize is the internal method that records the Authorize() call.
619+
func (r *RecordingAuthorizer) recordAuthorize(subject rbac.Subject, action rbac.Action, object rbac.Object) {
620+
r.Lock()
621+
defer r.Unlock()
622+
r.Called = append(r.Called, authCall{
623+
Actor: subject,
624+
Action: action,
625+
Object: object,
626+
})
627+
}
628+
629+
func (r *RecordingAuthorizer) Authorize(ctx context.Context, subject rbac.Subject, action rbac.Action, object rbac.Object) error {
630+
r.recordAuthorize(subject, action, object)
631+
if r.Wrapped == nil {
632+
panic("Developer error: RecordingAuthorizer.Wrapped is nil")
633+
}
634+
return r.Wrapped.Authorize(ctx, subject, action, object)
635+
}
636+
637+
func (r *RecordingAuthorizer) Prepare(ctx context.Context, subject rbac.Subject, action rbac.Action, objectType string) (rbac.PreparedAuthorized, error) {
638+
r.RLock()
639+
defer r.RUnlock()
640+
if r.Wrapped == nil {
641+
panic("Developer error: RecordingAuthorizer.Wrapped is nil")
642+
}
643+
644+
prep, err := r.Wrapped.Prepare(ctx, subject, action, objectType)
645+
if err != nil {
646+
return nil, err
647+
}
648+
return &PreparedRecorder{
649+
rec: r,
650+
prepped: prep,
651+
subject: subject,
652+
action: action,
575653
}, nil
576654
}
577655

578-
func (r *RecordingAuthorizer) reset() {
656+
// Reset clears the recorded Authorize() calls.
657+
func (r *RecordingAuthorizer) Reset() {
658+
r.Lock()
659+
defer r.Unlock()
579660
r.Called = nil
580661
}
581662

663+
// lastCall is implemented to support legacy tests.
664+
// Deprecated
665+
func (r *RecordingAuthorizer) lastCall() *authCall {
666+
r.RLock()
667+
defer r.RUnlock()
668+
if len(r.Called) == 0 {
669+
return nil
670+
}
671+
return &r.Called[len(r.Called)-1]
672+
}
673+
674+
// PreparedRecorder is the prepared version of the RecordingAuthorizer.
675+
// It records the Authorize() calls to the original recorder. If the caller
676+
// uses CompileToSQL, all recording stops. This is to support parity between
677+
// memory and SQL backed dbs.
678+
type PreparedRecorder struct {
679+
rec *RecordingAuthorizer
680+
prepped rbac.PreparedAuthorized
681+
subject rbac.Subject
682+
action rbac.Action
683+
684+
rw sync.Mutex
685+
usingSQL bool
686+
}
687+
688+
func (s *PreparedRecorder) Authorize(ctx context.Context, object rbac.Object) error {
689+
s.rw.Lock()
690+
defer s.rw.Unlock()
691+
692+
if !s.usingSQL {
693+
s.rec.recordAuthorize(s.subject, s.action, object)
694+
}
695+
return s.prepped.Authorize(ctx, object)
696+
}
697+
func (s *PreparedRecorder) CompileToSQL(ctx context.Context, cfg regosql.ConvertConfig) (string, error) {
698+
s.rw.Lock()
699+
defer s.rw.Unlock()
700+
701+
s.usingSQL = true
702+
return s.prepped.CompileToSQL(ctx, cfg)
703+
}
704+
705+
// FakeAuthorizer is an Authorizer that always returns the same error.
706+
type FakeAuthorizer struct {
707+
// AlwaysReturn is the error that will be returned by Authorize.
708+
AlwaysReturn error
709+
}
710+
711+
var _ rbac.Authorizer = (*FakeAuthorizer)(nil)
712+
713+
func (d *FakeAuthorizer) Authorize(_ context.Context, _ rbac.Subject, _ rbac.Action, _ rbac.Object) error {
714+
return d.AlwaysReturn
715+
}
716+
717+
func (d *FakeAuthorizer) Prepare(_ context.Context, subject rbac.Subject, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) {
718+
return &fakePreparedAuthorizer{
719+
Original: d,
720+
Subject: subject,
721+
Action: action,
722+
}, nil
723+
}
724+
725+
var _ rbac.PreparedAuthorized = (*fakePreparedAuthorizer)(nil)
726+
727+
// fakePreparedAuthorizer is the prepared version of a FakeAuthorizer. It will
728+
// return the same error as the original FakeAuthorizer.
582729
type fakePreparedAuthorizer struct {
583-
Original *RecordingAuthorizer
584-
Subject rbac.Subject
585-
Action rbac.Action
586-
HardCodedSQLString string
587-
HardCodedRegoString string
730+
sync.RWMutex
731+
Original *FakeAuthorizer
732+
Subject rbac.Subject
733+
Action rbac.Action
588734
}
589735

590736
func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Object) error {
@@ -593,17 +739,6 @@ func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Obje
593739

594740
// CompileToSQL returns a compiled version of the authorizer that will work for
595741
// in memory databases. This fake version will not work against a SQL database.
596-
func (fakePreparedAuthorizer) CompileToSQL(_ context.Context, _ regosql.ConvertConfig) (string, error) {
597-
return "", xerrors.New("not implemented")
598-
}
599-
600-
func (f *fakePreparedAuthorizer) Eval(object rbac.Object) bool {
601-
return f.Original.AuthorizeSQL(context.Background(), f.Subject, f.Action, object) == nil
602-
}
603-
604-
func (f fakePreparedAuthorizer) RegoString() string {
605-
if f.HardCodedRegoString != "" {
606-
return f.HardCodedRegoString
607-
}
608-
panic("not implemented")
742+
func (*fakePreparedAuthorizer) CompileToSQL(_ context.Context, _ regosql.ConvertConfig) (string, error) {
743+
return "not a valid sql string", nil
609744
}

0 commit comments

Comments
 (0)