Skip to content

Commit dae69c7

Browse files
committed
add customer driver tests
1 parent 85f6cfe commit dae69c7

File tree

4 files changed

+214
-6
lines changed

4 files changed

+214
-6
lines changed

coderd/database/connector.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
package database
22

33
import (
4-
"context"
54
"database/sql/driver"
65

76
"github.com/lib/pq"
87
)
98

10-
// ConnectorCreator can create a driver.Connector.
9+
// ConnectorCreator is a driver.Driver that can create a driver.Connector.
1110
type ConnectorCreator interface {
11+
driver.Driver
1212
Connector(name string) (driver.Connector, error)
1313
}
1414

15-
// DialerConnector can create a driver.Connector and set a pq.Dialer.
15+
// DialerConnector is a driver.Connector that can set a pq.Dialer.
1616
type DialerConnector interface {
17-
Connect(context.Context) (driver.Conn, error)
17+
driver.Connector
1818
Dialer(dialer pq.Dialer)
1919
}

coderd/database/dbtestutil/driver.go

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
package dbtestutil
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"database/sql/driver"
7+
"fmt"
8+
9+
"github.com/lib/pq"
10+
"golang.org/x/xerrors"
11+
12+
"github.com/coder/coder/v2/coderd/database"
13+
"github.com/coder/coder/v2/cryptorand"
14+
)
15+
16+
var (
17+
_ driver.Driver = &Driver{}
18+
_ database.ConnectorCreator = &Driver{}
19+
_ database.DialerConnector = &Connector{}
20+
)
21+
22+
type Driver struct {
23+
name string
24+
inner driver.Driver
25+
connections []driver.Conn
26+
listeners map[chan struct{}]chan struct{}
27+
}
28+
29+
func Register() (*Driver, error) {
30+
db, err := sql.Open("postgres", "")
31+
if err != nil {
32+
return nil, xerrors.Errorf("failed to open database: %w", err)
33+
}
34+
35+
su, err := cryptorand.StringCharset(cryptorand.Alpha, 10)
36+
if err != nil {
37+
return nil, xerrors.Errorf("failed to generate random string: %w", err)
38+
}
39+
40+
d := &Driver{
41+
name: fmt.Sprintf("postgres-test-%s", su),
42+
inner: db.Driver(),
43+
listeners: make(map[chan struct{}]chan struct{}),
44+
}
45+
46+
sql.Register(d.name, d)
47+
48+
return d, nil
49+
}
50+
51+
func (d *Driver) Open(name string) (driver.Conn, error) {
52+
conn, err := d.inner.Open(name)
53+
if err != nil {
54+
return nil, xerrors.Errorf("failed to open connection: %w", err)
55+
}
56+
57+
d.AddConnection(conn)
58+
59+
return conn, nil
60+
}
61+
62+
func (d *Driver) Connector(name string) (driver.Connector, error) {
63+
return &Connector{
64+
name: name,
65+
driver: d,
66+
}, nil
67+
}
68+
69+
func (d *Driver) Name() string {
70+
return d.name
71+
}
72+
73+
func (d *Driver) AddConnection(conn driver.Conn) {
74+
d.connections = append(d.connections, conn)
75+
76+
for listener := range d.listeners {
77+
d.listeners[listener] <- struct{}{}
78+
}
79+
80+
}
81+
82+
func (d *Driver) WaitForConnection() {
83+
ch := make(chan struct{})
84+
d.listeners[ch] = ch
85+
<-ch
86+
delete(d.listeners, ch)
87+
}
88+
89+
func (d *Driver) DropConnections() {
90+
for _, conn := range d.connections {
91+
_ = conn.Close()
92+
}
93+
d.connections = nil
94+
}
95+
96+
type Connector struct {
97+
name string
98+
driver *Driver
99+
dialer pq.Dialer
100+
}
101+
102+
func (c *Connector) Connect(_ context.Context) (driver.Conn, error) {
103+
if c.dialer != nil {
104+
conn, err := pq.DialOpen(c.dialer, c.name)
105+
if err != nil {
106+
return nil, xerrors.Errorf("failed to dial open connection: %w", err)
107+
}
108+
109+
c.driver.AddConnection(conn)
110+
111+
return conn, nil
112+
}
113+
114+
conn, err := c.driver.Open(c.name)
115+
if err != nil {
116+
return nil, xerrors.Errorf("failed to open connection: %w", err)
117+
}
118+
119+
return conn, nil
120+
}
121+
122+
func (c *Connector) Driver() driver.Driver {
123+
return c.driver
124+
}
125+
126+
func (c *Connector) Dialer(dialer pq.Dialer) {
127+
c.dialer = dialer
128+
}

coderd/database/pubsub/pubsub.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,10 @@ func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
455455
}
456456

457457
// Set the dialer if the connector supports it.
458-
if dc, ok := connector.(database.DialerConnector); ok {
458+
dc, ok := connector.(database.DialerConnector)
459+
if !ok {
460+
p.logger.Critical(ctx, "connector does not support setting log dialer, database connection debug logs will be missing")
461+
} else {
459462
dc.Dialer(dialer)
460463
}
461464

coderd/database/pubsub/pubsub_test.go

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"database/sql"
66
"testing"
7+
"time"
78

89
"github.com/prometheus/client_golang/prometheus"
910
"github.com/stretchr/testify/assert"
@@ -86,7 +87,7 @@ func TestPGPubsub_Metrics(t *testing.T) {
8687
for i := range colossalData {
8788
colossalData[i] = 'q'
8889
}
89-
unsub1, err := uut.Subscribe(event, func(ctx context.Context, message []byte) {
90+
unsub1, err := uut.Subscribe(event, func(_ context.Context, message []byte) {
9091
messageChannel <- message
9192
})
9293
require.NoError(t, err)
@@ -119,3 +120,79 @@ func TestPGPubsub_Metrics(t *testing.T) {
119120
!testutil.PromCounterGathered(t, metrics, "coder_pubsub_latency_measure_errs_total")
120121
}, testutil.WaitShort, testutil.IntervalFast)
121122
}
123+
124+
func TestPGPubsubDriver(t *testing.T) {
125+
t.Parallel()
126+
if !dbtestutil.WillUsePostgres() {
127+
t.Skip("test only with postgres")
128+
}
129+
130+
ctx := testutil.Context(t, testutil.WaitLong)
131+
logger := slogtest.Make(t, &slogtest.Options{
132+
IgnoreErrors: true,
133+
}).Leveled(slog.LevelDebug)
134+
135+
connectionURL, closePg, err := dbtestutil.Open()
136+
require.NoError(t, err)
137+
defer closePg()
138+
139+
// wrap the pg driver with one we can control
140+
d, err := dbtestutil.Register()
141+
require.NoError(t, err)
142+
143+
db, err := sql.Open(d.Name(), connectionURL)
144+
require.NoError(t, err)
145+
defer db.Close()
146+
147+
ps, err := pubsub.New(ctx, logger, db, connectionURL)
148+
require.NoError(t, err)
149+
defer ps.Close()
150+
151+
// test that we can publish and subscribe
152+
gotChan := make(chan struct{})
153+
defer close(gotChan)
154+
subCancel, err := ps.Subscribe("test", func(_ context.Context, _ []byte) {
155+
gotChan <- struct{}{}
156+
})
157+
require.NoError(t, err)
158+
defer subCancel()
159+
160+
err = ps.Publish("test", []byte("hello"))
161+
require.NoError(t, err)
162+
163+
select {
164+
case <-gotChan:
165+
case <-ctx.Done():
166+
t.Fatal("timeout waiting for message")
167+
}
168+
169+
reconnectChan := make(chan struct{})
170+
go func() {
171+
d.WaitForConnection()
172+
// wait a bit to make sure the pubsub has reestablished it's connection
173+
// if we don't wait, the publish may be dropped because the pubsub hasn't initialized yet.
174+
time.Sleep(1 * time.Second)
175+
reconnectChan <- struct{}{}
176+
}()
177+
178+
// drop the underlying connection being used by the pubsub
179+
// the pq.Listener should reconnect and repopulate it's listeners
180+
// so old subscriptions should still work
181+
d.DropConnections()
182+
183+
select {
184+
case <-reconnectChan:
185+
case <-ctx.Done():
186+
t.Fatal("timeout waiting for reconnect")
187+
}
188+
189+
// ensure our old subscription still fires
190+
err = ps.Publish("test", []byte("hello-again"))
191+
require.NoError(t, err)
192+
193+
select {
194+
case <-gotChan:
195+
case <-ctx.Done():
196+
t.Fatal("timeout waiting for message after reconnect")
197+
}
198+
}

0 commit comments

Comments
 (0)