Skip to content

Commit c2dcd56

Browse files
ovadbarBaroukh Ovadiakyleconroy
authored
Allow for mixed parameters types ($1 or ?) and sqlc.arg() (#1072)
* Allow for mixing parameter styles Co-authored-by: Baroukh Ovadia <bovadia@dyl.com> Co-authored-by: Kyle Conroy <kyle@conroy.org>
1 parent 6fdb7e1 commit c2dcd56

File tree

21 files changed

+379
-46
lines changed

21 files changed

+379
-46
lines changed

internal/compiler/parse.go

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"sort"
77
"strings"
88

9+
"github.com/kyleconroy/sqlc/internal/config"
910
"github.com/kyleconroy/sqlc/internal/debug"
1011
"github.com/kyleconroy/sqlc/internal/metadata"
1112
"github.com/kyleconroy/sqlc/internal/opts"
@@ -37,7 +38,8 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
3738
if err := validate.ParamStyle(stmt); err != nil {
3839
return nil, err
3940
}
40-
if err := validate.ParamRef(stmt); err != nil {
41+
numbers, dollar, err := validate.ParamRef(stmt)
42+
if err != nil {
4143
return nil, err
4244
}
4345
raw, ok := stmt.(*ast.RawStmt)
@@ -75,7 +77,7 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
7577
return nil, err
7678
}
7779

78-
raw, namedParams, edits := rewrite.NamedParameters(c.conf.Engine, raw)
80+
raw, namedParams, edits := rewrite.NamedParameters(c.conf.Engine, raw, numbers, dollar)
7981
rvs := rangeVars(raw.Stmt)
8082
refs := findParameters(raw.Stmt)
8183
if o.UsePositionalParameters {
@@ -84,8 +86,12 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
8486
return nil, err
8587
}
8688
} else {
87-
refs = uniqueParamRefs(refs)
88-
sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number })
89+
refs = uniqueParamRefs(refs, dollar)
90+
if c.conf.Engine == config.EngineMySQL || !dollar {
91+
sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Location < refs[j].ref.Location })
92+
} else {
93+
sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number })
94+
}
8995
}
9096
qc, err := buildQueryCatalog(c.catalog, raw.Stmt)
9197
if err != nil {
@@ -122,7 +128,6 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
122128
if err != nil {
123129
return nil, err
124130
}
125-
126131
return &Query{
127132
Cmd: cmd,
128133
Comments: comments,
@@ -145,13 +150,27 @@ func rangeVars(root ast.Node) []*ast.RangeVar {
145150
return vars
146151
}
147152

148-
func uniqueParamRefs(in []paramRef) []paramRef {
149-
m := make(map[int]struct{}, len(in))
153+
func uniqueParamRefs(in []paramRef, dollar bool) []paramRef {
154+
m := make(map[int]bool, len(in))
150155
o := make([]paramRef, 0, len(in))
151156
for _, v := range in {
152-
if _, ok := m[v.ref.Number]; !ok {
153-
m[v.ref.Number] = struct{}{}
154-
o = append(o, v)
157+
if !m[v.ref.Number] {
158+
m[v.ref.Number] = true
159+
if v.ref.Number != 0 {
160+
o = append(o, v)
161+
}
162+
}
163+
}
164+
if !dollar {
165+
start := 1
166+
for _, v := range in {
167+
if v.ref.Number == 0 {
168+
for m[start] {
169+
start++
170+
}
171+
v.ref.Number = start
172+
o = append(o, v)
173+
}
155174
}
156175
}
157176
return o

internal/endtoend/testdata/invalid_params/pgx/stderr.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
query.sql:4:1: could not determine data type of parameter $1
33
query.sql:7:1: could not determine data type of parameter $2
44
query.sql:10:8: column "foo" does not exist
5-
query.sql:13:1: query mixes positional parameters ($1) and named parameters (sqlc.arg or @arg)
5+
query.sql:13:1: could not determine data type of parameter $2

internal/endtoend/testdata/invalid_params/stdlib/stderr.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
query.sql:4:1: could not determine data type of parameter $1
33
query.sql:7:1: could not determine data type of parameter $2
44
query.sql:10:8: column "foo" does not exist
5-
query.sql:13:1: query mixes positional parameters ($1) and named parameters (sqlc.arg or @arg)
5+
query.sql:13:1: could not determine data type of parameter $2

internal/endtoend/testdata/mix_param_types/mysql/go/db.go

Lines changed: 29 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/mix_param_types/mysql/go/models.go

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/mix_param_types/mysql/go/test.sql.go

Lines changed: 57 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"version": "1",
3+
"packages": [
4+
{
5+
"path": "go",
6+
"name": "querytest",
7+
"schema": "test.sql",
8+
"queries": "test.sql",
9+
"engine": "mysql"
10+
}
11+
]
12+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
CREATE TABLE bar (
2+
id serial not null,
3+
name text not null,
4+
phone text not null
5+
);
6+
7+
-- name: CountOne :one
8+
SELECT count(1) FROM bar WHERE id = sqlc.arg(id) AND name <> ?;
9+
10+
-- name: CountTwo :one
11+
SELECT count(1) FROM bar WHERE id = ? AND name <> sqlc.arg(name);
12+
13+
-- name: CountThree :one
14+
SELECT count(1) FROM bar WHERE id > ? AND phone <> sqlc.arg(phone) AND name <> ?;

internal/endtoend/testdata/mix_param_types/postgresql/go/db.go

Lines changed: 29 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/mix_param_types/postgresql/go/models.go

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/mix_param_types/postgresql/go/test.sql.go

Lines changed: 75 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"version": "1",
3+
"packages": [
4+
{
5+
"path": "go",
6+
"name": "querytest",
7+
"schema": "test.sql",
8+
"queries": "test.sql",
9+
"engine": "postgresql"
10+
}
11+
]
12+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
CREATE TABLE bar (
2+
id serial not null,
3+
name text not null,
4+
phone text not null
5+
);
6+
7+
-- name: CountOne :one
8+
SELECT count(1) FROM bar WHERE id = sqlc.arg(id) AND name <> $1 LIMIT sqlc.arg('limit');
9+
10+
-- name: CountTwo :one
11+
SELECT count(1) FROM bar WHERE id = $1 AND name <> sqlc.arg(name);
12+
13+
-- name: CountThree :one
14+
SELECT count(1) FROM bar WHERE id > $2 AND phone <> sqlc.arg(phone) AND name <> $1;
15+
16+
-- name: CountFour :one
17+
SELECT count(1) FROM bar WHERE id > ? AND phone <> sqlc.arg(phone) AND name <> ?;
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# package querytest
22
query.sql:7:1: function "sqlc.argh" does not exist
33
query.sql:10:45: expected 1 parameter to sqlc.arg; got 2
4-
query.sql:13:45: expected parameter to sqlc.arg to be string or reference; got *ast.FuncCall
5-
query.sql:16:1: query mixes positional parameters ($1) and named parameters (sqlc.arg or @arg)
4+
query.sql:13:54: Invalid argument to sqlc.arg()
5+
query.sql:16:54: Invalid argument to sqlc.arg()
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# package querytest
22
query.sql:7:1: function "sqlc.argh" does not exist
33
query.sql:10:45: expected 1 parameter to sqlc.arg; got 2
4-
query.sql:13:45: expected parameter to sqlc.arg to be string or reference; got *ast.FuncCall
5-
query.sql:16:1: query mixes positional parameters ($1) and named parameters (sqlc.arg or @arg)
4+
query.sql:13:54: Invalid argument to sqlc.arg()
5+
query.sql:16:54: Invalid argument to sqlc.arg()

0 commit comments

Comments
 (0)