1
1
package dbtestutil
2
2
3
3
import (
4
+ "bytes"
4
5
"context"
5
6
"database/sql"
6
7
"fmt"
7
8
"net/url"
8
9
"os"
10
+ "os/exec"
11
+ "path/filepath"
12
+ "regexp"
9
13
"strings"
10
14
"testing"
11
15
12
16
"github.com/stretchr/testify/require"
17
+ "golang.org/x/xerrors"
13
18
14
19
"github.com/coder/coder/v2/coderd/database"
15
20
"github.com/coder/coder/v2/coderd/database/dbfake"
@@ -24,6 +29,7 @@ func WillUsePostgres() bool {
24
29
25
30
type options struct {
26
31
fixedTimezone string
32
+ dumpOnFailure bool
27
33
}
28
34
29
35
type Option func (* options )
@@ -35,6 +41,13 @@ func WithTimezone(tz string) Option {
35
41
}
36
42
}
37
43
44
+ // WithDumpOnFailure will dump the entire database on test failure.
45
+ func WithDumpOnFailure () Option {
46
+ return func (o * options ) {
47
+ o .dumpOnFailure = true
48
+ }
49
+ }
50
+
38
51
func NewDB (t testing.TB , opts ... Option ) (database.Store , pubsub.Pubsub ) {
39
52
t .Helper ()
40
53
@@ -74,6 +87,9 @@ func NewDB(t testing.TB, opts ...Option) (database.Store, pubsub.Pubsub) {
74
87
t .Cleanup (func () {
75
88
_ = sqlDB .Close ()
76
89
})
90
+ if o .dumpOnFailure {
91
+ t .Cleanup (func () { DumpOnFailure (t , connectionURL ) })
92
+ }
77
93
db = database .New (sqlDB )
78
94
79
95
ps , err = pubsub .New (context .Background (), sqlDB , connectionURL )
@@ -110,3 +126,85 @@ func dbNameFromConnectionURL(t testing.TB, connectionURL string) string {
110
126
require .NoError (t , err )
111
127
return strings .TrimPrefix (u .Path , "/" )
112
128
}
129
+
130
+ func DumpOnFailure (t testing.TB , connectionURL string ) {
131
+ if ! t .Failed () {
132
+ return
133
+ }
134
+ cwd , err := filepath .Abs ("." )
135
+ if err != nil {
136
+ t .Errorf ("dump on failure: cannot determine current working directory" )
137
+ return
138
+ }
139
+ snakeCaseName := regexp .MustCompile ("[^a-zA-Z0-9-_]+" ).ReplaceAllString (t .Name (), "_" )
140
+ outPath := filepath .Join (cwd , snakeCaseName + ".test.sql" )
141
+ dump , err := pgDump (connectionURL )
142
+ if err != nil {
143
+ t .Errorf ("dump on failure: failed to run pg_dump" )
144
+ return
145
+ }
146
+ if err := os .WriteFile (outPath , filterDump (dump ), 0o600 ); err != nil {
147
+ t .Errorf ("dump on failure: failed to write: %w" , err )
148
+ return
149
+ }
150
+ t .Logf ("Dumped database to %q due to failed test. I hope you find what you're looking for!" , outPath )
151
+ }
152
+
153
+ // pgDump runs pg_dump against dbURL and returns the output.
154
+ func pgDump (dbURL string ) ([]byte , error ) {
155
+ if _ , err := exec .LookPath ("pg_dump" ); err != nil {
156
+ return nil , xerrors .Errorf ("could not find pg_dump in path: %w" , err )
157
+ }
158
+ cmdArgs := []string {
159
+ "pg_dump" ,
160
+ dbURL ,
161
+ "--data-only" ,
162
+ "--no-comments" ,
163
+ "--no-privileges" ,
164
+ "--no-publication" ,
165
+ "--no-security-labels" ,
166
+ "--no-subscriptions" ,
167
+ "--no-tablespaces" ,
168
+ // "--no-unlogged-table-data", // some tables are unlogged and may contain data of interest
169
+ "--no-owner" ,
170
+ "--exclude-table=schema_migrations" ,
171
+ }
172
+ cmd := exec .Command (cmdArgs [0 ], cmdArgs [1 :]... ) // nolint:gosec
173
+ cmd .Env = []string {
174
+ // "PGTZ=UTC", // This is probably not going to be useful if tz has been changed.
175
+ "PGCLIENTENCODINDG=UTF8" ,
176
+ "PGDATABASE=" , // we should always specify the database name in the connection string
177
+ }
178
+ var stdout bytes.Buffer
179
+ cmd .Stdout = & stdout
180
+ if err := cmd .Run (); err != nil {
181
+ return nil , xerrors .Errorf ("exec pg_dump: %w" , err )
182
+ }
183
+ return stdout .Bytes (), nil
184
+ }
185
+
186
+ func filterDump (dump []byte ) []byte {
187
+ lines := bytes .Split (dump , []byte {'\n' })
188
+ var buf bytes.Buffer
189
+ for _ , line := range lines {
190
+ // Skip blank lines
191
+ if len (line ) == 0 {
192
+ continue
193
+ }
194
+ // Skip comments
195
+ if bytes .HasPrefix (line , []byte ("--" )) {
196
+ continue
197
+ }
198
+ // Skip SELECT or SET statements
199
+ if bytes .HasPrefix (line , []byte ("SELECT" )) {
200
+ continue
201
+ }
202
+ if bytes .HasPrefix (line , []byte ("SET" )) {
203
+ continue
204
+ }
205
+
206
+ buf .Write (line )
207
+ buf .WriteRune ('\n' )
208
+ }
209
+ return buf .Bytes ()
210
+ }
0 commit comments