From ff931dcc2155bab9066594554125fc688064a94a Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 27 Sep 2022 13:25:06 -0400 Subject: [PATCH 01/18] feat: Convert rego queries into SQL clauses --- coderd/rbac/query.go | 269 +++++++++++++++++++++++++++++ coderd/rbac/query_internal_test.go | 38 ++++ 2 files changed, 307 insertions(+) create mode 100644 coderd/rbac/query.go create mode 100644 coderd/rbac/query_internal_test.go diff --git a/coderd/rbac/query.go b/coderd/rbac/query.go new file mode 100644 index 0000000000000..1dab38ecea429 --- /dev/null +++ b/coderd/rbac/query.go @@ -0,0 +1,269 @@ +package rbac + +import ( + "context" + "fmt" + "strings" + + "github.com/open-policy-agent/opa/ast" + + "golang.org/x/xerrors" + + "github.com/open-policy-agent/opa/rego" +) + +// Example python: https://github.com/open-policy-agent/contrib/tree/main/data_filter_example +// + +func Compile(ctx context.Context, partialQueries *rego.PartialQueries) (Expression, error) { + if len(partialQueries.Support) > 0 { + return nil, xerrors.Errorf("cannot convert support rules, expect 0 found %d", len(partialQueries.Support)) + } + + result := make([]Expression, 0, len(partialQueries.Queries)) + var builder strings.Builder + for i := range partialQueries.Queries { + query, err := processQuery(partialQueries.Queries[i]) + if err != nil { + return nil, err + } + result = append(result, query) + if i != 0 { + builder.WriteString("\n") + } + builder.WriteString(partialQueries.Queries[i].String()) + } + return ExpOr{ + Base: Base{ + Rego: builder.String(), + }, + Expressions: result, + }, nil +} + +func processQuery(query ast.Body) (Expression, error) { + expressions := make([]Expression, 0, len(query)) + for _, astExpr := range query { + expr, err := processExpression(astExpr) + if err != nil { + return nil, err + } + expressions = append(expressions, expr) + } + + return ExpAnd{ + Base: Base{ + Rego: query.String(), + }, + Expressions: expressions, + }, nil +} + +func processExpression(expr *ast.Expr) (Expression, error) { + if !expr.IsCall() { + return nil, xerrors.Errorf("invalid expression: function calls not supported") + } + + op := expr.Operator().String() + base := Base{Rego: op} + switch op { + case "neq", "eq", "equal": + terms, err := processTerms(2, expr.Operands()) + if err != nil { + return nil, xerrors.Errorf("invalid '%s' expression: %w", op, err) + } + return &OpEqual{ + Base: base, + Terms: [2]Term{terms[0], terms[1]}, + Not: op == "neq", + }, nil + case "internal.member_2": + terms, err := processTerms(2, expr.Operands()) + if err != nil { + return nil, xerrors.Errorf("invalid '%s' expression: %w", op, err) + } + return &OpInternalMember2{ + Base: base, + Terms: [2]Term{terms[0], terms[1]}, + }, nil + + //case "eq", "equal": + default: + return nil, xerrors.Errorf("invalid expression: operator %s not supported", op) + } +} + +func processTerms(expected int, terms []*ast.Term) ([]Term, error) { + if len(terms) != expected { + return nil, xerrors.Errorf("too many arguments, expect %d found %d", expected, len(terms)) + } + + result := make([]Term, 0, len(terms)) + for _, term := range terms { + processed, err := processTerm(term) + if err != nil { + return nil, xerrors.Errorf("invalid term: %w", err) + } + result = append(result, processed) + } + + return result, nil +} + +func processTerm(term *ast.Term) (Term, error) { + base := Base{Rego: term.String()} + switch v := term.Value.(type) { + case ast.Ref: + // A ref is a set of terms. If the first term is a var, then the + // following terms are the path to the value. + if v0, ok := v[0].Value.(ast.Var); ok { + name := v0.String() + for _, p := range v[1:] { + name += "." + p.String() + } + return &TermVariable{ + Base: base, + Name: name, + }, nil + } else { + return nil, xerrors.Errorf("invalid term: ref must start with a var, started with %T", v[0]) + } + case ast.Var: + return &TermVariable{ + Name: v.String(), + Base: base, + }, nil + case ast.String: + return &TermString{ + Value: v.String(), + Base: base, + }, nil + case ast.Set: + return &TermSet{ + Value: v, + Base: base, + }, nil + default: + return nil, xerrors.Errorf("invalid term: %T not supported, %q", v, term.String()) + } +} + +type Base struct { + // Rego is the original rego string + Rego string +} + +func (b Base) RegoString() string { + return b.Rego +} + +// Expression comprises a set of terms, operators, and functions. All +// expressions return a boolean value. +// +// Eg: neq(input.object.org_owner, "") +type Expression interface { + RegoString() string + SQLString() string +} + +type ExpAnd struct { + Base + Expressions []Expression +} + +func (t ExpAnd) SQLString() string { + exprs := make([]string, 0, len(t.Expressions)) + for _, expr := range t.Expressions { + exprs = append(exprs, expr.SQLString()) + } + return strings.Join(exprs, " AND ") +} + +type ExpOr struct { + Base + Expressions []Expression +} + +func (t ExpOr) SQLString() string { + exprs := make([]string, 0, len(t.Expressions)) + for _, expr := range t.Expressions { + exprs = append(exprs, expr.SQLString()) + } + return strings.Join(exprs, " OR ") +} + +// Operator joins terms together to form an expression. +// Operators are also expressions. +// +// Eg: "=", "neq", "internal.member_2", etc. +type Operator interface { + RegoString() string + SQLString() string +} + +type OpEqual struct { + Base + Terms [2]Term + // For NotEqual + Not bool +} + +func (t OpEqual) SQLString() string { + op := "=" + if t.Not { + op = "!=" + } + return fmt.Sprintf("%s %s %s", t.Terms[0].SQLString(), op, t.Terms[1].SQLString()) +} + +type OpInternalMember2 struct { + Base + Terms [2]Term +} + +func (t OpInternalMember2) SQLString() string { + return fmt.Sprintf("%s = ANY(%s)", t.Terms[0].SQLString(), t.Terms[1].SQLString()) +} + +// Term is a single value in an expression. Terms can be variables or constants. +// +// Eg: "f9d6fb75-b59b-4363-ab6b-ae9d26b679d7", "input.object.org_owner", +// "{"f9d6fb75-b59b-4363-ab6b-ae9d26b679d7"}" +type Term interface { + SQLString() string + RegoString() string +} + +type TermString struct { + Base + Value string +} + +func (t TermString) SQLString() string { + return t.Value +} + +type TermVariable struct { + Base + Name string +} + +func (t TermVariable) SQLString() string { + return t.Name +} + +type TermSet struct { + Base + Value ast.Set +} + +func (t TermSet) SQLString() string { + values := t.Value.Slice() + elems := make([]string, 0, len(values)) + // TODO: Handle different typed terms? + for _, v := range t.Value.Slice() { + elems = append(elems, v.String()) + } + + return fmt.Sprintf("ARRAY [%s]", strings.Join(elems, ",")) +} diff --git a/coderd/rbac/query_internal_test.go b/coderd/rbac/query_internal_test.go new file mode 100644 index 0000000000000..c24a48d91ed17 --- /dev/null +++ b/coderd/rbac/query_internal_test.go @@ -0,0 +1,38 @@ +package rbac + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/google/uuid" +) + +func TestCompileQuery(t *testing.T) { + ctx := context.Background() + defOrg := uuid.New() + unuseID := uuid.New() + + user := subject{ + UserID: "me", + Scope: must(ScopeRole(ScopeAll)), + Roles: []Role{ + must(RoleByName(RoleMember())), + must(RoleByName(RoleOrgMember(defOrg))), + }, + } + var action Action = ActionRead + object := ResourceWorkspace.InOrg(defOrg).WithOwner(unuseID.String()) + + auth := NewAuthorizer() + part, err := auth.Prepare(ctx, user.UserID, user.Roles, user.Scope, action, object.Type) + require.NoError(t, err) + + result, err := Compile(ctx, part.partialQueries) + require.NoError(t, err) + + fmt.Println("Rego: ", result.RegoString()) + fmt.Println("SQL: ", result.SQLString()) +} From e535e5a2a6254b8515029453773b95eef9fa5156 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 27 Sep 2022 13:37:52 -0400 Subject: [PATCH 02/18] Fix postgres quotes to single quotes --- coderd/rbac/query.go | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/coderd/rbac/query.go b/coderd/rbac/query.go index 1dab38ecea429..04f0290e0159e 100644 --- a/coderd/rbac/query.go +++ b/coderd/rbac/query.go @@ -117,9 +117,9 @@ func processTerm(term *ast.Term) (Term, error) { // A ref is a set of terms. If the first term is a var, then the // following terms are the path to the value. if v0, ok := v[0].Value.(ast.Var); ok { - name := v0.String() + name := trimQuotes(v0.String()) for _, p := range v[1:] { - name += "." + p.String() + name += "." + trimQuotes(p.String()) } return &TermVariable{ Base: base, @@ -130,12 +130,12 @@ func processTerm(term *ast.Term) (Term, error) { } case ast.Var: return &TermVariable{ - Name: v.String(), + Name: trimQuotes(v.String()), Base: base, }, nil case ast.String: return &TermString{ - Value: v.String(), + Value: trimQuotes(v.String()), Base: base, }, nil case ast.Set: @@ -176,7 +176,7 @@ func (t ExpAnd) SQLString() string { for _, expr := range t.Expressions { exprs = append(exprs, expr.SQLString()) } - return strings.Join(exprs, " AND ") + return "(" + strings.Join(exprs, " AND ") + ")" } type ExpOr struct { @@ -189,7 +189,8 @@ func (t ExpOr) SQLString() string { for _, expr := range t.Expressions { exprs = append(exprs, expr.SQLString()) } - return strings.Join(exprs, " OR ") + + return "(" + strings.Join(exprs, " OR ") + ")" } // Operator joins terms together to form an expression. @@ -240,7 +241,7 @@ type TermString struct { } func (t TermString) SQLString() string { - return t.Value + return "'" + t.Value + "'" } type TermVariable struct { @@ -262,8 +263,16 @@ func (t TermSet) SQLString() string { elems := make([]string, 0, len(values)) // TODO: Handle different typed terms? for _, v := range t.Value.Slice() { - elems = append(elems, v.String()) + t, err := processTerm(v) + if err != nil { + panic(err) + } + elems = append(elems, t.SQLString()) } return fmt.Sprintf("ARRAY [%s]", strings.Join(elems, ",")) } + +func trimQuotes(s string) string { + return strings.Trim(s, "\"") +} From 8f9295316c131fd1680b90226599aa9be06010aa Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 27 Sep 2022 13:52:17 -0400 Subject: [PATCH 03/18] Ensure all test cases can compile into SQL clauses --- coderd/rbac/authz_internal_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/coderd/rbac/authz_internal_test.go b/coderd/rbac/authz_internal_test.go index bf5342c3b27e4..67049e350f8db 100644 --- a/coderd/rbac/authz_internal_test.go +++ b/coderd/rbac/authz_internal_test.go @@ -781,6 +781,9 @@ func testAuthorize(t *testing.T, name string, subject subject, sets ...[]authTes partialAuthz, err := authorizer.Prepare(ctx, subject.UserID, subject.Roles, subject.Scope, a, c.resource.Type) require.NoError(t, err, "make prepared authorizer") + _, err = Compile(ctx, partialAuthz.partialQueries) + require.NoError(t, err, "compile prepared authorizer") + // Also check the rego policy can form a valid partial query result. // This ensures we can convert the queries into SQL WHERE clauses in the future. // If this function returns 'Support' sections, then we cannot convert the query into SQL. From cb5d5198bf02d569a2a69616adda99438dcd219f Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 27 Sep 2022 14:53:53 -0400 Subject: [PATCH 04/18] Do not export extra types --- coderd/rbac/query.go | 132 ++++++++++++++++++++++++++----------------- 1 file changed, 80 insertions(+), 52 deletions(-) diff --git a/coderd/rbac/query.go b/coderd/rbac/query.go index 04f0290e0159e..4c3388aa7e80a 100644 --- a/coderd/rbac/query.go +++ b/coderd/rbac/query.go @@ -1,25 +1,41 @@ package rbac import ( - "context" "fmt" + "strconv" "strings" "github.com/open-policy-agent/opa/ast" - - "golang.org/x/xerrors" - "github.com/open-policy-agent/opa/rego" + "golang.org/x/xerrors" ) -// Example python: https://github.com/open-policy-agent/contrib/tree/main/data_filter_example -// - -func Compile(ctx context.Context, partialQueries *rego.PartialQueries) (Expression, error) { +// Compile will convert a rego query AST into our custom types. The output is +// an AST that can be used to generate SQL. +func Compile(partialQueries *rego.PartialQueries) (Expression, error) { if len(partialQueries.Support) > 0 { return nil, xerrors.Errorf("cannot convert support rules, expect 0 found %d", len(partialQueries.Support)) } + // 0 queries means the result is "undefined". This is the same as "false". + if len(partialQueries.Queries) == 0 { + return &termBoolean{ + base: base{Rego: "false"}, + Value: false, + }, nil + } + + // Abort early if any of the "OR"'d expressions are the empty string. + // This is the same as "true". + for _, query := range partialQueries.Queries { + if query.String() == "" { + return &termBoolean{ + base: base{Rego: "true"}, + Value: true, + }, nil + } + } + result := make([]Expression, 0, len(partialQueries.Queries)) var builder strings.Builder for i := range partialQueries.Queries { @@ -33,14 +49,16 @@ func Compile(ctx context.Context, partialQueries *rego.PartialQueries) (Expressi } builder.WriteString(partialQueries.Queries[i].String()) } - return ExpOr{ - Base: Base{ + return expOr{ + base: base{ Rego: builder.String(), }, Expressions: result, }, nil } +// processQuery processes an entire set of expressions and joins them with +// "AND". func processQuery(query ast.Body) (Expression, error) { expressions := make([]Expression, 0, len(query)) for _, astExpr := range query { @@ -51,8 +69,8 @@ func processQuery(query ast.Body) (Expression, error) { expressions = append(expressions, expr) } - return ExpAnd{ - Base: Base{ + return expAnd{ + base: base{ Rego: query.String(), }, Expressions: expressions, @@ -65,15 +83,15 @@ func processExpression(expr *ast.Expr) (Expression, error) { } op := expr.Operator().String() - base := Base{Rego: op} + base := base{Rego: op} switch op { case "neq", "eq", "equal": terms, err := processTerms(2, expr.Operands()) if err != nil { return nil, xerrors.Errorf("invalid '%s' expression: %w", op, err) } - return &OpEqual{ - Base: base, + return &opEqual{ + base: base, Terms: [2]Term{terms[0], terms[1]}, Not: op == "neq", }, nil @@ -82,12 +100,10 @@ func processExpression(expr *ast.Expr) (Expression, error) { if err != nil { return nil, xerrors.Errorf("invalid '%s' expression: %w", op, err) } - return &OpInternalMember2{ - Base: base, + return &opInternalMember2{ + base: base, Terms: [2]Term{terms[0], terms[1]}, }, nil - - //case "eq", "equal": default: return nil, xerrors.Errorf("invalid expression: operator %s not supported", op) } @@ -111,7 +127,7 @@ func processTerms(expected int, terms []*ast.Term) ([]Term, error) { } func processTerm(term *ast.Term) (Term, error) { - base := Base{Rego: term.String()} + base := base{Rego: term.String()} switch v := term.Value.(type) { case ast.Ref: // A ref is a set of terms. If the first term is a var, then the @@ -121,57 +137,57 @@ func processTerm(term *ast.Term) (Term, error) { for _, p := range v[1:] { name += "." + trimQuotes(p.String()) } - return &TermVariable{ - Base: base, + return &termVariable{ + base: base, Name: name, }, nil } else { return nil, xerrors.Errorf("invalid term: ref must start with a var, started with %T", v[0]) } case ast.Var: - return &TermVariable{ + return &termVariable{ Name: trimQuotes(v.String()), - Base: base, + base: base, }, nil case ast.String: - return &TermString{ + return &termString{ Value: trimQuotes(v.String()), - Base: base, + base: base, }, nil case ast.Set: - return &TermSet{ + return &termSet{ Value: v, - Base: base, + base: base, }, nil default: return nil, xerrors.Errorf("invalid term: %T not supported, %q", v, term.String()) } } -type Base struct { +type base struct { // Rego is the original rego string Rego string } -func (b Base) RegoString() string { +func (b base) RegoString() string { return b.Rego } // Expression comprises a set of terms, operators, and functions. All // expressions return a boolean value. // -// Eg: neq(input.object.org_owner, "") +// Eg: neq(input.object.org_owner, "") AND input.object.org_owner == "foo" type Expression interface { RegoString() string SQLString() string } -type ExpAnd struct { - Base +type expAnd struct { + base Expressions []Expression } -func (t ExpAnd) SQLString() string { +func (t expAnd) SQLString() string { exprs := make([]string, 0, len(t.Expressions)) for _, expr := range t.Expressions { exprs = append(exprs, expr.SQLString()) @@ -179,12 +195,12 @@ func (t ExpAnd) SQLString() string { return "(" + strings.Join(exprs, " AND ") + ")" } -type ExpOr struct { - Base +type expOr struct { + base Expressions []Expression } -func (t ExpOr) SQLString() string { +func (t expOr) SQLString() string { exprs := make([]string, 0, len(t.Expressions)) for _, expr := range t.Expressions { exprs = append(exprs, expr.SQLString()) @@ -202,14 +218,14 @@ type Operator interface { SQLString() string } -type OpEqual struct { - Base +type opEqual struct { + base Terms [2]Term // For NotEqual Not bool } -func (t OpEqual) SQLString() string { +func (t opEqual) SQLString() string { op := "=" if t.Not { op = "!=" @@ -217,12 +233,14 @@ func (t OpEqual) SQLString() string { return fmt.Sprintf("%s %s %s", t.Terms[0].SQLString(), op, t.Terms[1].SQLString()) } -type OpInternalMember2 struct { - Base +// opInternalMember2 is checking if the first term is a member of the second term. +// The second term is a set or list. +type opInternalMember2 struct { + base Terms [2]Term } -func (t OpInternalMember2) SQLString() string { +func (t opInternalMember2) SQLString() string { return fmt.Sprintf("%s = ANY(%s)", t.Terms[0].SQLString(), t.Terms[1].SQLString()) } @@ -235,30 +253,31 @@ type Term interface { RegoString() string } -type TermString struct { - Base +type termString struct { + base Value string } -func (t TermString) SQLString() string { +func (t termString) SQLString() string { return "'" + t.Value + "'" } -type TermVariable struct { - Base +type termVariable struct { + base Name string } -func (t TermVariable) SQLString() string { +func (t termVariable) SQLString() string { return t.Name } -type TermSet struct { - Base +// termSet is a set of unique terms. +type termSet struct { + base Value ast.Set } -func (t TermSet) SQLString() string { +func (t termSet) SQLString() string { values := t.Value.Slice() elems := make([]string, 0, len(values)) // TODO: Handle different typed terms? @@ -273,6 +292,15 @@ func (t TermSet) SQLString() string { return fmt.Sprintf("ARRAY [%s]", strings.Join(elems, ",")) } +type termBoolean struct { + base + Value bool +} + +func (t termBoolean) SQLString() string { + return strconv.FormatBool(t.Value) +} + func trimQuotes(s string) string { return strings.Trim(s, "\"") } From 2bd01658a4c9706556bb5bb1a398cadb1deffdbe Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 27 Sep 2022 16:31:25 -0400 Subject: [PATCH 05/18] Add custom query with rbac filter --- coderd/database/custom_queries.go | 59 ++++++++++++++++++++ coderd/database/databasefake/databasefake.go | 10 ++++ coderd/database/db.go | 7 +++ coderd/rbac/authz_internal_test.go | 2 +- coderd/rbac/query.go | 49 +++++++++------- coderd/rbac/query_internal_test.go | 8 ++- 6 files changed, 113 insertions(+), 22 deletions(-) create mode 100644 coderd/database/custom_queries.go diff --git a/coderd/database/custom_queries.go b/coderd/database/custom_queries.go new file mode 100644 index 0000000000000..6dda9021e3317 --- /dev/null +++ b/coderd/database/custom_queries.go @@ -0,0 +1,59 @@ +package database + +import ( + "context" + "fmt" + + "github.com/coder/coder/coderd/rbac" + + "github.com/lib/pq" +) + +// AuthorizedGetWorkspaces returns all workspaces that the user is authorized to access. +func (q *sqlQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) { + query := fmt.Sprintf("%s AND %s", getWorkspaces, authorizedFilter.SQLString(rbac.SQLConfig{ + VariableRenames: map[string]string{ + "input.object.org_owner": "organization_id", + "input.object.owner": "owner_id", + }, + })) + rows, err := q.db.QueryContext(ctx, query, + arg.Deleted, + arg.OwnerID, + arg.OwnerUsername, + arg.TemplateName, + pq.Array(arg.TemplateIds), + arg.Name, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Workspace + for rows.Next() { + var i Workspace + if err := rows.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.OwnerID, + &i.OrganizationID, + &i.TemplateID, + &i.Deleted, + &i.Name, + &i.AutostartSchedule, + &i.Ttl, + &i.LastUsedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index f0cdee99f254b..d7fa4cd248f4b 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -566,6 +566,16 @@ func (q *fakeQuerier) GetWorkspaces(_ context.Context, arg database.GetWorkspace return workspaces, nil } +func (q *fakeQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]database.Workspace, error) { + workspaces, err := q.GetWorkspaces(ctx, arg) + if err != nil { + return nil, err + } + + // TODO: Filter workspaces + return workspaces, nil +} + func (q *fakeQuerier) GetWorkspaceByID(_ context.Context, id uuid.UUID) (database.Workspace, error) { q.mutex.RLock() defer q.mutex.RUnlock() diff --git a/coderd/database/db.go b/coderd/database/db.go index 0a9e8928df253..80a5748de7263 100644 --- a/coderd/database/db.go +++ b/coderd/database/db.go @@ -13,6 +13,8 @@ import ( "database/sql" "errors" + "github.com/coder/coder/coderd/rbac" + "golang.org/x/xerrors" ) @@ -20,10 +22,15 @@ import ( // It extends the generated interface to add transaction support. type Store interface { querier + customQuerier InTx(func(Store) error) error } +type customQuerier interface { + AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) +} + // DBTX represents a database connection or transaction. type DBTX interface { ExecContext(context.Context, string, ...interface{}) (sql.Result, error) diff --git a/coderd/rbac/authz_internal_test.go b/coderd/rbac/authz_internal_test.go index 67049e350f8db..7a646de754ab5 100644 --- a/coderd/rbac/authz_internal_test.go +++ b/coderd/rbac/authz_internal_test.go @@ -781,7 +781,7 @@ func testAuthorize(t *testing.T, name string, subject subject, sets ...[]authTes partialAuthz, err := authorizer.Prepare(ctx, subject.UserID, subject.Roles, subject.Scope, a, c.resource.Type) require.NoError(t, err, "make prepared authorizer") - _, err = Compile(ctx, partialAuthz.partialQueries) + _, err = Compile(partialAuthz.partialQueries) require.NoError(t, err, "compile prepared authorizer") // Also check the rego policy can form a valid partial query result. diff --git a/coderd/rbac/query.go b/coderd/rbac/query.go index 4c3388aa7e80a..e715160e5587c 100644 --- a/coderd/rbac/query.go +++ b/coderd/rbac/query.go @@ -10,6 +10,16 @@ import ( "golang.org/x/xerrors" ) +type SQLConfig struct { + // VariableRenames renames rego variables to sql columns + VariableRenames map[string]string +} + +type AuthorizeFilter interface { + RegoString() string + SQLString(cfg SQLConfig) string +} + // Compile will convert a rego query AST into our custom types. The output is // an AST that can be used to generate SQL. func Compile(partialQueries *rego.PartialQueries) (Expression, error) { @@ -178,8 +188,7 @@ func (b base) RegoString() string { // // Eg: neq(input.object.org_owner, "") AND input.object.org_owner == "foo" type Expression interface { - RegoString() string - SQLString() string + AuthorizeFilter } type expAnd struct { @@ -187,10 +196,10 @@ type expAnd struct { Expressions []Expression } -func (t expAnd) SQLString() string { +func (t expAnd) SQLString(cfg SQLConfig) string { exprs := make([]string, 0, len(t.Expressions)) for _, expr := range t.Expressions { - exprs = append(exprs, expr.SQLString()) + exprs = append(exprs, expr.SQLString(cfg)) } return "(" + strings.Join(exprs, " AND ") + ")" } @@ -200,10 +209,10 @@ type expOr struct { Expressions []Expression } -func (t expOr) SQLString() string { +func (t expOr) SQLString(cfg SQLConfig) string { exprs := make([]string, 0, len(t.Expressions)) for _, expr := range t.Expressions { - exprs = append(exprs, expr.SQLString()) + exprs = append(exprs, expr.SQLString(cfg)) } return "(" + strings.Join(exprs, " OR ") + ")" @@ -214,8 +223,7 @@ func (t expOr) SQLString() string { // // Eg: "=", "neq", "internal.member_2", etc. type Operator interface { - RegoString() string - SQLString() string + AuthorizeFilter } type opEqual struct { @@ -225,12 +233,12 @@ type opEqual struct { Not bool } -func (t opEqual) SQLString() string { +func (t opEqual) SQLString(cfg SQLConfig) string { op := "=" if t.Not { op = "!=" } - return fmt.Sprintf("%s %s %s", t.Terms[0].SQLString(), op, t.Terms[1].SQLString()) + return fmt.Sprintf("%s %s %s", t.Terms[0].SQLString(cfg), op, t.Terms[1].SQLString(cfg)) } // opInternalMember2 is checking if the first term is a member of the second term. @@ -240,8 +248,8 @@ type opInternalMember2 struct { Terms [2]Term } -func (t opInternalMember2) SQLString() string { - return fmt.Sprintf("%s = ANY(%s)", t.Terms[0].SQLString(), t.Terms[1].SQLString()) +func (t opInternalMember2) SQLString(cfg SQLConfig) string { + return fmt.Sprintf("%s = ANY(%s)", t.Terms[0].SQLString(cfg), t.Terms[1].SQLString(cfg)) } // Term is a single value in an expression. Terms can be variables or constants. @@ -249,8 +257,7 @@ func (t opInternalMember2) SQLString() string { // Eg: "f9d6fb75-b59b-4363-ab6b-ae9d26b679d7", "input.object.org_owner", // "{"f9d6fb75-b59b-4363-ab6b-ae9d26b679d7"}" type Term interface { - SQLString() string - RegoString() string + AuthorizeFilter } type termString struct { @@ -258,7 +265,7 @@ type termString struct { Value string } -func (t termString) SQLString() string { +func (t termString) SQLString(_ SQLConfig) string { return "'" + t.Value + "'" } @@ -267,7 +274,11 @@ type termVariable struct { Name string } -func (t termVariable) SQLString() string { +func (t termVariable) SQLString(cfg SQLConfig) string { + rename, ok := cfg.VariableRenames[t.Name] + if ok { + return rename + } return t.Name } @@ -277,7 +288,7 @@ type termSet struct { Value ast.Set } -func (t termSet) SQLString() string { +func (t termSet) SQLString(cfg SQLConfig) string { values := t.Value.Slice() elems := make([]string, 0, len(values)) // TODO: Handle different typed terms? @@ -286,7 +297,7 @@ func (t termSet) SQLString() string { if err != nil { panic(err) } - elems = append(elems, t.SQLString()) + elems = append(elems, t.SQLString(cfg)) } return fmt.Sprintf("ARRAY [%s]", strings.Join(elems, ",")) @@ -297,7 +308,7 @@ type termBoolean struct { Value bool } -func (t termBoolean) SQLString() string { +func (t termBoolean) SQLString(_ SQLConfig) string { return strconv.FormatBool(t.Value) } diff --git a/coderd/rbac/query_internal_test.go b/coderd/rbac/query_internal_test.go index c24a48d91ed17..67964e754f708 100644 --- a/coderd/rbac/query_internal_test.go +++ b/coderd/rbac/query_internal_test.go @@ -30,9 +30,13 @@ func TestCompileQuery(t *testing.T) { part, err := auth.Prepare(ctx, user.UserID, user.Roles, user.Scope, action, object.Type) require.NoError(t, err) - result, err := Compile(ctx, part.partialQueries) + result, err := Compile(part.partialQueries) require.NoError(t, err) fmt.Println("Rego: ", result.RegoString()) - fmt.Println("SQL: ", result.SQLString()) + fmt.Println("SQL: ", result.SQLString(SQLConfig{ + map[string]string{ + "input.object.org_owner": "organization_id", + }, + })) } From 364498c4cc8d5ffa4ec327e479cd6ff2b0cc07ed Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 27 Sep 2022 21:07:59 -0400 Subject: [PATCH 06/18] First draft of a custom authorized db call --- coderd/authorize.go | 15 +++ coderd/coderdtest/authorize.go | 18 ++++ coderd/database/custom_queries.go | 8 +- coderd/database/databasefake/databasefake.go | 21 ++-- coderd/rbac/authz.go | 15 +++ coderd/rbac/partial.go | 8 ++ coderd/rbac/query.go | 107 ++++++++++++++++--- coderd/workspaces.go | 7 +- 8 files changed, 168 insertions(+), 31 deletions(-) diff --git a/coderd/authorize.go b/coderd/authorize.go index 0a6953cb1231e..7ed0e404612d1 100644 --- a/coderd/authorize.go +++ b/coderd/authorize.go @@ -85,6 +85,21 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action rbac.Action, object r return true } +func (a *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.AuthorizeFilter, error) { + roles := httpmw.UserAuthorization(r) + prepared, err := a.Authorizer.PrepareByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), action, objectType) + if err != nil { + return nil, xerrors.Errorf("prepare filter: %w", err) + } + + filter, err := prepared.Compile() + if err != nil { + return nil, xerrors.Errorf("compile filter: %w", err) + } + + return filter, nil +} + // checkAuthorization returns if the current API key can use the given // permissions, factoring in the current user's roles and the API key scopes. func (api *API) checkAuthorization(rw http.ResponseWriter, r *http.Request) { diff --git a/coderd/coderdtest/authorize.go b/coderd/coderdtest/authorize.go index 60fca71cd0062..9ee6313e66efc 100644 --- a/coderd/coderdtest/authorize.go +++ b/coderd/coderdtest/authorize.go @@ -553,3 +553,21 @@ type fakePreparedAuthorizer struct { func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Object) error { return f.Original.ByRoleName(ctx, f.SubjectID, f.Roles, f.Scope, f.Action, object) } + +// Compile returns a compiled version of the authorizer that will work for +// in memory databases. This fake version will not work against a SQL database. +func (f *fakePreparedAuthorizer) Compile() (rbac.AuthorizeFilter, error) { + return f, nil +} + +func (f *fakePreparedAuthorizer) Eval(object rbac.Object) bool { + return f.Original.ByRoleName(context.Background(), f.SubjectID, f.Roles, f.Scope, f.Action, object) == nil +} + +func (f *fakePreparedAuthorizer) RegoString() string { + panic("not implemented") +} + +func (f *fakePreparedAuthorizer) SQLString(_ rbac.SQLConfig) string { + panic("not implemented") +} diff --git a/coderd/database/custom_queries.go b/coderd/database/custom_queries.go index 6dda9021e3317..756f98bcae062 100644 --- a/coderd/database/custom_queries.go +++ b/coderd/database/custom_queries.go @@ -4,6 +4,8 @@ import ( "context" "fmt" + "golang.org/x/xerrors" + "github.com/coder/coder/coderd/rbac" "github.com/lib/pq" @@ -13,8 +15,8 @@ import ( func (q *sqlQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) { query := fmt.Sprintf("%s AND %s", getWorkspaces, authorizedFilter.SQLString(rbac.SQLConfig{ VariableRenames: map[string]string{ - "input.object.org_owner": "organization_id", - "input.object.owner": "owner_id", + "input.object.org_owner": "organization_id::text", + "input.object.owner": "owner_id::text", }, })) rows, err := q.db.QueryContext(ctx, query, @@ -26,7 +28,7 @@ func (q *sqlQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspa arg.Name, ) if err != nil { - return nil, err + return nil, xerrors.Errorf("get authorized workspaces: %w", err) } defer rows.Close() var items []Workspace diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index d7fa4cd248f4b..1dcecf080dca5 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -520,7 +520,12 @@ func (q *fakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U }, nil } -func (q *fakeQuerier) GetWorkspaces(_ context.Context, arg database.GetWorkspacesParams) ([]database.Workspace, error) { +func (q *fakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams) ([]database.Workspace, error) { + workspaces, err := q.AuthorizedGetWorkspaces(ctx, arg, nil) + return workspaces, err +} + +func (q *fakeQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]database.Workspace, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -560,19 +565,13 @@ func (q *fakeQuerier) GetWorkspaces(_ context.Context, arg database.GetWorkspace continue } } - workspaces = append(workspaces, workspace) - } - - return workspaces, nil -} -func (q *fakeQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]database.Workspace, error) { - workspaces, err := q.GetWorkspaces(ctx, arg) - if err != nil { - return nil, err + if authorizedFilter != nil && !authorizedFilter.Eval(workspace.RBACObject()) { + continue + } + workspaces = append(workspaces, workspace) } - // TODO: Filter workspaces return workspaces, nil } diff --git a/coderd/rbac/authz.go b/coderd/rbac/authz.go index 3bc37d056400a..526e028c7359d 100644 --- a/coderd/rbac/authz.go +++ b/coderd/rbac/authz.go @@ -21,6 +21,21 @@ type Authorizer interface { type PreparedAuthorized interface { Authorize(ctx context.Context, object Object) error + Compile() (AuthorizeFilter, error) +} + +func (a *RegoAuthorizer) SQLFilter(ctx context.Context, subjID string, subjRoles []string, scope Scope, action Action, objectType string) (AuthorizeFilter, error) { + prepared, err := a.PrepareByRoleName(ctx, subjID, subjRoles, scope, action, objectType) + if err != nil { + return nil, xerrors.Errorf("filter: %w", err) + } + + filter, err := prepared.Compile() + if err != nil { + return nil, xerrors.Errorf("filter: %w", err) + } + + return filter, nil } // Filter takes in a list of objects, and will filter the list removing all diff --git a/coderd/rbac/partial.go b/coderd/rbac/partial.go index 7fea09d374e17..6dfd0827f39f4 100644 --- a/coderd/rbac/partial.go +++ b/coderd/rbac/partial.go @@ -28,6 +28,14 @@ type PartialAuthorizer struct { var _ PreparedAuthorized = (*PartialAuthorizer)(nil) +func (pa *PartialAuthorizer) Compile() (AuthorizeFilter, error) { + filter, err := Compile(pa.partialQueries) + if err != nil { + return nil, xerrors.Errorf("compile: %w", err) + } + return filter, nil +} + func (pa *PartialAuthorizer) Authorize(ctx context.Context, object Object) error { if pa.alwaysTrue { return nil diff --git a/coderd/rbac/query.go b/coderd/rbac/query.go index e715160e5587c..9d6dccf109a8a 100644 --- a/coderd/rbac/query.go +++ b/coderd/rbac/query.go @@ -18,6 +18,9 @@ type SQLConfig struct { type AuthorizeFilter interface { RegoString() string SQLString(cfg SQLConfig) string + // Eval is required for the fake in memory database to work. The in memory + // database can use this function to filter the results. + Eval(object Object) bool } // Compile will convert a rego query AST into our custom types. The output is @@ -165,8 +168,18 @@ func processTerm(term *ast.Term) (Term, error) { base: base, }, nil case ast.Set: + slice := v.Slice() + set := make([]Term, 0, len(slice)) + for _, elem := range slice { + processed, err := processTerm(elem) + if err != nil { + return nil, xerrors.Errorf("invalid set term %s: %w", elem.String(), err) + } + set = append(set, processed) + } + return &termSet{ - Value: v, + Value: set, base: base, }, nil default: @@ -204,6 +217,15 @@ func (t expAnd) SQLString(cfg SQLConfig) string { return "(" + strings.Join(exprs, " AND ") + ")" } +func (t expAnd) Eval(object Object) bool { + for _, expr := range t.Expressions { + if !expr.Eval(object) { + return false + } + } + return true +} + type expOr struct { base Expressions []Expression @@ -218,12 +240,21 @@ func (t expOr) SQLString(cfg SQLConfig) string { return "(" + strings.Join(exprs, " OR ") + ")" } +func (t expOr) Eval(object Object) bool { + for _, expr := range t.Expressions { + if expr.Eval(object) { + return true + } + } + return false +} + // Operator joins terms together to form an expression. // Operators are also expressions. // // Eg: "=", "neq", "internal.member_2", etc. type Operator interface { - AuthorizeFilter + Expression } type opEqual struct { @@ -241,6 +272,14 @@ func (t opEqual) SQLString(cfg SQLConfig) string { return fmt.Sprintf("%s %s %s", t.Terms[0].SQLString(cfg), op, t.Terms[1].SQLString(cfg)) } +func (t opEqual) Eval(object Object) bool { + a, b := t.Terms[0].Eval(object), t.Terms[1].Eval(object) + if t.Not { + return a != b + } + return a == b +} + // opInternalMember2 is checking if the first term is a member of the second term. // The second term is a set or list. type opInternalMember2 struct { @@ -248,6 +287,20 @@ type opInternalMember2 struct { Terms [2]Term } +func (t opInternalMember2) Eval(object Object) bool { + a, b := t.Terms[0].Eval(object), t.Terms[1].Eval(object) + bset, ok := b.([]interface{}) + if !ok { + return false + } + for _, elem := range bset { + if a == elem { + return true + } + } + return false +} + func (t opInternalMember2) SQLString(cfg SQLConfig) string { return fmt.Sprintf("%s = ANY(%s)", t.Terms[0].SQLString(cfg), t.Terms[1].SQLString(cfg)) } @@ -257,7 +310,11 @@ func (t opInternalMember2) SQLString(cfg SQLConfig) string { // Eg: "f9d6fb75-b59b-4363-ab6b-ae9d26b679d7", "input.object.org_owner", // "{"f9d6fb75-b59b-4363-ab6b-ae9d26b679d7"}" type Term interface { - AuthorizeFilter + RegoString() string + SQLString(cfg SQLConfig) string + // Eval will evaluate the term + // Terms can eval to any type. The operator/expression will type check. + Eval(object Object) interface{} } type termString struct { @@ -265,6 +322,10 @@ type termString struct { Value string } +func (t termString) Eval(_ Object) interface{} { + return t.Value +} + func (t termString) SQLString(_ SQLConfig) string { return "'" + t.Value + "'" } @@ -274,6 +335,19 @@ type termVariable struct { Name string } +func (t termVariable) Eval(obj Object) interface{} { + switch t.Name { + case "input.object.org_owner": + return obj.OrgID + case "input.object.owner": + return obj.Owner + case "input.object.type": + return obj.Type + default: + return fmt.Sprintf("'Unknown variable %s'", t.Name) + } +} + func (t termVariable) SQLString(cfg SQLConfig) string { rename, ok := cfg.VariableRenames[t.Name] if ok { @@ -285,19 +359,22 @@ func (t termVariable) SQLString(cfg SQLConfig) string { // termSet is a set of unique terms. type termSet struct { base - Value ast.Set + Value []Term +} + +func (t termSet) Eval(obj Object) interface{} { + set := make([]interface{}, 0, len(t.Value)) + for _, term := range t.Value { + set = append(set, term.Eval(obj)) + } + + return set } func (t termSet) SQLString(cfg SQLConfig) string { - values := t.Value.Slice() - elems := make([]string, 0, len(values)) - // TODO: Handle different typed terms? - for _, v := range t.Value.Slice() { - t, err := processTerm(v) - if err != nil { - panic(err) - } - elems = append(elems, t.SQLString(cfg)) + elems := make([]string, 0, len(t.Value)) + for _, v := range t.Value { + elems = append(elems, v.SQLString(cfg)) } return fmt.Sprintf("ARRAY [%s]", strings.Join(elems, ",")) @@ -308,6 +385,10 @@ type termBoolean struct { Value bool } +func (t termBoolean) Eval(_ Object) bool { + return t.Value +} + func (t termBoolean) SQLString(_ SQLConfig) string { return strconv.FormatBool(t.Value) } diff --git a/coderd/workspaces.go b/coderd/workspaces.go index c809242d8fed5..f2b92fea54fa3 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -111,17 +111,16 @@ func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) { filter.OwnerUsername = "" } - workspaces, err := api.Database.GetWorkspaces(ctx, filter) + sqlFilter, err := api.HTTPAuth.AuthorizeSQLFilter(r, rbac.ActionRead, rbac.ResourceWorkspace.Type) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error fetching workspaces.", + Message: "Internal error preparing sql filter.", Detail: err.Error(), }) return } - // Only return workspaces the user can read - workspaces, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, workspaces) + workspaces, err := api.Database.AuthorizedGetWorkspaces(ctx, filter, sqlFilter) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching workspaces.", From d516be7840b5c29b41786ce8b580000dbc10f089 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 28 Sep 2022 15:00:10 -0400 Subject: [PATCH 07/18] Add comments + tests --- coderd/authorize.go | 8 +++ coderd/database/custom_queries.go | 9 ++- coderd/database/databasefake/databasefake.go | 2 + coderd/database/db.go | 7 +-- coderd/rbac/authz.go | 14 ----- coderd/rbac/authz_internal_test.go | 2 + coderd/rbac/query.go | 50 +++++++++++++--- coderd/rbac/query_internal_test.go | 60 +++++++++----------- 8 files changed, 89 insertions(+), 63 deletions(-) diff --git a/coderd/authorize.go b/coderd/authorize.go index 7ed0e404612d1..166cae76e841c 100644 --- a/coderd/authorize.go +++ b/coderd/authorize.go @@ -13,6 +13,9 @@ import ( "github.com/coder/coder/codersdk" ) +// AuthorizeFilter takes a list of objects and returns the filtered list of +// objects that the user is authorized to perform the given action on. +// This is faster than calling Authorize() on each object. func AuthorizeFilter[O rbac.Objecter](h *HTTPAuthorizer, r *http.Request, action rbac.Action, objects []O) ([]O, error) { roles := httpmw.UserAuthorization(r) objects, err := rbac.Filter(r.Context(), h.Authorizer, roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), action, objects) @@ -85,6 +88,11 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action rbac.Action, object r return true } +// AuthorizeSQLFilter returns an authorization filter that can used in a +// SQL 'WHERE' clause. If the filter is used, the resulting rows returned +// from postgres are already authorized, and the caller does not need to +// call 'Authorize()' on the returned objects. +// Note the authorization is only for the given action and object type. func (a *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.AuthorizeFilter, error) { roles := httpmw.UserAuthorization(r) prepared, err := a.Authorizer.PrepareByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), action, objectType) diff --git a/coderd/database/custom_queries.go b/coderd/database/custom_queries.go index 756f98bcae062..956fbeb031a7f 100644 --- a/coderd/database/custom_queries.go +++ b/coderd/database/custom_queries.go @@ -4,14 +4,19 @@ import ( "context" "fmt" + "github.com/lib/pq" "golang.org/x/xerrors" "github.com/coder/coder/coderd/rbac" - - "github.com/lib/pq" ) +type customQuerier interface { + AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) +} + // AuthorizedGetWorkspaces returns all workspaces that the user is authorized to access. +// This code is copied from `GetWorkspaces` and adds the authorized filter WHERE +// clause. func (q *sqlQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) { query := fmt.Sprintf("%s AND %s", getWorkspaces, authorizedFilter.SQLString(rbac.SQLConfig{ VariableRenames: map[string]string{ diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index d40194812f94d..973da51bd0a01 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -521,6 +521,7 @@ func (q *fakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U } func (q *fakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams) ([]database.Workspace, error) { + // A nil auth filter means no auth filter. workspaces, err := q.AuthorizedGetWorkspaces(ctx, arg, nil) return workspaces, err } @@ -566,6 +567,7 @@ func (q *fakeQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg database. } } + // If the filter exists, ensure the object is authorized. if authorizedFilter != nil && !authorizedFilter.Eval(workspace.RBACObject()) { continue } diff --git a/coderd/database/db.go b/coderd/database/db.go index 80a5748de7263..9997a88f1e148 100644 --- a/coderd/database/db.go +++ b/coderd/database/db.go @@ -13,8 +13,6 @@ import ( "database/sql" "errors" - "github.com/coder/coder/coderd/rbac" - "golang.org/x/xerrors" ) @@ -22,15 +20,12 @@ import ( // It extends the generated interface to add transaction support. type Store interface { querier + // customQuerier contains custom queries that are not generated. customQuerier InTx(func(Store) error) error } -type customQuerier interface { - AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) -} - // DBTX represents a database connection or transaction. type DBTX interface { ExecContext(context.Context, string, ...interface{}) (sql.Result, error) diff --git a/coderd/rbac/authz.go b/coderd/rbac/authz.go index 526e028c7359d..237fb7f365188 100644 --- a/coderd/rbac/authz.go +++ b/coderd/rbac/authz.go @@ -24,20 +24,6 @@ type PreparedAuthorized interface { Compile() (AuthorizeFilter, error) } -func (a *RegoAuthorizer) SQLFilter(ctx context.Context, subjID string, subjRoles []string, scope Scope, action Action, objectType string) (AuthorizeFilter, error) { - prepared, err := a.PrepareByRoleName(ctx, subjID, subjRoles, scope, action, objectType) - if err != nil { - return nil, xerrors.Errorf("filter: %w", err) - } - - filter, err := prepared.Compile() - if err != nil { - return nil, xerrors.Errorf("filter: %w", err) - } - - return filter, nil -} - // Filter takes in a list of objects, and will filter the list removing all // the elements the subject does not have permission for. All objects must be // of the same type. diff --git a/coderd/rbac/authz_internal_test.go b/coderd/rbac/authz_internal_test.go index 7a646de754ab5..e7b3b8522bcd1 100644 --- a/coderd/rbac/authz_internal_test.go +++ b/coderd/rbac/authz_internal_test.go @@ -781,6 +781,8 @@ func testAuthorize(t *testing.T, name string, subject subject, sets ...[]authTes partialAuthz, err := authorizer.Prepare(ctx, subject.UserID, subject.Roles, subject.Scope, a, c.resource.Type) require.NoError(t, err, "make prepared authorizer") + // Ensure the partial can compile to a SQL clause. + // This does not guarantee that the clause is valid SQL. _, err = Compile(partialAuthz.partialQueries) require.NoError(t, err, "compile prepared authorizer") diff --git a/coderd/rbac/query.go b/coderd/rbac/query.go index 9d6dccf109a8a..8cc3e10d05aab 100644 --- a/coderd/rbac/query.go +++ b/coderd/rbac/query.go @@ -12,11 +12,16 @@ import ( type SQLConfig struct { // VariableRenames renames rego variables to sql columns + // Example: + // "input.object.org_owner": "organization_id::text" + // "input.object.owner": "owner_id::text" VariableRenames map[string]string } type AuthorizeFilter interface { + // RegoString is used in debugging to see the original rego expression. RegoString() string + // SQLString returns the SQL expression that can be used in a WHERE clause. SQLString(cfg SQLConfig) string // Eval is required for the fake in memory database to work. The in memory // database can use this function to filter the results. @@ -92,7 +97,18 @@ func processQuery(query ast.Body) (Expression, error) { func processExpression(expr *ast.Expr) (Expression, error) { if !expr.IsCall() { - return nil, xerrors.Errorf("invalid expression: function calls not supported") + // This could be a single term that is a valid expression. + if term, ok := expr.Terms.(*ast.Term); ok { + value, err := processTerm(term) + if err != nil { + return nil, xerrors.Errorf("single term expression: %w", err) + } + if boolExp, ok := value.(Expression); ok { + return boolExp, nil + } + // Default to error. + } + return nil, xerrors.Errorf("invalid expression: single non-boolean terms not supported") } op := expr.Operator().String() @@ -142,6 +158,11 @@ func processTerms(expected int, terms []*ast.Term) ([]Term, error) { func processTerm(term *ast.Term) (Term, error) { base := base{Rego: term.String()} switch v := term.Value.(type) { + case ast.Boolean: + return &termBoolean{ + base: base, + Value: bool(v), + }, nil case ast.Ref: // A ref is a set of terms. If the first term is a var, then the // following terms are the path to the value. @@ -210,6 +231,10 @@ type expAnd struct { } func (t expAnd) SQLString(cfg SQLConfig) string { + if len(t.Expressions) == 1 { + return t.Expressions[0].SQLString(cfg) + } + exprs := make([]string, 0, len(t.Expressions)) for _, expr := range t.Expressions { exprs = append(exprs, expr.SQLString(cfg)) @@ -232,11 +257,14 @@ type expOr struct { } func (t expOr) SQLString(cfg SQLConfig) string { + if len(t.Expressions) == 1 { + return t.Expressions[0].SQLString(cfg) + } + exprs := make([]string, 0, len(t.Expressions)) for _, expr := range t.Expressions { exprs = append(exprs, expr.SQLString(cfg)) } - return "(" + strings.Join(exprs, " OR ") + ")" } @@ -273,7 +301,7 @@ func (t opEqual) SQLString(cfg SQLConfig) string { } func (t opEqual) Eval(object Object) bool { - a, b := t.Terms[0].Eval(object), t.Terms[1].Eval(object) + a, b := t.Terms[0].EvalTerm(object), t.Terms[1].EvalTerm(object) if t.Not { return a != b } @@ -288,7 +316,7 @@ type opInternalMember2 struct { } func (t opInternalMember2) Eval(object Object) bool { - a, b := t.Terms[0].Eval(object), t.Terms[1].Eval(object) + a, b := t.Terms[0].EvalTerm(object), t.Terms[1].EvalTerm(object) bset, ok := b.([]interface{}) if !ok { return false @@ -314,7 +342,7 @@ type Term interface { SQLString(cfg SQLConfig) string // Eval will evaluate the term // Terms can eval to any type. The operator/expression will type check. - Eval(object Object) interface{} + EvalTerm(object Object) interface{} } type termString struct { @@ -322,7 +350,7 @@ type termString struct { Value string } -func (t termString) Eval(_ Object) interface{} { +func (t termString) EvalTerm(_ Object) interface{} { return t.Value } @@ -335,7 +363,7 @@ type termVariable struct { Name string } -func (t termVariable) Eval(obj Object) interface{} { +func (t termVariable) EvalTerm(obj Object) interface{} { switch t.Name { case "input.object.org_owner": return obj.OrgID @@ -362,10 +390,10 @@ type termSet struct { Value []Term } -func (t termSet) Eval(obj Object) interface{} { +func (t termSet) EvalTerm(obj Object) interface{} { set := make([]interface{}, 0, len(t.Value)) for _, term := range t.Value { - set = append(set, term.Eval(obj)) + set = append(set, term.EvalTerm(obj)) } return set @@ -389,6 +417,10 @@ func (t termBoolean) Eval(_ Object) bool { return t.Value } +func (t termBoolean) EvalTerm(_ Object) interface{} { + return t.Value +} + func (t termBoolean) SQLString(_ SQLConfig) string { return strconv.FormatBool(t.Value) } diff --git a/coderd/rbac/query_internal_test.go b/coderd/rbac/query_internal_test.go index 67964e754f708..63e5f0c578856 100644 --- a/coderd/rbac/query_internal_test.go +++ b/coderd/rbac/query_internal_test.go @@ -1,42 +1,38 @@ package rbac import ( - "context" - "fmt" "testing" - "github.com/stretchr/testify/require" + "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/rego" - "github.com/google/uuid" + "github.com/stretchr/testify/require" ) func TestCompileQuery(t *testing.T) { - ctx := context.Background() - defOrg := uuid.New() - unuseID := uuid.New() - - user := subject{ - UserID: "me", - Scope: must(ScopeRole(ScopeAll)), - Roles: []Role{ - must(RoleByName(RoleMember())), - must(RoleByName(RoleOrgMember(defOrg))), - }, - } - var action Action = ActionRead - object := ResourceWorkspace.InOrg(defOrg).WithOwner(unuseID.String()) - - auth := NewAuthorizer() - part, err := auth.Prepare(ctx, user.UserID, user.Roles, user.Scope, action, object.Type) - require.NoError(t, err) - - result, err := Compile(part.partialQueries) - require.NoError(t, err) - - fmt.Println("Rego: ", result.RegoString()) - fmt.Println("SQL: ", result.SQLString(SQLConfig{ - map[string]string{ - "input.object.org_owner": "organization_id", - }, - })) + t.Run("EmptyQuery", func(t *testing.T) { + expression, err := Compile(®o.PartialQueries{ + Queries: []ast.Body{ + must(ast.ParseBody("")), + }, + Support: []*ast.Module{}, + }) + require.NoError(t, err, "compile empty") + + require.Equal(t, "true", expression.RegoString(), "empty query is rego 'true'") + require.Equal(t, "true", expression.SQLString(SQLConfig{}), "empty query is sql 'true'") + }) + + t.Run("TrueQuery", func(t *testing.T) { + expression, err := Compile(®o.PartialQueries{ + Queries: []ast.Body{ + must(ast.ParseBody("true")), + }, + Support: []*ast.Module{}, + }) + require.NoError(t, err, "compile empty") + + require.Equal(t, "true", expression.RegoString(), "true query is rego 'true'") + require.Equal(t, "true", expression.SQLString(SQLConfig{}), "true query is sql 'true'") + }) } From 125931a4fbde70c227314d9626df25c36cfc0f5f Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 28 Sep 2022 17:15:13 -0400 Subject: [PATCH 08/18] Support better regex style matching for variables --- coderd/database/custom_queries.go | 7 +-- coderd/rbac/query.go | 79 ++++++++++++++++++++++++++++--- 2 files changed, 73 insertions(+), 13 deletions(-) diff --git a/coderd/database/custom_queries.go b/coderd/database/custom_queries.go index 956fbeb031a7f..bbcfc3e6ba813 100644 --- a/coderd/database/custom_queries.go +++ b/coderd/database/custom_queries.go @@ -18,12 +18,7 @@ type customQuerier interface { // This code is copied from `GetWorkspaces` and adds the authorized filter WHERE // clause. func (q *sqlQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) { - query := fmt.Sprintf("%s AND %s", getWorkspaces, authorizedFilter.SQLString(rbac.SQLConfig{ - VariableRenames: map[string]string{ - "input.object.org_owner": "organization_id::text", - "input.object.owner": "owner_id::text", - }, - })) + query := fmt.Sprintf("%s AND %s", getWorkspaces, authorizedFilter.SQLString(rbac.DefaultConfig())) rows, err := q.db.QueryContext(ctx, query, arg.Deleted, arg.OwnerID, diff --git a/coderd/rbac/query.go b/coderd/rbac/query.go index 8cc3e10d05aab..3851fac76eec0 100644 --- a/coderd/rbac/query.go +++ b/coderd/rbac/query.go @@ -2,6 +2,7 @@ package rbac import ( "fmt" + "regexp" "strconv" "strings" @@ -10,12 +11,66 @@ import ( "golang.org/x/xerrors" ) +const ( + VarTypeJsonbArray = "jsonb-array" + VarTypeUUID = "uuid" + VarTypeText = "text" +) + +type SQLColumn struct { + // RegoMatch matches the original variable string. + // If it is a match, then this variable config will apply. + RegoMatch *regexp.Regexp + // ColumnSelect is the name of the postgres column to select. + // Can use capture groups from RegoMatch with $1, $2, etc. + ColumnSelect string + + // Type indicates the postgres type of the column. Some expressions will + // need to know this in order to determine what SQL to produce. + // An example is if the variable is a jsonb array, the "contains" SQL + // query is `"value"' @> variable.` instead of `'value' = ANY(variable)`. + // This type is only needed to be provided + Type string +} + type SQLConfig struct { - // VariableRenames renames rego variables to sql columns + // Variables is a map of rego variable names to SQL columns. // Example: - // "input.object.org_owner": "organization_id::text" - // "input.object.owner": "owner_id::text" - VariableRenames map[string]string + // "input\.object\.org_owner": SQLColumn{ + // ColumnSelect: "organization_id", + // Type: VarTypeUUID + // } + // "input\.object\.owner": SQLColumn{ + // ColumnSelect: "owner_id", + // Type: VarTypeUUID + // } + // "input\.object\.group_acl\.(.*)": SQLColumn{ + // ColumnSelect: "group_acl->$1", + // Type: VarTypeJsonb + // } + Variables []SQLColumn +} + +func DefaultConfig() SQLConfig { + return SQLConfig{ + Variables: []SQLColumn{ + { + RegoMatch: regexp.MustCompile(`^input\.object\.acl_group_list\.([^.]*)$`), + ColumnSelect: "group_acl->$1", + Type: VarTypeJsonbArray, + }, + { + RegoMatch: regexp.MustCompile(`^input\.object\.org_owner$`), + ColumnSelect: "organization_id :: text", + Type: VarTypeUUID, + }, + { + RegoMatch: regexp.MustCompile(`^input\.object\.owner$`), + ColumnSelect: "owner_id :: text", + Type: VarTypeUUID, + }, + }, + } } type AuthorizeFilter interface { @@ -377,10 +432,20 @@ func (t termVariable) EvalTerm(obj Object) interface{} { } func (t termVariable) SQLString(cfg SQLConfig) string { - rename, ok := cfg.VariableRenames[t.Name] - if ok { - return rename + for _, col := range cfg.Variables { + matches := col.RegoMatch.FindStringSubmatch(t.Name) + if len(matches) > 0 { + // This config matches this variable. + replace := make([]string, 0, len(matches)*2) + for i, m := range matches { + replace = append(replace, fmt.Sprintf("$%d", i)) + replace = append(replace, m) + } + replacer := strings.NewReplacer(replace...) + return replacer.Replace(col.ColumnSelect) + } } + return t.Name } From fc58da5cfd7cfc5b5d678d48c72cd8ac72c70f70 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 28 Sep 2022 18:27:41 -0400 Subject: [PATCH 09/18] Handle jsonb arrays --- coderd/authorize.go | 4 +- coderd/coderdtest/authorize.go | 4 +- coderd/database/databasefake/databasefake.go | 2 +- coderd/rbac/query.go | 73 +++++++++++++---- coderd/rbac/query_internal_test.go | 84 +++++++++++++++++++- 5 files changed, 146 insertions(+), 21 deletions(-) diff --git a/coderd/authorize.go b/coderd/authorize.go index 166cae76e841c..c0b8eaba757ed 100644 --- a/coderd/authorize.go +++ b/coderd/authorize.go @@ -93,9 +93,9 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action rbac.Action, object r // from postgres are already authorized, and the caller does not need to // call 'Authorize()' on the returned objects. // Note the authorization is only for the given action and object type. -func (a *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.AuthorizeFilter, error) { +func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.AuthorizeFilter, error) { roles := httpmw.UserAuthorization(r) - prepared, err := a.Authorizer.PrepareByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), action, objectType) + prepared, err := h.Authorizer.PrepareByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), action, objectType) if err != nil { return nil, xerrors.Errorf("prepare filter: %w", err) } diff --git a/coderd/coderdtest/authorize.go b/coderd/coderdtest/authorize.go index 9ee6313e66efc..f3b759845a217 100644 --- a/coderd/coderdtest/authorize.go +++ b/coderd/coderdtest/authorize.go @@ -564,10 +564,10 @@ func (f *fakePreparedAuthorizer) Eval(object rbac.Object) bool { return f.Original.ByRoleName(context.Background(), f.SubjectID, f.Roles, f.Scope, f.Action, object) == nil } -func (f *fakePreparedAuthorizer) RegoString() string { +func (fakePreparedAuthorizer) RegoString() string { panic("not implemented") } -func (f *fakePreparedAuthorizer) SQLString(_ rbac.SQLConfig) string { +func (fakePreparedAuthorizer) SQLString(_ rbac.SQLConfig) string { panic("not implemented") } diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 973da51bd0a01..d66cb662a6636 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -526,7 +526,7 @@ func (q *fakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspa return workspaces, err } -func (q *fakeQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]database.Workspace, error) { +func (q *fakeQuerier) AuthorizedGetWorkspaces(_ context.Context, arg database.GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]database.Workspace, error) { q.mutex.RLock() defer q.mutex.RUnlock() diff --git a/coderd/rbac/query.go b/coderd/rbac/query.go index 3851fac76eec0..397b05bcc3552 100644 --- a/coderd/rbac/query.go +++ b/coderd/rbac/query.go @@ -11,10 +11,11 @@ import ( "golang.org/x/xerrors" ) +type TermType string + const ( - VarTypeJsonbArray = "jsonb-array" - VarTypeUUID = "uuid" - VarTypeText = "text" + VarTypeJsonbTextArray TermType = "jsonb-text-array" + VarTypeText TermType = "text" ) type SQLColumn struct { @@ -30,7 +31,7 @@ type SQLColumn struct { // An example is if the variable is a jsonb array, the "contains" SQL // query is `"value"' @> variable.` instead of `'value' = ANY(variable)`. // This type is only needed to be provided - Type string + Type TermType } type SQLConfig struct { @@ -57,17 +58,22 @@ func DefaultConfig() SQLConfig { { RegoMatch: regexp.MustCompile(`^input\.object\.acl_group_list\.([^.]*)$`), ColumnSelect: "group_acl->$1", - Type: VarTypeJsonbArray, + Type: VarTypeJsonbTextArray, + }, + { + RegoMatch: regexp.MustCompile(`^input\.object\.acl_user_list\.([^.]*)$`), + ColumnSelect: "user_acl->$1", + Type: VarTypeJsonbTextArray, }, { RegoMatch: regexp.MustCompile(`^input\.object\.org_owner$`), ColumnSelect: "organization_id :: text", - Type: VarTypeUUID, + Type: VarTypeText, }, { RegoMatch: regexp.MustCompile(`^input\.object\.owner$`), ColumnSelect: "owner_id :: text", - Type: VarTypeUUID, + Type: VarTypeText, }, }, } @@ -185,8 +191,9 @@ func processExpression(expr *ast.Expr) (Expression, error) { return nil, xerrors.Errorf("invalid '%s' expression: %w", op, err) } return &opInternalMember2{ - base: base, - Terms: [2]Term{terms[0], terms[1]}, + base: base, + Needle: terms[0], + Haystack: terms[1], }, nil default: return nil, xerrors.Errorf("invalid expression: operator %s not supported", op) @@ -230,9 +237,8 @@ func processTerm(term *ast.Term) (Term, error) { base: base, Name: name, }, nil - } else { - return nil, xerrors.Errorf("invalid term: ref must start with a var, started with %T", v[0]) } + return nil, xerrors.Errorf("invalid term: ref must start with a var, started with %T", v[0]) case ast.Var: return &termVariable{ Name: trimQuotes(v.String()), @@ -367,11 +373,12 @@ func (t opEqual) Eval(object Object) bool { // The second term is a set or list. type opInternalMember2 struct { base - Terms [2]Term + Needle Term + Haystack Term } func (t opInternalMember2) Eval(object Object) bool { - a, b := t.Terms[0].EvalTerm(object), t.Terms[1].EvalTerm(object) + a, b := t.Needle.EvalTerm(object), t.Haystack.EvalTerm(object) bset, ok := b.([]interface{}) if !ok { return false @@ -385,7 +392,20 @@ func (t opInternalMember2) Eval(object Object) bool { } func (t opInternalMember2) SQLString(cfg SQLConfig) string { - return fmt.Sprintf("%s = ANY(%s)", t.Terms[0].SQLString(cfg), t.Terms[1].SQLString(cfg)) + if haystack, ok := t.Haystack.(*termVariable); ok { + // This is a special case where the haystack is a jsonb array. + // The more general way to solve this would be to implement a fuller type + // system and handle type conversions for each supported type. + // Then we could determine that the haystack is always an "array" and + // implement the "contains" function on the array type. + // But that requires a lot more code to handle a lot of cases we don't + // actually care about. + if haystack.SQLType(cfg) == VarTypeJsonbTextArray { + return fmt.Sprintf("%s ? %s", haystack.SQLString(cfg), t.Needle.SQLString(cfg)) + } + } + + return fmt.Sprintf("%s = ANY(%s)", t.Needle.SQLString(cfg), t.Haystack.SQLString(cfg)) } // Term is a single value in an expression. Terms can be variables or constants. @@ -413,6 +433,10 @@ func (t termString) SQLString(_ SQLConfig) string { return "'" + t.Value + "'" } +func (t termString) SQLType(_ SQLConfig) TermType { + return VarTypeText +} + type termVariable struct { base Name string @@ -431,8 +455,15 @@ func (t termVariable) EvalTerm(obj Object) interface{} { } } +func (t termVariable) SQLType(cfg SQLConfig) TermType { + if col := t.ColumnConfig(cfg); col != nil { + return col.Type + } + return VarTypeText +} + func (t termVariable) SQLString(cfg SQLConfig) string { - for _, col := range cfg.Variables { + if col := t.ColumnConfig(cfg); col != nil { matches := col.RegoMatch.FindStringSubmatch(t.Name) if len(matches) > 0 { // This config matches this variable. @@ -449,6 +480,18 @@ func (t termVariable) SQLString(cfg SQLConfig) string { return t.Name } +// ColumnConfig returns the correct SQLColumn settings for the +// term. If there is no configured column, it will return nil. +func (t termVariable) ColumnConfig(cfg SQLConfig) *SQLColumn { + for _, col := range cfg.Variables { + matches := col.RegoMatch.MatchString(t.Name) + if matches { + return &col + } + } + return nil +} + // termSet is a set of unique terms. type termSet struct { base diff --git a/coderd/rbac/query_internal_test.go b/coderd/rbac/query_internal_test.go index 63e5f0c578856..50c1efc35b7aa 100644 --- a/coderd/rbac/query_internal_test.go +++ b/coderd/rbac/query_internal_test.go @@ -1,8 +1,11 @@ package rbac import ( + "context" + "fmt" "testing" + "github.com/google/uuid" "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/rego" @@ -10,7 +13,12 @@ import ( ) func TestCompileQuery(t *testing.T) { + t.Parallel() + opts := ast.ParserOptions{ + AllFutureKeywords: true, + } t.Run("EmptyQuery", func(t *testing.T) { + t.Parallel() expression, err := Compile(®o.PartialQueries{ Queries: []ast.Body{ must(ast.ParseBody("")), @@ -24,15 +32,89 @@ func TestCompileQuery(t *testing.T) { }) t.Run("TrueQuery", func(t *testing.T) { + t.Parallel() expression, err := Compile(®o.PartialQueries{ Queries: []ast.Body{ must(ast.ParseBody("true")), }, Support: []*ast.Module{}, }) - require.NoError(t, err, "compile empty") + require.NoError(t, err, "compile") require.Equal(t, "true", expression.RegoString(), "true query is rego 'true'") require.Equal(t, "true", expression.SQLString(SQLConfig{}), "true query is sql 'true'") }) + + t.Run("ACLIn", func(t *testing.T) { + t.Parallel() + expression, err := Compile(®o.PartialQueries{ + Queries: []ast.Body{ + ast.MustParseBodyWithOpts(`"*" in input.object.acl_group_list.allUsers`, opts), + }, + Support: []*ast.Module{}, + }) + require.NoError(t, err, "compile") + + require.Equal(t, `internal.member_2("*", input.object.acl_group_list.allUsers)`, expression.RegoString(), "convert to internal_member") + require.Equal(t, `group_acl->allUsers ? '*'`, expression.SQLString(DefaultConfig()), "jsonb in") + }) + + t.Run("Complex", func(t *testing.T) { + t.Parallel() + expression, err := Compile(®o.PartialQueries{ + Queries: []ast.Body{ + ast.MustParseBodyWithOpts(`input.object.org_owner != ""`, opts), + ast.MustParseBodyWithOpts(`input.object.org_owner in {"a", "b", "c"}`, opts), + ast.MustParseBodyWithOpts(`input.object.org_owner != ""`, opts), + ast.MustParseBodyWithOpts(`"read" in input.object.acl_group_list.allUsers`, opts), + ast.MustParseBodyWithOpts(`"read" in input.object.acl_user_list.me`, opts), + }, + Support: []*ast.Module{}, + }) + require.NoError(t, err, "compile") + require.Equal(t, `(organization_id :: text != '' OR `+ + `organization_id :: text = ANY(ARRAY ['a','b','c']) OR `+ + `organization_id :: text != '' OR `+ + `group_acl->allUsers ? 'read' OR `+ + `user_acl->me ? 'read')`, + expression.SQLString(DefaultConfig()), "complex") + }) +} + +//func TestRE(t *testing.T) { +// // ^input\.object\.group_acl\.([^.]*)$ +// re := regexp.MustCompile(`^input\.object\.group_acl\.([^.]*)$`) +// +// x := []string{"test"} +// fmt.Sprintf("test", x) +// +// //re.FindStringSubmatch("input.object.group_acl.allUsers") +// fmt.Println(re.FindStringSubmatch("input.object.group_acl.allUsers")) +//} + +func TestPartialCompileQuery(t *testing.T) { + ctx := context.Background() + defOrg := uuid.New() + unuseID := uuid.New() + + user := subject{ + UserID: "me", + Scope: must(ScopeRole(ScopeAll)), + Roles: []Role{ + must(RoleByName(RoleMember())), + must(RoleByName(RoleOrgMember(defOrg))), + }, + } + var action Action = ActionRead + object := ResourceWorkspace.InOrg(defOrg).WithOwner(unuseID.String()) + + auth := NewAuthorizer() + part, err := auth.Prepare(ctx, user.UserID, user.Roles, user.Scope, action, object.Type) + require.NoError(t, err) + + result, err := Compile(part.partialQueries) + require.NoError(t, err) + + fmt.Println("Rego: ", result.RegoString()) + fmt.Println("SQL: ", result.SQLString(DefaultConfig())) } From 04c1d6a132eb3cfc4d2eb92dde3446f9eff6aa9e Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 28 Sep 2022 18:31:59 -0400 Subject: [PATCH 10/18] Remove auth call on workspaces --- coderd/coderdtest/authorize.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/coderd/coderdtest/authorize.go b/coderd/coderdtest/authorize.go index f3b759845a217..ddba865e2c8ac 100644 --- a/coderd/coderdtest/authorize.go +++ b/coderd/coderdtest/authorize.go @@ -128,11 +128,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { AssertAction: rbac.ActionCreate, AssertObject: workspaceExecObj, }, - "GET:/api/v2/workspaces/": { - StatusCode: http.StatusOK, - AssertAction: rbac.ActionRead, - AssertObject: workspaceRBACObj, - }, "GET:/api/v2/organizations/{organization}/templates": { StatusCode: http.StatusOK, AssertAction: rbac.ActionRead, @@ -250,6 +245,9 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { "PUT:/api/v2/organizations/{organization}/members/{user}/roles": {NoAuthorize: true}, "POST:/api/v2/workspaces/{workspace}/builds": {StatusCode: http.StatusBadRequest, NoAuthorize: true}, "POST:/api/v2/organizations/{organization}/templateversions": {StatusCode: http.StatusBadRequest, NoAuthorize: true}, + + // Endpoints that use the SQLQuery filter. + "GET:/api/v2/workspaces/": {StatusCode: http.StatusOK}, } // Routes like proxy routes support all HTTP methods. A helper func to expand From 98b405e8339c7a633886b97217adcbdaf82b95b1 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 28 Sep 2022 19:02:22 -0400 Subject: [PATCH 11/18] Fix PG endpoints test --- coderd/coderdtest/authorize.go | 35 +++++++++++++++---------- coderd/rbac/query_internal_test.go | 41 ------------------------------ 2 files changed, 22 insertions(+), 54 deletions(-) diff --git a/coderd/coderdtest/authorize.go b/coderd/coderdtest/authorize.go index ddba865e2c8ac..63be497e0b622 100644 --- a/coderd/coderdtest/authorize.go +++ b/coderd/coderdtest/authorize.go @@ -247,7 +247,7 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { "POST:/api/v2/organizations/{organization}/templateversions": {StatusCode: http.StatusBadRequest, NoAuthorize: true}, // Endpoints that use the SQLQuery filter. - "GET:/api/v2/workspaces/": {StatusCode: http.StatusOK}, + "GET:/api/v2/workspaces/": {StatusCode: http.StatusOK, NoAuthorize: true}, } // Routes like proxy routes support all HTTP methods. A helper func to expand @@ -528,11 +528,12 @@ func (r *RecordingAuthorizer) ByRoleName(_ context.Context, subjectID string, ro func (r *RecordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID string, roles []string, scope rbac.Scope, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) { return &fakePreparedAuthorizer{ - Original: r, - SubjectID: subjectID, - Roles: roles, - Scope: scope, - Action: action, + Original: r, + SubjectID: subjectID, + Roles: roles, + Scope: scope, + Action: action, + HardCodedSQLString: "true", }, nil } @@ -541,11 +542,13 @@ func (r *RecordingAuthorizer) reset() { } type fakePreparedAuthorizer struct { - Original *RecordingAuthorizer - SubjectID string - Roles []string - Scope rbac.Scope - Action rbac.Action + Original *RecordingAuthorizer + SubjectID string + Roles []string + Scope rbac.Scope + Action rbac.Action + HardCodedSQLString string + HardCodedRegoString string } func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Object) error { @@ -562,10 +565,16 @@ func (f *fakePreparedAuthorizer) Eval(object rbac.Object) bool { return f.Original.ByRoleName(context.Background(), f.SubjectID, f.Roles, f.Scope, f.Action, object) == nil } -func (fakePreparedAuthorizer) RegoString() string { +func (f fakePreparedAuthorizer) RegoString() string { + if f.HardCodedRegoString != "" { + return f.HardCodedRegoString + } panic("not implemented") } -func (fakePreparedAuthorizer) SQLString(_ rbac.SQLConfig) string { +func (f fakePreparedAuthorizer) SQLString(_ rbac.SQLConfig) string { + if f.HardCodedSQLString != "" { + return f.HardCodedSQLString + } panic("not implemented") } diff --git a/coderd/rbac/query_internal_test.go b/coderd/rbac/query_internal_test.go index 50c1efc35b7aa..f7923062b18eb 100644 --- a/coderd/rbac/query_internal_test.go +++ b/coderd/rbac/query_internal_test.go @@ -1,11 +1,8 @@ package rbac import ( - "context" - "fmt" "testing" - "github.com/google/uuid" "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/rego" @@ -80,41 +77,3 @@ func TestCompileQuery(t *testing.T) { expression.SQLString(DefaultConfig()), "complex") }) } - -//func TestRE(t *testing.T) { -// // ^input\.object\.group_acl\.([^.]*)$ -// re := regexp.MustCompile(`^input\.object\.group_acl\.([^.]*)$`) -// -// x := []string{"test"} -// fmt.Sprintf("test", x) -// -// //re.FindStringSubmatch("input.object.group_acl.allUsers") -// fmt.Println(re.FindStringSubmatch("input.object.group_acl.allUsers")) -//} - -func TestPartialCompileQuery(t *testing.T) { - ctx := context.Background() - defOrg := uuid.New() - unuseID := uuid.New() - - user := subject{ - UserID: "me", - Scope: must(ScopeRole(ScopeAll)), - Roles: []Role{ - must(RoleByName(RoleMember())), - must(RoleByName(RoleOrgMember(defOrg))), - }, - } - var action Action = ActionRead - object := ResourceWorkspace.InOrg(defOrg).WithOwner(unuseID.String()) - - auth := NewAuthorizer() - part, err := auth.Prepare(ctx, user.UserID, user.Roles, user.Scope, action, object.Type) - require.NoError(t, err) - - result, err := Compile(part.partialQueries) - require.NoError(t, err) - - fmt.Println("Rego: ", result.RegoString()) - fmt.Println("SQL: ", result.SQLString(DefaultConfig())) -} From 6ad0b51c01c175aa57c0a7601113449a2022705d Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 28 Sep 2022 20:32:23 -0400 Subject: [PATCH 12/18] Match psql implementation --- coderd/coderdtest/authorize.go | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/coderd/coderdtest/authorize.go b/coderd/coderdtest/authorize.go index 63be497e0b622..3c291dfefcb58 100644 --- a/coderd/coderdtest/authorize.go +++ b/coderd/coderdtest/authorize.go @@ -515,17 +515,23 @@ type RecordingAuthorizer struct { var _ rbac.Authorizer = (*RecordingAuthorizer)(nil) -func (r *RecordingAuthorizer) ByRoleName(_ context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object) error { - r.Called = &authCall{ - SubjectID: subjectID, - Roles: roleNames, - Scope: scope, - Action: action, - Object: object, +func (r *RecordingAuthorizer) FakeByRoleName(_ context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object, record bool) error { + if record { + r.Called = &authCall{ + SubjectID: subjectID, + Roles: roleNames, + Scope: scope, + Action: action, + Object: object, + } } return r.AlwaysReturn } +func (r *RecordingAuthorizer) ByRoleName(ctx context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object) error { + return r.FakeByRoleName(ctx, subjectID, roleNames, scope, action, object, true) +} + func (r *RecordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID string, roles []string, scope rbac.Scope, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) { return &fakePreparedAuthorizer{ Original: r, @@ -552,7 +558,7 @@ type fakePreparedAuthorizer struct { } func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Object) error { - return f.Original.ByRoleName(ctx, f.SubjectID, f.Roles, f.Scope, f.Action, object) + return f.Original.FakeByRoleName(ctx, f.SubjectID, f.Roles, f.Scope, f.Action, object, true) } // Compile returns a compiled version of the authorizer that will work for @@ -562,7 +568,7 @@ func (f *fakePreparedAuthorizer) Compile() (rbac.AuthorizeFilter, error) { } func (f *fakePreparedAuthorizer) Eval(object rbac.Object) bool { - return f.Original.ByRoleName(context.Background(), f.SubjectID, f.Roles, f.Scope, f.Action, object) == nil + return f.Original.FakeByRoleName(context.Background(), f.SubjectID, f.Roles, f.Scope, f.Action, object, false) == nil } func (f fakePreparedAuthorizer) RegoString() string { From 7cfad877ff795523c4a02e65ba8d51053e1b979c Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 28 Sep 2022 20:52:30 -0400 Subject: [PATCH 13/18] Add some comments --- coderd/coderdtest/authorize.go | 26 +++++++++++++------------- coderd/rbac/query.go | 20 +++++++++++--------- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/coderd/coderdtest/authorize.go b/coderd/coderdtest/authorize.go index 3c291dfefcb58..2549f9179e5d3 100644 --- a/coderd/coderdtest/authorize.go +++ b/coderd/coderdtest/authorize.go @@ -515,21 +515,21 @@ type RecordingAuthorizer struct { var _ rbac.Authorizer = (*RecordingAuthorizer)(nil) -func (r *RecordingAuthorizer) FakeByRoleName(_ context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object, record bool) error { - if record { - r.Called = &authCall{ - SubjectID: subjectID, - Roles: roleNames, - Scope: scope, - Action: action, - Object: object, - } - } +// ByRoleNameSQL does not record the call. This matches the postgres behavior +// of not calling Authorize() +func (r *RecordingAuthorizer) ByRoleNameSQL(_ context.Context, _ string, _ []string, _ rbac.Scope, _ rbac.Action, _ rbac.Object) error { return r.AlwaysReturn } func (r *RecordingAuthorizer) ByRoleName(ctx context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object) error { - return r.FakeByRoleName(ctx, subjectID, roleNames, scope, action, object, true) + r.Called = &authCall{ + SubjectID: subjectID, + Roles: roleNames, + Scope: scope, + Action: action, + Object: object, + } + return r.AlwaysReturn } func (r *RecordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID string, roles []string, scope rbac.Scope, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) { @@ -558,7 +558,7 @@ type fakePreparedAuthorizer struct { } func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Object) error { - return f.Original.FakeByRoleName(ctx, f.SubjectID, f.Roles, f.Scope, f.Action, object, true) + return f.Original.ByRoleName(ctx, f.SubjectID, f.Roles, f.Scope, f.Action, object) } // Compile returns a compiled version of the authorizer that will work for @@ -568,7 +568,7 @@ func (f *fakePreparedAuthorizer) Compile() (rbac.AuthorizeFilter, error) { } func (f *fakePreparedAuthorizer) Eval(object rbac.Object) bool { - return f.Original.FakeByRoleName(context.Background(), f.SubjectID, f.Roles, f.Scope, f.Action, object, false) == nil + return f.Original.ByRoleNameSQL(context.Background(), f.SubjectID, f.Roles, f.Scope, f.Action, object) == nil } func (f fakePreparedAuthorizer) RegoString() string { diff --git a/coderd/rbac/query.go b/coderd/rbac/query.go index 397b05bcc3552..40866d9c58fb6 100644 --- a/coderd/rbac/query.go +++ b/coderd/rbac/query.go @@ -29,7 +29,7 @@ type SQLColumn struct { // Type indicates the postgres type of the column. Some expressions will // need to know this in order to determine what SQL to produce. // An example is if the variable is a jsonb array, the "contains" SQL - // query is `"value"' @> variable.` instead of `'value' = ANY(variable)`. + // query is `variable ? 'value'` instead of `'value' = ANY(variable)`. // This type is only needed to be provided Type TermType } @@ -47,7 +47,7 @@ type SQLConfig struct { // } // "input\.object\.group_acl\.(.*)": SQLColumn{ // ColumnSelect: "group_acl->$1", - // Type: VarTypeJsonb + // Type: VarTypeJsonbTextArray // } Variables []SQLColumn } @@ -394,12 +394,14 @@ func (t opInternalMember2) Eval(object Object) bool { func (t opInternalMember2) SQLString(cfg SQLConfig) string { if haystack, ok := t.Haystack.(*termVariable); ok { // This is a special case where the haystack is a jsonb array. - // The more general way to solve this would be to implement a fuller type - // system and handle type conversions for each supported type. - // Then we could determine that the haystack is always an "array" and - // implement the "contains" function on the array type. - // But that requires a lot more code to handle a lot of cases we don't - // actually care about. + // There is a more general way to solve this, but that requires a lot + // more code to cover a lot more cases that we do not care about. + // To handle this more generally we should implement "Array" as a type. + // Then have the `contains` function on the Array type. This would defer + // knowing the element type to the Array and cover more cases without + // having to add more "if" branches here. + // But until we need more cases, our basic type system is ok, and + // this is the only case we need to handle. if haystack.SQLType(cfg) == VarTypeJsonbTextArray { return fmt.Sprintf("%s ? %s", haystack.SQLString(cfg), t.Needle.SQLString(cfg)) } @@ -433,7 +435,7 @@ func (t termString) SQLString(_ SQLConfig) string { return "'" + t.Value + "'" } -func (t termString) SQLType(_ SQLConfig) TermType { +func (termString) SQLType(_ SQLConfig) TermType { return VarTypeText } From f10e9b70a823932eeb9dc1081769138db82af3f0 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 28 Sep 2022 21:20:35 -0400 Subject: [PATCH 14/18] Remove unused argument --- coderd/coderdtest/authorize.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/coderdtest/authorize.go b/coderd/coderdtest/authorize.go index 2549f9179e5d3..8677b305c1e20 100644 --- a/coderd/coderdtest/authorize.go +++ b/coderd/coderdtest/authorize.go @@ -521,7 +521,7 @@ func (r *RecordingAuthorizer) ByRoleNameSQL(_ context.Context, _ string, _ []str return r.AlwaysReturn } -func (r *RecordingAuthorizer) ByRoleName(ctx context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object) error { +func (r *RecordingAuthorizer) ByRoleName(_ context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object) error { r.Called = &authCall{ SubjectID: subjectID, Roles: roleNames, From d89c4d22976ce3d968c009070d03b9f022edb202 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 29 Sep 2022 15:07:41 -0400 Subject: [PATCH 15/18] Add query name for tracking --- coderd/database/custom_queries.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/coderd/database/custom_queries.go b/coderd/database/custom_queries.go index bbcfc3e6ba813..b8c004ebb0c41 100644 --- a/coderd/database/custom_queries.go +++ b/coderd/database/custom_queries.go @@ -18,7 +18,8 @@ type customQuerier interface { // This code is copied from `GetWorkspaces` and adds the authorized filter WHERE // clause. func (q *sqlQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) { - query := fmt.Sprintf("%s AND %s", getWorkspaces, authorizedFilter.SQLString(rbac.DefaultConfig())) + // The name comment is for metric tracking + query := fmt.Sprintf("-- name: AuthorizedGetWorkspaces :many\n%s AND %s", getWorkspaces, authorizedFilter.SQLString(rbac.DefaultConfig())) rows, err := q.db.QueryContext(ctx, query, arg.Deleted, arg.OwnerID, From 3828c29e580ef4095f7ab60bfafb8117a3f17055 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 30 Sep 2022 10:14:26 -0400 Subject: [PATCH 16/18] Handle nested types This solves it without proper types in our AST. Might bite the bullet and implement some better types --- coderd/rbac/query.go | 98 ++++++++++++++++++++++++++---- coderd/rbac/query_internal_test.go | 13 ++++ 2 files changed, 99 insertions(+), 12 deletions(-) diff --git a/coderd/rbac/query.go b/coderd/rbac/query.go index 40866d9c58fb6..0211a1ff538b0 100644 --- a/coderd/rbac/query.go +++ b/coderd/rbac/query.go @@ -56,12 +56,12 @@ func DefaultConfig() SQLConfig { return SQLConfig{ Variables: []SQLColumn{ { - RegoMatch: regexp.MustCompile(`^input\.object\.acl_group_list\.([^.]*)$`), + RegoMatch: regexp.MustCompile(`^input\.object\.acl_group_list\.?(.*)$`), ColumnSelect: "group_acl->$1", Type: VarTypeJsonbTextArray, }, { - RegoMatch: regexp.MustCompile(`^input\.object\.acl_user_list\.([^.]*)$`), + RegoMatch: regexp.MustCompile(`^input\.object\.acl_user_list\.?(.*)$`), ColumnSelect: "user_acl->$1", Type: VarTypeJsonbTextArray, }, @@ -226,19 +226,42 @@ func processTerm(term *ast.Term) (Term, error) { Value: bool(v), }, nil case ast.Ref: + obj := &termObject{ + base: base, + Variables: []termVariable{}, + } + var idx int // A ref is a set of terms. If the first term is a var, then the // following terms are the path to the value. - if v0, ok := v[0].Value.(ast.Var); ok { - name := trimQuotes(v0.String()) - for _, p := range v[1:] { - name += "." + trimQuotes(p.String()) + var builder strings.Builder + for _, term := range v { + if idx == 0 { + if _, ok := v[0].Value.(ast.Var); !ok { + return nil, xerrors.Errorf("invalid term (%s): ref must start with a var, started with %T", v[0].String(), v[0]) + } } - return &termVariable{ - base: base, - Name: name, - }, nil + + if _, ok := term.Value.(ast.Ref); ok { + // New obj + obj.Variables = append(obj.Variables, termVariable{ + base: base, + Name: builder.String(), + }) + builder.Reset() + idx = 0 + } + if builder.Len() != 0 { + builder.WriteString(".") + } + builder.WriteString(trimQuotes(term.String())) + idx++ } - return nil, xerrors.Errorf("invalid term: ref must start with a var, started with %T", v[0]) + + obj.Variables = append(obj.Variables, termVariable{ + base: base, + Name: builder.String(), + }) + return obj, nil case ast.Var: return &termVariable{ Name: trimQuotes(v.String()), @@ -392,7 +415,7 @@ func (t opInternalMember2) Eval(object Object) bool { } func (t opInternalMember2) SQLString(cfg SQLConfig) string { - if haystack, ok := t.Haystack.(*termVariable); ok { + if haystack, ok := t.Haystack.(*termObject); ok { // This is a special case where the haystack is a jsonb array. // There is a more general way to solve this, but that requires a lot // more code to cover a lot more cases that we do not care about. @@ -439,6 +462,57 @@ func (termString) SQLType(_ SQLConfig) TermType { return VarTypeText } +// termObject is a variable that can be dereferenced. We count some rego objects +// as single variables, eg: input.object.org_owner. In reality, it is a nested +// object. +// In rego, we can dereference the object with the "." operator, which we can +// handle with regex. +// Or we can dereference the object with the "[]", which we can handle with this +// term type. +type termObject struct { + base + Variables []termVariable +} + +func (t termObject) EvalTerm(obj Object) interface{} { + if len(t.Variables) == 0 { + return t.Variables[0].EvalTerm(obj) + } + panic("no nested structures are supported yet") +} + +func (t termObject) SQLType(cfg SQLConfig) TermType { + // Without a full type system, let's just assume the type of the first var + // is the resulting type. This is correct for our use case. + // Solving this more generally requires a full type system, which is + // excessive for our mostly static policy. + return t.Variables[0].SQLType(cfg) +} + +func (t termObject) SQLString(cfg SQLConfig) string { + if len(t.Variables) == 1 { + return t.Variables[0].SQLString(cfg) + } + // Combine the last 2 variables into 1 variable. + end := t.Variables[len(t.Variables)-1] + before := t.Variables[len(t.Variables)-2] + + return termObject{ + base: t.base, + Variables: append( + t.Variables[:len(t.Variables)-2], + termVariable{ + base: base{ + Rego: before.base.Rego + "[" + end.base.Rego + "]", + }, + // Convert the end to SQL string. We evaluate each term + // one at a time. + Name: before.Name + "." + end.SQLString(cfg), + }, + ), + }.SQLString(cfg) +} + type termVariable struct { base Name string diff --git a/coderd/rbac/query_internal_test.go b/coderd/rbac/query_internal_test.go index f7923062b18eb..92d8b91543953 100644 --- a/coderd/rbac/query_internal_test.go +++ b/coderd/rbac/query_internal_test.go @@ -76,4 +76,17 @@ func TestCompileQuery(t *testing.T) { `user_acl->me ? 'read')`, expression.SQLString(DefaultConfig()), "complex") }) + + t.Run("SetDereference", func(t *testing.T) { + t.Parallel() + expression, err := Compile(®o.PartialQueries{ + Queries: []ast.Body{ + ast.MustParseBodyWithOpts(`"*" in input.object.acl_group_list[input.object.org_owner]`, opts), + }, + Support: []*ast.Module{}, + }) + require.NoError(t, err, "compile") + require.Equal(t, `group_acl->organization_id :: text ? '*'`, + expression.SQLString(DefaultConfig()), "set dereference") + }) } From 913fb27c4b070e87bf3cf5f6901d00fe78731b50 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 30 Sep 2022 10:18:50 -0400 Subject: [PATCH 17/18] Add comment --- coderd/rbac/query.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/coderd/rbac/query.go b/coderd/rbac/query.go index 0211a1ff538b0..d8b1a140e9eb0 100644 --- a/coderd/rbac/query.go +++ b/coderd/rbac/query.go @@ -497,6 +497,8 @@ func (t termObject) SQLString(cfg SQLConfig) string { end := t.Variables[len(t.Variables)-1] before := t.Variables[len(t.Variables)-2] + // Recursively solve the SQLString by removing the last nested reference. + // This continues until we have a single variable. return termObject{ base: t.base, Variables: append( From 3e2fbb801831a84c2be6b449987b6e4f1f9a8bea Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 30 Sep 2022 12:19:36 -0400 Subject: [PATCH 18/18] Renaming function call to GetAuthorizedWorkspaces --- coderd/database/custom_queries.go | 8 ++++---- coderd/database/databasefake/databasefake.go | 4 ++-- coderd/workspaces.go | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/coderd/database/custom_queries.go b/coderd/database/custom_queries.go index b8c004ebb0c41..219a6cdb13c7b 100644 --- a/coderd/database/custom_queries.go +++ b/coderd/database/custom_queries.go @@ -11,15 +11,15 @@ import ( ) type customQuerier interface { - AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) + GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) } -// AuthorizedGetWorkspaces returns all workspaces that the user is authorized to access. +// GetAuthorizedWorkspaces returns all workspaces that the user is authorized to access. // This code is copied from `GetWorkspaces` and adds the authorized filter WHERE // clause. -func (q *sqlQuerier) AuthorizedGetWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) { +func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) { // The name comment is for metric tracking - query := fmt.Sprintf("-- name: AuthorizedGetWorkspaces :many\n%s AND %s", getWorkspaces, authorizedFilter.SQLString(rbac.DefaultConfig())) + query := fmt.Sprintf("-- name: GetAuthorizedWorkspaces :many\n%s AND %s", getWorkspaces, authorizedFilter.SQLString(rbac.DefaultConfig())) rows, err := q.db.QueryContext(ctx, query, arg.Deleted, arg.OwnerID, diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index d66cb662a6636..33a33cc8c7496 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -522,11 +522,11 @@ func (q *fakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U func (q *fakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams) ([]database.Workspace, error) { // A nil auth filter means no auth filter. - workspaces, err := q.AuthorizedGetWorkspaces(ctx, arg, nil) + workspaces, err := q.GetAuthorizedWorkspaces(ctx, arg, nil) return workspaces, err } -func (q *fakeQuerier) AuthorizedGetWorkspaces(_ context.Context, arg database.GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]database.Workspace, error) { +func (q *fakeQuerier) GetAuthorizedWorkspaces(_ context.Context, arg database.GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]database.Workspace, error) { q.mutex.RLock() defer q.mutex.RUnlock() diff --git a/coderd/workspaces.go b/coderd/workspaces.go index a20093130aef4..ff834aa6b246b 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -122,7 +122,7 @@ func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) { return } - workspaces, err := api.Database.AuthorizedGetWorkspaces(ctx, filter, sqlFilter) + workspaces, err := api.Database.GetAuthorizedWorkspaces(ctx, filter, sqlFilter) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching workspaces.",