Skip to content

fix: use authenticated urls for pubsub #14261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix tests
  • Loading branch information
f0ssel committed Aug 22, 2024
commit 1600abff933930520fbe39ee2254a360bea8a699
111 changes: 31 additions & 80 deletions coderd/database/dbtestutil/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,96 +2,19 @@ package dbtestutil

import (
"context"
"database/sql"

"database/sql/driver"
"fmt"

"github.com/lib/pq"
"golang.org/x/xerrors"

"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/cryptorand"
)

var (
_ driver.Driver = &Driver{}
_ database.ConnectorCreator = &Driver{}
_ database.DialerConnector = &Connector{}
_ database.DialerConnector = &Connector{}
)

type Driver struct {
name string
inner driver.Driver
connections []driver.Conn
listeners map[chan struct{}]chan struct{}
}

func Register() (*Driver, error) {
db, err := sql.Open("postgres", "")
if err != nil {
return nil, xerrors.Errorf("failed to open database: %w", err)
}

su, err := cryptorand.StringCharset(cryptorand.Alpha, 10)
if err != nil {
return nil, xerrors.Errorf("failed to generate random string: %w", err)
}

d := &Driver{
name: fmt.Sprintf("postgres-test-%s", su),
inner: db.Driver(),
listeners: make(map[chan struct{}]chan struct{}),
}

sql.Register(d.name, d)

return d, nil
}

func (d *Driver) Open(name string) (driver.Conn, error) {
conn, err := d.inner.Open(name)
if err != nil {
return nil, xerrors.Errorf("failed to open connection: %w", err)
}

d.AddConnection(conn)

return conn, nil
}

func (d *Driver) Connector(name string) (driver.Connector, error) {
return &Connector{
name: name,
driver: d,
}, nil
}

func (d *Driver) Name() string {
return d.name
}

func (d *Driver) AddConnection(conn driver.Conn) {
d.connections = append(d.connections, conn)
for listener := range d.listeners {
d.listeners[listener] <- struct{}{}
}
}

func (d *Driver) WaitForConnection() {
ch := make(chan struct{})
defer close(ch)
defer delete(d.listeners, ch)
d.listeners[ch] = ch
<-ch
}

func (d *Driver) DropConnections() {
for _, conn := range d.connections {
_ = conn.Close()
}
d.connections = nil
}

type Connector struct {
name string
driver *Driver
Expand All @@ -105,7 +28,7 @@ func (c *Connector) Connect(_ context.Context) (driver.Conn, error) {
return nil, xerrors.Errorf("failed to dial open connection: %w", err)
}

c.driver.AddConnection(conn)
c.driver.Connections <- conn

return conn, nil
}
Expand All @@ -115,6 +38,8 @@ func (c *Connector) Connect(_ context.Context) (driver.Conn, error) {
return nil, xerrors.Errorf("failed to open connection: %w", err)
}

c.driver.Connections <- conn

return conn, nil
}

Expand All @@ -125,3 +50,29 @@ func (c *Connector) Driver() driver.Driver {
func (c *Connector) Dialer(dialer pq.Dialer) {
c.dialer = dialer
}

type Driver struct {
Connections chan driver.Conn
}

func NewDriver() *Driver {
return &Driver{
Connections: make(chan driver.Conn, 1),
}
}

func (d *Driver) Connector(name string) (driver.Connector, error) {
return &Connector{
name: name,
driver: d,
}, nil
}

func (d *Driver) Open(name string) (driver.Conn, error) {
c, err := d.Connector(name)
if err != nil {
return nil, err
}

return c.Connect(context.Background())
}
62 changes: 26 additions & 36 deletions coderd/database/pubsub/pubsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"database/sql"
"testing"
"time"

"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -136,63 +135,54 @@ func TestPGPubsubDriver(t *testing.T) {
require.NoError(t, err)
defer closePg()

// wrap the pg driver with one we can control
d, err := dbtestutil.Register()
// use a separate subber and pubber so we can keep track of listener connections
db, err := sql.Open("postgres", connectionURL)
require.NoError(t, err)

db, err := sql.Open(d.Name(), connectionURL)
pubber, err := pubsub.New(ctx, logger, db, connectionURL)
require.NoError(t, err)
defer db.Close()

ps, err := pubsub.New(ctx, logger, db, connectionURL)
// use a connector that sends us the connections for the subber
subDriver := dbtestutil.NewDriver()
tconn, err := subDriver.Connector(connectionURL)
require.NoError(t, err)
defer ps.Close()
tcdb := sql.OpenDB(tconn)
subber, err := pubsub.New(ctx, logger, tcdb, connectionURL)
require.NoError(t, err)
defer subber.Close()

// test that we can publish and subscribe
gotChan := make(chan struct{})
gotChan := make(chan struct{}, 1)
defer close(gotChan)
subCancel, err := ps.Subscribe("test", func(_ context.Context, _ []byte) {
subCancel, err := subber.Subscribe("test", func(_ context.Context, _ []byte) {
gotChan <- struct{}{}
})
require.NoError(t, err)
defer subCancel()

err = ps.Publish("test", []byte("hello"))
t.Log("publishing message")
// send a message
err = pubber.Publish("test", []byte("hello"))
require.NoError(t, err)

select {
case <-gotChan:
case <-ctx.Done():
t.Fatal("timeout waiting for message")
}
// wait for the message
_ = testutil.RequireRecvCtx(ctx, t, gotChan)

reconnectChan := make(chan struct{})
go func() {
d.WaitForConnection()
// wait a bit to make sure the pubsub has reestablished it's connection
// if we don't wait, the publish may be dropped because the pubsub hasn't initialized yet.
time.Sleep(1 * time.Second)
reconnectChan <- struct{}{}
}()
// read out first connection
firstConn := testutil.RequireRecvCtx(ctx, t, subDriver.Connections)

// drop the underlying connection being used by the pubsub
// the pq.Listener should reconnect and repopulate it's listeners
// so old subscriptions should still work
d.DropConnections()
err = firstConn.Close()
require.NoError(t, err)

select {
case <-reconnectChan:
case <-ctx.Done():
t.Fatal("timeout waiting for reconnect")
}
// wait for the reconnect
_ = testutil.RequireRecvCtx(ctx, t, subDriver.Connections)

// ensure our old subscription still fires
err = ps.Publish("test", []byte("hello-again"))
err = pubber.Publish("test", []byte("hello-again"))
require.NoError(t, err)

select {
case <-gotChan:
case <-ctx.Done():
t.Fatal("timeout waiting for message after reconnect")
}
// wait for the message on the old subscription
_ = testutil.RequireRecvCtx(ctx, t, gotChan)
}