diff --git a/Makefile b/Makefile index d8ac60b..dbea94f 100644 --- a/Makefile +++ b/Makefile @@ -9,10 +9,10 @@ test: bin/sqlc-gen-python.wasm all: bin/sqlc-gen-python bin/sqlc-gen-python.wasm bin/sqlc-gen-python: bin go.mod go.sum $(wildcard **/*.go) - cd plugin && go build -o ../bin/sqlc-gen-python ./main.go + cd plugin/sqlc-gen-python && go build -o ../../bin/sqlc-gen-python ./main.go bin/sqlc-gen-python.wasm: bin/sqlc-gen-python - cd plugin && GOOS=wasip1 GOARCH=wasm go build -o ../bin/sqlc-gen-python.wasm main.go + cd plugin/sqlc-gen-python && GOOS=wasip1 GOARCH=wasm go build -o ../../bin/sqlc-gen-python.wasm main.go bin: mkdir -p bin diff --git a/README.md b/README.md index d7e0d75..7fe41ad 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,9 @@ +# Sqlc-gen-python for Asyncpg + +## Build + +make bin/sqlc-gen-python + ## Usage ```yaml diff --git a/internal/config.go b/internal/config.go index 009cb04..2a9bcdb 100644 --- a/internal/config.go +++ b/internal/config.go @@ -2,11 +2,20 @@ package python type Config struct { EmitExactTableNames bool `json:"emit_exact_table_names"` - EmitSyncQuerier bool `json:"emit_sync_querier"` - EmitAsyncQuerier bool `json:"emit_async_querier"` Package string `json:"package"` Out string `json:"out"` EmitPydanticModels bool `json:"emit_pydantic_models"` QueryParameterLimit *int32 `json:"query_parameter_limit"` InflectionExcludeTableNames []string `json:"inflection_exclude_table_names"` + TablePrefix string `json:"table_prefix"` + // When a query uses a table with RLS enforced fields, it will be required to + // parametrized those fields. Not covered: + // - Associate tables + // - sqlc.embed() + // - json_agg(tbl.*) + RLSEnforcedFields []string `json:"rls_enforced_fields"` + // Merge queries defined in different files into one output queries.py file + MergeQueryFiles bool `json:"merge_query_files"` } + +const MODELS_FILENAME = "db_models" diff --git a/internal/gen.go b/internal/gen.go index f81c53b..9ce877b 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -38,12 +38,15 @@ type pyType struct { IsNull bool } -func (t pyType) Annotation() *pyast.Node { +func (t pyType) Annotation(isFuncSignature bool) *pyast.Node { ann := poet.Name(t.InnerType) if t.IsArray { ann = subscriptNode("List", ann) } - if t.IsNull { + if t.IsNull && isFuncSignature { + ann = optionalKeywordNode("Optional", ann) + } + if t.IsNull && !isFuncSignature { ann = subscriptNode("Optional", ann) } return ann @@ -53,6 +56,8 @@ type Field struct { Name string Type pyType Comment string + // EmbedFields contains the embedded fields that require scanning. + EmbedFields []Field } type Struct struct { @@ -69,15 +74,16 @@ type QueryValue struct { Typ pyType } +// Annotation in function signature func (v QueryValue) Annotation() *pyast.Node { if v.Typ != (pyType{}) { - return v.Typ.Annotation() + return v.Typ.Annotation(true) } if v.Struct != nil { if v.Emit { return poet.Name(v.Struct.Name) } else { - return typeRefNode("models", v.Struct.Name) + return typeRefNode(MODELS_FILENAME, v.Struct.Name) } } panic("no type for QueryValue: " + v.Name) @@ -105,14 +111,47 @@ func (v QueryValue) RowNode(rowVar string) *pyast.Node { call := &pyast.Call{ Func: v.Annotation(), } - for i, f := range v.Struct.Fields { + var idx int + for _, f := range v.Struct.Fields { + var val *pyast.Node + if len(f.EmbedFields) > 0 { + var embedFields []*pyast.Keyword + for _, embed := range f.EmbedFields { + embedFields = append(embedFields, &pyast.Keyword{ + Arg: embed.Name, + Value: subscriptNode(rowVar, constantInt(idx)), + }) + idx++ + } + val = &pyast.Node{ + Node: &pyast.Node_Compare{ + Compare: &pyast.Compare{ + Left: &pyast.Node{ + Node: &pyast.Node_Call{ + Call: &pyast.Call{ + Func: poet.Name(f.Type.InnerType), + Keywords: embedFields, + }, + }, + }, + Ops: []*pyast.Node{ + poet.Name(fmt.Sprintf("if row[%d] else", idx-len(f.EmbedFields))), + }, + Comparators: []*pyast.Node{ + poet.Constant(nil), + }, + }, + }, + } + } else { + val = subscriptNode(rowVar, constantInt(idx)) + idx++ + } call.Keywords = append(call.Keywords, &pyast.Keyword{ - Arg: f.Name, - Value: subscriptNode( - rowVar, - constantInt(i), - ), + Arg: f.Name, + Value: val, }) + } return &pyast.Node{ Node: &pyast.Node_Call{ @@ -143,12 +182,41 @@ func (q Query) AddArgs(args *pyast.Arguments) { }) return } + var optionalArgs []*pyast.Arg for _, a := range q.Args { + if a.Typ.IsNull { + optionalArgs = append(optionalArgs, &pyast.Arg{ + Arg: a.Name, + Annotation: a.Annotation(), + }) + continue + } args.KwOnlyArgs = append(args.KwOnlyArgs, &pyast.Arg{ Arg: a.Name, Annotation: a.Annotation(), }) } + args.KwOnlyArgs = append(args.KwOnlyArgs, optionalArgs...) +} + +func (q Query) ArgNodes() []*pyast.Node { + args := []*pyast.Node{} + i := 1 + for _, a := range q.Args { + if a.isEmpty() { + continue + } + if a.IsStruct() { + for _, f := range a.Struct.Fields { + args = append(args, typeRefNode(a.Name, f.Name)) + i++ + } + } else { + args = append(args, poet.Name(a.Name)) + i++ + } + } + return args } func (q Query) ArgDictNode() *pyast.Node { @@ -279,6 +347,9 @@ func buildModels(conf Config, req *plugin.GenerateRequest) []Struct { Exclusions: conf.InflectionExcludeTableNames, }) } + if conf.TablePrefix != "" { + structName = conf.TablePrefix + strings.ToUpper(structName[:1]) + structName[1:] + } s := Struct{ Table: plugin.Identifier{Schema: schema.Name, Name: table.Rel.Name}, Name: modelName(structName, req.Settings), @@ -319,6 +390,46 @@ func paramName(p *plugin.Parameter) string { type pyColumn struct { id int32 *plugin.Column + embed *pyEmbed +} + +type pyEmbed struct { + modelType string + modelName string + fields []Field +} + +// look through all the structs and attempt to find a matching one to embed +// We need the name of the struct and its field names. +func newPyEmbed(embed *plugin.Identifier, structs []Struct, defaultSchema string) *pyEmbed { + if embed == nil { + return nil + } + + for _, s := range structs { + embedSchema := defaultSchema + if embed.Schema != "" { + embedSchema = embed.Schema + } + + // compare the other attributes + if embed.Catalog != s.Table.Catalog || embed.Name != s.Table.Name || embedSchema != s.Table.Schema { + continue + } + + fields := make([]Field, len(s.Fields)) + for i, f := range s.Fields { + fields[i] = f + } + + return &pyEmbed{ + modelType: s.Name, + modelName: s.Name, + fields: fields, + } + } + + return nil } func columnsToStruct(req *plugin.GenerateRequest, name string, columns []pyColumn) *Struct { @@ -330,6 +441,12 @@ func columnsToStruct(req *plugin.GenerateRequest, name string, columns []pyColum for i, c := range columns { colName := columnName(c.Column, i) fieldName := colName + + // override col with expected model name + if c.embed != nil { + colName = c.embed.modelName + } + // Track suffixes by the ID of the column, so that columns referring to // the same numbered parameter can be reused. var suffix int32 @@ -342,28 +459,38 @@ func columnsToStruct(req *plugin.GenerateRequest, name string, columns []pyColum if suffix > 0 { fieldName = fmt.Sprintf("%s_%d", fieldName, suffix) } - gs.Fields = append(gs.Fields, Field{ - Name: fieldName, - Type: makePyType(req, c.Column), - }) + f := Field{Name: fieldName} + if c.embed == nil { + f.Type = makePyType(req, c.Column) + } else { + f.Type = pyType{ + InnerType: MODELS_FILENAME + "." + c.embed.modelType, + IsArray: c.IsArray, + IsNull: !c.NotNull, + } + f.EmbedFields = c.embed.fields + } + gs.Fields = append(gs.Fields, f) seen[colName]++ } return &gs } -var postgresPlaceholderRegexp = regexp.MustCompile(`\B\$(\d+)\b`) - -// Sqlalchemy uses ":name" for placeholders, so "$N" is converted to ":pN" -// This also means ":" has special meaning to sqlalchemy, so it must be escaped. -func sqlalchemySQL(s, engine string) string { - s = strings.ReplaceAll(s, ":", `\\:`) - if engine == "postgresql" { - return postgresPlaceholderRegexp.ReplaceAllString(s, ":p$1") +func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([]Query, error) { + rlsFieldsByTable := make(map[string][]string) + if len(conf.RLSEnforcedFields) > 0 { + for i := range structs { + tableName := structs[i].Table.Name + for _, f := range structs[i].Fields { + for _, enforced := range conf.RLSEnforcedFields { + if f.Name == enforced { + rlsFieldsByTable[tableName] = append(rlsFieldsByTable[tableName], f.Name) + } + } + } + } } - return s -} -func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([]Query, error) { qs := make([]Query, 0, len(req.Queries)) for _, query := range req.Queries { if query.Name == "" { @@ -384,7 +511,7 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ MethodName: methodName, FieldName: sdk.LowerTitle(query.Name) + "Stmt", ConstantName: strings.ToUpper(methodName), - SQL: sqlalchemySQL(query.Text, req.Settings.Engine), + SQL: query.Text, SourceName: query.Filename, } @@ -395,9 +522,22 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ if qpl < 0 { return nil, errors.New("invalid query parameter limit") } + enforcedFields := make(map[string]bool) + // log.Printf("%v\n\n", query) + for _, c := range query.Columns { + // log.Printf("%v\n\n", c) + if fields, ok := rlsFieldsByTable[c.GetTable().GetName()]; ok { + for _, f := range fields { + enforcedFields[f] = false + } + } + } if len(query.Params) > qpl || qpl == 0 { var cols []pyColumn for _, p := range query.Params { + if _, ok := enforcedFields[p.GetColumn().GetName()]; ok { + enforcedFields[p.Column.Name] = true + } cols = append(cols, pyColumn{ id: p.Number, Column: p.Column, @@ -411,6 +551,9 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ } else { args := make([]QueryValue, 0, len(query.Params)) for _, p := range query.Params { + if _, ok := enforcedFields[p.GetColumn().GetName()]; ok { + enforcedFields[p.Column.Name] = true + } args = append(args, QueryValue{ Name: paramName(p), Typ: makePyType(req, p.Column), @@ -418,8 +561,12 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ } gq.Args = args } - - if len(query.Columns) == 1 { + for field, is_enforced := range enforcedFields { + if !is_enforced { + return nil, fmt.Errorf("RLS field %s is not filtered in query %s", field, query.Name) + } + } + if len(query.Columns) == 1 && query.Columns[0].EmbedTable == nil { c := query.Columns[0] gq.Ret = QueryValue{ Name: columnName(c, 0), @@ -459,6 +606,7 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ columns = append(columns, pyColumn{ id: int32(i), Column: c, + embed: newPyEmbed(c.EmbedTable, structs, req.Catalog.DefaultSchema), }) } gs = columnsToStruct(req, query.Name+"Row", columns) @@ -474,6 +622,7 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ qs = append(qs, gq) } sort.Slice(qs, func(i, j int) bool { return qs[i].MethodName < qs[j].MethodName }) + // return nil, errors.New("debug") return qs, nil } @@ -566,6 +715,31 @@ func subscriptNode(value string, slice *pyast.Node) *pyast.Node { } } +func optionalKeywordNode(value string, slice *pyast.Node) *pyast.Node { + v := &pyast.Node{ + Node: &pyast.Node_Subscript{ + Subscript: &pyast.Subscript{ + Value: &pyast.Name{Id: value}, + Slice: slice, + }, + }, + } + return &pyast.Node{ + Node: &pyast.Node_Keyword{ + Keyword: &pyast.Keyword{ + Arg: string(pyprint.Print(v, pyprint.Options{}).Python), + Value: &pyast.Node{ + Node: &pyast.Node_Constant{ + Constant: &pyast.Constant{ + Value: &pyast.Constant_None{None: true}, + }, + }, + }, + }, + }, + } +} + func dataclassNode(name string) *pyast.ClassDef { return &pyast.ClassDef{ Name: name, @@ -606,7 +780,7 @@ func fieldNode(f Field) *pyast.Node { Node: &pyast.Node_AnnAssign{ AnnAssign: &pyast.AnnAssign{ Target: &pyast.Name{Id: f.Name}, - Annotation: f.Type.Annotation(), + Annotation: f.Type.Annotation(false), Comment: f.Comment, }, }, @@ -621,22 +795,9 @@ func typeRefNode(base string, parts ...string) *pyast.Node { return n } -func connMethodNode(method, name string, arg *pyast.Node) *pyast.Node { - args := []*pyast.Node{ - { - Node: &pyast.Node_Call{ - Call: &pyast.Call{ - Func: typeRefNode("sqlalchemy", "text"), - Args: []*pyast.Node{ - poet.Name(name), - }, - }, - }, - }, - } - if arg != nil { - args = append(args, arg) - } +func connMethodNode(method, name string, params ...*pyast.Node) *pyast.Node { + args := []*pyast.Node{poet.Name(name)} + args = append(args, params...) return &pyast.Node{ Node: &pyast.Node_Call{ Call: &pyast.Call{ @@ -789,7 +950,7 @@ func asyncQuerierClassDef() *pyast.ClassDef { }, { Arg: "conn", - Annotation: typeRefNode("sqlalchemy", "ext", "asyncio", "AsyncConnection"), + Annotation: typeRefNode("asyncpg", "pool", "PoolConnectionProxy"), }, }, }, @@ -825,7 +986,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node { ImportFrom: &pyast.ImportFrom{ Module: ctx.C.Package, Names: []*pyast.Node{ - poet.Alias("models"), + poet.Alias(MODELS_FILENAME), }, }, }, @@ -869,190 +1030,84 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node { } } - if ctx.C.EmitSyncQuerier { - cls := querierClassDef() - for _, q := range ctx.Queries { - if !ctx.OutputQuery(q.SourceName) { - continue - } - f := &pyast.FunctionDef{ - Name: q.MethodName, - Args: &pyast.Arguments{ - Args: []*pyast.Arg{ - { - Arg: "self", - }, + cls := asyncQuerierClassDef() + for _, q := range ctx.Queries { + if !ctx.OutputQuery(q.SourceName) { + continue + } + f := &pyast.AsyncFunctionDef{ + Name: q.MethodName, + Args: &pyast.Arguments{ + Args: []*pyast.Arg{ + { + Arg: "self", }, }, - } - - q.AddArgs(f.Args) - exec := connMethodNode("execute", q.ConstantName, q.ArgDictNode()) - - switch q.Cmd { - case ":one": - f.Body = append(f.Body, - assignNode("row", poet.Node( - &pyast.Call{ - Func: poet.Attribute(exec, "first"), - }, - )), - poet.Node( - &pyast.If{ - Test: poet.Node( - &pyast.Compare{ - Left: poet.Name("row"), - Ops: []*pyast.Node{ - poet.Is(), - }, - Comparators: []*pyast.Node{ - poet.Constant(nil), - }, - }, - ), - Body: []*pyast.Node{ - poet.Return( - poet.Constant(nil), - ), - }, - }, - ), - poet.Return(q.Ret.RowNode("row")), - ) - f.Returns = subscriptNode("Optional", q.Ret.Annotation()) - case ":many": - f.Body = append(f.Body, - assignNode("result", exec), - poet.Node( - &pyast.For{ - Target: poet.Name("row"), - Iter: poet.Name("result"), - Body: []*pyast.Node{ - poet.Expr( - poet.Yield( - q.Ret.RowNode("row"), - ), - ), - }, - }, - ), - ) - f.Returns = subscriptNode("Iterator", q.Ret.Annotation()) - case ":exec": - f.Body = append(f.Body, exec) - f.Returns = poet.Constant(nil) - case ":execrows": - f.Body = append(f.Body, - assignNode("result", exec), - poet.Return(poet.Attribute(poet.Name("result"), "rowcount")), - ) - f.Returns = poet.Name("int") - case ":execresult": - f.Body = append(f.Body, - poet.Return(exec), - ) - f.Returns = typeRefNode("sqlalchemy", "engine", "Result") - default: - panic("unknown cmd " + q.Cmd) - } - - cls.Body = append(cls.Body, poet.Node(f)) + }, } - mod.Body = append(mod.Body, poet.Node(cls)) - } - if ctx.C.EmitAsyncQuerier { - cls := asyncQuerierClassDef() - for _, q := range ctx.Queries { - if !ctx.OutputQuery(q.SourceName) { - continue - } - f := &pyast.AsyncFunctionDef{ - Name: q.MethodName, - Args: &pyast.Arguments{ - Args: []*pyast.Arg{ - { - Arg: "self", - }, - }, - }, - } + q.AddArgs(f.Args) - q.AddArgs(f.Args) - exec := connMethodNode("execute", q.ConstantName, q.ArgDictNode()) + switch q.Cmd { + case ":one": + fetchrow := connMethodNode("fetchrow", q.ConstantName, q.ArgNodes()...) + f.Body = append(f.Body, assignNode("row", poet.Await(fetchrow))) - switch q.Cmd { - case ":one": - f.Body = append(f.Body, - assignNode("row", poet.Node( - &pyast.Call{ - Func: poet.Attribute(poet.Await(exec), "first"), - }, - )), - poet.Node( - &pyast.If{ - Test: poet.Node( - &pyast.Compare{ - Left: poet.Name("row"), - Ops: []*pyast.Node{ - poet.Is(), - }, - Comparators: []*pyast.Node{ - poet.Constant(nil), - }, + if isAlwaysReturningInsert(q.SQL) { + f.Returns = q.Ret.Annotation() + } else { + f.Body = append(f.Body, poet.Node( + &pyast.If{ + Test: poet.Node( + &pyast.Compare{ + Left: poet.Name("row"), + Ops: []*pyast.Node{ + poet.Is(), }, - ), - Body: []*pyast.Node{ - poet.Return( + Comparators: []*pyast.Node{ poet.Constant(nil), - ), + }, }, + ), + Body: []*pyast.Node{ + poet.Return( + poet.Constant(nil), + ), }, - ), - poet.Return(q.Ret.RowNode("row")), - ) + }, + )) f.Returns = subscriptNode("Optional", q.Ret.Annotation()) - case ":many": - stream := connMethodNode("stream", q.ConstantName, q.ArgDictNode()) - f.Body = append(f.Body, - assignNode("result", poet.Await(stream)), - poet.Node( - &pyast.AsyncFor{ - Target: poet.Name("row"), - Iter: poet.Name("result"), - Body: []*pyast.Node{ - poet.Expr( - poet.Yield( - q.Ret.RowNode("row"), - ), + } + f.Body = append(f.Body, poet.Return(q.Ret.RowNode("row"))) + case ":many": + cursor := connMethodNode("cursor", q.ConstantName, q.ArgNodes()...) + f.Body = append(f.Body, + poet.Node( + &pyast.AsyncFor{ + Target: poet.Name("row"), + Iter: cursor, + Body: []*pyast.Node{ + poet.Expr( + poet.Yield( + q.Ret.RowNode("row"), ), - }, + ), }, - ), - ) - f.Returns = subscriptNode("AsyncIterator", q.Ret.Annotation()) - case ":exec": - f.Body = append(f.Body, poet.Await(exec)) - f.Returns = poet.Constant(nil) - case ":execrows": - f.Body = append(f.Body, - assignNode("result", poet.Await(exec)), - poet.Return(poet.Attribute(poet.Name("result"), "rowcount")), - ) - f.Returns = poet.Name("int") - case ":execresult": - f.Body = append(f.Body, - poet.Return(poet.Await(exec)), - ) - f.Returns = typeRefNode("sqlalchemy", "engine", "Result") - default: - panic("unknown cmd " + q.Cmd) - } - - cls.Body = append(cls.Body, poet.Node(f)) + }, + ), + ) + f.Returns = subscriptNode("AsyncIterator", q.Ret.Annotation()) + case ":exec": + exec := connMethodNode("execute", q.ConstantName, q.ArgNodes()...) + f.Body = append(f.Body, poet.Await(exec)) + f.Returns = poet.Constant(nil) + default: + panic("unknown cmd " + q.Cmd) } - mod.Body = append(mod.Body, poet.Node(cls)) + + cls.Body = append(cls.Body, poet.Node(f)) } + mod.Body = append(mod.Body, poet.Node(cls)) return poet.Node(mod) } @@ -1067,6 +1122,9 @@ type pyTmplCtx struct { } func (t *pyTmplCtx) OutputQuery(sourceName string) bool { + if t.C.MergeQueryFiles { + return true + } return t.SourceName == sourceName } @@ -1106,12 +1164,16 @@ func Generate(_ context.Context, req *plugin.GenerateRequest) (*plugin.GenerateR output := map[string]string{} result := pyprint.Print(buildModelsTree(&tctx, i), pyprint.Options{}) - tctx.SourceName = "models.py" - output["models.py"] = string(result.Python) + tctx.SourceName = MODELS_FILENAME + ".py" + output[MODELS_FILENAME+".py"] = string(result.Python) files := map[string]struct{}{} - for _, q := range queries { - files[q.SourceName] = struct{}{} + if i.C.MergeQueryFiles { + files["db_queries.sql"] = struct{}{} + } else { + for _, q := range queries { + files[q.SourceName] = struct{}{} + } } for source := range files { @@ -1136,3 +1198,18 @@ func Generate(_ context.Context, req *plugin.GenerateRequest) (*plugin.GenerateR return &resp, nil } + +func isAlwaysReturningInsert(sql string) bool { + var hasInsert, hasWhere, hasReturning bool + for _, w := range strings.Fields(sql) { + switch strings.ToUpper(w) { + case "INSERT": + hasInsert = true + case "WHERE": + hasWhere = true + case "RETURNING": + hasReturning = true + } + } + return hasInsert && hasReturning && !hasWhere +} diff --git a/internal/imports.go b/internal/imports.go index b88c58c..f5011b7 100644 --- a/internal/imports.go +++ b/internal/imports.go @@ -69,7 +69,7 @@ func queryValueUses(name string, qv QueryValue) bool { } func (i *importer) Imports(fileName string) []string { - if fileName == "models.py" { + if fileName == MODELS_FILENAME+".py" { return i.modelImports() } return i.queryImports(fileName) @@ -113,7 +113,7 @@ func (i *importer) modelImports() []string { func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map[string]importSpec) { queryUses := func(name string) bool { for _, q := range i.Queries { - if q.SourceName != fileName { + if !i.C.MergeQueryFiles && q.SourceName != fileName { continue } if queryValueUses(name, q.Ret) { @@ -131,10 +131,7 @@ func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map std := stdImports(queryUses) pkg := make(map[string]importSpec) - pkg["sqlalchemy"] = importSpec{Module: "sqlalchemy"} - if i.C.EmitAsyncQuerier { - pkg["sqlalchemy.ext.asyncio"] = importSpec{Module: "sqlalchemy.ext.asyncio"} - } + pkg["asyncpg"] = importSpec{Module: "asyncpg"} queryValueModelImports := func(qv QueryValue) { if qv.IsStruct() && qv.EmitStruct() { @@ -147,19 +144,14 @@ func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map } for _, q := range i.Queries { - if q.SourceName != fileName { + if !i.C.MergeQueryFiles && q.SourceName != fileName { continue } if q.Cmd == ":one" { std["typing.Optional"] = importSpec{Module: "typing", Name: "Optional"} } if q.Cmd == ":many" { - if i.C.EmitSyncQuerier { - std["typing.Iterator"] = importSpec{Module: "typing", Name: "Iterator"} - } - if i.C.EmitAsyncQuerier { - std["typing.AsyncIterator"] = importSpec{Module: "typing", Name: "AsyncIterator"} - } + std["typing.AsyncIterator"] = importSpec{Module: "typing", Name: "AsyncIterator"} } queryValueModelImports(q.Ret) for _, qv := range q.Args { @@ -173,7 +165,7 @@ func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map func (i *importer) queryImports(fileName string) []string { std, pkg := i.queryImportSpecs(fileName) - modelImportStr := fmt.Sprintf("from %s import models", i.C.Package) + modelImportStr := fmt.Sprintf("from %s import %s", i.C.Package, MODELS_FILENAME) importLines := []string{ buildImportBlock(std), diff --git a/plugin/main.go b/plugin/sqlc-gen-python/main.go similarity index 100% rename from plugin/main.go rename to plugin/sqlc-gen-python/main.go