From 6e5decdf5544207f667b3f752f864159851ed820 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Tue, 4 Jun 2024 16:36:29 +0400 Subject: [PATCH] chore: add explicit Wait() to clock.Advance() --- clock/mock.go | 83 +++++++++++++++++++------ coderd/database/pubsub/watchdog_test.go | 12 ++-- 2 files changed, 70 insertions(+), 25 deletions(-) diff --git a/clock/mock.go b/clock/mock.go index b119c53ccf9d6..55e4254ac2d3e 100644 --- a/clock/mock.go +++ b/clock/mock.go @@ -5,6 +5,7 @@ import ( "errors" "slices" "sync" + "testing" "time" ) @@ -141,14 +142,51 @@ func (m *Mock) matchCallLocked(c *Call) { m.mu.Lock() } -// Advance moves the clock forward by d, triggering any timers or tickers. Advance will wait for -// tick functions of tickers created using TickerFunc to complete before returning from -// Advance. If multiple timers or tickers trigger simultaneously, they are all run on separate go -// routines. -func (m *Mock) Advance(d time.Duration) { - m.mu.Lock() - defer m.mu.Unlock() - m.advanceLocked(d) +// AdvanceWaiter is returned from Advance and Set calls and allows you to wait for tick functions of +// tickers created using TickerFunc to complete. If multiple timers or tickers trigger +// simultaneously, they are all run on separate go routines. +type AdvanceWaiter struct { + ch chan struct{} +} + +// Wait for all timers and ticks to complete, or until context expires. +func (w AdvanceWaiter) Wait(ctx context.Context) error { + select { + case <-w.ch: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// MustWait waits for all timers and ticks to complete, and fails the test immediately if the +// context completes first. MustWait must be called from the goroutine running the test or +// benchmark, similar to `t.FailNow()`. +func (w AdvanceWaiter) MustWait(ctx context.Context, t testing.TB) { + select { + case <-w.ch: + return + case <-ctx.Done(): + t.Fatalf("context expired while waiting for clock to advance: %s", ctx.Err()) + } +} + +// Done returns a channel that is closed when all timers and ticks complete. +func (w AdvanceWaiter) Done() <-chan struct{} { + return w.ch +} + +// Advance moves the clock forward by d, triggering any timers or tickers. The returned value can +// be used to wait for all timers and ticks to complete. +func (m *Mock) Advance(d time.Duration) AdvanceWaiter { + w := AdvanceWaiter{ch: make(chan struct{})} + go func() { + defer close(w.ch) + m.mu.Lock() + defer m.mu.Unlock() + m.advanceLocked(d) + }() + return w } func (m *Mock) advanceLocked(d time.Duration) { @@ -194,19 +232,24 @@ func (m *Mock) advanceLocked(d time.Duration) { // Set the time to t. If the time is after the current mocked time, then this is equivalent to // Advance() with the difference. You may only Set the time earlier than the current time before // starting tickers and timers (e.g. at the start of your test case). -func (m *Mock) Set(t time.Time) { - m.mu.Lock() - defer m.mu.Unlock() - if t.Before(m.cur) { - // past - if !m.nextTime.IsZero() { - panic("Set mock clock to the past after timers/tickers started") +func (m *Mock) Set(t time.Time) AdvanceWaiter { + w := AdvanceWaiter{ch: make(chan struct{})} + go func() { + defer close(w.ch) + m.mu.Lock() + defer m.mu.Unlock() + if t.Before(m.cur) { + // past + if !m.nextTime.IsZero() { + panic("Set mock clock to the past after timers/tickers started") + } + m.cur = t + return } - m.cur = t - return - } - // future, just advance as normal. - m.advanceLocked(t.Sub(m.cur)) + // future, just advance as normal. + m.advanceLocked(t.Sub(m.cur)) + }() + return w } // Trapper allows the creation of Traps diff --git a/coderd/database/pubsub/watchdog_test.go b/coderd/database/pubsub/watchdog_test.go index 8d695447e91cf..62d51c8ecaaee 100644 --- a/coderd/database/pubsub/watchdog_test.go +++ b/coderd/database/pubsub/watchdog_test.go @@ -42,10 +42,11 @@ func TestWatchdog_NoTimeout(t *testing.T) { // 5 min / 15 sec = 20, so do 21 ticks for i := 0; i < 21; i++ { - mClock.Advance(15 * time.Second) + mClock.Advance(15*time.Second).MustWait(ctx, t) p := testutil.RequireRecvCtx(ctx, t, fPS.pubs) require.Equal(t, pubsub.EventPubsubWatchdog, p) - mClock.Advance(30 * time.Millisecond) // reasonable round-trip + mClock.Advance(30*time.Millisecond). // reasonable round-trip + MustWait(ctx, t) // forward the beat sub.listener(ctx, []byte{}) // we shouldn't time out @@ -95,10 +96,11 @@ func TestWatchdog_Timeout(t *testing.T) { // 5 min / 15 sec = 20, so do 19 ticks without timing out for i := 0; i < 19; i++ { - mClock.Advance(15 * time.Second) + mClock.Advance(15*time.Second).MustWait(ctx, t) p := testutil.RequireRecvCtx(ctx, t, fPS.pubs) require.Equal(t, pubsub.EventPubsubWatchdog, p) - mClock.Advance(30 * time.Millisecond) // reasonable round-trip + mClock.Advance(30*time.Millisecond). // reasonable round-trip + MustWait(ctx, t) // we DO NOT forward the heartbeat // we shouldn't time out select { @@ -108,7 +110,7 @@ func TestWatchdog_Timeout(t *testing.T) { // OK! } } - mClock.Advance(15 * time.Second) + mClock.Advance(15*time.Second).MustWait(ctx, t) p := testutil.RequireRecvCtx(ctx, t, fPS.pubs) require.Equal(t, pubsub.EventPubsubWatchdog, p) testutil.RequireRecvCtx(ctx, t, uut.Timeout())