Skip to content

Commit 68a3154

Browse files
committed
Reuse structs (wip)
1 parent 8218707 commit 68a3154

File tree

5 files changed

+309
-38
lines changed

5 files changed

+309
-38
lines changed

internal/codegen/golang/field.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,19 @@ type Field struct {
2020
EmbedFields []Field
2121
}
2222

23+
// Match returns true if the name and the type of the 2 fields are equal.
24+
func (gf Field) Match(other Field) bool {
25+
if gf.Name != other.Name {
26+
return false
27+
}
28+
29+
if gf.Type != other.Type {
30+
return false
31+
}
32+
33+
return true
34+
}
35+
2336
func (gf Field) Tag() string {
2437
return TagsToString(gf.Tags)
2538
}

internal/codegen/golang/field_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package golang
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func TestField_Match(t *testing.T) {
8+
t.Parallel()
9+
10+
field := Field{
11+
Name: "Name",
12+
Type: "string",
13+
}
14+
15+
tests := []struct {
16+
name string
17+
field Field
18+
want bool
19+
}{
20+
{
21+
name: "match",
22+
field: Field{
23+
Name: "Name",
24+
Type: "string",
25+
},
26+
want: true,
27+
},
28+
{
29+
name: "name mismatch",
30+
field: Field{
31+
Name: "OtherName",
32+
Type: "string",
33+
},
34+
},
35+
{
36+
name: "type mismatch",
37+
field: Field{
38+
Name: "Name",
39+
Type: "int",
40+
},
41+
},
42+
}
43+
44+
for _, tt := range tests {
45+
tt := tt
46+
47+
t.Run(tt.name, func(t *testing.T) {
48+
t.Parallel()
49+
50+
if got := field.Match(tt.field); got != tt.want {
51+
t.Errorf("Match() = %v, want %v", got, tt.want)
52+
}
53+
})
54+
}
55+
}

internal/codegen/golang/result.go

Lines changed: 45 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,21 @@ func argName(name string) string {
181181
return out
182182
}
183183

184-
func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error) {
184+
func lookups(req *plugin.CodeGenRequest, other Struct, ls ...Structs) (Struct, bool) {
185+
for _, s := range ls {
186+
if exists, found := s.Lookup(req, other); found {
187+
return exists, true
188+
}
189+
}
190+
191+
return other, false
192+
}
193+
194+
func buildQueries(req *plugin.CodeGenRequest, tableStructs Structs) ([]Query, error) {
195+
var queryStructs Structs
196+
185197
qs := make([]Query, 0, len(req.Queries))
198+
186199
for _, query := range req.Queries {
187200
if query.Name == "" {
188201
continue
@@ -233,8 +246,15 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
233246
if err != nil {
234247
return nil, err
235248
}
249+
250+
found := false
251+
252+
if req.Settings.Go.ReuseStructs {
253+
*s, found = lookups(req, *s, tableStructs, queryStructs)
254+
}
255+
236256
gq.Arg = QueryValue{
237-
Emit: true,
257+
Emit: !found,
238258
Name: "arg",
239259
Struct: s,
240260
SQLDriver: sqlpkg,
@@ -259,47 +279,34 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
259279
SQLDriver: sqlpkg,
260280
}
261281
} else if putOutColumns(query) {
262-
var gs *Struct
263-
var emit bool
282+
var columns []goColumn
283+
for i, c := range query.Columns {
284+
columns = append(columns, goColumn{
285+
id: i,
286+
Column: c,
287+
embed: newGoEmbed(c.EmbedTable, tableStructs, req.Catalog.DefaultSchema),
288+
})
289+
}
264290

265-
for _, s := range structs {
266-
if len(s.Fields) != len(query.Columns) {
267-
continue
268-
}
269-
same := true
270-
for i, f := range s.Fields {
271-
c := query.Columns[i]
272-
sameName := f.Name == StructName(columnName(c, i), req.Settings)
273-
sameType := f.Type == goType(req, c)
274-
sameTable := sdk.SameTableName(c.Table, s.Table, req.Catalog.DefaultSchema)
275-
if !sameName || !sameType || !sameTable {
276-
same = false
277-
}
278-
}
279-
if same {
280-
gs = &s
281-
break
282-
}
291+
gs, err := columnsToStruct(req, gq.MethodName+"Row", columns, true)
292+
if err != nil {
293+
return nil, err
283294
}
284295

285-
if gs == nil {
286-
var columns []goColumn
287-
for i, c := range query.Columns {
288-
columns = append(columns, goColumn{
289-
id: i,
290-
Column: c,
291-
embed: newGoEmbed(c.EmbedTable, structs, req.Catalog.DefaultSchema),
292-
})
293-
}
294-
var err error
295-
gs, err = columnsToStruct(req, gq.MethodName+"Row", columns, true)
296-
if err != nil {
297-
return nil, err
298-
}
299-
emit = true
296+
found := false
297+
298+
*gs, found = tableStructs.Lookup(req, *gs)
299+
300+
if !found && req.Settings.Go.ReuseStructs {
301+
*gs, found = queryStructs.Lookup(req, *gs)
300302
}
303+
304+
if !found {
305+
queryStructs = append(queryStructs, *gs)
306+
}
307+
301308
gq.Ret = QueryValue{
302-
Emit: emit,
309+
Emit: !found,
303310
Name: "i",
304311
Struct: gs,
305312
SQLDriver: sqlpkg,

internal/codegen/golang/struct.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"unicode"
66
"unicode/utf8"
77

8+
"github.com/sqlc-dev/sqlc/internal/codegen/sdk"
89
"github.com/sqlc-dev/sqlc/internal/plugin"
910
)
1011

@@ -15,6 +16,42 @@ type Struct struct {
1516
Comment string
1617
}
1718

19+
func (s Struct) Match(req *plugin.CodeGenRequest, other Struct) bool {
20+
if len(s.Fields) != len(other.Fields) {
21+
return false
22+
}
23+
24+
for i, f := range s.Fields {
25+
of := other.Fields[i]
26+
27+
if !f.Match(of) {
28+
return false
29+
}
30+
31+
if s.Table != nil && !sdk.SameTableName(of.Column.Table, s.Table, req.Catalog.DefaultSchema) {
32+
return false
33+
}
34+
}
35+
36+
return true
37+
}
38+
39+
type Structs []Struct
40+
41+
// Lookup search for a matching Struct in slice.
42+
//
43+
// - if found, returns the matching Struct and true
44+
// - else returns the given Struct and false
45+
func (s Structs) Lookup(req *plugin.CodeGenRequest, other Struct) (Struct, bool) {
46+
for _, exists := range s {
47+
if exists.Match(req, other) {
48+
return exists, true
49+
}
50+
}
51+
52+
return other, false
53+
}
54+
1855
func StructName(name string, settings *plugin.Settings) string {
1956
if rename := settings.Rename[name]; rename != "" {
2057
return rename
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
package golang
2+
3+
import (
4+
"reflect"
5+
"testing"
6+
7+
"github.com/sqlc-dev/sqlc/internal/plugin"
8+
)
9+
10+
func TestStruct_Match(t *testing.T) {
11+
t.Parallel()
12+
13+
req := &plugin.CodeGenRequest{Catalog: &plugin.Catalog{}}
14+
tableId := &plugin.Identifier{Name: "Table"}
15+
tableStruct := Struct{
16+
Table: tableId,
17+
Name: "Name",
18+
Fields: []Field{
19+
{Name: "Field", Type: "string"},
20+
},
21+
}
22+
23+
tests := []struct {
24+
name string
25+
str Struct
26+
other Struct
27+
want bool
28+
}{
29+
{
30+
name: "match",
31+
str: tableStruct,
32+
other: Struct{
33+
Table: tableId,
34+
Name: "Name",
35+
Fields: []Field{
36+
{Name: "Field", Type: "string", Column: &plugin.Column{Table: tableId}},
37+
},
38+
},
39+
want: true,
40+
},
41+
{
42+
name: "table mismatch",
43+
str: tableStruct,
44+
other: Struct{
45+
Table: tableId,
46+
Name: "Name",
47+
Fields: []Field{
48+
{Name: "Field", Type: "string", Column: &plugin.Column{Table: &plugin.Identifier{Name: "OtherTable"}}},
49+
},
50+
},
51+
},
52+
{
53+
name: "other table nil",
54+
str: tableStruct,
55+
other: Struct{
56+
Table: tableId,
57+
Name: "Name",
58+
Fields: []Field{
59+
{Name: "Field", Type: "string", Column: &plugin.Column{}},
60+
},
61+
},
62+
},
63+
{
64+
name: "field count mismatch",
65+
str: tableStruct,
66+
other: Struct{
67+
Table: tableId,
68+
Name: "Name",
69+
Fields: []Field{
70+
{Name: "Field1", Type: "string"},
71+
{Name: "Field2", Type: "string"},
72+
},
73+
},
74+
},
75+
{
76+
name: "field mismatch",
77+
str: tableStruct,
78+
other: Struct{
79+
Table: tableId,
80+
Name: "Name",
81+
Fields: []Field{
82+
{Name: "OtherField", Type: "string", Column: &plugin.Column{Table: tableId}},
83+
},
84+
},
85+
},
86+
}
87+
88+
for _, tt := range tests {
89+
tt := tt
90+
91+
t.Run(tt.name, func(t *testing.T) {
92+
t.Parallel()
93+
94+
if got := tt.str.Match(req, tt.other); got != tt.want {
95+
t.Errorf("Match() = %v, want %v", got, tt.want)
96+
}
97+
})
98+
}
99+
}
100+
101+
func TestStructs_Lookup(t *testing.T) {
102+
t.Parallel()
103+
104+
req := &plugin.CodeGenRequest{Catalog: &plugin.Catalog{}}
105+
str := Struct{
106+
Fields: []Field{
107+
{Name: "Field", Type: "string"},
108+
},
109+
}
110+
other := Struct{
111+
Fields: []Field{
112+
{Name: "OtherField", Type: "string"},
113+
},
114+
Comment: "OtherStruct",
115+
}
116+
structs := Structs{str}
117+
118+
tests := []struct {
119+
name string
120+
other Struct
121+
want Struct
122+
wantFound bool
123+
}{
124+
{
125+
name: "found",
126+
other: Struct{
127+
Fields: []Field{
128+
{Name: "Field", Type: "string"},
129+
},
130+
Comment: "Matching Struct",
131+
},
132+
want: str,
133+
wantFound: true,
134+
},
135+
{
136+
name: "not found",
137+
other: other,
138+
want: other,
139+
},
140+
}
141+
142+
for _, tt := range tests {
143+
tt := tt
144+
145+
t.Run(tt.name, func(t *testing.T) {
146+
t.Parallel()
147+
148+
got, found := structs.Lookup(req, tt.other)
149+
150+
if !reflect.DeepEqual(got, tt.want) {
151+
t.Errorf("Lookup() got = %v, want %v", got, tt.want)
152+
}
153+
154+
if found != tt.wantFound {
155+
t.Errorf("Lookup() found = %v, want %v", found, tt.wantFound)
156+
}
157+
})
158+
}
159+
}

0 commit comments

Comments
 (0)