Skip to content

Commit 2cf1cad

Browse files
committed
Rename to "querier", add unit test for double wrap protection
1 parent 951d74f commit 2cf1cad

File tree

5 files changed

+221
-224
lines changed

5 files changed

+221
-224
lines changed

coderd/database/dbauthz/dbauthz.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import (
1515
"github.com/coder/coder/coderd/rbac"
1616
)
1717

18-
var _ database.Store = (*authzQuerier)(nil)
18+
var _ database.Store = (*querier)(nil)
1919

2020
var (
2121
// NoActorError wraps ErrNoRows for the api to return a 404. This is the correct
@@ -55,34 +55,34 @@ func logNotAuthorizedError(ctx context.Context, logger slog.Logger, err error) e
5555
}
5656
}
5757

58-
// authzQuerier is a wrapper around the database store that performs authorization
59-
// checks before returning data. All authzQuerier methods expect an authorization
58+
// querier is a wrapper around the database store that performs authorization
59+
// checks before returning data. All querier methods expect an authorization
6060
// subject present in the context. If no subject is present, most methods will
6161
// fail.
6262
//
6363
// Use WithAuthorizeContext to set the authorization subject in the context for
6464
// the common user case.
65-
type authzQuerier struct {
65+
type querier struct {
6666
db database.Store
6767
auth rbac.Authorizer
6868
log slog.Logger
6969
}
7070

7171
func New(db database.Store, authorizer rbac.Authorizer, logger slog.Logger) database.Store {
72-
// If the underlying db store is already an authzquerier, return it.
72+
// If the underlying db store is already a querier, return it.
7373
// Do not double wrap.
74-
if _, ok := db.(*authzQuerier); ok {
74+
if _, ok := db.(*querier); ok {
7575
return db
7676
}
77-
return &authzQuerier{
77+
return &querier{
7878
db: db,
7979
auth: authorizer,
8080
log: logger,
8181
}
8282
}
8383

8484
// authorizeContext is a helper function to authorize an action on an object.
85-
func (q *authzQuerier) authorizeContext(ctx context.Context, action rbac.Action, object rbac.Objecter) error {
85+
func (q *querier) authorizeContext(ctx context.Context, action rbac.Action, object rbac.Objecter) error {
8686
act, ok := ActorFromContext(ctx)
8787
if !ok {
8888
return NoActorError

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,9 @@ package dbauthz_test
22

33
import (
44
"context"
5-
"database/sql"
65
"reflect"
76
"testing"
87

9-
"cdr.dev/slog/sloggers/slogtest"
10-
118
"github.com/google/uuid"
129
"github.com/stretchr/testify/require"
1310
"golang.org/x/xerrors"
@@ -55,31 +52,31 @@ func TestInTX(t *testing.T) {
5552
require.ErrorAs(t, err, &dbauthz.NotAuthorizedError{}, "must be an authorized error")
5653
}
5754

58-
func TestNotAuthorizedError(t *testing.T) {
55+
// TestNew should not double wrap a querier.
56+
func TestNew(t *testing.T) {
5957
t.Parallel()
6058

61-
t.Run("Is404", func(t *testing.T) {
62-
t.Parallel()
63-
64-
testErr := xerrors.New("custom error")
59+
var (
60+
db = dbfake.New()
61+
exp = dbgen.Workspace(t, db, database.Workspace{})
62+
rec = &coderdtest.RecordingAuthorizer{
63+
Wrapped: &coderdtest.FakeAuthorizer{AlwaysReturn: nil},
64+
}
65+
subj = rbac.Subject{}
66+
ctx = dbauthz.WithAuthorizeContext(context.Background(), rbac.Subject{})
67+
)
6568

66-
err := dbauthz.logNotAuthorizedError(context.Background(), slogtest.Make(t, nil), testErr)
67-
require.ErrorIs(t, err, sql.ErrNoRows, "must be a sql.ErrNoRows")
69+
// Double wrap should not cause an actual double wrap. So only 1 rbac call
70+
// should be made.
71+
az := dbauthz.New(db, rec, slog.Make())
72+
az = dbauthz.New(az, rec, slog.Make())
6873

69-
var authErr dbauthz.NotAuthorizedError
70-
require.ErrorAs(t, err, &authErr, "must be a NotAuthorizedError")
71-
require.ErrorIs(t, authErr.Err, testErr, "internal error must match")
72-
})
74+
w, err := az.GetWorkspaceByID(ctx, exp.ID)
75+
require.NoError(t, err, "must not error")
76+
require.Equal(t, exp, w, "must be equal")
7377

74-
t.Run("MissingActor", func(t *testing.T) {
75-
t.Parallel()
76-
q := dbauthz.New(dbfake.New(), &coderdtest.RecordingAuthorizer{
77-
Wrapped: &coderdtest.FakeAuthorizer{AlwaysReturn: nil},
78-
}, slog.Make())
79-
// This should fail because the actor is missing.
80-
_, err := q.GetWorkspaceByID(context.Background(), uuid.New())
81-
require.ErrorIs(t, err, dbauthz.NoActorError, "must be a NoActorError")
82-
})
78+
rec.AssertActor(t, subj, rec.Pair(rbac.ActionRead, exp))
79+
require.NoError(t, rec.AllAsserted(), "should only be 1 rbac call")
8380
}
8481

8582
// TestDBAuthzRecursive is a simple test to search for infinite recursion

0 commit comments

Comments
 (0)