Skip to content

fix(coderd/database): improve pubsub closure and context cancellation #7993

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions coderd/database/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ func (q *msgQueue) dropped() {
// Pubsub implementation using PostgreSQL.
type pgPubsub struct {
ctx context.Context
cancel context.CancelFunc
listenDone chan struct{}
pgListener *pq.Listener
db *sql.DB
mut sync.Mutex
Expand Down Expand Up @@ -228,7 +230,7 @@ func (p *pgPubsub) Publish(event string, message []byte) error {
// This is safe because we are calling pq.QuoteLiteral. pg_notify doesn't
// support the first parameter being a prepared statement.
//nolint:gosec
_, err := p.db.ExecContext(context.Background(), `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message)
_, err := p.db.ExecContext(p.ctx, `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message)
if err != nil {
return xerrors.Errorf("exec pg_notify: %w", err)
}
Expand All @@ -237,19 +239,24 @@ func (p *pgPubsub) Publish(event string, message []byte) error {

// Close closes the pubsub instance.
func (p *pgPubsub) Close() error {
return p.pgListener.Close()
p.cancel()
err := p.pgListener.Close()
<-p.listenDone
return err
}

// listen begins receiving messages on the pq listener.
func (p *pgPubsub) listen(ctx context.Context) {
func (p *pgPubsub) listen() {
defer close(p.listenDone)
defer p.pgListener.Close()

var (
notif *pq.Notification
ok bool
)
defer p.pgListener.Close()
for {
select {
case <-ctx.Done():
case <-p.ctx.Done():
return
case notif, ok = <-p.pgListener.Notify:
if !ok {
Expand Down Expand Up @@ -292,7 +299,7 @@ func (p *pgPubsub) recordReconnect() {
func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub, error) {
// Creates a new listener using pq.
errCh := make(chan error)
listener := pq.NewListener(connectURL, time.Second, time.Minute, func(event pq.ListenerEventType, err error) {
listener := pq.NewListener(connectURL, time.Second, time.Minute, func(_ pq.ListenerEventType, err error) {
// This callback gets events whenever the connection state changes.
// Don't send if the errChannel has already been closed.
select {
Expand All @@ -306,18 +313,25 @@ func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub
select {
case err := <-errCh:
if err != nil {
_ = listener.Close()
return nil, xerrors.Errorf("create pq listener: %w", err)
}
case <-ctx.Done():
_ = listener.Close()
return nil, ctx.Err()
}

// Start a new context that will be canceled when the pubsub is closed.
ctx, cancel := context.WithCancel(context.Background())
pgPubsub := &pgPubsub{
ctx: ctx,
cancel: cancel,
listenDone: make(chan struct{}),
db: database,
pgListener: listener,
queues: make(map[string]map[uuid.UUID]*msgQueue),
}
go pgPubsub.listen(ctx)
go pgPubsub.listen()

return pgPubsub, nil
}
89 changes: 87 additions & 2 deletions coderd/database/pubsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ func TestPubsub(t *testing.T) {
event := "test"
data := "testing"
messageChannel := make(chan []byte)
cancelFunc, err = pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
unsub, err := pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
messageChannel <- message
})
require.NoError(t, err)
defer cancelFunc()
defer unsub()
go func() {
err = pubsub.Publish(event, []byte(data))
assert.NoError(t, err)
Expand All @@ -72,6 +72,91 @@ func TestPubsub(t *testing.T) {
defer pubsub.Close()
cancelFunc()
})

t.Run("NotClosedOnCancelContext", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
connectionURL, closePg, err := postgres.Open()
require.NoError(t, err)
defer closePg()
db, err := sql.Open("postgres", connectionURL)
require.NoError(t, err)
defer db.Close()
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
require.NoError(t, err)
defer pubsub.Close()

// Provided context must only be active during NewPubsub, not after.
cancel()

event := "test"
data := "testing"
messageChannel := make(chan []byte)
unsub, err := pubsub.Subscribe(event, func(_ context.Context, message []byte) {
messageChannel <- message
})
require.NoError(t, err)
defer unsub()
go func() {
err = pubsub.Publish(event, []byte(data))
assert.NoError(t, err)
}()
message := <-messageChannel
assert.Equal(t, string(message), data)
})

t.Run("ClosePropagatesContextCancellationToSubscription", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
connectionURL, closePg, err := postgres.Open()
require.NoError(t, err)
defer closePg()
db, err := sql.Open("postgres", connectionURL)
require.NoError(t, err)
defer db.Close()
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
require.NoError(t, err)
defer pubsub.Close()

event := "test"
done := make(chan struct{})
called := make(chan struct{})
unsub, err := pubsub.Subscribe(event, func(subCtx context.Context, _ []byte) {
defer close(done)
select {
case <-subCtx.Done():
assert.Fail(t, "context should not be canceled")
default:
}
close(called)
select {
case <-subCtx.Done():
case <-ctx.Done():
assert.Fail(t, "timeout waiting for sub context to be canceled")
}
})
require.NoError(t, err)
defer unsub()

go func() {
err := pubsub.Publish(event, nil)
assert.NoError(t, err)
}()

select {
case <-called:
case <-ctx.Done():
require.Fail(t, "timeout waiting for handler to be called")
}
err = pubsub.Close()
require.NoError(t, err)

select {
case <-done:
case <-ctx.Done():
require.Fail(t, "timeout waiting for handler to finish")
}
})
}

func TestPubsub_ordering(t *testing.T) {
Expand Down