Skip to content

Commit 5accbfe

Browse files
committed
Allow asserting many rbac checks in recording authorizer
1 parent cb40686 commit 5accbfe

File tree

2 files changed

+127
-37
lines changed

2 files changed

+127
-37
lines changed

coderd/authzquery/workspace_test.go

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import (
55
"testing"
66
"time"
77

8+
"github.com/moby/moby/pkg/namesgenerator"
9+
810
"github.com/coder/coder/coderd/rbac"
911

1012
"github.com/google/uuid"
@@ -24,34 +26,61 @@ func TestWorkspace(t *testing.T) {
2426
// TODO: Recorder should record all authz calls
2527
rec = &coderdtest.RecordingAuthorizer{}
2628
q = authzquery.NewAuthzQuerier(db, rec)
27-
ctx = context.Background()
28-
actor = authzquery.WithAuthorizeContext(ctx,
29-
uuid.New(),
30-
rbac.RoleNames{rbac.RoleOwner()},
31-
[]string{},
32-
rbac.ScopeAll,
33-
)
29+
actor = rbac.Subject{
30+
ID: uuid.New().String(),
31+
Roles: rbac.RoleNames{rbac.RoleOwner()},
32+
Groups: []string{},
33+
Scope: rbac.ScopeAll,
34+
}
35+
ctx = authzquery.WithAuthorizeContext(context.Background(), actor)
3436
)
3537

36-
// Seed db
37-
workspace, err := db.InsertWorkspace(ctx, database.InsertWorkspaceParams{
38+
workspace := insertRandomWorkspace(t, db)
39+
40+
// Test recorder
41+
_, err := q.GetWorkspaceByID(ctx, workspace.ID)
42+
require.NoError(t, err)
43+
44+
_, err = q.UpdateWorkspace(ctx, database.UpdateWorkspaceParams{
45+
ID: workspace.ID,
46+
Name: "new-name",
47+
})
48+
require.NoError(t, err)
49+
50+
rec.AssertActor(t, actor,
51+
rec.Pair(rbac.ActionRead, workspace),
52+
rec.Pair(rbac.ActionUpdate, workspace),
53+
)
54+
require.NoError(t, rec.AllAsserted())
55+
}
56+
57+
func insertRandomWorkspace(t *testing.T, db database.Store, opts ...func(w *database.Workspace)) database.Workspace {
58+
workspace := &database.Workspace{
3859
ID: uuid.New(),
39-
CreatedAt: time.Time{},
40-
UpdatedAt: time.Time{},
60+
CreatedAt: time.Now().Add(time.Hour * -1),
61+
UpdatedAt: time.Now(),
4162
OwnerID: uuid.New(),
4263
OrganizationID: uuid.New(),
4364
TemplateID: uuid.New(),
44-
Name: "fake-workspace",
45-
})
46-
require.NoError(t, err)
65+
Deleted: false,
66+
Name: namesgenerator.GetRandomName(1),
67+
LastUsedAt: time.Now(),
68+
}
69+
for _, opt := range opts {
70+
opt(workspace)
71+
}
4772

48-
// Test
49-
// NoAuth
50-
_, err = q.GetWorkspaceByID(ctx, workspace.ID)
51-
require.Error(t, err, "no actor in context")
52-
53-
// Test recorder
54-
_, err = q.GetWorkspaceByID(actor, workspace.ID)
55-
require.NoError(t, err)
56-
require.Equal(t, rec.Called.Object, workspace.RBACObject())
73+
newWorkspace, err := db.InsertWorkspace(context.Background(), database.InsertWorkspaceParams{
74+
ID: workspace.ID,
75+
CreatedAt: workspace.CreatedAt,
76+
UpdatedAt: workspace.UpdatedAt,
77+
OwnerID: workspace.OwnerID,
78+
OrganizationID: workspace.OrganizationID,
79+
TemplateID: workspace.TemplateID,
80+
Name: workspace.Name,
81+
AutostartSchedule: workspace.AutostartSchedule,
82+
Ttl: workspace.Ttl,
83+
})
84+
require.NoError(t, err, "insert workspace")
85+
return newWorkspace
5786
}

coderd/coderdtest/authorize.go

Lines changed: 75 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -508,18 +508,19 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
508508
assert.Equal(t, http.StatusForbidden, resp.StatusCode, "expect unauthorized")
509509
}
510510
}
511-
if a.authorizer.Called != nil {
511+
if a.authorizer.LastCall() != nil {
512+
last := a.authorizer.LastCall()
512513
if routeAssertions.AssertAction != "" {
513-
assert.Equal(t, routeAssertions.AssertAction, a.authorizer.Called.Action, "resource action")
514+
assert.Equal(t, routeAssertions.AssertAction, last.Action, "resource action")
514515
}
515516
if routeAssertions.AssertObject.Type != "" {
516-
assert.Equal(t, routeAssertions.AssertObject.Type, a.authorizer.Called.Object.Type, "resource type")
517+
assert.Equal(t, routeAssertions.AssertObject.Type, last.Object.Type, "resource type")
517518
}
518519
if routeAssertions.AssertObject.Owner != "" {
519-
assert.Equal(t, routeAssertions.AssertObject.Owner, a.authorizer.Called.Object.Owner, "resource owner")
520+
assert.Equal(t, routeAssertions.AssertObject.Owner, last.Object.Owner, "resource owner")
520521
}
521522
if routeAssertions.AssertObject.OrgID != "" {
522-
assert.Equal(t, routeAssertions.AssertObject.OrgID, a.authorizer.Called.Object.OrgID, "resource org")
523+
assert.Equal(t, routeAssertions.AssertObject.OrgID, last.Object.OrgID, "resource org")
523524
}
524525
}
525526
} else {
@@ -533,30 +534,81 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
533534
}
534535

535536
type authCall struct {
536-
Subject rbac.Subject
537-
Action rbac.Action
538-
Object rbac.Object
537+
Actor rbac.Subject
538+
Action rbac.Action
539+
Object rbac.Object
540+
541+
asserted bool
539542
}
540543

541544
type RecordingAuthorizer struct {
542-
Called *authCall
545+
Called []authCall
543546
AlwaysReturn error
544547
}
545548

546549
var _ rbac.Authorizer = (*RecordingAuthorizer)(nil)
547550

551+
type ActionObjectPair struct {
552+
Action rbac.Action
553+
Object rbac.Object
554+
}
555+
556+
// Pair is on the RecordingAuthorizer to be easy to find and keep the pkg
557+
// interface smaller.
558+
func (r *RecordingAuthorizer) Pair(action rbac.Action, object rbac.Objecter) ActionObjectPair {
559+
return ActionObjectPair{
560+
Action: action,
561+
Object: object.RBACObject(),
562+
}
563+
}
564+
565+
func (r *RecordingAuthorizer) AllAsserted() error {
566+
missed := 0
567+
for _, c := range r.Called {
568+
if !c.asserted {
569+
missed++
570+
}
571+
}
572+
573+
if missed > 0 {
574+
return xerrors.Errorf("missed %d calls", missed)
575+
}
576+
return nil
577+
}
578+
579+
// AssertActor asserts in order.
580+
func (r *RecordingAuthorizer) AssertActor(t *testing.T, actor rbac.Subject, did ...ActionObjectPair) {
581+
ptr := 0
582+
for i, call := range r.Called {
583+
if ptr == len(did) {
584+
// Finished all assertions
585+
return
586+
}
587+
if call.Actor.ID == actor.ID {
588+
//action, object := did[ptr], on[ptr]
589+
action, object := did[ptr].Action, did[ptr].Object
590+
assert.Equalf(t, action, call.Action, "assert action %d", ptr)
591+
assert.Equalf(t, object, call.Object, "assert object %d", ptr)
592+
r.Called[i].asserted = true
593+
ptr++
594+
}
595+
}
596+
597+
assert.Equalf(t, len(did), ptr, "assert actor: didn't find all actions, %d missing actions", len(did)-ptr)
598+
}
599+
548600
// AuthorizeSQL does not record the call. This matches the postgres behavior
549601
// of not calling Authorize()
550602
func (r *RecordingAuthorizer) AuthorizeSQL(_ context.Context, _ rbac.Subject, _ rbac.Action, _ rbac.Object) error {
551603
return r.AlwaysReturn
552604
}
553605

554606
func (r *RecordingAuthorizer) Authorize(_ context.Context, subject rbac.Subject, action rbac.Action, object rbac.Object) error {
555-
r.Called = &authCall{
556-
Subject: subject,
557-
Action: action,
558-
Object: object,
559-
}
607+
r.Called = append(r.Called, authCall{
608+
Actor: subject,
609+
Action: action,
610+
Object: object,
611+
})
560612
return r.AlwaysReturn
561613
}
562614

@@ -601,3 +653,12 @@ func (f fakePreparedAuthorizer) RegoString() string {
601653
}
602654
panic("not implemented")
603655
}
656+
657+
// LastCall is implemented to support legacy tests.
658+
// Deprecated
659+
func (r *RecordingAuthorizer) LastCall() *authCall {
660+
if len(r.Called) == 0 {
661+
return nil
662+
}
663+
return &r.Called[len(r.Called)-1]
664+
}

0 commit comments

Comments
 (0)