@@ -53,6 +53,8 @@ type Field struct {
53
53
Name string
54
54
Type pyType
55
55
Comment string
56
+ // EmbedFields contains the embedded fields that require scanning.
57
+ EmbedFields []Field
56
58
}
57
59
58
60
type Struct struct {
@@ -105,14 +107,42 @@ func (v QueryValue) RowNode(rowVar string) *pyast.Node {
105
107
call := & pyast.Call {
106
108
Func : v .Annotation (),
107
109
}
108
- for i , f := range v .Struct .Fields {
109
- call .Keywords = append (call .Keywords , & pyast.Keyword {
110
- Arg : f .Name ,
111
- Value : subscriptNode (
110
+ rowIndex := 0 // We need to keep track of the index in the row variable.
111
+ for _ , f := range v .Struct .Fields {
112
+
113
+ var valueNode * pyast.Node
114
+ // Check if we are using sqlc.embed, if so we need to create a new object.
115
+ if len (f .EmbedFields ) > 0 {
116
+ // We keep this separate so we can easily add all arguments.
117
+ embed_call := & pyast.Call {Func : f .Type .Annotation ()}
118
+
119
+ // Now add all field Initializers for the embedded model that index into the original row.
120
+ for i , embedField := range f .EmbedFields {
121
+ embed_call .Keywords = append (embed_call .Keywords , & pyast.Keyword {
122
+ Arg : embedField .Name ,
123
+ Value : subscriptNode (
124
+ rowVar ,
125
+ constantInt (rowIndex + i ),
126
+ ),
127
+ })
128
+ }
129
+
130
+ valueNode = & pyast.Node {
131
+ Node : & pyast.Node_Call {
132
+ Call : embed_call ,
133
+ },
134
+ }
135
+
136
+ rowIndex += len (f .EmbedFields )
137
+ } else {
138
+ valueNode = subscriptNode (
112
139
rowVar ,
113
- constantInt (i ),
114
- ),
115
- })
140
+ constantInt (rowIndex ),
141
+ )
142
+ rowIndex ++
143
+ }
144
+
145
+ call .Keywords = append (call .Keywords , & pyast.Keyword {Arg : f .Name , Value : valueNode })
116
146
}
117
147
return & pyast.Node {
118
148
Node : & pyast.Node_Call {
@@ -319,6 +349,47 @@ func paramName(p *plugin.Parameter) string {
319
349
type pyColumn struct {
320
350
id int32
321
351
* plugin.Column
352
+ embed * pyEmbed
353
+ }
354
+
355
+ type pyEmbed struct {
356
+ modelType string
357
+ modelName string
358
+ fields []Field
359
+ }
360
+
361
+ // Taken from https://github.com/sqlc-dev/sqlc/blob/8c59fbb9938a0bad3d9971fc2c10ea1f83cc1d0b/internal/codegen/golang/result.go#L123-L126
362
+ // look through all the structs and attempt to find a matching one to embed
363
+ // We need the name of the struct and its field names.
364
+ func newGoEmbed (embed * plugin.Identifier , structs []Struct , defaultSchema string ) * pyEmbed {
365
+ if embed == nil {
366
+ return nil
367
+ }
368
+
369
+ for _ , s := range structs {
370
+ embedSchema := defaultSchema
371
+ if embed .Schema != "" {
372
+ embedSchema = embed .Schema
373
+ }
374
+
375
+ // compare the other attributes
376
+ if embed .Catalog != s .Table .Catalog || embed .Name != s .Table .Name || embedSchema != s .Table .Schema {
377
+ continue
378
+ }
379
+
380
+ fields := make ([]Field , len (s .Fields ))
381
+ for i , f := range s .Fields {
382
+ fields [i ] = f
383
+ }
384
+
385
+ return & pyEmbed {
386
+ modelType : s .Name ,
387
+ modelName : s .Name ,
388
+ fields : fields ,
389
+ }
390
+ }
391
+
392
+ return nil
322
393
}
323
394
324
395
func columnsToStruct (req * plugin.GenerateRequest , name string , columns []pyColumn ) * Struct {
@@ -342,10 +413,22 @@ func columnsToStruct(req *plugin.GenerateRequest, name string, columns []pyColum
342
413
if suffix > 0 {
343
414
fieldName = fmt .Sprintf ("%s_%d" , fieldName , suffix )
344
415
}
345
- gs .Fields = append (gs .Fields , Field {
416
+
417
+ f := Field {
346
418
Name : fieldName ,
347
419
Type : makePyType (req , c .Column ),
348
- })
420
+ }
421
+
422
+ if c .embed != nil {
423
+ f .Type = pyType {
424
+ InnerType : "models." + modelName (c .embed .modelType , req .Settings ),
425
+ IsArray : false ,
426
+ IsNull : false ,
427
+ }
428
+ f .EmbedFields = c .embed .fields
429
+ }
430
+
431
+ gs .Fields = append (gs .Fields , f )
349
432
seen [colName ]++
350
433
}
351
434
return & gs
@@ -459,6 +542,7 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([
459
542
columns = append (columns , pyColumn {
460
543
id : int32 (i ),
461
544
Column : c ,
545
+ embed : newGoEmbed (c .EmbedTable , structs , req .Catalog .DefaultSchema ),
462
546
})
463
547
}
464
548
gs = columnsToStruct (req , query .Name + "Row" , columns )
0 commit comments