Skip to content

Commit 1600abf

Browse files
committed
fix tests
1 parent c18963b commit 1600abf

File tree

2 files changed

+57
-116
lines changed

2 files changed

+57
-116
lines changed

coderd/database/dbtestutil/driver.go

Lines changed: 31 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -2,96 +2,19 @@ package dbtestutil
22

33
import (
44
"context"
5-
"database/sql"
5+
66
"database/sql/driver"
7-
"fmt"
87

98
"github.com/lib/pq"
109
"golang.org/x/xerrors"
1110

1211
"github.com/coder/coder/v2/coderd/database"
13-
"github.com/coder/coder/v2/cryptorand"
1412
)
1513

1614
var (
17-
_ driver.Driver = &Driver{}
18-
_ database.ConnectorCreator = &Driver{}
19-
_ database.DialerConnector = &Connector{}
15+
_ database.DialerConnector = &Connector{}
2016
)
2117

22-
type Driver struct {
23-
name string
24-
inner driver.Driver
25-
connections []driver.Conn
26-
listeners map[chan struct{}]chan struct{}
27-
}
28-
29-
func Register() (*Driver, error) {
30-
db, err := sql.Open("postgres", "")
31-
if err != nil {
32-
return nil, xerrors.Errorf("failed to open database: %w", err)
33-
}
34-
35-
su, err := cryptorand.StringCharset(cryptorand.Alpha, 10)
36-
if err != nil {
37-
return nil, xerrors.Errorf("failed to generate random string: %w", err)
38-
}
39-
40-
d := &Driver{
41-
name: fmt.Sprintf("postgres-test-%s", su),
42-
inner: db.Driver(),
43-
listeners: make(map[chan struct{}]chan struct{}),
44-
}
45-
46-
sql.Register(d.name, d)
47-
48-
return d, nil
49-
}
50-
51-
func (d *Driver) Open(name string) (driver.Conn, error) {
52-
conn, err := d.inner.Open(name)
53-
if err != nil {
54-
return nil, xerrors.Errorf("failed to open connection: %w", err)
55-
}
56-
57-
d.AddConnection(conn)
58-
59-
return conn, nil
60-
}
61-
62-
func (d *Driver) Connector(name string) (driver.Connector, error) {
63-
return &Connector{
64-
name: name,
65-
driver: d,
66-
}, nil
67-
}
68-
69-
func (d *Driver) Name() string {
70-
return d.name
71-
}
72-
73-
func (d *Driver) AddConnection(conn driver.Conn) {
74-
d.connections = append(d.connections, conn)
75-
for listener := range d.listeners {
76-
d.listeners[listener] <- struct{}{}
77-
}
78-
}
79-
80-
func (d *Driver) WaitForConnection() {
81-
ch := make(chan struct{})
82-
defer close(ch)
83-
defer delete(d.listeners, ch)
84-
d.listeners[ch] = ch
85-
<-ch
86-
}
87-
88-
func (d *Driver) DropConnections() {
89-
for _, conn := range d.connections {
90-
_ = conn.Close()
91-
}
92-
d.connections = nil
93-
}
94-
9518
type Connector struct {
9619
name string
9720
driver *Driver
@@ -105,7 +28,7 @@ func (c *Connector) Connect(_ context.Context) (driver.Conn, error) {
10528
return nil, xerrors.Errorf("failed to dial open connection: %w", err)
10629
}
10730

108-
c.driver.AddConnection(conn)
31+
c.driver.Connections <- conn
10932

11033
return conn, nil
11134
}
@@ -115,6 +38,8 @@ func (c *Connector) Connect(_ context.Context) (driver.Conn, error) {
11538
return nil, xerrors.Errorf("failed to open connection: %w", err)
11639
}
11740

41+
c.driver.Connections <- conn
42+
11843
return conn, nil
11944
}
12045

@@ -125,3 +50,29 @@ func (c *Connector) Driver() driver.Driver {
12550
func (c *Connector) Dialer(dialer pq.Dialer) {
12651
c.dialer = dialer
12752
}
53+
54+
type Driver struct {
55+
Connections chan driver.Conn
56+
}
57+
58+
func NewDriver() *Driver {
59+
return &Driver{
60+
Connections: make(chan driver.Conn, 1),
61+
}
62+
}
63+
64+
func (d *Driver) Connector(name string) (driver.Connector, error) {
65+
return &Connector{
66+
name: name,
67+
driver: d,
68+
}, nil
69+
}
70+
71+
func (d *Driver) Open(name string) (driver.Conn, error) {
72+
c, err := d.Connector(name)
73+
if err != nil {
74+
return nil, err
75+
}
76+
77+
return c.Connect(context.Background())
78+
}

coderd/database/pubsub/pubsub_test.go

Lines changed: 26 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"database/sql"
66
"testing"
7-
"time"
87

98
"github.com/prometheus/client_golang/prometheus"
109
"github.com/stretchr/testify/assert"
@@ -136,63 +135,54 @@ func TestPGPubsubDriver(t *testing.T) {
136135
require.NoError(t, err)
137136
defer closePg()
138137

139-
// wrap the pg driver with one we can control
140-
d, err := dbtestutil.Register()
138+
// use a separate subber and pubber so we can keep track of listener connections
139+
db, err := sql.Open("postgres", connectionURL)
141140
require.NoError(t, err)
142-
143-
db, err := sql.Open(d.Name(), connectionURL)
141+
pubber, err := pubsub.New(ctx, logger, db, connectionURL)
144142
require.NoError(t, err)
145-
defer db.Close()
146143

147-
ps, err := pubsub.New(ctx, logger, db, connectionURL)
144+
// use a connector that sends us the connections for the subber
145+
subDriver := dbtestutil.NewDriver()
146+
tconn, err := subDriver.Connector(connectionURL)
148147
require.NoError(t, err)
149-
defer ps.Close()
148+
tcdb := sql.OpenDB(tconn)
149+
subber, err := pubsub.New(ctx, logger, tcdb, connectionURL)
150+
require.NoError(t, err)
151+
defer subber.Close()
150152

151153
// test that we can publish and subscribe
152-
gotChan := make(chan struct{})
154+
gotChan := make(chan struct{}, 1)
153155
defer close(gotChan)
154-
subCancel, err := ps.Subscribe("test", func(_ context.Context, _ []byte) {
156+
subCancel, err := subber.Subscribe("test", func(_ context.Context, _ []byte) {
155157
gotChan <- struct{}{}
156158
})
157159
require.NoError(t, err)
158160
defer subCancel()
159161

160-
err = ps.Publish("test", []byte("hello"))
162+
t.Log("publishing message")
163+
// send a message
164+
err = pubber.Publish("test", []byte("hello"))
161165
require.NoError(t, err)
162166

163-
select {
164-
case <-gotChan:
165-
case <-ctx.Done():
166-
t.Fatal("timeout waiting for message")
167-
}
167+
// wait for the message
168+
_ = testutil.RequireRecvCtx(ctx, t, gotChan)
168169

169-
reconnectChan := make(chan struct{})
170-
go func() {
171-
d.WaitForConnection()
172-
// wait a bit to make sure the pubsub has reestablished it's connection
173-
// if we don't wait, the publish may be dropped because the pubsub hasn't initialized yet.
174-
time.Sleep(1 * time.Second)
175-
reconnectChan <- struct{}{}
176-
}()
170+
// read out first connection
171+
firstConn := testutil.RequireRecvCtx(ctx, t, subDriver.Connections)
177172

178173
// drop the underlying connection being used by the pubsub
179174
// the pq.Listener should reconnect and repopulate it's listeners
180175
// so old subscriptions should still work
181-
d.DropConnections()
176+
err = firstConn.Close()
177+
require.NoError(t, err)
182178

183-
select {
184-
case <-reconnectChan:
185-
case <-ctx.Done():
186-
t.Fatal("timeout waiting for reconnect")
187-
}
179+
// wait for the reconnect
180+
_ = testutil.RequireRecvCtx(ctx, t, subDriver.Connections)
188181

189182
// ensure our old subscription still fires
190-
err = ps.Publish("test", []byte("hello-again"))
183+
err = pubber.Publish("test", []byte("hello-again"))
191184
require.NoError(t, err)
192185

193-
select {
194-
case <-gotChan:
195-
case <-ctx.Done():
196-
t.Fatal("timeout waiting for message after reconnect")
197-
}
186+
// wait for the message on the old subscription
187+
_ = testutil.RequireRecvCtx(ctx, t, gotChan)
198188
}

0 commit comments

Comments
 (0)