Skip to content
This repository was archived by the owner on Jul 10, 2024. It is now read-only.

Commit e2269bc

Browse files
Merge pull request #2 from nickjackson/feature/sqlc-embed
Feature/sqlc embed
2 parents 9e13aa7 + 14923ff commit e2269bc

File tree

32 files changed

+10999
-4929
lines changed

32 files changed

+10999
-4929
lines changed

internal/cmd/shim.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ func pluginGoCode(s config.SQLGo) *plugin.GoCode {
8686
EmitResultStructPointers: s.EmitResultStructPointers,
8787
EmitParamsStructPointers: s.EmitParamsStructPointers,
8888
EmitMethodsWithDbArgument: s.EmitMethodsWithDBArgument,
89-
EmitPointersForNullTypes: s.EmitPointersForNullTypes,
89+
EmitPointersForNullTypes: s.EmitPointersForNullTypes,
9090
EmitEnumValidMethod: s.EmitEnumValidMethod,
9191
EmitAllEnumValues: s.EmitAllEnumValues,
9292
JsonTagsCaseStyle: s.JSONTagsCaseStyle,
@@ -263,6 +263,14 @@ func pluginQueryColumn(c *compiler.Column) *plugin.Column {
263263
}
264264
}
265265

266+
if c.EmbedTable != nil {
267+
out.EmbedTable = &plugin.Identifier{
268+
Catalog: c.EmbedTable.Catalog,
269+
Schema: c.EmbedTable.Schema,
270+
Name: c.EmbedTable.Name,
271+
}
272+
}
273+
266274
return out
267275
}
268276

internal/codegen/golang/field.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ type Field struct {
1414
Type string
1515
Tags map[string]string
1616
Comment string
17+
18+
// EmbedFields contains the embedded fields that reuqire scanning.
19+
EmbedFields []string
1720
}
1821

1922
func (gf Field) Tag() string {

internal/codegen/golang/query.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,15 @@ func (v QueryValue) Scan() string {
135135
}
136136
} else {
137137
for _, f := range v.Struct.Fields {
138+
139+
// append any embedded fields
140+
if len(f.EmbedFields) > 0 {
141+
for _, embed := range f.EmbedFields {
142+
out = append(out, "&"+v.Name+"."+f.Name+"."+embed)
143+
}
144+
continue
145+
}
146+
138147
if strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() {
139148
out = append(out, "pq.Array(&"+v.Name+"."+f.Name+")")
140149
} else {

internal/codegen/golang/result.go

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,46 @@ func buildStructs(req *plugin.CodeGenRequest) []Struct {
103103
type goColumn struct {
104104
id int
105105
*plugin.Column
106+
embed *goEmbed
107+
}
108+
109+
type goEmbed struct {
110+
modelType string
111+
modelName string
112+
fields []string
113+
}
114+
115+
// look through all the structs and attempt to find a matching one to embed
116+
// We need the name of the struct and its field names.
117+
func newGoEmbed(embed *plugin.Identifier, structs []Struct) *goEmbed {
118+
if embed == nil {
119+
return nil
120+
}
121+
122+
for _, s := range structs {
123+
embedSchema := "public"
124+
if embed.Schema != "" {
125+
embedSchema = embed.Schema
126+
}
127+
128+
// compare the other attributes
129+
if embed.Catalog != s.Table.Catalog || embed.Name != s.Table.Name || embedSchema != s.Table.Schema {
130+
continue
131+
}
132+
133+
fields := make([]string, len(s.Fields))
134+
for i, f := range s.Fields {
135+
fields[i] = f.Name
136+
}
137+
138+
return &goEmbed{
139+
modelType: s.Name,
140+
modelName: s.Name,
141+
fields: fields,
142+
}
143+
}
144+
145+
return nil
106146
}
107147

108148
func columnName(c *plugin.Column, pos int) string {
@@ -190,7 +230,7 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
190230
}
191231
}
192232

193-
if len(query.Columns) == 1 {
233+
if len(query.Columns) == 1 && query.Columns[0].EmbedTable == nil {
194234
c := query.Columns[0]
195235
name := columnName(c, 0)
196236
if c.IsFuncCall {
@@ -231,6 +271,7 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
231271
columns = append(columns, goColumn{
232272
id: i,
233273
Column: c,
274+
embed: newGoEmbed(c.EmbedTable, structs),
234275
})
235276
}
236277
var err error
@@ -284,6 +325,13 @@ func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []goColumn
284325
for i, c := range columns {
285326
colName := columnName(c.Column, i)
286327
tagName := colName
328+
329+
// overide col/tag with expected model name
330+
if c.embed != nil {
331+
colName = c.embed.modelName
332+
tagName = SetCaseStyle(colName, "snake")
333+
}
334+
287335
fieldName := StructName(colName, req.Settings)
288336
baseFieldName := fieldName
289337
// Track suffixes by the ID of the column, so that columns referring to the same numbered parameter can be
@@ -306,12 +354,19 @@ func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []goColumn
306354
if req.Settings.Go.EmitJsonTags {
307355
tags["json"] = JSONTagName(tagName, req.Settings)
308356
}
309-
gs.Fields = append(gs.Fields, Field{
357+
f := Field{
310358
Name: fieldName,
311359
DBName: colName,
312-
Type: goType(req, c.Column),
313360
Tags: tags,
314-
})
361+
}
362+
if c.embed == nil {
363+
f.Type = goType(req, c.Column)
364+
} else {
365+
f.Type = c.embed.modelType
366+
f.EmbedFields = c.embed.fields
367+
}
368+
369+
gs.Fields = append(gs.Fields, f)
315370
if _, found := seen[baseFieldName]; !found {
316371
seen[baseFieldName] = []int{i}
317372
} else {

internal/compiler/expand.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,16 @@ func (c *Compiler) expandStmt(qc *QueryCatalog, raw *ast.RawStmt, node ast.Node)
132132
for _, p := range parts {
133133
old = append(old, c.quoteIdent(p))
134134
}
135+
oldString := strings.Join(old, ".")
136+
137+
// use the sqlc.embed string instead
138+
if embed, ok := qc.embeds.Find(ref); ok {
139+
oldString = embed.Orig()
140+
}
141+
135142
edits = append(edits, source.Edit{
136143
Location: res.Location - raw.StmtLocation,
137-
Old: strings.Join(old, "."),
144+
Old: oldString,
138145
New: strings.Join(cols, ", "),
139146
})
140147
}

internal/compiler/output_columns.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414

1515
// OutputColumns determines which columns a statement will output
1616
func (c *Compiler) OutputColumns(stmt ast.Node) ([]*catalog.Column, error) {
17-
qc, err := buildQueryCatalog(c.catalog, stmt)
17+
qc, err := buildQueryCatalog(c.catalog, stmt, nil)
1818
if err != nil {
1919
return nil, err
2020
}
@@ -178,6 +178,16 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
178178

179179
case *ast.ColumnRef:
180180
if hasStarRef(n) {
181+
182+
// add a column with a reference to an embedded table
183+
if embed, ok := qc.embeds.Find(n); ok {
184+
cols = append(cols, &Column{
185+
Name: embed.Table.Name,
186+
EmbedTable: embed.Table,
187+
})
188+
continue
189+
}
190+
181191
// TODO: This code is copied in func expand()
182192
for _, t := range tables {
183193
scope := astutils.Join(n.Fields, ".")
@@ -495,6 +505,7 @@ func outputColumnRefs(res *ast.ResTarget, tables []*Table, node *ast.ColumnRef)
495505
NotNull: c.NotNull,
496506
IsArray: c.IsArray,
497507
Length: c.Length,
508+
EmbedTable: c.EmbedTable,
498509
})
499510
}
500511
}

internal/compiler/parse.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,14 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
8383
} else {
8484
sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number })
8585
}
86-
qc, err := buildQueryCatalog(c.catalog, raw.Stmt)
86+
87+
raw, embeds := rewrite.Embeds(raw)
88+
qc, err := buildQueryCatalog(c.catalog, raw.Stmt, embeds)
8789
if err != nil {
8890
return nil, err
8991
}
9092

91-
params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams)
93+
params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams, embeds)
9294
if err != nil {
9395
return nil, err
9496
}

internal/compiler/query.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ type Column struct {
2929
Table *ast.TableName
3030
TableAlias string
3131
Type *ast.TypeName
32+
EmbedTable *ast.TableName
3233

3334
skipTableRequiredCheck bool
3435
}

internal/compiler/query_catalog.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@ import (
55

66
"github.com/kyleconroy/sqlc/internal/sql/ast"
77
"github.com/kyleconroy/sqlc/internal/sql/catalog"
8+
"github.com/kyleconroy/sqlc/internal/sql/rewrite"
89
)
910

1011
type QueryCatalog struct {
1112
catalog *catalog.Catalog
1213
ctes map[string]*Table
14+
embeds rewrite.EmbedSet
1315
}
1416

15-
func buildQueryCatalog(c *catalog.Catalog, node ast.Node) (*QueryCatalog, error) {
17+
func buildQueryCatalog(c *catalog.Catalog, node ast.Node, embeds rewrite.EmbedSet) (*QueryCatalog, error) {
1618
var with *ast.WithClause
1719
switch n := node.(type) {
1820
case *ast.DeleteStmt:
@@ -26,7 +28,7 @@ func buildQueryCatalog(c *catalog.Catalog, node ast.Node) (*QueryCatalog, error)
2628
default:
2729
with = nil
2830
}
29-
qc := &QueryCatalog{catalog: c, ctes: map[string]*Table{}}
31+
qc := &QueryCatalog{catalog: c, ctes: map[string]*Table{}, embeds: embeds}
3032
if with != nil {
3133
for _, item := range with.Ctes.Items {
3234
if cte, ok := item.(*ast.CommonTableExpr); ok {

internal/compiler/resolve.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"github.com/kyleconroy/sqlc/internal/sql/astutils"
99
"github.com/kyleconroy/sqlc/internal/sql/catalog"
1010
"github.com/kyleconroy/sqlc/internal/sql/named"
11+
"github.com/kyleconroy/sqlc/internal/sql/rewrite"
1112
"github.com/kyleconroy/sqlc/internal/sql/sqlerr"
1213
)
1314

@@ -19,7 +20,7 @@ func dataType(n *ast.TypeName) string {
1920
}
2021
}
2122

22-
func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet) ([]Parameter, error) {
23+
func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet) ([]Parameter, error) {
2324
c := comp.catalog
2425

2526
aliasMap := map[string]*ast.TableName{}
@@ -76,6 +77,22 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
7677
}
7778
}
7879

80+
// resolve a table for an embed
81+
for _, embed := range embeds {
82+
table, err := c.GetTable(embed.Table)
83+
if err == nil {
84+
embed.Table = table.Rel
85+
continue
86+
}
87+
88+
if alias, ok := aliasMap[embed.Table.Name]; ok {
89+
embed.Table = alias
90+
continue
91+
}
92+
93+
return nil, fmt.Errorf("unable to resolve table with %q: %w", embed.Orig(), err)
94+
}
95+
7996
var a []Parameter
8097
for _, ref := range args {
8198
switch n := ref.parent.(type) {

0 commit comments

Comments
 (0)