@@ -7,17 +7,17 @@ import (
7
7
"net/http"
8
8
"strconv"
9
9
"strings"
10
+ "sync"
10
11
"testing"
11
12
"time"
12
13
13
- "github.com/coder/coder/coderd/database/dbfake"
14
-
15
14
"github.com/go-chi/chi/v5"
16
15
"github.com/stretchr/testify/assert"
17
16
"github.com/stretchr/testify/require"
18
17
"golang.org/x/xerrors"
19
18
20
19
"github.com/coder/coder/coderd"
20
+ "github.com/coder/coder/coderd/database/dbfake"
21
21
"github.com/coder/coder/coderd/rbac"
22
22
"github.com/coder/coder/coderd/rbac/regosql"
23
23
"github.com/coder/coder/codersdk"
@@ -443,7 +443,9 @@ func NewAuthTester(ctx context.Context, t *testing.T, client *codersdk.Client, a
443
443
444
444
func (a * AuthTester ) Test (ctx context.Context , assertRoute map [string ]RouteCheck , skipRoutes map [string ]string ) {
445
445
// Always fail auth from this point forward
446
- a .authorizer .AlwaysReturn = rbac .ForbiddenWithInternal (xerrors .New ("fake implementation" ), nil , nil )
446
+ a .authorizer .Wrapped = & FakeAuthorizer {
447
+ AlwaysReturn : rbac .ForbiddenWithInternal (xerrors .New ("fake implementation" ), nil , nil ),
448
+ }
447
449
448
450
routeMissing := make (map [string ]bool )
449
451
for k , v := range assertRoute {
@@ -483,7 +485,7 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
483
485
return nil
484
486
}
485
487
a .t .Run (name , func (t * testing.T ) {
486
- a .authorizer .reset ()
488
+ a .authorizer .Reset ()
487
489
routeKey := strings .TrimRight (name , "/" )
488
490
489
491
routeAssertions , ok := assertRoute [routeKey ]
@@ -514,18 +516,19 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
514
516
assert .Equal (t , http .StatusForbidden , resp .StatusCode , "expect unauthorized" )
515
517
}
516
518
}
517
- if a .authorizer .Called != nil {
519
+ if a .authorizer .lastCall () != nil {
520
+ last := a .authorizer .lastCall ()
518
521
if routeAssertions .AssertAction != "" {
519
- assert .Equal (t , routeAssertions .AssertAction , a . authorizer . Called .Action , "resource action" )
522
+ assert .Equal (t , routeAssertions .AssertAction , last .Action , "resource action" )
520
523
}
521
524
if routeAssertions .AssertObject .Type != "" {
522
- assert .Equal (t , routeAssertions .AssertObject .Type , a . authorizer . Called .Object .Type , "resource type" )
525
+ assert .Equal (t , routeAssertions .AssertObject .Type , last .Object .Type , "resource type" )
523
526
}
524
527
if routeAssertions .AssertObject .Owner != "" {
525
- assert .Equal (t , routeAssertions .AssertObject .Owner , a . authorizer . Called .Object .Owner , "resource owner" )
528
+ assert .Equal (t , routeAssertions .AssertObject .Owner , last .Object .Owner , "resource owner" )
526
529
}
527
530
if routeAssertions .AssertObject .OrgID != "" {
528
- assert .Equal (t , routeAssertions .AssertObject .OrgID , a . authorizer . Called .Object .OrgID , "resource org" )
531
+ assert .Equal (t , routeAssertions .AssertObject .OrgID , last .Object .OrgID , "resource org" )
529
532
}
530
533
}
531
534
} else {
@@ -539,52 +542,195 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
539
542
}
540
543
541
544
type authCall struct {
542
- Subject rbac.Subject
543
- Action rbac.Action
544
- Object rbac.Object
545
+ Actor rbac.Subject
546
+ Action rbac.Action
547
+ Object rbac.Object
548
+
549
+ asserted bool
545
550
}
546
551
552
+ var _ rbac.Authorizer = (* RecordingAuthorizer )(nil )
553
+
554
+ // RecordingAuthorizer wraps any rbac.Authorizer and records all Authorize()
555
+ // calls made. This is useful for testing as these calls can later be asserted.
547
556
type RecordingAuthorizer struct {
548
- Called * authCall
549
- AlwaysReturn error
557
+ sync.RWMutex
558
+ Called []authCall
559
+ Wrapped rbac.Authorizer
550
560
}
551
561
552
- var _ rbac.Authorizer = (* RecordingAuthorizer )(nil )
562
+ type ActionObjectPair struct {
563
+ Action rbac.Action
564
+ Object rbac.Object
565
+ }
553
566
554
- // AuthorizeSQL does not record the call. This matches the postgres behavior
555
- // of not calling Authorize()
556
- func (r * RecordingAuthorizer ) AuthorizeSQL (_ context.Context , _ rbac.Subject , _ rbac.Action , _ rbac.Object ) error {
557
- return r .AlwaysReturn
567
+ // Pair is on the RecordingAuthorizer to be easy to find and keep the pkg
568
+ // interface smaller.
569
+ func (* RecordingAuthorizer ) Pair (action rbac.Action , object rbac.Objecter ) ActionObjectPair {
570
+ return ActionObjectPair {
571
+ Action : action ,
572
+ Object : object .RBACObject (),
573
+ }
558
574
}
559
575
560
- func (r * RecordingAuthorizer ) Authorize (_ context.Context , subject rbac.Subject , action rbac.Action , object rbac.Object ) error {
561
- r .Called = & authCall {
562
- Subject : subject ,
563
- Action : action ,
564
- Object : object ,
576
+ // AllAsserted returns an error if all calls to Authorize() have not been
577
+ // asserted and checked. This is useful for testing to ensure that all
578
+ // Authorize() calls are checked in the unit test.
579
+ func (r * RecordingAuthorizer ) AllAsserted () error {
580
+ r .RLock ()
581
+ defer r .RUnlock ()
582
+ missed := []authCall {}
583
+ for _ , c := range r .Called {
584
+ if ! c .asserted {
585
+ missed = append (missed , c )
586
+ }
565
587
}
566
- return r .AlwaysReturn
588
+
589
+ if len (missed ) > 0 {
590
+ return xerrors .Errorf ("missed calls: %+v" , missed )
591
+ }
592
+ return nil
567
593
}
568
594
569
- func (r * RecordingAuthorizer ) Prepare (_ context.Context , subject rbac.Subject , action rbac.Action , _ string ) (rbac.PreparedAuthorized , error ) {
570
- return & fakePreparedAuthorizer {
571
- Original : r ,
572
- Subject : subject ,
573
- Action : action ,
574
- HardCodedSQLString : "true" ,
595
+ // AssertActor asserts in order. If the order of authz calls does not match,
596
+ // this will fail.
597
+ func (r * RecordingAuthorizer ) AssertActor (t * testing.T , actor rbac.Subject , did ... ActionObjectPair ) {
598
+ r .RLock ()
599
+ defer r .RUnlock ()
600
+ ptr := 0
601
+ for i , call := range r .Called {
602
+ if ptr == len (did ) {
603
+ // Finished all assertions
604
+ return
605
+ }
606
+ if call .Actor .ID == actor .ID {
607
+ action , object := did [ptr ].Action , did [ptr ].Object
608
+ assert .Equalf (t , action , call .Action , "assert action %d" , ptr )
609
+ assert .Equalf (t , object , call .Object , "assert object %d" , ptr )
610
+ r .Called [i ].asserted = true
611
+ ptr ++
612
+ }
613
+ }
614
+
615
+ assert .Equalf (t , len (did ), ptr , "assert actor: didn't find all actions, %d missing actions" , len (did )- ptr )
616
+ }
617
+
618
+ // recordAuthorize is the internal method that records the Authorize() call.
619
+ func (r * RecordingAuthorizer ) recordAuthorize (subject rbac.Subject , action rbac.Action , object rbac.Object ) {
620
+ r .Lock ()
621
+ defer r .Unlock ()
622
+ r .Called = append (r .Called , authCall {
623
+ Actor : subject ,
624
+ Action : action ,
625
+ Object : object ,
626
+ })
627
+ }
628
+
629
+ func (r * RecordingAuthorizer ) Authorize (ctx context.Context , subject rbac.Subject , action rbac.Action , object rbac.Object ) error {
630
+ r .recordAuthorize (subject , action , object )
631
+ if r .Wrapped == nil {
632
+ panic ("Developer error: RecordingAuthorizer.Wrapped is nil" )
633
+ }
634
+ return r .Wrapped .Authorize (ctx , subject , action , object )
635
+ }
636
+
637
+ func (r * RecordingAuthorizer ) Prepare (ctx context.Context , subject rbac.Subject , action rbac.Action , objectType string ) (rbac.PreparedAuthorized , error ) {
638
+ r .RLock ()
639
+ defer r .RUnlock ()
640
+ if r .Wrapped == nil {
641
+ panic ("Developer error: RecordingAuthorizer.Wrapped is nil" )
642
+ }
643
+
644
+ prep , err := r .Wrapped .Prepare (ctx , subject , action , objectType )
645
+ if err != nil {
646
+ return nil , err
647
+ }
648
+ return & PreparedRecorder {
649
+ rec : r ,
650
+ prepped : prep ,
651
+ subject : subject ,
652
+ action : action ,
575
653
}, nil
576
654
}
577
655
578
- func (r * RecordingAuthorizer ) reset () {
656
+ // Reset clears the recorded Authorize() calls.
657
+ func (r * RecordingAuthorizer ) Reset () {
658
+ r .Lock ()
659
+ defer r .Unlock ()
579
660
r .Called = nil
580
661
}
581
662
663
+ // lastCall is implemented to support legacy tests.
664
+ // Deprecated
665
+ func (r * RecordingAuthorizer ) lastCall () * authCall {
666
+ r .RLock ()
667
+ defer r .RUnlock ()
668
+ if len (r .Called ) == 0 {
669
+ return nil
670
+ }
671
+ return & r .Called [len (r .Called )- 1 ]
672
+ }
673
+
674
+ // PreparedRecorder is the prepared version of the RecordingAuthorizer.
675
+ // It records the Authorize() calls to the original recorder. If the caller
676
+ // uses CompileToSQL, all recording stops. This is to support parity between
677
+ // memory and SQL backed dbs.
678
+ type PreparedRecorder struct {
679
+ rec * RecordingAuthorizer
680
+ prepped rbac.PreparedAuthorized
681
+ subject rbac.Subject
682
+ action rbac.Action
683
+
684
+ rw sync.Mutex
685
+ usingSQL bool
686
+ }
687
+
688
+ func (s * PreparedRecorder ) Authorize (ctx context.Context , object rbac.Object ) error {
689
+ s .rw .Lock ()
690
+ defer s .rw .Unlock ()
691
+
692
+ if ! s .usingSQL {
693
+ s .rec .recordAuthorize (s .subject , s .action , object )
694
+ }
695
+ return s .prepped .Authorize (ctx , object )
696
+ }
697
+ func (s * PreparedRecorder ) CompileToSQL (ctx context.Context , cfg regosql.ConvertConfig ) (string , error ) {
698
+ s .rw .Lock ()
699
+ defer s .rw .Unlock ()
700
+
701
+ s .usingSQL = true
702
+ return s .prepped .CompileToSQL (ctx , cfg )
703
+ }
704
+
705
+ // FakeAuthorizer is an Authorizer that always returns the same error.
706
+ type FakeAuthorizer struct {
707
+ // AlwaysReturn is the error that will be returned by Authorize.
708
+ AlwaysReturn error
709
+ }
710
+
711
+ var _ rbac.Authorizer = (* FakeAuthorizer )(nil )
712
+
713
+ func (d * FakeAuthorizer ) Authorize (_ context.Context , _ rbac.Subject , _ rbac.Action , _ rbac.Object ) error {
714
+ return d .AlwaysReturn
715
+ }
716
+
717
+ func (d * FakeAuthorizer ) Prepare (_ context.Context , subject rbac.Subject , action rbac.Action , _ string ) (rbac.PreparedAuthorized , error ) {
718
+ return & fakePreparedAuthorizer {
719
+ Original : d ,
720
+ Subject : subject ,
721
+ Action : action ,
722
+ }, nil
723
+ }
724
+
725
+ var _ rbac.PreparedAuthorized = (* fakePreparedAuthorizer )(nil )
726
+
727
+ // fakePreparedAuthorizer is the prepared version of a FakeAuthorizer. It will
728
+ // return the same error as the original FakeAuthorizer.
582
729
type fakePreparedAuthorizer struct {
583
- Original * RecordingAuthorizer
584
- Subject rbac.Subject
585
- Action rbac.Action
586
- HardCodedSQLString string
587
- HardCodedRegoString string
730
+ sync.RWMutex
731
+ Original * FakeAuthorizer
732
+ Subject rbac.Subject
733
+ Action rbac.Action
588
734
}
589
735
590
736
func (f * fakePreparedAuthorizer ) Authorize (ctx context.Context , object rbac.Object ) error {
@@ -593,17 +739,6 @@ func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Obje
593
739
594
740
// CompileToSQL returns a compiled version of the authorizer that will work for
595
741
// in memory databases. This fake version will not work against a SQL database.
596
- func (fakePreparedAuthorizer ) CompileToSQL (_ context.Context , _ regosql.ConvertConfig ) (string , error ) {
597
- return "" , xerrors .New ("not implemented" )
598
- }
599
-
600
- func (f * fakePreparedAuthorizer ) Eval (object rbac.Object ) bool {
601
- return f .Original .AuthorizeSQL (context .Background (), f .Subject , f .Action , object ) == nil
602
- }
603
-
604
- func (f fakePreparedAuthorizer ) RegoString () string {
605
- if f .HardCodedRegoString != "" {
606
- return f .HardCodedRegoString
607
- }
608
- panic ("not implemented" )
742
+ func (* fakePreparedAuthorizer ) CompileToSQL (_ context.Context , _ regosql.ConvertConfig ) (string , error ) {
743
+ return "not a valid sql string" , nil
609
744
}
0 commit comments