Skip to content

fix: pubsub ordering #7404

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 37 additions & 16 deletions coderd/database/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,17 @@ type Pubsub interface {

// Pubsub implementation using PostgreSQL.
type pgPubsub struct {
ctx context.Context
pgListener *pq.Listener
db *sql.DB
mut sync.Mutex
listeners map[string]map[uuid.UUID]Listener
listeners map[string]map[uuid.UUID]chan<- []byte
}

// messageBufferSize is the maximum number of unhandled messages we will buffer
// for a subscriber before dropping messages.
const messageBufferSize = 2048

// Subscribe calls the listener when an event matching the name is received.
func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
p.mut.Lock()
Expand All @@ -45,25 +50,22 @@ func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), er
return nil, xerrors.Errorf("listen: %w", err)
}

var eventListeners map[uuid.UUID]Listener
var eventListeners map[uuid.UUID]chan<- []byte
var ok bool
if eventListeners, ok = p.listeners[event]; !ok {
eventListeners = map[uuid.UUID]Listener{}
eventListeners = make(map[uuid.UUID]chan<- []byte)
p.listeners[event] = eventListeners
}

var id uuid.UUID
for {
id = uuid.New()
if _, ok = eventListeners[id]; !ok {
break
}
}

eventListeners[id] = listener
ctx, cancelCallbacks := context.WithCancel(p.ctx)
messages := make(chan []byte, messageBufferSize)
go messagesToListener(ctx, messages, listener)
id := uuid.New()
eventListeners[id] = messages
return func() {
p.mut.Lock()
defer p.mut.Unlock()
cancelCallbacks()
listeners := p.listeners[event]
delete(listeners, id)

Expand Down Expand Up @@ -109,11 +111,11 @@ func (p *pgPubsub) listen(ctx context.Context) {
if notif == nil {
continue
}
p.listenReceive(ctx, notif)
p.listenReceive(notif)
}
}

func (p *pgPubsub) listenReceive(ctx context.Context, notif *pq.Notification) {
func (p *pgPubsub) listenReceive(notif *pq.Notification) {
p.mut.Lock()
defer p.mut.Unlock()
listeners, ok := p.listeners[notif.Channel]
Expand All @@ -122,7 +124,14 @@ func (p *pgPubsub) listenReceive(ctx context.Context, notif *pq.Notification) {
}
extra := []byte(notif.Extra)
for _, listener := range listeners {
go listener(ctx, extra)
select {
case listener <- extra:
// ok!
default:
// bad news, we dropped the event because the listener isn't
// keeping up
// TODO (spike): figure out a way to communicate this to the Listener
}
}
}

Expand Down Expand Up @@ -150,11 +159,23 @@ func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub
return nil, ctx.Err()
}
pgPubsub := &pgPubsub{
ctx: ctx,
db: database,
pgListener: listener,
listeners: make(map[string]map[uuid.UUID]Listener),
listeners: make(map[string]map[uuid.UUID]chan<- []byte),
}
go pgPubsub.listen(ctx)

return pgPubsub, nil
}

func messagesToListener(ctx context.Context, messages <-chan []byte, listener Listener) {
for {
select {
case <-ctx.Done():
return
case m := <-messages:
listener(ctx, m)
}
}
}
46 changes: 46 additions & 0 deletions coderd/database/pubsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@ package database_test
import (
"context"
"database/sql"
"fmt"
"math/rand"
"testing"
"time"

"github.com/coder/coder/testutil"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -67,3 +72,44 @@ func TestPubsub(t *testing.T) {
cancelFunc()
})
}

func TestPubsub_ordering(t *testing.T) {
t.Parallel()

ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()

connectionURL, closePg, err := postgres.Open()
require.NoError(t, err)
defer closePg()
db, err := sql.Open("postgres", connectionURL)
require.NoError(t, err)
defer db.Close()
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
require.NoError(t, err)
defer pubsub.Close()
event := "test"
messageChannel := make(chan []byte, 100)
cancelFunc, err = pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
// sleep a random amount of time to simulate handlers taking different amount of time
// to process, depending on the message
// nolint: gosec
n := rand.Intn(100)
time.Sleep(time.Duration(n) * time.Millisecond)
messageChannel <- message
})
require.NoError(t, err)
defer cancelFunc()
for i := 0; i < 100; i++ {
err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", i)))
assert.NoError(t, err)
}
for i := 0; i < 100; i++ {
select {
case <-time.After(testutil.WaitShort):
t.Fatalf("timed out waiting for message %d", i)
case message := <-messageChannel:
assert.Equal(t, fmt.Sprintf("%d", i), string(message))
}
}
}