Skip to content

Commit 5a91aaa

Browse files
committed
Implementation; need linux tests
Signed-off-by: Spike Curtis <spike@coder.com>
1 parent a1853f2 commit 5a91aaa

File tree

3 files changed

+211
-40
lines changed

3 files changed

+211
-40
lines changed

coderd/database/pubsub.go

Lines changed: 165 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,155 @@ import (
1515
// Listener represents a pubsub handler.
1616
type Listener func(ctx context.Context, message []byte)
1717

18+
// ListenerWithErr represents a pubsub handler that can also receive error
19+
// indications
20+
type ListenerWithErr func(ctx context.Context, message []byte, err error)
21+
22+
// ErrDroppedMessages is sent to ListenerWithErr if messages are dropped or
23+
// might have been dropped.
24+
var ErrDroppedMessages = xerrors.New("dropped messages")
25+
1826
// Pubsub is a generic interface for broadcasting and receiving messages.
1927
// Implementors should assume high-availability with the backing implementation.
2028
type Pubsub interface {
2129
Subscribe(event string, listener Listener) (cancel func(), err error)
30+
SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error)
2231
Publish(event string, message []byte) error
2332
Close() error
2433
}
2534

35+
// msgOrErr either contains a message or an error
36+
type msgOrErr struct {
37+
msg []byte
38+
err error
39+
}
40+
41+
// msgQueue implements a fixed length queue with the ability to replace elements
42+
// after they are queued (but before they are dequeued).
43+
//
44+
// The purpose of this data structure is to build something that works a bit
45+
// like a golang channel, but if the queue is full, then we can replace the
46+
// last element with an error so that the subscriber can get notified that some
47+
// messages were dropped, all without blocking.
48+
type msgQueue struct {
49+
ctx context.Context
50+
cond *sync.Cond
51+
q [messageBufferSize]msgOrErr
52+
front int
53+
back int
54+
closed bool
55+
l Listener
56+
le ListenerWithErr
57+
}
58+
59+
func newMsgQueue(ctx context.Context, l Listener, le ListenerWithErr) *msgQueue {
60+
if l == nil && le == nil {
61+
panic("l or le must be non-nil")
62+
}
63+
q := &msgQueue{
64+
ctx: ctx,
65+
cond: sync.NewCond(&sync.Mutex{}),
66+
l: l,
67+
le: le,
68+
}
69+
go q.run()
70+
return q
71+
}
72+
73+
func (q *msgQueue) run() {
74+
for {
75+
// wait until there is something on the queue or we are closed
76+
q.cond.L.Lock()
77+
for q.front == q.back && !q.closed {
78+
q.cond.Wait()
79+
}
80+
if q.closed {
81+
q.cond.L.Unlock()
82+
return
83+
}
84+
item := q.q[q.front]
85+
q.front = (q.front + 1) % messageBufferSize
86+
q.cond.L.Unlock()
87+
88+
// process item without holding lock
89+
if item.err == nil {
90+
// real message
91+
if q.l != nil {
92+
q.l(q.ctx, item.msg)
93+
continue
94+
}
95+
if q.le != nil {
96+
q.le(q.ctx, item.msg, nil)
97+
continue
98+
}
99+
// unhittable
100+
continue
101+
}
102+
// if the listener wants errors, send it.
103+
if q.le != nil {
104+
q.le(q.ctx, nil, item.err)
105+
}
106+
}
107+
}
108+
109+
func (q *msgQueue) enqueue(msg []byte) {
110+
q.cond.L.Lock()
111+
defer q.cond.L.Unlock()
112+
113+
next := (q.back + 1) % messageBufferSize
114+
if next == q.front {
115+
// queue is full, so we're going to drop the msg we got called with.
116+
// We also need to record that messages are being dropped, which we
117+
// do at the last message in the queue. This potentially makes us
118+
// lose 2 messages instead of one, but it's more important at this
119+
// point to warn the subscriber that they're losing messages so they
120+
// can do something about it.
121+
q.q[q.back].msg = nil
122+
q.q[q.back].err = ErrDroppedMessages
123+
return
124+
}
125+
// queue is not full, insert the message
126+
q.back = next
127+
q.q[next].msg = msg
128+
q.q[next].err = nil
129+
q.cond.Broadcast()
130+
}
131+
132+
func (q *msgQueue) close() {
133+
q.cond.L.Lock()
134+
defer q.cond.L.Unlock()
135+
defer q.cond.Broadcast()
136+
q.closed = true
137+
}
138+
139+
// dropped records an error in the queue that messages might have been dropped
140+
func (q *msgQueue) dropped() {
141+
q.cond.L.Lock()
142+
defer q.cond.L.Unlock()
143+
144+
next := (q.back + 1) % messageBufferSize
145+
if next == q.front {
146+
// queue is full, but we need to record that messages are being dropped,
147+
// which we do at the last message in the queue. This potentially drops
148+
// another message, but it's more important for the subscriber to know.
149+
q.q[q.back].msg = nil
150+
q.q[q.back].err = ErrDroppedMessages
151+
return
152+
}
153+
// queue is not full, insert the error
154+
q.back = next
155+
q.q[next].msg = nil
156+
q.q[next].err = ErrDroppedMessages
157+
q.cond.Broadcast()
158+
}
159+
26160
// Pubsub implementation using PostgreSQL.
27161
type pgPubsub struct {
28162
ctx context.Context
29163
pgListener *pq.Listener
30164
db *sql.DB
31165
mut sync.Mutex
32-
listeners map[string]map[uuid.UUID]chan<- []byte
166+
queues map[string]map[uuid.UUID]*msgQueue
33167
}
34168

35169
// messageBufferSize is the maximum number of unhandled messages we will buffer
@@ -38,6 +172,14 @@ const messageBufferSize = 2048
38172

39173
// Subscribe calls the listener when an event matching the name is received.
40174
func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
175+
return p.subscribe(event, newMsgQueue(p.ctx, listener, nil))
176+
}
177+
178+
func (p *pgPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) {
179+
return p.subscribe(event, newMsgQueue(p.ctx, nil, listener))
180+
}
181+
182+
func (p *pgPubsub) subscribe(event string, newQ *msgQueue) (cancel func(), err error) {
41183
p.mut.Lock()
42184
defer p.mut.Unlock()
43185

@@ -50,23 +192,20 @@ func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), er
50192
return nil, xerrors.Errorf("listen: %w", err)
51193
}
52194

53-
var eventListeners map[uuid.UUID]chan<- []byte
195+
var eventQs map[uuid.UUID]*msgQueue
54196
var ok bool
55-
if eventListeners, ok = p.listeners[event]; !ok {
56-
eventListeners = make(map[uuid.UUID]chan<- []byte)
57-
p.listeners[event] = eventListeners
197+
if eventQs, ok = p.queues[event]; !ok {
198+
eventQs = make(map[uuid.UUID]*msgQueue)
199+
p.queues[event] = eventQs
58200
}
59-
60-
ctx, cancelCallbacks := context.WithCancel(p.ctx)
61-
messages := make(chan []byte, messageBufferSize)
62-
go messagesToListener(ctx, messages, listener)
63201
id := uuid.New()
64-
eventListeners[id] = messages
202+
eventQs[id] = newQ
65203
return func() {
66204
p.mut.Lock()
67205
defer p.mut.Unlock()
68-
cancelCallbacks()
69-
listeners := p.listeners[event]
206+
listeners := p.queues[event]
207+
q := listeners[id]
208+
q.close()
70209
delete(listeners, id)
71210

72211
if len(listeners) == 0 {
@@ -109,7 +248,7 @@ func (p *pgPubsub) listen(ctx context.Context) {
109248
}
110249
// A nil notification can be dispatched on reconnect.
111250
if notif == nil {
112-
continue
251+
p.recordReconnect()
113252
}
114253
p.listenReceive(notif)
115254
}
@@ -118,19 +257,22 @@ func (p *pgPubsub) listen(ctx context.Context) {
118257
func (p *pgPubsub) listenReceive(notif *pq.Notification) {
119258
p.mut.Lock()
120259
defer p.mut.Unlock()
121-
listeners, ok := p.listeners[notif.Channel]
260+
queues, ok := p.queues[notif.Channel]
122261
if !ok {
123262
return
124263
}
125264
extra := []byte(notif.Extra)
126-
for _, listener := range listeners {
127-
select {
128-
case listener <- extra:
129-
// ok!
130-
default:
131-
// bad news, we dropped the event because the listener isn't
132-
// keeping up
133-
// TODO (spike): figure out a way to communicate this to the Listener
265+
for _, q := range queues {
266+
q.enqueue(extra)
267+
}
268+
}
269+
270+
func (p *pgPubsub) recordReconnect() {
271+
p.mut.Lock()
272+
defer p.mut.Unlock()
273+
for _, listeners := range p.queues {
274+
for _, q := range listeners {
275+
q.dropped()
134276
}
135277
}
136278
}
@@ -162,20 +304,9 @@ func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub
162304
ctx: ctx,
163305
db: database,
164306
pgListener: listener,
165-
listeners: make(map[string]map[uuid.UUID]chan<- []byte),
307+
queues: make(map[string]map[uuid.UUID]*msgQueue),
166308
}
167309
go pgPubsub.listen(ctx)
168310

169311
return pgPubsub, nil
170312
}
171-
172-
func messagesToListener(ctx context.Context, messages <-chan []byte, listener Listener) {
173-
for {
174-
select {
175-
case <-ctx.Done():
176-
return
177-
case m := <-messages:
178-
listener(ctx, m)
179-
}
180-
}
181-
}

coderd/database/pubsub_memory.go

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,34 @@ import (
77
"github.com/google/uuid"
88
)
99

10+
// genericListener is either a Listener or ListenerWithErr
11+
type genericListener struct {
12+
l Listener
13+
le ListenerWithErr
14+
}
15+
1016
// memoryPubsub is an in-memory Pubsub implementation.
1117
type memoryPubsub struct {
1218
mut sync.RWMutex
13-
listeners map[string]map[uuid.UUID]Listener
19+
listeners map[string]map[uuid.UUID]genericListener
1420
}
1521

1622
func (m *memoryPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
23+
return m.subscribe(event, genericListener{l: listener})
24+
}
25+
26+
func (m *memoryPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) {
27+
return m.subscribe(event, genericListener{le: listener})
28+
}
29+
30+
func (m *memoryPubsub) subscribe(event string, listener genericListener) (cancel func(), err error) {
1731
m.mut.Lock()
1832
defer m.mut.Unlock()
1933

20-
var listeners map[uuid.UUID]Listener
34+
var listeners map[uuid.UUID]genericListener
2135
var ok bool
2236
if listeners, ok = m.listeners[event]; !ok {
23-
listeners = map[uuid.UUID]Listener{}
37+
listeners = map[uuid.UUID]genericListener{}
2438
m.listeners[event] = listeners
2539
}
2640
var id uuid.UUID
@@ -52,7 +66,12 @@ func (m *memoryPubsub) Publish(event string, message []byte) error {
5266
listener := listener
5367
go func() {
5468
defer wg.Done()
55-
listener(context.Background(), message)
69+
if listener.l != nil {
70+
listener.l(context.Background(), message)
71+
}
72+
if listener.le != nil {
73+
listener.le(context.Background(), message, nil)
74+
}
5675
}()
5776
}
5877
wg.Wait()
@@ -66,6 +85,6 @@ func (*memoryPubsub) Close() error {
6685

6786
func NewPubsubInMemory() Pubsub {
6887
return &memoryPubsub{
69-
listeners: make(map[string]map[uuid.UUID]Listener),
88+
listeners: make(map[string]map[uuid.UUID]genericListener),
7089
}
7190
}

coderd/database/pubsub_memory_test.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import (
1313
func TestPubsubMemory(t *testing.T) {
1414
t.Parallel()
1515

16-
t.Run("Memory", func(t *testing.T) {
16+
t.Run("Legacy", func(t *testing.T) {
1717
t.Parallel()
1818

1919
pubsub := database.NewPubsubInMemory()
@@ -32,4 +32,25 @@ func TestPubsubMemory(t *testing.T) {
3232
message := <-messageChannel
3333
assert.Equal(t, string(message), data)
3434
})
35+
36+
t.Run("WithErr", func(t *testing.T) {
37+
t.Parallel()
38+
39+
pubsub := database.NewPubsubInMemory()
40+
event := "test"
41+
data := "testing"
42+
messageChannel := make(chan []byte)
43+
cancelFunc, err := pubsub.SubscribeWithErr(event, func(ctx context.Context, message []byte, err error) {
44+
assert.NoError(t, err) // memory pubsub never sends errors.
45+
messageChannel <- message
46+
})
47+
require.NoError(t, err)
48+
defer cancelFunc()
49+
go func() {
50+
err = pubsub.Publish(event, []byte(data))
51+
assert.NoError(t, err)
52+
}()
53+
message := <-messageChannel
54+
assert.Equal(t, string(message), data)
55+
})
3556
}

0 commit comments

Comments
 (0)