Skip to content

Commit a6a4489

Browse files
authored
fix: pubsub ordering (#7404)
* fix: pubsub sends messages in order Signed-off-by: Spike Curtis <spike@coder.com> * Drop messages rather than block Signed-off-by: Spike Curtis <spike@coder.com> --------- Signed-off-by: Spike Curtis <spike@coder.com>
1 parent 667d9a7 commit a6a4489

File tree

2 files changed

+83
-16
lines changed

2 files changed

+83
-16
lines changed

coderd/database/pubsub.go

+37-16
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,17 @@ type Pubsub interface {
2525

2626
// Pubsub implementation using PostgreSQL.
2727
type pgPubsub struct {
28+
ctx context.Context
2829
pgListener *pq.Listener
2930
db *sql.DB
3031
mut sync.Mutex
31-
listeners map[string]map[uuid.UUID]Listener
32+
listeners map[string]map[uuid.UUID]chan<- []byte
3233
}
3334

35+
// messageBufferSize is the maximum number of unhandled messages we will buffer
36+
// for a subscriber before dropping messages.
37+
const messageBufferSize = 2048
38+
3439
// Subscribe calls the listener when an event matching the name is received.
3540
func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
3641
p.mut.Lock()
@@ -45,25 +50,22 @@ func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), er
4550
return nil, xerrors.Errorf("listen: %w", err)
4651
}
4752

48-
var eventListeners map[uuid.UUID]Listener
53+
var eventListeners map[uuid.UUID]chan<- []byte
4954
var ok bool
5055
if eventListeners, ok = p.listeners[event]; !ok {
51-
eventListeners = map[uuid.UUID]Listener{}
56+
eventListeners = make(map[uuid.UUID]chan<- []byte)
5257
p.listeners[event] = eventListeners
5358
}
5459

55-
var id uuid.UUID
56-
for {
57-
id = uuid.New()
58-
if _, ok = eventListeners[id]; !ok {
59-
break
60-
}
61-
}
62-
63-
eventListeners[id] = listener
60+
ctx, cancelCallbacks := context.WithCancel(p.ctx)
61+
messages := make(chan []byte, messageBufferSize)
62+
go messagesToListener(ctx, messages, listener)
63+
id := uuid.New()
64+
eventListeners[id] = messages
6465
return func() {
6566
p.mut.Lock()
6667
defer p.mut.Unlock()
68+
cancelCallbacks()
6769
listeners := p.listeners[event]
6870
delete(listeners, id)
6971

@@ -109,11 +111,11 @@ func (p *pgPubsub) listen(ctx context.Context) {
109111
if notif == nil {
110112
continue
111113
}
112-
p.listenReceive(ctx, notif)
114+
p.listenReceive(notif)
113115
}
114116
}
115117

116-
func (p *pgPubsub) listenReceive(ctx context.Context, notif *pq.Notification) {
118+
func (p *pgPubsub) listenReceive(notif *pq.Notification) {
117119
p.mut.Lock()
118120
defer p.mut.Unlock()
119121
listeners, ok := p.listeners[notif.Channel]
@@ -122,7 +124,14 @@ func (p *pgPubsub) listenReceive(ctx context.Context, notif *pq.Notification) {
122124
}
123125
extra := []byte(notif.Extra)
124126
for _, listener := range listeners {
125-
go listener(ctx, extra)
127+
select {
128+
case listener <- extra:
129+
// ok!
130+
default:
131+
// bad news, we dropped the event because the listener isn't
132+
// keeping up
133+
// TODO (spike): figure out a way to communicate this to the Listener
134+
}
126135
}
127136
}
128137

@@ -150,11 +159,23 @@ func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub
150159
return nil, ctx.Err()
151160
}
152161
pgPubsub := &pgPubsub{
162+
ctx: ctx,
153163
db: database,
154164
pgListener: listener,
155-
listeners: make(map[string]map[uuid.UUID]Listener),
165+
listeners: make(map[string]map[uuid.UUID]chan<- []byte),
156166
}
157167
go pgPubsub.listen(ctx)
158168

159169
return pgPubsub, nil
160170
}
171+
172+
func messagesToListener(ctx context.Context, messages <-chan []byte, listener Listener) {
173+
for {
174+
select {
175+
case <-ctx.Done():
176+
return
177+
case m := <-messages:
178+
listener(ctx, m)
179+
}
180+
}
181+
}

coderd/database/pubsub_test.go

+46
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@ package database_test
55
import (
66
"context"
77
"database/sql"
8+
"fmt"
9+
"math/rand"
810
"testing"
11+
"time"
12+
13+
"github.com/coder/coder/testutil"
914

1015
"github.com/stretchr/testify/assert"
1116
"github.com/stretchr/testify/require"
@@ -67,3 +72,44 @@ func TestPubsub(t *testing.T) {
6772
cancelFunc()
6873
})
6974
}
75+
76+
func TestPubsub_ordering(t *testing.T) {
77+
t.Parallel()
78+
79+
ctx, cancelFunc := context.WithCancel(context.Background())
80+
defer cancelFunc()
81+
82+
connectionURL, closePg, err := postgres.Open()
83+
require.NoError(t, err)
84+
defer closePg()
85+
db, err := sql.Open("postgres", connectionURL)
86+
require.NoError(t, err)
87+
defer db.Close()
88+
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
89+
require.NoError(t, err)
90+
defer pubsub.Close()
91+
event := "test"
92+
messageChannel := make(chan []byte, 100)
93+
cancelFunc, err = pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
94+
// sleep a random amount of time to simulate handlers taking different amount of time
95+
// to process, depending on the message
96+
// nolint: gosec
97+
n := rand.Intn(100)
98+
time.Sleep(time.Duration(n) * time.Millisecond)
99+
messageChannel <- message
100+
})
101+
require.NoError(t, err)
102+
defer cancelFunc()
103+
for i := 0; i < 100; i++ {
104+
err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", i)))
105+
assert.NoError(t, err)
106+
}
107+
for i := 0; i < 100; i++ {
108+
select {
109+
case <-time.After(testutil.WaitShort):
110+
t.Fatalf("timed out waiting for message %d", i)
111+
case message := <-messageChannel:
112+
assert.Equal(t, fmt.Sprintf("%d", i), string(message))
113+
}
114+
}
115+
}

0 commit comments

Comments
 (0)