Skip to content

Commit 449d140

Browse files
committed
feat: added type annotations
1 parent f239082 commit 449d140

File tree

4 files changed

+43
-8
lines changed

4 files changed

+43
-8
lines changed

internal/ast/ast.pb.go

Lines changed: 10 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/gen.go

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -702,12 +702,23 @@ func pydanticNode(name string) *pyast.ClassDef {
702702
}
703703
}
704704

705-
func fieldNode(f Field) *pyast.Node {
705+
func fieldNode(f Field, defaultNone bool) *pyast.Node {
706+
// TODO: Current AST is showing limitation as annotated assign does not support a value :'(, manually edited :'(
707+
var value *pyast.Node = nil
708+
if defaultNone && f.Type.IsNull {
709+
value = &pyast.Node{
710+
Node: &pyast.Node_Name{
711+
Name: &pyast.Name{Id: "None"},
712+
},
713+
}
714+
}
715+
706716
return &pyast.Node{
707717
Node: &pyast.Node_AnnAssign{
708718
AnnAssign: &pyast.AnnAssign{
709719
Target: &pyast.Name{Id: f.Name},
710720
Annotation: f.Type.Annotation(),
721+
Value: value,
711722
Comment: f.Comment,
712723
},
713724
},
@@ -825,7 +836,7 @@ func buildModelsTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
825836
})
826837
}
827838
for _, f := range m.Fields {
828-
def.Body = append(def.Body, fieldNode(f))
839+
def.Body = append(def.Body, fieldNode(f, false))
829840
}
830841
mod.Body = append(mod.Body, &pyast.Node{
831842
Node: &pyast.Node_ClassDef{
@@ -904,6 +915,8 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
904915
}
905916
queryText := fmt.Sprintf("-- name: %s \\\\%s\n%s\n", q.MethodName, q.Cmd, q.SQL)
906917
mod.Body = append(mod.Body, assignNode(q.ConstantName, poet.Constant(queryText)))
918+
919+
// Generate params structures
907920
for _, arg := range q.Args {
908921
if arg.EmitStruct() {
909922
var def *pyast.ClassDef
@@ -912,8 +925,18 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
912925
} else {
913926
def = dataclassNode(arg.Struct.Name)
914927
}
915-
for _, f := range arg.Struct.Fields {
916-
def.Body = append(def.Body, fieldNode(f))
928+
929+
// We need a copy as we want to make sure that nullable params are at the end of the dataclass
930+
fields := make([]Field, len(arg.Struct.Fields))
931+
copy(fields, arg.Struct.Fields)
932+
933+
// Place all nullable fields at the end and try to keep the original order as much as possible
934+
sort.SliceStable(fields, func(i int, j int) bool {
935+
return (fields[j].Type.IsNull && fields[i].Type.IsNull != fields[j].Type.IsNull) || i < j
936+
})
937+
938+
for _, f := range fields {
939+
def.Body = append(def.Body, fieldNode(f, true))
917940
}
918941
mod.Body = append(mod.Body, poet.Node(def))
919942
}
@@ -926,7 +949,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
926949
def = dataclassNode(q.Ret.Struct.Name)
927950
}
928951
for _, f := range q.Ret.Struct.Fields {
929-
def.Body = append(def.Body, fieldNode(f))
952+
def.Body = append(def.Body, fieldNode(f, false))
930953
}
931954
mod.Body = append(mod.Body, poet.Node(def))
932955
}

internal/postgresql_type.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func postgresType(req *plugin.CodeGenRequest, col *plugin.Column) string {
2222
case "json", "jsonb":
2323
return "Any"
2424
case "bytea", "blob", "pg_catalog.bytea":
25-
return "memoryview"
25+
return "bytes"
2626
case "date":
2727
return "datetime.date"
2828
case "pg_catalog.time", "pg_catalog.timetz":

internal/printer/printer.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,10 @@ func (w *writer) printAnnAssign(aa *ast.AnnAssign, indent int32) {
140140
w.printName(aa.Target, indent)
141141
w.print(": ")
142142
w.printNode(aa.Annotation, indent)
143+
if aa.Value != nil {
144+
w.print(" = ")
145+
w.printNode(aa.Value, indent)
146+
}
143147
}
144148

145149
func (w *writer) printArg(a *ast.Arg, indent int32) {

0 commit comments

Comments
 (0)