Skip to content

chore: Rewrite rbac rego -> SQL clause #5138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
More linting
  • Loading branch information
Emyrk committed Nov 21, 2022
commit a521616e44f35282d3f9d7d323d6c583528e0fda
2 changes: 1 addition & 1 deletion coderd/coderdtest/authorize.go
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Obje

// CompileToSQL returns a compiled version of the authorizer that will work for
// in memory databases. This fake version will not work against a SQL database.
func (f *fakePreparedAuthorizer) CompileToSQL(_ regosql.ConvertConfig) (string, error) {
func (_ fakePreparedAuthorizer) CompileToSQL(_ regosql.ConvertConfig) (string, error) {
return "", xerrors.New("not implemented")
}

Expand Down
2 changes: 1 addition & 1 deletion coderd/rbac/regosql/aclGroupVar.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (g ACLGroupVar) SQLString(cfg *sqltypes.SQLGenerator) string {
}

func (g ACLGroupVar) ContainsSQL(cfg *sqltypes.SQLGenerator, other sqltypes.Node) (string, error) {
//nolint:singleCaseSwitch
//nolint:gocritic
switch other.UseAs().(type) {
// Only supports containing other strings.
case sqltypes.AstString:
Expand Down
9 changes: 1 addition & 8 deletions coderd/rbac/regosql/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,11 @@ func ConvertRegoAst(cfg ConvertConfig, partial *rego.PartialQueries) (sqltypes.B
return nil, xerrors.Errorf("query %s: %w", q.String(), err)
}

// Each query should result in a boolean expression. If it is not,
// this cannot be converted to SQL.
boolConverted, ok := converted.(sqltypes.BooleanNode)
if !ok {
return nil, xerrors.Errorf("query %s: not a boolean expression", q.String())
}

if i != 0 {
builder.WriteString("\n")
}
builder.WriteString(q.String())
queries = append(queries, boolConverted)
queries = append(queries, converted)
}

// All queries are OR'd together. This means that if any query is true,
Expand Down
19 changes: 2 additions & 17 deletions coderd/rbac/regosql/compile_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package regosql_test

import (
"context"
"testing"

"github.com/open-policy-agent/opa/ast"
Expand All @@ -15,6 +14,8 @@ import (
// TestRegoQueriesNoVariables handles cases without variables. These should be
// very simple and straight forward.
func TestRegoQueries(t *testing.T) {
t.Parallel()

p := func(v string) string {
return "(" + v + ")"
}
Expand Down Expand Up @@ -299,22 +300,6 @@ func partialQueries(t *testing.T, queries ...string) *rego.PartialQueries {
astQueries = append(astQueries, ast.MustParseBodyWithOpts(q, opts))
}

prepareQueries := make([]rego.PreparedEvalQuery, 0, len(queries))
for _, q := range astQueries {
var prepped rego.PreparedEvalQuery
var err error
if q.String() == "" {
prepped, err = rego.New(
rego.Query("true"),
).PrepareForEval(context.Background())
} else {
prepped, err = rego.New(
rego.ParsedQuery(q),
).PrepareForEval(context.Background())
}
require.NoError(t, err, "prepare query")
prepareQueries = append(prepareQueries, prepped)
}
return &rego.PartialQueries{
Queries: astQueries,
Support: []*ast.Module{},
Expand Down
10 changes: 5 additions & 5 deletions coderd/rbac/regosql/sqltypes/alwaysFalse.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func AlwaysFalseNode(n Node) Node {
}

// UseAs uses a type no one supports to always override with false.
func (f alwaysFalse) UseAs() Node { return alwaysFalse{} }
func (alwaysFalse) UseAs() Node { return alwaysFalse{} }
func (f alwaysFalse) ConvertVariable(rego ast.Ref) (Node, bool) {
if f.Matcher != nil {
n, ok := f.Matcher.ConvertVariable(rego)
Expand All @@ -44,18 +44,18 @@ func (f alwaysFalse) ConvertVariable(rego ast.Ref) (Node, bool) {
return nil, false
}

func (f alwaysFalse) SQLString(_ *SQLGenerator) string {
func (alwaysFalse) SQLString(_ *SQLGenerator) string {
return "false"
}

func (f alwaysFalse) ContainsSQL(_ *SQLGenerator, _ Node) (string, error) {
func (alwaysFalse) ContainsSQL(_ *SQLGenerator, _ Node) (string, error) {
return "false", nil
}

func (f alwaysFalse) ContainedInSQL(_ *SQLGenerator, _ Node) (string, error) {
func (alwaysFalse) ContainedInSQL(_ *SQLGenerator, _ Node) (string, error) {
return "false", nil
}

func (f alwaysFalse) EqualsSQLString(_ *SQLGenerator, not bool, _ Node) (string, error) {
func (alwaysFalse) EqualsSQLString(_ *SQLGenerator, not bool, _ Node) (string, error) {
return "false", nil
}
2 changes: 1 addition & 1 deletion coderd/rbac/regosql/sqltypes/array.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (a ASTArray) ContainsSQL(cfg *SQLGenerator, needle Node) (string, error) {
// This condition supports any contains function if the needle type is
// the same as the ASTArray element type.
if reflect.TypeOf(a.MyType().UseAs()) != reflect.TypeOf(needle.UseAs()) {
return "ArrayContainsError", fmt.Errorf("array contains %q: type mismatch (%T, %T)",
return "ArrayContainsError", xerrors.Errorf("array contains %q: type mismatch (%T, %T)",
a.Source, a.MyType(), needle)
}

Expand Down
1 change: 0 additions & 1 deletion coderd/rbac/regosql/sqltypes/bool.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,4 @@ func (b AstBoolean) SQLString(_ *SQLGenerator) string {

func (b AstBoolean) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
return boolEqualsSQLString(cfg, b, not, other)

}
3 changes: 2 additions & 1 deletion coderd/rbac/regosql/sqltypes/equality.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (e equality) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (stri
}

func boolEqualsSQLString(cfg *SQLGenerator, a BooleanNode, not bool, other Node) (string, error) {
//nolint:singleCaseSwitch
//nolint:gocritic
switch other.UseAs().(type) {
case BooleanNode:
bn, ok := other.(BooleanNode)
Expand All @@ -86,6 +86,7 @@ func boolEqualsSQLString(cfg *SQLGenerator, a BooleanNode, not bool, other Node)
return "", xerrors.Errorf("unsupported equality: %T %s %T", a, equalsOp(not), other)
}

// nolint:revive
func equalsOp(not bool) string {
if not {
return "!="
Expand Down
1 change: 0 additions & 1 deletion coderd/rbac/regosql/sqltypes/equality_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,5 +127,4 @@ func TestEquality(t *testing.T) {
}
})
}

}
2 changes: 1 addition & 1 deletion coderd/rbac/regosql/sqltypes/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type invalidNode struct{}

func (invalidNode) UseAs() Node { return invalidNode{} }

func (i invalidNode) SQLString(cfg *SQLGenerator) string {
func (invalidNode) SQLString(cfg *SQLGenerator) string {
cfg.AddError(xerrors.Errorf("invalid node called"))
return "invalid_type"
}
Expand Down
2 changes: 1 addition & 1 deletion coderd/rbac/regosql/sqltypes/number.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (n AstNumber) SQLString(_ *SQLGenerator) string {
}

func (n AstNumber) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
//nolint:singleCaseSwitch
//nolint:gocritic
switch other.UseAs().(type) {
case AstNumber:
return basicSQLEquality(cfg, not, n, other), nil
Expand Down
2 changes: 1 addition & 1 deletion coderd/rbac/regosql/sqltypes/string.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func (s AstString) SQLString(_ *SQLGenerator) string {
}

func (s AstString) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
//nolint:singleCaseSwitch
//nolint:gocritic
switch other.UseAs().(type) {
case AstString:
return basicSQLEquality(cfg, not, s, other), nil
Expand Down
2 changes: 1 addition & 1 deletion coderd/rbac/regosql/sqltypes/variable.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (s astStringVar) SQLString(_ *SQLGenerator) string {
}

func (s astStringVar) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
//nolint:singleCaseSwitch
//nolint:gocritic
switch other.UseAs().(type) {
case AstString:
return basicSQLEquality(cfg, not, s, other), nil
Expand Down