From 79b2e92bce33579178395b8a16f6ecf1ac19e206 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Mon, 11 Mar 2024 15:50:40 +0400 Subject: [PATCH] fix: stop holding Pubsub mutex while calling pq.Listener --- coderd/database/pubsub/pubsub.go | 189 ++++++++++-------- .../database/pubsub/pubsub_internal_test.go | 117 +++++++++++ coderd/database/pubsub/pubsub_linux_test.go | 54 ----- 3 files changed, 221 insertions(+), 139 deletions(-) diff --git a/coderd/database/pubsub/pubsub.go b/coderd/database/pubsub/pubsub.go index 33b3c083b66f8..59e5b23c34b00 100644 --- a/coderd/database/pubsub/pubsub.go +++ b/coderd/database/pubsub/pubsub.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "errors" + "io" "net" "sync" "time" @@ -164,16 +165,36 @@ func (q *msgQueue) dropped() { q.cond.Broadcast() } +// pqListener is an interface that represents a *pq.Listener for testing +type pqListener interface { + io.Closer + Listen(string) error + Unlisten(string) error + NotifyChan() <-chan *pq.Notification +} + +type pqListenerShim struct { + *pq.Listener +} + +func (l pqListenerShim) NotifyChan() <-chan *pq.Notification { + return l.Notify +} + // PGPubsub is a pubsub implementation using PostgreSQL. type PGPubsub struct { - ctx context.Context - cancel context.CancelFunc - logger slog.Logger - listenDone chan struct{} - pgListener *pq.Listener - db *sql.DB - mut sync.Mutex - queues map[string]map[uuid.UUID]*msgQueue + logger slog.Logger + listenDone chan struct{} + pgListener pqListener + db *sql.DB + + qMu sync.Mutex + queues map[string]map[uuid.UUID]*msgQueue + + // making the close state its own mutex domain simplifies closing logic so + // that we don't have to hold the qMu --- which could block processing + // notifications while the pqListener is closing. + closeMu sync.Mutex closedListener bool closeListenerErr error @@ -192,16 +213,14 @@ const BufferSize = 2048 // Subscribe calls the listener when an event matching the name is received. func (p *PGPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) { - return p.subscribeQueue(event, newMsgQueue(p.ctx, listener, nil)) + return p.subscribeQueue(event, newMsgQueue(context.Background(), listener, nil)) } func (p *PGPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) { - return p.subscribeQueue(event, newMsgQueue(p.ctx, nil, listener)) + return p.subscribeQueue(event, newMsgQueue(context.Background(), nil, listener)) } func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), err error) { - p.mut.Lock() - defer p.mut.Unlock() defer func() { if err != nil { // if we hit an error, we need to close the queue so we don't @@ -213,9 +232,13 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), } }() + // The pgListener waits for the response to `LISTEN` on a mainloop that also dispatches + // notifies. We need to avoid holding the mutex while this happens, since holding the mutex + // blocks reading notifications and can deadlock the pgListener. + // c.f. https://github.com/coder/coder/issues/11950 err = p.pgListener.Listen(event) if err == nil { - p.logger.Debug(p.ctx, "started listening to event channel", slog.F("event", event)) + p.logger.Debug(context.Background(), "started listening to event channel", slog.F("event", event)) } if errors.Is(err, pq.ErrChannelAlreadyOpen) { // It's ok if it's already open! @@ -224,6 +247,8 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), if err != nil { return nil, xerrors.Errorf("listen: %w", err) } + p.qMu.Lock() + defer p.qMu.Unlock() var eventQs map[uuid.UUID]*msgQueue var ok bool @@ -234,30 +259,36 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), id := uuid.New() eventQs[id] = newQ return func() { - p.mut.Lock() - defer p.mut.Unlock() + p.qMu.Lock() listeners := p.queues[event] q := listeners[id] q.close() delete(listeners, id) + if len(listeners) == 0 { + delete(p.queues, event) + } + p.qMu.Unlock() + // as above, we must not hold the lock while calling into pgListener if len(listeners) == 0 { uErr := p.pgListener.Unlisten(event) + p.closeMu.Lock() + defer p.closeMu.Unlock() if uErr != nil && !p.closedListener { - p.logger.Warn(p.ctx, "failed to unlisten", slog.Error(uErr), slog.F("event", event)) + p.logger.Warn(context.Background(), "failed to unlisten", slog.Error(uErr), slog.F("event", event)) } else { - p.logger.Debug(p.ctx, "stopped listening to event channel", slog.F("event", event)) + p.logger.Debug(context.Background(), "stopped listening to event channel", slog.F("event", event)) } } }, nil } func (p *PGPubsub) Publish(event string, message []byte) error { - p.logger.Debug(p.ctx, "publish", slog.F("event", event), slog.F("message_len", len(message))) + p.logger.Debug(context.Background(), "publish", slog.F("event", event), slog.F("message_len", len(message))) // This is safe because we are calling pq.QuoteLiteral. pg_notify doesn't // support the first parameter being a prepared statement. //nolint:gosec - _, err := p.db.ExecContext(p.ctx, `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message) + _, err := p.db.ExecContext(context.Background(), `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message) if err != nil { p.publishesTotal.WithLabelValues("false").Inc() return xerrors.Errorf("exec pg_notify: %w", err) @@ -269,53 +300,38 @@ func (p *PGPubsub) Publish(event string, message []byte) error { // Close closes the pubsub instance. func (p *PGPubsub) Close() error { - p.logger.Info(p.ctx, "pubsub is closing") - p.cancel() + p.logger.Info(context.Background(), "pubsub is closing") err := p.closeListener() <-p.listenDone - p.logger.Debug(p.ctx, "pubsub closed") + p.logger.Debug(context.Background(), "pubsub closed") return err } // closeListener closes the pgListener, unless it has already been closed. func (p *PGPubsub) closeListener() error { - p.mut.Lock() - defer p.mut.Unlock() + p.closeMu.Lock() + defer p.closeMu.Unlock() if p.closedListener { return p.closeListenerErr } - p.closeListenerErr = p.pgListener.Close() p.closedListener = true + p.closeListenerErr = p.pgListener.Close() + return p.closeListenerErr } // listen begins receiving messages on the pq listener. func (p *PGPubsub) listen() { defer func() { - p.logger.Info(p.ctx, "pubsub listen stopped receiving notify") - cErr := p.closeListener() - if cErr != nil { - p.logger.Error(p.ctx, "failed to close listener") - } + p.logger.Info(context.Background(), "pubsub listen stopped receiving notify") close(p.listenDone) }() - var ( - notif *pq.Notification - ok bool - ) - for { - select { - case <-p.ctx.Done(): - return - case notif, ok = <-p.pgListener.Notify: - if !ok { - return - } - } + notify := p.pgListener.NotifyChan() + for notif := range notify { // A nil notification can be dispatched on reconnect. if notif == nil { - p.logger.Debug(p.ctx, "notifying subscribers of a reconnection") + p.logger.Debug(context.Background(), "notifying subscribers of a reconnection") p.recordReconnect() continue } @@ -331,8 +347,8 @@ func (p *PGPubsub) listenReceive(notif *pq.Notification) { p.messagesTotal.WithLabelValues(sizeLabel).Inc() p.receivedBytesTotal.Add(float64(len(notif.Extra))) - p.mut.Lock() - defer p.mut.Unlock() + p.qMu.Lock() + defer p.qMu.Unlock() queues, ok := p.queues[notif.Channel] if !ok { return @@ -344,8 +360,8 @@ func (p *PGPubsub) listenReceive(notif *pq.Notification) { } func (p *PGPubsub) recordReconnect() { - p.mut.Lock() - defer p.mut.Unlock() + p.qMu.Lock() + defer p.qMu.Unlock() for _, listeners := range p.queues { for _, q := range listeners { q.dropped() @@ -409,30 +425,32 @@ func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error { d: net.Dialer{}, } ) - p.pgListener = pq.NewDialListener(dialer, connectURL, time.Second, time.Minute, func(t pq.ListenerEventType, err error) { - switch t { - case pq.ListenerEventConnected: - p.logger.Info(ctx, "pubsub connected to postgres") - p.connected.Set(1.0) - case pq.ListenerEventDisconnected: - p.logger.Error(ctx, "pubsub disconnected from postgres", slog.Error(err)) - p.connected.Set(0) - case pq.ListenerEventReconnected: - p.logger.Info(ctx, "pubsub reconnected to postgres") - p.connected.Set(1) - case pq.ListenerEventConnectionAttemptFailed: - p.logger.Error(ctx, "pubsub failed to connect to postgres", slog.Error(err)) - } - // This callback gets events whenever the connection state changes. - // Don't send if the errChannel has already been closed. - select { - case <-errCh: - return - default: - errCh <- err - close(errCh) - } - }) + p.pgListener = pqListenerShim{ + Listener: pq.NewDialListener(dialer, connectURL, time.Second, time.Minute, func(t pq.ListenerEventType, err error) { + switch t { + case pq.ListenerEventConnected: + p.logger.Info(ctx, "pubsub connected to postgres") + p.connected.Set(1.0) + case pq.ListenerEventDisconnected: + p.logger.Error(ctx, "pubsub disconnected from postgres", slog.Error(err)) + p.connected.Set(0) + case pq.ListenerEventReconnected: + p.logger.Info(ctx, "pubsub reconnected to postgres") + p.connected.Set(1) + case pq.ListenerEventConnectionAttemptFailed: + p.logger.Error(ctx, "pubsub failed to connect to postgres", slog.Error(err)) + } + // This callback gets events whenever the connection state changes. + // Don't send if the errChannel has already been closed. + select { + case <-errCh: + return + default: + errCh <- err + close(errCh) + } + }), + } select { case err := <-errCh: if err != nil { @@ -501,24 +519,31 @@ func (p *PGPubsub) Collect(metrics chan<- prometheus.Metric) { p.connected.Collect(metrics) // implicit metrics - p.mut.Lock() + p.qMu.Lock() events := len(p.queues) subs := 0 for _, subscriberMap := range p.queues { subs += len(subscriberMap) } - p.mut.Unlock() + p.qMu.Unlock() metrics <- prometheus.MustNewConstMetric(currentSubscribersDesc, prometheus.GaugeValue, float64(subs)) metrics <- prometheus.MustNewConstMetric(currentEventsDesc, prometheus.GaugeValue, float64(events)) } // New creates a new Pubsub implementation using a PostgreSQL connection. func New(startCtx context.Context, logger slog.Logger, database *sql.DB, connectURL string) (*PGPubsub, error) { - // Start a new context that will be canceled when the pubsub is closed. - ctx, cancel := context.WithCancel(context.Background()) - p := &PGPubsub{ - ctx: ctx, - cancel: cancel, + p := newWithoutListener(logger, database) + if err := p.startListener(startCtx, connectURL); err != nil { + return nil, err + } + go p.listen() + logger.Info(startCtx, "pubsub has started") + return p, nil +} + +// newWithoutListener creates a new PGPubsub without creating the pqListener. +func newWithoutListener(logger slog.Logger, database *sql.DB) *PGPubsub { + return &PGPubsub{ logger: logger, listenDone: make(chan struct{}), db: database, @@ -567,10 +592,4 @@ func New(startCtx context.Context, logger slog.Logger, database *sql.DB, connect Help: "Whether we are connected (1) or not connected (0) to postgres", }), } - if err := p.startListener(startCtx, connectURL); err != nil { - return nil, err - } - go p.listen() - logger.Info(ctx, "pubsub has started") - return p, nil } diff --git a/coderd/database/pubsub/pubsub_internal_test.go b/coderd/database/pubsub/pubsub_internal_test.go index 47dd324fc09df..2587357153ee8 100644 --- a/coderd/database/pubsub/pubsub_internal_test.go +++ b/coderd/database/pubsub/pubsub_internal_test.go @@ -3,10 +3,15 @@ package pubsub import ( "context" "fmt" + "sync" "testing" + "github.com/lib/pq" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/testutil" ) @@ -138,3 +143,115 @@ func Test_msgQueue_Full(t *testing.T) { // for the error, so we read 2 less than we sent. require.Equal(t, BufferSize, n) } + +func TestPubSub_DoesntBlockNotify(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + uut := newWithoutListener(logger, nil) + fListener := newFakePqListener() + uut.pgListener = fListener + go uut.listen() + + cancels := make(chan func()) + go func() { + subCancel, err := uut.Subscribe("bagels", func(ctx context.Context, message []byte) { + t.Logf("got message: %s", string(message)) + }) + assert.NoError(t, err) + cancels <- subCancel + }() + subCancel := testutil.RequireRecvCtx(ctx, t, cancels) + cancelDone := make(chan struct{}) + go func() { + defer close(cancelDone) + subCancel() + }() + testutil.RequireRecvCtx(ctx, t, cancelDone) + + closeErrs := make(chan error) + go func() { + closeErrs <- uut.Close() + }() + err := testutil.RequireRecvCtx(ctx, t, closeErrs) + require.NoError(t, err) +} + +const ( + numNotifications = 5 + testMessage = "birds of a feather" +) + +// fakePqListener is a fake version of pq.Listener. This test code tests for regressions of +// https://github.com/coder/coder/issues/11950 where pq.Listener deadlocked because we blocked the +// PGPubsub.listen() goroutine while calling other pq.Listener functions. So, all function calls +// into the fakePqListener will send 5 notifications before returning to ensure the listen() +// goroutine is unblocked. +type fakePqListener struct { + mu sync.Mutex + channels map[string]struct{} + notify chan *pq.Notification +} + +func (f *fakePqListener) Close() error { + f.mu.Lock() + defer f.mu.Unlock() + ch := f.getTestChanLocked() + for i := 0; i < numNotifications; i++ { + f.notify <- &pq.Notification{Channel: ch, Extra: testMessage} + } + // note that the realPqListener must only be closed once, so go ahead and + // close the notify unprotected here. If it panics, we have a bug. + close(f.notify) + return nil +} + +func (f *fakePqListener) Listen(s string) error { + f.mu.Lock() + defer f.mu.Unlock() + ch := f.getTestChanLocked() + for i := 0; i < numNotifications; i++ { + f.notify <- &pq.Notification{Channel: ch, Extra: testMessage} + } + if _, ok := f.channels[s]; ok { + return pq.ErrChannelAlreadyOpen + } + f.channels[s] = struct{}{} + return nil +} + +func (f *fakePqListener) Unlisten(s string) error { + f.mu.Lock() + defer f.mu.Unlock() + ch := f.getTestChanLocked() + for i := 0; i < numNotifications; i++ { + f.notify <- &pq.Notification{Channel: ch, Extra: testMessage} + } + if _, ok := f.channels[s]; ok { + delete(f.channels, s) + return nil + } + return pq.ErrChannelNotOpen +} + +func (f *fakePqListener) NotifyChan() <-chan *pq.Notification { + return f.notify +} + +// getTestChanLocked returns the name of a channel we are currently listening for, if there is one. +// Otherwise, it just returns "test". We prefer to send test notifications for channels that appear +// in the tests, but if there are none, just return anything. +func (f *fakePqListener) getTestChanLocked() string { + for c := range f.channels { + return c + } + return "test" +} + +func newFakePqListener() *fakePqListener { + return &fakePqListener{ + channels: make(map[string]struct{}), + notify: make(chan *pq.Notification), + } +} diff --git a/coderd/database/pubsub/pubsub_linux_test.go b/coderd/database/pubsub/pubsub_linux_test.go index c25af429a5d78..d170c896ee391 100644 --- a/coderd/database/pubsub/pubsub_linux_test.go +++ b/coderd/database/pubsub/pubsub_linux_test.go @@ -109,60 +109,6 @@ func TestPubsub(t *testing.T) { message := <-messageChannel assert.Equal(t, string(message), data) }) - - t.Run("ClosePropagatesContextCancellationToSubscription", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - connectionURL, closePg, err := postgres.Open() - require.NoError(t, err) - defer closePg() - db, err := sql.Open("postgres", connectionURL) - require.NoError(t, err) - defer db.Close() - pubsub, err := pubsub.New(ctx, logger, db, connectionURL) - require.NoError(t, err) - defer pubsub.Close() - - event := "test" - done := make(chan struct{}) - called := make(chan struct{}) - unsub, err := pubsub.Subscribe(event, func(subCtx context.Context, _ []byte) { - defer close(done) - select { - case <-subCtx.Done(): - assert.Fail(t, "context should not be canceled") - default: - } - close(called) - select { - case <-subCtx.Done(): - case <-ctx.Done(): - assert.Fail(t, "timeout waiting for sub context to be canceled") - } - }) - require.NoError(t, err) - defer unsub() - - go func() { - err := pubsub.Publish(event, nil) - assert.NoError(t, err) - }() - - select { - case <-called: - case <-ctx.Done(): - require.Fail(t, "timeout waiting for handler to be called") - } - err = pubsub.Close() - require.NoError(t, err) - - select { - case <-done: - case <-ctx.Done(): - require.Fail(t, "timeout waiting for handler to finish") - } - }) } func TestPubsub_ordering(t *testing.T) {