Skip to content

feat: Add reset-password command #1380

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 12, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
allow non-destructively checking if database needs to be migrated
  • Loading branch information
dwahler committed May 10, 2022
commit b0805727be993de0b255cd8cb7b5cd40cac9e0f5
81 changes: 74 additions & 7 deletions coderd/database/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,40 @@ import (
"database/sql"
"embed"
"errors"
"os"

"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/postgres"
"github.com/golang-migrate/migrate/v4/source"
"github.com/golang-migrate/migrate/v4/source/iofs"
"golang.org/x/xerrors"
)

//go:embed migrations/*.sql
var migrations embed.FS

func migrateSetup(db *sql.DB) (*migrate.Migrate, error) {
func migrateSetup(db *sql.DB) (source.Driver, *migrate.Migrate, error) {
sourceDriver, err := iofs.New(migrations, "migrations")
if err != nil {
return nil, xerrors.Errorf("create iofs: %w", err)
return nil, nil, xerrors.Errorf("create iofs: %w", err)
}

dbDriver, err := postgres.WithInstance(db, &postgres.Config{})
if err != nil {
return nil, xerrors.Errorf("wrap postgres connection: %w", err)
return nil, nil, xerrors.Errorf("wrap postgres connection: %w", err)
}

m, err := migrate.NewWithInstance("", sourceDriver, "", dbDriver)
if err != nil {
return nil, xerrors.Errorf("new migrate instance: %w", err)
return nil, nil, xerrors.Errorf("new migrate instance: %w", err)
}

return m, nil
return sourceDriver, m, nil
}

// MigrateUp runs SQL migrations to ensure the database schema is up-to-date.
func MigrateUp(db *sql.DB) error {
m, err := migrateSetup(db)
_, m, err := migrateSetup(db)
if err != nil {
return xerrors.Errorf("migrate setup: %w", err)
}
Expand All @@ -55,7 +57,7 @@ func MigrateUp(db *sql.DB) error {

// MigrateDown runs all down SQL migrations.
func MigrateDown(db *sql.DB) error {
m, err := migrateSetup(db)
_, m, err := migrateSetup(db)
if err != nil {
return xerrors.Errorf("migrate setup: %w", err)
}
Expand All @@ -72,3 +74,68 @@ func MigrateDown(db *sql.DB) error {

return nil
}

// EnsureClean checks whether all migrations for the current version have been
// applied, without making any changes to the database. If not, returns a
// non-nil error.
func EnsureClean(db *sql.DB) error {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it helps, but if you want to test 99% of this function, you could write this as e.g. func EnsureClean(db *sql.DB) error { return ensureClean(db, migrateSetup); }.

This would let you swap out the migrateSetup function in your test, if that helps with the global state (considering your comment here #1380 (comment)). It would still require an internal test package but I think that's fine for this use-case.

sourceDriver, m, err := migrateSetup(db)
if err != nil {
return xerrors.Errorf("migrate setup: %w", err)
}

version, dirty, err := m.Version()
if err != nil {
return xerrors.Errorf("get migration version: %w", err)
}

if dirty {
return xerrors.Errorf("database has not been cleanly migrated")
}

// Verify that the database's migration version is "current" by checking
// that a migration with that version exists, but there is no next version.
err = CheckLatestVersion(sourceDriver, version)
if err != nil {
return xerrors.Errorf("database needs migration: %w", err)
}

return nil
}

// Returns nil if currentVersion corresponds to the latest available migration,
// otherwise an error explaining why not.
func CheckLatestVersion(sourceDriver source.Driver, currentVersion uint) error {
Comment on lines +106 to +108
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only thing I'm significantly unhappy about with this commit. I feel like this ought to be a private function, but that doesn't play nicely with the way we put tests in a separate package.

The alternative would be to test the EnsureClean function instead, but that seems difficult because its behavior implicitly depends on the global migrations data, for consistency with the rest of the package.

// This is ugly, but seems like the only way to do it with the public
// interfaces provided by golang-migrate.

// Check that there is no later version
nextVersion, err := sourceDriver.Next(currentVersion)
if err == nil {
return xerrors.Errorf("current version is %d, but later version %d exists", currentVersion, nextVersion)
}
if !errors.Is(err, os.ErrNotExist) {
return xerrors.Errorf("get next migration after %d: %w", currentVersion, err)
}

// Once we reach this point, we know that either currentVersion doesn't
// exist, or it has no successor (the return value from
// sourceDriver.Next() is the same in either case). So we need to check
// that either it's the first version, or it has a predecessor.

firstVersion, err := sourceDriver.First()
if err != nil {
// the total number of migrations should be non-zero, so this must be
// an actual error, not just a missing file
return xerrors.Errorf("get first migration: %w", err)
}
if firstVersion == currentVersion {
return nil
}

_, err = sourceDriver.Prev(currentVersion)
if err != nil {
return xerrors.Errorf("get previous migration: %w", err)
}
return nil
}
49 changes: 49 additions & 0 deletions coderd/database/migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ package database_test

import (
"database/sql"
"fmt"
"testing"

"github.com/golang-migrate/migrate/v4/source"
"github.com/golang-migrate/migrate/v4/source/stub"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"

Expand Down Expand Up @@ -75,3 +78,49 @@ func testSQLDB(t testing.TB) *sql.DB {

return db
}

func TestCheckLatestVersion(t *testing.T) {
t.Parallel()

type test struct {
currentVersion uint
existingVersions []uint
expectedResult string
}

tests := []test{
// successful cases
{1, []uint{1}, ""},
{3, []uint{1, 2, 3}, ""},
{3, []uint{1, 3}, ""},

// failure cases
{1, []uint{1, 2}, "current version is 1, but later version 2 exists"},
{2, []uint{1, 2, 3}, "current version is 2, but later version 3 exists"},
{4, []uint{1, 2, 3}, "get previous migration: prev for version 4 : file does not exist"},
{4, []uint{1, 2, 3, 5}, "get previous migration: prev for version 4 : file does not exist"},
}

for i, tc := range tests {
t.Run(fmt.Sprintf("entry %d", i), func(t *testing.T) {

driver, _ := stub.WithInstance(nil, &stub.Config{})
migrations := driver.(*stub.Stub).Migrations
for _, version := range tc.existingVersions {
migrations.Append(&source.Migration{
Version: version,
Identifier: "",
Direction: source.Up,
Raw: "",
})
}

err := database.CheckLatestVersion(driver, tc.currentVersion)
var errMessage string
if err != nil {
errMessage = err.Error()
}
require.Equal(t, tc.expectedResult, errMessage)
})
}
}