Skip to content

Commit 79b2e92

Browse files
committed
fix: stop holding Pubsub mutex while calling pq.Listener
1 parent bae0a74 commit 79b2e92

File tree

3 files changed

+221
-139
lines changed

3 files changed

+221
-139
lines changed

coderd/database/pubsub/pubsub.go

+104-85
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"database/sql"
66
"errors"
7+
"io"
78
"net"
89
"sync"
910
"time"
@@ -164,16 +165,36 @@ func (q *msgQueue) dropped() {
164165
q.cond.Broadcast()
165166
}
166167

168+
// pqListener is an interface that represents a *pq.Listener for testing
169+
type pqListener interface {
170+
io.Closer
171+
Listen(string) error
172+
Unlisten(string) error
173+
NotifyChan() <-chan *pq.Notification
174+
}
175+
176+
type pqListenerShim struct {
177+
*pq.Listener
178+
}
179+
180+
func (l pqListenerShim) NotifyChan() <-chan *pq.Notification {
181+
return l.Notify
182+
}
183+
167184
// PGPubsub is a pubsub implementation using PostgreSQL.
168185
type PGPubsub struct {
169-
ctx context.Context
170-
cancel context.CancelFunc
171-
logger slog.Logger
172-
listenDone chan struct{}
173-
pgListener *pq.Listener
174-
db *sql.DB
175-
mut sync.Mutex
176-
queues map[string]map[uuid.UUID]*msgQueue
186+
logger slog.Logger
187+
listenDone chan struct{}
188+
pgListener pqListener
189+
db *sql.DB
190+
191+
qMu sync.Mutex
192+
queues map[string]map[uuid.UUID]*msgQueue
193+
194+
// making the close state its own mutex domain simplifies closing logic so
195+
// that we don't have to hold the qMu --- which could block processing
196+
// notifications while the pqListener is closing.
197+
closeMu sync.Mutex
177198
closedListener bool
178199
closeListenerErr error
179200

@@ -192,16 +213,14 @@ const BufferSize = 2048
192213

193214
// Subscribe calls the listener when an event matching the name is received.
194215
func (p *PGPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
195-
return p.subscribeQueue(event, newMsgQueue(p.ctx, listener, nil))
216+
return p.subscribeQueue(event, newMsgQueue(context.Background(), listener, nil))
196217
}
197218

198219
func (p *PGPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) {
199-
return p.subscribeQueue(event, newMsgQueue(p.ctx, nil, listener))
220+
return p.subscribeQueue(event, newMsgQueue(context.Background(), nil, listener))
200221
}
201222

202223
func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), err error) {
203-
p.mut.Lock()
204-
defer p.mut.Unlock()
205224
defer func() {
206225
if err != nil {
207226
// if we hit an error, we need to close the queue so we don't
@@ -213,9 +232,13 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(),
213232
}
214233
}()
215234

235+
// The pgListener waits for the response to `LISTEN` on a mainloop that also dispatches
236+
// notifies. We need to avoid holding the mutex while this happens, since holding the mutex
237+
// blocks reading notifications and can deadlock the pgListener.
238+
// c.f. https://github.com/coder/coder/issues/11950
216239
err = p.pgListener.Listen(event)
217240
if err == nil {
218-
p.logger.Debug(p.ctx, "started listening to event channel", slog.F("event", event))
241+
p.logger.Debug(context.Background(), "started listening to event channel", slog.F("event", event))
219242
}
220243
if errors.Is(err, pq.ErrChannelAlreadyOpen) {
221244
// It's ok if it's already open!
@@ -224,6 +247,8 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(),
224247
if err != nil {
225248
return nil, xerrors.Errorf("listen: %w", err)
226249
}
250+
p.qMu.Lock()
251+
defer p.qMu.Unlock()
227252

228253
var eventQs map[uuid.UUID]*msgQueue
229254
var ok bool
@@ -234,30 +259,36 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(),
234259
id := uuid.New()
235260
eventQs[id] = newQ
236261
return func() {
237-
p.mut.Lock()
238-
defer p.mut.Unlock()
262+
p.qMu.Lock()
239263
listeners := p.queues[event]
240264
q := listeners[id]
241265
q.close()
242266
delete(listeners, id)
267+
if len(listeners) == 0 {
268+
delete(p.queues, event)
269+
}
270+
p.qMu.Unlock()
271+
// as above, we must not hold the lock while calling into pgListener
243272

244273
if len(listeners) == 0 {
245274
uErr := p.pgListener.Unlisten(event)
275+
p.closeMu.Lock()
276+
defer p.closeMu.Unlock()
246277
if uErr != nil && !p.closedListener {
247-
p.logger.Warn(p.ctx, "failed to unlisten", slog.Error(uErr), slog.F("event", event))
278+
p.logger.Warn(context.Background(), "failed to unlisten", slog.Error(uErr), slog.F("event", event))
248279
} else {
249-
p.logger.Debug(p.ctx, "stopped listening to event channel", slog.F("event", event))
280+
p.logger.Debug(context.Background(), "stopped listening to event channel", slog.F("event", event))
250281
}
251282
}
252283
}, nil
253284
}
254285

255286
func (p *PGPubsub) Publish(event string, message []byte) error {
256-
p.logger.Debug(p.ctx, "publish", slog.F("event", event), slog.F("message_len", len(message)))
287+
p.logger.Debug(context.Background(), "publish", slog.F("event", event), slog.F("message_len", len(message)))
257288
// This is safe because we are calling pq.QuoteLiteral. pg_notify doesn't
258289
// support the first parameter being a prepared statement.
259290
//nolint:gosec
260-
_, err := p.db.ExecContext(p.ctx, `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message)
291+
_, err := p.db.ExecContext(context.Background(), `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message)
261292
if err != nil {
262293
p.publishesTotal.WithLabelValues("false").Inc()
263294
return xerrors.Errorf("exec pg_notify: %w", err)
@@ -269,53 +300,38 @@ func (p *PGPubsub) Publish(event string, message []byte) error {
269300

270301
// Close closes the pubsub instance.
271302
func (p *PGPubsub) Close() error {
272-
p.logger.Info(p.ctx, "pubsub is closing")
273-
p.cancel()
303+
p.logger.Info(context.Background(), "pubsub is closing")
274304
err := p.closeListener()
275305
<-p.listenDone
276-
p.logger.Debug(p.ctx, "pubsub closed")
306+
p.logger.Debug(context.Background(), "pubsub closed")
277307
return err
278308
}
279309

280310
// closeListener closes the pgListener, unless it has already been closed.
281311
func (p *PGPubsub) closeListener() error {
282-
p.mut.Lock()
283-
defer p.mut.Unlock()
312+
p.closeMu.Lock()
313+
defer p.closeMu.Unlock()
284314
if p.closedListener {
285315
return p.closeListenerErr
286316
}
287-
p.closeListenerErr = p.pgListener.Close()
288317
p.closedListener = true
318+
p.closeListenerErr = p.pgListener.Close()
319+
289320
return p.closeListenerErr
290321
}
291322

292323
// listen begins receiving messages on the pq listener.
293324
func (p *PGPubsub) listen() {
294325
defer func() {
295-
p.logger.Info(p.ctx, "pubsub listen stopped receiving notify")
296-
cErr := p.closeListener()
297-
if cErr != nil {
298-
p.logger.Error(p.ctx, "failed to close listener")
299-
}
326+
p.logger.Info(context.Background(), "pubsub listen stopped receiving notify")
300327
close(p.listenDone)
301328
}()
302329

303-
var (
304-
notif *pq.Notification
305-
ok bool
306-
)
307-
for {
308-
select {
309-
case <-p.ctx.Done():
310-
return
311-
case notif, ok = <-p.pgListener.Notify:
312-
if !ok {
313-
return
314-
}
315-
}
330+
notify := p.pgListener.NotifyChan()
331+
for notif := range notify {
316332
// A nil notification can be dispatched on reconnect.
317333
if notif == nil {
318-
p.logger.Debug(p.ctx, "notifying subscribers of a reconnection")
334+
p.logger.Debug(context.Background(), "notifying subscribers of a reconnection")
319335
p.recordReconnect()
320336
continue
321337
}
@@ -331,8 +347,8 @@ func (p *PGPubsub) listenReceive(notif *pq.Notification) {
331347
p.messagesTotal.WithLabelValues(sizeLabel).Inc()
332348
p.receivedBytesTotal.Add(float64(len(notif.Extra)))
333349

334-
p.mut.Lock()
335-
defer p.mut.Unlock()
350+
p.qMu.Lock()
351+
defer p.qMu.Unlock()
336352
queues, ok := p.queues[notif.Channel]
337353
if !ok {
338354
return
@@ -344,8 +360,8 @@ func (p *PGPubsub) listenReceive(notif *pq.Notification) {
344360
}
345361

346362
func (p *PGPubsub) recordReconnect() {
347-
p.mut.Lock()
348-
defer p.mut.Unlock()
363+
p.qMu.Lock()
364+
defer p.qMu.Unlock()
349365
for _, listeners := range p.queues {
350366
for _, q := range listeners {
351367
q.dropped()
@@ -409,30 +425,32 @@ func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
409425
d: net.Dialer{},
410426
}
411427
)
412-
p.pgListener = pq.NewDialListener(dialer, connectURL, time.Second, time.Minute, func(t pq.ListenerEventType, err error) {
413-
switch t {
414-
case pq.ListenerEventConnected:
415-
p.logger.Info(ctx, "pubsub connected to postgres")
416-
p.connected.Set(1.0)
417-
case pq.ListenerEventDisconnected:
418-
p.logger.Error(ctx, "pubsub disconnected from postgres", slog.Error(err))
419-
p.connected.Set(0)
420-
case pq.ListenerEventReconnected:
421-
p.logger.Info(ctx, "pubsub reconnected to postgres")
422-
p.connected.Set(1)
423-
case pq.ListenerEventConnectionAttemptFailed:
424-
p.logger.Error(ctx, "pubsub failed to connect to postgres", slog.Error(err))
425-
}
426-
// This callback gets events whenever the connection state changes.
427-
// Don't send if the errChannel has already been closed.
428-
select {
429-
case <-errCh:
430-
return
431-
default:
432-
errCh <- err
433-
close(errCh)
434-
}
435-
})
428+
p.pgListener = pqListenerShim{
429+
Listener: pq.NewDialListener(dialer, connectURL, time.Second, time.Minute, func(t pq.ListenerEventType, err error) {
430+
switch t {
431+
case pq.ListenerEventConnected:
432+
p.logger.Info(ctx, "pubsub connected to postgres")
433+
p.connected.Set(1.0)
434+
case pq.ListenerEventDisconnected:
435+
p.logger.Error(ctx, "pubsub disconnected from postgres", slog.Error(err))
436+
p.connected.Set(0)
437+
case pq.ListenerEventReconnected:
438+
p.logger.Info(ctx, "pubsub reconnected to postgres")
439+
p.connected.Set(1)
440+
case pq.ListenerEventConnectionAttemptFailed:
441+
p.logger.Error(ctx, "pubsub failed to connect to postgres", slog.Error(err))
442+
}
443+
// This callback gets events whenever the connection state changes.
444+
// Don't send if the errChannel has already been closed.
445+
select {
446+
case <-errCh:
447+
return
448+
default:
449+
errCh <- err
450+
close(errCh)
451+
}
452+
}),
453+
}
436454
select {
437455
case err := <-errCh:
438456
if err != nil {
@@ -501,24 +519,31 @@ func (p *PGPubsub) Collect(metrics chan<- prometheus.Metric) {
501519
p.connected.Collect(metrics)
502520

503521
// implicit metrics
504-
p.mut.Lock()
522+
p.qMu.Lock()
505523
events := len(p.queues)
506524
subs := 0
507525
for _, subscriberMap := range p.queues {
508526
subs += len(subscriberMap)
509527
}
510-
p.mut.Unlock()
528+
p.qMu.Unlock()
511529
metrics <- prometheus.MustNewConstMetric(currentSubscribersDesc, prometheus.GaugeValue, float64(subs))
512530
metrics <- prometheus.MustNewConstMetric(currentEventsDesc, prometheus.GaugeValue, float64(events))
513531
}
514532

515533
// New creates a new Pubsub implementation using a PostgreSQL connection.
516534
func New(startCtx context.Context, logger slog.Logger, database *sql.DB, connectURL string) (*PGPubsub, error) {
517-
// Start a new context that will be canceled when the pubsub is closed.
518-
ctx, cancel := context.WithCancel(context.Background())
519-
p := &PGPubsub{
520-
ctx: ctx,
521-
cancel: cancel,
535+
p := newWithoutListener(logger, database)
536+
if err := p.startListener(startCtx, connectURL); err != nil {
537+
return nil, err
538+
}
539+
go p.listen()
540+
logger.Info(startCtx, "pubsub has started")
541+
return p, nil
542+
}
543+
544+
// newWithoutListener creates a new PGPubsub without creating the pqListener.
545+
func newWithoutListener(logger slog.Logger, database *sql.DB) *PGPubsub {
546+
return &PGPubsub{
522547
logger: logger,
523548
listenDone: make(chan struct{}),
524549
db: database,
@@ -567,10 +592,4 @@ func New(startCtx context.Context, logger slog.Logger, database *sql.DB, connect
567592
Help: "Whether we are connected (1) or not connected (0) to postgres",
568593
}),
569594
}
570-
if err := p.startListener(startCtx, connectURL); err != nil {
571-
return nil, err
572-
}
573-
go p.listen()
574-
logger.Info(ctx, "pubsub has started")
575-
return p, nil
576595
}

0 commit comments

Comments
 (0)