diff --git a/coderd/database/dbtestutil/db.go b/coderd/database/dbtestutil/db.go index a32d7b31245f0..65e1afecc7948 100644 --- a/coderd/database/dbtestutil/db.go +++ b/coderd/database/dbtestutil/db.go @@ -31,6 +31,7 @@ func WillUsePostgres() bool { type options struct { fixedTimezone string dumpOnFailure bool + returnSQLDB func(*sql.DB) } type Option func(*options) @@ -49,6 +50,27 @@ func WithDumpOnFailure() Option { } } +func withReturnSQLDB(f func(*sql.DB)) Option { + return func(o *options) { + o.returnSQLDB = f + } +} + +func NewDBWithSQLDB(t testing.TB, opts ...Option) (database.Store, pubsub.Pubsub, *sql.DB) { + t.Helper() + + if !WillUsePostgres() { + t.Fatal("cannot use NewDBWithSQLDB without PostgreSQL, consider adding `if !dbtestutil.WillUsePostgres() { t.Skip() }` to this test") + } + + var sqlDB *sql.DB + opts = append(opts, withReturnSQLDB(func(db *sql.DB) { + sqlDB = db + })) + db, ps := NewDB(t, opts...) + return db, ps, sqlDB +} + func NewDB(t testing.TB, opts ...Option) (database.Store, pubsub.Pubsub) { t.Helper() @@ -88,6 +110,9 @@ func NewDB(t testing.TB, opts ...Option) (database.Store, pubsub.Pubsub) { t.Cleanup(func() { _ = sqlDB.Close() }) + if o.returnSQLDB != nil { + o.returnSQLDB(sqlDB) + } if o.dumpOnFailure { t.Cleanup(func() { DumpOnFailure(t, connectionURL) }) }