Skip to content

Commit 9597fe3

Browse files
kevinvalkewhauser
authored andcommitted
feat: support sqlc.embed
1 parent 245f052 commit 9597fe3

File tree

1 file changed

+93
-9
lines changed

1 file changed

+93
-9
lines changed

internal/gen.go

Lines changed: 93 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ type Field struct {
5353
Name string
5454
Type pyType
5555
Comment string
56+
// EmbedFields contains the embedded fields that require scanning.
57+
EmbedFields []Field
5658
}
5759

5860
type Struct struct {
@@ -105,14 +107,42 @@ func (v QueryValue) RowNode(rowVar string) *pyast.Node {
105107
call := &pyast.Call{
106108
Func: v.Annotation(),
107109
}
108-
for i, f := range v.Struct.Fields {
109-
call.Keywords = append(call.Keywords, &pyast.Keyword{
110-
Arg: f.Name,
111-
Value: subscriptNode(
110+
rowIndex := 0 // We need to keep track of the index in the row variable.
111+
for _, f := range v.Struct.Fields {
112+
113+
var valueNode *pyast.Node
114+
// Check if we are using sqlc.embed, if so we need to create a new object.
115+
if len(f.EmbedFields) > 0 {
116+
// We keep this separate so we can easily add all arguments.
117+
embed_call := &pyast.Call{Func: f.Type.Annotation()}
118+
119+
// Now add all field Initializers for the embedded model that index into the original row.
120+
for i, embedField := range f.EmbedFields {
121+
embed_call.Keywords = append(embed_call.Keywords, &pyast.Keyword{
122+
Arg: embedField.Name,
123+
Value: subscriptNode(
124+
rowVar,
125+
constantInt(rowIndex+i),
126+
),
127+
})
128+
}
129+
130+
valueNode = &pyast.Node{
131+
Node: &pyast.Node_Call{
132+
Call: embed_call,
133+
},
134+
}
135+
136+
rowIndex += len(f.EmbedFields)
137+
} else {
138+
valueNode = subscriptNode(
112139
rowVar,
113-
constantInt(i),
114-
),
115-
})
140+
constantInt(rowIndex),
141+
)
142+
rowIndex++
143+
}
144+
145+
call.Keywords = append(call.Keywords, &pyast.Keyword{Arg: f.Name, Value: valueNode})
116146
}
117147
return &pyast.Node{
118148
Node: &pyast.Node_Call{
@@ -319,6 +349,47 @@ func paramName(p *plugin.Parameter) string {
319349
type pyColumn struct {
320350
id int32
321351
*plugin.Column
352+
embed *pyEmbed
353+
}
354+
355+
type pyEmbed struct {
356+
modelType string
357+
modelName string
358+
fields []Field
359+
}
360+
361+
// Taken from https://github.com/sqlc-dev/sqlc/blob/8c59fbb9938a0bad3d9971fc2c10ea1f83cc1d0b/internal/codegen/golang/result.go#L123-L126
362+
// look through all the structs and attempt to find a matching one to embed
363+
// We need the name of the struct and its field names.
364+
func newGoEmbed(embed *plugin.Identifier, structs []Struct, defaultSchema string) *pyEmbed {
365+
if embed == nil {
366+
return nil
367+
}
368+
369+
for _, s := range structs {
370+
embedSchema := defaultSchema
371+
if embed.Schema != "" {
372+
embedSchema = embed.Schema
373+
}
374+
375+
// compare the other attributes
376+
if embed.Catalog != s.Table.Catalog || embed.Name != s.Table.Name || embedSchema != s.Table.Schema {
377+
continue
378+
}
379+
380+
fields := make([]Field, len(s.Fields))
381+
for i, f := range s.Fields {
382+
fields[i] = f
383+
}
384+
385+
return &pyEmbed{
386+
modelType: s.Name,
387+
modelName: s.Name,
388+
fields: fields,
389+
}
390+
}
391+
392+
return nil
322393
}
323394

324395
func columnsToStruct(req *plugin.GenerateRequest, name string, columns []pyColumn) *Struct {
@@ -342,10 +413,22 @@ func columnsToStruct(req *plugin.GenerateRequest, name string, columns []pyColum
342413
if suffix > 0 {
343414
fieldName = fmt.Sprintf("%s_%d", fieldName, suffix)
344415
}
345-
gs.Fields = append(gs.Fields, Field{
416+
417+
f := Field{
346418
Name: fieldName,
347419
Type: makePyType(req, c.Column),
348-
})
420+
}
421+
422+
if c.embed != nil {
423+
f.Type = pyType{
424+
InnerType: "models." + modelName(c.embed.modelType, req.Settings),
425+
IsArray: false,
426+
IsNull: false,
427+
}
428+
f.EmbedFields = c.embed.fields
429+
}
430+
431+
gs.Fields = append(gs.Fields, f)
349432
seen[colName]++
350433
}
351434
return &gs
@@ -459,6 +542,7 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([
459542
columns = append(columns, pyColumn{
460543
id: int32(i),
461544
Column: c,
545+
embed: newGoEmbed(c.EmbedTable, structs, req.Catalog.DefaultSchema),
462546
})
463547
}
464548
gs = columnsToStruct(req, query.Name+"Row", columns)

0 commit comments

Comments
 (0)