Skip to content

Commit 491bfe0

Browse files
committed
mimic override parse and merge within golang codegen package
1 parent 41343ab commit 491bfe0

File tree

8 files changed

+490
-118
lines changed

8 files changed

+490
-118
lines changed

internal/cmd/generate.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ func codegen(ctx context.Context, combo config.CombinedSettings, sql outPair, re
418418
case sql.Gen.Go != nil:
419419
out = combo.Go.Out
420420
handler = ext.HandleFunc(golang.Generate)
421-
opts, err := json.Marshal(pluginGoOpts(sql.Gen.Go, combo, result))
421+
opts, err := json.Marshal(sql.Gen.Go)
422422
if err != nil {
423423
return "", nil, fmt.Errorf("opts marshal failed: %w", err)
424424
}

internal/cmd/shim.go

Lines changed: 21 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package cmd
22

33
import (
4+
"encoding/json"
45
"strings"
56

67
goopts "github.com/sqlc-dev/sqlc/internal/codegen/golang/opts"
@@ -12,7 +13,7 @@ import (
1213
"github.com/sqlc-dev/sqlc/internal/sql/catalog"
1314
)
1415

15-
func pluginOverride(r *compiler.Result, o config.Override) goopts.Override {
16+
func pluginOverride(r *compiler.Result, o config.Override) *plugin.Override {
1617
var column string
1718
var table plugin.Identifier
1819

@@ -34,26 +35,36 @@ func pluginOverride(r *compiler.Result, o config.Override) goopts.Override {
3435
column = colParts[3]
3536
}
3637
}
37-
return goopts.Override{
38-
CodeType: "", // FIXME
38+
39+
goTypeJSON, err := json.Marshal(pluginGoType(o))
40+
if err != nil {
41+
panic(err)
42+
}
43+
44+
return &plugin.Override{
45+
CodeType: goTypeJSON,
3946
DbType: o.DBType,
4047
Nullable: o.Nullable,
4148
Unsigned: o.Unsigned,
4249
Column: o.Column,
4350
ColumnName: column,
4451
Table: &table,
45-
GoType: pluginGoType(o),
4652
}
4753
}
4854

4955
func pluginSettings(r *compiler.Result, cs config.CombinedSettings) *plugin.Settings {
56+
var overrides []*plugin.Override
57+
for _, o := range cs.Overrides {
58+
overrides = append(overrides, pluginOverride(r, o))
59+
}
5060
return &plugin.Settings{
51-
Version: cs.Global.Version,
52-
Engine: string(cs.Package.Engine),
53-
Schema: []string(cs.Package.Schema),
54-
Queries: []string(cs.Package.Queries),
55-
Rename: cs.Rename,
56-
Codegen: pluginCodegen(cs, cs.Codegen),
61+
Version: cs.Global.Version,
62+
Engine: string(cs.Package.Engine),
63+
Schema: []string(cs.Package.Schema),
64+
Queries: []string(cs.Package.Queries),
65+
Overrides: overrides,
66+
Rename: cs.Rename,
67+
Codegen: pluginCodegen(cs, cs.Codegen),
5768
}
5869
}
5970

@@ -111,46 +122,6 @@ func pluginGoType(o config.Override) *goopts.ParsedGoType {
111122
}
112123
}
113124

114-
func pluginGoOpts(sqlGo *config.SQLGo, cs config.CombinedSettings, r *compiler.Result) *goopts.Options {
115-
var overrides []goopts.Override
116-
for _, o := range cs.Overrides {
117-
overrides = append(overrides, pluginOverride(r, o))
118-
}
119-
return &goopts.Options{
120-
EmitInterface: sqlGo.EmitInterface,
121-
EmitJsonTags: sqlGo.EmitJSONTags,
122-
JsonTagsIdUppercase: sqlGo.JsonTagsIDUppercase,
123-
EmitDbTags: sqlGo.EmitDBTags,
124-
EmitPreparedQueries: sqlGo.EmitPreparedQueries,
125-
EmitExactTableNames: sqlGo.EmitExactTableNames,
126-
EmitEmptySlices: sqlGo.EmitEmptySlices,
127-
EmitExportedQueries: sqlGo.EmitExportedQueries,
128-
EmitResultStructPointers: sqlGo.EmitResultStructPointers,
129-
EmitParamsStructPointers: sqlGo.EmitParamsStructPointers,
130-
EmitMethodsWithDbArgument: sqlGo.EmitMethodsWithDBArgument,
131-
EmitPointersForNullTypes: sqlGo.EmitPointersForNullTypes,
132-
EmitEnumValidMethod: sqlGo.EmitEnumValidMethod,
133-
EmitAllEnumValues: sqlGo.EmitAllEnumValues,
134-
JsonTagsCaseStyle: sqlGo.JSONTagsCaseStyle,
135-
Package: sqlGo.Package,
136-
Out: sqlGo.Out,
137-
Overrides: overrides,
138-
// Rename intentionally omitted
139-
SqlPackage: sqlGo.SQLPackage,
140-
SqlDriver: sqlGo.SQLDriver,
141-
OutputBatchFileName: sqlGo.OutputBatchFileName,
142-
OutputDbFileName: sqlGo.OutputDBFileName,
143-
OutputModelsFileName: sqlGo.OutputModelsFileName,
144-
OutputQuerierFileName: sqlGo.OutputQuerierFileName,
145-
OutputCopyfromFileName: sqlGo.OutputCopyFromFileName,
146-
OutputFilesSuffix: sqlGo.OutputFilesSuffix,
147-
InflectionExcludeTableNames: sqlGo.InflectionExcludeTableNames,
148-
QueryParameterLimit: sqlGo.QueryParameterLimit,
149-
OmitUnusedStructs: sqlGo.OmitUnusedStructs,
150-
BuildTags: sqlGo.BuildTags,
151-
}
152-
}
153-
154125
func pluginCatalog(c *catalog.Catalog) *plugin.Catalog {
155126
var schemas []*plugin.Schema
156127
for _, s := range c.Schemas {
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package opts
2+
3+
import (
4+
"github.com/sqlc-dev/sqlc/internal/codegen/sdk"
5+
"github.com/sqlc-dev/sqlc/internal/plugin"
6+
)
7+
8+
type GlobalOverride struct {
9+
*plugin.Override
10+
11+
GoType *ParsedGoType
12+
}
13+
14+
func (o *GlobalOverride) Convert() *plugin.Override {
15+
return &plugin.Override{
16+
DbType: o.DbType,
17+
Nullable: o.Nullable,
18+
Column: o.Column,
19+
Table: o.Table,
20+
ColumnName: o.ColumnName,
21+
Unsigned: o.Unsigned,
22+
}
23+
}
24+
25+
func (o *GlobalOverride) Matches(n *plugin.Identifier, defaultSchema string) bool {
26+
return sdk.Matches(o.Convert(), n, defaultSchema)
27+
}
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
package opts
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"go/types"
7+
"regexp"
8+
"strings"
9+
10+
"github.com/fatih/structtag"
11+
)
12+
13+
type GoType struct {
14+
Path string `json:"import" yaml:"import"`
15+
Package string `json:"package" yaml:"package"`
16+
Name string `json:"type" yaml:"type"`
17+
Pointer bool `json:"pointer" yaml:"pointer"`
18+
Slice bool `json:"slice" yaml:"slice"`
19+
Spec string
20+
BuiltIn bool
21+
}
22+
23+
type ParsedGoType struct {
24+
ImportPath string `json:"import_path"`
25+
Package string `json:"package"`
26+
TypeName string `json:"type_name"`
27+
BasicType bool `json:"basic_type"`
28+
StructTags map[string]string `json:"struct_tags"`
29+
}
30+
31+
func (o *GoType) UnmarshalJSON(data []byte) error {
32+
var spec string
33+
if err := json.Unmarshal(data, &spec); err == nil {
34+
*o = GoType{Spec: spec}
35+
return nil
36+
}
37+
type alias GoType
38+
var a alias
39+
if err := json.Unmarshal(data, &a); err != nil {
40+
return err
41+
}
42+
*o = GoType(a)
43+
return nil
44+
}
45+
46+
func (o *GoType) UnmarshalYAML(unmarshal func(interface{}) error) error {
47+
var spec string
48+
if err := unmarshal(&spec); err == nil {
49+
*o = GoType{Spec: spec}
50+
return nil
51+
}
52+
type alias GoType
53+
var a alias
54+
if err := unmarshal(&a); err != nil {
55+
return err
56+
}
57+
*o = GoType(a)
58+
return nil
59+
}
60+
61+
var validIdentifier = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
62+
var versionNumber = regexp.MustCompile(`^v[0-9]+$`)
63+
var invalidIdentifier = regexp.MustCompile(`[^a-zA-Z0-9_]`)
64+
65+
func generatePackageID(importPath string) (string, bool) {
66+
parts := strings.Split(importPath, "/")
67+
name := parts[len(parts)-1]
68+
// If the last part of the import path is a valid identifier, assume that's the package name
69+
if versionNumber.MatchString(name) && len(parts) >= 2 {
70+
name = parts[len(parts)-2]
71+
return invalidIdentifier.ReplaceAllString(strings.ToLower(name), "_"), true
72+
}
73+
if validIdentifier.MatchString(name) {
74+
return name, false
75+
}
76+
return invalidIdentifier.ReplaceAllString(strings.ToLower(name), "_"), true
77+
}
78+
79+
// validate GoType
80+
func (gt GoType) Parse() (*ParsedGoType, error) {
81+
var o ParsedGoType
82+
83+
if gt.Spec == "" {
84+
// TODO: Validation
85+
if gt.Path == "" && gt.Package != "" {
86+
return nil, fmt.Errorf("Package override `go_type`: package name requires an import path")
87+
}
88+
var pkg string
89+
var pkgNeedsAlias bool
90+
91+
if gt.Package == "" && gt.Path != "" {
92+
pkg, pkgNeedsAlias = generatePackageID(gt.Path)
93+
if pkgNeedsAlias {
94+
o.Package = pkg
95+
}
96+
} else {
97+
pkg = gt.Package
98+
o.Package = gt.Package
99+
}
100+
101+
o.ImportPath = gt.Path
102+
o.TypeName = gt.Name
103+
o.BasicType = gt.Path == "" && gt.Package == ""
104+
if pkg != "" {
105+
o.TypeName = pkg + "." + o.TypeName
106+
}
107+
if gt.Pointer {
108+
o.TypeName = "*" + o.TypeName
109+
}
110+
if gt.Slice {
111+
o.TypeName = "[]" + o.TypeName
112+
}
113+
return &o, nil
114+
}
115+
116+
input := gt.Spec
117+
lastDot := strings.LastIndex(input, ".")
118+
lastSlash := strings.LastIndex(input, "/")
119+
typename := input
120+
if lastDot == -1 && lastSlash == -1 {
121+
// if the type name has no slash and no dot, validate that the type is a basic Go type
122+
var found bool
123+
for _, typ := range types.Typ {
124+
info := typ.Info()
125+
if info == 0 {
126+
continue
127+
}
128+
if info&types.IsUntyped != 0 {
129+
continue
130+
}
131+
if typename == typ.Name() {
132+
found = true
133+
}
134+
}
135+
if !found {
136+
return nil, fmt.Errorf("Package override `go_type` specifier %q is not a Go basic type e.g. 'string'", input)
137+
}
138+
o.BasicType = true
139+
} else {
140+
// assume the type lives in a Go package
141+
if lastDot == -1 {
142+
return nil, fmt.Errorf("Package override `go_type` specifier %q is not the proper format, expected 'package.type', e.g. 'github.com/segmentio/ksuid.KSUID'", input)
143+
}
144+
typename = input[lastSlash+1:]
145+
// a package name beginning with "go-" will give syntax errors in
146+
// generated code. We should do the right thing and get the actual
147+
// import name, but in lieu of that, stripping the leading "go-" may get
148+
// us what we want.
149+
typename = strings.TrimPrefix(typename, "go-")
150+
typename = strings.TrimSuffix(typename, "-go")
151+
o.ImportPath = input[:lastDot]
152+
}
153+
o.TypeName = typename
154+
isPointer := input[0] == '*'
155+
if isPointer {
156+
o.ImportPath = o.ImportPath[1:]
157+
o.TypeName = "*" + o.TypeName
158+
}
159+
return &o, nil
160+
}
161+
162+
// GoStructTag is a raw Go struct tag.
163+
type GoStructTag string
164+
165+
// Parse parses and validates a GoStructTag.
166+
// The output is in a form convenient for codegen.
167+
//
168+
// Sample valid inputs/outputs:
169+
//
170+
// In Out
171+
// empty string {}
172+
// `a:"b"` {"a": "b"}
173+
// `a:"b" x:"y,z"` {"a": "b", "x": "y,z"}
174+
func (s GoStructTag) Parse() (map[string]string, error) {
175+
m := make(map[string]string)
176+
tags, err := structtag.Parse(string(s))
177+
if err != nil {
178+
return nil, err
179+
}
180+
for _, tag := range tags.Tags() {
181+
m[tag.Key] = tag.Value()
182+
}
183+
return m, nil
184+
}

0 commit comments

Comments
 (0)