From ff526cc5a8812fd766ac09346825b048e7544ef6 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 1 Feb 2024 09:12:19 +0400 Subject: [PATCH] feat: add metrics to PGPubsub --- cli/server.go | 6 +- coderd/database/pubsub/pubsub.go | 192 ++++++++-- coderd/database/pubsub/pubsub_linux_test.go | 350 +++++++++++++++++ coderd/database/pubsub/pubsub_test.go | 394 +++++--------------- 4 files changed, 615 insertions(+), 327 deletions(-) create mode 100644 coderd/database/pubsub/pubsub_linux_test.go diff --git a/cli/server.go b/cli/server.go index 1df5f49855909..fe53c5a09309b 100644 --- a/cli/server.go +++ b/cli/server.go @@ -673,10 +673,14 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. }() options.Database = database.New(sqlDB) - options.Pubsub, err = pubsub.New(ctx, logger.Named("pubsub"), sqlDB, dbURL) + ps, err := pubsub.New(ctx, logger.Named("pubsub"), sqlDB, dbURL) if err != nil { return xerrors.Errorf("create pubsub: %w", err) } + options.Pubsub = ps + if options.DeploymentValues.Prometheus.Enable { + options.PrometheusRegistry.MustRegister(ps) + } defer options.Pubsub.Close() } diff --git a/coderd/database/pubsub/pubsub.go b/coderd/database/pubsub/pubsub.go index d70b5f5f9ce9a..6bab8d279bdc1 100644 --- a/coderd/database/pubsub/pubsub.go +++ b/coderd/database/pubsub/pubsub.go @@ -9,6 +9,7 @@ import ( "github.com/google/uuid" "github.com/lib/pq" + "github.com/prometheus/client_golang/prometheus" "golang.org/x/xerrors" "cdr.dev/slog" @@ -162,8 +163,8 @@ func (q *msgQueue) dropped() { q.cond.Broadcast() } -// Pubsub implementation using PostgreSQL. -type pgPubsub struct { +// PGPubsub is a pubsub implementation using PostgreSQL. +type PGPubsub struct { ctx context.Context cancel context.CancelFunc logger slog.Logger @@ -174,6 +175,14 @@ type pgPubsub struct { queues map[string]map[uuid.UUID]*msgQueue closedListener bool closeListenerErr error + + publishesTotal *prometheus.CounterVec + subscribesTotal *prometheus.CounterVec + messagesTotal *prometheus.CounterVec + publishedBytesTotal prometheus.Counter + receivedBytesTotal prometheus.Counter + disconnectionsTotal prometheus.Counter + connected prometheus.Gauge } // BufferSize is the maximum number of unhandled messages we will buffer @@ -181,15 +190,15 @@ type pgPubsub struct { const BufferSize = 2048 // Subscribe calls the listener when an event matching the name is received. -func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) { +func (p *PGPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) { return p.subscribeQueue(event, newMsgQueue(p.ctx, listener, nil)) } -func (p *pgPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) { +func (p *PGPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) { return p.subscribeQueue(event, newMsgQueue(p.ctx, nil, listener)) } -func (p *pgPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), err error) { +func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), err error) { p.mut.Lock() defer p.mut.Unlock() defer func() { @@ -197,6 +206,9 @@ func (p *pgPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), // if we hit an error, we need to close the queue so we don't // leak its goroutine. newQ.close() + p.subscribesTotal.WithLabelValues("false").Inc() + } else { + p.subscribesTotal.WithLabelValues("true").Inc() } }() @@ -239,20 +251,23 @@ func (p *pgPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), }, nil } -func (p *pgPubsub) Publish(event string, message []byte) error { +func (p *PGPubsub) Publish(event string, message []byte) error { p.logger.Debug(p.ctx, "publish", slog.F("event", event), slog.F("message_len", len(message))) // This is safe because we are calling pq.QuoteLiteral. pg_notify doesn't // support the first parameter being a prepared statement. //nolint:gosec _, err := p.db.ExecContext(p.ctx, `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message) if err != nil { + p.publishesTotal.WithLabelValues("false").Inc() return xerrors.Errorf("exec pg_notify: %w", err) } + p.publishesTotal.WithLabelValues("true").Inc() + p.publishedBytesTotal.Add(float64(len(message))) return nil } // Close closes the pubsub instance. -func (p *pgPubsub) Close() error { +func (p *PGPubsub) Close() error { p.logger.Info(p.ctx, "pubsub is closing") p.cancel() err := p.closeListener() @@ -262,7 +277,7 @@ func (p *pgPubsub) Close() error { } // closeListener closes the pgListener, unless it has already been closed. -func (p *pgPubsub) closeListener() error { +func (p *PGPubsub) closeListener() error { p.mut.Lock() defer p.mut.Unlock() if p.closedListener { @@ -274,7 +289,7 @@ func (p *pgPubsub) closeListener() error { } // listen begins receiving messages on the pq listener. -func (p *pgPubsub) listen() { +func (p *PGPubsub) listen() { defer func() { p.logger.Info(p.ctx, "pubsub listen stopped receiving notify") cErr := p.closeListener() @@ -307,7 +322,14 @@ func (p *pgPubsub) listen() { } } -func (p *pgPubsub) listenReceive(notif *pq.Notification) { +func (p *PGPubsub) listenReceive(notif *pq.Notification) { + sizeLabel := messageSizeNormal + if len(notif.Extra) >= colossalThreshold { + sizeLabel = messageSizeColossal + } + p.messagesTotal.WithLabelValues(sizeLabel).Inc() + p.receivedBytesTotal.Add(float64(len(notif.Extra))) + p.mut.Lock() defer p.mut.Unlock() queues, ok := p.queues[notif.Channel] @@ -320,7 +342,7 @@ func (p *pgPubsub) listenReceive(notif *pq.Notification) { } } -func (p *pgPubsub) recordReconnect() { +func (p *PGPubsub) recordReconnect() { p.mut.Lock() defer p.mut.Unlock() for _, listeners := range p.queues { @@ -330,20 +352,23 @@ func (p *pgPubsub) recordReconnect() { } } -// New creates a new Pubsub implementation using a PostgreSQL connection. -func New(ctx context.Context, logger slog.Logger, database *sql.DB, connectURL string) (Pubsub, error) { +func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error { + p.connected.Set(0) // Creates a new listener using pq. errCh := make(chan error) - listener := pq.NewListener(connectURL, time.Second, time.Minute, func(t pq.ListenerEventType, err error) { + p.pgListener = pq.NewListener(connectURL, time.Second, time.Minute, func(t pq.ListenerEventType, err error) { switch t { case pq.ListenerEventConnected: - logger.Info(ctx, "pubsub connected to postgres") + p.logger.Info(ctx, "pubsub connected to postgres") + p.connected.Set(1.0) case pq.ListenerEventDisconnected: - logger.Error(ctx, "pubsub disconnected from postgres", slog.Error(err)) + p.logger.Error(ctx, "pubsub disconnected from postgres", slog.Error(err)) + p.connected.Set(0) case pq.ListenerEventReconnected: - logger.Info(ctx, "pubsub reconnected to postgres") + p.logger.Info(ctx, "pubsub reconnected to postgres") + p.connected.Set(1) case pq.ListenerEventConnectionAttemptFailed: - logger.Error(ctx, "pubsub failed to connect to postgres", slog.Error(err)) + p.logger.Error(ctx, "pubsub failed to connect to postgres", slog.Error(err)) } // This callback gets events whenever the connection state changes. // Don't send if the errChannel has already been closed. @@ -358,26 +383,141 @@ func New(ctx context.Context, logger slog.Logger, database *sql.DB, connectURL s select { case err := <-errCh: if err != nil { - _ = listener.Close() - return nil, xerrors.Errorf("create pq listener: %w", err) + _ = p.pgListener.Close() + return xerrors.Errorf("create pq listener: %w", err) } case <-ctx.Done(): - _ = listener.Close() - return nil, ctx.Err() + _ = p.pgListener.Close() + return ctx.Err() } + return nil +} +// these are the metrics we compute implicitly from our existing data structures +var ( + currentSubscribersDesc = prometheus.NewDesc( + "coder_pubsub_current_subscribers", + "The current number of active pubsub subscribers", + nil, nil, + ) + currentEventsDesc = prometheus.NewDesc( + "coder_pubsub_current_events", + "The current number of pubsub event channels listened for", + nil, nil, + ) +) + +// We'll track messages as size "normal" and "colossal", where the +// latter are messages larger than 7600 bytes, or 95% of the postgres +// notify limit. If we see a lot of colossal packets that's an indication that +// we might be trying to send too much data over the pubsub and are in danger of +// failing to publish. +const ( + colossalThreshold = 7600 + messageSizeNormal = "normal" + messageSizeColossal = "colossal" +) + +// Describe implements, along with Collect, the prometheus.Collector interface +// for metrics. +func (p *PGPubsub) Describe(descs chan<- *prometheus.Desc) { + // explicit metrics + p.publishesTotal.Describe(descs) + p.subscribesTotal.Describe(descs) + p.messagesTotal.Describe(descs) + p.publishedBytesTotal.Describe(descs) + p.receivedBytesTotal.Describe(descs) + p.disconnectionsTotal.Describe(descs) + p.connected.Describe(descs) + + // implicit metrics + descs <- currentSubscribersDesc + descs <- currentEventsDesc +} + +// Collect implements, along with Describe, the prometheus.Collector interface +// for metrics +func (p *PGPubsub) Collect(metrics chan<- prometheus.Metric) { + // explicit metrics + p.publishesTotal.Collect(metrics) + p.subscribesTotal.Collect(metrics) + p.messagesTotal.Collect(metrics) + p.publishedBytesTotal.Collect(metrics) + p.receivedBytesTotal.Collect(metrics) + p.disconnectionsTotal.Collect(metrics) + p.connected.Collect(metrics) + + // implicit metrics + p.mut.Lock() + events := len(p.queues) + subs := 0 + for _, subscriberMap := range p.queues { + subs += len(subscriberMap) + } + p.mut.Unlock() + metrics <- prometheus.MustNewConstMetric(currentSubscribersDesc, prometheus.GaugeValue, float64(subs)) + metrics <- prometheus.MustNewConstMetric(currentEventsDesc, prometheus.GaugeValue, float64(events)) +} + +// New creates a new Pubsub implementation using a PostgreSQL connection. +func New(startCtx context.Context, logger slog.Logger, database *sql.DB, connectURL string) (*PGPubsub, error) { // Start a new context that will be canceled when the pubsub is closed. ctx, cancel := context.WithCancel(context.Background()) - pgPubsub := &pgPubsub{ + p := &PGPubsub{ ctx: ctx, cancel: cancel, logger: logger, listenDone: make(chan struct{}), db: database, - pgListener: listener, queues: make(map[string]map[uuid.UUID]*msgQueue), + + publishesTotal: prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "coder", + Subsystem: "pubsub", + Name: "publishes_total", + Help: "Total number of calls to Publish", + }, []string{"success"}), + subscribesTotal: prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "coder", + Subsystem: "pubsub", + Name: "subscribes_total", + Help: "Total number of calls to Subscribe/SubscribeWithErr", + }, []string{"success"}), + messagesTotal: prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "coder", + Subsystem: "pubsub", + Name: "messages_total", + Help: "Total number of messages received from postgres", + }, []string{"size"}), + publishedBytesTotal: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "coder", + Subsystem: "pubsub", + Name: "published_bytes_total", + Help: "Total number of bytes successfully published across all publishes", + }), + receivedBytesTotal: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "coder", + Subsystem: "pubsub", + Name: "received_bytes_total", + Help: "Total number of bytes received across all messages", + }), + disconnectionsTotal: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "coder", + Subsystem: "pubsub", + Name: "disconnections_total", + Help: "Total number of times we disconnected unexpectedly from postgres", + }), + connected: prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: "coder", + Subsystem: "pubsub", + Name: "connected", + Help: "Whether we are connected (1) or not connected (0) to postgres", + }), + } + if err := p.startListener(startCtx, connectURL); err != nil { + return nil, err } - go pgPubsub.listen() + go p.listen() logger.Info(ctx, "pubsub has started") - return pgPubsub, nil + return p, nil } diff --git a/coderd/database/pubsub/pubsub_linux_test.go b/coderd/database/pubsub/pubsub_linux_test.go new file mode 100644 index 0000000000000..c25af429a5d78 --- /dev/null +++ b/coderd/database/pubsub/pubsub_linux_test.go @@ -0,0 +1,350 @@ +//go:build linux + +package pubsub_test + +import ( + "context" + "database/sql" + "fmt" + "math/rand" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database/postgres" + "github.com/coder/coder/v2/coderd/database/pubsub" + "github.com/coder/coder/v2/testutil" +) + +// nolint:tparallel,paralleltest +func TestPubsub(t *testing.T) { + t.Parallel() + + if testing.Short() { + t.SkipNow() + return + } + + t.Run("Postgres", func(t *testing.T) { + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + connectionURL, closePg, err := postgres.Open() + require.NoError(t, err) + defer closePg() + db, err := sql.Open("postgres", connectionURL) + require.NoError(t, err) + defer db.Close() + pubsub, err := pubsub.New(ctx, logger, db, connectionURL) + require.NoError(t, err) + defer pubsub.Close() + event := "test" + data := "testing" + messageChannel := make(chan []byte) + unsub, err := pubsub.Subscribe(event, func(ctx context.Context, message []byte) { + messageChannel <- message + }) + require.NoError(t, err) + defer unsub() + go func() { + err = pubsub.Publish(event, []byte(data)) + assert.NoError(t, err) + }() + message := <-messageChannel + assert.Equal(t, string(message), data) + }) + + t.Run("PostgresCloseCancel", func(t *testing.T) { + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + connectionURL, closePg, err := postgres.Open() + require.NoError(t, err) + defer closePg() + db, err := sql.Open("postgres", connectionURL) + require.NoError(t, err) + defer db.Close() + pubsub, err := pubsub.New(ctx, logger, db, connectionURL) + require.NoError(t, err) + defer pubsub.Close() + cancelFunc() + }) + + t.Run("NotClosedOnCancelContext", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + connectionURL, closePg, err := postgres.Open() + require.NoError(t, err) + defer closePg() + db, err := sql.Open("postgres", connectionURL) + require.NoError(t, err) + defer db.Close() + pubsub, err := pubsub.New(ctx, logger, db, connectionURL) + require.NoError(t, err) + defer pubsub.Close() + + // Provided context must only be active during NewPubsub, not after. + cancel() + + event := "test" + data := "testing" + messageChannel := make(chan []byte) + unsub, err := pubsub.Subscribe(event, func(_ context.Context, message []byte) { + messageChannel <- message + }) + require.NoError(t, err) + defer unsub() + go func() { + err = pubsub.Publish(event, []byte(data)) + assert.NoError(t, err) + }() + message := <-messageChannel + assert.Equal(t, string(message), data) + }) + + t.Run("ClosePropagatesContextCancellationToSubscription", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + connectionURL, closePg, err := postgres.Open() + require.NoError(t, err) + defer closePg() + db, err := sql.Open("postgres", connectionURL) + require.NoError(t, err) + defer db.Close() + pubsub, err := pubsub.New(ctx, logger, db, connectionURL) + require.NoError(t, err) + defer pubsub.Close() + + event := "test" + done := make(chan struct{}) + called := make(chan struct{}) + unsub, err := pubsub.Subscribe(event, func(subCtx context.Context, _ []byte) { + defer close(done) + select { + case <-subCtx.Done(): + assert.Fail(t, "context should not be canceled") + default: + } + close(called) + select { + case <-subCtx.Done(): + case <-ctx.Done(): + assert.Fail(t, "timeout waiting for sub context to be canceled") + } + }) + require.NoError(t, err) + defer unsub() + + go func() { + err := pubsub.Publish(event, nil) + assert.NoError(t, err) + }() + + select { + case <-called: + case <-ctx.Done(): + require.Fail(t, "timeout waiting for handler to be called") + } + err = pubsub.Close() + require.NoError(t, err) + + select { + case <-done: + case <-ctx.Done(): + require.Fail(t, "timeout waiting for handler to finish") + } + }) +} + +func TestPubsub_ordering(t *testing.T) { + t.Parallel() + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + connectionURL, closePg, err := postgres.Open() + require.NoError(t, err) + defer closePg() + db, err := sql.Open("postgres", connectionURL) + require.NoError(t, err) + defer db.Close() + ps, err := pubsub.New(ctx, logger, db, connectionURL) + require.NoError(t, err) + defer ps.Close() + event := "test" + messageChannel := make(chan []byte, 100) + cancelSub, err := ps.Subscribe(event, func(ctx context.Context, message []byte) { + // sleep a random amount of time to simulate handlers taking different amount of time + // to process, depending on the message + // nolint: gosec + n := rand.Intn(100) + time.Sleep(time.Duration(n) * time.Millisecond) + messageChannel <- message + }) + require.NoError(t, err) + defer cancelSub() + for i := 0; i < 100; i++ { + err = ps.Publish(event, []byte(fmt.Sprintf("%d", i))) + assert.NoError(t, err) + } + for i := 0; i < 100; i++ { + select { + case <-time.After(testutil.WaitShort): + t.Fatalf("timed out waiting for message %d", i) + case message := <-messageChannel: + assert.Equal(t, fmt.Sprintf("%d", i), string(message)) + } + } +} + +// disconnectTestPort is the hardcoded port for TestPubsub_Disconnect. In this test we need to be able to stop Postgres +// and restart it on the same port. If we use an ephemeral port, there is a chance the OS will reallocate before we +// start back up. The downside is that if the test crashes and leaves the container up, subsequent test runs will fail +// until we manually kill the container. +const disconnectTestPort = 26892 + +// nolint: paralleltest +func TestPubsub_Disconnect(t *testing.T) { + // we always use a Docker container for this test, even in CI, since we need to be able to kill + // postgres and bring it back on the same port. + connectionURL, closePg, err := postgres.OpenContainerized(disconnectTestPort) + require.NoError(t, err) + defer closePg() + db, err := sql.Open("postgres", connectionURL) + require.NoError(t, err) + defer db.Close() + + ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancelFunc() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + ps, err := pubsub.New(ctx, logger, db, connectionURL) + require.NoError(t, err) + defer ps.Close() + event := "test" + + // buffer responses so that when the test completes, goroutines don't get blocked & leak + errors := make(chan error, pubsub.BufferSize) + messages := make(chan string, pubsub.BufferSize) + readOne := func() (m string, e error) { + t.Helper() + select { + case <-ctx.Done(): + t.Fatal("timed out") + case m = <-messages: + // OK + } + select { + case <-ctx.Done(): + t.Fatal("timed out") + case e = <-errors: + // OK + } + return m, e + } + + cancelSub, err := ps.SubscribeWithErr(event, func(ctx context.Context, msg []byte, err error) { + messages <- string(msg) + errors <- err + }) + require.NoError(t, err) + defer cancelSub() + + for i := 0; i < 100; i++ { + err = ps.Publish(event, []byte(fmt.Sprintf("%d", i))) + require.NoError(t, err) + } + // make sure we're getting at least one message. + m, err := readOne() + require.NoError(t, err) + require.Equal(t, "0", m) + + closePg() + // write some more messages until we hit an error + j := 100 + for { + select { + case <-ctx.Done(): + t.Fatal("timed out") + default: + // ok + } + err = ps.Publish(event, []byte(fmt.Sprintf("%d", j))) + j++ + if err != nil { + break + } + time.Sleep(testutil.IntervalFast) + } + + // restart postgres on the same port --- since we only use LISTEN/NOTIFY it doesn't + // matter that the new postgres doesn't have any persisted state from before. + _, closeNewPg, err := postgres.OpenContainerized(disconnectTestPort) + require.NoError(t, err) + defer closeNewPg() + + // now write messages until we DON'T hit an error -- pubsub is back up. + for { + select { + case <-ctx.Done(): + t.Fatal("timed out") + default: + // ok + } + err = ps.Publish(event, []byte(fmt.Sprintf("%d", j))) + if err == nil { + break + } + j++ + time.Sleep(testutil.IntervalFast) + } + // any message k or higher comes from after the restart. + k := j + // exceeding the buffer invalidates the test because this causes us to drop messages for reasons other than DB + // reconnect + require.Less(t, k, pubsub.BufferSize, "exceeded buffer") + + // We don't know how quickly the pubsub will reconnect, so continue to send messages with increasing numbers. As + // soon as we see k or higher we know we're getting messages after the restart. + go func() { + for { + select { + case <-ctx.Done(): + return + default: + // ok + } + _ = ps.Publish(event, []byte(fmt.Sprintf("%d", j))) + j++ + time.Sleep(testutil.IntervalFast) + } + }() + + gotDroppedErr := false + for { + m, err := readOne() + if xerrors.Is(err, pubsub.ErrDroppedMessages) { + gotDroppedErr = true + continue + } + require.NoError(t, err, "should only get ErrDroppedMessages") + l, err := strconv.Atoi(m) + require.NoError(t, err) + if l >= k { + // exceeding the buffer invalidates the test because this causes us to drop messages for reasons other than + // DB reconnect + require.Less(t, l, pubsub.BufferSize, "exceeded buffer") + break + } + } + require.True(t, gotDroppedErr) +} diff --git a/coderd/database/pubsub/pubsub_test.go b/coderd/database/pubsub/pubsub_test.go index c25af429a5d78..204d7f55a1c68 100644 --- a/coderd/database/pubsub/pubsub_test.go +++ b/coderd/database/pubsub/pubsub_test.go @@ -1,350 +1,144 @@ -//go:build linux - package pubsub_test import ( "context" "database/sql" - "fmt" - "math/rand" - "strconv" "testing" - "time" + "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/xerrors" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/postgres" "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/testutil" ) -// nolint:tparallel,paralleltest -func TestPubsub(t *testing.T) { +func TestPGPubsub_Metrics(t *testing.T) { t.Parallel() - - if testing.Short() { - t.SkipNow() - return + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") } - t.Run("Postgres", func(t *testing.T) { - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - - connectionURL, closePg, err := postgres.Open() - require.NoError(t, err) - defer closePg() - db, err := sql.Open("postgres", connectionURL) - require.NoError(t, err) - defer db.Close() - pubsub, err := pubsub.New(ctx, logger, db, connectionURL) - require.NoError(t, err) - defer pubsub.Close() - event := "test" - data := "testing" - messageChannel := make(chan []byte) - unsub, err := pubsub.Subscribe(event, func(ctx context.Context, message []byte) { - messageChannel <- message - }) - require.NoError(t, err) - defer unsub() - go func() { - err = pubsub.Publish(event, []byte(data)) - assert.NoError(t, err) - }() - message := <-messageChannel - assert.Equal(t, string(message), data) - }) - - t.Run("PostgresCloseCancel", func(t *testing.T) { - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - connectionURL, closePg, err := postgres.Open() - require.NoError(t, err) - defer closePg() - db, err := sql.Open("postgres", connectionURL) - require.NoError(t, err) - defer db.Close() - pubsub, err := pubsub.New(ctx, logger, db, connectionURL) - require.NoError(t, err) - defer pubsub.Close() - cancelFunc() - }) - - t.Run("NotClosedOnCancelContext", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - connectionURL, closePg, err := postgres.Open() - require.NoError(t, err) - defer closePg() - db, err := sql.Open("postgres", connectionURL) - require.NoError(t, err) - defer db.Close() - pubsub, err := pubsub.New(ctx, logger, db, connectionURL) - require.NoError(t, err) - defer pubsub.Close() - - // Provided context must only be active during NewPubsub, not after. - cancel() - - event := "test" - data := "testing" - messageChannel := make(chan []byte) - unsub, err := pubsub.Subscribe(event, func(_ context.Context, message []byte) { - messageChannel <- message - }) - require.NoError(t, err) - defer unsub() - go func() { - err = pubsub.Publish(event, []byte(data)) - assert.NoError(t, err) - }() - message := <-messageChannel - assert.Equal(t, string(message), data) - }) - - t.Run("ClosePropagatesContextCancellationToSubscription", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - connectionURL, closePg, err := postgres.Open() - require.NoError(t, err) - defer closePg() - db, err := sql.Open("postgres", connectionURL) - require.NoError(t, err) - defer db.Close() - pubsub, err := pubsub.New(ctx, logger, db, connectionURL) - require.NoError(t, err) - defer pubsub.Close() - - event := "test" - done := make(chan struct{}) - called := make(chan struct{}) - unsub, err := pubsub.Subscribe(event, func(subCtx context.Context, _ []byte) { - defer close(done) - select { - case <-subCtx.Done(): - assert.Fail(t, "context should not be canceled") - default: - } - close(called) - select { - case <-subCtx.Done(): - case <-ctx.Done(): - assert.Fail(t, "timeout waiting for sub context to be canceled") - } - }) - require.NoError(t, err) - defer unsub() - - go func() { - err := pubsub.Publish(event, nil) - assert.NoError(t, err) - }() - - select { - case <-called: - case <-ctx.Done(): - require.Fail(t, "timeout waiting for handler to be called") - } - err = pubsub.Close() - require.NoError(t, err) - - select { - case <-done: - case <-ctx.Done(): - require.Fail(t, "timeout waiting for handler to finish") - } - }) -} - -func TestPubsub_ordering(t *testing.T) { - t.Parallel() - - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - connectionURL, closePg, err := postgres.Open() require.NoError(t, err) defer closePg() db, err := sql.Open("postgres", connectionURL) require.NoError(t, err) defer db.Close() - ps, err := pubsub.New(ctx, logger, db, connectionURL) - require.NoError(t, err) - defer ps.Close() - event := "test" - messageChannel := make(chan []byte, 100) - cancelSub, err := ps.Subscribe(event, func(ctx context.Context, message []byte) { - // sleep a random amount of time to simulate handlers taking different amount of time - // to process, depending on the message - // nolint: gosec - n := rand.Intn(100) - time.Sleep(time.Duration(n) * time.Millisecond) - messageChannel <- message - }) - require.NoError(t, err) - defer cancelSub() - for i := 0; i < 100; i++ { - err = ps.Publish(event, []byte(fmt.Sprintf("%d", i))) - assert.NoError(t, err) - } - for i := 0; i < 100; i++ { - select { - case <-time.After(testutil.WaitShort): - t.Fatalf("timed out waiting for message %d", i) - case message := <-messageChannel: - assert.Equal(t, fmt.Sprintf("%d", i), string(message)) - } - } -} + registry := prometheus.NewRegistry() + ctx := testutil.Context(t, testutil.WaitLong) -// disconnectTestPort is the hardcoded port for TestPubsub_Disconnect. In this test we need to be able to stop Postgres -// and restart it on the same port. If we use an ephemeral port, there is a chance the OS will reallocate before we -// start back up. The downside is that if the test crashes and leaves the container up, subsequent test runs will fail -// until we manually kill the container. -const disconnectTestPort = 26892 - -// nolint: paralleltest -func TestPubsub_Disconnect(t *testing.T) { - // we always use a Docker container for this test, even in CI, since we need to be able to kill - // postgres and bring it back on the same port. - connectionURL, closePg, err := postgres.OpenContainerized(disconnectTestPort) + uut, err := pubsub.New(ctx, logger, db, connectionURL) require.NoError(t, err) - defer closePg() - db, err := sql.Open("postgres", connectionURL) - require.NoError(t, err) - defer db.Close() + defer uut.Close() - ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong) - defer cancelFunc() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - ps, err := pubsub.New(ctx, logger, db, connectionURL) + err = registry.Register(uut) require.NoError(t, err) - defer ps.Close() - event := "test" - // buffer responses so that when the test completes, goroutines don't get blocked & leak - errors := make(chan error, pubsub.BufferSize) - messages := make(chan string, pubsub.BufferSize) - readOne := func() (m string, e error) { - t.Helper() - select { - case <-ctx.Done(): - t.Fatal("timed out") - case m = <-messages: - // OK - } - select { - case <-ctx.Done(): - t.Fatal("timed out") - case e = <-errors: - // OK - } - return m, e - } + metrics, err := registry.Gather() + require.NoError(t, err) + requireGaugeValue(t, metrics, 0, "coder_pubsub_current_events") + requireGaugeValue(t, metrics, 0, "coder_pubsub_current_subscribers") - cancelSub, err := ps.SubscribeWithErr(event, func(ctx context.Context, msg []byte, err error) { - messages <- string(msg) - errors <- err + event := "test" + data := "testing" + messageChannel := make(chan []byte) + unsub0, err := uut.Subscribe(event, func(ctx context.Context, message []byte) { + messageChannel <- message }) require.NoError(t, err) - defer cancelSub() + defer unsub0() + go func() { + err = uut.Publish(event, []byte(data)) + assert.NoError(t, err) + }() + _ = testutil.RequireRecvCtx(ctx, t, messageChannel) - for i := 0; i < 100; i++ { - err = ps.Publish(event, []byte(fmt.Sprintf("%d", i))) - require.NoError(t, err) - } - // make sure we're getting at least one message. - m, err := readOne() + metrics, err = registry.Gather() require.NoError(t, err) - require.Equal(t, "0", m) - - closePg() - // write some more messages until we hit an error - j := 100 - for { - select { - case <-ctx.Done(): - t.Fatal("timed out") - default: - // ok - } - err = ps.Publish(event, []byte(fmt.Sprintf("%d", j))) - j++ - if err != nil { - break - } - time.Sleep(testutil.IntervalFast) + requireGaugeValue(t, metrics, 1, "coder_pubsub_current_events") + requireGaugeValue(t, metrics, 1, "coder_pubsub_current_subscribers") + requireGaugeValue(t, metrics, 1, "coder_pubsub_connected") + requireCounterValue(t, metrics, 1, "coder_pubsub_publishes_total", "true") + requireCounterValue(t, metrics, 1, "coder_pubsub_subscribes_total", "true") + requireCounterValue(t, metrics, 1, "coder_pubsub_messages_total", "normal") + requireCounterValue(t, metrics, 7, "coder_pubsub_received_bytes_total") + requireCounterValue(t, metrics, 7, "coder_pubsub_published_bytes_total") + + colossalData := make([]byte, 7600) + for i := range colossalData { + colossalData[i] = 'q' } + unsub1, err := uut.Subscribe(event, func(ctx context.Context, message []byte) { + messageChannel <- message + }) + require.NoError(t, err) + defer unsub1() + go func() { + err = uut.Publish(event, colossalData) + assert.NoError(t, err) + }() + // should get 2 messages because we have 2 subs + _ = testutil.RequireRecvCtx(ctx, t, messageChannel) + _ = testutil.RequireRecvCtx(ctx, t, messageChannel) - // restart postgres on the same port --- since we only use LISTEN/NOTIFY it doesn't - // matter that the new postgres doesn't have any persisted state from before. - _, closeNewPg, err := postgres.OpenContainerized(disconnectTestPort) + metrics, err = registry.Gather() require.NoError(t, err) - defer closeNewPg() + requireGaugeValue(t, metrics, 1, "coder_pubsub_current_events") + requireGaugeValue(t, metrics, 2, "coder_pubsub_current_subscribers") + requireGaugeValue(t, metrics, 1, "coder_pubsub_connected") + requireCounterValue(t, metrics, 2, "coder_pubsub_publishes_total", "true") + requireCounterValue(t, metrics, 2, "coder_pubsub_subscribes_total", "true") + requireCounterValue(t, metrics, 1, "coder_pubsub_messages_total", "normal") + requireCounterValue(t, metrics, 1, "coder_pubsub_messages_total", "colossal") + requireCounterValue(t, metrics, 7607, "coder_pubsub_received_bytes_total") + requireCounterValue(t, metrics, 7607, "coder_pubsub_published_bytes_total") +} - // now write messages until we DON'T hit an error -- pubsub is back up. - for { - select { - case <-ctx.Done(): - t.Fatal("timed out") - default: - // ok - } - err = ps.Publish(event, []byte(fmt.Sprintf("%d", j))) - if err == nil { - break +func requireGaugeValue(t testing.TB, metrics []*dto.MetricFamily, value float64, name string, label ...string) { + t.Helper() + for _, family := range metrics { + if family.GetName() != name { + continue } - j++ - time.Sleep(testutil.IntervalFast) - } - // any message k or higher comes from after the restart. - k := j - // exceeding the buffer invalidates the test because this causes us to drop messages for reasons other than DB - // reconnect - require.Less(t, k, pubsub.BufferSize, "exceeded buffer") - - // We don't know how quickly the pubsub will reconnect, so continue to send messages with increasing numbers. As - // soon as we see k or higher we know we're getting messages after the restart. - go func() { - for { - select { - case <-ctx.Done(): - return - default: - // ok + ms := family.GetMetric() + for _, m := range ms { + require.Equal(t, len(label), len(m.GetLabel())) + for i, lv := range label { + if lv != m.GetLabel()[i].GetValue() { + continue + } } - _ = ps.Publish(event, []byte(fmt.Sprintf("%d", j))) - j++ - time.Sleep(testutil.IntervalFast) + require.Equal(t, value, m.GetGauge().GetValue()) + return } - }() + } + t.Fatal("didn't find metric") +} - gotDroppedErr := false - for { - m, err := readOne() - if xerrors.Is(err, pubsub.ErrDroppedMessages) { - gotDroppedErr = true +func requireCounterValue(t testing.TB, metrics []*dto.MetricFamily, value float64, name string, label ...string) { + t.Helper() + for _, family := range metrics { + if family.GetName() != name { continue } - require.NoError(t, err, "should only get ErrDroppedMessages") - l, err := strconv.Atoi(m) - require.NoError(t, err) - if l >= k { - // exceeding the buffer invalidates the test because this causes us to drop messages for reasons other than - // DB reconnect - require.Less(t, l, pubsub.BufferSize, "exceeded buffer") - break + ms := family.GetMetric() + for _, m := range ms { + require.Equal(t, len(label), len(m.GetLabel())) + for i, lv := range label { + if lv != m.GetLabel()[i].GetValue() { + continue + } + } + require.Equal(t, value, m.GetCounter().GetValue()) + return } } - require.True(t, gotDroppedErr) + t.Fatal("didn't find metric") }