Skip to content

Commit f1ccd71

Browse files
committed
chore: unit test to enforce authorized queries match args
1 parent 0cd4842 commit f1ccd71

File tree

1 file changed

+116
-0
lines changed

1 file changed

+116
-0
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
package gentest
2+
3+
import (
4+
"go/ast"
5+
"go/parser"
6+
"go/token"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
"golang.org/x/exp/slices"
12+
)
13+
14+
// TestCustomQueriesSynced makes sure the manual custom queries in modelqueries.go
15+
// are synced with the autogenerated queries.sql.go. This should probably be
16+
// autogenerated, but it's not atm and this is easy to throw in to elevate a better
17+
// error message.
18+
//
19+
// If this breaks, and is hard to fix, you can t.Skip() it. It is not a critical
20+
// test. Ping @Emyrk to fix it again.
21+
func TestCustomQueriesSyncedRowScan(t *testing.T) {
22+
t.Parallel()
23+
24+
funcsToTrack := map[string]string{
25+
"GetTemplates": "GetAuthorizedTemplates",
26+
"GetWorkspaces": "GetAuthorizedWorkspaces",
27+
"GetUsers": "GetAuthorizedUsers",
28+
}
29+
parsedRowScanArgs := make(map[string][]ast.Expr)
30+
31+
// Scan custom
32+
var custom []string
33+
for _, fn := range funcsToTrack {
34+
custom = append(custom, fn)
35+
}
36+
fset := token.NewFileSet()
37+
f, err := parser.ParseFile(fset, "../modelqueries.go", nil, parser.SkipObjectResolution)
38+
require.NoError(t, err)
39+
40+
for _, decl := range f.Decls {
41+
if fn, ok := decl.(*ast.FuncDecl); ok {
42+
if slices.Contains(custom, fn.Name.Name) {
43+
parsedRowScanArgs[fn.Name.String()] = pullRowScanArgs(fn)
44+
}
45+
}
46+
}
47+
48+
fset = token.NewFileSet()
49+
f, err = parser.ParseFile(fset, "../queries.sql.go", nil, parser.SkipObjectResolution)
50+
require.NoError(t, err)
51+
52+
for _, decl := range f.Decls {
53+
if fn, ok := decl.(*ast.FuncDecl); ok {
54+
if _, ok := funcsToTrack[fn.Name.Name]; ok {
55+
parsedRowScanArgs[fn.Name.String()] = pullRowScanArgs(fn)
56+
}
57+
}
58+
}
59+
60+
for a, b := range funcsToTrack {
61+
if !compareArgs(t, a, b, parsedRowScanArgs[a], parsedRowScanArgs[b]) {
62+
defer func() {
63+
// Run this at the end so the suggested fix is the last thing printed.
64+
t.Errorf("The functions %q and %q need to have identical 'rows.Scan()' in their function bodies. "+
65+
"Make sure to copy the function body from the autogenerated %q body. "+
66+
"Specifically the parameters for 'rows.Scan()'.", a, b, a)
67+
}()
68+
}
69+
}
70+
71+
}
72+
73+
func compareArgs(t *testing.T, aName, bName string, a, b []ast.Expr) bool {
74+
if a == nil {
75+
t.Errorf("The function %q is missing", aName)
76+
return false
77+
}
78+
if b == nil {
79+
t.Errorf("The function %q is missing", bName)
80+
return false
81+
}
82+
return assert.Equal(t, argList(t, a), argList(t, b), "mismatched args for %s and %s", aName, bName)
83+
}
84+
85+
func argList(t *testing.T, args []ast.Expr) []string {
86+
defer func() {
87+
if r := recover(); r != nil {
88+
t.Errorf("Recovered in f reading arg names: %s", r)
89+
}
90+
}()
91+
92+
var argNames []string
93+
for _, arg := range args {
94+
sel := arg.(*ast.UnaryExpr).X.(*ast.SelectorExpr).Sel
95+
argNames = append(argNames, sel.Name)
96+
}
97+
return argNames
98+
}
99+
100+
func pullRowScanArgs(fn *ast.FuncDecl) []ast.Expr {
101+
for _, exp := range fn.Body.List {
102+
if forStmt, ok := exp.(*ast.ForStmt); ok {
103+
// This came from the debugger window and tracking it down.
104+
rowScan := (forStmt.Body.
105+
// Second statement in the for loop is the if statement
106+
// with rows.can
107+
List[1].(*ast.IfStmt).
108+
// This is the err := rows.Scan()
109+
Init.(*ast.AssignStmt).
110+
// Rhs is the row.Scan part
111+
Rhs)[0].(*ast.CallExpr)
112+
return rowScan.Args
113+
}
114+
}
115+
return nil
116+
}

0 commit comments

Comments
 (0)