Skip to content

Commit 300dc4c

Browse files
committed
fixup! Add unit tests for equality and membership
1 parent 6c89f1d commit 300dc4c

File tree

4 files changed

+121
-11
lines changed

4 files changed

+121
-11
lines changed

coderd/rbac/regosql/doc.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
// Package regosql converts rego queries into SQL WHERE clauses. This is so
2+
// the rego queries can be used to filter the results of a SQL query.
3+
package regosql

coderd/rbac/regosql/sqltypes/doc.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// Package sqltypes contains the types used to convert rego queries into SQL.
2+
// The rego ast is converted into these types to better control the SQL
3+
// generation. It allows writing the SQL generation for types in an easier to
4+
// read way.
5+
package sqltypes

coderd/rbac/regosql/sqltypes/equality_test.go

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,36 +45,54 @@ func TestEquality(t *testing.T) {
4545
Equality: sqltypes.Equality(false,
4646
sqltypes.Bool(true),
4747
sqltypes.Equality(true,
48-
sqltypes.String("foo"),
49-
sqltypes.String("bar"),
48+
sqltypes.Equality(true,
49+
sqltypes.String("foo"),
50+
sqltypes.String("bar"),
51+
),
52+
sqltypes.Bool(false),
5053
),
5154
),
52-
ExpectedSQL: "true = ('foo' != 'bar')",
55+
ExpectedSQL: "true = (('foo' != 'bar') != false)",
5356
},
5457
{
55-
Name: "String=Equality",
58+
Name: "Equality=Equality",
5659
Equality: sqltypes.Equality(false,
57-
sqltypes.Bool(true),
60+
sqltypes.Equality(true,
61+
sqltypes.Bool(true),
62+
sqltypes.Bool(false),
63+
),
5864
sqltypes.Equality(false,
5965
sqltypes.String("foo"),
6066
sqltypes.String("foo"),
6167
),
6268
),
63-
ExpectedSQL: "true = ('foo' = 'foo')",
69+
ExpectedSQL: "(true != false) = ('foo' = 'foo')",
6470
},
6571
{
66-
Name: "Equality=Equality",
72+
Name: "Membership=Membership",
6773
Equality: sqltypes.Equality(false,
6874
sqltypes.Equality(true,
69-
sqltypes.Bool(true),
75+
sqltypes.MemberOf(
76+
sqltypes.String("foo"),
77+
must(sqltypes.Array("",
78+
sqltypes.String("foo"),
79+
sqltypes.String("bar"),
80+
)),
81+
),
7082
sqltypes.Bool(false),
7183
),
7284
sqltypes.Equality(false,
73-
sqltypes.String("foo"),
74-
sqltypes.String("foo"),
85+
sqltypes.Bool(true),
86+
sqltypes.MemberOf(
87+
sqltypes.Number("", "2"),
88+
must(sqltypes.Array("",
89+
sqltypes.Number("", "5"),
90+
sqltypes.Number("", "2"),
91+
)),
92+
),
7593
),
7694
),
77-
ExpectedSQL: "(true != false) = ('foo' = 'foo')",
95+
ExpectedSQL: "(('foo' = ANY(ARRAY ['foo','bar'])) != false) = (true = (2 = ANY(ARRAY [5,2])))",
7896
},
7997
}
8098

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
package sqltypes_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/coder/coder/coderd/rbac/regosql/sqltypes"
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
func TestMembership(t *testing.T) {
11+
t.Parallel()
12+
testCases := []struct {
13+
Name string
14+
Membership sqltypes.Node
15+
ExpectedSQL string
16+
ExpectedErrors int
17+
}{
18+
{
19+
Name: "StringArray",
20+
Membership: sqltypes.MemberOf(
21+
sqltypes.String("foo"),
22+
must(sqltypes.Array("",
23+
sqltypes.String("bar"),
24+
sqltypes.String("buzz"),
25+
)),
26+
),
27+
ExpectedSQL: "'foo' = ANY(ARRAY ['bar','buzz'])",
28+
},
29+
{
30+
Name: "NumberArray",
31+
Membership: sqltypes.MemberOf(
32+
sqltypes.Number("", "5"),
33+
must(sqltypes.Array("",
34+
sqltypes.Number("", "2"),
35+
sqltypes.Number("", "5"),
36+
)),
37+
),
38+
ExpectedSQL: "5 = ANY(ARRAY [2,5])",
39+
},
40+
{
41+
Name: "BoolArray",
42+
Membership: sqltypes.MemberOf(
43+
sqltypes.Bool(true),
44+
must(sqltypes.Array("",
45+
sqltypes.Bool(false),
46+
sqltypes.Bool(true),
47+
)),
48+
),
49+
ExpectedSQL: "true = ANY(ARRAY [false,true])",
50+
},
51+
52+
// Errors
53+
{
54+
Name: "Unsupported",
55+
Membership: sqltypes.MemberOf(
56+
sqltypes.Bool(true),
57+
sqltypes.Bool(true),
58+
),
59+
ExpectedErrors: 1,
60+
},
61+
}
62+
63+
for _, tc := range testCases {
64+
tc := tc
65+
t.Run(tc.Name, func(t *testing.T) {
66+
t.Parallel()
67+
68+
gen := sqltypes.NewSQLGenerator()
69+
found := tc.Membership.SQLString(gen)
70+
if tc.ExpectedErrors > 0 {
71+
require.Equal(t, tc.ExpectedErrors, len(gen.Errors()), "expected AstNumber of errors")
72+
} else {
73+
require.Equal(t, tc.ExpectedSQL, found, "expected sql")
74+
}
75+
})
76+
}
77+
}
78+
79+
func must[V any](v V, err error) V {
80+
if err != nil {
81+
panic(err)
82+
}
83+
return v
84+
}

0 commit comments

Comments
 (0)