Skip to content

Commit d8fce5e

Browse files
committed
chore: Rewrite rbac rego -> SQL clause
Previous code was challenging to read with edge cases
1 parent 5fa3fde commit d8fce5e

File tree

15 files changed

+1555
-0
lines changed

15 files changed

+1555
-0
lines changed

coderd/rbac/regosql/aclGroupVar.go

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
package regosql
2+
3+
import (
4+
"fmt"
5+
6+
"github.com/open-policy-agent/opa/ast"
7+
8+
"github.com/coder/coder/coderd/rbac/regosql/sqltypes"
9+
)
10+
11+
var _ sqltypes.VariableMatcher = ACLGroupVar{}
12+
13+
// ACLGroupVar is also the Node type to reduce the number of types that we need
14+
// to export.
15+
var _ sqltypes.Node = ACLGroupVar{}
16+
17+
// ACLGroupVar is a variable matcher that handles group_acl and user_acl.
18+
// The sql type is a jsonb object with the following structure:
19+
//
20+
// "group_acl": {
21+
// "<group_name>": ["<actions>"]
22+
// }
23+
//
24+
// This is a custom variable matcher as json objects have arbitrary complexity.
25+
type ACLGroupVar struct {
26+
StructSQL string
27+
StructPath []string
28+
// DenyAll is helpful for when we don't care about ACL groups.
29+
// We need to default to denying access.
30+
DenyAll bool
31+
32+
// FieldReference handles referencing the subfields, which could be
33+
// more variables. We pass one in as the global one might not be correctly
34+
// scoped.
35+
FieldReference sqltypes.VariableMatcher
36+
37+
// Instance fields
38+
Source sqltypes.RegoSource
39+
GroupNode sqltypes.Node
40+
}
41+
42+
func ACLGroupMatcher(fieldRefernce sqltypes.VariableMatcher, structSQL string, structPath []string) ACLGroupVar {
43+
return ACLGroupVar{StructSQL: structSQL, StructPath: structPath, FieldReference: fieldRefernce}
44+
}
45+
46+
func (ACLGroupVar) UseAs() sqltypes.Node { return ACLGroupVar{} }
47+
48+
// Disable is a helper to disable the ACL group matching in the SQL generation.
49+
// This is because some tables do not have ACL columns, and in this case we
50+
// do not want to grant access based on columns that do not exist.
51+
// This replaces any clause with "group_acl" or "user_acl" with "false".
52+
func (g *ACLGroupVar) Disable() *ACLGroupVar {
53+
g.DenyAll = true
54+
return g
55+
}
56+
57+
func (g ACLGroupVar) ConvertVariable(rego ast.Ref) (sqltypes.Node, bool) {
58+
// "left" will be a map of group names to actions in rego.
59+
// {
60+
// "all_users": ["read"]
61+
// }
62+
left, err := sqltypes.RegoVarPath(g.StructPath, rego)
63+
if err != nil {
64+
return nil, false
65+
}
66+
67+
aclGrp := ACLGroupVar{
68+
DenyAll: g.DenyAll,
69+
StructSQL: g.StructSQL,
70+
StructPath: g.StructPath,
71+
FieldReference: g.FieldReference,
72+
73+
Source: sqltypes.RegoSource(rego.String()),
74+
}
75+
76+
// We expect 1 more term. Either a ref or a string.
77+
if len(left) != 1 {
78+
return nil, false
79+
}
80+
81+
// If the remaining is a variable, then we need to convert it.
82+
// Assuming we support variable fields.
83+
ref, ok := left[0].Value.(ast.Ref)
84+
if ok && g.FieldReference != nil {
85+
groupNode, ok := g.FieldReference.ConvertVariable(ref)
86+
if ok {
87+
aclGrp.GroupNode = groupNode
88+
return aclGrp, true
89+
}
90+
}
91+
92+
// If it is a string, we assume it is a literal
93+
groupName, ok := left[0].Value.(ast.String)
94+
if ok {
95+
aclGrp.GroupNode = sqltypes.String(string(groupName))
96+
return aclGrp, true
97+
}
98+
99+
// If we have not matched it yet, then it is something we do not recognize.
100+
return nil, false
101+
}
102+
103+
func (g ACLGroupVar) SQLString(cfg *sqltypes.SQLGenerator) string {
104+
if g.DenyAll {
105+
return "false"
106+
}
107+
return fmt.Sprintf("%s->%s", g.StructSQL, g.GroupNode.SQLString(cfg))
108+
}
109+
110+
func (g ACLGroupVar) ContainsSQL(cfg *sqltypes.SQLGenerator, other sqltypes.Node) (string, error) {
111+
if g.DenyAll {
112+
return "false", nil
113+
}
114+
115+
switch other.UseAs().(type) {
116+
// Only supports containing other strings.
117+
case sqltypes.AstString:
118+
return fmt.Sprintf("%s ? %s", g.SQLString(cfg), other.SQLString(cfg)), nil
119+
}
120+
121+
return "", fmt.Errorf("unsupported acl group contains %T", other)
122+
}

0 commit comments

Comments
 (0)