-
Notifications
You must be signed in to change notification settings - Fork 894
fix: use authenticated urls for pubsub #14261
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
b89ff23
88d50b8
8be0524
3be05c2
c7964e4
6ff2ec7
02258ee
9d8681b
85f6cfe
dae69c7
55df464
2019fec
ff2784c
c18963b
1600abf
839a52e
428f3f5
cd61cc7
c946cdd
1b139e4
8c96f6e
a645a76
ed009f7
c8a7c9a
05cbc3c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
package database | ||
|
||
import ( | ||
"database/sql/driver" | ||
|
||
"github.com/lib/pq" | ||
) | ||
|
||
// ConnectorCreator is a driver.Driver that can create a driver.Connector. | ||
type ConnectorCreator interface { | ||
driver.Driver | ||
Connector(name string) (driver.Connector, error) | ||
} | ||
|
||
// DialerConnector is a driver.Connector that can set a pq.Dialer. | ||
type DialerConnector interface { | ||
driver.Connector | ||
Dialer(dialer pq.Dialer) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
package dbtestutil | ||
|
||
import ( | ||
"context" | ||
"database/sql" | ||
"database/sql/driver" | ||
"fmt" | ||
|
||
"github.com/lib/pq" | ||
"golang.org/x/xerrors" | ||
|
||
"github.com/coder/coder/v2/coderd/database" | ||
"github.com/coder/coder/v2/cryptorand" | ||
) | ||
|
||
var ( | ||
_ driver.Driver = &Driver{} | ||
_ database.ConnectorCreator = &Driver{} | ||
_ database.DialerConnector = &Connector{} | ||
) | ||
|
||
type Driver struct { | ||
name string | ||
inner driver.Driver | ||
connections []driver.Conn | ||
listeners map[chan struct{}]chan struct{} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This connections and listeners business is not threadsafe, and so can cause races in our tests. A better design would be to just have a That way you don't need methods like There is some complexity around the |
||
} | ||
|
||
func Register() (*Driver, error) { | ||
db, err := sql.Open("postgres", "") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a roundabout way to get a |
||
if err != nil { | ||
return nil, xerrors.Errorf("failed to open database: %w", err) | ||
} | ||
|
||
su, err := cryptorand.StringCharset(cryptorand.Alpha, 10) | ||
if err != nil { | ||
return nil, xerrors.Errorf("failed to generate random string: %w", err) | ||
} | ||
|
||
d := &Driver{ | ||
name: fmt.Sprintf("postgres-test-%s", su), | ||
inner: db.Driver(), | ||
listeners: make(map[chan struct{}]chan struct{}), | ||
} | ||
|
||
sql.Register(d.name, d) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like you're doing this so that you can get a If you allow test code to directly instantiate the instrumented connector, then you can get a
That avoids registering the driver and worrying about a unique name. |
||
|
||
return d, nil | ||
} | ||
|
||
func (d *Driver) Open(name string) (driver.Conn, error) { | ||
conn, err := d.inner.Open(name) | ||
if err != nil { | ||
return nil, xerrors.Errorf("failed to open connection: %w", err) | ||
} | ||
|
||
d.AddConnection(conn) | ||
|
||
return conn, nil | ||
} | ||
|
||
func (d *Driver) Connector(name string) (driver.Connector, error) { | ||
return &Connector{ | ||
name: name, | ||
driver: d, | ||
}, nil | ||
} | ||
|
||
func (d *Driver) Name() string { | ||
return d.name | ||
} | ||
|
||
func (d *Driver) AddConnection(conn driver.Conn) { | ||
d.connections = append(d.connections, conn) | ||
for listener := range d.listeners { | ||
d.listeners[listener] <- struct{}{} | ||
} | ||
} | ||
|
||
func (d *Driver) WaitForConnection() { | ||
ch := make(chan struct{}) | ||
defer close(ch) | ||
defer delete(d.listeners, ch) | ||
d.listeners[ch] = ch | ||
<-ch | ||
} | ||
|
||
func (d *Driver) DropConnections() { | ||
for _, conn := range d.connections { | ||
_ = conn.Close() | ||
} | ||
d.connections = nil | ||
} | ||
|
||
type Connector struct { | ||
name string | ||
driver *Driver | ||
dialer pq.Dialer | ||
} | ||
|
||
func (c *Connector) Connect(_ context.Context) (driver.Conn, error) { | ||
if c.dialer != nil { | ||
conn, err := pq.DialOpen(c.dialer, c.name) | ||
if err != nil { | ||
return nil, xerrors.Errorf("failed to dial open connection: %w", err) | ||
} | ||
|
||
c.driver.AddConnection(conn) | ||
|
||
return conn, nil | ||
} | ||
|
||
conn, err := c.driver.Open(c.name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is circular. You need to use |
||
if err != nil { | ||
return nil, xerrors.Errorf("failed to open connection: %w", err) | ||
} | ||
|
||
return conn, nil | ||
} | ||
|
||
func (c *Connector) Driver() driver.Driver { | ||
return c.driver | ||
} | ||
|
||
func (c *Connector) Dialer(dialer pq.Dialer) { | ||
c.dialer = dialer | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ package pubsub | |
import ( | ||
"context" | ||
"database/sql" | ||
"database/sql/driver" | ||
"errors" | ||
"io" | ||
"net" | ||
|
@@ -15,6 +16,8 @@ import ( | |
"github.com/prometheus/client_golang/prometheus" | ||
"golang.org/x/xerrors" | ||
|
||
"github.com/coder/coder/v2/coderd/database" | ||
|
||
"cdr.dev/slog" | ||
) | ||
|
||
|
@@ -432,9 +435,35 @@ func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error { | |
// pq.defaultDialer uses a zero net.Dialer as well. | ||
d: net.Dialer{}, | ||
} | ||
connector driver.Connector | ||
err error | ||
) | ||
|
||
// Create a custom connector if the database driver supports it. | ||
connectorCreator, ok := p.db.Driver().(database.ConnectorCreator) | ||
if ok { | ||
connector, err = connectorCreator.Connector(connectURL) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code doesn't get hit in the package unit tests. A good way to test it would be to create a pq Driver wrapper that we can control. I'd like to see a test where we start pubsub with the wrapped driver, do some pub'ing and sub'ing, then kill the connection and verify that the pubsub / pq.Listener reconnects automatically. That would give a nice test of the pq changes you made as well. |
||
if err != nil { | ||
return xerrors.Errorf("create custom connector: %w", err) | ||
} | ||
} else { | ||
// use the default pq connector otherwise | ||
connector, err = pq.NewConnector(connectURL) | ||
if err != nil { | ||
return xerrors.Errorf("create pq connector: %w", err) | ||
} | ||
} | ||
|
||
// Set the dialer if the connector supports it. | ||
dc, ok := connector.(database.DialerConnector) | ||
if !ok { | ||
p.logger.Critical(ctx, "connector does not support setting log dialer, database connection debug logs will be missing") | ||
} else { | ||
dc.Dialer(dialer) | ||
} | ||
|
||
p.pgListener = pqListenerShim{ | ||
Listener: pq.NewDialListener(dialer, connectURL, time.Second, time.Minute, func(t pq.ListenerEventType, err error) { | ||
Listener: pq.NewConnectorListener(connector, connectURL, time.Second, time.Minute, func(t pq.ListenerEventType, err error) { | ||
switch t { | ||
case pq.ListenerEventConnected: | ||
p.logger.Info(ctx, "pubsub connected to postgres") | ||
|
@@ -583,8 +612,8 @@ func (p *PGPubsub) Collect(metrics chan<- prometheus.Metric) { | |
} | ||
|
||
// New creates a new Pubsub implementation using a PostgreSQL connection. | ||
func New(startCtx context.Context, logger slog.Logger, database *sql.DB, connectURL string) (*PGPubsub, error) { | ||
p := newWithoutListener(logger, database) | ||
func New(startCtx context.Context, logger slog.Logger, db *sql.DB, connectURL string) (*PGPubsub, error) { | ||
p := newWithoutListener(logger, db) | ||
if err := p.startListener(startCtx, connectURL); err != nil { | ||
return nil, err | ||
} | ||
|
@@ -594,11 +623,11 @@ func New(startCtx context.Context, logger slog.Logger, database *sql.DB, connect | |
} | ||
|
||
// newWithoutListener creates a new PGPubsub without creating the pqListener. | ||
func newWithoutListener(logger slog.Logger, database *sql.DB) *PGPubsub { | ||
func newWithoutListener(logger slog.Logger, db *sql.DB) *PGPubsub { | ||
return &PGPubsub{ | ||
logger: logger, | ||
listenDone: make(chan struct{}), | ||
db: database, | ||
db: db, | ||
queues: make(map[string]map[uuid.UUID]*msgQueue), | ||
latencyMeasurer: NewLatencyMeasurer(logger.Named("latency-measurer")), | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍