1
1
package gentest
2
2
3
3
import (
4
+ "fmt"
4
5
"go/ast"
5
6
"go/parser"
6
7
"go/token"
@@ -26,51 +27,65 @@ func TestCustomQueriesSyncedRowScan(t *testing.T) {
26
27
"GetWorkspaces" : "GetAuthorizedWorkspaces" ,
27
28
"GetUsers" : "GetAuthorizedUsers" ,
28
29
}
29
- parsedRowScanArgs := make (map [string ][]ast.Expr )
30
30
31
31
// Scan custom
32
32
var custom []string
33
33
for _ , fn := range funcsToTrack {
34
34
custom = append (custom , fn )
35
35
}
36
- fset := token .NewFileSet ()
37
- f , err := parser .ParseFile (fset , "../modelqueries.go" , nil , parser .SkipObjectResolution )
38
- require .NoError (t , err )
39
36
40
- for _ , decl := range f .Decls {
41
- if fn , ok := decl .(* ast.FuncDecl ); ok {
42
- if slices .Contains (custom , fn .Name .Name ) {
43
- parsedRowScanArgs [fn .Name .String ()] = pullRowScanArgs (fn )
44
- }
45
- }
46
- }
47
-
48
- fset = token .NewFileSet ()
49
- f , err = parser .ParseFile (fset , "../queries.sql.go" , nil , parser .SkipObjectResolution )
50
- require .NoError (t , err )
51
-
52
- for _ , decl := range f .Decls {
53
- if fn , ok := decl .(* ast.FuncDecl ); ok {
54
- if _ , ok := funcsToTrack [fn .Name .Name ]; ok {
55
- parsedRowScanArgs [fn .Name .String ()] = pullRowScanArgs (fn )
56
- }
57
- }
37
+ customFns := parseFile (t , "../modelqueries.go" , func (name string ) bool {
38
+ return slices .Contains (custom , name )
39
+ })
40
+ generatedFns := parseFile (t , "../queries.sql.go" , func (name string ) bool {
41
+ _ , ok := funcsToTrack [name ]
42
+ return ok
43
+ })
44
+ merged := customFns
45
+ for k , v := range generatedFns {
46
+ merged [k ] = v
58
47
}
59
48
60
49
for a , b := range funcsToTrack {
61
- if ! compareArgs (t , a , b , parsedRowScanArgs [a ], parsedRowScanArgs [b ]) {
50
+ if ! compareFns (t , a , b , merged [a ], merged [b ]) {
62
51
defer func () {
63
52
// Run this at the end so the suggested fix is the last thing printed.
64
- t .Errorf ("The functions %q and %q need to have identical 'rows.Scan()' in their function bodies. " +
53
+ t .Errorf ("The functions %q and %q need to have identical 'rows.Scan()' " +
54
+ "and 'db.QueryContext()' arguments in their function bodies. " +
65
55
"Make sure to copy the function body from the autogenerated %q body. " +
66
- "Specifically the parameters for 'rows.Scan()'." , a , b , a )
56
+ "Specifically the parameters for 'rows.Scan()' and 'db.QueryContext()' ." , a , b , a )
67
57
}()
68
58
}
69
59
}
70
60
71
61
}
72
62
73
- func compareArgs (t * testing.T , aName , bName string , a , b []ast.Expr ) bool {
63
+ type parsedFunc struct {
64
+ RowScanArgs []ast.Expr
65
+ QueryArgs []ast.Expr
66
+ }
67
+
68
+ func parseFile (t * testing.T , filename string , trackFunc func (name string ) bool ) map [string ]* parsedFunc {
69
+ fset := token .NewFileSet ()
70
+ f , err := parser .ParseFile (fset , filename , nil , parser .SkipObjectResolution )
71
+ require .NoErrorf (t , err , "failed to parse file %q" , filename )
72
+
73
+ parsed := make (map [string ]* parsedFunc )
74
+ for _ , decl := range f .Decls {
75
+ if fn , ok := decl .(* ast.FuncDecl ); ok {
76
+ if trackFunc (fn .Name .Name ) {
77
+ parsed [fn .Name .String ()] = & parsedFunc {
78
+ RowScanArgs : pullRowScanArgs (fn ),
79
+ QueryArgs : pullQueryArgs (fn ),
80
+ }
81
+ }
82
+ }
83
+ }
84
+
85
+ return parsed
86
+ }
87
+
88
+ func compareFns (t * testing.T , aName , bName string , a , b * parsedFunc ) bool {
74
89
if a == nil {
75
90
t .Errorf ("The function %q is missing" , aName )
76
91
return false
@@ -79,7 +94,18 @@ func compareArgs(t *testing.T, aName, bName string, a, b []ast.Expr) bool {
79
94
t .Errorf ("The function %q is missing" , bName )
80
95
return false
81
96
}
82
- return assert .Equal (t , argList (t , a ), argList (t , b ), "mismatched args for %s and %s" , aName , bName )
97
+ r := compareArgs (t , "rows.Scan() arguments" , aName , bName , a .RowScanArgs , b .RowScanArgs )
98
+ if len (a .QueryArgs ) > 2 && len (b .QueryArgs ) > 2 {
99
+ // This is because the actual query param name is different. One uses the
100
+ // const, the other uses a variable that is a mutation of the original query.
101
+ a .QueryArgs [1 ] = b .QueryArgs [1 ]
102
+ }
103
+ q := compareArgs (t , "db.QueryContext() arguments" , aName , bName , a .QueryArgs , b .QueryArgs )
104
+ return r && q
105
+ }
106
+
107
+ func compareArgs (t * testing.T , argType string , aName , bName string , a , b []ast.Expr ) bool {
108
+ return assert .Equal (t , argList (t , a ), argList (t , b ), "mismatched %s for %s and %s" , argType , aName , bName )
83
109
}
84
110
85
111
func argList (t * testing.T , args []ast.Expr ) []string {
@@ -91,12 +117,51 @@ func argList(t *testing.T, args []ast.Expr) []string {
91
117
92
118
var argNames []string
93
119
for _ , arg := range args {
94
- sel := arg .(* ast.UnaryExpr ).X .(* ast.SelectorExpr ).Sel
95
- argNames = append (argNames , sel .Name )
120
+ argname := "unknown"
121
+ // This is "&i.Arg" style stuff
122
+ if unary , ok := arg .(* ast.UnaryExpr ); ok {
123
+ argname = unary .X .(* ast.SelectorExpr ).Sel .Name
124
+ }
125
+ if ident , ok := arg .(* ast.Ident ); ok {
126
+ argname = ident .Name
127
+ }
128
+ if sel , ok := arg .(* ast.SelectorExpr ); ok {
129
+ argname = sel .Sel .Name
130
+ }
131
+ if call , ok := arg .(* ast.CallExpr ); ok {
132
+ // Eh, this is pg.Array style stuff. Do a best effort.
133
+ argname = fmt .Sprintf ("call(%d)" , len (call .Args ))
134
+ if fnCall , ok := call .Fun .(* ast.SelectorExpr ); ok {
135
+ argname = fmt .Sprintf ("%s(%d)" , fnCall .Sel .Name , len (call .Args ))
136
+ }
137
+ }
138
+
139
+ if argname == "unknown" {
140
+ t .Errorf ("Unknown arg, cannot parse: %T" , arg )
141
+ }
142
+ argNames = append (argNames , argname )
96
143
}
97
144
return argNames
98
145
}
99
146
147
+ func pullQueryArgs (fn * ast.FuncDecl ) []ast.Expr {
148
+ for _ , exp := range fn .Body .List {
149
+ // find "rows, err :="
150
+ if assign , ok := exp .(* ast.AssignStmt ); ok {
151
+ if len (assign .Lhs ) == 2 {
152
+ if id , ok := assign .Lhs [0 ].(* ast.Ident ); ok && id .Name == "rows" {
153
+ // This is rows, err :=
154
+ query := assign .Rhs [0 ].(* ast.CallExpr )
155
+ if qSel , ok := query .Fun .(* ast.SelectorExpr ); ok && qSel .Sel .Name == "QueryContext" {
156
+ return query .Args
157
+ }
158
+ }
159
+ }
160
+ }
161
+ }
162
+ return nil
163
+ }
164
+
100
165
func pullRowScanArgs (fn * ast.FuncDecl ) []ast.Expr {
101
166
for _ , exp := range fn .Body .List {
102
167
if forStmt , ok := exp .(* ast.ForStmt ); ok {
0 commit comments