diff --git a/clock/clock.go b/clock/clock.go new file mode 100644 index 0000000000000..44fdeb5716463 --- /dev/null +++ b/clock/clock.go @@ -0,0 +1,25 @@ +// Package clock is a library for testing time related code. It exports an interface Clock that +// mimics the standard library time package functions. In production, an implementation that calls +// thru to the standard library is used. In testing, a Mock clock is used to precisely control and +// intercept time functions. +package clock + +import ( + "context" + "time" +) + +type Clock interface { + // TickerFunc is a convenience function that calls f on the interval d until either the given + // context expires or f returns an error. Callers may call Wait() on the returned Waiter to + // wait until this happens and obtain the error. + TickerFunc(ctx context.Context, d time.Duration, f func() error, tags ...string) Waiter + // NewTimer creates a new Timer that will send the current time on its channel after at least + // duration d. + NewTimer(d time.Duration, tags ...string) *Timer +} + +// Waiter can be waited on for an error. +type Waiter interface { + Wait(tags ...string) error +} diff --git a/clock/mock.go b/clock/mock.go new file mode 100644 index 0000000000000..b119c53ccf9d6 --- /dev/null +++ b/clock/mock.go @@ -0,0 +1,444 @@ +package clock + +import ( + "context" + "errors" + "slices" + "sync" + "time" +) + +// Mock is the testing implementation of Clock. It tracks a time that monotonically increases +// during a test, triggering any timers or tickers automatically. +type Mock struct { + mu sync.Mutex + + // cur is the current time + cur time.Time + // advancing is true when we are in the process of advancing the clock. We don't support + // multiple goroutines doing this at once. + advancing bool + + all []event + nextTime time.Time + nextEvents []event + traps []*Trap +} + +type event interface { + next() time.Time + fire(t time.Time) +} + +func (m *Mock) TickerFunc(ctx context.Context, d time.Duration, f func() error, tags ...string) Waiter { + m.mu.Lock() + defer m.mu.Unlock() + c := newCall(clockFunctionTickerFunc, tags, withDuration(d)) + m.matchCallLocked(c) + defer close(c.complete) + t := &mockTickerFunc{ + ctx: ctx, + d: d, + f: f, + nxt: m.cur.Add(d), + mock: m, + cond: sync.NewCond(&m.mu), + } + m.all = append(m.all, t) + m.recomputeNextLocked() + go t.waitForCtx() + return t +} + +func (m *Mock) NewTimer(d time.Duration, tags ...string) *Timer { + if d < 0 { + panic("duration must be positive or zero") + } + m.mu.Lock() + defer m.mu.Unlock() + c := newCall(clockFunctionNewTimer, tags, withDuration(d)) + defer close(c.complete) + m.matchCallLocked(c) + ch := make(chan time.Time, 1) + t := &Timer{ + C: ch, + c: ch, + nxt: m.cur.Add(d), + mock: m, + } + m.addTimerLocked(t) + return t +} + +func (m *Mock) addTimerLocked(t *Timer) { + m.all = append(m.all, t) + m.recomputeNextLocked() +} + +func (m *Mock) recomputeNextLocked() { + var best time.Time + var events []event + for _, e := range m.all { + if best.IsZero() || e.next().Before(best) { + best = e.next() + events = []event{e} + continue + } + if e.next().Equal(best) { + events = append(events, e) + continue + } + } + m.nextTime = best + m.nextEvents = events +} + +func (m *Mock) removeTimer(t *Timer) { + m.mu.Lock() + defer m.mu.Unlock() + m.removeTimerLocked(t) +} + +func (m *Mock) removeTimerLocked(t *Timer) { + defer m.recomputeNextLocked() + t.stopped = true + var e event = t + for i := range m.all { + if m.all[i] == e { + m.all = append(m.all[:i], m.all[i+1:]...) + return + } + } +} + +func (m *Mock) removeTickerFuncLocked(ct *mockTickerFunc) { + defer m.recomputeNextLocked() + var e event = ct + for i := range m.all { + if m.all[i] == e { + m.all = append(m.all[:i], m.all[i+1:]...) + return + } + } +} + +func (m *Mock) matchCallLocked(c *Call) { + var traps []*Trap + for _, t := range m.traps { + if t.matches(c) { + traps = append(traps, t) + } + } + if len(traps) == 0 { + return + } + c.releases.Add(len(traps)) + m.mu.Unlock() + for _, t := range traps { + go t.catch(c) + } + c.releases.Wait() + m.mu.Lock() +} + +// Advance moves the clock forward by d, triggering any timers or tickers. Advance will wait for +// tick functions of tickers created using TickerFunc to complete before returning from +// Advance. If multiple timers or tickers trigger simultaneously, they are all run on separate go +// routines. +func (m *Mock) Advance(d time.Duration) { + m.mu.Lock() + defer m.mu.Unlock() + m.advanceLocked(d) +} + +func (m *Mock) advanceLocked(d time.Duration) { + if m.advancing { + panic("multiple simultaneous calls to Advance not supported") + } + m.advancing = true + defer func() { + m.advancing = false + }() + + fin := m.cur.Add(d) + for { + // nextTime.IsZero implies no events scheduled + if m.nextTime.IsZero() || m.nextTime.After(fin) { + m.cur = fin + return + } + + if m.nextTime.After(m.cur) { + m.cur = m.nextTime + } + + wg := sync.WaitGroup{} + for i := range m.nextEvents { + e := m.nextEvents[i] + t := m.cur + wg.Add(1) + go func() { + e.fire(t) + wg.Done() + }() + } + // release the lock and let the events resolve. This allows them to call back into the + // Mock to query the time or set new timers. Each event should remove or reschedule + // itself from nextEvents. + m.mu.Unlock() + wg.Wait() + m.mu.Lock() + } +} + +// Set the time to t. If the time is after the current mocked time, then this is equivalent to +// Advance() with the difference. You may only Set the time earlier than the current time before +// starting tickers and timers (e.g. at the start of your test case). +func (m *Mock) Set(t time.Time) { + m.mu.Lock() + defer m.mu.Unlock() + if t.Before(m.cur) { + // past + if !m.nextTime.IsZero() { + panic("Set mock clock to the past after timers/tickers started") + } + m.cur = t + return + } + // future, just advance as normal. + m.advanceLocked(t.Sub(m.cur)) +} + +// Trapper allows the creation of Traps +type Trapper struct { + // mock is the underlying Mock. This is a thin wrapper around Mock so that + // we can have our interface look like mClock.Trap().NewTimer("foo") + mock *Mock +} + +func (t Trapper) NewTimer(tags ...string) *Trap { + return t.mock.newTrap(clockFunctionNewTimer, tags) +} + +func (t Trapper) TimerStop(tags ...string) *Trap { + return t.mock.newTrap(clockFunctionTimerStop, tags) +} + +func (t Trapper) TimerReset(tags ...string) *Trap { + return t.mock.newTrap(clockFunctionTimerReset, tags) +} + +func (t Trapper) TickerFunc(tags ...string) *Trap { + return t.mock.newTrap(clockFunctionTickerFunc, tags) +} + +func (t Trapper) TickerFuncWait(tags ...string) *Trap { + return t.mock.newTrap(clockFunctionTickerFuncWait, tags) +} + +func (m *Mock) Trap() Trapper { + return Trapper{m} +} + +func (m *Mock) newTrap(fn clockFunction, tags []string) *Trap { + m.mu.Lock() + defer m.mu.Unlock() + tr := &Trap{ + fn: fn, + tags: tags, + mock: m, + calls: make(chan *Call), + done: make(chan struct{}), + } + m.traps = append(m.traps, tr) + return tr +} + +// NewMock creates a new Mock with the time set to midnight UTC on Jan 1, 2024. +// You may re-set the time earlier than this, but only before timers or tickers +// are created. +func NewMock() *Mock { + cur, err := time.Parse(time.RFC3339, "2024-01-01T00:00:00Z") + if err != nil { + panic(err) + } + return &Mock{ + cur: cur, + } +} + +var _ Clock = &Mock{} + +type mockTickerFunc struct { + ctx context.Context + d time.Duration + f func() error + nxt time.Time + mock *Mock + + // cond is a condition Locked on the main Mock.mu + cond *sync.Cond + // done is true when the ticker exits + done bool + // err holds the error when the ticker exits + err error +} + +func (m *mockTickerFunc) next() time.Time { + return m.nxt +} + +func (m *mockTickerFunc) fire(t time.Time) { + m.mock.mu.Lock() + defer m.mock.mu.Unlock() + if m.done { + return + } + if !m.nxt.Equal(t) { + panic("mockTickerFunc fired at wrong time") + } + m.nxt = m.nxt.Add(m.d) + m.mock.recomputeNextLocked() + + m.mock.mu.Unlock() + err := m.f() + m.mock.mu.Lock() + if err != nil { + m.exitLocked(err) + } +} + +func (m *mockTickerFunc) exitLocked(err error) { + if m.done { + return + } + m.done = true + m.err = err + m.mock.removeTickerFuncLocked(m) + m.cond.Broadcast() +} + +func (m *mockTickerFunc) waitForCtx() { + <-m.ctx.Done() + m.mock.mu.Lock() + defer m.mock.mu.Unlock() + m.exitLocked(m.ctx.Err()) +} + +func (m *mockTickerFunc) Wait(tags ...string) error { + m.mock.mu.Lock() + defer m.mock.mu.Unlock() + c := newCall(clockFunctionTickerFuncWait, tags) + m.mock.matchCallLocked(c) + defer close(c.complete) + for !m.done { + m.cond.Wait() + } + return m.err +} + +var _ Waiter = &mockTickerFunc{} + +type clockFunction int + +const ( + clockFunctionNewTimer clockFunction = iota + clockFunctionTimerStop + clockFunctionTimerReset + clockFunctionTickerFunc + clockFunctionTickerFuncWait +) + +type callArg func(c *Call) + +type Call struct { + Time time.Time + Duration time.Duration + Tags []string + + fn clockFunction + releases sync.WaitGroup + complete chan struct{} +} + +func (c *Call) Release() { + c.releases.Done() + <-c.complete +} + +// nolint: unused // it will be soon +func withTime(t time.Time) callArg { + return func(c *Call) { + c.Time = t + } +} + +func withDuration(d time.Duration) callArg { + return func(c *Call) { + c.Duration = d + } +} + +func newCall(fn clockFunction, tags []string, args ...callArg) *Call { + c := &Call{ + fn: fn, + Tags: tags, + complete: make(chan struct{}), + } + for _, a := range args { + a(c) + } + return c +} + +type Trap struct { + fn clockFunction + tags []string + mock *Mock + calls chan *Call + done chan struct{} +} + +func (t *Trap) catch(c *Call) { + select { + case t.calls <- c: + case <-t.done: + c.Release() + } +} + +func (t *Trap) matches(c *Call) bool { + if t.fn != c.fn { + return false + } + for _, tag := range t.tags { + if !slices.Contains(c.Tags, tag) { + return false + } + } + return true +} + +func (t *Trap) Close() { + t.mock.mu.Lock() + defer t.mock.mu.Unlock() + for i, tr := range t.mock.traps { + if t == tr { + t.mock.traps = append(t.mock.traps[:i], t.mock.traps[i+1:]...) + } + } + close(t.done) +} + +var ErrTrapClosed = errors.New("trap closed") + +func (t *Trap) Wait(ctx context.Context) (*Call, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-t.done: + return nil, ErrTrapClosed + case c := <-t.calls: + return c, nil + } +} diff --git a/clock/real.go b/clock/real.go new file mode 100644 index 0000000000000..d632cb4943c4d --- /dev/null +++ b/clock/real.go @@ -0,0 +1,58 @@ +package clock + +import ( + "context" + "time" +) + +type realClock struct{} + +func NewReal() Clock { + return realClock{} +} + +func (realClock) TickerFunc(ctx context.Context, d time.Duration, f func() error, _ ...string) Waiter { + ct := &realContextTicker{ + ctx: ctx, + tkr: time.NewTicker(d), + f: f, + err: make(chan error, 1), + } + go ct.run() + return ct +} + +type realContextTicker struct { + ctx context.Context + tkr *time.Ticker + f func() error + err chan error +} + +func (t *realContextTicker) Wait(_ ...string) error { + return <-t.err +} + +func (t *realContextTicker) run() { + defer t.tkr.Stop() + for { + select { + case <-t.ctx.Done(): + t.err <- t.ctx.Err() + return + case <-t.tkr.C: + err := t.f() + if err != nil { + t.err <- err + return + } + } + } +} + +func (realClock) NewTimer(d time.Duration, _ ...string) *Timer { + rt := time.NewTimer(d) + return &Timer{C: rt.C, timer: rt} +} + +var _ Clock = realClock{} diff --git a/clock/timer.go b/clock/timer.go new file mode 100644 index 0000000000000..bf31ab18a6764 --- /dev/null +++ b/clock/timer.go @@ -0,0 +1,67 @@ +package clock + +import "time" + +type Timer struct { + C <-chan time.Time + //nolint: revive + c chan time.Time + timer *time.Timer // realtime impl, if set + nxt time.Time // next tick time + mock *Mock // mock clock, if set + fn func() // AfterFunc function, if set + stopped bool // True if stopped, false if running +} + +func (t *Timer) fire(tt time.Time) { + if !tt.Equal(t.nxt) { + panic("mock timer fired at wrong time") + } + t.mock.removeTimer(t) + t.c <- tt + if t.fn != nil { + t.fn() + } +} + +func (t *Timer) next() time.Time { + return t.nxt +} + +func (t *Timer) Stop(tags ...string) bool { + if t.timer != nil { + return t.timer.Stop() + } + t.mock.mu.Lock() + defer t.mock.mu.Unlock() + c := newCall(clockFunctionTimerStop, tags) + t.mock.matchCallLocked(c) + defer close(c.complete) + result := !t.stopped + t.mock.removeTimerLocked(t) + return result +} + +func (t *Timer) Reset(d time.Duration, tags ...string) bool { + if t.timer != nil { + return t.timer.Reset(d) + } + if d < 0 { + panic("duration must be positive or zero") + } + t.mock.mu.Lock() + defer t.mock.mu.Unlock() + c := newCall(clockFunctionTimerReset, tags, withDuration(d)) + t.mock.matchCallLocked(c) + defer close(c.complete) + result := !t.stopped + t.mock.removeTimerLocked(t) + t.stopped = false + t.nxt = t.mock.cur.Add(d) + select { + case <-t.c: + default: + } + t.mock.addTimerLocked(t) + return result +} diff --git a/coderd/database/pubsub/watchdog.go b/coderd/database/pubsub/watchdog.go index 687129fc5bcc2..df54019bb49b2 100644 --- a/coderd/database/pubsub/watchdog.go +++ b/coderd/database/pubsub/watchdog.go @@ -7,9 +7,8 @@ import ( "sync" "time" - "github.com/benbjohnson/clock" - "cdr.dev/slog" + "github.com/coder/coder/v2/clock" ) const ( @@ -36,7 +35,7 @@ type Watchdog struct { } func NewWatchdog(ctx context.Context, logger slog.Logger, ps Pubsub) *Watchdog { - return NewWatchdogWithClock(ctx, logger, ps, clock.New()) + return NewWatchdogWithClock(ctx, logger, ps, clock.NewReal()) } // NewWatchdogWithClock returns a watchdog with the given clock. Product code should always call NewWatchDog. @@ -79,32 +78,23 @@ func (w *Watchdog) Timeout() <-chan struct{} { func (w *Watchdog) publishLoop() { defer w.wg.Done() - tkr := w.clock.Ticker(periodHeartbeat) - defer tkr.Stop() - // immediate publish after starting the ticker. This helps testing so that we can tell from - // the outside that the ticker is started. - err := w.ps.Publish(EventPubsubWatchdog, []byte{}) - if err != nil { - w.logger.Warn(w.ctx, "failed to publish heartbeat on pubsub watchdog", slog.Error(err)) - } - for { - select { - case <-w.ctx.Done(): - w.logger.Debug(w.ctx, "context done; exiting publishLoop") - return - case <-tkr.C: - err := w.ps.Publish(EventPubsubWatchdog, []byte{}) - if err != nil { - w.logger.Warn(w.ctx, "failed to publish heartbeat on pubsub watchdog", slog.Error(err)) - } + tkr := w.clock.TickerFunc(w.ctx, periodHeartbeat, func() error { + err := w.ps.Publish(EventPubsubWatchdog, []byte{}) + if err != nil { + w.logger.Warn(w.ctx, "failed to publish heartbeat on pubsub watchdog", slog.Error(err)) + } else { + w.logger.Debug(w.ctx, "published heartbeat on pubsub watchdog") } - } + return err + }, "publish") + // ignore the error, since we log before returning the error + _ = tkr.Wait() } func (w *Watchdog) subscribeMonitor() { defer w.wg.Done() - tmr := w.clock.Timer(periodTimeout) - defer tmr.Stop() + tmr := w.clock.NewTimer(periodTimeout) + defer tmr.Stop("subscribe") beats := make(chan struct{}) unsub, err := w.ps.Subscribe(EventPubsubWatchdog, func(context.Context, []byte) { w.logger.Debug(w.ctx, "got heartbeat for pubsub watchdog") diff --git a/coderd/database/pubsub/watchdog_test.go b/coderd/database/pubsub/watchdog_test.go index ddd5a864e2c66..8d695447e91cf 100644 --- a/coderd/database/pubsub/watchdog_test.go +++ b/coderd/database/pubsub/watchdog_test.go @@ -4,36 +4,48 @@ import ( "testing" "time" - "github.com/benbjohnson/clock" "github.com/stretchr/testify/require" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/clock" "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/testutil" ) func TestWatchdog_NoTimeout(t *testing.T) { t.Parallel() - ctx := testutil.Context(t, time.Hour) + ctx := testutil.Context(t, testutil.WaitShort) mClock := clock.NewMock() - start := time.Date(2024, 2, 5, 8, 7, 6, 5, time.UTC) - mClock.Set(start) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) fPS := newFakePubsub() + + // trap the ticker and timer.Stop() calls + pubTrap := mClock.Trap().TickerFunc("publish") + defer pubTrap.Close() + subTrap := mClock.Trap().TimerStop("subscribe") + defer subTrap.Close() + uut := pubsub.NewWatchdogWithClock(ctx, logger, fPS, mClock) + // wait for the ticker to be created so that we know it starts from the + // right baseline time. + pc, err := pubTrap.Wait(ctx) + require.NoError(t, err) + pc.Release() + require.Equal(t, 15*time.Second, pc.Duration) + + // we subscribe after starting the timer, so we know the timer also starts + // from the baseline. sub := testutil.RequireRecvCtx(ctx, t, fPS.subs) require.Equal(t, pubsub.EventPubsubWatchdog, sub.event) - p := testutil.RequireRecvCtx(ctx, t, fPS.pubs) - require.Equal(t, pubsub.EventPubsubWatchdog, p) // 5 min / 15 sec = 20, so do 21 ticks for i := 0; i < 21; i++ { - mClock.Add(15 * time.Second) - p = testutil.RequireRecvCtx(ctx, t, fPS.pubs) + mClock.Advance(15 * time.Second) + p := testutil.RequireRecvCtx(ctx, t, fPS.pubs) require.Equal(t, pubsub.EventPubsubWatchdog, p) - mClock.Add(30 * time.Millisecond) // reasonable round-trip + mClock.Advance(30 * time.Millisecond) // reasonable round-trip // forward the beat sub.listener(ctx, []byte{}) // we shouldn't time out @@ -45,7 +57,14 @@ func TestWatchdog_NoTimeout(t *testing.T) { } } - err := uut.Close() + errCh := make(chan error, 1) + go func() { + errCh <- uut.Close() + }() + sc, err := subTrap.Wait(ctx) // timer.Stop() called + require.NoError(t, err) + sc.Release() + err = testutil.RequireRecvCtx(ctx, t, errCh) require.NoError(t, err) } @@ -53,23 +72,33 @@ func TestWatchdog_Timeout(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) mClock := clock.NewMock() - start := time.Date(2024, 2, 5, 8, 7, 6, 5, time.UTC) - mClock.Set(start) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) fPS := newFakePubsub() + + // trap the ticker and timer calls + pubTrap := mClock.Trap().TickerFunc("publish") + defer pubTrap.Close() + uut := pubsub.NewWatchdogWithClock(ctx, logger, fPS, mClock) + // wait for the ticker to be created so that we know it starts from the + // right baseline time. + pc, err := pubTrap.Wait(ctx) + require.NoError(t, err) + pc.Release() + require.Equal(t, 15*time.Second, pc.Duration) + + // we subscribe after starting the timer, so we know the timer also starts + // from the baseline. sub := testutil.RequireRecvCtx(ctx, t, fPS.subs) require.Equal(t, pubsub.EventPubsubWatchdog, sub.event) - p := testutil.RequireRecvCtx(ctx, t, fPS.pubs) - require.Equal(t, pubsub.EventPubsubWatchdog, p) // 5 min / 15 sec = 20, so do 19 ticks without timing out for i := 0; i < 19; i++ { - mClock.Add(15 * time.Second) - p = testutil.RequireRecvCtx(ctx, t, fPS.pubs) + mClock.Advance(15 * time.Second) + p := testutil.RequireRecvCtx(ctx, t, fPS.pubs) require.Equal(t, pubsub.EventPubsubWatchdog, p) - mClock.Add(30 * time.Millisecond) // reasonable round-trip + mClock.Advance(30 * time.Millisecond) // reasonable round-trip // we DO NOT forward the heartbeat // we shouldn't time out select { @@ -79,12 +108,12 @@ func TestWatchdog_Timeout(t *testing.T) { // OK! } } - mClock.Add(15 * time.Second) - p = testutil.RequireRecvCtx(ctx, t, fPS.pubs) + mClock.Advance(15 * time.Second) + p := testutil.RequireRecvCtx(ctx, t, fPS.pubs) require.Equal(t, pubsub.EventPubsubWatchdog, p) testutil.RequireRecvCtx(ctx, t, uut.Timeout()) - err := uut.Close() + err = uut.Close() require.NoError(t, err) } @@ -118,7 +147,7 @@ func (f *fakePubsub) Publish(event string, _ []byte) error { func newFakePubsub() *fakePubsub { return &fakePubsub{ - pubs: make(chan string), + pubs: make(chan string, 1), subs: make(chan subscribe), } }