Skip to content

Commit ab9298f

Browse files
authored
chore: Rewrite rbac rego -> SQL clause (#5138)
* chore: Rewrite rbac rego -> SQL clause Previous code was challenging to read with edge cases - bug: OrgAdmin could not make new groups - Also refactor some function names
1 parent d5ab4fd commit ab9298f

39 files changed

+2075
-823
lines changed

coderd/authorize.go

+2-8
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"net/http"
66

77
"github.com/google/uuid"
8-
98
"golang.org/x/xerrors"
109

1110
"cdr.dev/slog"
@@ -95,19 +94,14 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action rbac.Action, object r
9594
// from postgres are already authorized, and the caller does not need to
9695
// call 'Authorize()' on the returned objects.
9796
// Note the authorization is only for the given action and object type.
98-
func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.AuthorizeFilter, error) {
97+
func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.PreparedAuthorized, error) {
9998
roles := httpmw.UserAuthorization(r)
10099
prepared, err := h.Authorizer.PrepareByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), roles.Groups, action, objectType)
101100
if err != nil {
102101
return nil, xerrors.Errorf("prepare filter: %w", err)
103102
}
104103

105-
filter, err := prepared.Compile()
106-
if err != nil {
107-
return nil, xerrors.Errorf("compile filter: %w", err)
108-
}
109-
110-
return filter, nil
104+
return prepared, nil
111105
}
112106

113107
// checkAuthorization returns if the current API key can use the given

coderd/coderdtest/authorize.go

+24-16
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,28 @@ import (
99
"strings"
1010
"testing"
1111

12+
"github.com/coder/coder/coderd/database/databasefake"
13+
1214
"github.com/go-chi/chi/v5"
1315
"github.com/stretchr/testify/assert"
1416
"github.com/stretchr/testify/require"
1517
"golang.org/x/xerrors"
1618

1719
"github.com/coder/coder/coderd"
1820
"github.com/coder/coder/coderd/rbac"
21+
"github.com/coder/coder/coderd/rbac/regosql"
1922
"github.com/coder/coder/codersdk"
2023
"github.com/coder/coder/provisioner/echo"
2124
"github.com/coder/coder/provisionersdk/proto"
2225
)
2326

2427
func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
28+
// For any route using SQL filters, we need to know if the database is an
29+
// in memory fake. This is because the in memory fake does not use SQL, and
30+
// still uses rego. So this boolean indicates how to assert the expected
31+
// behavior.
32+
_, isMemoryDB := a.api.Database.(databasefake.FakeDatabase)
33+
2534
// Some quick reused objects
2635
workspaceRBACObj := rbac.ResourceWorkspace.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String())
2736
workspaceExecObj := rbac.ResourceWorkspaceExecution.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String())
@@ -125,11 +134,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
125134
AssertAction: rbac.ActionCreate,
126135
AssertObject: workspaceExecObj,
127136
},
128-
"GET:/api/v2/organizations/{organization}/templates": {
129-
StatusCode: http.StatusOK,
130-
AssertAction: rbac.ActionRead,
131-
AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID),
132-
},
133137
"POST:/api/v2/organizations/{organization}/templates": {
134138
AssertAction: rbac.ActionCreate,
135139
AssertObject: rbac.ResourceTemplate.InOrg(a.Organization.ID),
@@ -240,7 +244,18 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
240244
"GET:/api/v2/organizations/{organization}/templateversions/{templateversionname}": {StatusCode: http.StatusBadRequest, NoAuthorize: true},
241245

242246
// Endpoints that use the SQLQuery filter.
243-
"GET:/api/v2/workspaces/": {StatusCode: http.StatusOK, NoAuthorize: true},
247+
"GET:/api/v2/workspaces/": {
248+
StatusCode: http.StatusOK,
249+
NoAuthorize: !isMemoryDB,
250+
AssertAction: rbac.ActionRead,
251+
AssertObject: rbac.ResourceWorkspace,
252+
},
253+
"GET:/api/v2/organizations/{organization}/templates": {
254+
StatusCode: http.StatusOK,
255+
NoAuthorize: !isMemoryDB,
256+
AssertAction: rbac.ActionRead,
257+
AssertObject: rbac.ResourceTemplate,
258+
},
244259
}
245260

246261
// Routes like proxy routes support all HTTP methods. A helper func to expand
@@ -549,10 +564,10 @@ func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Obje
549564
return f.Original.ByRoleName(ctx, f.SubjectID, f.Roles, f.Scope, f.Groups, f.Action, object)
550565
}
551566

552-
// Compile returns a compiled version of the authorizer that will work for
567+
// CompileToSQL returns a compiled version of the authorizer that will work for
553568
// in memory databases. This fake version will not work against a SQL database.
554-
func (f *fakePreparedAuthorizer) Compile() (rbac.AuthorizeFilter, error) {
555-
return f, nil
569+
func (fakePreparedAuthorizer) CompileToSQL(_ regosql.ConvertConfig) (string, error) {
570+
return "", xerrors.New("not implemented")
556571
}
557572

558573
func (f *fakePreparedAuthorizer) Eval(object rbac.Object) bool {
@@ -565,10 +580,3 @@ func (f fakePreparedAuthorizer) RegoString() string {
565580
}
566581
panic("not implemented")
567582
}
568-
569-
func (f fakePreparedAuthorizer) SQLString(_ rbac.SQLConfig) string {
570-
if f.HardCodedSQLString != "" {
571-
return f.HardCodedSQLString
572-
}
573-
panic("not implemented")
574-
}

coderd/database/databasefake/databasefake.go

+30-12
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ import (
2020
"github.com/coder/coder/coderd/util/slice"
2121
)
2222

23+
// FakeDatabase is helpful for knowing if the underlying db is an in memory fake
24+
// database. This is only in the databasefake package, so will only be used
25+
// by unit tests.
26+
type FakeDatabase interface {
27+
IsFakeDB()
28+
}
29+
2330
var errDuplicateKey = &pq.Error{
2431
Code: "23505",
2532
Message: "duplicate key value violates unique constraint",
@@ -117,6 +124,7 @@ type data struct {
117124
lastLicenseID int32
118125
}
119126

127+
func (fakeQuerier) IsFakeDB() {}
120128
func (*fakeQuerier) Ping(_ context.Context) (time.Duration, error) {
121129
return 0, nil
122130
}
@@ -488,11 +496,20 @@ func (q *fakeQuerier) GetFilteredUserCount(ctx context.Context, arg database.Get
488496
return count, err
489497
}
490498

491-
func (q *fakeQuerier) GetAuthorizedUserCount(_ context.Context, params database.GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) {
499+
func (q *fakeQuerier) GetAuthorizedUserCount(ctx context.Context, params database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
492500
q.mutex.RLock()
493501
defer q.mutex.RUnlock()
494502

495-
users := append([]database.User{}, q.users...)
503+
users := make([]database.User, 0, len(q.users))
504+
505+
for _, user := range q.users {
506+
// If the filter exists, ensure the object is authorized.
507+
if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil {
508+
continue
509+
}
510+
511+
users = append(users, user)
512+
}
496513

497514
if params.Deleted {
498515
tmp := make([]database.User, 0, len(users))
@@ -539,13 +556,6 @@ func (q *fakeQuerier) GetAuthorizedUserCount(_ context.Context, params database.
539556
users = usersFilteredByRole
540557
}
541558

542-
for _, user := range q.workspaces {
543-
// If the filter exists, ensure the object is authorized.
544-
if authorizedFilter != nil && !authorizedFilter.Eval(user.RBACObject()) {
545-
continue
546-
}
547-
}
548-
549559
return int64(len(users)), nil
550560
}
551561

@@ -750,7 +760,7 @@ func (q *fakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspa
750760
}
751761

752762
//nolint:gocyclo
753-
func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]database.GetWorkspacesRow, error) {
763+
func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) {
754764
q.mutex.RLock()
755765
defer q.mutex.RUnlock()
756766

@@ -923,7 +933,7 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.
923933
}
924934

925935
// If the filter exists, ensure the object is authorized.
926-
if authorizedFilter != nil && !authorizedFilter.Eval(workspace.RBACObject()) {
936+
if prepared != nil && prepared.Authorize(ctx, workspace.RBACObject()) != nil {
927937
continue
928938
}
929939
workspaces = append(workspaces, workspace)
@@ -1505,12 +1515,20 @@ func (q *fakeQuerier) UpdateTemplateMetaByID(_ context.Context, arg database.Upd
15051515
return database.Template{}, sql.ErrNoRows
15061516
}
15071517

1508-
func (q *fakeQuerier) GetTemplatesWithFilter(_ context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) {
1518+
func (q *fakeQuerier) GetTemplatesWithFilter(ctx context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) {
1519+
return q.GetAuthorizedTemplates(ctx, arg, nil)
1520+
}
1521+
1522+
func (q *fakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) {
15091523
q.mutex.RLock()
15101524
defer q.mutex.RUnlock()
15111525

15121526
var templates []database.Template
15131527
for _, template := range q.templates {
1528+
if prepared != nil && prepared.Authorize(ctx, template.RBACObject()) != nil {
1529+
continue
1530+
}
1531+
15141532
if template.Deleted != arg.Deleted {
15151533
continue
15161534
}

coderd/database/databasefake/databasefake_test.go

+1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ func TestExactMethods(t *testing.T) {
7474
extraFakeMethods := map[string]string{
7575
// Example
7676
// "SortFakeLists": "Helper function used",
77+
"IsFakeDB": "Helper function used for unit testing",
7778
}
7879

7980
fake := reflect.TypeOf(databasefake.New())

coderd/database/modelqueries.go

+101-11
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,16 @@ import (
55
"fmt"
66
"strings"
77

8+
"github.com/google/uuid"
89
"github.com/lib/pq"
10+
"golang.org/x/xerrors"
911

1012
"github.com/coder/coder/coderd/rbac"
13+
"github.com/coder/coder/coderd/rbac/regosql"
14+
)
1115

12-
"github.com/google/uuid"
13-
"golang.org/x/xerrors"
16+
const (
17+
authorizedQueryPlaceholder = "-- @authorize_filter"
1418
)
1519

1620
// customQuerier encompasses all non-generated queries.
@@ -23,10 +27,70 @@ type customQuerier interface {
2327
}
2428

2529
type templateQuerier interface {
30+
GetAuthorizedTemplates(ctx context.Context, arg GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]Template, error)
2631
GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]TemplateGroup, error)
2732
GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]TemplateUser, error)
2833
}
2934

35+
func (q *sqlQuerier) GetAuthorizedTemplates(ctx context.Context, arg GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]Template, error) {
36+
authorizedFilter, err := prepared.CompileToSQL(regosql.ConvertConfig{
37+
VariableConverter: regosql.TemplateConverter(),
38+
})
39+
if err != nil {
40+
return nil, xerrors.Errorf("compile authorized filter: %w", err)
41+
}
42+
43+
filtered, err := insertAuthorizedFilter(getTemplatesWithFilter, fmt.Sprintf(" AND %s", authorizedFilter))
44+
if err != nil {
45+
return nil, xerrors.Errorf("insert authorized filter: %w", err)
46+
}
47+
48+
// The name comment is for metric tracking
49+
query := fmt.Sprintf("-- name: GetAuthorizedTemplates :many\n%s", filtered)
50+
rows, err := q.db.QueryContext(ctx, query,
51+
arg.Deleted,
52+
arg.OrganizationID,
53+
arg.ExactName,
54+
pq.Array(arg.IDs),
55+
)
56+
if err != nil {
57+
return nil, err
58+
}
59+
defer rows.Close()
60+
var items []Template
61+
for rows.Next() {
62+
var i Template
63+
if err := rows.Scan(
64+
&i.ID,
65+
&i.CreatedAt,
66+
&i.UpdatedAt,
67+
&i.OrganizationID,
68+
&i.Deleted,
69+
&i.Name,
70+
&i.Provisioner,
71+
&i.ActiveVersionID,
72+
&i.Description,
73+
&i.DefaultTTL,
74+
&i.CreatedBy,
75+
&i.Icon,
76+
&i.UserACL,
77+
&i.GroupACL,
78+
&i.DisplayName,
79+
&i.AllowUserCancelWorkspaceJobs,
80+
); err != nil {
81+
return nil, err
82+
}
83+
items = append(items, i)
84+
}
85+
if err := rows.Close(); err != nil {
86+
return nil, err
87+
}
88+
if err := rows.Err(); err != nil {
89+
return nil, err
90+
}
91+
return items, nil
92+
}
93+
3094
type TemplateUser struct {
3195
User
3296
Actions Actions `db:"actions"`
@@ -112,18 +176,27 @@ func (q *sqlQuerier) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([
112176
}
113177

114178
type workspaceQuerier interface {
115-
GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]GetWorkspacesRow, error)
179+
GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]GetWorkspacesRow, error)
116180
}
117181

118182
// GetAuthorizedWorkspaces returns all workspaces that the user is authorized to access.
119183
// This code is copied from `GetWorkspaces` and adds the authorized filter WHERE
120184
// clause.
121-
func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]GetWorkspacesRow, error) {
185+
func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]GetWorkspacesRow, error) {
186+
authorizedFilter, err := prepared.CompileToSQL(rbac.ConfigWithoutACL())
187+
if err != nil {
188+
return nil, xerrors.Errorf("compile authorized filter: %w", err)
189+
}
190+
122191
// In order to properly use ORDER BY, OFFSET, and LIMIT, we need to inject the
123192
// authorizedFilter between the end of the where clause and those statements.
124-
filter := strings.Replace(getWorkspaces, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString(rbac.NoACLConfig())), 1)
193+
filtered, err := insertAuthorizedFilter(getWorkspaces, fmt.Sprintf(" AND %s", authorizedFilter))
194+
if err != nil {
195+
return nil, xerrors.Errorf("insert authorized filter: %w", err)
196+
}
197+
125198
// The name comment is for metric tracking
126-
query := fmt.Sprintf("-- name: GetAuthorizedWorkspaces :many\n%s", filter)
199+
query := fmt.Sprintf("-- name: GetAuthorizedWorkspaces :many\n%s", filtered)
127200
rows, err := q.db.QueryContext(ctx, query,
128201
arg.Deleted,
129202
arg.Status,
@@ -172,19 +245,36 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa
172245
}
173246

174247
type userQuerier interface {
175-
GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error)
248+
GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error)
176249
}
177250

178-
func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) {
179-
filter := strings.Replace(getFilteredUserCount, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString(rbac.NoACLConfig())), 1)
180-
query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filter)
251+
func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
252+
authorizedFilter, err := prepared.CompileToSQL(rbac.ConfigWithoutACL())
253+
if err != nil {
254+
return -1, xerrors.Errorf("compile authorized filter: %w", err)
255+
}
256+
257+
filtered, err := insertAuthorizedFilter(getFilteredUserCount, fmt.Sprintf(" AND %s", authorizedFilter))
258+
if err != nil {
259+
return -1, xerrors.Errorf("insert authorized filter: %w", err)
260+
}
261+
262+
query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filtered)
181263
row := q.db.QueryRowContext(ctx, query,
182264
arg.Deleted,
183265
arg.Search,
184266
pq.Array(arg.Status),
185267
pq.Array(arg.RbacRole),
186268
)
187269
var count int64
188-
err := row.Scan(&count)
270+
err = row.Scan(&count)
189271
return count, err
190272
}
273+
274+
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
275+
if !strings.Contains(query, authorizedQueryPlaceholder) {
276+
return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query")
277+
}
278+
filtered := strings.Replace(query, authorizedQueryPlaceholder, replaceWith, 1)
279+
return filtered, nil
280+
}

0 commit comments

Comments
 (0)