Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add comments
  • Loading branch information
coadler committed Dec 1, 2023
commit 307b99938bd23ee5f4899ace894d6fff5a113cef
22 changes: 14 additions & 8 deletions coderd/database/migrations/txnmigrator.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package migrations

import (
"bytes"
"context"
"database/sql"
"fmt"
Expand All @@ -18,6 +17,10 @@ const (
migrationsTableName = "schema_migrations"
)

// pgTxnDriver is a Postgres migration driver that runs all migrations in a
// single transaction. This is done to prevent users from being locked out of
// their deployment if a migration fails, since the schema will simply revert
// back to the previous version.
type pgTxnDriver struct {
ctx context.Context
db *sql.DB
Expand Down Expand Up @@ -59,14 +62,11 @@ func (d *pgTxnDriver) Unlock() error {
return nil
}

// Run applies a migration to the database. migration is guaranteed to be not nil.
func (d *pgTxnDriver) Run(migration io.Reader) error {
migr, err := io.ReadAll(migration)
if err != nil {
return xerrors.Errorf("read migration: %w", err)
}
migr = bytes.ReplaceAll(migr, []byte("BEGIN;"), []byte{})
migr = bytes.ReplaceAll(migr, []byte("COMMIT;"), []byte{})
err = d.runStatement(migr)
if err != nil {
return xerrors.Errorf("run statement: %w", err)
Expand All @@ -81,11 +81,12 @@ func (d *pgTxnDriver) runStatement(statement []byte) error {
return nil
}
if _, err := d.tx.ExecContext(ctx, query); err != nil {
if pgErr, ok := err.(*pq.Error); ok { //nolint
var pgErr *pq.Error
if xerrors.As(err, &pgErr) {
var line uint
message := fmt.Sprintf("migration failed: %s", pgErr.Message)
if pgErr.Detail != "" {
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
message += ", " + pgErr.Detail
}
return database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
}
Expand All @@ -112,9 +113,13 @@ func (d *pgTxnDriver) SetVersion(version int, dirty bool) error {
}

func (d *pgTxnDriver) Version() (version int, dirty bool, err error) {
// If the transaction is valid (we hold the exclusive lock), use the txn for
// the query.
var q interface {
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
} = d.tx
// If we don't hold the lock just use the database. This only happens in the
// `Stepper` function and is only used in tests.
if d.tx == nil {
q = d.db
}
Expand All @@ -126,8 +131,9 @@ func (d *pgTxnDriver) Version() (version int, dirty bool, err error) {
return database.NilVersion, false, nil

case err != nil:
if e, ok := err.(*pq.Error); ok { //nolint
if e.Code.Name() == "undefined_table" {
var pgErr *pq.Error
if xerrors.As(err, &pgErr) {
if pgErr.Code.Name() == "undefined_table" {
return database.NilVersion, false, nil
}
}
Expand Down