From 1ac5932452c9eb71e9291631c2b21d3760bc3df0 Mon Sep 17 00:00:00 2001 From: simo7 Date: Fri, 2 Aug 2024 15:44:36 +0200 Subject: [PATCH 01/15] Add table_prefix conf option --- Makefile | 4 ++-- README.md | 4 ++++ internal/config.go | 1 + internal/gen.go | 3 +++ plugin/{ => sqlc-gen-python}/main.go | 0 5 files changed, 10 insertions(+), 2 deletions(-) rename plugin/{ => sqlc-gen-python}/main.go (100%) diff --git a/Makefile b/Makefile index d8ac60b..dbea94f 100644 --- a/Makefile +++ b/Makefile @@ -9,10 +9,10 @@ test: bin/sqlc-gen-python.wasm all: bin/sqlc-gen-python bin/sqlc-gen-python.wasm bin/sqlc-gen-python: bin go.mod go.sum $(wildcard **/*.go) - cd plugin && go build -o ../bin/sqlc-gen-python ./main.go + cd plugin/sqlc-gen-python && go build -o ../../bin/sqlc-gen-python ./main.go bin/sqlc-gen-python.wasm: bin/sqlc-gen-python - cd plugin && GOOS=wasip1 GOARCH=wasm go build -o ../bin/sqlc-gen-python.wasm main.go + cd plugin/sqlc-gen-python && GOOS=wasip1 GOARCH=wasm go build -o ../../bin/sqlc-gen-python.wasm main.go bin: mkdir -p bin diff --git a/README.md b/README.md index d7e0d75..8aaf0c9 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,7 @@ +## Build + +make bin/sqlc-gen-python + ## Usage ```yaml diff --git a/internal/config.go b/internal/config.go index 009cb04..22659d6 100644 --- a/internal/config.go +++ b/internal/config.go @@ -9,4 +9,5 @@ type Config struct { EmitPydanticModels bool `json:"emit_pydantic_models"` QueryParameterLimit *int32 `json:"query_parameter_limit"` InflectionExcludeTableNames []string `json:"inflection_exclude_table_names"` + TablePrefix string `json:"table_prefix"` } diff --git a/internal/gen.go b/internal/gen.go index f81c53b..5c1727b 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -279,6 +279,9 @@ func buildModels(conf Config, req *plugin.GenerateRequest) []Struct { Exclusions: conf.InflectionExcludeTableNames, }) } + if conf.TablePrefix != "" { + structName = conf.TablePrefix + strings.ToUpper(structName[:1]) + structName[1:] + } s := Struct{ Table: plugin.Identifier{Schema: schema.Name, Name: table.Rel.Name}, Name: modelName(structName, req.Settings), diff --git a/plugin/main.go b/plugin/sqlc-gen-python/main.go similarity index 100% rename from plugin/main.go rename to plugin/sqlc-gen-python/main.go From 7776922f6cf95b5f7a102de3d27134c4e8f7c609 Mon Sep 17 00:00:00 2001 From: simo7 Date: Wed, 11 Sep 2024 23:49:58 +0200 Subject: [PATCH 02/15] convert to asyncpg-based --- internal/config.go | 2 - internal/gen.go | 262 +++++++++++--------------------------------- internal/imports.go | 12 +- 3 files changed, 67 insertions(+), 209 deletions(-) diff --git a/internal/config.go b/internal/config.go index 22659d6..899009e 100644 --- a/internal/config.go +++ b/internal/config.go @@ -2,8 +2,6 @@ package python type Config struct { EmitExactTableNames bool `json:"emit_exact_table_names"` - EmitSyncQuerier bool `json:"emit_sync_querier"` - EmitAsyncQuerier bool `json:"emit_async_querier"` Package string `json:"package"` Out string `json:"out"` EmitPydanticModels bool `json:"emit_pydantic_models"` diff --git a/internal/gen.go b/internal/gen.go index 5c1727b..e14ecec 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -354,18 +354,6 @@ func columnsToStruct(req *plugin.GenerateRequest, name string, columns []pyColum return &gs } -var postgresPlaceholderRegexp = regexp.MustCompile(`\B\$(\d+)\b`) - -// Sqlalchemy uses ":name" for placeholders, so "$N" is converted to ":pN" -// This also means ":" has special meaning to sqlalchemy, so it must be escaped. -func sqlalchemySQL(s, engine string) string { - s = strings.ReplaceAll(s, ":", `\\:`) - if engine == "postgresql" { - return postgresPlaceholderRegexp.ReplaceAllString(s, ":p$1") - } - return s -} - func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([]Query, error) { qs := make([]Query, 0, len(req.Queries)) for _, query := range req.Queries { @@ -387,7 +375,7 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ MethodName: methodName, FieldName: sdk.LowerTitle(query.Name) + "Stmt", ConstantName: strings.ToUpper(methodName), - SQL: sqlalchemySQL(query.Text, req.Settings.Engine), + SQL: query.Text, SourceName: query.Filename, } @@ -625,18 +613,7 @@ func typeRefNode(base string, parts ...string) *pyast.Node { } func connMethodNode(method, name string, arg *pyast.Node) *pyast.Node { - args := []*pyast.Node{ - { - Node: &pyast.Node_Call{ - Call: &pyast.Call{ - Func: typeRefNode("sqlalchemy", "text"), - Args: []*pyast.Node{ - poet.Name(name), - }, - }, - }, - }, - } + args := []*pyast.Node{poet.Name(name)} if arg != nil { args = append(args, arg) } @@ -792,7 +769,7 @@ func asyncQuerierClassDef() *pyast.ClassDef { }, { Arg: "conn", - Annotation: typeRefNode("sqlalchemy", "ext", "asyncio", "AsyncConnection"), + Annotation: typeRefNode("asyncpg", "pool", "PoolConnectionProxy"), }, }, }, @@ -872,190 +849,81 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node { } } - if ctx.C.EmitSyncQuerier { - cls := querierClassDef() - for _, q := range ctx.Queries { - if !ctx.OutputQuery(q.SourceName) { - continue - } - f := &pyast.FunctionDef{ - Name: q.MethodName, - Args: &pyast.Arguments{ - Args: []*pyast.Arg{ - { - Arg: "self", - }, + cls := asyncQuerierClassDef() + for _, q := range ctx.Queries { + if !ctx.OutputQuery(q.SourceName) { + continue + } + f := &pyast.AsyncFunctionDef{ + Name: q.MethodName, + Args: &pyast.Arguments{ + Args: []*pyast.Arg{ + { + Arg: "self", }, }, - } - - q.AddArgs(f.Args) - exec := connMethodNode("execute", q.ConstantName, q.ArgDictNode()) + }, + } - switch q.Cmd { - case ":one": - f.Body = append(f.Body, - assignNode("row", poet.Node( - &pyast.Call{ - Func: poet.Attribute(exec, "first"), - }, - )), - poet.Node( - &pyast.If{ - Test: poet.Node( - &pyast.Compare{ - Left: poet.Name("row"), - Ops: []*pyast.Node{ - poet.Is(), - }, - Comparators: []*pyast.Node{ - poet.Constant(nil), - }, + q.AddArgs(f.Args) + + switch q.Cmd { + case ":one": + fetchrow := connMethodNode("fetchrow", q.ConstantName, q.ArgDictNode()) + f.Body = append(f.Body, + assignNode("row", poet.Await(fetchrow)), + poet.Node( + &pyast.If{ + Test: poet.Node( + &pyast.Compare{ + Left: poet.Name("row"), + Ops: []*pyast.Node{ + poet.Is(), }, - ), - Body: []*pyast.Node{ - poet.Return( + Comparators: []*pyast.Node{ poet.Constant(nil), - ), + }, }, + ), + Body: []*pyast.Node{ + poet.Return( + poet.Constant(nil), + ), }, - ), - poet.Return(q.Ret.RowNode("row")), - ) - f.Returns = subscriptNode("Optional", q.Ret.Annotation()) - case ":many": - f.Body = append(f.Body, - assignNode("result", exec), - poet.Node( - &pyast.For{ - Target: poet.Name("row"), - Iter: poet.Name("result"), - Body: []*pyast.Node{ - poet.Expr( - poet.Yield( - q.Ret.RowNode("row"), - ), + }, + ), + poet.Return(q.Ret.RowNode("row")), + ) + f.Returns = subscriptNode("Optional", q.Ret.Annotation()) + case ":many": + cursor := connMethodNode("cursor", q.ConstantName, q.ArgDictNode()) + f.Body = append(f.Body, + poet.Node( + &pyast.AsyncFor{ + Target: poet.Name("row"), + Iter: cursor, + Body: []*pyast.Node{ + poet.Expr( + poet.Yield( + q.Ret.RowNode("row"), ), - }, - }, - ), - ) - f.Returns = subscriptNode("Iterator", q.Ret.Annotation()) - case ":exec": - f.Body = append(f.Body, exec) - f.Returns = poet.Constant(nil) - case ":execrows": - f.Body = append(f.Body, - assignNode("result", exec), - poet.Return(poet.Attribute(poet.Name("result"), "rowcount")), - ) - f.Returns = poet.Name("int") - case ":execresult": - f.Body = append(f.Body, - poet.Return(exec), - ) - f.Returns = typeRefNode("sqlalchemy", "engine", "Result") - default: - panic("unknown cmd " + q.Cmd) - } - - cls.Body = append(cls.Body, poet.Node(f)) - } - mod.Body = append(mod.Body, poet.Node(cls)) - } - - if ctx.C.EmitAsyncQuerier { - cls := asyncQuerierClassDef() - for _, q := range ctx.Queries { - if !ctx.OutputQuery(q.SourceName) { - continue - } - f := &pyast.AsyncFunctionDef{ - Name: q.MethodName, - Args: &pyast.Arguments{ - Args: []*pyast.Arg{ - { - Arg: "self", + ), }, }, - }, - } - - q.AddArgs(f.Args) + ), + ) + f.Returns = subscriptNode("AsyncIterator", q.Ret.Annotation()) + case ":exec": exec := connMethodNode("execute", q.ConstantName, q.ArgDictNode()) - - switch q.Cmd { - case ":one": - f.Body = append(f.Body, - assignNode("row", poet.Node( - &pyast.Call{ - Func: poet.Attribute(poet.Await(exec), "first"), - }, - )), - poet.Node( - &pyast.If{ - Test: poet.Node( - &pyast.Compare{ - Left: poet.Name("row"), - Ops: []*pyast.Node{ - poet.Is(), - }, - Comparators: []*pyast.Node{ - poet.Constant(nil), - }, - }, - ), - Body: []*pyast.Node{ - poet.Return( - poet.Constant(nil), - ), - }, - }, - ), - poet.Return(q.Ret.RowNode("row")), - ) - f.Returns = subscriptNode("Optional", q.Ret.Annotation()) - case ":many": - stream := connMethodNode("stream", q.ConstantName, q.ArgDictNode()) - f.Body = append(f.Body, - assignNode("result", poet.Await(stream)), - poet.Node( - &pyast.AsyncFor{ - Target: poet.Name("row"), - Iter: poet.Name("result"), - Body: []*pyast.Node{ - poet.Expr( - poet.Yield( - q.Ret.RowNode("row"), - ), - ), - }, - }, - ), - ) - f.Returns = subscriptNode("AsyncIterator", q.Ret.Annotation()) - case ":exec": - f.Body = append(f.Body, poet.Await(exec)) - f.Returns = poet.Constant(nil) - case ":execrows": - f.Body = append(f.Body, - assignNode("result", poet.Await(exec)), - poet.Return(poet.Attribute(poet.Name("result"), "rowcount")), - ) - f.Returns = poet.Name("int") - case ":execresult": - f.Body = append(f.Body, - poet.Return(poet.Await(exec)), - ) - f.Returns = typeRefNode("sqlalchemy", "engine", "Result") - default: - panic("unknown cmd " + q.Cmd) - } - - cls.Body = append(cls.Body, poet.Node(f)) + f.Body = append(f.Body, poet.Await(exec)) + f.Returns = poet.Constant(nil) + default: + panic("unknown cmd " + q.Cmd) } - mod.Body = append(mod.Body, poet.Node(cls)) + + cls.Body = append(cls.Body, poet.Node(f)) } + mod.Body = append(mod.Body, poet.Node(cls)) return poet.Node(mod) } diff --git a/internal/imports.go b/internal/imports.go index b88c58c..423b1a0 100644 --- a/internal/imports.go +++ b/internal/imports.go @@ -131,10 +131,7 @@ func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map std := stdImports(queryUses) pkg := make(map[string]importSpec) - pkg["sqlalchemy"] = importSpec{Module: "sqlalchemy"} - if i.C.EmitAsyncQuerier { - pkg["sqlalchemy.ext.asyncio"] = importSpec{Module: "sqlalchemy.ext.asyncio"} - } + pkg["asyncpg"] = importSpec{Module: "asyncpg"} queryValueModelImports := func(qv QueryValue) { if qv.IsStruct() && qv.EmitStruct() { @@ -154,12 +151,7 @@ func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map std["typing.Optional"] = importSpec{Module: "typing", Name: "Optional"} } if q.Cmd == ":many" { - if i.C.EmitSyncQuerier { - std["typing.Iterator"] = importSpec{Module: "typing", Name: "Iterator"} - } - if i.C.EmitAsyncQuerier { - std["typing.AsyncIterator"] = importSpec{Module: "typing", Name: "AsyncIterator"} - } + std["typing.AsyncIterator"] = importSpec{Module: "typing", Name: "AsyncIterator"} } queryValueModelImports(q.Ret) for _, qv := range q.Args { From 554bc943a896cbdf390952dfcc03b1fe4abc1166 Mon Sep 17 00:00:00 2001 From: simo7 Date: Thu, 12 Sep 2024 11:14:05 +0200 Subject: [PATCH 03/15] Pass params as args to asyncpg functions --- README.md | 2 ++ internal/gen.go | 32 +++++++++++++++++++++++++------- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 8aaf0c9..7fe41ad 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +# Sqlc-gen-python for Asyncpg + ## Build make bin/sqlc-gen-python diff --git a/internal/gen.go b/internal/gen.go index e14ecec..2e683bf 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -151,6 +151,26 @@ func (q Query) AddArgs(args *pyast.Arguments) { } } +func (q Query) ArgNodes() []*pyast.Node { + args := []*pyast.Node{} + i := 1 + for _, a := range q.Args { + if a.isEmpty() { + continue + } + if a.IsStruct() { + for _, f := range a.Struct.Fields { + args = append(args, typeRefNode(a.Name, f.Name)) + i++ + } + } else { + args = append(args, poet.Name(a.Name)) + i++ + } + } + return args +} + func (q Query) ArgDictNode() *pyast.Node { dict := &pyast.Dict{} i := 1 @@ -612,11 +632,9 @@ func typeRefNode(base string, parts ...string) *pyast.Node { return n } -func connMethodNode(method, name string, arg *pyast.Node) *pyast.Node { +func connMethodNode(method, name string, params ...*pyast.Node) *pyast.Node { args := []*pyast.Node{poet.Name(name)} - if arg != nil { - args = append(args, arg) - } + args = append(args, params...) return &pyast.Node{ Node: &pyast.Node_Call{ Call: &pyast.Call{ @@ -869,7 +887,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node { switch q.Cmd { case ":one": - fetchrow := connMethodNode("fetchrow", q.ConstantName, q.ArgDictNode()) + fetchrow := connMethodNode("fetchrow", q.ConstantName, q.ArgNodes()...) f.Body = append(f.Body, assignNode("row", poet.Await(fetchrow)), poet.Node( @@ -896,7 +914,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node { ) f.Returns = subscriptNode("Optional", q.Ret.Annotation()) case ":many": - cursor := connMethodNode("cursor", q.ConstantName, q.ArgDictNode()) + cursor := connMethodNode("cursor", q.ConstantName, q.ArgNodes()...) f.Body = append(f.Body, poet.Node( &pyast.AsyncFor{ @@ -914,7 +932,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node { ) f.Returns = subscriptNode("AsyncIterator", q.Ret.Annotation()) case ":exec": - exec := connMethodNode("execute", q.ConstantName, q.ArgDictNode()) + exec := connMethodNode("execute", q.ConstantName, q.ArgNodes()...) f.Body = append(f.Body, poet.Await(exec)) f.Returns = poet.Constant(nil) default: From 746594d78a95abf236de8336dfe556f6193b5094 Mon Sep 17 00:00:00 2001 From: simo7 Date: Thu, 12 Sep 2024 17:04:19 +0200 Subject: [PATCH 04/15] No optional return for fetchrow on insert/upsert --- internal/gen.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/internal/gen.go b/internal/gen.go index 2e683bf..70d005a 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -888,9 +888,11 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node { switch q.Cmd { case ":one": fetchrow := connMethodNode("fetchrow", q.ConstantName, q.ArgNodes()...) - f.Body = append(f.Body, - assignNode("row", poet.Await(fetchrow)), - poet.Node( + f.Body = append(f.Body, assignNode("row", poet.Await(fetchrow))) + if strings.HasPrefix(q.ConstantName, "INSERT") || strings.HasPrefix(q.ConstantName, "UPSERT") { + f.Returns = q.Ret.Annotation() + } else { + f.Body = append(f.Body, poet.Node( &pyast.If{ Test: poet.Node( &pyast.Compare{ @@ -909,10 +911,10 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node { ), }, }, - ), - poet.Return(q.Ret.RowNode("row")), - ) - f.Returns = subscriptNode("Optional", q.Ret.Annotation()) + )) + f.Returns = subscriptNode("Optional", q.Ret.Annotation()) + } + f.Body = append(f.Body, poet.Return(q.Ret.RowNode("row"))) case ":many": cursor := connMethodNode("cursor", q.ConstantName, q.ArgNodes()...) f.Body = append(f.Body, From d1a001d7ea4e41eb08850750af612f64449414b4 Mon Sep 17 00:00:00 2001 From: simo7 Date: Fri, 13 Sep 2024 22:31:23 +0200 Subject: [PATCH 05/15] Account for insert/upsert with where clause --- internal/gen.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/gen.go b/internal/gen.go index 70d005a..bd21d83 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -889,7 +889,8 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node { case ":one": fetchrow := connMethodNode("fetchrow", q.ConstantName, q.ArgNodes()...) f.Body = append(f.Body, assignNode("row", poet.Await(fetchrow))) - if strings.HasPrefix(q.ConstantName, "INSERT") || strings.HasPrefix(q.ConstantName, "UPSERT") { + if !strings.Contains(strings.ToUpper(q.SQL), "WHERE ") && + (strings.HasPrefix(q.ConstantName, "INSERT") || strings.HasPrefix(q.ConstantName, "UPSERT")) { f.Returns = q.Ret.Annotation() } else { f.Body = append(f.Body, poet.Node( From 41d4d6ba4ae36239b3892041625b4272cff3cac6 Mon Sep 17 00:00:00 2001 From: simo7 Date: Sat, 14 Sep 2024 01:43:00 +0200 Subject: [PATCH 06/15] Refactor w/ isAlwaysReturningInsert() --- internal/gen.go | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/internal/gen.go b/internal/gen.go index bd21d83..a5d36c9 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -889,8 +889,8 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node { case ":one": fetchrow := connMethodNode("fetchrow", q.ConstantName, q.ArgNodes()...) f.Body = append(f.Body, assignNode("row", poet.Await(fetchrow))) - if !strings.Contains(strings.ToUpper(q.SQL), "WHERE ") && - (strings.HasPrefix(q.ConstantName, "INSERT") || strings.HasPrefix(q.ConstantName, "UPSERT")) { + + if isAlwaysReturningInsert(q.SQL) { f.Returns = q.Ret.Annotation() } else { f.Body = append(f.Body, poet.Node( @@ -1028,3 +1028,18 @@ func Generate(_ context.Context, req *plugin.GenerateRequest) (*plugin.GenerateR return &resp, nil } + +func isAlwaysReturningInsert(sql string) bool { + var hasInsert, hasWhere, hasReturning bool + for _, w := range strings.Fields(sql) { + switch strings.ToUpper(w) { + case "INSERT": + hasInsert = true + case "WHERE": + hasWhere = true + case "RETURNING": + hasReturning = true + } + } + return hasInsert && hasReturning && !hasWhere +} From 21b40269bd66267b69edc7f7c67cee438ee49fb5 Mon Sep 17 00:00:00 2001 From: simo7 Date: Thu, 19 Sep 2024 00:01:48 +0200 Subject: [PATCH 07/15] rm * catch all for args --- internal/printer/printer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/printer/printer.go b/internal/printer/printer.go index 0660c6a..f56ff45 100644 --- a/internal/printer/printer.go +++ b/internal/printer/printer.go @@ -381,7 +381,7 @@ func (w *writer) printFunctionDef(fd *ast.FunctionDef, indent int32) { } } if len(fd.Args.KwOnlyArgs) > 0 { - w.print(", *, ") + w.print(", ") for i, arg := range fd.Args.KwOnlyArgs { w.printArg(arg, indent) if i != len(fd.Args.KwOnlyArgs)-1 { From 4d2625e9b5eec426d8676b9619cb1d6d43e0f762 Mon Sep 17 00:00:00 2001 From: simo7 Date: Fri, 4 Oct 2024 14:12:57 +0200 Subject: [PATCH 08/15] models.py -> db_models.py --- internal/gen.go | 4 ++-- internal/imports.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/gen.go b/internal/gen.go index a5d36c9..2990db5 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -998,8 +998,8 @@ func Generate(_ context.Context, req *plugin.GenerateRequest) (*plugin.GenerateR output := map[string]string{} result := pyprint.Print(buildModelsTree(&tctx, i), pyprint.Options{}) - tctx.SourceName = "models.py" - output["models.py"] = string(result.Python) + tctx.SourceName = "db_models.py" + output["db_models.py"] = string(result.Python) files := map[string]struct{}{} for _, q := range queries { diff --git a/internal/imports.go b/internal/imports.go index 423b1a0..08bc710 100644 --- a/internal/imports.go +++ b/internal/imports.go @@ -69,7 +69,7 @@ func queryValueUses(name string, qv QueryValue) bool { } func (i *importer) Imports(fileName string) []string { - if fileName == "models.py" { + if fileName == "db_models.py" { return i.modelImports() } return i.queryImports(fileName) From 648631bd94c9a9d58ca2f841817fb5466382e7d4 Mon Sep 17 00:00:00 2001 From: simo7 Date: Fri, 4 Oct 2024 18:32:17 +0200 Subject: [PATCH 09/15] revert to keyword-only args --- internal/printer/printer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/printer/printer.go b/internal/printer/printer.go index f56ff45..0660c6a 100644 --- a/internal/printer/printer.go +++ b/internal/printer/printer.go @@ -381,7 +381,7 @@ func (w *writer) printFunctionDef(fd *ast.FunctionDef, indent int32) { } } if len(fd.Args.KwOnlyArgs) > 0 { - w.print(", ") + w.print(", *, ") for i, arg := range fd.Args.KwOnlyArgs { w.printArg(arg, indent) if i != len(fd.Args.KwOnlyArgs)-1 { From abec3c8744d00b1527c3a40134b5df47ae6b5d39 Mon Sep 17 00:00:00 2001 From: simo7 Date: Fri, 4 Oct 2024 18:40:52 +0200 Subject: [PATCH 10/15] Default to None for optional func args --- internal/gen.go | 46 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/internal/gen.go b/internal/gen.go index 2990db5..08a27b3 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -38,12 +38,15 @@ type pyType struct { IsNull bool } -func (t pyType) Annotation() *pyast.Node { +func (t pyType) Annotation(isFuncSignature bool) *pyast.Node { ann := poet.Name(t.InnerType) if t.IsArray { ann = subscriptNode("List", ann) } - if t.IsNull { + if t.IsNull && isFuncSignature { + ann = optionalKeywordNode("Optional", ann) + } + if t.IsNull && !isFuncSignature { ann = subscriptNode("Optional", ann) } return ann @@ -69,9 +72,10 @@ type QueryValue struct { Typ pyType } +// Annotation in function signature func (v QueryValue) Annotation() *pyast.Node { if v.Typ != (pyType{}) { - return v.Typ.Annotation() + return v.Typ.Annotation(true) } if v.Struct != nil { if v.Emit { @@ -143,12 +147,21 @@ func (q Query) AddArgs(args *pyast.Arguments) { }) return } + var optionalArgs []*pyast.Arg for _, a := range q.Args { + if a.Typ.IsNull { + optionalArgs = append(optionalArgs, &pyast.Arg{ + Arg: a.Name, + Annotation: a.Annotation(), + }) + continue + } args.KwOnlyArgs = append(args.KwOnlyArgs, &pyast.Arg{ Arg: a.Name, Annotation: a.Annotation(), }) } + args.KwOnlyArgs = append(args.KwOnlyArgs, optionalArgs...) } func (q Query) ArgNodes() []*pyast.Node { @@ -577,6 +590,31 @@ func subscriptNode(value string, slice *pyast.Node) *pyast.Node { } } +func optionalKeywordNode(value string, slice *pyast.Node) *pyast.Node { + v := &pyast.Node{ + Node: &pyast.Node_Subscript{ + Subscript: &pyast.Subscript{ + Value: &pyast.Name{Id: value}, + Slice: slice, + }, + }, + } + return &pyast.Node{ + Node: &pyast.Node_Keyword{ + Keyword: &pyast.Keyword{ + Arg: string(pyprint.Print(v, pyprint.Options{}).Python), + Value: &pyast.Node{ + Node: &pyast.Node_Constant{ + Constant: &pyast.Constant{ + Value: &pyast.Constant_None{None: true}, + }, + }, + }, + }, + }, + } +} + func dataclassNode(name string) *pyast.ClassDef { return &pyast.ClassDef{ Name: name, @@ -617,7 +655,7 @@ func fieldNode(f Field) *pyast.Node { Node: &pyast.Node_AnnAssign{ AnnAssign: &pyast.AnnAssign{ Target: &pyast.Name{Id: f.Name}, - Annotation: f.Type.Annotation(), + Annotation: f.Type.Annotation(false), Comment: f.Comment, }, }, From b376daadd34c421d904c6274897d0f652661ae8f Mon Sep 17 00:00:00 2001 From: simo7 Date: Sat, 5 Oct 2024 13:22:48 +0200 Subject: [PATCH 11/15] RLS enforced fields --- internal/config.go | 3 +++ internal/gen.go | 34 +++++++++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/internal/config.go b/internal/config.go index 899009e..032eb01 100644 --- a/internal/config.go +++ b/internal/config.go @@ -8,4 +8,7 @@ type Config struct { QueryParameterLimit *int32 `json:"query_parameter_limit"` InflectionExcludeTableNames []string `json:"inflection_exclude_table_names"` TablePrefix string `json:"table_prefix"` + // When a query uses a table with RLS enforced fields, it will be required to + // parametrized those fields. Associate tables are not covered! + RLSEnforcedFields []string `json:"rls_enforced_fields"` } diff --git a/internal/gen.go b/internal/gen.go index 08a27b3..5bd1683 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -388,6 +388,20 @@ func columnsToStruct(req *plugin.GenerateRequest, name string, columns []pyColum } func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([]Query, error) { + rlsFieldsByTable := make(map[string][]string) // TODO + if len(conf.RLSEnforcedFields) > 0 { + for i := range structs { + tableName := structs[i].Table.Name + for _, f := range structs[i].Fields { + for _, enforced := range conf.RLSEnforcedFields { + if f.Name == enforced { + rlsFieldsByTable[tableName] = append(rlsFieldsByTable[tableName], f.Name) + } + } + } + } + } + qs := make([]Query, 0, len(req.Queries)) for _, query := range req.Queries { if query.Name == "" { @@ -419,9 +433,20 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ if qpl < 0 { return nil, errors.New("invalid query parameter limit") } + enforcedFields := make(map[string]bool) + for _, c := range query.Columns { + if fields, ok := rlsFieldsByTable[c.GetTable().GetName()]; ok { + for _, f := range fields { + enforcedFields[f] = false + } + } + } if len(query.Params) > qpl || qpl == 0 { var cols []pyColumn for _, p := range query.Params { + if _, ok := enforcedFields[p.GetColumn().GetName()]; ok { + enforcedFields[p.Column.Name] = true + } cols = append(cols, pyColumn{ id: p.Number, Column: p.Column, @@ -435,6 +460,9 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ } else { args := make([]QueryValue, 0, len(query.Params)) for _, p := range query.Params { + if _, ok := enforcedFields[p.GetColumn().GetName()]; ok { + enforcedFields[p.Column.Name] = true + } args = append(args, QueryValue{ Name: paramName(p), Typ: makePyType(req, p.Column), @@ -442,7 +470,11 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ } gq.Args = args } - + for field, is_enforced := range enforcedFields { + if !is_enforced { + return nil, fmt.Errorf("RLS field %s is not filtered in query %s", field, query.Name) + } + } if len(query.Columns) == 1 { c := query.Columns[0] gq.Ret = QueryValue{ From 7b4b3550150ceb3ef16f5866cec9a81fed392fac Mon Sep 17 00:00:00 2001 From: simo7 Date: Sun, 6 Oct 2024 01:58:29 +0200 Subject: [PATCH 12/15] Implement sqlc.embed --- internal/config.go | 2 + internal/gen.go | 116 ++++++++++++++++++++++++++++++++++++++------ internal/imports.go | 4 +- 3 files changed, 104 insertions(+), 18 deletions(-) diff --git a/internal/config.go b/internal/config.go index 032eb01..706c978 100644 --- a/internal/config.go +++ b/internal/config.go @@ -12,3 +12,5 @@ type Config struct { // parametrized those fields. Associate tables are not covered! RLSEnforcedFields []string `json:"rls_enforced_fields"` } + +const MODELS_FILENAME = "db_models" diff --git a/internal/gen.go b/internal/gen.go index 5bd1683..ca82758 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -56,6 +56,8 @@ type Field struct { Name string Type pyType Comment string + // EmbedFields contains the embedded fields that require scanning. + EmbedFields []Field } type Struct struct { @@ -81,7 +83,7 @@ func (v QueryValue) Annotation() *pyast.Node { if v.Emit { return poet.Name(v.Struct.Name) } else { - return typeRefNode("models", v.Struct.Name) + return typeRefNode(MODELS_FILENAME, v.Struct.Name) } } panic("no type for QueryValue: " + v.Name) @@ -109,14 +111,41 @@ func (v QueryValue) RowNode(rowVar string) *pyast.Node { call := &pyast.Call{ Func: v.Annotation(), } - for i, f := range v.Struct.Fields { - call.Keywords = append(call.Keywords, &pyast.Keyword{ - Arg: f.Name, - Value: subscriptNode( + var idx int + for _, f := range v.Struct.Fields { + var val *pyast.Node + if len(f.EmbedFields) > 0 { + var embedFields []*pyast.Keyword + for _, embed := range f.EmbedFields { + embedFields = append(embedFields, &pyast.Keyword{ + Arg: embed.Name, + Value: subscriptNode( + rowVar, + constantInt(idx), + ), + }) + idx++ + } + val = &pyast.Node{ + Node: &pyast.Node_Call{ + Call: &pyast.Call{ + Func: f.Type.Annotation(false), + Keywords: embedFields, + }, + }, + } + } else { + val = subscriptNode( rowVar, - constantInt(i), - ), + constantInt(idx), + ) + idx++ + } + call.Keywords = append(call.Keywords, &pyast.Keyword{ + Arg: f.Name, + Value: val, }) + } return &pyast.Node{ Node: &pyast.Node_Call{ @@ -355,6 +384,46 @@ func paramName(p *plugin.Parameter) string { type pyColumn struct { id int32 *plugin.Column + embed *pyEmbed +} + +type pyEmbed struct { + modelType string + modelName string + fields []Field +} + +// look through all the structs and attempt to find a matching one to embed +// We need the name of the struct and its field names. +func newPyEmbed(embed *plugin.Identifier, structs []Struct, defaultSchema string) *pyEmbed { + if embed == nil { + return nil + } + + for _, s := range structs { + embedSchema := defaultSchema + if embed.Schema != "" { + embedSchema = embed.Schema + } + + // compare the other attributes + if embed.Catalog != s.Table.Catalog || embed.Name != s.Table.Name || embedSchema != s.Table.Schema { + continue + } + + fields := make([]Field, len(s.Fields)) + for i, f := range s.Fields { + fields[i] = f + } + + return &pyEmbed{ + modelType: s.Name, + modelName: s.Name, + fields: fields, + } + } + + return nil } func columnsToStruct(req *plugin.GenerateRequest, name string, columns []pyColumn) *Struct { @@ -366,6 +435,12 @@ func columnsToStruct(req *plugin.GenerateRequest, name string, columns []pyColum for i, c := range columns { colName := columnName(c.Column, i) fieldName := colName + + // override col with expected model name + if c.embed != nil { + colName = c.embed.modelName + } + // Track suffixes by the ID of the column, so that columns referring to // the same numbered parameter can be reused. var suffix int32 @@ -378,17 +453,25 @@ func columnsToStruct(req *plugin.GenerateRequest, name string, columns []pyColum if suffix > 0 { fieldName = fmt.Sprintf("%s_%d", fieldName, suffix) } - gs.Fields = append(gs.Fields, Field{ - Name: fieldName, - Type: makePyType(req, c.Column), - }) + f := Field{Name: fieldName} + if c.embed == nil { + f.Type = makePyType(req, c.Column) + } else { + f.Type = pyType{ + InnerType: MODELS_FILENAME + "." + c.embed.modelType, + IsArray: c.IsArray, + IsNull: false, + } + f.EmbedFields = c.embed.fields + } + gs.Fields = append(gs.Fields, f) seen[colName]++ } return &gs } func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([]Query, error) { - rlsFieldsByTable := make(map[string][]string) // TODO + rlsFieldsByTable := make(map[string][]string) if len(conf.RLSEnforcedFields) > 0 { for i := range structs { tableName := structs[i].Table.Name @@ -475,7 +558,7 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ return nil, fmt.Errorf("RLS field %s is not filtered in query %s", field, query.Name) } } - if len(query.Columns) == 1 { + if len(query.Columns) == 1 && query.Columns[0].EmbedTable == nil { c := query.Columns[0] gq.Ret = QueryValue{ Name: columnName(c, 0), @@ -515,6 +598,7 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ columns = append(columns, pyColumn{ id: int32(i), Column: c, + embed: newPyEmbed(c.EmbedTable, structs, req.Catalog.DefaultSchema), }) } gs = columnsToStruct(req, query.Name+"Row", columns) @@ -893,7 +977,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node { ImportFrom: &pyast.ImportFrom{ Module: ctx.C.Package, Names: []*pyast.Node{ - poet.Alias("models"), + poet.Alias(MODELS_FILENAME), }, }, }, @@ -1068,8 +1152,8 @@ func Generate(_ context.Context, req *plugin.GenerateRequest) (*plugin.GenerateR output := map[string]string{} result := pyprint.Print(buildModelsTree(&tctx, i), pyprint.Options{}) - tctx.SourceName = "db_models.py" - output["db_models.py"] = string(result.Python) + tctx.SourceName = MODELS_FILENAME + ".py" + output[MODELS_FILENAME+".py"] = string(result.Python) files := map[string]struct{}{} for _, q := range queries { diff --git a/internal/imports.go b/internal/imports.go index 08bc710..b892bd8 100644 --- a/internal/imports.go +++ b/internal/imports.go @@ -69,7 +69,7 @@ func queryValueUses(name string, qv QueryValue) bool { } func (i *importer) Imports(fileName string) []string { - if fileName == "db_models.py" { + if fileName == MODELS_FILENAME+".py" { return i.modelImports() } return i.queryImports(fileName) @@ -165,7 +165,7 @@ func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map func (i *importer) queryImports(fileName string) []string { std, pkg := i.queryImportSpecs(fileName) - modelImportStr := fmt.Sprintf("from %s import models", i.C.Package) + modelImportStr := fmt.Sprintf("from %s import %s", i.C.Package, MODELS_FILENAME) importLines := []string{ buildImportBlock(std), From b8217016eb5966cea358a592c41bbf3244e2c0b0 Mon Sep 17 00:00:00 2001 From: simo7 Date: Sun, 6 Oct 2024 03:01:23 +0200 Subject: [PATCH 13/15] Add MergeQueryFiles config opt --- internal/config.go | 2 ++ internal/gen.go | 11 +++++++++-- internal/imports.go | 4 ++-- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/internal/config.go b/internal/config.go index 706c978..ccf5a48 100644 --- a/internal/config.go +++ b/internal/config.go @@ -11,6 +11,8 @@ type Config struct { // When a query uses a table with RLS enforced fields, it will be required to // parametrized those fields. Associate tables are not covered! RLSEnforcedFields []string `json:"rls_enforced_fields"` + // Merge queries defined in different files into one output queries.py file + MergeQueryFiles bool `json:"merge_query_files"` } const MODELS_FILENAME = "db_models" diff --git a/internal/gen.go b/internal/gen.go index ca82758..5877a9c 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -1113,6 +1113,9 @@ type pyTmplCtx struct { } func (t *pyTmplCtx) OutputQuery(sourceName string) bool { + if t.C.MergeQueryFiles { + return true + } return t.SourceName == sourceName } @@ -1156,8 +1159,12 @@ func Generate(_ context.Context, req *plugin.GenerateRequest) (*plugin.GenerateR output[MODELS_FILENAME+".py"] = string(result.Python) files := map[string]struct{}{} - for _, q := range queries { - files[q.SourceName] = struct{}{} + if i.C.MergeQueryFiles { + files["db_queries.sql"] = struct{}{} + } else { + for _, q := range queries { + files[q.SourceName] = struct{}{} + } } for source := range files { diff --git a/internal/imports.go b/internal/imports.go index b892bd8..f5011b7 100644 --- a/internal/imports.go +++ b/internal/imports.go @@ -113,7 +113,7 @@ func (i *importer) modelImports() []string { func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map[string]importSpec) { queryUses := func(name string) bool { for _, q := range i.Queries { - if q.SourceName != fileName { + if !i.C.MergeQueryFiles && q.SourceName != fileName { continue } if queryValueUses(name, q.Ret) { @@ -144,7 +144,7 @@ func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map } for _, q := range i.Queries { - if q.SourceName != fileName { + if !i.C.MergeQueryFiles && q.SourceName != fileName { continue } if q.Cmd == ":one" { From eaae4615626a3832ab41c10e0d5fad3ec7c48514 Mon Sep 17 00:00:00 2001 From: simo7 Date: Sun, 6 Oct 2024 19:47:38 +0200 Subject: [PATCH 14/15] sqlc.embed: skip object if id is None --- internal/config.go | 5 ++++- internal/gen.go | 32 +++++++++++++++++++------------- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/internal/config.go b/internal/config.go index ccf5a48..2a9bcdb 100644 --- a/internal/config.go +++ b/internal/config.go @@ -9,7 +9,10 @@ type Config struct { InflectionExcludeTableNames []string `json:"inflection_exclude_table_names"` TablePrefix string `json:"table_prefix"` // When a query uses a table with RLS enforced fields, it will be required to - // parametrized those fields. Associate tables are not covered! + // parametrized those fields. Not covered: + // - Associate tables + // - sqlc.embed() + // - json_agg(tbl.*) RLSEnforcedFields []string `json:"rls_enforced_fields"` // Merge queries defined in different files into one output queries.py file MergeQueryFiles bool `json:"merge_query_files"` diff --git a/internal/gen.go b/internal/gen.go index 5877a9c..8454a7d 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -118,27 +118,33 @@ func (v QueryValue) RowNode(rowVar string) *pyast.Node { var embedFields []*pyast.Keyword for _, embed := range f.EmbedFields { embedFields = append(embedFields, &pyast.Keyword{ - Arg: embed.Name, - Value: subscriptNode( - rowVar, - constantInt(idx), - ), + Arg: embed.Name, + Value: subscriptNode(rowVar, constantInt(idx)), }) idx++ } val = &pyast.Node{ - Node: &pyast.Node_Call{ - Call: &pyast.Call{ - Func: f.Type.Annotation(false), - Keywords: embedFields, + Node: &pyast.Node_Compare{ + Compare: &pyast.Compare{ + Left: &pyast.Node{ + Node: &pyast.Node_Call{ + Call: &pyast.Call{ + Func: f.Type.Annotation(false), + Keywords: embedFields, + }, + }, + }, + Ops: []*pyast.Node{ + poet.Name(fmt.Sprintf("if row[%d] else", idx-len(f.EmbedFields))), + }, + Comparators: []*pyast.Node{ + poet.Constant(nil), + }, }, }, } } else { - val = subscriptNode( - rowVar, - constantInt(idx), - ) + val = subscriptNode(rowVar, constantInt(idx)) idx++ } call.Keywords = append(call.Keywords, &pyast.Keyword{ From c505216b420ad2f791b4ce2c7697da385d7c913e Mon Sep 17 00:00:00 2001 From: simo7 Date: Thu, 17 Oct 2024 02:37:45 +0200 Subject: [PATCH 15/15] Allow optional embeds --- internal/gen.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/internal/gen.go b/internal/gen.go index 8454a7d..9ce877b 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -129,7 +129,7 @@ func (v QueryValue) RowNode(rowVar string) *pyast.Node { Left: &pyast.Node{ Node: &pyast.Node_Call{ Call: &pyast.Call{ - Func: f.Type.Annotation(false), + Func: poet.Name(f.Type.InnerType), Keywords: embedFields, }, }, @@ -466,7 +466,7 @@ func columnsToStruct(req *plugin.GenerateRequest, name string, columns []pyColum f.Type = pyType{ InnerType: MODELS_FILENAME + "." + c.embed.modelType, IsArray: c.IsArray, - IsNull: false, + IsNull: !c.NotNull, } f.EmbedFields = c.embed.fields } @@ -523,7 +523,9 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ return nil, errors.New("invalid query parameter limit") } enforcedFields := make(map[string]bool) + // log.Printf("%v\n\n", query) for _, c := range query.Columns { + // log.Printf("%v\n\n", c) if fields, ok := rlsFieldsByTable[c.GetTable().GetName()]; ok { for _, f := range fields { enforcedFields[f] = false @@ -620,6 +622,7 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ qs = append(qs, gq) } sort.Slice(qs, func(i, j int) bool { return qs[i].MethodName < qs[j].MethodName }) + // return nil, errors.New("debug") return qs, nil }