diff --git a/cli/server.go b/cli/server.go index 520f58454db29..e0d17838cde25 100644 --- a/cli/server.go +++ b/cli/server.go @@ -655,6 +655,10 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. options.OIDCConfig = oc } + // We'll read from this channel in the select below that tracks shutdown. If it remains + // nil, that case of the select will just never fire, but it's important not to have a + // "bare" read on this channel. + var pubsubWatchdogTimeout <-chan struct{} if vals.InMemoryDatabase { // This is only used for testing. options.Database = dbmem.New() @@ -683,6 +687,9 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. options.PrometheusRegistry.MustRegister(ps) } defer options.Pubsub.Close() + psWatchdog := pubsub.NewWatchdog(ctx, logger.Named("pswatch"), ps) + pubsubWatchdogTimeout = psWatchdog.Timeout() + defer psWatchdog.Close() } if options.DeploymentValues.Prometheus.Enable && options.DeploymentValues.Prometheus.CollectDBMetrics { @@ -1031,6 +1038,8 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. _, _ = io.WriteString(inv.Stdout, cliui.Bold("Interrupt caught, gracefully exiting. Use ctrl+\\ to force quit")) case <-tunnelDone: exitErr = xerrors.New("dev tunnel closed unexpectedly") + case <-pubsubWatchdogTimeout: + exitErr = xerrors.New("pubsub Watchdog timed out") case exitErr = <-errCh: } if exitErr != nil && !xerrors.Is(exitErr, context.Canceled) { diff --git a/coderd/database/pubsub/watchdog.go b/coderd/database/pubsub/watchdog.go new file mode 100644 index 0000000000000..2b8d929fa7e38 --- /dev/null +++ b/coderd/database/pubsub/watchdog.go @@ -0,0 +1,141 @@ +package pubsub + +import ( + "context" + "runtime/pprof" + "strings" + "sync" + "time" + + "github.com/benbjohnson/clock" + + "cdr.dev/slog" +) + +const ( + EventPubsubWatchdog = "pubsub_watchdog" + periodHeartbeat = 15 * time.Second + // periodTimeout is the time without receiving a heartbeat (from any publisher) before we + // consider the watchdog to have timed out. There is a tradeoff here between avoiding + // disruption due to a short-lived issue connecting to the postgres database, and restarting + // before the consequences of a non-working pubsub are noticed by end users (e.g. being unable + // to connect to their workspaces). + periodTimeout = 5 * time.Minute +) + +type Watchdog struct { + ctx context.Context + cancel context.CancelFunc + logger slog.Logger + ps Pubsub + wg sync.WaitGroup + timeout chan struct{} + + // for testing + clock clock.Clock +} + +func NewWatchdog(ctx context.Context, logger slog.Logger, ps Pubsub) *Watchdog { + return NewWatchdogWithClock(ctx, logger, ps, clock.New()) +} + +// NewWatchdogWithClock returns a watchdog with the given clock. Product code should always call NewWatchDog. +func NewWatchdogWithClock(ctx context.Context, logger slog.Logger, ps Pubsub, c clock.Clock) *Watchdog { + ctx, cancel := context.WithCancel(ctx) + w := &Watchdog{ + ctx: ctx, + cancel: cancel, + logger: logger, + ps: ps, + timeout: make(chan struct{}), + clock: c, + } + w.wg.Add(2) + go w.publishLoop() + go w.subscribeMonitor() + return w +} + +func (w *Watchdog) Close() error { + w.cancel() + w.wg.Wait() + return nil +} + +// Timeout returns a channel that is closed if the watchdog times out. Note that the Timeout() chan +// will NOT be closed if the Watchdog is Close'd or its context expires, so it is important to read +// from the Timeout() chan in a select e.g. +// +// w := NewWatchDog(ctx, logger, ps) +// select { +// case <-ctx.Done(): +// case <-w.Timeout(): +// +// FreakOut() +// } +func (w *Watchdog) Timeout() <-chan struct{} { + return w.timeout +} + +func (w *Watchdog) publishLoop() { + defer w.wg.Done() + tkr := w.clock.Ticker(periodHeartbeat) + defer tkr.Stop() + // immediate publish after starting the ticker. This helps testing so that we can tell from + // the outside that the ticker is started. + err := w.ps.Publish(EventPubsubWatchdog, []byte{}) + if err != nil { + w.logger.Warn(w.ctx, "failed to publish heartbeat on pubsub watchdog", slog.Error(err)) + } + for { + select { + case <-w.ctx.Done(): + w.logger.Debug(w.ctx, "context done; exiting publishLoop") + return + case <-tkr.C: + err := w.ps.Publish(EventPubsubWatchdog, []byte{}) + if err != nil { + w.logger.Warn(w.ctx, "failed to publish heartbeat on pubsub watchdog", slog.Error(err)) + } + } + } +} + +func (w *Watchdog) subscribeMonitor() { + defer w.wg.Done() + beats := make(chan struct{}) + unsub, err := w.ps.Subscribe(EventPubsubWatchdog, func(context.Context, []byte) { + w.logger.Debug(w.ctx, "got heartbeat for pubsub watchdog") + select { + case <-w.ctx.Done(): + case beats <- struct{}{}: + } + }) + if err != nil { + w.logger.Critical(w.ctx, "watchdog failed to subscribe", slog.Error(err)) + close(w.timeout) + return + } + defer unsub() + tmr := w.clock.Timer(periodTimeout) + defer tmr.Stop() + for { + select { + case <-w.ctx.Done(): + w.logger.Debug(w.ctx, "context done; exiting subscribeMonitor") + return + case <-beats: + // c.f. https://pkg.go.dev/time#Timer.Reset + if !tmr.Stop() { + <-tmr.C + } + tmr.Reset(periodTimeout) + case <-tmr.C: + buf := new(strings.Builder) + _ = pprof.Lookup("goroutine").WriteTo(buf, 1) + w.logger.Critical(w.ctx, "pubsub watchdog timeout", slog.F("goroutines", buf.String())) + close(w.timeout) + return + } + } +} diff --git a/coderd/database/pubsub/watchdog_test.go b/coderd/database/pubsub/watchdog_test.go new file mode 100644 index 0000000000000..ddd5a864e2c66 --- /dev/null +++ b/coderd/database/pubsub/watchdog_test.go @@ -0,0 +1,124 @@ +package pubsub_test + +import ( + "testing" + "time" + + "github.com/benbjohnson/clock" + "github.com/stretchr/testify/require" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database/pubsub" + "github.com/coder/coder/v2/testutil" +) + +func TestWatchdog_NoTimeout(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, time.Hour) + mClock := clock.NewMock() + start := time.Date(2024, 2, 5, 8, 7, 6, 5, time.UTC) + mClock.Set(start) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + fPS := newFakePubsub() + uut := pubsub.NewWatchdogWithClock(ctx, logger, fPS, mClock) + + sub := testutil.RequireRecvCtx(ctx, t, fPS.subs) + require.Equal(t, pubsub.EventPubsubWatchdog, sub.event) + p := testutil.RequireRecvCtx(ctx, t, fPS.pubs) + require.Equal(t, pubsub.EventPubsubWatchdog, p) + + // 5 min / 15 sec = 20, so do 21 ticks + for i := 0; i < 21; i++ { + mClock.Add(15 * time.Second) + p = testutil.RequireRecvCtx(ctx, t, fPS.pubs) + require.Equal(t, pubsub.EventPubsubWatchdog, p) + mClock.Add(30 * time.Millisecond) // reasonable round-trip + // forward the beat + sub.listener(ctx, []byte{}) + // we shouldn't time out + select { + case <-uut.Timeout(): + t.Fatal("watchdog tripped") + default: + // OK! + } + } + + err := uut.Close() + require.NoError(t, err) +} + +func TestWatchdog_Timeout(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + mClock := clock.NewMock() + start := time.Date(2024, 2, 5, 8, 7, 6, 5, time.UTC) + mClock.Set(start) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + fPS := newFakePubsub() + uut := pubsub.NewWatchdogWithClock(ctx, logger, fPS, mClock) + + sub := testutil.RequireRecvCtx(ctx, t, fPS.subs) + require.Equal(t, pubsub.EventPubsubWatchdog, sub.event) + p := testutil.RequireRecvCtx(ctx, t, fPS.pubs) + require.Equal(t, pubsub.EventPubsubWatchdog, p) + + // 5 min / 15 sec = 20, so do 19 ticks without timing out + for i := 0; i < 19; i++ { + mClock.Add(15 * time.Second) + p = testutil.RequireRecvCtx(ctx, t, fPS.pubs) + require.Equal(t, pubsub.EventPubsubWatchdog, p) + mClock.Add(30 * time.Millisecond) // reasonable round-trip + // we DO NOT forward the heartbeat + // we shouldn't time out + select { + case <-uut.Timeout(): + t.Fatal("watchdog tripped") + default: + // OK! + } + } + mClock.Add(15 * time.Second) + p = testutil.RequireRecvCtx(ctx, t, fPS.pubs) + require.Equal(t, pubsub.EventPubsubWatchdog, p) + testutil.RequireRecvCtx(ctx, t, uut.Timeout()) + + err := uut.Close() + require.NoError(t, err) +} + +type subscribe struct { + event string + listener pubsub.Listener +} + +type fakePubsub struct { + pubs chan string + subs chan subscribe +} + +func (f *fakePubsub) Subscribe(event string, listener pubsub.Listener) (func(), error) { + f.subs <- subscribe{event, listener} + return func() {}, nil +} + +func (*fakePubsub) SubscribeWithErr(string, pubsub.ListenerWithErr) (func(), error) { + panic("should not be called") +} + +func (*fakePubsub) Close() error { + panic("should not be called") +} + +func (f *fakePubsub) Publish(event string, _ []byte) error { + f.pubs <- event + return nil +} + +func newFakePubsub() *fakePubsub { + return &fakePubsub{ + pubs: make(chan string), + subs: make(chan subscribe), + } +}