diff --git a/cli/resetpassword.go b/cli/resetpassword.go new file mode 100644 index 0000000000000..ec23a2fc89bed --- /dev/null +++ b/cli/resetpassword.go @@ -0,0 +1,90 @@ +package cli + +import ( + "database/sql" + + "github.com/spf13/cobra" + "golang.org/x/xerrors" + + "github.com/coder/coder/cli/cliflag" + "github.com/coder/coder/cli/cliui" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/userpassword" +) + +func resetPassword() *cobra.Command { + var ( + postgresURL string + ) + + root := &cobra.Command{ + Use: "reset-password ", + Short: "Reset a user's password by directly updating the database", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + username := args[0] + + sqlDB, err := sql.Open("postgres", postgresURL) + if err != nil { + return xerrors.Errorf("dial postgres: %w", err) + } + defer sqlDB.Close() + err = sqlDB.Ping() + if err != nil { + return xerrors.Errorf("ping postgres: %w", err) + } + + err = database.EnsureClean(sqlDB) + if err != nil { + return xerrors.Errorf("database needs migration: %w", err) + } + db := database.New(sqlDB) + + user, err := db.GetUserByEmailOrUsername(cmd.Context(), database.GetUserByEmailOrUsernameParams{ + Username: username, + }) + if err != nil { + return xerrors.Errorf("retrieving user: %w", err) + } + + password, err := cliui.Prompt(cmd, cliui.PromptOptions{ + Text: "Enter new " + cliui.Styles.Field.Render("password") + ":", + Secret: true, + Validate: cliui.ValidateNotEmpty, + }) + if err != nil { + return xerrors.Errorf("password prompt: %w", err) + } + confirmedPassword, err := cliui.Prompt(cmd, cliui.PromptOptions{ + Text: "Confirm " + cliui.Styles.Field.Render("password") + ":", + Secret: true, + Validate: cliui.ValidateNotEmpty, + }) + if err != nil { + return xerrors.Errorf("confirm password prompt: %w", err) + } + if password != confirmedPassword { + return xerrors.New("Passwords do not match") + } + + hashedPassword, err := userpassword.Hash(password) + if err != nil { + return xerrors.Errorf("hash password: %w", err) + } + + err = db.UpdateUserHashedPassword(cmd.Context(), database.UpdateUserHashedPasswordParams{ + ID: user.ID, + HashedPassword: []byte(hashedPassword), + }) + if err != nil { + return xerrors.Errorf("updating password: %w", err) + } + + return nil + }, + } + + cliflag.StringVarP(root.Flags(), &postgresURL, "postgres-url", "", "CODER_PG_CONNECTION_URL", "", "URL of a PostgreSQL database to connect to") + + return root +} diff --git a/cli/resetpassword_test.go b/cli/resetpassword_test.go new file mode 100644 index 0000000000000..eafa097e3e842 --- /dev/null +++ b/cli/resetpassword_test.go @@ -0,0 +1,106 @@ +package cli_test + +import ( + "context" + "net/url" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/cli/clitest" + "github.com/coder/coder/coderd/database/postgres" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/pty/ptytest" +) + +func TestResetPassword(t *testing.T) { + t.Parallel() + + if runtime.GOOS != "linux" || testing.Short() { + // Skip on non-Linux because it spawns a PostgreSQL instance. + t.SkipNow() + } + + const email = "some@one.com" + const username = "example" + const oldPassword = "password" + const newPassword = "password2" + + // start postgres and coder server processes + + connectionURL, closeFunc, err := postgres.Open() + require.NoError(t, err) + defer closeFunc() + ctx, cancelFunc := context.WithCancel(context.Background()) + serverDone := make(chan struct{}) + serverCmd, cfg := clitest.New(t, "server", "--address", ":0", "--postgres-url", connectionURL) + go func() { + defer close(serverDone) + err = serverCmd.ExecuteContext(ctx) + require.ErrorIs(t, err, context.Canceled) + }() + var client *codersdk.Client + require.Eventually(t, func() bool { + rawURL, err := cfg.URL().Read() + if err != nil { + return false + } + accessURL, err := url.Parse(rawURL) + require.NoError(t, err) + client = codersdk.New(accessURL) + return true + }, 15*time.Second, 25*time.Millisecond) + _, err = client.CreateFirstUser(ctx, codersdk.CreateFirstUserRequest{ + Email: email, + Username: username, + Password: oldPassword, + OrganizationName: "example", + }) + require.NoError(t, err) + + // reset the password + + resetCmd, cmdCfg := clitest.New(t, "reset-password", "--postgres-url", connectionURL, username) + clitest.SetupConfig(t, client, cmdCfg) + cmdDone := make(chan struct{}) + pty := ptytest.New(t) + resetCmd.SetIn(pty.Input()) + resetCmd.SetOut(pty.Output()) + go func() { + defer close(cmdDone) + err = resetCmd.Execute() + require.NoError(t, err) + }() + + matches := []struct { + output string + input string + }{ + {"Enter new", newPassword}, + {"Confirm", newPassword}, + } + for _, match := range matches { + pty.ExpectMatch(match.output) + pty.WriteLine(match.input) + } + <-cmdDone + + // now try logging in + + _, err = client.LoginWithPassword(ctx, codersdk.LoginWithPasswordRequest{ + Email: email, + Password: oldPassword, + }) + require.Error(t, err) + + _, err = client.LoginWithPassword(ctx, codersdk.LoginWithPasswordRequest{ + Email: email, + Password: newPassword, + }) + require.NoError(t, err) + + cancelFunc() + <-serverDone +} diff --git a/cli/root.go b/cli/root.go index 598c1eeefd502..909b18db9a5c8 100644 --- a/cli/root.go +++ b/cli/root.go @@ -61,6 +61,7 @@ func Root() *cobra.Command { list(), login(), publickey(), + resetPassword(), server(), show(), start(), diff --git a/coderd/database/migrate.go b/coderd/database/migrate.go index fc3816016a1f1..25ba7a44d6143 100644 --- a/coderd/database/migrate.go +++ b/coderd/database/migrate.go @@ -4,9 +4,11 @@ import ( "database/sql" "embed" "errors" + "os" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database/postgres" + "github.com/golang-migrate/migrate/v4/source" "github.com/golang-migrate/migrate/v4/source/iofs" "golang.org/x/xerrors" ) @@ -14,28 +16,28 @@ import ( //go:embed migrations/*.sql var migrations embed.FS -func migrateSetup(db *sql.DB) (*migrate.Migrate, error) { +func migrateSetup(db *sql.DB) (source.Driver, *migrate.Migrate, error) { sourceDriver, err := iofs.New(migrations, "migrations") if err != nil { - return nil, xerrors.Errorf("create iofs: %w", err) + return nil, nil, xerrors.Errorf("create iofs: %w", err) } dbDriver, err := postgres.WithInstance(db, &postgres.Config{}) if err != nil { - return nil, xerrors.Errorf("wrap postgres connection: %w", err) + return nil, nil, xerrors.Errorf("wrap postgres connection: %w", err) } m, err := migrate.NewWithInstance("", sourceDriver, "", dbDriver) if err != nil { - return nil, xerrors.Errorf("new migrate instance: %w", err) + return nil, nil, xerrors.Errorf("new migrate instance: %w", err) } - return m, nil + return sourceDriver, m, nil } // MigrateUp runs SQL migrations to ensure the database schema is up-to-date. func MigrateUp(db *sql.DB) error { - m, err := migrateSetup(db) + _, m, err := migrateSetup(db) if err != nil { return xerrors.Errorf("migrate setup: %w", err) } @@ -55,7 +57,7 @@ func MigrateUp(db *sql.DB) error { // MigrateDown runs all down SQL migrations. func MigrateDown(db *sql.DB) error { - m, err := migrateSetup(db) + _, m, err := migrateSetup(db) if err != nil { return xerrors.Errorf("migrate setup: %w", err) } @@ -72,3 +74,68 @@ func MigrateDown(db *sql.DB) error { return nil } + +// EnsureClean checks whether all migrations for the current version have been +// applied, without making any changes to the database. If not, returns a +// non-nil error. +func EnsureClean(db *sql.DB) error { + sourceDriver, m, err := migrateSetup(db) + if err != nil { + return xerrors.Errorf("migrate setup: %w", err) + } + + version, dirty, err := m.Version() + if err != nil { + return xerrors.Errorf("get migration version: %w", err) + } + + if dirty { + return xerrors.Errorf("database has not been cleanly migrated") + } + + // Verify that the database's migration version is "current" by checking + // that a migration with that version exists, but there is no next version. + err = CheckLatestVersion(sourceDriver, version) + if err != nil { + return xerrors.Errorf("database needs migration: %w", err) + } + + return nil +} + +// Returns nil if currentVersion corresponds to the latest available migration, +// otherwise an error explaining why not. +func CheckLatestVersion(sourceDriver source.Driver, currentVersion uint) error { + // This is ugly, but seems like the only way to do it with the public + // interfaces provided by golang-migrate. + + // Check that there is no later version + nextVersion, err := sourceDriver.Next(currentVersion) + if err == nil { + return xerrors.Errorf("current version is %d, but later version %d exists", currentVersion, nextVersion) + } + if !errors.Is(err, os.ErrNotExist) { + return xerrors.Errorf("get next migration after %d: %w", currentVersion, err) + } + + // Once we reach this point, we know that either currentVersion doesn't + // exist, or it has no successor (the return value from + // sourceDriver.Next() is the same in either case). So we need to check + // that either it's the first version, or it has a predecessor. + + firstVersion, err := sourceDriver.First() + if err != nil { + // the total number of migrations should be non-zero, so this must be + // an actual error, not just a missing file + return xerrors.Errorf("get first migration: %w", err) + } + if firstVersion == currentVersion { + return nil + } + + _, err = sourceDriver.Prev(currentVersion) + if err != nil { + return xerrors.Errorf("get previous migration: %w", err) + } + return nil +} diff --git a/coderd/database/migrate_test.go b/coderd/database/migrate_test.go index 8bd9c41f0c443..a8f739c275c36 100644 --- a/coderd/database/migrate_test.go +++ b/coderd/database/migrate_test.go @@ -4,8 +4,11 @@ package database_test import ( "database/sql" + "fmt" "testing" + "github.com/golang-migrate/migrate/v4/source" + "github.com/golang-migrate/migrate/v4/source/stub" "github.com/stretchr/testify/require" "go.uber.org/goleak" @@ -75,3 +78,54 @@ func testSQLDB(t testing.TB) *sql.DB { return db } + +// paralleltest linter doesn't correctly handle table-driven tests (https://github.com/kunwardeep/paralleltest/issues/8) +// nolint:paralleltest +func TestCheckLatestVersion(t *testing.T) { + t.Parallel() + + type test struct { + currentVersion uint + existingVersions []uint + expectedResult string + } + + tests := []test{ + // successful cases + {1, []uint{1}, ""}, + {3, []uint{1, 2, 3}, ""}, + {3, []uint{1, 3}, ""}, + + // failure cases + {1, []uint{1, 2}, "current version is 1, but later version 2 exists"}, + {2, []uint{1, 2, 3}, "current version is 2, but later version 3 exists"}, + {4, []uint{1, 2, 3}, "get previous migration: prev for version 4 : file does not exist"}, + {4, []uint{1, 2, 3, 5}, "get previous migration: prev for version 4 : file does not exist"}, + } + + for i, tc := range tests { + i, tc := i, tc + t.Run(fmt.Sprintf("entry %d", i), func(t *testing.T) { + t.Parallel() + + driver, _ := stub.WithInstance(nil, &stub.Config{}) + stub, ok := driver.(*stub.Stub) + require.True(t, ok) + for _, version := range tc.existingVersions { + stub.Migrations.Append(&source.Migration{ + Version: version, + Identifier: "", + Direction: source.Up, + Raw: "", + }) + } + + err := database.CheckLatestVersion(driver, tc.currentVersion) + var errMessage string + if err != nil { + errMessage = err.Error() + } + require.Equal(t, tc.expectedResult, errMessage) + }) + } +}