diff --git a/coderd/database/pubsub.go b/coderd/database/pubsub.go index 72549d28691b9..917a93951fa10 100644 --- a/coderd/database/pubsub.go +++ b/coderd/database/pubsub.go @@ -25,12 +25,17 @@ type Pubsub interface { // Pubsub implementation using PostgreSQL. type pgPubsub struct { + ctx context.Context pgListener *pq.Listener db *sql.DB mut sync.Mutex - listeners map[string]map[uuid.UUID]Listener + listeners map[string]map[uuid.UUID]chan<- []byte } +// messageBufferSize is the maximum number of unhandled messages we will buffer +// for a subscriber before dropping messages. +const messageBufferSize = 2048 + // Subscribe calls the listener when an event matching the name is received. func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) { p.mut.Lock() @@ -45,25 +50,22 @@ func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), er return nil, xerrors.Errorf("listen: %w", err) } - var eventListeners map[uuid.UUID]Listener + var eventListeners map[uuid.UUID]chan<- []byte var ok bool if eventListeners, ok = p.listeners[event]; !ok { - eventListeners = map[uuid.UUID]Listener{} + eventListeners = make(map[uuid.UUID]chan<- []byte) p.listeners[event] = eventListeners } - var id uuid.UUID - for { - id = uuid.New() - if _, ok = eventListeners[id]; !ok { - break - } - } - - eventListeners[id] = listener + ctx, cancelCallbacks := context.WithCancel(p.ctx) + messages := make(chan []byte, messageBufferSize) + go messagesToListener(ctx, messages, listener) + id := uuid.New() + eventListeners[id] = messages return func() { p.mut.Lock() defer p.mut.Unlock() + cancelCallbacks() listeners := p.listeners[event] delete(listeners, id) @@ -109,11 +111,11 @@ func (p *pgPubsub) listen(ctx context.Context) { if notif == nil { continue } - p.listenReceive(ctx, notif) + p.listenReceive(notif) } } -func (p *pgPubsub) listenReceive(ctx context.Context, notif *pq.Notification) { +func (p *pgPubsub) listenReceive(notif *pq.Notification) { p.mut.Lock() defer p.mut.Unlock() listeners, ok := p.listeners[notif.Channel] @@ -122,7 +124,14 @@ func (p *pgPubsub) listenReceive(ctx context.Context, notif *pq.Notification) { } extra := []byte(notif.Extra) for _, listener := range listeners { - go listener(ctx, extra) + select { + case listener <- extra: + // ok! + default: + // bad news, we dropped the event because the listener isn't + // keeping up + // TODO (spike): figure out a way to communicate this to the Listener + } } } @@ -150,11 +159,23 @@ func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub return nil, ctx.Err() } pgPubsub := &pgPubsub{ + ctx: ctx, db: database, pgListener: listener, - listeners: make(map[string]map[uuid.UUID]Listener), + listeners: make(map[string]map[uuid.UUID]chan<- []byte), } go pgPubsub.listen(ctx) return pgPubsub, nil } + +func messagesToListener(ctx context.Context, messages <-chan []byte, listener Listener) { + for { + select { + case <-ctx.Done(): + return + case m := <-messages: + listener(ctx, m) + } + } +} diff --git a/coderd/database/pubsub_test.go b/coderd/database/pubsub_test.go index c1377b69aa4ae..13d3c5723fb29 100644 --- a/coderd/database/pubsub_test.go +++ b/coderd/database/pubsub_test.go @@ -5,7 +5,12 @@ package database_test import ( "context" "database/sql" + "fmt" + "math/rand" "testing" + "time" + + "github.com/coder/coder/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -67,3 +72,44 @@ func TestPubsub(t *testing.T) { cancelFunc() }) } + +func TestPubsub_ordering(t *testing.T) { + t.Parallel() + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + connectionURL, closePg, err := postgres.Open() + require.NoError(t, err) + defer closePg() + db, err := sql.Open("postgres", connectionURL) + require.NoError(t, err) + defer db.Close() + pubsub, err := database.NewPubsub(ctx, db, connectionURL) + require.NoError(t, err) + defer pubsub.Close() + event := "test" + messageChannel := make(chan []byte, 100) + cancelFunc, err = pubsub.Subscribe(event, func(ctx context.Context, message []byte) { + // sleep a random amount of time to simulate handlers taking different amount of time + // to process, depending on the message + // nolint: gosec + n := rand.Intn(100) + time.Sleep(time.Duration(n) * time.Millisecond) + messageChannel <- message + }) + require.NoError(t, err) + defer cancelFunc() + for i := 0; i < 100; i++ { + err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", i))) + assert.NoError(t, err) + } + for i := 0; i < 100; i++ { + select { + case <-time.After(testutil.WaitShort): + t.Fatalf("timed out waiting for message %d", i) + case message := <-messageChannel: + assert.Equal(t, fmt.Sprintf("%d", i), string(message)) + } + } +}