From 9b731d0c58e0a5422a39da0a3cfd09d10c20e9ff Mon Sep 17 00:00:00 2001 From: Andrew Benton Date: Wed, 19 Jul 2023 18:11:33 -0400 Subject: [PATCH] fix(vet): report an error when a query is unpreparable, close prepared statement connection --- internal/cmd/vet.go | 42 +++++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/internal/cmd/vet.go b/internal/cmd/vet.go index 1233614340..b35e050ce9 100644 --- a/internal/cmd/vet.go +++ b/internal/cmd/vet.go @@ -34,6 +34,17 @@ var ErrFailedChecks = errors.New("failed checks") const RuleDbPrepare = "sqlc/db-prepare" const QueryFlagSqlcVetDisable = "@sqlc-vet-disable" +type emptyProgram struct { +} + +func (e *emptyProgram) Eval(any) (ref.Val, *cel.EvalDetails, error) { + return nil, nil, fmt.Errorf("unimplemented") +} + +func (e *emptyProgram) ContextEval(ctx context.Context, a any) (ref.Val, *cel.EvalDetails, error) { + return e.Eval(a) +} + func NewCmdVet() *cobra.Command { return &cobra.Command{ Use: "vet", @@ -53,17 +64,6 @@ func NewCmdVet() *cobra.Command { } } -type emptyProgram struct { -} - -func (e *emptyProgram) Eval(any) (ref.Val, *cel.EvalDetails, error) { - return nil, nil, fmt.Errorf("unimplemented") -} - -func (e *emptyProgram) ContextEval(ctx context.Context, a any) (ref.Val, *cel.EvalDetails, error) { - return e.Eval(a) -} - func Vet(ctx context.Context, e Env, dir, filename string, stderr io.Writer) error { configPath, conf, err := readConfig(stderr, dir, filename) if err != nil { @@ -100,7 +100,7 @@ func Vet(ctx context.Context, e Env, dir, filename string, stderr io.Writer) err } checks := map[string]cel.Program{ - RuleDbPrepare: &emptyProgram{}, + RuleDbPrepare: &emptyProgram{}, // Keep this to trigger the name conflict error below } msgs := map[string]string{} @@ -198,7 +198,8 @@ type dbPreparer struct { } func (p *dbPreparer) Prepare(ctx context.Context, name, query string) error { - _, err := p.db.PrepareContext(ctx, query) + s, err := p.db.PrepareContext(ctx, query) + s.Close() return err } @@ -316,12 +317,15 @@ func (c *checker) checkSQL(ctx context.Context, s config.SQL) error { continue } original := result.Queries[i] - if prepareable(s, original.RawStmt) { - name := fmt.Sprintf("sqlc_vet_%d_%d", time.Now().Unix(), i) - if err := prep.Prepare(ctx, name, query.Text); err != nil { - fmt.Fprintf(c.Stderr, "%s: %s: %s: error preparing query: %s\n", query.Filename, q.Name, name, err) - errored = true - } + if !prepareable(s, original.RawStmt) { + fmt.Fprintf(c.Stderr, "%s: %s: %s: error preparing query: %s\n", query.Filename, q.Name, name, "query type is unpreparable") + errored = true + continue + } + name := fmt.Sprintf("sqlc_vet_%d_%d", time.Now().Unix(), i) + if err := prep.Prepare(ctx, name, query.Text); err != nil { + fmt.Fprintf(c.Stderr, "%s: %s: %s: error preparing query: %s\n", query.Filename, q.Name, name, err) + errored = true } continue }