Skip to content

Commit c1efec1

Browse files
committed
Also check querycontext arguments
1 parent f1ccd71 commit c1efec1

File tree

1 file changed

+94
-29
lines changed

1 file changed

+94
-29
lines changed

coderd/database/gentest/modelqueries_test.go

+94-29
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package gentest
22

33
import (
4+
"fmt"
45
"go/ast"
56
"go/parser"
67
"go/token"
@@ -26,51 +27,65 @@ func TestCustomQueriesSyncedRowScan(t *testing.T) {
2627
"GetWorkspaces": "GetAuthorizedWorkspaces",
2728
"GetUsers": "GetAuthorizedUsers",
2829
}
29-
parsedRowScanArgs := make(map[string][]ast.Expr)
3030

3131
// Scan custom
3232
var custom []string
3333
for _, fn := range funcsToTrack {
3434
custom = append(custom, fn)
3535
}
36-
fset := token.NewFileSet()
37-
f, err := parser.ParseFile(fset, "../modelqueries.go", nil, parser.SkipObjectResolution)
38-
require.NoError(t, err)
3936

40-
for _, decl := range f.Decls {
41-
if fn, ok := decl.(*ast.FuncDecl); ok {
42-
if slices.Contains(custom, fn.Name.Name) {
43-
parsedRowScanArgs[fn.Name.String()] = pullRowScanArgs(fn)
44-
}
45-
}
46-
}
47-
48-
fset = token.NewFileSet()
49-
f, err = parser.ParseFile(fset, "../queries.sql.go", nil, parser.SkipObjectResolution)
50-
require.NoError(t, err)
51-
52-
for _, decl := range f.Decls {
53-
if fn, ok := decl.(*ast.FuncDecl); ok {
54-
if _, ok := funcsToTrack[fn.Name.Name]; ok {
55-
parsedRowScanArgs[fn.Name.String()] = pullRowScanArgs(fn)
56-
}
57-
}
37+
customFns := parseFile(t, "../modelqueries.go", func(name string) bool {
38+
return slices.Contains(custom, name)
39+
})
40+
generatedFns := parseFile(t, "../queries.sql.go", func(name string) bool {
41+
_, ok := funcsToTrack[name]
42+
return ok
43+
})
44+
merged := customFns
45+
for k, v := range generatedFns {
46+
merged[k] = v
5847
}
5948

6049
for a, b := range funcsToTrack {
61-
if !compareArgs(t, a, b, parsedRowScanArgs[a], parsedRowScanArgs[b]) {
50+
if !compareFns(t, a, b, merged[a], merged[b]) {
6251
defer func() {
6352
// Run this at the end so the suggested fix is the last thing printed.
64-
t.Errorf("The functions %q and %q need to have identical 'rows.Scan()' in their function bodies. "+
53+
t.Errorf("The functions %q and %q need to have identical 'rows.Scan()' "+
54+
"and 'db.QueryContext()' arguments in their function bodies. "+
6555
"Make sure to copy the function body from the autogenerated %q body. "+
66-
"Specifically the parameters for 'rows.Scan()'.", a, b, a)
56+
"Specifically the parameters for 'rows.Scan()' and 'db.QueryContext()'.", a, b, a)
6757
}()
6858
}
6959
}
7060

7161
}
7262

73-
func compareArgs(t *testing.T, aName, bName string, a, b []ast.Expr) bool {
63+
type parsedFunc struct {
64+
RowScanArgs []ast.Expr
65+
QueryArgs []ast.Expr
66+
}
67+
68+
func parseFile(t *testing.T, filename string, trackFunc func(name string) bool) map[string]*parsedFunc {
69+
fset := token.NewFileSet()
70+
f, err := parser.ParseFile(fset, filename, nil, parser.SkipObjectResolution)
71+
require.NoErrorf(t, err, "failed to parse file %q", filename)
72+
73+
parsed := make(map[string]*parsedFunc)
74+
for _, decl := range f.Decls {
75+
if fn, ok := decl.(*ast.FuncDecl); ok {
76+
if trackFunc(fn.Name.Name) {
77+
parsed[fn.Name.String()] = &parsedFunc{
78+
RowScanArgs: pullRowScanArgs(fn),
79+
QueryArgs: pullQueryArgs(fn),
80+
}
81+
}
82+
}
83+
}
84+
85+
return parsed
86+
}
87+
88+
func compareFns(t *testing.T, aName, bName string, a, b *parsedFunc) bool {
7489
if a == nil {
7590
t.Errorf("The function %q is missing", aName)
7691
return false
@@ -79,7 +94,18 @@ func compareArgs(t *testing.T, aName, bName string, a, b []ast.Expr) bool {
7994
t.Errorf("The function %q is missing", bName)
8095
return false
8196
}
82-
return assert.Equal(t, argList(t, a), argList(t, b), "mismatched args for %s and %s", aName, bName)
97+
r := compareArgs(t, "rows.Scan() arguments", aName, bName, a.RowScanArgs, b.RowScanArgs)
98+
if len(a.QueryArgs) > 2 && len(b.QueryArgs) > 2 {
99+
// This is because the actual query param name is different. One uses the
100+
// const, the other uses a variable that is a mutation of the original query.
101+
a.QueryArgs[1] = b.QueryArgs[1]
102+
}
103+
q := compareArgs(t, "db.QueryContext() arguments", aName, bName, a.QueryArgs, b.QueryArgs)
104+
return r && q
105+
}
106+
107+
func compareArgs(t *testing.T, argType string, aName, bName string, a, b []ast.Expr) bool {
108+
return assert.Equal(t, argList(t, a), argList(t, b), "mismatched %s for %s and %s", argType, aName, bName)
83109
}
84110

85111
func argList(t *testing.T, args []ast.Expr) []string {
@@ -91,12 +117,51 @@ func argList(t *testing.T, args []ast.Expr) []string {
91117

92118
var argNames []string
93119
for _, arg := range args {
94-
sel := arg.(*ast.UnaryExpr).X.(*ast.SelectorExpr).Sel
95-
argNames = append(argNames, sel.Name)
120+
argname := "unknown"
121+
// This is "&i.Arg" style stuff
122+
if unary, ok := arg.(*ast.UnaryExpr); ok {
123+
argname = unary.X.(*ast.SelectorExpr).Sel.Name
124+
}
125+
if ident, ok := arg.(*ast.Ident); ok {
126+
argname = ident.Name
127+
}
128+
if sel, ok := arg.(*ast.SelectorExpr); ok {
129+
argname = sel.Sel.Name
130+
}
131+
if call, ok := arg.(*ast.CallExpr); ok {
132+
// Eh, this is pg.Array style stuff. Do a best effort.
133+
argname = fmt.Sprintf("call(%d)", len(call.Args))
134+
if fnCall, ok := call.Fun.(*ast.SelectorExpr); ok {
135+
argname = fmt.Sprintf("%s(%d)", fnCall.Sel.Name, len(call.Args))
136+
}
137+
}
138+
139+
if argname == "unknown" {
140+
t.Errorf("Unknown arg, cannot parse: %T", arg)
141+
}
142+
argNames = append(argNames, argname)
96143
}
97144
return argNames
98145
}
99146

147+
func pullQueryArgs(fn *ast.FuncDecl) []ast.Expr {
148+
for _, exp := range fn.Body.List {
149+
// find "rows, err :="
150+
if assign, ok := exp.(*ast.AssignStmt); ok {
151+
if len(assign.Lhs) == 2 {
152+
if id, ok := assign.Lhs[0].(*ast.Ident); ok && id.Name == "rows" {
153+
// This is rows, err :=
154+
query := assign.Rhs[0].(*ast.CallExpr)
155+
if qSel, ok := query.Fun.(*ast.SelectorExpr); ok && qSel.Sel.Name == "QueryContext" {
156+
return query.Args
157+
}
158+
}
159+
}
160+
}
161+
}
162+
return nil
163+
}
164+
100165
func pullRowScanArgs(fn *ast.FuncDecl) []ast.Expr {
101166
for _, exp := range fn.Body.List {
102167
if forStmt, ok := exp.(*ast.ForStmt); ok {

0 commit comments

Comments
 (0)