1
1
package main
2
2
3
3
import (
4
+ "fmt"
4
5
"go/format"
5
6
"go/token"
6
7
"log"
7
8
"os"
9
+ "path"
8
10
9
11
"github.com/dave/dst"
10
12
"github.com/dave/dst/decorator"
@@ -65,6 +67,76 @@ func run() error {
65
67
}
66
68
67
69
for _ , fn := range funcs {
70
+ var bodyStmts []dst.Stmt
71
+ if len (fn .Func .Params .List ) == 2 && fn .Func .Params .List [1 ].Names [0 ].Name == "arg" {
72
+ /*
73
+ err := validateDatabaseType(arg)
74
+ if err != nil {
75
+ return database.User{}, err
76
+ }
77
+ */
78
+ bodyStmts = append (bodyStmts , & dst.AssignStmt {
79
+ Lhs : []dst.Expr {dst .NewIdent ("err" )},
80
+ Tok : token .DEFINE ,
81
+ Rhs : []dst.Expr {
82
+ & dst.CallExpr {
83
+ Fun : & dst.Ident {
84
+ Name : "validateDatabaseType" ,
85
+ },
86
+ Args : []dst.Expr {dst .NewIdent ("arg" )},
87
+ },
88
+ },
89
+ })
90
+ returnStmt := & dst.ReturnStmt {
91
+ Results : []dst.Expr {}, // Filled below.
92
+ }
93
+ bodyStmts = append (bodyStmts , & dst.IfStmt {
94
+ Cond : & dst.BinaryExpr {
95
+ X : dst .NewIdent ("err" ),
96
+ Op : token .NEQ ,
97
+ Y : dst .NewIdent ("nil" ),
98
+ },
99
+ Body : & dst.BlockStmt {
100
+ List : []dst.Stmt {
101
+ returnStmt ,
102
+ },
103
+ },
104
+ Decs : dst.IfStmtDecorations {
105
+ NodeDecs : dst.NodeDecs {
106
+ After : dst .EmptyLine ,
107
+ },
108
+ },
109
+ })
110
+ for _ , r := range fn .Func .Results .List {
111
+ switch typ := r .Type .(type ) {
112
+ case * dst.StarExpr , * dst.ArrayType :
113
+ returnStmt .Results = append (returnStmt .Results , dst .NewIdent ("nil" ))
114
+ case * dst.Ident :
115
+ if typ .Path != "" {
116
+ returnStmt .Results = append (returnStmt .Results , dst .NewIdent (fmt .Sprintf ("%s.%s{}" , path .Base (typ .Path ), typ .Name )))
117
+ } else {
118
+ switch typ .Name {
119
+ case "uint8" , "uint16" , "uint32" , "uint64" , "uint" , "uintptr" ,
120
+ "int8" , "int16" , "int32" , "int64" , "int" ,
121
+ "byte" , "rune" ,
122
+ "float32" , "float64" ,
123
+ "complex64" , "complex128" :
124
+ returnStmt .Results = append (returnStmt .Results , dst .NewIdent ("0" ))
125
+ case "string" :
126
+ returnStmt .Results = append (returnStmt .Results , dst .NewIdent ("\" \" " ))
127
+ case "bool" :
128
+ returnStmt .Results = append (returnStmt .Results , dst .NewIdent ("false" ))
129
+ case "error" :
130
+ returnStmt .Results = append (returnStmt .Results , dst .NewIdent ("err" ))
131
+ default :
132
+ panic (fmt .Sprintf ("unknown ident: %#v" , r .Type ))
133
+ }
134
+ }
135
+ default :
136
+ panic (fmt .Sprintf ("unknown return type: %T" , r .Type ))
137
+ }
138
+ }
139
+ }
68
140
decl , ok := declByName [fn .Name ]
69
141
if ! ok {
70
142
// Not implemented!
@@ -90,21 +162,19 @@ func run() error {
90
162
},
91
163
},
92
164
Body : & dst.BlockStmt {
93
- List : []dst.Stmt {
94
- & dst.ExprStmt {
95
- X : & dst.CallExpr {
96
- Fun : & dst.Ident {
97
- Name : "panic" ,
98
- },
99
- Args : []dst.Expr {
100
- & dst.BasicLit {
101
- Kind : token .STRING ,
102
- Value : "\" Not implemented\" " ,
103
- },
165
+ List : append (bodyStmts , & dst.ExprStmt {
166
+ X : & dst.CallExpr {
167
+ Fun : & dst.Ident {
168
+ Name : "panic" ,
169
+ },
170
+ Args : []dst.Expr {
171
+ & dst.BasicLit {
172
+ Kind : token .STRING ,
173
+ Value : "\" Not implemented\" " ,
104
174
},
105
175
},
106
176
},
107
- },
177
+ }) ,
108
178
},
109
179
}
110
180
}
@@ -178,9 +248,25 @@ func readStoreInterface() ([]storeMethod, error) {
178
248
if t == nil {
179
249
continue
180
250
}
251
+ var (
252
+ ident * dst.Ident
253
+ ok bool
254
+ )
181
255
for _ , f := range t .List {
182
- ident , ok := f .Type .(* dst.Ident )
183
- if ! ok {
256
+ switch typ := f .Type .(type ) {
257
+ case * dst.StarExpr :
258
+ ident , ok = typ .X .(* dst.Ident )
259
+ if ! ok {
260
+ continue
261
+ }
262
+ case * dst.ArrayType :
263
+ ident , ok = typ .Elt .(* dst.Ident )
264
+ if ! ok {
265
+ continue
266
+ }
267
+ case * dst.Ident :
268
+ ident = typ
269
+ default :
184
270
continue
185
271
}
186
272
if ! ident .IsExported () {
0 commit comments