diff --git a/coderd/database/pubsub.go b/coderd/database/pubsub.go index 1995cd7203510..6a6d1f2f07751 100644 --- a/coderd/database/pubsub.go +++ b/coderd/database/pubsub.go @@ -163,6 +163,8 @@ func (q *msgQueue) dropped() { // Pubsub implementation using PostgreSQL. type pgPubsub struct { ctx context.Context + cancel context.CancelFunc + listenDone chan struct{} pgListener *pq.Listener db *sql.DB mut sync.Mutex @@ -228,7 +230,7 @@ func (p *pgPubsub) Publish(event string, message []byte) error { // This is safe because we are calling pq.QuoteLiteral. pg_notify doesn't // support the first parameter being a prepared statement. //nolint:gosec - _, err := p.db.ExecContext(context.Background(), `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message) + _, err := p.db.ExecContext(p.ctx, `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message) if err != nil { return xerrors.Errorf("exec pg_notify: %w", err) } @@ -237,19 +239,24 @@ func (p *pgPubsub) Publish(event string, message []byte) error { // Close closes the pubsub instance. func (p *pgPubsub) Close() error { - return p.pgListener.Close() + p.cancel() + err := p.pgListener.Close() + <-p.listenDone + return err } // listen begins receiving messages on the pq listener. -func (p *pgPubsub) listen(ctx context.Context) { +func (p *pgPubsub) listen() { + defer close(p.listenDone) + defer p.pgListener.Close() + var ( notif *pq.Notification ok bool ) - defer p.pgListener.Close() for { select { - case <-ctx.Done(): + case <-p.ctx.Done(): return case notif, ok = <-p.pgListener.Notify: if !ok { @@ -292,7 +299,7 @@ func (p *pgPubsub) recordReconnect() { func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub, error) { // Creates a new listener using pq. errCh := make(chan error) - listener := pq.NewListener(connectURL, time.Second, time.Minute, func(event pq.ListenerEventType, err error) { + listener := pq.NewListener(connectURL, time.Second, time.Minute, func(_ pq.ListenerEventType, err error) { // This callback gets events whenever the connection state changes. // Don't send if the errChannel has already been closed. select { @@ -306,18 +313,25 @@ func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub select { case err := <-errCh: if err != nil { + _ = listener.Close() return nil, xerrors.Errorf("create pq listener: %w", err) } case <-ctx.Done(): + _ = listener.Close() return nil, ctx.Err() } + + // Start a new context that will be canceled when the pubsub is closed. + ctx, cancel := context.WithCancel(context.Background()) pgPubsub := &pgPubsub{ ctx: ctx, + cancel: cancel, + listenDone: make(chan struct{}), db: database, pgListener: listener, queues: make(map[string]map[uuid.UUID]*msgQueue), } - go pgPubsub.listen(ctx) + go pgPubsub.listen() return pgPubsub, nil } diff --git a/coderd/database/pubsub_test.go b/coderd/database/pubsub_test.go index e30767cb02085..60fb1821af55d 100644 --- a/coderd/database/pubsub_test.go +++ b/coderd/database/pubsub_test.go @@ -45,11 +45,11 @@ func TestPubsub(t *testing.T) { event := "test" data := "testing" messageChannel := make(chan []byte) - cancelFunc, err = pubsub.Subscribe(event, func(ctx context.Context, message []byte) { + unsub, err := pubsub.Subscribe(event, func(ctx context.Context, message []byte) { messageChannel <- message }) require.NoError(t, err) - defer cancelFunc() + defer unsub() go func() { err = pubsub.Publish(event, []byte(data)) assert.NoError(t, err) @@ -72,6 +72,91 @@ func TestPubsub(t *testing.T) { defer pubsub.Close() cancelFunc() }) + + t.Run("NotClosedOnCancelContext", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + connectionURL, closePg, err := postgres.Open() + require.NoError(t, err) + defer closePg() + db, err := sql.Open("postgres", connectionURL) + require.NoError(t, err) + defer db.Close() + pubsub, err := database.NewPubsub(ctx, db, connectionURL) + require.NoError(t, err) + defer pubsub.Close() + + // Provided context must only be active during NewPubsub, not after. + cancel() + + event := "test" + data := "testing" + messageChannel := make(chan []byte) + unsub, err := pubsub.Subscribe(event, func(_ context.Context, message []byte) { + messageChannel <- message + }) + require.NoError(t, err) + defer unsub() + go func() { + err = pubsub.Publish(event, []byte(data)) + assert.NoError(t, err) + }() + message := <-messageChannel + assert.Equal(t, string(message), data) + }) + + t.Run("ClosePropagatesContextCancellationToSubscription", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + connectionURL, closePg, err := postgres.Open() + require.NoError(t, err) + defer closePg() + db, err := sql.Open("postgres", connectionURL) + require.NoError(t, err) + defer db.Close() + pubsub, err := database.NewPubsub(ctx, db, connectionURL) + require.NoError(t, err) + defer pubsub.Close() + + event := "test" + done := make(chan struct{}) + called := make(chan struct{}) + unsub, err := pubsub.Subscribe(event, func(subCtx context.Context, _ []byte) { + defer close(done) + select { + case <-subCtx.Done(): + assert.Fail(t, "context should not be canceled") + default: + } + close(called) + select { + case <-subCtx.Done(): + case <-ctx.Done(): + assert.Fail(t, "timeout waiting for sub context to be canceled") + } + }) + require.NoError(t, err) + defer unsub() + + go func() { + err := pubsub.Publish(event, nil) + assert.NoError(t, err) + }() + + select { + case <-called: + case <-ctx.Done(): + require.Fail(t, "timeout waiting for handler to be called") + } + err = pubsub.Close() + require.NoError(t, err) + + select { + case <-done: + case <-ctx.Done(): + require.Fail(t, "timeout waiting for handler to finish") + } + }) } func TestPubsub_ordering(t *testing.T) {