Skip to content

Commit 87b3454

Browse files
committed
Rework arguments
1 parent 300dc4c commit 87b3454

File tree

5 files changed

+43
-21
lines changed

5 files changed

+43
-21
lines changed

coderd/authorize.go

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,19 +95,14 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action rbac.Action, object r
9595
// from postgres are already authorized, and the caller does not need to
9696
// call 'Authorize()' on the returned objects.
9797
// Note the authorization is only for the given action and object type.
98-
func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.AuthorizeFilter, error) {
98+
func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.PreparedAuthorized, error) {
9999
roles := httpmw.UserAuthorization(r)
100100
prepared, err := h.Authorizer.PrepareByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), roles.Groups, action, objectType)
101101
if err != nil {
102102
return nil, xerrors.Errorf("prepare filter: %w", err)
103103
}
104104

105-
filter, err := prepared.Compile()
106-
if err != nil {
107-
return nil, xerrors.Errorf("compile filter: %w", err)
108-
}
109-
110-
return filter, nil
105+
return prepared, nil
111106
}
112107

113108
// checkAuthorization returns if the current API key can use the given

coderd/coderdtest/authorize.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
"strings"
1010
"testing"
1111

12+
"github.com/coder/coder/coderd/rbac/regosql"
13+
1214
"github.com/go-chi/chi/v5"
1315
"github.com/stretchr/testify/assert"
1416
"github.com/stretchr/testify/require"
@@ -551,7 +553,7 @@ func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Obje
551553

552554
// Compile returns a compiled version of the authorizer that will work for
553555
// in memory databases. This fake version will not work against a SQL database.
554-
func (f *fakePreparedAuthorizer) Compile() (rbac.AuthorizeFilter, error) {
556+
func (f *fakePreparedAuthorizer) Compile(_ regosql.ConvertConfig) (rbac.AuthorizeFilter, error) {
555557
return f, nil
556558
}
557559

@@ -566,7 +568,7 @@ func (f fakePreparedAuthorizer) RegoString() string {
566568
panic("not implemented")
567569
}
568570

569-
func (f fakePreparedAuthorizer) SQLString(_ rbac.SQLConfig) string {
571+
func (f fakePreparedAuthorizer) SQLString() string {
570572
if f.HardCodedSQLString != "" {
571573
return f.HardCodedSQLString
572574
}

coderd/database/databasefake/databasefake.go

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ func (q *fakeQuerier) GetFilteredUserCount(ctx context.Context, arg database.Get
488488
return count, err
489489
}
490490

491-
func (q *fakeQuerier) GetAuthorizedUserCount(_ context.Context, params database.GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) {
491+
func (q *fakeQuerier) GetAuthorizedUserCount(_ context.Context, params database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
492492
q.mutex.RLock()
493493
defer q.mutex.RUnlock()
494494

@@ -539,6 +539,15 @@ func (q *fakeQuerier) GetAuthorizedUserCount(_ context.Context, params database.
539539
users = usersFilteredByRole
540540
}
541541

542+
var authorizedFilter rbac.AuthorizeFilter
543+
var err error
544+
if prepared != nil {
545+
authorizedFilter, err = prepared.Compile(rbac.ConfigWithoutACL())
546+
if err != nil {
547+
return -1, xerrors.Errorf("compile authorized filter: %w", err)
548+
}
549+
}
550+
542551
for _, user := range q.workspaces {
543552
// If the filter exists, ensure the object is authorized.
544553
if authorizedFilter != nil && !authorizedFilter.Eval(user.RBACObject()) {
@@ -750,10 +759,15 @@ func (q *fakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspa
750759
}
751760

752761
//nolint:gocyclo
753-
func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]database.GetWorkspacesRow, error) {
762+
func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) {
754763
q.mutex.RLock()
755764
defer q.mutex.RUnlock()
756765

766+
filter, err := prepared.Compile(rbac.ConfigWithoutACL())
767+
if err != nil {
768+
return nil, xerrors.Errorf("compile authorized filter: %w", err)
769+
}
770+
757771
workspaces := make([]database.Workspace, 0)
758772
for _, workspace := range q.workspaces {
759773
if arg.OwnerID != uuid.Nil && workspace.OwnerID != arg.OwnerID {
@@ -885,7 +899,7 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.
885899
}
886900

887901
// If the filter exists, ensure the object is authorized.
888-
if authorizedFilter != nil && !authorizedFilter.Eval(workspace.RBACObject()) {
902+
if filter != nil && !filter.Eval(workspace.RBACObject()) {
889903
continue
890904
}
891905
workspaces = append(workspaces, workspace)

coderd/database/modelqueries.go

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,16 +112,21 @@ func (q *sqlQuerier) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([
112112
}
113113

114114
type workspaceQuerier interface {
115-
GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]GetWorkspacesRow, error)
115+
GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]GetWorkspacesRow, error)
116116
}
117117

118118
// GetAuthorizedWorkspaces returns all workspaces that the user is authorized to access.
119119
// This code is copied from `GetWorkspaces` and adds the authorized filter WHERE
120120
// clause.
121-
func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]GetWorkspacesRow, error) {
121+
func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]GetWorkspacesRow, error) {
122+
authorizedFilter, err := prepared.Compile(rbac.ConfigWithoutACL())
123+
if err != nil {
124+
return nil, xerrors.Errorf("compile authorized filter: %w", err)
125+
}
126+
122127
// In order to properly use ORDER BY, OFFSET, and LIMIT, we need to inject the
123128
// authorizedFilter between the end of the where clause and those statements.
124-
filter := strings.Replace(getWorkspaces, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString(rbac.NoACLConfig())), 1)
129+
filter := strings.Replace(getWorkspaces, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString()), 1)
125130
// The name comment is for metric tracking
126131
query := fmt.Sprintf("-- name: GetAuthorizedWorkspaces :many\n%s", filter)
127132
rows, err := q.db.QueryContext(ctx, query,
@@ -170,11 +175,16 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa
170175
}
171176

172177
type userQuerier interface {
173-
GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error)
178+
GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error)
174179
}
175180

176-
func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) {
177-
filter := strings.Replace(getFilteredUserCount, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString(rbac.NoACLConfig())), 1)
181+
func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
182+
authorizedFilter, err := prepared.Compile(rbac.ConfigWithoutACL())
183+
if err != nil {
184+
return -1, xerrors.Errorf("compile authorized filter: %w", err)
185+
}
186+
187+
filter := strings.Replace(getFilteredUserCount, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString()), 1)
178188
query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filter)
179189
row := q.db.QueryRowContext(ctx, query,
180190
arg.Deleted,
@@ -183,6 +193,6 @@ func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFiltered
183193
pq.Array(arg.RbacRole),
184194
)
185195
var count int64
186-
err := row.Scan(&count)
196+
err = row.Scan(&count)
187197
return count, err
188198
}

coderd/workspaces.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) {
118118
filter.OwnerUsername = ""
119119
}
120120

121-
sqlFilter, err := api.HTTPAuth.AuthorizeSQLFilter(r, rbac.ActionRead, rbac.ResourceWorkspace.Type)
121+
// Workspaces do not have ACL columns.
122+
prepared, err := api.HTTPAuth.AuthorizeSQLFilter(r, rbac.ActionRead, rbac.ResourceWorkspace.Type)
122123
if err != nil {
123124
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
124125
Message: "Internal error preparing sql filter.",
@@ -127,7 +128,7 @@ func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) {
127128
return
128129
}
129130

130-
workspaceRows, err := api.Database.GetAuthorizedWorkspaces(ctx, filter, sqlFilter)
131+
workspaceRows, err := api.Database.GetAuthorizedWorkspaces(ctx, filter, prepared)
131132
if err != nil {
132133
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
133134
Message: "Internal error fetching workspaces.",

0 commit comments

Comments
 (0)