diff --git a/coderd/database/pubsub/pubsub.go b/coderd/database/pubsub/pubsub.go index c0130e3deac04..6823dc0188ef3 100644 --- a/coderd/database/pubsub/pubsub.go +++ b/coderd/database/pubsub/pubsub.go @@ -11,7 +11,6 @@ import ( "sync/atomic" "time" - "github.com/google/uuid" "github.com/lib/pq" "github.com/prometheus/client_golang/prometheus" "golang.org/x/xerrors" @@ -188,6 +187,19 @@ func (l pqListenerShim) NotifyChan() <-chan *pq.Notification { return l.Notify } +type queueSet struct { + m map[*msgQueue]struct{} + // unlistenInProgress will be non-nil if another goroutine is unlistening for the event this + // queueSet corresponds to. If non-nil, that goroutine will close the channel when it is done. + unlistenInProgress chan struct{} +} + +func newQueueSet() *queueSet { + return &queueSet{ + m: make(map[*msgQueue]struct{}), + } +} + // PGPubsub is a pubsub implementation using PostgreSQL. type PGPubsub struct { logger slog.Logger @@ -196,7 +208,7 @@ type PGPubsub struct { db *sql.DB qMu sync.Mutex - queues map[string]map[uuid.UUID]*msgQueue + queues map[string]*queueSet // making the close state its own mutex domain simplifies closing logic so // that we don't have to hold the qMu --- which could block processing @@ -243,6 +255,48 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), } }() + var ( + unlistenInProgress <-chan struct{} + // MUST hold the p.qMu lock to manipulate this! + qs *queueSet + ) + func() { + p.qMu.Lock() + defer p.qMu.Unlock() + + var ok bool + if qs, ok = p.queues[event]; !ok { + qs = newQueueSet() + p.queues[event] = qs + } + qs.m[newQ] = struct{}{} + unlistenInProgress = qs.unlistenInProgress + }() + // NOTE there cannot be any `return` statements between here and the next +-+, otherwise the + // assumptions the defer makes could be violated + if unlistenInProgress != nil { + // We have to wait here because we don't want our `Listen` call to happen before the other + // goroutine calls `Unlisten`. That would result in this subscription not getting any + // events. c.f. https://github.com/coder/coder/issues/15312 + p.logger.Debug(context.Background(), "waiting for Unlisten in progress", slog.F("event", event)) + <-unlistenInProgress + p.logger.Debug(context.Background(), "unlistening complete", slog.F("event", event)) + } + // +-+ (see above) + defer func() { + if err != nil { + p.qMu.Lock() + defer p.qMu.Unlock() + delete(qs.m, newQ) + if len(qs.m) == 0 { + // we know that newQ was in the queueSet since we last unlocked, so there cannot + // have been any _new_ goroutines trying to Unlisten(). Therefore, if the queueSet + // is now empty, it's safe to delete. + delete(p.queues, event) + } + } + }() + // The pgListener waits for the response to `LISTEN` on a mainloop that also dispatches // notifies. We need to avoid holding the mutex while this happens, since holding the mutex // blocks reading notifications and can deadlock the pgListener. @@ -258,32 +312,40 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), if err != nil { return nil, xerrors.Errorf("listen: %w", err) } - p.qMu.Lock() - defer p.qMu.Unlock() - var eventQs map[uuid.UUID]*msgQueue - var ok bool - if eventQs, ok = p.queues[event]; !ok { - eventQs = make(map[uuid.UUID]*msgQueue) - p.queues[event] = eventQs - } - id := uuid.New() - eventQs[id] = newQ return func() { - p.qMu.Lock() - listeners := p.queues[event] - q := listeners[id] - q.close() - delete(listeners, id) - if len(listeners) == 0 { - delete(p.queues, event) - } - listenerCount := len(listeners) - p.qMu.Unlock() - // as above, we must not hold the lock while calling into pgListener + var unlistening chan struct{} + func() { + p.qMu.Lock() + defer p.qMu.Unlock() + newQ.close() + qSet, ok := p.queues[event] + if !ok { + p.logger.Critical(context.Background(), "event was removed before cancel", slog.F("event", event)) + return + } + delete(qSet.m, newQ) + if len(qSet.m) == 0 { + unlistening = make(chan struct{}) + qSet.unlistenInProgress = unlistening + } + }() - if listenerCount == 0 { + // as above, we must not hold the lock while calling into pgListener + if unlistening != nil { uErr := p.pgListener.Unlisten(event) + close(unlistening) + // we can now delete the queueSet if it is empty. + func() { + p.qMu.Lock() + defer p.qMu.Unlock() + qSet, ok := p.queues[event] + if ok && len(qSet.m) == 0 { + p.logger.Debug(context.Background(), "removing queueSet", slog.F("event", event)) + delete(p.queues, event) + } + }() + p.closeMu.Lock() defer p.closeMu.Unlock() if uErr != nil && !p.closedListener { @@ -361,12 +423,12 @@ func (p *PGPubsub) listenReceive(notif *pq.Notification) { p.qMu.Lock() defer p.qMu.Unlock() - queues, ok := p.queues[notif.Channel] + qSet, ok := p.queues[notif.Channel] if !ok { return } extra := []byte(notif.Extra) - for _, q := range queues { + for q := range qSet.m { q.enqueue(extra) } } @@ -374,8 +436,8 @@ func (p *PGPubsub) listenReceive(notif *pq.Notification) { func (p *PGPubsub) recordReconnect() { p.qMu.Lock() defer p.qMu.Unlock() - for _, listeners := range p.queues { - for _, q := range listeners { + for _, qSet := range p.queues { + for q := range qSet.m { q.dropped() } } @@ -590,8 +652,8 @@ func (p *PGPubsub) Collect(metrics chan<- prometheus.Metric) { p.qMu.Lock() events := len(p.queues) subs := 0 - for _, subscriberMap := range p.queues { - subs += len(subscriberMap) + for _, qSet := range p.queues { + subs += len(qSet.m) } p.qMu.Unlock() metrics <- prometheus.MustNewConstMetric(currentSubscribersDesc, prometheus.GaugeValue, float64(subs)) @@ -629,7 +691,7 @@ func newWithoutListener(logger slog.Logger, db *sql.DB) *PGPubsub { logger: logger, listenDone: make(chan struct{}), db: db, - queues: make(map[string]map[uuid.UUID]*msgQueue), + queues: make(map[string]*queueSet), latencyMeasurer: NewLatencyMeasurer(logger.Named("latency-measurer")), publishesTotal: prometheus.NewCounterVec(prometheus.CounterOpts{ diff --git a/coderd/database/pubsub/pubsub_internal_test.go b/coderd/database/pubsub/pubsub_internal_test.go index 2587357153ee8..df54ca5498f34 100644 --- a/coderd/database/pubsub/pubsub_internal_test.go +++ b/coderd/database/pubsub/pubsub_internal_test.go @@ -178,6 +178,60 @@ func TestPubSub_DoesntBlockNotify(t *testing.T) { require.NoError(t, err) } +// TestPubSub_DoesntRaceListenUnlisten tests for regressions of +// https://github.com/coder/coder/issues/15312 +func TestPubSub_DoesntRaceListenUnlisten(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + uut := newWithoutListener(logger, nil) + fListener := newFakePqListener() + uut.pgListener = fListener + go uut.listen() + + noopListener := func(_ context.Context, _ []byte) {} + + const numEvents = 500 + events := make([]string, numEvents) + cancels := make([]func(), numEvents) + for i := range events { + var err error + events[i] = fmt.Sprintf("event-%d", i) + cancels[i], err = uut.Subscribe(events[i], noopListener) + require.NoError(t, err) + } + start := make(chan struct{}) + done := make(chan struct{}) + finalCancels := make([]func(), numEvents) + for i := range events { + event := events[i] + cancel := cancels[i] + go func() { + <-start + var err error + // subscribe again + finalCancels[i], err = uut.Subscribe(event, noopListener) + assert.NoError(t, err) + done <- struct{}{} + }() + go func() { + <-start + cancel() + done <- struct{}{} + }() + } + close(start) + for range numEvents * 2 { + _ = testutil.RequireRecvCtx(ctx, t, done) + } + for i := range events { + fListener.requireIsListening(t, events[i]) + finalCancels[i]() + } + require.Len(t, uut.queues, 0) +} + const ( numNotifications = 5 testMessage = "birds of a feather" @@ -255,3 +309,11 @@ func newFakePqListener() *fakePqListener { notify: make(chan *pq.Notification), } } + +func (f *fakePqListener) requireIsListening(t testing.TB, s string) { + t.Helper() + f.mu.Lock() + defer f.mu.Unlock() + _, ok := f.channels[s] + require.True(t, ok, "should be listening for '%s', but isn't", s) +}