@@ -508,18 +508,19 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
508
508
assert .Equal (t , http .StatusForbidden , resp .StatusCode , "expect unauthorized" )
509
509
}
510
510
}
511
- if a .authorizer .Called != nil {
511
+ if a .authorizer .LastCall () != nil {
512
+ last := a .authorizer .LastCall ()
512
513
if routeAssertions .AssertAction != "" {
513
- assert .Equal (t , routeAssertions .AssertAction , a . authorizer . Called .Action , "resource action" )
514
+ assert .Equal (t , routeAssertions .AssertAction , last .Action , "resource action" )
514
515
}
515
516
if routeAssertions .AssertObject .Type != "" {
516
- assert .Equal (t , routeAssertions .AssertObject .Type , a . authorizer . Called .Object .Type , "resource type" )
517
+ assert .Equal (t , routeAssertions .AssertObject .Type , last .Object .Type , "resource type" )
517
518
}
518
519
if routeAssertions .AssertObject .Owner != "" {
519
- assert .Equal (t , routeAssertions .AssertObject .Owner , a . authorizer . Called .Object .Owner , "resource owner" )
520
+ assert .Equal (t , routeAssertions .AssertObject .Owner , last .Object .Owner , "resource owner" )
520
521
}
521
522
if routeAssertions .AssertObject .OrgID != "" {
522
- assert .Equal (t , routeAssertions .AssertObject .OrgID , a . authorizer . Called .Object .OrgID , "resource org" )
523
+ assert .Equal (t , routeAssertions .AssertObject .OrgID , last .Object .OrgID , "resource org" )
523
524
}
524
525
}
525
526
} else {
@@ -533,30 +534,81 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
533
534
}
534
535
535
536
type authCall struct {
536
- Subject rbac.Subject
537
- Action rbac.Action
538
- Object rbac.Object
537
+ Actor rbac.Subject
538
+ Action rbac.Action
539
+ Object rbac.Object
540
+
541
+ asserted bool
539
542
}
540
543
541
544
type RecordingAuthorizer struct {
542
- Called * authCall
545
+ Called [] authCall
543
546
AlwaysReturn error
544
547
}
545
548
546
549
var _ rbac.Authorizer = (* RecordingAuthorizer )(nil )
547
550
551
+ type ActionObjectPair struct {
552
+ Action rbac.Action
553
+ Object rbac.Object
554
+ }
555
+
556
+ // Pair is on the RecordingAuthorizer to be easy to find and keep the pkg
557
+ // interface smaller.
558
+ func (r * RecordingAuthorizer ) Pair (action rbac.Action , object rbac.Objecter ) ActionObjectPair {
559
+ return ActionObjectPair {
560
+ Action : action ,
561
+ Object : object .RBACObject (),
562
+ }
563
+ }
564
+
565
+ func (r * RecordingAuthorizer ) AllAsserted () error {
566
+ missed := 0
567
+ for _ , c := range r .Called {
568
+ if ! c .asserted {
569
+ missed ++
570
+ }
571
+ }
572
+
573
+ if missed > 0 {
574
+ return xerrors .Errorf ("missed %d calls" , missed )
575
+ }
576
+ return nil
577
+ }
578
+
579
+ // AssertActor asserts in order.
580
+ func (r * RecordingAuthorizer ) AssertActor (t * testing.T , actor rbac.Subject , did ... ActionObjectPair ) {
581
+ ptr := 0
582
+ for i , call := range r .Called {
583
+ if ptr == len (did ) {
584
+ // Finished all assertions
585
+ return
586
+ }
587
+ if call .Actor .ID == actor .ID {
588
+ //action, object := did[ptr], on[ptr]
589
+ action , object := did [ptr ].Action , did [ptr ].Object
590
+ assert .Equalf (t , action , call .Action , "assert action %d" , ptr )
591
+ assert .Equalf (t , object , call .Object , "assert object %d" , ptr )
592
+ r .Called [i ].asserted = true
593
+ ptr ++
594
+ }
595
+ }
596
+
597
+ assert .Equalf (t , len (did ), ptr , "assert actor: didn't find all actions, %d missing actions" , len (did )- ptr )
598
+ }
599
+
548
600
// AuthorizeSQL does not record the call. This matches the postgres behavior
549
601
// of not calling Authorize()
550
602
func (r * RecordingAuthorizer ) AuthorizeSQL (_ context.Context , _ rbac.Subject , _ rbac.Action , _ rbac.Object ) error {
551
603
return r .AlwaysReturn
552
604
}
553
605
554
606
func (r * RecordingAuthorizer ) Authorize (_ context.Context , subject rbac.Subject , action rbac.Action , object rbac.Object ) error {
555
- r .Called = & authCall {
556
- Subject : subject ,
557
- Action : action ,
558
- Object : object ,
559
- }
607
+ r .Called = append ( r . Called , authCall {
608
+ Actor : subject ,
609
+ Action : action ,
610
+ Object : object ,
611
+ })
560
612
return r .AlwaysReturn
561
613
}
562
614
@@ -601,3 +653,12 @@ func (f fakePreparedAuthorizer) RegoString() string {
601
653
}
602
654
panic ("not implemented" )
603
655
}
656
+
657
+ // LastCall is implemented to support legacy tests.
658
+ // Deprecated
659
+ func (r * RecordingAuthorizer ) LastCall () * authCall {
660
+ if len (r .Called ) == 0 {
661
+ return nil
662
+ }
663
+ return & r .Called [len (r .Called )- 1 ]
664
+ }
0 commit comments