Skip to content

Commit 64d2863

Browse files
committed
feat(analyzer): Implement parameter type annotations
A first step towards implementing #2800
1 parent ce95162 commit 64d2863

File tree

12 files changed

+206
-26
lines changed

12 files changed

+206
-26
lines changed

internal/compiler/analyze.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55

66
analyzer "github.com/sqlc-dev/sqlc/internal/analysis"
77
"github.com/sqlc-dev/sqlc/internal/config"
8+
"github.com/sqlc-dev/sqlc/internal/metadata"
89
"github.com/sqlc-dev/sqlc/internal/source"
910
"github.com/sqlc-dev/sqlc/internal/sql/ast"
1011
"github.com/sqlc-dev/sqlc/internal/sql/named"
@@ -106,15 +107,15 @@ func combineAnalysis(prev *analysis, a *analyzer.Analysis) *analysis {
106107
return prev
107108
}
108109

109-
func (c *Compiler) analyzeQuery(raw *ast.RawStmt, query string) (*analysis, error) {
110-
return c._analyzeQuery(raw, query, true)
110+
func (c *Compiler) analyzeQuery(raw *ast.RawStmt, query string, paramAnnotations map[string]metadata.ParamMetadata) (*analysis, error) {
111+
return c._analyzeQuery(raw, query, paramAnnotations, true)
111112
}
112113

113-
func (c *Compiler) inferQuery(raw *ast.RawStmt, query string) (*analysis, error) {
114-
return c._analyzeQuery(raw, query, false)
114+
func (c *Compiler) inferQuery(raw *ast.RawStmt, query string, paramAnnotations map[string]metadata.ParamMetadata) (*analysis, error) {
115+
return c._analyzeQuery(raw, query, paramAnnotations, false)
115116
}
116117

117-
func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) (*analysis, error) {
118+
func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, paramAnnotations map[string]metadata.ParamMetadata, failfast bool) (*analysis, error) {
118119
errors := make([]error, 0)
119120
check := func(err error) error {
120121
if failfast {
@@ -173,7 +174,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
173174
return nil, err
174175
}
175176

176-
params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams, embeds)
177+
params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams, embeds, paramAnnotations)
177178
if err := check(err); err != nil {
178179
return nil, err
179180
}

internal/compiler/parse.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
7272

7373
var anlys *analysis
7474
if c.analyzer != nil {
75-
inference, _ := c.inferQuery(raw, rawSQL)
75+
inference, _ := c.inferQuery(raw, rawSQL, md.Params)
7676
if inference == nil {
7777
inference = &analysis{}
7878
}
@@ -100,7 +100,7 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
100100
// FOOTGUN: combineAnalysis mutates inference
101101
anlys = combineAnalysis(inference, result)
102102
} else {
103-
anlys, err = c.analyzeQuery(raw, rawSQL)
103+
anlys, err = c.analyzeQuery(raw, rawSQL, md.Params)
104104
if err != nil {
105105
return nil, err
106106
}

internal/compiler/resolve.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"log/slog"
66
"strconv"
77

8+
"github.com/sqlc-dev/sqlc/internal/metadata"
89
"github.com/sqlc-dev/sqlc/internal/sql/ast"
910
"github.com/sqlc-dev/sqlc/internal/sql/astutils"
1011
"github.com/sqlc-dev/sqlc/internal/sql/catalog"
@@ -21,7 +22,7 @@ func dataType(n *ast.TypeName) string {
2122
}
2223
}
2324

24-
func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet) ([]Parameter, error) {
25+
func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet, paramAnnotations map[string]metadata.ParamMetadata) ([]Parameter, error) {
2526
c := comp.catalog
2627

2728
aliasMap := map[string]*ast.TableName{}
@@ -619,5 +620,18 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
619620
})
620621
}
621622
}
623+
624+
// Override the inferrerd type and nullability of annotated named params
625+
for i, param := range a {
626+
if param.Column.IsNamedParam {
627+
if md, ok := paramAnnotations[param.Column.Name]; ok {
628+
a[i].Column.DataType = md.DatabaseType
629+
if md.ForceNotNull != nil {
630+
a[i].Column.NotNull = *md.ForceNotNull
631+
}
632+
}
633+
}
634+
}
635+
622636
return a, nil
623637
}

internal/endtoend/testdata/param_type_annotations/db/db.go

Lines changed: 31 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/param_type_annotations/db/models.go

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/param_type_annotations/db/query.sql.go

Lines changed: 74 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
-- name: TestSqlcArg :one
2+
-- @param foo text
3+
SELECT * FROM test WHERE id = sqlc.arg('foo');
4+
5+
-- name: TestAt :one
6+
-- @param foo integer
7+
SELECT * FROM test WHERE name = @foo;
8+
9+
-- name: TestForceNotNull :one
10+
-- @param foo! jsonb
11+
SELECT * FROM test WHERE name = @foo;
12+
13+
-- name: TestForceNullable :one
14+
-- @param foo? uuid
15+
SELECT * FROM test WHERE id = @foo;
16+
17+
-- name: TestGibberish :one
18+
-- @param foo? uuid sdfagyi
19+
SELECT * FROM test WHERE id = @foo;
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
CREATE TABLE test (id INTEGER NOT NULL, name TEXT);
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
version: "2"
2+
sql:
3+
- schema: schema.sql
4+
queries: query.sql
5+
engine: postgresql
6+
gen:
7+
go:
8+
out: db

internal/metadata/meta.go

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ type Metadata struct {
1515
Name string
1616
Cmd string
1717
Comments []string
18-
Params map[string]string
18+
Params map[string]ParamMetadata
1919
Flags map[string]bool
2020

2121
Filename string
@@ -34,6 +34,12 @@ const (
3434
CmdBatchOne = ":batchone"
3535
)
3636

37+
type ParamMetadata struct {
38+
DatabaseType string
39+
// unset => nil, "!" => true, "?" => false
40+
ForceNotNull *bool
41+
}
42+
3743
// A query name must be a valid Go identifier
3844
//
3945
// https://golang.org/ref/spec#Identifiers
@@ -113,8 +119,8 @@ func ParseQueryNameAndType(t string, commentStyle CommentSyntax) (string, string
113119
return "", "", nil
114120
}
115121

116-
func ParseParamsAndFlags(comments []string) (map[string]string, map[string]bool, error) {
117-
params := make(map[string]string)
122+
func ParseParamsAndFlags(comments []string) (map[string]ParamMetadata, map[string]bool, error) {
123+
params := make(map[string]ParamMetadata)
118124
flags := make(map[string]bool)
119125

120126
for _, line := range comments {
@@ -137,7 +143,25 @@ func ParseParamsAndFlags(comments []string) (map[string]string, map[string]bool,
137143
paramToken := s.Text()
138144
rest = append(rest, paramToken)
139145
}
140-
params[name] = strings.Join(rest, " ")
146+
var hasSuffix, suffixValue bool
147+
switch name[len(name)-1] {
148+
case '!':
149+
name = name[:len(name)-1]
150+
hasSuffix = true
151+
suffixValue = true
152+
case '?':
153+
name = name[:len(name)-1]
154+
hasSuffix = true
155+
suffixValue = false
156+
}
157+
var forceNotNull *bool
158+
if hasSuffix {
159+
forceNotNull = &suffixValue
160+
}
161+
params[name] = ParamMetadata{
162+
DatabaseType: strings.Join(rest, " "),
163+
ForceNotNull: forceNotNull,
164+
}
141165
default:
142166
flags[token] = true
143167
}

internal/metadata/meta_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ func TestParseQueryParams(t *testing.T) {
8787
t.Errorf("expected param not found")
8888
}
8989

90-
if pt != "UUID" {
90+
if pt.DatabaseType != "UUID" {
9191
t.Error("unexpected param metadata:", pt)
9292
}
9393

internal/sql/validate/param_ref.go

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,32 +10,26 @@ import (
1010
)
1111

1212
func ParamRef(n ast.Node) (map[int]bool, bool, error) {
13-
var allrefs []*ast.ParamRef
14-
var dollar bool
15-
var nodollar bool
13+
seen := map[int]bool{}
14+
var dollar, nodollar bool
1615
// Find all parameter references
1716
astutils.Walk(astutils.VisitorFunc(func(node ast.Node) {
1817
switch n := node.(type) {
1918
case *ast.ParamRef:
20-
ref := node.(*ast.ParamRef)
21-
if ref.Dollar {
19+
if n.Dollar {
2220
dollar = true
2321
} else {
2422
nodollar = true
2523
}
26-
allrefs = append(allrefs, n)
24+
if n.Number > 0 {
25+
seen[n.Number] = true
26+
}
2727
}
2828
}), n)
2929
if dollar && nodollar {
3030
return nil, false, errors.New("can not mix $1 format with ? format")
3131
}
3232

33-
seen := map[int]bool{}
34-
for _, r := range allrefs {
35-
if r.Number > 0 {
36-
seen[r.Number] = true
37-
}
38-
}
3933
for i := 1; i <= len(seen); i += 1 {
4034
if _, ok := seen[i]; !ok {
4135
return seen, !nodollar, &sqlerr.Error{

0 commit comments

Comments
 (0)