Skip to content

Commit 005ea53

Browse files
authored
fix: fix Listen/Unlisten race on Pubsub (#15315)
Fixes #15312 When we need to `Unlisten()` for an event, instead of immediately removing the event from the `p.queues`, we store a channel to signal any goroutines trying to Subscribe to the same event when we are done. On `Subscribe`, if the channel is present, wait for it before calling `Listen` to ensure the ordering is correct.
1 parent fbbefa2 commit 005ea53

File tree

2 files changed

+155
-31
lines changed

2 files changed

+155
-31
lines changed

coderd/database/pubsub/pubsub.go

Lines changed: 93 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import (
1111
"sync/atomic"
1212
"time"
1313

14-
"github.com/google/uuid"
1514
"github.com/lib/pq"
1615
"github.com/prometheus/client_golang/prometheus"
1716
"golang.org/x/xerrors"
@@ -188,6 +187,19 @@ func (l pqListenerShim) NotifyChan() <-chan *pq.Notification {
188187
return l.Notify
189188
}
190189

190+
type queueSet struct {
191+
m map[*msgQueue]struct{}
192+
// unlistenInProgress will be non-nil if another goroutine is unlistening for the event this
193+
// queueSet corresponds to. If non-nil, that goroutine will close the channel when it is done.
194+
unlistenInProgress chan struct{}
195+
}
196+
197+
func newQueueSet() *queueSet {
198+
return &queueSet{
199+
m: make(map[*msgQueue]struct{}),
200+
}
201+
}
202+
191203
// PGPubsub is a pubsub implementation using PostgreSQL.
192204
type PGPubsub struct {
193205
logger slog.Logger
@@ -196,7 +208,7 @@ type PGPubsub struct {
196208
db *sql.DB
197209

198210
qMu sync.Mutex
199-
queues map[string]map[uuid.UUID]*msgQueue
211+
queues map[string]*queueSet
200212

201213
// making the close state its own mutex domain simplifies closing logic so
202214
// that we don't have to hold the qMu --- which could block processing
@@ -243,6 +255,48 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(),
243255
}
244256
}()
245257

258+
var (
259+
unlistenInProgress <-chan struct{}
260+
// MUST hold the p.qMu lock to manipulate this!
261+
qs *queueSet
262+
)
263+
func() {
264+
p.qMu.Lock()
265+
defer p.qMu.Unlock()
266+
267+
var ok bool
268+
if qs, ok = p.queues[event]; !ok {
269+
qs = newQueueSet()
270+
p.queues[event] = qs
271+
}
272+
qs.m[newQ] = struct{}{}
273+
unlistenInProgress = qs.unlistenInProgress
274+
}()
275+
// NOTE there cannot be any `return` statements between here and the next +-+, otherwise the
276+
// assumptions the defer makes could be violated
277+
if unlistenInProgress != nil {
278+
// We have to wait here because we don't want our `Listen` call to happen before the other
279+
// goroutine calls `Unlisten`. That would result in this subscription not getting any
280+
// events. c.f. https://github.com/coder/coder/issues/15312
281+
p.logger.Debug(context.Background(), "waiting for Unlisten in progress", slog.F("event", event))
282+
<-unlistenInProgress
283+
p.logger.Debug(context.Background(), "unlistening complete", slog.F("event", event))
284+
}
285+
// +-+ (see above)
286+
defer func() {
287+
if err != nil {
288+
p.qMu.Lock()
289+
defer p.qMu.Unlock()
290+
delete(qs.m, newQ)
291+
if len(qs.m) == 0 {
292+
// we know that newQ was in the queueSet since we last unlocked, so there cannot
293+
// have been any _new_ goroutines trying to Unlisten(). Therefore, if the queueSet
294+
// is now empty, it's safe to delete.
295+
delete(p.queues, event)
296+
}
297+
}
298+
}()
299+
246300
// The pgListener waits for the response to `LISTEN` on a mainloop that also dispatches
247301
// notifies. We need to avoid holding the mutex while this happens, since holding the mutex
248302
// blocks reading notifications and can deadlock the pgListener.
@@ -258,32 +312,40 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(),
258312
if err != nil {
259313
return nil, xerrors.Errorf("listen: %w", err)
260314
}
261-
p.qMu.Lock()
262-
defer p.qMu.Unlock()
263315

264-
var eventQs map[uuid.UUID]*msgQueue
265-
var ok bool
266-
if eventQs, ok = p.queues[event]; !ok {
267-
eventQs = make(map[uuid.UUID]*msgQueue)
268-
p.queues[event] = eventQs
269-
}
270-
id := uuid.New()
271-
eventQs[id] = newQ
272316
return func() {
273-
p.qMu.Lock()
274-
listeners := p.queues[event]
275-
q := listeners[id]
276-
q.close()
277-
delete(listeners, id)
278-
if len(listeners) == 0 {
279-
delete(p.queues, event)
280-
}
281-
listenerCount := len(listeners)
282-
p.qMu.Unlock()
283-
// as above, we must not hold the lock while calling into pgListener
317+
var unlistening chan struct{}
318+
func() {
319+
p.qMu.Lock()
320+
defer p.qMu.Unlock()
321+
newQ.close()
322+
qSet, ok := p.queues[event]
323+
if !ok {
324+
p.logger.Critical(context.Background(), "event was removed before cancel", slog.F("event", event))
325+
return
326+
}
327+
delete(qSet.m, newQ)
328+
if len(qSet.m) == 0 {
329+
unlistening = make(chan struct{})
330+
qSet.unlistenInProgress = unlistening
331+
}
332+
}()
284333

285-
if listenerCount == 0 {
334+
// as above, we must not hold the lock while calling into pgListener
335+
if unlistening != nil {
286336
uErr := p.pgListener.Unlisten(event)
337+
close(unlistening)
338+
// we can now delete the queueSet if it is empty.
339+
func() {
340+
p.qMu.Lock()
341+
defer p.qMu.Unlock()
342+
qSet, ok := p.queues[event]
343+
if ok && len(qSet.m) == 0 {
344+
p.logger.Debug(context.Background(), "removing queueSet", slog.F("event", event))
345+
delete(p.queues, event)
346+
}
347+
}()
348+
287349
p.closeMu.Lock()
288350
defer p.closeMu.Unlock()
289351
if uErr != nil && !p.closedListener {
@@ -361,21 +423,21 @@ func (p *PGPubsub) listenReceive(notif *pq.Notification) {
361423

362424
p.qMu.Lock()
363425
defer p.qMu.Unlock()
364-
queues, ok := p.queues[notif.Channel]
426+
qSet, ok := p.queues[notif.Channel]
365427
if !ok {
366428
return
367429
}
368430
extra := []byte(notif.Extra)
369-
for _, q := range queues {
431+
for q := range qSet.m {
370432
q.enqueue(extra)
371433
}
372434
}
373435

374436
func (p *PGPubsub) recordReconnect() {
375437
p.qMu.Lock()
376438
defer p.qMu.Unlock()
377-
for _, listeners := range p.queues {
378-
for _, q := range listeners {
439+
for _, qSet := range p.queues {
440+
for q := range qSet.m {
379441
q.dropped()
380442
}
381443
}
@@ -590,8 +652,8 @@ func (p *PGPubsub) Collect(metrics chan<- prometheus.Metric) {
590652
p.qMu.Lock()
591653
events := len(p.queues)
592654
subs := 0
593-
for _, subscriberMap := range p.queues {
594-
subs += len(subscriberMap)
655+
for _, qSet := range p.queues {
656+
subs += len(qSet.m)
595657
}
596658
p.qMu.Unlock()
597659
metrics <- prometheus.MustNewConstMetric(currentSubscribersDesc, prometheus.GaugeValue, float64(subs))
@@ -629,7 +691,7 @@ func newWithoutListener(logger slog.Logger, db *sql.DB) *PGPubsub {
629691
logger: logger,
630692
listenDone: make(chan struct{}),
631693
db: db,
632-
queues: make(map[string]map[uuid.UUID]*msgQueue),
694+
queues: make(map[string]*queueSet),
633695
latencyMeasurer: NewLatencyMeasurer(logger.Named("latency-measurer")),
634696

635697
publishesTotal: prometheus.NewCounterVec(prometheus.CounterOpts{

coderd/database/pubsub/pubsub_internal_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,60 @@ func TestPubSub_DoesntBlockNotify(t *testing.T) {
178178
require.NoError(t, err)
179179
}
180180

181+
// TestPubSub_DoesntRaceListenUnlisten tests for regressions of
182+
// https://github.com/coder/coder/issues/15312
183+
func TestPubSub_DoesntRaceListenUnlisten(t *testing.T) {
184+
t.Parallel()
185+
ctx := testutil.Context(t, testutil.WaitShort)
186+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
187+
188+
uut := newWithoutListener(logger, nil)
189+
fListener := newFakePqListener()
190+
uut.pgListener = fListener
191+
go uut.listen()
192+
193+
noopListener := func(_ context.Context, _ []byte) {}
194+
195+
const numEvents = 500
196+
events := make([]string, numEvents)
197+
cancels := make([]func(), numEvents)
198+
for i := range events {
199+
var err error
200+
events[i] = fmt.Sprintf("event-%d", i)
201+
cancels[i], err = uut.Subscribe(events[i], noopListener)
202+
require.NoError(t, err)
203+
}
204+
start := make(chan struct{})
205+
done := make(chan struct{})
206+
finalCancels := make([]func(), numEvents)
207+
for i := range events {
208+
event := events[i]
209+
cancel := cancels[i]
210+
go func() {
211+
<-start
212+
var err error
213+
// subscribe again
214+
finalCancels[i], err = uut.Subscribe(event, noopListener)
215+
assert.NoError(t, err)
216+
done <- struct{}{}
217+
}()
218+
go func() {
219+
<-start
220+
cancel()
221+
done <- struct{}{}
222+
}()
223+
}
224+
close(start)
225+
for range numEvents * 2 {
226+
_ = testutil.RequireRecvCtx(ctx, t, done)
227+
}
228+
for i := range events {
229+
fListener.requireIsListening(t, events[i])
230+
finalCancels[i]()
231+
}
232+
require.Len(t, uut.queues, 0)
233+
}
234+
181235
const (
182236
numNotifications = 5
183237
testMessage = "birds of a feather"
@@ -255,3 +309,11 @@ func newFakePqListener() *fakePqListener {
255309
notify: make(chan *pq.Notification),
256310
}
257311
}
312+
313+
func (f *fakePqListener) requireIsListening(t testing.TB, s string) {
314+
t.Helper()
315+
f.mu.Lock()
316+
defer f.mu.Unlock()
317+
_, ok := f.channels[s]
318+
require.True(t, ok, "should be listening for '%s', but isn't", s)
319+
}

0 commit comments

Comments
 (0)