Skip to content

chore(python): Delete template-based codegen #1345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 15, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 2 additions & 218 deletions internal/codegen/python/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,6 @@ type pyType struct {
IsNull bool
}

func (t pyType) String() string {
v := t.InnerType
if t.IsArray {
v = fmt.Sprintf("List[%s]", v)
}
if t.IsNull {
v = fmt.Sprintf("Optional[%s]", v)
}
return v
}

func (t pyType) Annotation() *pyast.Node {
ann := poet.Name(t.InnerType)
if t.IsArray {
Expand Down Expand Up @@ -105,40 +94,6 @@ func (v QueryValue) isEmpty() bool {
return v.Typ == (pyType{}) && v.Name == "" && v.Struct == nil
}

func (v QueryValue) Pair() string {
if v.isEmpty() {
return ""
}
return v.Name + ": " + v.Type()
}

func (v QueryValue) Type() string {
if v.Typ != (pyType{}) {
return v.Typ.String()
}
if v.Struct != nil {
if v.Emit {
return v.Struct.Name
} else {
return "models." + v.Struct.Name
}
}
panic("no type for QueryValue: " + v.Name)
}

func (v QueryValue) StructRowParser(rowVar string, indentCount int) string {
if !v.IsStruct() {
panic("StructRowParse called on non-struct QueryValue")
}
indent := strings.Repeat(" ", indentCount+4)
params := make([]string, 0, len(v.Struct.Fields))
for i, f := range v.Struct.Fields {
params = append(params, fmt.Sprintf("%s%s=%s[%v],", indent, f.Name, rowVar, i))
}
indent = strings.Repeat(" ", indentCount)
return v.Type() + "(\n" + strings.Join(params, "\n") + "\n" + indent + ")"
}

func (v QueryValue) RowNode(rowVar string) *pyast.Node {
if !v.IsStruct() {
return subscriptNode(
Expand Down Expand Up @@ -178,21 +133,6 @@ type Query struct {
Args []QueryValue
}

func (q Query) ArgPairs() string {
// A single struct arg does not need to be passed as a keyword argument
if len(q.Args) == 1 && q.Args[0].IsStruct() {
return ", " + q.Args[0].Pair()
}
argPairs := make([]string, 0, len(q.Args))
for _, a := range q.Args {
argPairs = append(argPairs, a.Pair())
}
if len(argPairs) == 0 {
return ""
}
return ", *, " + strings.Join(argPairs, ", ")
}

func (q Query) AddArgs(args *pyast.Arguments) {
// A single struct arg does not need to be passed as a keyword argument
if len(q.Args) == 1 && q.Args[0].IsStruct() {
Expand All @@ -210,32 +150,6 @@ func (q Query) AddArgs(args *pyast.Arguments) {
}
}

func (q Query) ArgDict() string {
params := make([]string, 0, len(q.Args))
i := 1
for _, a := range q.Args {
if a.isEmpty() {
continue
}
if a.IsStruct() {
for _, f := range a.Struct.Fields {
params = append(params, fmt.Sprintf("\"p%v\": %s", i, a.Name+"."+f.Name))
i++
}
} else {
params = append(params, fmt.Sprintf("\"p%v\": %s", i, a.Name))
i++
}
}
if len(params) == 0 {
return ""
}
if len(params) < 4 {
return ", {" + strings.Join(params, ", ") + "}"
}
return ", {\n " + strings.Join(params, ",\n ") + ",\n }"
}

func (q Query) ArgDictNode() *pyast.Node {
dict := &pyast.Dict{}
i := 1
Expand Down Expand Up @@ -914,15 +828,15 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
mod.Body = append(mod.Body, assignNode(q.ConstantName, poet.Constant(queryText)))
for _, arg := range q.Args {
if arg.EmitStruct() {
def := dataclassNode(arg.Type())
def := dataclassNode(arg.Struct.Name)
for _, f := range arg.Struct.Fields {
def.Body = append(def.Body, fieldNode(f))
}
mod.Body = append(mod.Body, poet.Node(def))
}
}
if q.Ret.EmitStruct() {
def := dataclassNode(q.Ret.Type())
def := dataclassNode(q.Ret.Struct.Name)
for _, f := range q.Ret.Struct.Fields {
def.Body = append(def.Body, fieldNode(f))
}
Expand Down Expand Up @@ -1118,136 +1032,6 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
return poet.Node(mod)
}

var queriesTmpl = `
{{- define "dataclassParse"}}

{{end}}
# Code generated by sqlc. DO NOT EDIT.
{{- range imports .SourceName}}
{{.}}
{{- end}}

{{range .Queries}}
{{- if $.OutputQuery .SourceName}}
{{.ConstantName}} = """-- name: {{.MethodName}} \\{{.Cmd}}
{{.SQL}}
"""
{{range .Args}}
{{- if .EmitStruct}}

@dataclasses.dataclass()
class {{.Type}}: {{- range .Struct.Fields}}
{{.Name}}: {{.Type}}
{{- end}}
{{end}}{{end}}
{{- if .Ret.EmitStruct}}

@dataclasses.dataclass()
class {{.Ret.Type}}: {{- range .Ret.Struct.Fields}}
{{.Name}}: {{.Type}}
{{- end}}
{{end}}
{{end}}
{{- end}}

{{- if .EmitSync}}
class Querier:
def __init__(self, conn: sqlalchemy.engine.Connection):
self._conn = conn
{{range .Queries}}
{{- if $.OutputQuery .SourceName}}
{{- if eq .Cmd ":one"}}
def {{.MethodName}}(self{{.ArgPairs}}) -> Optional[{{.Ret.Type}}]:
row = self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}}).first()
if row is None:
return None
{{- if .Ret.IsStruct}}
return {{.Ret.StructRowParser "row" 8}}
{{- else}}
return row[0]
{{- end}}
{{end}}

{{- if eq .Cmd ":many"}}
def {{.MethodName}}(self{{.ArgPairs}}) -> Iterator[{{.Ret.Type}}]:
result = self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}})
for row in result:
{{- if .Ret.IsStruct}}
yield {{.Ret.StructRowParser "row" 12}}
{{- else}}
yield row[0]
{{- end}}
{{end}}

{{- if eq .Cmd ":exec"}}
def {{.MethodName}}(self{{.ArgPairs}}) -> None:
self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}})
{{end}}

{{- if eq .Cmd ":execrows"}}
def {{.MethodName}}(self{{.ArgPairs}}) -> int:
result = self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}})
return result.rowcount
{{end}}

{{- if eq .Cmd ":execresult"}}
def {{.MethodName}}(self{{.ArgPairs}}) -> sqlalchemy.engine.Result:
return self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}})
{{end}}
{{- end}}
{{- end}}
{{- end}}

{{- if .EmitAsync}}

class AsyncQuerier:
def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection):
self._conn = conn
{{range .Queries}}
{{- if $.OutputQuery .SourceName}}
{{- if eq .Cmd ":one"}}
async def {{.MethodName}}(self{{.ArgPairs}}) -> Optional[{{.Ret.Type}}]:
row = (await self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}})).first()
if row is None:
return None
{{- if .Ret.IsStruct}}
return {{.Ret.StructRowParser "row" 8}}
{{- else}}
return row[0]
{{- end}}
{{end}}

{{- if eq .Cmd ":many"}}
async def {{.MethodName}}(self{{.ArgPairs}}) -> AsyncIterator[{{.Ret.Type}}]:
result = await self._conn.stream(sqlalchemy.text({{.ConstantName}}){{.ArgDict}})
async for row in result:
{{- if .Ret.IsStruct}}
yield {{.Ret.StructRowParser "row" 12}}
{{- else}}
yield row[0]
{{- end}}
{{end}}

{{- if eq .Cmd ":exec"}}
async def {{.MethodName}}(self{{.ArgPairs}}) -> None:
await self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}})
{{end}}

{{- if eq .Cmd ":execrows"}}
async def {{.MethodName}}(self{{.ArgPairs}}) -> int:
result = await self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}})
return result.rowcount
{{end}}

{{- if eq .Cmd ":execresult"}}
async def {{.MethodName}}(self{{.ArgPairs}}) -> sqlalchemy.engine.Result:
return await self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}})
{{end}}
{{- end}}
{{- end}}
{{- end}}
`

type pyTmplCtx struct {
Models []Struct
Queries []Query
Expand Down