diff --git a/internal/gen.go b/internal/gen.go index ebe34b0..b9b4c2a 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -207,6 +207,8 @@ func pyInnerType(req *plugin.CodeGenRequest, col *plugin.Column) string { switch req.Settings.Engine { case "postgresql": return postgresType(req, col) + case "mysql": + return mysqlType(req, col) default: log.Println("unsupported engine type") return "Any" @@ -376,6 +378,15 @@ func sqlalchemySQL(s, engine string) string { s = strings.ReplaceAll(s, ":", `\\:`) if engine == "postgresql" { return postgresPlaceholderRegexp.ReplaceAllString(s, ":p$1") + } else if engine == "mysql" { + // All "?" in string s in string s are replaced with ":p1", ":p2", ... in that order + parts := strings.Split(s, "?") + for i := range parts { + if i != 0 { + parts[i] = fmt.Sprintf(":p%d%s", i, parts[i]) + } + } + return strings.Join(parts, "") } return s } diff --git a/internal/mysql_type.go b/internal/mysql_type.go new file mode 100644 index 0000000..7bad4fa --- /dev/null +++ b/internal/mysql_type.go @@ -0,0 +1,72 @@ +package python + +import ( + "log" + + "buf.build/gen/go/sqlc/sqlc/protocolbuffers/go/protos/plugin" + "github.com/sqlc-dev/sqlc-go/sdk" +) + +func mysqlType(req *plugin.CodeGenRequest, col *plugin.Column) string { + columnType := sdk.DataType(col.Type) + + switch columnType { + + case "varchar", "text", "char", "tinytext", "mediumtext", "longtext": + return "str" + + case "tinyint": + if col.Length == 1 { + return "bool" + } else { + return "int" + } + + case "int", "integer", "smallint", "mediumint", "year": + return "int" + + case "bigint": + return "int" + + case "blob", "binary", "varbinary", "tinyblob", "mediumblob", "longblob": + // TODO: Proper blob support + return "Any" + + case "double", "double precision", "real", "float": + return "float" + + case "decimal", "dec", "fixed": + return "string" + + case "enum": + // TODO: Proper Enum support + return "string" + + case "date", "timestamp", "datetime", "time": + return "datetime.date" + + case "boolean", "bool": + return "bool" + + case "json": + return "Any" + + case "any": + return "Any" + + default: + for _, schema := range req.Catalog.Schemas { + for _, enum := range schema.Enums { + if columnType == enum.Name { + if schema.Name == req.Catalog.DefaultSchema { + return "models." + modelName(enum.Name, req.Settings) + } + return "models." + modelName(schema.Name+"_"+enum.Name, req.Settings) + } + } + } + log.Printf("Unknown MySQL type: %s\n", columnType) + return "Any" + + } +}