diff --git a/Makefile b/Makefile index 084e8bb77e5f0..fe6e527fe712e 100644 --- a/Makefile +++ b/Makefile @@ -765,7 +765,7 @@ sqlc-vet: test-postgres-docker test-postgres: test-postgres-docker # The postgres test is prone to failure, so we limit parallelism for # more consistent execution. - $(GIT_FLAGS) DB=ci DB_FROM=$(shell go run scripts/migrate-ci/main.go) gotestsum \ + $(GIT_FLAGS) DB=ci gotestsum \ --junitfile="gotests.xml" \ --jsonfile="gotests.json" \ --packages="./..." -- \ diff --git a/cli/resetpassword_test.go b/cli/resetpassword_test.go index 0cd90f5b4cd00..de712874f3f07 100644 --- a/cli/resetpassword_test.go +++ b/cli/resetpassword_test.go @@ -32,9 +32,8 @@ func TestResetPassword(t *testing.T) { const newPassword = "MyNewPassword!" // start postgres and coder server processes - connectionURL, closeFunc, err := dbtestutil.Open() + connectionURL, err := dbtestutil.Open(t) require.NoError(t, err) - defer closeFunc() ctx, cancelFunc := context.WithCancel(context.Background()) serverDone := make(chan struct{}) serverinv, cfg := clitest.New(t, diff --git a/cli/server_createadminuser_test.go b/cli/server_createadminuser_test.go index 17c02b6548c09..7660d71e89d99 100644 --- a/cli/server_createadminuser_test.go +++ b/cli/server_createadminuser_test.go @@ -85,9 +85,8 @@ func TestServerCreateAdminUser(t *testing.T) { // Skip on non-Linux because it spawns a PostgreSQL instance. t.SkipNow() } - connectionURL, closeFunc, err := dbtestutil.Open() + connectionURL, err := dbtestutil.Open(t) require.NoError(t, err) - defer closeFunc() sqlDB, err := sql.Open("postgres", connectionURL) require.NoError(t, err) @@ -151,9 +150,8 @@ func TestServerCreateAdminUser(t *testing.T) { // Skip on non-Linux because it spawns a PostgreSQL instance. t.SkipNow() } - connectionURL, closeFunc, err := dbtestutil.Open() + connectionURL, err := dbtestutil.Open(t) require.NoError(t, err) - defer closeFunc() ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) defer cancel() @@ -185,9 +183,8 @@ func TestServerCreateAdminUser(t *testing.T) { // Skip on non-Linux because it spawns a PostgreSQL instance. t.SkipNow() } - connectionURL, closeFunc, err := dbtestutil.Open() + connectionURL, err := dbtestutil.Open(t) require.NoError(t, err) - defer closeFunc() ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) defer cancel() @@ -225,9 +222,8 @@ func TestServerCreateAdminUser(t *testing.T) { // Skip on non-Linux because it spawns a PostgreSQL instance. t.SkipNow() } - connectionURL, closeFunc, err := dbtestutil.Open() + connectionURL, err := dbtestutil.Open(t) require.NoError(t, err) - defer closeFunc() ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() diff --git a/cli/server_test.go b/cli/server_test.go index ad6a98038c7bb..83a7f7171c6f5 100644 --- a/cli/server_test.go +++ b/cli/server_test.go @@ -1598,9 +1598,8 @@ func TestServer_Production(t *testing.T) { // Skip on non-Linux because it spawns a PostgreSQL instance. t.SkipNow() } - connectionURL, closeFunc, err := dbtestutil.Open() + connectionURL, err := dbtestutil.Open(t) require.NoError(t, err) - defer closeFunc() // Postgres + race detector + CI = slow. ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong*3) @@ -1803,9 +1802,8 @@ func TestConnectToPostgres(t *testing.T) { log := slogtest.Make(t, nil) - dbURL, closeFunc, err := dbtestutil.Open() + dbURL, err := dbtestutil.Open(t) require.NoError(t, err) - t.Cleanup(closeFunc) sqlDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL) require.NoError(t, err) diff --git a/coderd/database/db_test.go b/coderd/database/db_test.go index a6df18fcbb8c8..b4580527c843a 100644 --- a/coderd/database/db_test.go +++ b/coderd/database/db_test.go @@ -87,9 +87,8 @@ func TestNestedInTx(t *testing.T) { func testSQLDB(t testing.TB) *sql.DB { t.Helper() - connection, closeFn, err := dbtestutil.Open() + connection, err := dbtestutil.Open(t) require.NoError(t, err) - t.Cleanup(closeFn) db, err := sql.Open("postgres", connection) require.NoError(t, err) diff --git a/coderd/database/dbtestutil/db.go b/coderd/database/dbtestutil/db.go index 327d880f69648..966efd8386643 100644 --- a/coderd/database/dbtestutil/db.go +++ b/coderd/database/dbtestutil/db.go @@ -95,21 +95,17 @@ func NewDB(t testing.TB, opts ...Option) (database.Store, pubsub.Pubsub) { opt(&o) } - db := dbmem.New() - ps := pubsub.NewInMemory() + var db database.Store + var ps pubsub.Pubsub if WillUsePostgres() { connectionURL := os.Getenv("CODER_PG_CONNECTION_URL") if connectionURL == "" && o.url != "" { connectionURL = o.url } if connectionURL == "" { - var ( - err error - closePg func() - ) - connectionURL, closePg, err = Open() + var err error + connectionURL, err = Open(t) require.NoError(t, err) - t.Cleanup(closePg) } if o.fixedTimezone == "" { @@ -142,6 +138,9 @@ func NewDB(t testing.TB, opts ...Option) (database.Store, pubsub.Pubsub) { t.Cleanup(func() { _ = ps.Close() }) + } else { + db = dbmem.New() + ps = pubsub.NewInMemory() } return db, ps diff --git a/coderd/database/dbtestutil/postgres.go b/coderd/database/dbtestutil/postgres.go index 3a559778b6968..a58ffb570763f 100644 --- a/coderd/database/dbtestutil/postgres.go +++ b/coderd/database/dbtestutil/postgres.go @@ -1,134 +1,498 @@ package dbtestutil import ( + "context" + "crypto/sha256" "database/sql" + "encoding/hex" + "errors" "fmt" + "net" "os" + "path/filepath" "strconv" + "strings" + "sync" "time" "github.com/cenkalti/backoff/v4" + "github.com/gofrs/flock" "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd/database/migrations" "github.com/coder/coder/v2/cryptorand" + "github.com/coder/retry" ) -// Open creates a new PostgreSQL database instance. With DB_FROM environment variable set, it clones a database -// from the provided template. With the environment variable unset, it creates a new Docker container running postgres. -func Open() (string, func(), error) { - if os.Getenv("DB_FROM") != "" { - // In CI, creating a Docker container for each test is slow. - // This expects a PostgreSQL instance with the hardcoded credentials - // available. - dbURL := "postgres://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable" - db, err := sql.Open("postgres", dbURL) +type ConnectionParams struct { + Username string + Password string + Host string + Port string + DBName string +} + +func (p ConnectionParams) DSN() string { + return fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", p.Username, p.Password, p.Host, p.Port, p.DBName) +} + +// These variables are global because all tests share them. +var ( + connectionParamsInitOnce sync.Once + defaultConnectionParams ConnectionParams + errDefaultConnectionParamsInit error +) + +// initDefaultConnection initializes the default postgres connection parameters. +// It first checks if the database is running at localhost:5432. If it is, it will +// use that database. If it's not, it will start a new container and use that. +func initDefaultConnection(t TBSubset) error { + params := ConnectionParams{ + Username: "postgres", + Password: "postgres", + Host: "127.0.0.1", + Port: "5432", + DBName: "postgres", + } + dsn := params.DSN() + db, dbErr := sql.Open("postgres", dsn) + if dbErr == nil { + dbErr = db.Ping() + if closeErr := db.Close(); closeErr != nil { + return xerrors.Errorf("close db: %w", closeErr) + } + } + shouldOpenContainer := false + if dbErr != nil { + errSubstrings := []string{ + "connection refused", // this happens on Linux when there's nothing listening on the port + "No connection could be made", // like above but Windows + } + errString := dbErr.Error() + for _, errSubstring := range errSubstrings { + if strings.Contains(errString, errSubstring) { + shouldOpenContainer = true + break + } + } + } + if dbErr != nil && shouldOpenContainer { + // If there's no database running on the default port, we'll start a + // postgres container. We won't be cleaning it up so it can be reused + // by subsequent tests. It'll keep on running until the user terminates + // it manually. + container, _, err := openContainer(t, DBContainerOptions{ + Name: "coder-test-postgres", + Port: 5432, + }) if err != nil { - return "", nil, xerrors.Errorf("connect to ci postgres: %w", err) + return xerrors.Errorf("open container: %w", err) } + params.Host = container.Host + params.Port = container.Port + dsn = params.DSN() - defer db.Close() + // Retry connecting for at most 10 seconds. + // The fact that openContainer succeeded does not + // mean that port forwarding is ready. + for r := retry.New(100*time.Millisecond, 10*time.Second); r.Wait(context.Background()); { + db, connErr := sql.Open("postgres", dsn) + if connErr == nil { + connErr = db.Ping() + if closeErr := db.Close(); closeErr != nil { + return xerrors.Errorf("close db, container: %w", closeErr) + } + } + if connErr == nil { + break + } + } + } else if dbErr != nil { + return xerrors.Errorf("open postgres connection: %w", dbErr) + } + defaultConnectionParams = params + return nil +} + +type OpenOptions struct { + DBFrom *string +} + +type OpenOption func(*OpenOptions) + +// WithDBFrom sets the template database to use when creating a new database. +// Overrides the DB_FROM environment variable. +func WithDBFrom(dbFrom string) OpenOption { + return func(o *OpenOptions) { + o.DBFrom = &dbFrom + } +} + +// TBSubset is a subset of the testing.TB interface. +// It allows to use dbtestutil.Open outside of tests. +type TBSubset interface { + Cleanup(func()) + Helper() + Logf(format string, args ...any) +} + +// Open creates a new PostgreSQL database instance. +// If there's a database running at localhost:5432, it will use that. +// Otherwise, it will start a new postgres container. +func Open(t TBSubset, opts ...OpenOption) (string, error) { + t.Helper() - dbName, err := cryptorand.StringCharset(cryptorand.Lower, 10) + connectionParamsInitOnce.Do(func() { + errDefaultConnectionParamsInit = initDefaultConnection(t) + }) + if errDefaultConnectionParamsInit != nil { + return "", xerrors.Errorf("init default connection params: %w", errDefaultConnectionParamsInit) + } + + openOptions := OpenOptions{} + for _, opt := range opts { + opt(&openOptions) + } + + var ( + username = defaultConnectionParams.Username + password = defaultConnectionParams.Password + host = defaultConnectionParams.Host + port = defaultConnectionParams.Port + ) + + // Use a time-based prefix to make it easier to find the database + // when debugging. + now := time.Now().Format("test_2006_01_02_15_04_05") + dbSuffix, err := cryptorand.StringCharset(cryptorand.Lower, 10) + if err != nil { + return "", xerrors.Errorf("generate db suffix: %w", err) + } + dbName := now + "_" + dbSuffix + + // if empty createDatabaseFromTemplate will create a new template db + templateDBName := os.Getenv("DB_FROM") + if openOptions.DBFrom != nil { + templateDBName = *openOptions.DBFrom + } + if err = createDatabaseFromTemplate(t, defaultConnectionParams, dbName, templateDBName); err != nil { + return "", xerrors.Errorf("create database: %w", err) + } + + t.Cleanup(func() { + cleanupDbURL := defaultConnectionParams.DSN() + cleanupConn, err := sql.Open("postgres", cleanupDbURL) if err != nil { - return "", nil, xerrors.Errorf("generate db name: %w", err) + t.Logf("cleanup database %q: failed to connect to postgres: %s\n", dbName, err.Error()) + return } - - dbName = "ci" + dbName - _, err = db.Exec("CREATE DATABASE " + dbName + " WITH TEMPLATE " + os.Getenv("DB_FROM")) + defer func() { + if err := cleanupConn.Close(); err != nil { + t.Logf("cleanup database %q: failed to close connection: %s\n", dbName, err.Error()) + } + }() + _, err = cleanupConn.Exec("DROP DATABASE " + dbName + ";") if err != nil { - return "", nil, xerrors.Errorf("create db with template: %w", err) + t.Logf("failed to clean up database %q: %s\n", dbName, err.Error()) + return } + }) - dsn := "postgres://postgres:postgres@127.0.0.1:5432/" + dbName + "?sslmode=disable" - // Normally this would get cleaned up by removing the container but if we - // reuse the same container for multiple tests we run the risk of filling - // up our disk. Avoid this! - cleanup := func() { - cleanupConn, err := sql.Open("postgres", dbURL) - if err != nil { - _, _ = fmt.Fprintf(os.Stderr, "cleanup database %q: failed to connect to postgres: %s\n", dbName, err.Error()) - } - defer cleanupConn.Close() - _, err = cleanupConn.Exec("DROP DATABASE " + dbName + ";") - if err != nil { - _, _ = fmt.Fprintf(os.Stderr, "failed to clean up database %q: %s\n", dbName, err.Error()) + dsn := ConnectionParams{ + Username: username, + Password: password, + Host: host, + Port: port, + DBName: dbName, + }.DSN() + return dsn, nil +} + +// createDatabaseFromTemplate creates a new database from a template database. +// If templateDBName is empty, it will create a new template database based on +// the current migrations, and name it "tpl_". Or if it's +// already been created, it will use that. +func createDatabaseFromTemplate(t TBSubset, connParams ConnectionParams, newDBName string, templateDBName string) error { + t.Helper() + + dbURL := connParams.DSN() + db, err := sql.Open("postgres", dbURL) + if err != nil { + return xerrors.Errorf("connect to postgres: %w", err) + } + defer func() { + if err := db.Close(); err != nil { + t.Logf("create database from template: failed to close connection: %s\n", err.Error()) + } + }() + + emptyTemplateDBName := templateDBName == "" + if emptyTemplateDBName { + templateDBName = fmt.Sprintf("tpl_%s", migrations.GetMigrationsHash()[:32]) + } + _, err = db.Exec("CREATE DATABASE " + newDBName + " WITH TEMPLATE " + templateDBName) + if err == nil { + // Template database already exists and we successfully created the new database. + return nil + } + tplDbDoesNotExistOccurred := strings.Contains(err.Error(), "template database") && strings.Contains(err.Error(), "does not exist") + if (tplDbDoesNotExistOccurred && !emptyTemplateDBName) || !tplDbDoesNotExistOccurred { + // First and case: user passed a templateDBName that doesn't exist. + // Second and case: some other error. + return xerrors.Errorf("create db with template: %w", err) + } + if !emptyTemplateDBName { + // sanity check + panic("templateDBName is not empty. there's a bug in the code above") + } + // The templateDBName is empty, so we need to create the template database. + // We will use a tx to obtain a lock, so another test or process doesn't race with us. + tx, err := db.BeginTx(context.Background(), nil) + if err != nil { + return xerrors.Errorf("begin tx: %w", err) + } + defer func() { + err := tx.Rollback() + if err != nil && !errors.Is(err, sql.ErrTxDone) { + t.Logf("create database from template: failed to rollback tx: %s\n", err.Error()) + } + }() + // 2137 is an arbitrary number. We just need a lock that is unique to creating + // the template database. + _, err = tx.Exec("SELECT pg_advisory_xact_lock(2137)") + if err != nil { + return xerrors.Errorf("acquire lock: %w", err) + } + + // Someone else might have created the template db while we were waiting. + tplDbExistsRes, err := tx.Query("SELECT 1 FROM pg_database WHERE datname = $1", templateDBName) + if err != nil { + return xerrors.Errorf("check if db exists: %w", err) + } + tplDbAlreadyExists := tplDbExistsRes.Next() + if err := tplDbExistsRes.Close(); err != nil { + return xerrors.Errorf("close tpl db exists res: %w", err) + } + if !tplDbAlreadyExists { + // We will use a temporary template database to avoid race conditions. We will + // rename it to the real template database name after we're sure it was fully + // initialized. + // It's dropped here to ensure that if a previous run of this function failed + // midway, we don't encounter issues with the temporary database still existing. + tmpTemplateDBName := "tmp_" + templateDBName + // We're using db instead of tx here because you can't run `DROP DATABASE` inside + // a transaction. + if _, err := db.Exec("DROP DATABASE IF EXISTS " + tmpTemplateDBName); err != nil { + return xerrors.Errorf("drop tmp template db: %w", err) + } + if _, err := db.Exec("CREATE DATABASE " + tmpTemplateDBName); err != nil { + return xerrors.Errorf("create tmp template db: %w", err) + } + tplDbURL := ConnectionParams{ + Username: connParams.Username, + Password: connParams.Password, + Host: connParams.Host, + Port: connParams.Port, + DBName: tmpTemplateDBName, + }.DSN() + tplDb, err := sql.Open("postgres", tplDbURL) + if err != nil { + return xerrors.Errorf("connect to template db: %w", err) + } + defer func() { + if err := tplDb.Close(); err != nil { + t.Logf("create database from template: failed to close template db: %s\n", err.Error()) } + }() + if err := migrations.Up(tplDb); err != nil { + return xerrors.Errorf("migrate template db: %w", err) } - return dsn, cleanup, nil + if err := tplDb.Close(); err != nil { + return xerrors.Errorf("close template db: %w", err) + } + if _, err := db.Exec("ALTER DATABASE " + tmpTemplateDBName + " RENAME TO " + templateDBName); err != nil { + return xerrors.Errorf("rename tmp template db: %w", err) + } + } + + // Try to create the database again now that a template exists. + if _, err = db.Exec("CREATE DATABASE " + newDBName + " WITH TEMPLATE " + templateDBName); err != nil { + return xerrors.Errorf("create db with template after migrations: %w", err) } - return OpenContainerized(0) + if err = tx.Commit(); err != nil { + return xerrors.Errorf("commit tx: %w", err) + } + return nil } -// OpenContainerized creates a new PostgreSQL server using a Docker container. If port is nonzero, forward host traffic -// to that port to the database. If port is zero, allocate a free port from the OS. -func OpenContainerized(port int) (string, func(), error) { +type DBContainerOptions struct { + Port int + Name string +} + +type container struct { + Resource *dockertest.Resource + Pool *dockertest.Pool + Host string + Port string +} + +// OpenContainer creates a new PostgreSQL server using a Docker container. If port is nonzero, forward host traffic +// to that port to the database. If port is zero, allocate a free port from the OS. +// If name is set, we'll ensure that only one container is started with that name. If it's already running, we'll use that. +// Otherwise, we'll start a new container. +func openContainer(t TBSubset, opts DBContainerOptions) (container, func(), error) { + if opts.Name != "" { + // We only want to start the container once per unique name, + // so we take an inter-process lock to avoid concurrent test runs + // racing with us. + nameHash := sha256.Sum256([]byte(opts.Name)) + nameHashStr := hex.EncodeToString(nameHash[:]) + lock := flock.New(filepath.Join(os.TempDir(), "coder-postgres-container-"+nameHashStr[:8])) + if err := lock.Lock(); err != nil { + return container{}, nil, xerrors.Errorf("lock: %w", err) + } + defer func() { + err := lock.Unlock() + if err != nil { + t.Logf("create database from template: failed to unlock: %s\n", err.Error()) + } + }() + } + pool, err := dockertest.NewPool("") if err != nil { - return "", nil, xerrors.Errorf("create pool: %w", err) + return container{}, nil, xerrors.Errorf("create pool: %w", err) + } + + var resource *dockertest.Resource + var tempDir string + if opts.Name != "" { + // If the container already exists, we'll use it. + resource, _ = pool.ContainerByName(opts.Name) + } + if resource == nil { + tempDir, err = os.MkdirTemp(os.TempDir(), "postgres") + if err != nil { + return container{}, nil, xerrors.Errorf("create tempdir: %w", err) + } + runOptions := dockertest.RunOptions{ + Repository: "gcr.io/coder-dev-1/postgres", + Tag: "13", + Env: []string{ + "POSTGRES_PASSWORD=postgres", + "POSTGRES_USER=postgres", + "POSTGRES_DB=postgres", + // The location for temporary database files! + "PGDATA=/tmp", + "listen_addresses = '*'", + }, + PortBindings: map[docker.Port][]docker.PortBinding{ + "5432/tcp": {{ + // Manually specifying a host IP tells Docker just to use an IPV4 address. + // If we don't do this, we hit a fun bug: + // https://github.com/moby/moby/issues/42442 + // where the ipv4 and ipv6 ports might be _different_ and collide with other running docker containers. + HostIP: "0.0.0.0", + HostPort: strconv.FormatInt(int64(opts.Port), 10), + }}, + }, + Mounts: []string{ + // The postgres image has a VOLUME parameter in it's image. + // If we don't mount at this point, Docker will allocate a + // volume for this directory. + // + // This isn't used anyways, since we override PGDATA. + fmt.Sprintf("%s:/var/lib/postgresql/data", tempDir), + }, + Cmd: []string{"-c", "max_connections=1000"}, + } + if opts.Name != "" { + runOptions.Name = opts.Name + } + resource, err = pool.RunWithOptions(&runOptions, func(config *docker.HostConfig) { + // set AutoRemove to true so that stopped container goes away by itself + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + config.Tmpfs = map[string]string{ + "/tmp": "rw", + } + }) + if err != nil { + return container{}, nil, xerrors.Errorf("could not start resource: %w", err) + } } - tempDir, err := os.MkdirTemp(os.TempDir(), "postgres") + hostAndPort := resource.GetHostPort("5432/tcp") + host, port, err := net.SplitHostPort(hostAndPort) if err != nil { - return "", nil, xerrors.Errorf("create tempdir: %w", err) - } - - resource, err := pool.RunWithOptions(&dockertest.RunOptions{ - Repository: "gcr.io/coder-dev-1/postgres", - Tag: "13", - Env: []string{ - "POSTGRES_PASSWORD=postgres", - "POSTGRES_USER=postgres", - "POSTGRES_DB=postgres", - // The location for temporary database files! - "PGDATA=/tmp", - "listen_addresses = '*'", - }, - PortBindings: map[docker.Port][]docker.PortBinding{ - "5432/tcp": {{ - // Manually specifying a host IP tells Docker just to use an IPV4 address. - // If we don't do this, we hit a fun bug: - // https://github.com/moby/moby/issues/42442 - // where the ipv4 and ipv6 ports might be _different_ and collide with other running docker containers. - HostIP: "0.0.0.0", - HostPort: strconv.FormatInt(int64(port), 10), - }}, - }, - Mounts: []string{ - // The postgres image has a VOLUME parameter in it's image. - // If we don't mount at this point, Docker will allocate a - // volume for this directory. - // - // This isn't used anyways, since we override PGDATA. - fmt.Sprintf("%s:/var/lib/postgresql/data", tempDir), - }, - }, func(config *docker.HostConfig) { - // set AutoRemove to true so that stopped container goes away by itself - config.AutoRemove = true - config.RestartPolicy = docker.RestartPolicy{Name: "no"} - }) + return container{}, nil, xerrors.Errorf("split host and port: %w", err) + } + + for r := retry.New(50*time.Millisecond, 15*time.Second); r.Wait(context.Background()); { + stdout := &strings.Builder{} + stderr := &strings.Builder{} + _, err = resource.Exec([]string{"pg_isready", "-h", "127.0.0.1"}, dockertest.ExecOptions{ + StdOut: stdout, + StdErr: stderr, + }) + if err == nil { + break + } + } if err != nil { - return "", nil, xerrors.Errorf("could not start resource: %w", err) + return container{}, nil, xerrors.Errorf("pg_isready: %w", err) } - hostAndPort := resource.GetHostPort("5432/tcp") - dbURL := fmt.Sprintf("postgres://postgres:postgres@%s/postgres?sslmode=disable", hostAndPort) + return container{ + Host: host, + Port: port, + Resource: resource, + Pool: pool, + }, func() { + _ = pool.Purge(resource) + if tempDir != "" { + _ = os.RemoveAll(tempDir) + } + }, nil +} + +// OpenContainerized creates a new PostgreSQL server using a Docker container. If port is nonzero, forward host traffic +// to that port to the database. If port is zero, allocate a free port from the OS. +// The user is responsible for calling the returned cleanup function. +func OpenContainerized(t TBSubset, opts DBContainerOptions) (string, func(), error) { + container, containerCleanup, err := openContainer(t, opts) + defer func() { + if err != nil { + containerCleanup() + } + }() + if err != nil { + return "", nil, xerrors.Errorf("open container: %w", err) + } + dbURL := ConnectionParams{ + Username: "postgres", + Password: "postgres", + Host: container.Host, + Port: container.Port, + DBName: "postgres", + }.DSN() // Docker should hard-kill the container after 120 seconds. - err = resource.Expire(120) + err = container.Resource.Expire(120) if err != nil { return "", nil, xerrors.Errorf("expire resource: %w", err) } - pool.MaxWait = 120 * time.Second + container.Pool.MaxWait = 120 * time.Second // Record the error that occurs during the retry. // The 'pool' pkg hardcodes a deadline error devoid // of any useful context. var retryErr error - err = pool.Retry(func() error { + err = container.Pool.Retry(func() error { db, err := sql.Open("postgres", dbURL) if err != nil { retryErr = xerrors.Errorf("open postgres: %w", err) @@ -155,8 +519,5 @@ func OpenContainerized(port int) (string, func(), error) { return "", nil, retryErr } - return dbURL, func() { - _ = pool.Purge(resource) - _ = os.RemoveAll(tempDir) - }, nil + return dbURL, containerCleanup, nil } diff --git a/coderd/database/dbtestutil/postgres_test.go b/coderd/database/dbtestutil/postgres_test.go index ec500d824a9ba..9cae9411289ad 100644 --- a/coderd/database/dbtestutil/postgres_test.go +++ b/coderd/database/dbtestutil/postgres_test.go @@ -11,25 +11,19 @@ import ( "go.uber.org/goleak" "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/migrations" ) func TestMain(m *testing.M) { goleak.VerifyTestMain(m) } -// nolint:paralleltest -func TestPostgres(t *testing.T) { - // postgres.Open() seems to be creating race conditions when run in parallel. - // t.Parallel() +func TestOpen(t *testing.T) { + t.Parallel() - if testing.Short() { - t.SkipNow() - return - } - - connect, closePg, err := dbtestutil.Open() + connect, err := dbtestutil.Open(t) require.NoError(t, err) - defer closePg() + db, err := sql.Open("postgres", connect) require.NoError(t, err) err = db.Ping() @@ -37,3 +31,74 @@ func TestPostgres(t *testing.T) { err = db.Close() require.NoError(t, err) } + +func TestOpen_InvalidDBFrom(t *testing.T) { + t.Parallel() + + _, err := dbtestutil.Open(t, dbtestutil.WithDBFrom("__invalid__")) + require.Error(t, err) + require.ErrorContains(t, err, "template database") + require.ErrorContains(t, err, "does not exist") +} + +func TestOpen_ValidDBFrom(t *testing.T) { + t.Parallel() + + // first check if we can create a new template db + dsn, err := dbtestutil.Open(t, dbtestutil.WithDBFrom("")) + require.NoError(t, err) + + db, err := sql.Open("postgres", dsn) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, db.Close()) + }) + + err = db.Ping() + require.NoError(t, err) + + templateDBName := "tpl_" + migrations.GetMigrationsHash()[:32] + tplDbExistsRes, err := db.Query("SELECT 1 FROM pg_database WHERE datname = $1", templateDBName) + if err != nil { + require.NoError(t, err) + } + require.True(t, tplDbExistsRes.Next()) + require.NoError(t, tplDbExistsRes.Close()) + + // now populate the db with some data and use it as a new template db + // to verify that dbtestutil.Open respects WithDBFrom + _, err = db.Exec("CREATE TABLE my_wonderful_table (id serial PRIMARY KEY, name text)") + require.NoError(t, err) + _, err = db.Exec("INSERT INTO my_wonderful_table (name) VALUES ('test')") + require.NoError(t, err) + + rows, err := db.Query("SELECT current_database()") + require.NoError(t, err) + require.True(t, rows.Next()) + var freshTemplateDBName string + require.NoError(t, rows.Scan(&freshTemplateDBName)) + require.NoError(t, rows.Close()) + require.NoError(t, db.Close()) + + for i := 0; i < 10; i++ { + db, err := sql.Open("postgres", dsn) + require.NoError(t, err) + require.NoError(t, db.Ping()) + require.NoError(t, db.Close()) + } + + // now create a new db from the template db + newDsn, err := dbtestutil.Open(t, dbtestutil.WithDBFrom(freshTemplateDBName)) + require.NoError(t, err) + + newDb, err := sql.Open("postgres", newDsn) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, newDb.Close()) + }) + + rows, err = newDb.Query("SELECT 1 FROM my_wonderful_table WHERE name = 'test'") + require.NoError(t, err) + require.True(t, rows.Next()) + require.NoError(t, rows.Close()) +} diff --git a/coderd/database/gen/dump/main.go b/coderd/database/gen/dump/main.go index f563e1142619e..e3e80c528144e 100644 --- a/coderd/database/gen/dump/main.go +++ b/coderd/database/gen/dump/main.go @@ -2,6 +2,7 @@ package main import ( "database/sql" + "fmt" "os" "path/filepath" "runtime" @@ -12,12 +13,34 @@ import ( var preamble = []byte("-- Code generated by 'make coderd/database/generate'. DO NOT EDIT.") +type mockTB struct { + cleanup []func() +} + +func (t *mockTB) Cleanup(f func()) { + t.cleanup = append(t.cleanup, f) +} + +func (*mockTB) Helper() { + // noop +} + +func (*mockTB) Logf(format string, args ...any) { + _, _ = fmt.Printf(format, args...) +} + func main() { - connection, closeFn, err := dbtestutil.Open() + t := &mockTB{} + defer func() { + for _, f := range t.cleanup { + f() + } + }() + + connection, err := dbtestutil.Open(t) if err != nil { panic(err) } - defer closeFn() db, err := sql.Open("postgres", connection) if err != nil { diff --git a/coderd/database/migrations/migrate.go b/coderd/database/migrations/migrate.go index 213408bbadd8c..c6c1b5740f873 100644 --- a/coderd/database/migrations/migrate.go +++ b/coderd/database/migrations/migrate.go @@ -2,11 +2,16 @@ package migrations import ( "context" + "crypto/sha256" "database/sql" "embed" "errors" + "fmt" "io/fs" "os" + "sort" + "strings" + "sync" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/source" @@ -17,6 +22,56 @@ import ( //go:embed *.sql var migrations embed.FS +var ( + migrationsHash string + migrationsHashOnce sync.Once +) + +// A migrations hash is a sha256 hash of the contents and names +// of the migrations sorted by filename. +func calculateMigrationsHash(migrationsFs embed.FS) (string, error) { + files, err := migrationsFs.ReadDir(".") + if err != nil { + return "", xerrors.Errorf("read migrations directory: %w", err) + } + sortedFiles := make([]fs.DirEntry, len(files)) + copy(sortedFiles, files) + sort.Slice(sortedFiles, func(i, j int) bool { + return sortedFiles[i].Name() < sortedFiles[j].Name() + }) + + var builder strings.Builder + for _, file := range sortedFiles { + if _, err := builder.WriteString(file.Name()); err != nil { + return "", xerrors.Errorf("write migration file name %q: %w", file.Name(), err) + } + content, err := migrationsFs.ReadFile(file.Name()) + if err != nil { + return "", xerrors.Errorf("read migration file %q: %w", file.Name(), err) + } + if _, err := builder.Write(content); err != nil { + return "", xerrors.Errorf("write migration file content %q: %w", file.Name(), err) + } + } + + hash := sha256.New() + if _, err := hash.Write([]byte(builder.String())); err != nil { + return "", xerrors.Errorf("write to hash: %w", err) + } + return fmt.Sprintf("%x", hash.Sum(nil)), nil +} + +func GetMigrationsHash() string { + migrationsHashOnce.Do(func() { + hash, err := calculateMigrationsHash(migrations) + if err != nil { + panic(err) + } + migrationsHash = hash + }) + return migrationsHash +} + func setup(db *sql.DB, migs fs.FS) (source.Driver, *migrate.Migrate, error) { if migs == nil { migs = migrations diff --git a/coderd/database/migrations/migrate_test.go b/coderd/database/migrations/migrate_test.go index 51e7fcc86cb03..c64c2436da18d 100644 --- a/coderd/database/migrations/migrate_test.go +++ b/coderd/database/migrations/migrate_test.go @@ -95,9 +95,8 @@ func TestMigrate(t *testing.T) { func testSQLDB(t testing.TB) *sql.DB { t.Helper() - connection, closeFn, err := dbtestutil.Open() + connection, err := dbtestutil.Open(t) require.NoError(t, err) - t.Cleanup(closeFn) db, err := sql.Open("postgres", connection) require.NoError(t, err) diff --git a/coderd/database/pubsub/pubsub_linux_test.go b/coderd/database/pubsub/pubsub_linux_test.go index f208af921b441..819de0a71ba52 100644 --- a/coderd/database/pubsub/pubsub_linux_test.go +++ b/coderd/database/pubsub/pubsub_linux_test.go @@ -40,9 +40,8 @@ func TestPubsub(t *testing.T) { defer cancelFunc() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - connectionURL, closePg, err := dbtestutil.Open() + connectionURL, err := dbtestutil.Open(t) require.NoError(t, err) - defer closePg() db, err := sql.Open("postgres", connectionURL) require.NoError(t, err) defer db.Close() @@ -69,9 +68,8 @@ func TestPubsub(t *testing.T) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - connectionURL, closePg, err := dbtestutil.Open() + connectionURL, err := dbtestutil.Open(t) require.NoError(t, err) - defer closePg() db, err := sql.Open("postgres", connectionURL) require.NoError(t, err) defer db.Close() @@ -85,9 +83,8 @@ func TestPubsub(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - connectionURL, closePg, err := dbtestutil.Open() + connectionURL, err := dbtestutil.Open(t) require.NoError(t, err) - defer closePg() db, err := sql.Open("postgres", connectionURL) require.NoError(t, err) defer db.Close() @@ -122,9 +119,8 @@ func TestPubsub_ordering(t *testing.T) { defer cancelFunc() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - connectionURL, closePg, err := dbtestutil.Open() + connectionURL, err := dbtestutil.Open(t) require.NoError(t, err) - defer closePg() db, err := sql.Open("postgres", connectionURL) require.NoError(t, err) defer db.Close() @@ -167,7 +163,7 @@ const disconnectTestPort = 26892 func TestPubsub_Disconnect(t *testing.T) { // we always use a Docker container for this test, even in CI, since we need to be able to kill // postgres and bring it back on the same port. - connectionURL, closePg, err := dbtestutil.OpenContainerized(disconnectTestPort) + connectionURL, closePg, err := dbtestutil.OpenContainerized(t, dbtestutil.DBContainerOptions{Port: disconnectTestPort}) require.NoError(t, err) defer closePg() db, err := sql.Open("postgres", connectionURL) @@ -238,7 +234,7 @@ func TestPubsub_Disconnect(t *testing.T) { // restart postgres on the same port --- since we only use LISTEN/NOTIFY it doesn't // matter that the new postgres doesn't have any persisted state from before. - _, closeNewPg, err := dbtestutil.OpenContainerized(disconnectTestPort) + _, closeNewPg, err := dbtestutil.OpenContainerized(t, dbtestutil.DBContainerOptions{Port: disconnectTestPort}) require.NoError(t, err) defer closeNewPg() @@ -305,7 +301,7 @@ func TestMeasureLatency(t *testing.T) { newPubsub := func() (pubsub.Pubsub, func()) { ctx, cancel := context.WithCancel(context.Background()) logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - connectionURL, closePg, err := dbtestutil.Open() + connectionURL, err := dbtestutil.Open(t) require.NoError(t, err) db, err := sql.Open("postgres", connectionURL) require.NoError(t, err) @@ -315,7 +311,6 @@ func TestMeasureLatency(t *testing.T) { return ps, func() { _ = ps.Close() _ = db.Close() - closePg() cancel() } } diff --git a/coderd/database/pubsub/pubsub_test.go b/coderd/database/pubsub/pubsub_test.go index 21b4b1d54c171..6b8181ea7d834 100644 --- a/coderd/database/pubsub/pubsub_test.go +++ b/coderd/database/pubsub/pubsub_test.go @@ -24,9 +24,8 @@ func TestPGPubsub_Metrics(t *testing.T) { } logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - connectionURL, closePg, err := dbtestutil.Open() + connectionURL, err := dbtestutil.Open(t) require.NoError(t, err) - defer closePg() db, err := sql.Open("postgres", connectionURL) require.NoError(t, err) defer db.Close() @@ -132,9 +131,8 @@ func TestPGPubsubDriver(t *testing.T) { IgnoreErrors: true, }).Leveled(slog.LevelDebug) - connectionURL, closePg, err := dbtestutil.Open() + connectionURL, err := dbtestutil.Open(t) require.NoError(t, err) - defer closePg() // use a separate subber and pubber so we can keep track of listener connections db, err := sql.Open("postgres", connectionURL) diff --git a/enterprise/cli/server_dbcrypt_test.go b/enterprise/cli/server_dbcrypt_test.go index b1767889d9c33..070f172bcbe7b 100644 --- a/enterprise/cli/server_dbcrypt_test.go +++ b/enterprise/cli/server_dbcrypt_test.go @@ -32,9 +32,8 @@ func TestServerDBCrypt(t *testing.T) { t.Cleanup(cancel) // Setup a postgres database. - connectionURL, closePg, err := dbtestutil.Open() + connectionURL, err := dbtestutil.Open(t) require.NoError(t, err) - t.Cleanup(closePg) t.Cleanup(func() { dbtestutil.DumpOnFailure(t, connectionURL) }) sqlDB, err := sql.Open("postgres", connectionURL) diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index c0d122aa74992..49248e636f04b 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -798,9 +798,8 @@ func TestPGCoordinatorDual_FailedHeartbeat(t *testing.T) { t.Skip("test only with postgres") } - dburl, closeFn, err := dbtestutil.Open() + dburl, err := dbtestutil.Open(t) require.NoError(t, err) - t.Cleanup(closeFn) store1, ps1, sdb1 := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithURL(dburl)) defer sdb1.Close()