From 5a91aaabec8441c159a827df197acaba57859795 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Mon, 22 May 2023 16:18:48 +0400 Subject: [PATCH 1/4] Implementation; need linux tests Signed-off-by: Spike Curtis --- coderd/database/pubsub.go | 199 +++++++++++++++++++++----- coderd/database/pubsub_memory.go | 29 +++- coderd/database/pubsub_memory_test.go | 23 ++- 3 files changed, 211 insertions(+), 40 deletions(-) diff --git a/coderd/database/pubsub.go b/coderd/database/pubsub.go index 917a93951fa10..93c5ad50e8230 100644 --- a/coderd/database/pubsub.go +++ b/coderd/database/pubsub.go @@ -15,21 +15,155 @@ import ( // Listener represents a pubsub handler. type Listener func(ctx context.Context, message []byte) +// ListenerWithErr represents a pubsub handler that can also receive error +// indications +type ListenerWithErr func(ctx context.Context, message []byte, err error) + +// ErrDroppedMessages is sent to ListenerWithErr if messages are dropped or +// might have been dropped. +var ErrDroppedMessages = xerrors.New("dropped messages") + // Pubsub is a generic interface for broadcasting and receiving messages. // Implementors should assume high-availability with the backing implementation. type Pubsub interface { Subscribe(event string, listener Listener) (cancel func(), err error) + SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) Publish(event string, message []byte) error Close() error } +// msgOrErr either contains a message or an error +type msgOrErr struct { + msg []byte + err error +} + +// msgQueue implements a fixed length queue with the ability to replace elements +// after they are queued (but before they are dequeued). +// +// The purpose of this data structure is to build something that works a bit +// like a golang channel, but if the queue is full, then we can replace the +// last element with an error so that the subscriber can get notified that some +// messages were dropped, all without blocking. +type msgQueue struct { + ctx context.Context + cond *sync.Cond + q [messageBufferSize]msgOrErr + front int + back int + closed bool + l Listener + le ListenerWithErr +} + +func newMsgQueue(ctx context.Context, l Listener, le ListenerWithErr) *msgQueue { + if l == nil && le == nil { + panic("l or le must be non-nil") + } + q := &msgQueue{ + ctx: ctx, + cond: sync.NewCond(&sync.Mutex{}), + l: l, + le: le, + } + go q.run() + return q +} + +func (q *msgQueue) run() { + for { + // wait until there is something on the queue or we are closed + q.cond.L.Lock() + for q.front == q.back && !q.closed { + q.cond.Wait() + } + if q.closed { + q.cond.L.Unlock() + return + } + item := q.q[q.front] + q.front = (q.front + 1) % messageBufferSize + q.cond.L.Unlock() + + // process item without holding lock + if item.err == nil { + // real message + if q.l != nil { + q.l(q.ctx, item.msg) + continue + } + if q.le != nil { + q.le(q.ctx, item.msg, nil) + continue + } + // unhittable + continue + } + // if the listener wants errors, send it. + if q.le != nil { + q.le(q.ctx, nil, item.err) + } + } +} + +func (q *msgQueue) enqueue(msg []byte) { + q.cond.L.Lock() + defer q.cond.L.Unlock() + + next := (q.back + 1) % messageBufferSize + if next == q.front { + // queue is full, so we're going to drop the msg we got called with. + // We also need to record that messages are being dropped, which we + // do at the last message in the queue. This potentially makes us + // lose 2 messages instead of one, but it's more important at this + // point to warn the subscriber that they're losing messages so they + // can do something about it. + q.q[q.back].msg = nil + q.q[q.back].err = ErrDroppedMessages + return + } + // queue is not full, insert the message + q.back = next + q.q[next].msg = msg + q.q[next].err = nil + q.cond.Broadcast() +} + +func (q *msgQueue) close() { + q.cond.L.Lock() + defer q.cond.L.Unlock() + defer q.cond.Broadcast() + q.closed = true +} + +// dropped records an error in the queue that messages might have been dropped +func (q *msgQueue) dropped() { + q.cond.L.Lock() + defer q.cond.L.Unlock() + + next := (q.back + 1) % messageBufferSize + if next == q.front { + // queue is full, but we need to record that messages are being dropped, + // which we do at the last message in the queue. This potentially drops + // another message, but it's more important for the subscriber to know. + q.q[q.back].msg = nil + q.q[q.back].err = ErrDroppedMessages + return + } + // queue is not full, insert the error + q.back = next + q.q[next].msg = nil + q.q[next].err = ErrDroppedMessages + q.cond.Broadcast() +} + // Pubsub implementation using PostgreSQL. type pgPubsub struct { ctx context.Context pgListener *pq.Listener db *sql.DB mut sync.Mutex - listeners map[string]map[uuid.UUID]chan<- []byte + queues map[string]map[uuid.UUID]*msgQueue } // messageBufferSize is the maximum number of unhandled messages we will buffer @@ -38,6 +172,14 @@ const messageBufferSize = 2048 // Subscribe calls the listener when an event matching the name is received. func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) { + return p.subscribe(event, newMsgQueue(p.ctx, listener, nil)) +} + +func (p *pgPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) { + return p.subscribe(event, newMsgQueue(p.ctx, nil, listener)) +} + +func (p *pgPubsub) subscribe(event string, newQ *msgQueue) (cancel func(), err error) { p.mut.Lock() defer p.mut.Unlock() @@ -50,23 +192,20 @@ func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), er return nil, xerrors.Errorf("listen: %w", err) } - var eventListeners map[uuid.UUID]chan<- []byte + var eventQs map[uuid.UUID]*msgQueue var ok bool - if eventListeners, ok = p.listeners[event]; !ok { - eventListeners = make(map[uuid.UUID]chan<- []byte) - p.listeners[event] = eventListeners + if eventQs, ok = p.queues[event]; !ok { + eventQs = make(map[uuid.UUID]*msgQueue) + p.queues[event] = eventQs } - - ctx, cancelCallbacks := context.WithCancel(p.ctx) - messages := make(chan []byte, messageBufferSize) - go messagesToListener(ctx, messages, listener) id := uuid.New() - eventListeners[id] = messages + eventQs[id] = newQ return func() { p.mut.Lock() defer p.mut.Unlock() - cancelCallbacks() - listeners := p.listeners[event] + listeners := p.queues[event] + q := listeners[id] + q.close() delete(listeners, id) if len(listeners) == 0 { @@ -109,7 +248,7 @@ func (p *pgPubsub) listen(ctx context.Context) { } // A nil notification can be dispatched on reconnect. if notif == nil { - continue + p.recordReconnect() } p.listenReceive(notif) } @@ -118,19 +257,22 @@ func (p *pgPubsub) listen(ctx context.Context) { func (p *pgPubsub) listenReceive(notif *pq.Notification) { p.mut.Lock() defer p.mut.Unlock() - listeners, ok := p.listeners[notif.Channel] + queues, ok := p.queues[notif.Channel] if !ok { return } extra := []byte(notif.Extra) - for _, listener := range listeners { - select { - case listener <- extra: - // ok! - default: - // bad news, we dropped the event because the listener isn't - // keeping up - // TODO (spike): figure out a way to communicate this to the Listener + for _, q := range queues { + q.enqueue(extra) + } +} + +func (p *pgPubsub) recordReconnect() { + p.mut.Lock() + defer p.mut.Unlock() + for _, listeners := range p.queues { + for _, q := range listeners { + q.dropped() } } } @@ -162,20 +304,9 @@ func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub ctx: ctx, db: database, pgListener: listener, - listeners: make(map[string]map[uuid.UUID]chan<- []byte), + queues: make(map[string]map[uuid.UUID]*msgQueue), } go pgPubsub.listen(ctx) return pgPubsub, nil } - -func messagesToListener(ctx context.Context, messages <-chan []byte, listener Listener) { - for { - select { - case <-ctx.Done(): - return - case m := <-messages: - listener(ctx, m) - } - } -} diff --git a/coderd/database/pubsub_memory.go b/coderd/database/pubsub_memory.go index 4306ec10fb000..f24b34be3264b 100644 --- a/coderd/database/pubsub_memory.go +++ b/coderd/database/pubsub_memory.go @@ -7,20 +7,34 @@ import ( "github.com/google/uuid" ) +// genericListener is either a Listener or ListenerWithErr +type genericListener struct { + l Listener + le ListenerWithErr +} + // memoryPubsub is an in-memory Pubsub implementation. type memoryPubsub struct { mut sync.RWMutex - listeners map[string]map[uuid.UUID]Listener + listeners map[string]map[uuid.UUID]genericListener } func (m *memoryPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) { + return m.subscribe(event, genericListener{l: listener}) +} + +func (m *memoryPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) { + return m.subscribe(event, genericListener{le: listener}) +} + +func (m *memoryPubsub) subscribe(event string, listener genericListener) (cancel func(), err error) { m.mut.Lock() defer m.mut.Unlock() - var listeners map[uuid.UUID]Listener + var listeners map[uuid.UUID]genericListener var ok bool if listeners, ok = m.listeners[event]; !ok { - listeners = map[uuid.UUID]Listener{} + listeners = map[uuid.UUID]genericListener{} m.listeners[event] = listeners } var id uuid.UUID @@ -52,7 +66,12 @@ func (m *memoryPubsub) Publish(event string, message []byte) error { listener := listener go func() { defer wg.Done() - listener(context.Background(), message) + if listener.l != nil { + listener.l(context.Background(), message) + } + if listener.le != nil { + listener.le(context.Background(), message, nil) + } }() } wg.Wait() @@ -66,6 +85,6 @@ func (*memoryPubsub) Close() error { func NewPubsubInMemory() Pubsub { return &memoryPubsub{ - listeners: make(map[string]map[uuid.UUID]Listener), + listeners: make(map[string]map[uuid.UUID]genericListener), } } diff --git a/coderd/database/pubsub_memory_test.go b/coderd/database/pubsub_memory_test.go index 1bb23bc31b111..7856880d856c2 100644 --- a/coderd/database/pubsub_memory_test.go +++ b/coderd/database/pubsub_memory_test.go @@ -13,7 +13,7 @@ import ( func TestPubsubMemory(t *testing.T) { t.Parallel() - t.Run("Memory", func(t *testing.T) { + t.Run("Legacy", func(t *testing.T) { t.Parallel() pubsub := database.NewPubsubInMemory() @@ -32,4 +32,25 @@ func TestPubsubMemory(t *testing.T) { message := <-messageChannel assert.Equal(t, string(message), data) }) + + t.Run("WithErr", func(t *testing.T) { + t.Parallel() + + pubsub := database.NewPubsubInMemory() + event := "test" + data := "testing" + messageChannel := make(chan []byte) + cancelFunc, err := pubsub.SubscribeWithErr(event, func(ctx context.Context, message []byte, err error) { + assert.NoError(t, err) // memory pubsub never sends errors. + messageChannel <- message + }) + require.NoError(t, err) + defer cancelFunc() + go func() { + err = pubsub.Publish(event, []byte(data)) + assert.NoError(t, err) + }() + message := <-messageChannel + assert.Equal(t, string(message), data) + }) } From ede14157c94e679ead5c3105f50639169ccca613 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 24 May 2023 07:50:13 +0000 Subject: [PATCH 2/4] Pubsub with errors tests and fixes Signed-off-by: Spike Curtis --- coderd/database/postgres/postgres.go | 22 ++-- coderd/database/pubsub.go | 42 ++++--- coderd/database/pubsub_internal_test.go | 140 +++++++++++++++++++++ coderd/database/pubsub_memory.go | 22 ++-- coderd/database/pubsub_test.go | 155 +++++++++++++++++++++++- 5 files changed, 342 insertions(+), 39 deletions(-) create mode 100644 coderd/database/pubsub_internal_test.go diff --git a/coderd/database/postgres/postgres.go b/coderd/database/postgres/postgres.go index 16003d0de333d..5c1bae645a639 100644 --- a/coderd/database/postgres/postgres.go +++ b/coderd/database/postgres/postgres.go @@ -22,7 +22,8 @@ import ( // Super unlikely, but it happened. See: https://github.com/coder/coder/runs/5375197003 var openPortMutex sync.Mutex -// Open creates a new PostgreSQL server using a Docker container. +// Open creates a new PostgreSQL database instance. With DB_FROM environment variable set, it clones a database +// from the provided template. With the environment variable unset, it creates a new Docker container running postgres. func Open() (string, func(), error) { if os.Getenv("DB_FROM") != "" { // In CI, creating a Docker container for each test is slow. @@ -51,7 +52,12 @@ func Open() (string, func(), error) { // so cleaning up the container will clean up the database. }, nil } + return OpenContainerized(0) +} +// OpenContainerized creates a new PostgreSQL server using a Docker container. If port is nonzero, forward host traffic +// to that port to the database. If port is zero, allocate a free port from the OS. +func OpenContainerized(port int) (string, func(), error) { pool, err := dockertest.NewPool("") if err != nil { return "", nil, xerrors.Errorf("create pool: %w", err) @@ -63,12 +69,14 @@ func Open() (string, func(), error) { } openPortMutex.Lock() - // Pick an explicit port on the host to connect to 5432. - // This is necessary so we can configure the port to only use ipv4. - port, err := getFreePort() - if err != nil { - openPortMutex.Unlock() - return "", nil, xerrors.Errorf("get free port: %w", err) + if port == 0 { + // Pick an explicit port on the host to connect to 5432. + // This is necessary so we can configure the port to only use ipv4. + port, err = getFreePort() + if err != nil { + openPortMutex.Unlock() + return "", nil, xerrors.Errorf("get free port: %w", err) + } } resource, err := pool.RunWithOptions(&dockertest.RunOptions{ diff --git a/coderd/database/pubsub.go b/coderd/database/pubsub.go index 93c5ad50e8230..3a2be7d79cd58 100644 --- a/coderd/database/pubsub.go +++ b/coderd/database/pubsub.go @@ -48,9 +48,9 @@ type msgOrErr struct { type msgQueue struct { ctx context.Context cond *sync.Cond - q [messageBufferSize]msgOrErr + q [PubsubBufferSize]msgOrErr front int - back int + size int closed bool l Listener le ListenerWithErr @@ -74,7 +74,7 @@ func (q *msgQueue) run() { for { // wait until there is something on the queue or we are closed q.cond.L.Lock() - for q.front == q.back && !q.closed { + for q.size == 0 && !q.closed { q.cond.Wait() } if q.closed { @@ -82,7 +82,8 @@ func (q *msgQueue) run() { return } item := q.q[q.front] - q.front = (q.front + 1) % messageBufferSize + q.front = (q.front + 1) % PubsubBufferSize + q.size-- q.cond.L.Unlock() // process item without holding lock @@ -110,22 +111,23 @@ func (q *msgQueue) enqueue(msg []byte) { q.cond.L.Lock() defer q.cond.L.Unlock() - next := (q.back + 1) % messageBufferSize - if next == q.front { + if q.size == PubsubBufferSize { // queue is full, so we're going to drop the msg we got called with. // We also need to record that messages are being dropped, which we // do at the last message in the queue. This potentially makes us // lose 2 messages instead of one, but it's more important at this // point to warn the subscriber that they're losing messages so they // can do something about it. - q.q[q.back].msg = nil - q.q[q.back].err = ErrDroppedMessages + back := (q.front + PubsubBufferSize - 1) % PubsubBufferSize + q.q[back].msg = nil + q.q[back].err = ErrDroppedMessages return } // queue is not full, insert the message - q.back = next + next := (q.front + q.size) % PubsubBufferSize q.q[next].msg = msg q.q[next].err = nil + q.size++ q.cond.Broadcast() } @@ -141,19 +143,20 @@ func (q *msgQueue) dropped() { q.cond.L.Lock() defer q.cond.L.Unlock() - next := (q.back + 1) % messageBufferSize - if next == q.front { + if q.size == PubsubBufferSize { // queue is full, but we need to record that messages are being dropped, // which we do at the last message in the queue. This potentially drops // another message, but it's more important for the subscriber to know. - q.q[q.back].msg = nil - q.q[q.back].err = ErrDroppedMessages + back := (q.front + PubsubBufferSize - 1) % PubsubBufferSize + q.q[back].msg = nil + q.q[back].err = ErrDroppedMessages return } // queue is not full, insert the error - q.back = next + next := (q.front + q.size) % PubsubBufferSize q.q[next].msg = nil q.q[next].err = ErrDroppedMessages + q.size++ q.cond.Broadcast() } @@ -166,20 +169,20 @@ type pgPubsub struct { queues map[string]map[uuid.UUID]*msgQueue } -// messageBufferSize is the maximum number of unhandled messages we will buffer +// PubsubBufferSize is the maximum number of unhandled messages we will buffer // for a subscriber before dropping messages. -const messageBufferSize = 2048 +const PubsubBufferSize = 2048 // Subscribe calls the listener when an event matching the name is received. func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) { - return p.subscribe(event, newMsgQueue(p.ctx, listener, nil)) + return p.subscribeQueue(event, newMsgQueue(p.ctx, listener, nil)) } func (p *pgPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) { - return p.subscribe(event, newMsgQueue(p.ctx, nil, listener)) + return p.subscribeQueue(event, newMsgQueue(p.ctx, nil, listener)) } -func (p *pgPubsub) subscribe(event string, newQ *msgQueue) (cancel func(), err error) { +func (p *pgPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), err error) { p.mut.Lock() defer p.mut.Unlock() @@ -249,6 +252,7 @@ func (p *pgPubsub) listen(ctx context.Context) { // A nil notification can be dispatched on reconnect. if notif == nil { p.recordReconnect() + continue } p.listenReceive(notif) } diff --git a/coderd/database/pubsub_internal_test.go b/coderd/database/pubsub_internal_test.go new file mode 100644 index 0000000000000..31c50ce172176 --- /dev/null +++ b/coderd/database/pubsub_internal_test.go @@ -0,0 +1,140 @@ +package database + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/testutil" +) + +func Test_msgQueue_ListenerWithError(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + m := make(chan string) + e := make(chan error) + uut := newMsgQueue(ctx, nil, func(ctx context.Context, msg []byte, err error) { + m <- string(msg) + e <- err + }) + defer uut.close() + + // We're going to enqueue 4 messages and an error in a loop -- that is, a cycle of 5. + // PubsubBufferSize is 2048, which is a power of 2, so a pattern of 5 will not be aligned + // when we wrap around the end of the circular buffer. This tests that we correctly handle + // the wrapping and aren't dequeueing misaligned data. + cycles := (PubsubBufferSize / 5) * 2 // almost twice around the ring + for j := 0; j < cycles; j++ { + for i := 0; i < 4; i++ { + uut.enqueue([]byte(fmt.Sprintf("%d%d", j, i))) + } + uut.dropped() + for i := 0; i < 4; i++ { + select { + case <-ctx.Done(): + t.Fatal("timed out") + case msg := <-m: + require.Equal(t, fmt.Sprintf("%d%d", j, i), msg) + } + select { + case <-ctx.Done(): + t.Fatal("timed out") + case err := <-e: + require.NoError(t, err) + } + } + select { + case <-ctx.Done(): + t.Fatal("timed out") + case msg := <-m: + require.Equal(t, "", msg) + } + select { + case <-ctx.Done(): + t.Fatal("timed out") + case err := <-e: + require.ErrorIs(t, err, ErrDroppedMessages) + } + } +} + +func Test_msgQueue_Listener(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + m := make(chan string) + uut := newMsgQueue(ctx, func(ctx context.Context, msg []byte) { + m <- string(msg) + }, nil) + defer uut.close() + + // We're going to enqueue 4 messages and an error in a loop -- that is, a cycle of 5. + // PubsubBufferSize is 2048, which is a power of 2, so a pattern of 5 will not be aligned + // when we wrap around the end of the circular buffer. This tests that we correctly handle + // the wrapping and aren't dequeueing misaligned data. + cycles := (PubsubBufferSize / 5) * 2 // almost twice around the ring + for j := 0; j < cycles; j++ { + for i := 0; i < 4; i++ { + uut.enqueue([]byte(fmt.Sprintf("%d%d", j, i))) + } + uut.dropped() + for i := 0; i < 4; i++ { + select { + case <-ctx.Done(): + t.Fatal("timed out") + case msg := <-m: + require.Equal(t, fmt.Sprintf("%d%d", j, i), msg) + } + } + // Listener skips over errors, so we only read out the 4 real messages. + } +} + +func Test_msgQueue_Full(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + + firstDequeue := make(chan struct{}) + allowRead := make(chan struct{}) + n := 0 + errors := make(chan error) + uut := newMsgQueue(ctx, nil, func(ctx context.Context, msg []byte, err error) { + if n == 0 { + close(firstDequeue) + } + <-allowRead + if err == nil { + require.Equal(t, fmt.Sprintf("%d", n), string(msg)) + n++ + return + } + errors <- err + }) + defer uut.close() + + // we send 2 more than the capacity. One extra because the call to the ListenerFunc blocks + // but only after we've dequeued a message, and then another extra because we want to exceed + // the capacity, not just reach it. + for i := 0; i < PubsubBufferSize+2; i++ { + uut.enqueue([]byte(fmt.Sprintf("%d", i))) + // ensure the first dequeue has happened before proceeding, so that this function isn't racing + // against the goroutine that dequeues items. + <-firstDequeue + } + close(allowRead) + + select { + case <-ctx.Done(): + t.Fatal("timed out") + case err := <-errors: + require.ErrorIs(t, err, ErrDroppedMessages) + } + // Ok, so we sent 2 more than capacity, but we only read the capacity, that's because the last + // message we send doesn't get queued, AND, it bumps a message out of the queue to make room + // for the error, so we read 2 less than we sent. + require.Equal(t, PubsubBufferSize, n) +} diff --git a/coderd/database/pubsub_memory.go b/coderd/database/pubsub_memory.go index f24b34be3264b..0ab4684c80a3f 100644 --- a/coderd/database/pubsub_memory.go +++ b/coderd/database/pubsub_memory.go @@ -13,6 +13,15 @@ type genericListener struct { le ListenerWithErr } +func (g genericListener) send(ctx context.Context, message []byte) { + if g.l != nil { + g.l(ctx, message) + } + if g.le != nil { + g.le(ctx, message, nil) + } +} + // memoryPubsub is an in-memory Pubsub implementation. type memoryPubsub struct { mut sync.RWMutex @@ -20,14 +29,14 @@ type memoryPubsub struct { } func (m *memoryPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) { - return m.subscribe(event, genericListener{l: listener}) + return m.subscribeGeneric(event, genericListener{l: listener}) } func (m *memoryPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) { - return m.subscribe(event, genericListener{le: listener}) + return m.subscribeGeneric(event, genericListener{le: listener}) } -func (m *memoryPubsub) subscribe(event string, listener genericListener) (cancel func(), err error) { +func (m *memoryPubsub) subscribeGeneric(event string, listener genericListener) (cancel func(), err error) { m.mut.Lock() defer m.mut.Unlock() @@ -66,12 +75,7 @@ func (m *memoryPubsub) Publish(event string, message []byte) error { listener := listener go func() { defer wg.Done() - if listener.l != nil { - listener.l(context.Background(), message) - } - if listener.le != nil { - listener.le(context.Background(), message, nil) - } + listener.send(context.Background(), message) }() } wg.Wait() diff --git a/coderd/database/pubsub_test.go b/coderd/database/pubsub_test.go index 0abd0c8ca1177..64109b2af1fff 100644 --- a/coderd/database/pubsub_test.go +++ b/coderd/database/pubsub_test.go @@ -7,18 +7,26 @@ import ( "database/sql" "fmt" "math/rand" + "net" + "net/url" + "strconv" "testing" "time" - "github.com/coder/coder/testutil" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/goleak" + "golang.org/x/xerrors" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/postgres" + "github.com/coder/coder/testutil" ) +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + // nolint:tparallel,paralleltest func TestPubsub(t *testing.T) { t.Parallel() @@ -90,7 +98,7 @@ func TestPubsub_ordering(t *testing.T) { defer pubsub.Close() event := "test" messageChannel := make(chan []byte, 100) - cancelFunc, err = pubsub.Subscribe(event, func(ctx context.Context, message []byte) { + cancelSub, err := pubsub.Subscribe(event, func(ctx context.Context, message []byte) { // sleep a random amount of time to simulate handlers taking different amount of time // to process, depending on the message // nolint: gosec @@ -99,7 +107,7 @@ func TestPubsub_ordering(t *testing.T) { messageChannel <- message }) require.NoError(t, err) - defer cancelFunc() + defer cancelSub() for i := 0; i < 100; i++ { err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", i))) assert.NoError(t, err) @@ -113,3 +121,142 @@ func TestPubsub_ordering(t *testing.T) { } } } + +func TestPubsub_Disconnect(t *testing.T) { + t.Parallel() + // we always use a Docker container for this test, even in CI, since we need to be able to kill + // postgres and bring it back on the same port. + connectionURL, closePg, err := postgres.OpenContainerized(0) + require.NoError(t, err) + defer closePg() + db, err := sql.Open("postgres", connectionURL) + require.NoError(t, err) + defer db.Close() + + ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancelFunc() + pubsub, err := database.NewPubsub(ctx, db, connectionURL) + require.NoError(t, err) + defer pubsub.Close() + event := "test" + + errors := make(chan error) + messages := make(chan string) + readOne := func() (m string, e error) { + t.Helper() + select { + case <-ctx.Done(): + t.Fatal("timed out") + case m = <-messages: + // OK + } + select { + case <-ctx.Done(): + t.Fatal("timed out") + case e = <-errors: + // OK + } + return m, e + } + + cancelSub, err := pubsub.SubscribeWithErr(event, func(ctx context.Context, msg []byte, err error) { + messages <- string(msg) + errors <- err + }) + require.NoError(t, err) + defer cancelSub() + + for i := 0; i < 100; i++ { + err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", i))) + require.NoError(t, err) + } + // make sure we're getting at least one message. + m, err := readOne() + require.NoError(t, err) + require.Equal(t, "0", m) + + closePg() + // write some more messages until we hit an error + j := 100 + for { + select { + case <-ctx.Done(): + t.Fatal("timed out") + default: + // ok + } + err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", j))) + j++ + if err != nil { + break + } + time.Sleep(testutil.IntervalFast) + } + + // restart postgres on the same port --- since we only use LISTEN/NOTIFY it doesn't + // matter that the new postgres doesn't have any persisted state from before. + u, err := url.Parse(connectionURL) + require.NoError(t, err) + addr, err := net.ResolveTCPAddr("tcp", u.Host) + require.NoError(t, err) + newURL, closeNewPg, err := postgres.OpenContainerized(addr.Port) + require.NoError(t, err) + require.Equal(t, connectionURL, newURL) + defer closeNewPg() + + // now write messages until we DON'T hit an error -- pubsub is back up. + for { + select { + case <-ctx.Done(): + t.Fatal("timed out") + default: + // ok + } + err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", j))) + if err == nil { + break + } + j++ + time.Sleep(testutil.IntervalFast) + } + // any message k or higher comes from after the restart. + k := j + // exceeding the buffer invalidates the test because this causes us to drop messages for reasons other than DB + // reconnect + require.Less(t, k, database.PubsubBufferSize, "exceeded buffer") + + // We don't know how quickly the pubsub will reconnect, so continue to send messages with increasing numbers. As + // soon as we see k or higher we know we're getting messages after the restart. + go func() { + for { + select { + case <-ctx.Done(): + return + default: + // ok + } + _ = pubsub.Publish(event, []byte(fmt.Sprintf("%d", j))) + j++ + time.Sleep(testutil.IntervalFast) + } + }() + + gotDroppedErr := false + for { + m, err := readOne() + if xerrors.Is(err, database.ErrDroppedMessages) { + gotDroppedErr = true + continue + } + require.NoError(t, err, "should only get ErrDroppedMessages") + l, err := strconv.Atoi(m) + require.NoError(t, err) + if l >= k { + // exceeding the buffer invalidates the test because this causes us to drop messages for reasons other than + // DB reconnect + require.Less(t, l, database.PubsubBufferSize, "exceeded buffer") + break + } + } + require.True(t, gotDroppedErr) +} From bbc8c4af9d9b9738017e27dba066c9bfd062556a Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 25 May 2023 05:50:30 +0000 Subject: [PATCH 3/4] Deal with test goroutines Signed-off-by: Spike Curtis --- coderd/database/pubsub_test.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/coderd/database/pubsub_test.go b/coderd/database/pubsub_test.go index 64109b2af1fff..9fe8e340bc740 100644 --- a/coderd/database/pubsub_test.go +++ b/coderd/database/pubsub_test.go @@ -140,8 +140,9 @@ func TestPubsub_Disconnect(t *testing.T) { defer pubsub.Close() event := "test" - errors := make(chan error) - messages := make(chan string) + // buffer responses so that when the test completes, goroutines don't get blocked & leak + errors := make(chan error, database.PubsubBufferSize) + messages := make(chan string, database.PubsubBufferSize) readOne := func() (m string, e error) { t.Helper() select { From 8e8d5c51ba58e41386ca71a03b0e18a93d2dab2e Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 25 May 2023 06:12:34 +0000 Subject: [PATCH 4/4] remove goleak -- lib/pq uses sleeps on notify :-( Signed-off-by: Spike Curtis --- coderd/database/pubsub_test.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/coderd/database/pubsub_test.go b/coderd/database/pubsub_test.go index 9fe8e340bc740..d1241492ff00f 100644 --- a/coderd/database/pubsub_test.go +++ b/coderd/database/pubsub_test.go @@ -15,7 +15,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.uber.org/goleak" "golang.org/x/xerrors" "github.com/coder/coder/coderd/database" @@ -23,10 +22,6 @@ import ( "github.com/coder/coder/testutil" ) -func TestMain(m *testing.M) { - goleak.VerifyTestMain(m) -} - // nolint:tparallel,paralleltest func TestPubsub(t *testing.T) { t.Parallel()