-
Notifications
You must be signed in to change notification settings - Fork 894
fix: use authenticated urls for pubsub #14261
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
b89ff23
88d50b8
8be0524
3be05c2
c7964e4
6ff2ec7
02258ee
9d8681b
85f6cfe
dae69c7
55df464
2019fec
ff2784c
c18963b
1600abf
839a52e
428f3f5
cd61cc7
c946cdd
1b139e4
8c96f6e
a645a76
ed009f7
c8a7c9a
05cbc3c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,19 @@ | ||
package database | ||
|
||
import ( | ||
"context" | ||
"database/sql/driver" | ||
|
||
"github.com/lib/pq" | ||
) | ||
|
||
// ConnectorCreator can create a driver.Connector. | ||
// ConnectorCreator is a driver.Driver that can create a driver.Connector. | ||
type ConnectorCreator interface { | ||
driver.Driver | ||
Connector(name string) (driver.Connector, error) | ||
} | ||
|
||
// DialerConnector can create a driver.Connector and set a pq.Dialer. | ||
// DialerConnector is a driver.Connector that can set a pq.Dialer. | ||
type DialerConnector interface { | ||
Connect(context.Context) (driver.Conn, error) | ||
driver.Connector | ||
Dialer(dialer pq.Dialer) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
package dbtestutil | ||
|
||
import ( | ||
"context" | ||
"database/sql" | ||
"database/sql/driver" | ||
"fmt" | ||
|
||
"github.com/lib/pq" | ||
"golang.org/x/xerrors" | ||
|
||
"github.com/coder/coder/v2/coderd/database" | ||
"github.com/coder/coder/v2/cryptorand" | ||
) | ||
|
||
var ( | ||
_ driver.Driver = &Driver{} | ||
_ database.ConnectorCreator = &Driver{} | ||
_ database.DialerConnector = &Connector{} | ||
) | ||
|
||
type Driver struct { | ||
name string | ||
inner driver.Driver | ||
connections []driver.Conn | ||
listeners map[chan struct{}]chan struct{} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This connections and listeners business is not threadsafe, and so can cause races in our tests. A better design would be to just have a That way you don't need methods like There is some complexity around the |
||
} | ||
|
||
func Register() (*Driver, error) { | ||
db, err := sql.Open("postgres", "") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a roundabout way to get a |
||
if err != nil { | ||
return nil, xerrors.Errorf("failed to open database: %w", err) | ||
} | ||
|
||
su, err := cryptorand.StringCharset(cryptorand.Alpha, 10) | ||
if err != nil { | ||
return nil, xerrors.Errorf("failed to generate random string: %w", err) | ||
} | ||
|
||
d := &Driver{ | ||
name: fmt.Sprintf("postgres-test-%s", su), | ||
inner: db.Driver(), | ||
listeners: make(map[chan struct{}]chan struct{}), | ||
} | ||
|
||
sql.Register(d.name, d) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like you're doing this so that you can get a If you allow test code to directly instantiate the instrumented connector, then you can get a
That avoids registering the driver and worrying about a unique name. |
||
|
||
return d, nil | ||
} | ||
|
||
func (d *Driver) Open(name string) (driver.Conn, error) { | ||
conn, err := d.inner.Open(name) | ||
if err != nil { | ||
return nil, xerrors.Errorf("failed to open connection: %w", err) | ||
} | ||
|
||
d.AddConnection(conn) | ||
|
||
return conn, nil | ||
} | ||
|
||
func (d *Driver) Connector(name string) (driver.Connector, error) { | ||
return &Connector{ | ||
name: name, | ||
driver: d, | ||
}, nil | ||
} | ||
|
||
func (d *Driver) Name() string { | ||
return d.name | ||
} | ||
|
||
func (d *Driver) AddConnection(conn driver.Conn) { | ||
d.connections = append(d.connections, conn) | ||
|
||
for listener := range d.listeners { | ||
d.listeners[listener] <- struct{}{} | ||
} | ||
|
||
} | ||
|
||
func (d *Driver) WaitForConnection() { | ||
ch := make(chan struct{}) | ||
d.listeners[ch] = ch | ||
<-ch | ||
delete(d.listeners, ch) | ||
} | ||
|
||
func (d *Driver) DropConnections() { | ||
for _, conn := range d.connections { | ||
_ = conn.Close() | ||
} | ||
d.connections = nil | ||
} | ||
|
||
type Connector struct { | ||
name string | ||
driver *Driver | ||
dialer pq.Dialer | ||
} | ||
|
||
func (c *Connector) Connect(_ context.Context) (driver.Conn, error) { | ||
if c.dialer != nil { | ||
conn, err := pq.DialOpen(c.dialer, c.name) | ||
if err != nil { | ||
return nil, xerrors.Errorf("failed to dial open connection: %w", err) | ||
} | ||
|
||
c.driver.AddConnection(conn) | ||
|
||
return conn, nil | ||
} | ||
|
||
conn, err := c.driver.Open(c.name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is circular. You need to use |
||
if err != nil { | ||
return nil, xerrors.Errorf("failed to open connection: %w", err) | ||
} | ||
|
||
return conn, nil | ||
} | ||
|
||
func (c *Connector) Driver() driver.Driver { | ||
return c.driver | ||
} | ||
|
||
func (c *Connector) Dialer(dialer pq.Dialer) { | ||
c.dialer = dialer | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ import ( | |
"context" | ||
"database/sql" | ||
"testing" | ||
"time" | ||
|
||
"github.com/prometheus/client_golang/prometheus" | ||
"github.com/stretchr/testify/assert" | ||
|
@@ -86,7 +87,7 @@ func TestPGPubsub_Metrics(t *testing.T) { | |
for i := range colossalData { | ||
colossalData[i] = 'q' | ||
} | ||
unsub1, err := uut.Subscribe(event, func(ctx context.Context, message []byte) { | ||
unsub1, err := uut.Subscribe(event, func(_ context.Context, message []byte) { | ||
messageChannel <- message | ||
}) | ||
require.NoError(t, err) | ||
|
@@ -119,3 +120,79 @@ func TestPGPubsub_Metrics(t *testing.T) { | |
!testutil.PromCounterGathered(t, metrics, "coder_pubsub_latency_measure_errs_total") | ||
}, testutil.WaitShort, testutil.IntervalFast) | ||
} | ||
|
||
func TestPGPubsubDriver(t *testing.T) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a good test, but the design of the instrumented driver is racy -- see comments on |
||
t.Parallel() | ||
if !dbtestutil.WillUsePostgres() { | ||
t.Skip("test only with postgres") | ||
} | ||
|
||
ctx := testutil.Context(t, testutil.WaitLong) | ||
logger := slogtest.Make(t, &slogtest.Options{ | ||
IgnoreErrors: true, | ||
}).Leveled(slog.LevelDebug) | ||
|
||
connectionURL, closePg, err := dbtestutil.Open() | ||
require.NoError(t, err) | ||
defer closePg() | ||
|
||
// wrap the pg driver with one we can control | ||
d, err := dbtestutil.Register() | ||
require.NoError(t, err) | ||
|
||
db, err := sql.Open(d.Name(), connectionURL) | ||
require.NoError(t, err) | ||
defer db.Close() | ||
|
||
ps, err := pubsub.New(ctx, logger, db, connectionURL) | ||
require.NoError(t, err) | ||
defer ps.Close() | ||
|
||
// test that we can publish and subscribe | ||
gotChan := make(chan struct{}) | ||
defer close(gotChan) | ||
subCancel, err := ps.Subscribe("test", func(_ context.Context, _ []byte) { | ||
gotChan <- struct{}{} | ||
}) | ||
require.NoError(t, err) | ||
defer subCancel() | ||
|
||
err = ps.Publish("test", []byte("hello")) | ||
require.NoError(t, err) | ||
|
||
select { | ||
case <-gotChan: | ||
case <-ctx.Done(): | ||
t.Fatal("timeout waiting for message") | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This construction comes up a lot, so we have
for this purpose. |
||
|
||
reconnectChan := make(chan struct{}) | ||
go func() { | ||
d.WaitForConnection() | ||
// wait a bit to make sure the pubsub has reestablished it's connection | ||
// if we don't wait, the publish may be dropped because the pubsub hasn't initialized yet. | ||
time.Sleep(1 * time.Second) | ||
reconnectChan <- struct{}{} | ||
}() | ||
|
||
// drop the underlying connection being used by the pubsub | ||
// the pq.Listener should reconnect and repopulate it's listeners | ||
// so old subscriptions should still work | ||
d.DropConnections() | ||
|
||
select { | ||
case <-reconnectChan: | ||
case <-ctx.Done(): | ||
t.Fatal("timeout waiting for reconnect") | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
// ensure our old subscription still fires | ||
err = ps.Publish("test", []byte("hello-again")) | ||
require.NoError(t, err) | ||
|
||
select { | ||
case <-gotChan: | ||
case <-ctx.Done(): | ||
t.Fatal("timeout waiting for message after reconnect") | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍