Skip to content

Commit af45e64

Browse files
authored
chore(coderd/database/gen): improve generated fake stub (#8088)
* chore(coderd/database/gen): generate arg validation where applicable * fix(coderd/database/gen): support pointers and slices as return types
1 parent f444100 commit af45e64

File tree

1 file changed

+100
-14
lines changed

1 file changed

+100
-14
lines changed

coderd/database/gen/fake/main.go

Lines changed: 100 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
package main
22

33
import (
4+
"fmt"
45
"go/format"
56
"go/token"
67
"log"
78
"os"
9+
"path"
810

911
"github.com/dave/dst"
1012
"github.com/dave/dst/decorator"
@@ -65,6 +67,76 @@ func run() error {
6567
}
6668

6769
for _, fn := range funcs {
70+
var bodyStmts []dst.Stmt
71+
if len(fn.Func.Params.List) == 2 && fn.Func.Params.List[1].Names[0].Name == "arg" {
72+
/*
73+
err := validateDatabaseType(arg)
74+
if err != nil {
75+
return database.User{}, err
76+
}
77+
*/
78+
bodyStmts = append(bodyStmts, &dst.AssignStmt{
79+
Lhs: []dst.Expr{dst.NewIdent("err")},
80+
Tok: token.DEFINE,
81+
Rhs: []dst.Expr{
82+
&dst.CallExpr{
83+
Fun: &dst.Ident{
84+
Name: "validateDatabaseType",
85+
},
86+
Args: []dst.Expr{dst.NewIdent("arg")},
87+
},
88+
},
89+
})
90+
returnStmt := &dst.ReturnStmt{
91+
Results: []dst.Expr{}, // Filled below.
92+
}
93+
bodyStmts = append(bodyStmts, &dst.IfStmt{
94+
Cond: &dst.BinaryExpr{
95+
X: dst.NewIdent("err"),
96+
Op: token.NEQ,
97+
Y: dst.NewIdent("nil"),
98+
},
99+
Body: &dst.BlockStmt{
100+
List: []dst.Stmt{
101+
returnStmt,
102+
},
103+
},
104+
Decs: dst.IfStmtDecorations{
105+
NodeDecs: dst.NodeDecs{
106+
After: dst.EmptyLine,
107+
},
108+
},
109+
})
110+
for _, r := range fn.Func.Results.List {
111+
switch typ := r.Type.(type) {
112+
case *dst.StarExpr, *dst.ArrayType:
113+
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("nil"))
114+
case *dst.Ident:
115+
if typ.Path != "" {
116+
returnStmt.Results = append(returnStmt.Results, dst.NewIdent(fmt.Sprintf("%s.%s{}", path.Base(typ.Path), typ.Name)))
117+
} else {
118+
switch typ.Name {
119+
case "uint8", "uint16", "uint32", "uint64", "uint", "uintptr",
120+
"int8", "int16", "int32", "int64", "int",
121+
"byte", "rune",
122+
"float32", "float64",
123+
"complex64", "complex128":
124+
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("0"))
125+
case "string":
126+
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("\"\""))
127+
case "bool":
128+
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("false"))
129+
case "error":
130+
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("err"))
131+
default:
132+
panic(fmt.Sprintf("unknown ident: %#v", r.Type))
133+
}
134+
}
135+
default:
136+
panic(fmt.Sprintf("unknown return type: %T", r.Type))
137+
}
138+
}
139+
}
68140
decl, ok := declByName[fn.Name]
69141
if !ok {
70142
// Not implemented!
@@ -90,21 +162,19 @@ func run() error {
90162
},
91163
},
92164
Body: &dst.BlockStmt{
93-
List: []dst.Stmt{
94-
&dst.ExprStmt{
95-
X: &dst.CallExpr{
96-
Fun: &dst.Ident{
97-
Name: "panic",
98-
},
99-
Args: []dst.Expr{
100-
&dst.BasicLit{
101-
Kind: token.STRING,
102-
Value: "\"Not implemented\"",
103-
},
165+
List: append(bodyStmts, &dst.ExprStmt{
166+
X: &dst.CallExpr{
167+
Fun: &dst.Ident{
168+
Name: "panic",
169+
},
170+
Args: []dst.Expr{
171+
&dst.BasicLit{
172+
Kind: token.STRING,
173+
Value: "\"Not implemented\"",
104174
},
105175
},
106176
},
107-
},
177+
}),
108178
},
109179
}
110180
}
@@ -178,9 +248,25 @@ func readStoreInterface() ([]storeMethod, error) {
178248
if t == nil {
179249
continue
180250
}
251+
var (
252+
ident *dst.Ident
253+
ok bool
254+
)
181255
for _, f := range t.List {
182-
ident, ok := f.Type.(*dst.Ident)
183-
if !ok {
256+
switch typ := f.Type.(type) {
257+
case *dst.StarExpr:
258+
ident, ok = typ.X.(*dst.Ident)
259+
if !ok {
260+
continue
261+
}
262+
case *dst.ArrayType:
263+
ident, ok = typ.Elt.(*dst.Ident)
264+
if !ok {
265+
continue
266+
}
267+
case *dst.Ident:
268+
ident = typ
269+
default:
184270
continue
185271
}
186272
if !ident.IsExported() {

0 commit comments

Comments
 (0)