Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add comments + tests
  • Loading branch information
Emyrk committed Sep 28, 2022
commit d516be7840b5c29b41786ce8b580000dbc10f089
8 changes: 8 additions & 0 deletions coderd/authorize.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ import (
"github.com/coder/coder/codersdk"
)

// AuthorizeFilter takes a list of objects and returns the filtered list of
// objects that the user is authorized to perform the given action on.
// This is faster than calling Authorize() on each object.
func AuthorizeFilter[O rbac.Objecter](h *HTTPAuthorizer, r *http.Request, action rbac.Action, objects []O) ([]O, error) {
roles := httpmw.UserAuthorization(r)
objects, err := rbac.Filter(r.Context(), h.Authorizer, roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), action, objects)
Expand Down Expand Up @@ -85,6 +88,11 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action rbac.Action, object r
return true
}

// AuthorizeSQLFilter returns an authorization filter that can used in a
// SQL 'WHERE' clause. If the filter is used, the resulting rows returned
// from postgres are already authorized, and the caller does not need to
// call 'Authorize()' on the returned objects.
// Note the authorization is only for the given action and object type.
func (a *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.AuthorizeFilter, error) {
roles := httpmw.UserAuthorization(r)
prepared, err := a.Authorizer.PrepareByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), action, objectType)
Expand Down
9 changes: 7 additions & 2 deletions coderd/database/custom_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@ import (
"context"
"fmt"

"github.com/lib/pq"
"golang.org/x/xerrors"

"github.com/coder/coder/coderd/rbac"

"github.com/lib/pq"
)

type customQuerier interface {
AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error)
}

// AuthorizedGetWorkspaces returns all workspaces that the user is authorized to access.
// This code is copied from `GetWorkspaces` and adds the authorized filter WHERE
// clause.
func (q *sqlQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish we could add hooks into SQLc to do this. Kinda sucks to have to copy the original function and call it like this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might change the wording of this. Prefixing with Authorized makes me think that I'm authorized to get workspaces, not that the workspaces returned I'm authorized to access.

Thoughts on GetAuthorizedWorkspaces instead?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for us to get rid of GetWorkspaces in its entirety? There doesn't seem much of a point in getting workspaces that the user doesn't have access to.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kylecarbs that is a good question. For system level things there might be a need? But we can always add back "GetWorkspaces" later.

I was using SQLc's generated code to make this function. To get rid of GetWorkspaces would be to remove all those SQLc calls. I guess that is ok? It is nice to use it to generate most of this code, it just can't support the dynamic parts of the query.

query := fmt.Sprintf("%s AND %s", getWorkspaces, authorizedFilter.SQLString(rbac.SQLConfig{
VariableRenames: map[string]string{
Expand Down
2 changes: 2 additions & 0 deletions coderd/database/databasefake/databasefake.go
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ func (q *fakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U
}

func (q *fakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams) ([]database.Workspace, error) {
// A nil auth filter means no auth filter.
workspaces, err := q.AuthorizedGetWorkspaces(ctx, arg, nil)
return workspaces, err
}
Expand Down Expand Up @@ -566,6 +567,7 @@ func (q *fakeQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg database.
}
}

// If the filter exists, ensure the object is authorized.
if authorizedFilter != nil && !authorizedFilter.Eval(workspace.RBACObject()) {
continue
}
Expand Down
7 changes: 1 addition & 6 deletions coderd/database/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,19 @@ import (
"database/sql"
"errors"

"github.com/coder/coder/coderd/rbac"

"golang.org/x/xerrors"
)

// Store contains all queryable database functions.
// It extends the generated interface to add transaction support.
type Store interface {
querier
// customQuerier contains custom queries that are not generated.
customQuerier

InTx(func(Store) error) error
}

type customQuerier interface {
AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error)
}

// DBTX represents a database connection or transaction.
type DBTX interface {
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
Expand Down
14 changes: 0 additions & 14 deletions coderd/rbac/authz.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,6 @@ type PreparedAuthorized interface {
Compile() (AuthorizeFilter, error)
}

func (a *RegoAuthorizer) SQLFilter(ctx context.Context, subjID string, subjRoles []string, scope Scope, action Action, objectType string) (AuthorizeFilter, error) {
prepared, err := a.PrepareByRoleName(ctx, subjID, subjRoles, scope, action, objectType)
if err != nil {
return nil, xerrors.Errorf("filter: %w", err)
}

filter, err := prepared.Compile()
if err != nil {
return nil, xerrors.Errorf("filter: %w", err)
}

return filter, nil
}

// Filter takes in a list of objects, and will filter the list removing all
// the elements the subject does not have permission for. All objects must be
// of the same type.
Expand Down
2 changes: 2 additions & 0 deletions coderd/rbac/authz_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,8 @@ func testAuthorize(t *testing.T, name string, subject subject, sets ...[]authTes
partialAuthz, err := authorizer.Prepare(ctx, subject.UserID, subject.Roles, subject.Scope, a, c.resource.Type)
require.NoError(t, err, "make prepared authorizer")

// Ensure the partial can compile to a SQL clause.
// This does not guarantee that the clause is valid SQL.
_, err = Compile(partialAuthz.partialQueries)
require.NoError(t, err, "compile prepared authorizer")

Expand Down
50 changes: 41 additions & 9 deletions coderd/rbac/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,16 @@ import (

type SQLConfig struct {
// VariableRenames renames rego variables to sql columns
// Example:
// "input.object.org_owner": "organization_id::text"
// "input.object.owner": "owner_id::text"
VariableRenames map[string]string
}

type AuthorizeFilter interface {
// RegoString is used in debugging to see the original rego expression.
RegoString() string
// SQLString returns the SQL expression that can be used in a WHERE clause.
SQLString(cfg SQLConfig) string
// Eval is required for the fake in memory database to work. The in memory
// database can use this function to filter the results.
Expand Down Expand Up @@ -92,7 +97,18 @@ func processQuery(query ast.Body) (Expression, error) {

func processExpression(expr *ast.Expr) (Expression, error) {
if !expr.IsCall() {
return nil, xerrors.Errorf("invalid expression: function calls not supported")
// This could be a single term that is a valid expression.
if term, ok := expr.Terms.(*ast.Term); ok {
value, err := processTerm(term)
if err != nil {
return nil, xerrors.Errorf("single term expression: %w", err)
}
if boolExp, ok := value.(Expression); ok {
return boolExp, nil
}
// Default to error.
}
return nil, xerrors.Errorf("invalid expression: single non-boolean terms not supported")
}

op := expr.Operator().String()
Expand Down Expand Up @@ -142,6 +158,11 @@ func processTerms(expected int, terms []*ast.Term) ([]Term, error) {
func processTerm(term *ast.Term) (Term, error) {
base := base{Rego: term.String()}
switch v := term.Value.(type) {
case ast.Boolean:
return &termBoolean{
base: base,
Value: bool(v),
}, nil
case ast.Ref:
// A ref is a set of terms. If the first term is a var, then the
// following terms are the path to the value.
Expand Down Expand Up @@ -210,6 +231,10 @@ type expAnd struct {
}

func (t expAnd) SQLString(cfg SQLConfig) string {
if len(t.Expressions) == 1 {
return t.Expressions[0].SQLString(cfg)
}

exprs := make([]string, 0, len(t.Expressions))
for _, expr := range t.Expressions {
exprs = append(exprs, expr.SQLString(cfg))
Expand All @@ -232,11 +257,14 @@ type expOr struct {
}

func (t expOr) SQLString(cfg SQLConfig) string {
if len(t.Expressions) == 1 {
return t.Expressions[0].SQLString(cfg)
}

exprs := make([]string, 0, len(t.Expressions))
for _, expr := range t.Expressions {
exprs = append(exprs, expr.SQLString(cfg))
}

return "(" + strings.Join(exprs, " OR ") + ")"
}

Expand Down Expand Up @@ -273,7 +301,7 @@ func (t opEqual) SQLString(cfg SQLConfig) string {
}

func (t opEqual) Eval(object Object) bool {
a, b := t.Terms[0].Eval(object), t.Terms[1].Eval(object)
a, b := t.Terms[0].EvalTerm(object), t.Terms[1].EvalTerm(object)
if t.Not {
return a != b
}
Expand All @@ -288,7 +316,7 @@ type opInternalMember2 struct {
}

func (t opInternalMember2) Eval(object Object) bool {
a, b := t.Terms[0].Eval(object), t.Terms[1].Eval(object)
a, b := t.Terms[0].EvalTerm(object), t.Terms[1].EvalTerm(object)
bset, ok := b.([]interface{})
if !ok {
return false
Expand All @@ -314,15 +342,15 @@ type Term interface {
SQLString(cfg SQLConfig) string
// Eval will evaluate the term
// Terms can eval to any type. The operator/expression will type check.
Eval(object Object) interface{}
EvalTerm(object Object) interface{}
}

type termString struct {
base
Value string
}

func (t termString) Eval(_ Object) interface{} {
func (t termString) EvalTerm(_ Object) interface{} {
return t.Value
}

Expand All @@ -335,7 +363,7 @@ type termVariable struct {
Name string
}

func (t termVariable) Eval(obj Object) interface{} {
func (t termVariable) EvalTerm(obj Object) interface{} {
switch t.Name {
case "input.object.org_owner":
return obj.OrgID
Expand All @@ -362,10 +390,10 @@ type termSet struct {
Value []Term
}

func (t termSet) Eval(obj Object) interface{} {
func (t termSet) EvalTerm(obj Object) interface{} {
set := make([]interface{}, 0, len(t.Value))
for _, term := range t.Value {
set = append(set, term.Eval(obj))
set = append(set, term.EvalTerm(obj))
}

return set
Expand All @@ -389,6 +417,10 @@ func (t termBoolean) Eval(_ Object) bool {
return t.Value
}

func (t termBoolean) EvalTerm(_ Object) interface{} {
return t.Value
}

func (t termBoolean) SQLString(_ SQLConfig) string {
return strconv.FormatBool(t.Value)
}
Expand Down
60 changes: 28 additions & 32 deletions coderd/rbac/query_internal_test.go
Original file line number Diff line number Diff line change
@@ -1,42 +1,38 @@
package rbac

import (
"context"
"fmt"
"testing"

"github.com/stretchr/testify/require"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"

"github.com/google/uuid"
"github.com/stretchr/testify/require"
)

func TestCompileQuery(t *testing.T) {
ctx := context.Background()
defOrg := uuid.New()
unuseID := uuid.New()

user := subject{
UserID: "me",
Scope: must(ScopeRole(ScopeAll)),
Roles: []Role{
must(RoleByName(RoleMember())),
must(RoleByName(RoleOrgMember(defOrg))),
},
}
var action Action = ActionRead
object := ResourceWorkspace.InOrg(defOrg).WithOwner(unuseID.String())

auth := NewAuthorizer()
part, err := auth.Prepare(ctx, user.UserID, user.Roles, user.Scope, action, object.Type)
require.NoError(t, err)

result, err := Compile(part.partialQueries)
require.NoError(t, err)

fmt.Println("Rego: ", result.RegoString())
fmt.Println("SQL: ", result.SQLString(SQLConfig{
map[string]string{
"input.object.org_owner": "organization_id",
},
}))
t.Run("EmptyQuery", func(t *testing.T) {
expression, err := Compile(&rego.PartialQueries{
Queries: []ast.Body{
must(ast.ParseBody("")),
},
Support: []*ast.Module{},
})
require.NoError(t, err, "compile empty")

require.Equal(t, "true", expression.RegoString(), "empty query is rego 'true'")
require.Equal(t, "true", expression.SQLString(SQLConfig{}), "empty query is sql 'true'")
})

t.Run("TrueQuery", func(t *testing.T) {
expression, err := Compile(&rego.PartialQueries{
Queries: []ast.Body{
must(ast.ParseBody("true")),
},
Support: []*ast.Module{},
})
require.NoError(t, err, "compile empty")

require.Equal(t, "true", expression.RegoString(), "true query is rego 'true'")
require.Equal(t, "true", expression.SQLString(SQLConfig{}), "true query is sql 'true'")
})
}