Skip to content

fix: use authenticated urls for pubsub #14261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add customer driver tests
  • Loading branch information
f0ssel committed Aug 21, 2024
commit dae69c715d750a4c1b9bb7bb421582bd52f4f419
8 changes: 4 additions & 4 deletions coderd/database/connector.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
package database

import (
"context"
"database/sql/driver"

"github.com/lib/pq"
)

// ConnectorCreator can create a driver.Connector.
// ConnectorCreator is a driver.Driver that can create a driver.Connector.
type ConnectorCreator interface {
driver.Driver
Connector(name string) (driver.Connector, error)
}

// DialerConnector can create a driver.Connector and set a pq.Dialer.
// DialerConnector is a driver.Connector that can set a pq.Dialer.
type DialerConnector interface {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Connect(context.Context) (driver.Conn, error)
driver.Connector
Dialer(dialer pq.Dialer)
}
128 changes: 128 additions & 0 deletions coderd/database/dbtestutil/driver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package dbtestutil

import (
"context"
"database/sql"
"database/sql/driver"
"fmt"

"github.com/lib/pq"
"golang.org/x/xerrors"

"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/cryptorand"
)

var (
_ driver.Driver = &Driver{}
_ database.ConnectorCreator = &Driver{}
_ database.DialerConnector = &Connector{}
)

type Driver struct {
name string
inner driver.Driver
connections []driver.Conn
listeners map[chan struct{}]chan struct{}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This connections and listeners business is not threadsafe, and so can cause races in our tests.

A better design would be to just have a chan driver.Conn that we pass each new connection thru. Tests will have to know how many connections they expect, and can read from the channel. For the PubSub, it's the listener we want to test, which should use a connection when it first connects, then another one when it reconnects.

That way you don't need methods like AddConnection, WaitForConnection, or DropConnections -- the test code reads from the channel to wait for the connection, and can directly close the connection when it wants to interrupt it.

There is some complexity around the sql.DB, which we use for publishing and has a pool of connections. I suggest you sidestep that complexity by just using a second PubSub for publishing with a regular pq driver.

}

func Register() (*Driver, error) {
db, err := sql.Open("postgres", "")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a roundabout way to get a pq driver, which is just pq.Driver{}

if err != nil {
return nil, xerrors.Errorf("failed to open database: %w", err)
}

su, err := cryptorand.StringCharset(cryptorand.Alpha, 10)
if err != nil {
return nil, xerrors.Errorf("failed to generate random string: %w", err)
}

d := &Driver{
name: fmt.Sprintf("postgres-test-%s", su),
inner: db.Driver(),
listeners: make(map[chan struct{}]chan struct{}),
}

sql.Register(d.name, d)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like you're doing this so that you can get a sql.DB to pass to the PubSub with this instrumented driver. But, registering it and using sql.Open() with the name is more complex than it needs to be.

If you allow test code to directly instantiate the instrumented connector, then you can get a sql.DB as:

db := sql.ConnectDB(connector)

That avoids registering the driver and worrying about a unique name.


return d, nil
}

func (d *Driver) Open(name string) (driver.Conn, error) {
conn, err := d.inner.Open(name)
if err != nil {
return nil, xerrors.Errorf("failed to open connection: %w", err)
}

d.AddConnection(conn)

return conn, nil
}

func (d *Driver) Connector(name string) (driver.Connector, error) {
return &Connector{
name: name,
driver: d,
}, nil
}

func (d *Driver) Name() string {
return d.name
}

func (d *Driver) AddConnection(conn driver.Conn) {
d.connections = append(d.connections, conn)

for listener := range d.listeners {
d.listeners[listener] <- struct{}{}
}

}

func (d *Driver) WaitForConnection() {
ch := make(chan struct{})
d.listeners[ch] = ch
<-ch
delete(d.listeners, ch)
}

func (d *Driver) DropConnections() {
for _, conn := range d.connections {
_ = conn.Close()
}
d.connections = nil
}

type Connector struct {
name string
driver *Driver
dialer pq.Dialer
}

func (c *Connector) Connect(_ context.Context) (driver.Conn, error) {
if c.dialer != nil {
conn, err := pq.DialOpen(c.dialer, c.name)
if err != nil {
return nil, xerrors.Errorf("failed to dial open connection: %w", err)
}

c.driver.AddConnection(conn)

return conn, nil
}

conn, err := c.driver.Open(c.name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is circular. Driver.Open() creates a new Connector and calls into it. Fortunately, the pubsub creates a dialer, but anyone else who uses this might not and will overflow their stack.

You need to use pq.Driver{} directly here.

if err != nil {
return nil, xerrors.Errorf("failed to open connection: %w", err)
}

return conn, nil
}

func (c *Connector) Driver() driver.Driver {
return c.driver
}

func (c *Connector) Dialer(dialer pq.Dialer) {
c.dialer = dialer
}
5 changes: 4 additions & 1 deletion coderd/database/pubsub/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,10 @@ func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
}

// Set the dialer if the connector supports it.
if dc, ok := connector.(database.DialerConnector); ok {
dc, ok := connector.(database.DialerConnector)
if !ok {
p.logger.Critical(ctx, "connector does not support setting log dialer, database connection debug logs will be missing")
} else {
dc.Dialer(dialer)
}

Expand Down
79 changes: 78 additions & 1 deletion coderd/database/pubsub/pubsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"testing"
"time"

"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -86,7 +87,7 @@ func TestPGPubsub_Metrics(t *testing.T) {
for i := range colossalData {
colossalData[i] = 'q'
}
unsub1, err := uut.Subscribe(event, func(ctx context.Context, message []byte) {
unsub1, err := uut.Subscribe(event, func(_ context.Context, message []byte) {
messageChannel <- message
})
require.NoError(t, err)
Expand Down Expand Up @@ -119,3 +120,79 @@ func TestPGPubsub_Metrics(t *testing.T) {
!testutil.PromCounterGathered(t, metrics, "coder_pubsub_latency_measure_errs_total")
}, testutil.WaitShort, testutil.IntervalFast)
}

func TestPGPubsubDriver(t *testing.T) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good test, but the design of the instrumented driver is racy -- see comments on dbtestutil/driver.go for how to address.

t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}

ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, &slogtest.Options{
IgnoreErrors: true,
}).Leveled(slog.LevelDebug)

connectionURL, closePg, err := dbtestutil.Open()
require.NoError(t, err)
defer closePg()

// wrap the pg driver with one we can control
d, err := dbtestutil.Register()
require.NoError(t, err)

db, err := sql.Open(d.Name(), connectionURL)
require.NoError(t, err)
defer db.Close()

ps, err := pubsub.New(ctx, logger, db, connectionURL)
require.NoError(t, err)
defer ps.Close()

// test that we can publish and subscribe
gotChan := make(chan struct{})
defer close(gotChan)
subCancel, err := ps.Subscribe("test", func(_ context.Context, _ []byte) {
gotChan <- struct{}{}
})
require.NoError(t, err)
defer subCancel()

err = ps.Publish("test", []byte("hello"))
require.NoError(t, err)

select {
case <-gotChan:
case <-ctx.Done():
t.Fatal("timeout waiting for message")
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This construction comes up a lot, so we have

_ = testutil.RequireRecvContext(ctx, t, gotChan)

for this purpose.


reconnectChan := make(chan struct{})
go func() {
d.WaitForConnection()
// wait a bit to make sure the pubsub has reestablished it's connection
// if we don't wait, the publish may be dropped because the pubsub hasn't initialized yet.
time.Sleep(1 * time.Second)
reconnectChan <- struct{}{}
}()

// drop the underlying connection being used by the pubsub
// the pq.Listener should reconnect and repopulate it's listeners
// so old subscriptions should still work
d.DropConnections()

select {
case <-reconnectChan:
case <-ctx.Done():
t.Fatal("timeout waiting for reconnect")
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_ = testutil.RequireRecvChan(ctx, t, reconnectChan)


// ensure our old subscription still fires
err = ps.Publish("test", []byte("hello-again"))
require.NoError(t, err)

select {
case <-gotChan:
case <-ctx.Done():
t.Fatal("timeout waiting for message after reconnect")
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_ = testutil.RequireRecvCtx(ctx, t, gotChan)

}