Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Do not export extra types
  • Loading branch information
Emyrk committed Sep 27, 2022
commit cb5d5198bf02d569a2a69616adda99438dcd219f
132 changes: 80 additions & 52 deletions coderd/rbac/query.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,41 @@
package rbac

import (
"context"
"fmt"
"strconv"
"strings"

"github.com/open-policy-agent/opa/ast"

"golang.org/x/xerrors"

"github.com/open-policy-agent/opa/rego"
"golang.org/x/xerrors"
)

// Example python: https://github.com/open-policy-agent/contrib/tree/main/data_filter_example
//

func Compile(ctx context.Context, partialQueries *rego.PartialQueries) (Expression, error) {
// Compile will convert a rego query AST into our custom types. The output is
// an AST that can be used to generate SQL.
func Compile(partialQueries *rego.PartialQueries) (Expression, error) {
if len(partialQueries.Support) > 0 {
return nil, xerrors.Errorf("cannot convert support rules, expect 0 found %d", len(partialQueries.Support))
}

// 0 queries means the result is "undefined". This is the same as "false".
if len(partialQueries.Queries) == 0 {
return &termBoolean{
base: base{Rego: "false"},
Value: false,
}, nil
}

// Abort early if any of the "OR"'d expressions are the empty string.
// This is the same as "true".
for _, query := range partialQueries.Queries {
if query.String() == "" {
return &termBoolean{
base: base{Rego: "true"},
Value: true,
}, nil
}
}

result := make([]Expression, 0, len(partialQueries.Queries))
var builder strings.Builder
for i := range partialQueries.Queries {
Expand All @@ -33,14 +49,16 @@ func Compile(ctx context.Context, partialQueries *rego.PartialQueries) (Expressi
}
builder.WriteString(partialQueries.Queries[i].String())
}
return ExpOr{
Base: Base{
return expOr{
base: base{
Rego: builder.String(),
},
Expressions: result,
}, nil
}

// processQuery processes an entire set of expressions and joins them with
// "AND".
func processQuery(query ast.Body) (Expression, error) {
expressions := make([]Expression, 0, len(query))
for _, astExpr := range query {
Expand All @@ -51,8 +69,8 @@ func processQuery(query ast.Body) (Expression, error) {
expressions = append(expressions, expr)
}

return ExpAnd{
Base: Base{
return expAnd{
base: base{
Rego: query.String(),
},
Expressions: expressions,
Expand All @@ -65,15 +83,15 @@ func processExpression(expr *ast.Expr) (Expression, error) {
}

op := expr.Operator().String()
base := Base{Rego: op}
base := base{Rego: op}
switch op {
case "neq", "eq", "equal":
terms, err := processTerms(2, expr.Operands())
if err != nil {
return nil, xerrors.Errorf("invalid '%s' expression: %w", op, err)
}
return &OpEqual{
Base: base,
return &opEqual{
base: base,
Terms: [2]Term{terms[0], terms[1]},
Not: op == "neq",
}, nil
Expand All @@ -82,12 +100,10 @@ func processExpression(expr *ast.Expr) (Expression, error) {
if err != nil {
return nil, xerrors.Errorf("invalid '%s' expression: %w", op, err)
}
return &OpInternalMember2{
Base: base,
return &opInternalMember2{
base: base,
Terms: [2]Term{terms[0], terms[1]},
}, nil

//case "eq", "equal":
default:
return nil, xerrors.Errorf("invalid expression: operator %s not supported", op)
}
Expand All @@ -111,7 +127,7 @@ func processTerms(expected int, terms []*ast.Term) ([]Term, error) {
}

func processTerm(term *ast.Term) (Term, error) {
base := Base{Rego: term.String()}
base := base{Rego: term.String()}
switch v := term.Value.(type) {
case ast.Ref:
// A ref is a set of terms. If the first term is a var, then the
Expand All @@ -121,70 +137,70 @@ func processTerm(term *ast.Term) (Term, error) {
for _, p := range v[1:] {
name += "." + trimQuotes(p.String())
}
return &TermVariable{
Base: base,
return &termVariable{
base: base,
Name: name,
}, nil
} else {
return nil, xerrors.Errorf("invalid term: ref must start with a var, started with %T", v[0])
}
case ast.Var:
return &TermVariable{
return &termVariable{
Name: trimQuotes(v.String()),
Base: base,
base: base,
}, nil
case ast.String:
return &TermString{
return &termString{
Value: trimQuotes(v.String()),
Base: base,
base: base,
}, nil
case ast.Set:
return &TermSet{
return &termSet{
Value: v,
Base: base,
base: base,
}, nil
default:
return nil, xerrors.Errorf("invalid term: %T not supported, %q", v, term.String())
}
}

type Base struct {
type base struct {
// Rego is the original rego string
Rego string
}

func (b Base) RegoString() string {
func (b base) RegoString() string {
return b.Rego
}

// Expression comprises a set of terms, operators, and functions. All
// expressions return a boolean value.
//
// Eg: neq(input.object.org_owner, "")
// Eg: neq(input.object.org_owner, "") AND input.object.org_owner == "foo"
type Expression interface {
RegoString() string
SQLString() string
}

type ExpAnd struct {
Base
type expAnd struct {
base
Expressions []Expression
}

func (t ExpAnd) SQLString() string {
func (t expAnd) SQLString() string {
exprs := make([]string, 0, len(t.Expressions))
for _, expr := range t.Expressions {
exprs = append(exprs, expr.SQLString())
}
return "(" + strings.Join(exprs, " AND ") + ")"
}

type ExpOr struct {
Base
type expOr struct {
base
Expressions []Expression
}

func (t ExpOr) SQLString() string {
func (t expOr) SQLString() string {
exprs := make([]string, 0, len(t.Expressions))
for _, expr := range t.Expressions {
exprs = append(exprs, expr.SQLString())
Expand All @@ -202,27 +218,29 @@ type Operator interface {
SQLString() string
}

type OpEqual struct {
Base
type opEqual struct {
base
Terms [2]Term
// For NotEqual
Not bool
}

func (t OpEqual) SQLString() string {
func (t opEqual) SQLString() string {
op := "="
if t.Not {
op = "!="
}
return fmt.Sprintf("%s %s %s", t.Terms[0].SQLString(), op, t.Terms[1].SQLString())
}

type OpInternalMember2 struct {
Base
// opInternalMember2 is checking if the first term is a member of the second term.
// The second term is a set or list.
type opInternalMember2 struct {
base
Terms [2]Term
}

func (t OpInternalMember2) SQLString() string {
func (t opInternalMember2) SQLString() string {
return fmt.Sprintf("%s = ANY(%s)", t.Terms[0].SQLString(), t.Terms[1].SQLString())
}

Expand All @@ -235,30 +253,31 @@ type Term interface {
RegoString() string
}

type TermString struct {
Base
type termString struct {
base
Value string
}

func (t TermString) SQLString() string {
func (t termString) SQLString() string {
return "'" + t.Value + "'"
}

type TermVariable struct {
Base
type termVariable struct {
base
Name string
}

func (t TermVariable) SQLString() string {
func (t termVariable) SQLString() string {
return t.Name
}

type TermSet struct {
Base
// termSet is a set of unique terms.
type termSet struct {
base
Value ast.Set
}

func (t TermSet) SQLString() string {
func (t termSet) SQLString() string {
values := t.Value.Slice()
elems := make([]string, 0, len(values))
// TODO: Handle different typed terms?
Expand All @@ -273,6 +292,15 @@ func (t TermSet) SQLString() string {
return fmt.Sprintf("ARRAY [%s]", strings.Join(elems, ","))
}

type termBoolean struct {
base
Value bool
}

func (t termBoolean) SQLString() string {
return strconv.FormatBool(t.Value)
}

func trimQuotes(s string) string {
return strings.Trim(s, "\"")
}