Skip to content

Commit 3f6096b

Browse files
authored
chore: unit test to enforce authorized queries match args (#11211)
* chore: unit test to enforce authorized queries match args * Also check querycontext arguments
1 parent 7924bb2 commit 3f6096b

File tree

1 file changed

+182
-0
lines changed

1 file changed

+182
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
package gentest_test
2+
3+
import (
4+
"fmt"
5+
"go/ast"
6+
"go/parser"
7+
"go/token"
8+
"testing"
9+
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
"golang.org/x/exp/slices"
13+
)
14+
15+
// TestCustomQueriesSynced makes sure the manual custom queries in modelqueries.go
16+
// are synced with the autogenerated queries.sql.go. This should probably be
17+
// autogenerated, but it's not atm and this is easy to throw in to elevate a better
18+
// error message.
19+
//
20+
// If this breaks, and is hard to fix, you can t.Skip() it. It is not a critical
21+
// test. Ping @Emyrk to fix it again.
22+
func TestCustomQueriesSyncedRowScan(t *testing.T) {
23+
t.Parallel()
24+
25+
funcsToTrack := map[string]string{
26+
"GetTemplatesWithFilter": "GetAuthorizedTemplates",
27+
"GetWorkspaces": "GetAuthorizedWorkspaces",
28+
"GetUsers": "GetAuthorizedUsers",
29+
}
30+
31+
// Scan custom
32+
var custom []string
33+
for _, fn := range funcsToTrack {
34+
custom = append(custom, fn)
35+
}
36+
37+
customFns := parseFile(t, "../modelqueries.go", func(name string) bool {
38+
return slices.Contains(custom, name)
39+
})
40+
generatedFns := parseFile(t, "../queries.sql.go", func(name string) bool {
41+
_, ok := funcsToTrack[name]
42+
return ok
43+
})
44+
merged := customFns
45+
for k, v := range generatedFns {
46+
merged[k] = v
47+
}
48+
49+
for a, b := range funcsToTrack {
50+
a, b := a, b
51+
if !compareFns(t, a, b, merged[a], merged[b]) {
52+
//nolint:revive
53+
defer func() {
54+
// Run this at the end so the suggested fix is the last thing printed.
55+
t.Errorf("The functions %q and %q need to have identical 'rows.Scan()' "+
56+
"and 'db.QueryContext()' arguments in their function bodies. "+
57+
"Make sure to copy the function body from the autogenerated %q body. "+
58+
"Specifically the parameters for 'rows.Scan()' and 'db.QueryContext()'.", a, b, a)
59+
}()
60+
}
61+
}
62+
}
63+
64+
type parsedFunc struct {
65+
RowScanArgs []ast.Expr
66+
QueryArgs []ast.Expr
67+
}
68+
69+
func parseFile(t *testing.T, filename string, trackFunc func(name string) bool) map[string]*parsedFunc {
70+
fset := token.NewFileSet()
71+
f, err := parser.ParseFile(fset, filename, nil, parser.SkipObjectResolution)
72+
require.NoErrorf(t, err, "failed to parse file %q", filename)
73+
74+
parsed := make(map[string]*parsedFunc)
75+
for _, decl := range f.Decls {
76+
if fn, ok := decl.(*ast.FuncDecl); ok {
77+
if trackFunc(fn.Name.Name) {
78+
parsed[fn.Name.String()] = &parsedFunc{
79+
RowScanArgs: pullRowScanArgs(fn),
80+
QueryArgs: pullQueryArgs(fn),
81+
}
82+
}
83+
}
84+
}
85+
86+
return parsed
87+
}
88+
89+
func compareFns(t *testing.T, aName, bName string, a, b *parsedFunc) bool {
90+
if a == nil {
91+
t.Errorf("The function %q is missing", aName)
92+
return false
93+
}
94+
if b == nil {
95+
t.Errorf("The function %q is missing", bName)
96+
return false
97+
}
98+
r := compareArgs(t, "rows.Scan() arguments", aName, bName, a.RowScanArgs, b.RowScanArgs)
99+
if len(a.QueryArgs) > 2 && len(b.QueryArgs) > 2 {
100+
// This is because the actual query param name is different. One uses the
101+
// const, the other uses a variable that is a mutation of the original query.
102+
a.QueryArgs[1] = b.QueryArgs[1]
103+
}
104+
q := compareArgs(t, "db.QueryContext() arguments", aName, bName, a.QueryArgs, b.QueryArgs)
105+
return r && q
106+
}
107+
108+
func compareArgs(t *testing.T, argType string, aName, bName string, a, b []ast.Expr) bool {
109+
return assert.Equal(t, argList(t, a), argList(t, b), "mismatched %s for %s and %s", argType, aName, bName)
110+
}
111+
112+
func argList(t *testing.T, args []ast.Expr) []string {
113+
defer func() {
114+
if r := recover(); r != nil {
115+
t.Errorf("Recovered in f reading arg names: %s", r)
116+
}
117+
}()
118+
119+
var argNames []string
120+
for _, arg := range args {
121+
argname := "unknown"
122+
// This is "&i.Arg" style stuff
123+
if unary, ok := arg.(*ast.UnaryExpr); ok {
124+
argname = unary.X.(*ast.SelectorExpr).Sel.Name
125+
}
126+
if ident, ok := arg.(*ast.Ident); ok {
127+
argname = ident.Name
128+
}
129+
if sel, ok := arg.(*ast.SelectorExpr); ok {
130+
argname = sel.Sel.Name
131+
}
132+
if call, ok := arg.(*ast.CallExpr); ok {
133+
// Eh, this is pg.Array style stuff. Do a best effort.
134+
argname = fmt.Sprintf("call(%d)", len(call.Args))
135+
if fnCall, ok := call.Fun.(*ast.SelectorExpr); ok {
136+
argname = fmt.Sprintf("%s(%d)", fnCall.Sel.Name, len(call.Args))
137+
}
138+
}
139+
140+
if argname == "unknown" {
141+
t.Errorf("Unknown arg, cannot parse: %T", arg)
142+
}
143+
argNames = append(argNames, argname)
144+
}
145+
return argNames
146+
}
147+
148+
func pullQueryArgs(fn *ast.FuncDecl) []ast.Expr {
149+
for _, exp := range fn.Body.List {
150+
// find "rows, err :="
151+
if assign, ok := exp.(*ast.AssignStmt); ok {
152+
if len(assign.Lhs) == 2 {
153+
if id, ok := assign.Lhs[0].(*ast.Ident); ok && id.Name == "rows" {
154+
// This is rows, err :=
155+
query := assign.Rhs[0].(*ast.CallExpr)
156+
if qSel, ok := query.Fun.(*ast.SelectorExpr); ok && qSel.Sel.Name == "QueryContext" {
157+
return query.Args
158+
}
159+
}
160+
}
161+
}
162+
}
163+
return nil
164+
}
165+
166+
func pullRowScanArgs(fn *ast.FuncDecl) []ast.Expr {
167+
for _, exp := range fn.Body.List {
168+
if forStmt, ok := exp.(*ast.ForStmt); ok {
169+
// This came from the debugger window and tracking it down.
170+
rowScan := (forStmt.Body.
171+
// Second statement in the for loop is the if statement
172+
// with rows.can
173+
List[1].(*ast.IfStmt).
174+
// This is the err := rows.Scan()
175+
Init.(*ast.AssignStmt).
176+
// Rhs is the row.Scan part
177+
Rhs)[0].(*ast.CallExpr)
178+
return rowScan.Args
179+
}
180+
}
181+
return nil
182+
}

0 commit comments

Comments
 (0)