1
1
package main
2
2
3
3
import (
4
+ "go/format"
5
+ "go/token"
4
6
"log"
5
7
"os"
6
8
7
9
"github.com/dave/dst"
8
10
"github.com/dave/dst/decorator"
11
+ "github.com/dave/dst/decorator/resolver/goast"
12
+ "github.com/dave/dst/decorator/resolver/guess"
9
13
"golang.org/x/xerrors"
10
14
)
11
15
@@ -27,11 +31,13 @@ func run() error {
27
31
}
28
32
declByName := map [string ]* dst.FuncDecl {}
29
33
30
- dbfake , err := os .ReadFile ("../.. /dbfake/dbfake.go" )
34
+ dbfake , err := os .ReadFile ("./dbfake/dbfake.go" )
31
35
if err != nil {
32
36
return xerrors .Errorf ("read dbfake: %w" , err )
33
37
}
34
- f , err := decorator .Parse (dbfake )
38
+
39
+ // Required to preserve imports!
40
+ f , err := decorator .NewDecoratorWithImports (token .NewFileSet (), "dbfake" , goast .New ()).Parse (dbfake )
35
41
if err != nil {
36
42
return xerrors .Errorf ("parse dbfake: %w" , err )
37
43
}
@@ -64,7 +70,13 @@ func run() error {
64
70
// Not implemented!
65
71
decl = & dst.FuncDecl {
66
72
Name : dst .NewIdent (fn .Name ),
67
- Type : fn .Func ,
73
+ Type : & dst.FuncType {
74
+ Func : true ,
75
+ TypeParams : fn .Func .TypeParams ,
76
+ Params : fn .Func .Params ,
77
+ Results : fn .Func .Results ,
78
+ Decs : fn .Func .Decs ,
79
+ },
68
80
Recv : & dst.FieldList {
69
81
List : []* dst.Field {{
70
82
Names : []* dst.Ident {dst .NewIdent ("q" )},
@@ -81,7 +93,7 @@ func run() error {
81
93
Decs : dst.BlockStmtDecorations {
82
94
Lbrace : dst.Decorations {
83
95
"\n " ,
84
- "// Implement me! " ,
96
+ "// Not implemented " ,
85
97
},
86
98
},
87
99
},
@@ -90,18 +102,20 @@ func run() error {
90
102
f .Decls = append (f .Decls , decl )
91
103
}
92
104
93
- file , err := os .OpenFile ("../.. /dbfake/dbfake.go" , os .O_RDWR | os .O_CREATE | os .O_TRUNC , 0755 )
105
+ file , err := os .OpenFile ("./dbfake/dbfake.go" , os .O_RDWR | os .O_CREATE | os .O_TRUNC , 0755 )
94
106
if err != nil {
95
107
return xerrors .Errorf ("open dbfake: %w" , err )
96
108
}
97
109
defer file .Close ()
98
110
99
- err = decorator .Fprint (file , f )
111
+ // Required to preserve imports!
112
+ restorer := decorator .NewRestorerWithImports ("dbfake" , guess .New ())
113
+ restored , err := restorer .RestoreFile (f )
100
114
if err != nil {
101
- return xerrors .Errorf ("write dbfake: %w" , err )
115
+ return xerrors .Errorf ("restore dbfake: %w" , err )
102
116
}
103
-
104
- return nil
117
+ err = format . Node ( file , restorer . Fset , restored )
118
+ return err
105
119
}
106
120
107
121
type storeMethod struct {
@@ -110,7 +124,7 @@ type storeMethod struct {
110
124
}
111
125
112
126
func readStoreInterface () ([]storeMethod , error ) {
113
- querier , err := os .ReadFile ("../.. /querier.go" )
127
+ querier , err := os .ReadFile ("./querier.go" )
114
128
if err != nil {
115
129
return nil , xerrors .Errorf ("read querier: %w" , err )
116
130
}
@@ -150,6 +164,23 @@ func readStoreInterface() ([]storeMethod, error) {
150
164
if ! ok {
151
165
continue
152
166
}
167
+
168
+ for _ , t := range []* dst.FieldList {funcType .Params , funcType .Results } {
169
+ if t == nil {
170
+ continue
171
+ }
172
+ for _ , f := range t .List {
173
+ ident , ok := f .Type .(* dst.Ident )
174
+ if ! ok {
175
+ continue
176
+ }
177
+ if ! ident .IsExported () {
178
+ continue
179
+ }
180
+ ident .Path = "github.com/coder/coder/coderd/database"
181
+ }
182
+ }
183
+
153
184
funcs = append (funcs , storeMethod {
154
185
Name : method .Names [0 ].Name ,
155
186
Func : funcType ,
0 commit comments