Skip to content

fix: stop holding Pubsub mutex while calling pq.Listener #12518

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 104 additions & 85 deletions coderd/database/pubsub/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"errors"
"io"
"net"
"sync"
"time"
Expand Down Expand Up @@ -164,16 +165,36 @@ func (q *msgQueue) dropped() {
q.cond.Broadcast()
}

// pqListener is an interface that represents a *pq.Listener for testing
type pqListener interface {
io.Closer
Listen(string) error
Unlisten(string) error
NotifyChan() <-chan *pq.Notification
}

type pqListenerShim struct {
*pq.Listener
}

func (l pqListenerShim) NotifyChan() <-chan *pq.Notification {
return l.Notify
}

// PGPubsub is a pubsub implementation using PostgreSQL.
type PGPubsub struct {
ctx context.Context
cancel context.CancelFunc
logger slog.Logger
listenDone chan struct{}
pgListener *pq.Listener
db *sql.DB
mut sync.Mutex
queues map[string]map[uuid.UUID]*msgQueue
logger slog.Logger
listenDone chan struct{}
pgListener pqListener
db *sql.DB

qMu sync.Mutex
queues map[string]map[uuid.UUID]*msgQueue

// making the close state its own mutex domain simplifies closing logic so
// that we don't have to hold the qMu --- which could block processing
// notifications while the pqListener is closing.
closeMu sync.Mutex
closedListener bool
closeListenerErr error

Expand All @@ -192,16 +213,14 @@ const BufferSize = 2048

// Subscribe calls the listener when an event matching the name is received.
func (p *PGPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
return p.subscribeQueue(event, newMsgQueue(p.ctx, listener, nil))
return p.subscribeQueue(event, newMsgQueue(context.Background(), listener, nil))
}

func (p *PGPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) {
return p.subscribeQueue(event, newMsgQueue(p.ctx, nil, listener))
return p.subscribeQueue(event, newMsgQueue(context.Background(), nil, listener))
}

func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), err error) {
p.mut.Lock()
defer p.mut.Unlock()
defer func() {
if err != nil {
// if we hit an error, we need to close the queue so we don't
Expand All @@ -213,9 +232,13 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(),
}
}()

// The pgListener waits for the response to `LISTEN` on a mainloop that also dispatches
// notifies. We need to avoid holding the mutex while this happens, since holding the mutex
// blocks reading notifications and can deadlock the pgListener.
// c.f. https://github.com/coder/coder/issues/11950
err = p.pgListener.Listen(event)
if err == nil {
p.logger.Debug(p.ctx, "started listening to event channel", slog.F("event", event))
p.logger.Debug(context.Background(), "started listening to event channel", slog.F("event", event))
}
if errors.Is(err, pq.ErrChannelAlreadyOpen) {
// It's ok if it's already open!
Expand All @@ -224,6 +247,8 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(),
if err != nil {
return nil, xerrors.Errorf("listen: %w", err)
}
p.qMu.Lock()
defer p.qMu.Unlock()

var eventQs map[uuid.UUID]*msgQueue
var ok bool
Expand All @@ -234,30 +259,36 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(),
id := uuid.New()
eventQs[id] = newQ
return func() {
p.mut.Lock()
defer p.mut.Unlock()
p.qMu.Lock()
listeners := p.queues[event]
q := listeners[id]
q.close()
delete(listeners, id)
if len(listeners) == 0 {
delete(p.queues, event)
}
p.qMu.Unlock()
// as above, we must not hold the lock while calling into pgListener

if len(listeners) == 0 {
uErr := p.pgListener.Unlisten(event)
p.closeMu.Lock()
defer p.closeMu.Unlock()
if uErr != nil && !p.closedListener {
p.logger.Warn(p.ctx, "failed to unlisten", slog.Error(uErr), slog.F("event", event))
p.logger.Warn(context.Background(), "failed to unlisten", slog.Error(uErr), slog.F("event", event))
} else {
p.logger.Debug(p.ctx, "stopped listening to event channel", slog.F("event", event))
p.logger.Debug(context.Background(), "stopped listening to event channel", slog.F("event", event))
}
}
}, nil
}

func (p *PGPubsub) Publish(event string, message []byte) error {
p.logger.Debug(p.ctx, "publish", slog.F("event", event), slog.F("message_len", len(message)))
p.logger.Debug(context.Background(), "publish", slog.F("event", event), slog.F("message_len", len(message)))
// This is safe because we are calling pq.QuoteLiteral. pg_notify doesn't
// support the first parameter being a prepared statement.
//nolint:gosec
_, err := p.db.ExecContext(p.ctx, `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message)
_, err := p.db.ExecContext(context.Background(), `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message)
if err != nil {
p.publishesTotal.WithLabelValues("false").Inc()
return xerrors.Errorf("exec pg_notify: %w", err)
Expand All @@ -269,53 +300,38 @@ func (p *PGPubsub) Publish(event string, message []byte) error {

// Close closes the pubsub instance.
func (p *PGPubsub) Close() error {
p.logger.Info(p.ctx, "pubsub is closing")
p.cancel()
p.logger.Info(context.Background(), "pubsub is closing")
err := p.closeListener()
<-p.listenDone
p.logger.Debug(p.ctx, "pubsub closed")
p.logger.Debug(context.Background(), "pubsub closed")
return err
}

// closeListener closes the pgListener, unless it has already been closed.
func (p *PGPubsub) closeListener() error {
p.mut.Lock()
defer p.mut.Unlock()
p.closeMu.Lock()
defer p.closeMu.Unlock()
if p.closedListener {
return p.closeListenerErr
}
p.closeListenerErr = p.pgListener.Close()
p.closedListener = true
p.closeListenerErr = p.pgListener.Close()

return p.closeListenerErr
}

// listen begins receiving messages on the pq listener.
func (p *PGPubsub) listen() {
defer func() {
p.logger.Info(p.ctx, "pubsub listen stopped receiving notify")
cErr := p.closeListener()
if cErr != nil {
p.logger.Error(p.ctx, "failed to close listener")
}
p.logger.Info(context.Background(), "pubsub listen stopped receiving notify")
close(p.listenDone)
}()

var (
notif *pq.Notification
ok bool
)
for {
select {
case <-p.ctx.Done():
return
case notif, ok = <-p.pgListener.Notify:
if !ok {
return
}
}
notify := p.pgListener.NotifyChan()
for notif := range notify {
// A nil notification can be dispatched on reconnect.
if notif == nil {
p.logger.Debug(p.ctx, "notifying subscribers of a reconnection")
p.logger.Debug(context.Background(), "notifying subscribers of a reconnection")
p.recordReconnect()
continue
}
Expand All @@ -331,8 +347,8 @@ func (p *PGPubsub) listenReceive(notif *pq.Notification) {
p.messagesTotal.WithLabelValues(sizeLabel).Inc()
p.receivedBytesTotal.Add(float64(len(notif.Extra)))

p.mut.Lock()
defer p.mut.Unlock()
p.qMu.Lock()
defer p.qMu.Unlock()
queues, ok := p.queues[notif.Channel]
if !ok {
return
Expand All @@ -344,8 +360,8 @@ func (p *PGPubsub) listenReceive(notif *pq.Notification) {
}

func (p *PGPubsub) recordReconnect() {
p.mut.Lock()
defer p.mut.Unlock()
p.qMu.Lock()
defer p.qMu.Unlock()
for _, listeners := range p.queues {
for _, q := range listeners {
q.dropped()
Expand Down Expand Up @@ -409,30 +425,32 @@ func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
d: net.Dialer{},
}
)
p.pgListener = pq.NewDialListener(dialer, connectURL, time.Second, time.Minute, func(t pq.ListenerEventType, err error) {
switch t {
case pq.ListenerEventConnected:
p.logger.Info(ctx, "pubsub connected to postgres")
p.connected.Set(1.0)
case pq.ListenerEventDisconnected:
p.logger.Error(ctx, "pubsub disconnected from postgres", slog.Error(err))
p.connected.Set(0)
case pq.ListenerEventReconnected:
p.logger.Info(ctx, "pubsub reconnected to postgres")
p.connected.Set(1)
case pq.ListenerEventConnectionAttemptFailed:
p.logger.Error(ctx, "pubsub failed to connect to postgres", slog.Error(err))
}
// This callback gets events whenever the connection state changes.
// Don't send if the errChannel has already been closed.
select {
case <-errCh:
return
default:
errCh <- err
close(errCh)
}
})
p.pgListener = pqListenerShim{
Listener: pq.NewDialListener(dialer, connectURL, time.Second, time.Minute, func(t pq.ListenerEventType, err error) {
switch t {
case pq.ListenerEventConnected:
p.logger.Info(ctx, "pubsub connected to postgres")
p.connected.Set(1.0)
case pq.ListenerEventDisconnected:
p.logger.Error(ctx, "pubsub disconnected from postgres", slog.Error(err))
p.connected.Set(0)
case pq.ListenerEventReconnected:
p.logger.Info(ctx, "pubsub reconnected to postgres")
p.connected.Set(1)
case pq.ListenerEventConnectionAttemptFailed:
p.logger.Error(ctx, "pubsub failed to connect to postgres", slog.Error(err))
}
// This callback gets events whenever the connection state changes.
// Don't send if the errChannel has already been closed.
select {
case <-errCh:
return
default:
errCh <- err
close(errCh)
}
}),
}
select {
case err := <-errCh:
if err != nil {
Expand Down Expand Up @@ -501,24 +519,31 @@ func (p *PGPubsub) Collect(metrics chan<- prometheus.Metric) {
p.connected.Collect(metrics)

// implicit metrics
p.mut.Lock()
p.qMu.Lock()
events := len(p.queues)
subs := 0
for _, subscriberMap := range p.queues {
subs += len(subscriberMap)
}
p.mut.Unlock()
p.qMu.Unlock()
metrics <- prometheus.MustNewConstMetric(currentSubscribersDesc, prometheus.GaugeValue, float64(subs))
metrics <- prometheus.MustNewConstMetric(currentEventsDesc, prometheus.GaugeValue, float64(events))
}

// New creates a new Pubsub implementation using a PostgreSQL connection.
func New(startCtx context.Context, logger slog.Logger, database *sql.DB, connectURL string) (*PGPubsub, error) {
// Start a new context that will be canceled when the pubsub is closed.
ctx, cancel := context.WithCancel(context.Background())
p := &PGPubsub{
ctx: ctx,
cancel: cancel,
p := newWithoutListener(logger, database)
if err := p.startListener(startCtx, connectURL); err != nil {
return nil, err
}
go p.listen()
logger.Info(startCtx, "pubsub has started")
return p, nil
}

// newWithoutListener creates a new PGPubsub without creating the pqListener.
func newWithoutListener(logger slog.Logger, database *sql.DB) *PGPubsub {
return &PGPubsub{
logger: logger,
listenDone: make(chan struct{}),
db: database,
Expand Down Expand Up @@ -567,10 +592,4 @@ func New(startCtx context.Context, logger slog.Logger, database *sql.DB, connect
Help: "Whether we are connected (1) or not connected (0) to postgres",
}),
}
if err := p.startListener(startCtx, connectURL); err != nil {
return nil, err
}
go p.listen()
logger.Info(ctx, "pubsub has started")
return p, nil
}
Loading