Skip to content

Commit 1f3c4aa

Browse files
committed
chore: add explicit Wait() to clock.Advance()
1 parent e243711 commit 1f3c4aa

File tree

2 files changed

+69
-25
lines changed

2 files changed

+69
-25
lines changed

clock/mock.go

Lines changed: 62 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"errors"
66
"slices"
77
"sync"
8+
"testing"
89
"time"
910
)
1011

@@ -141,14 +142,50 @@ func (m *Mock) matchCallLocked(c *Call) {
141142
m.mu.Lock()
142143
}
143144

144-
// Advance moves the clock forward by d, triggering any timers or tickers. Advance will wait for
145-
// tick functions of tickers created using TickerFunc to complete before returning from
146-
// Advance. If multiple timers or tickers trigger simultaneously, they are all run on separate go
147-
// routines.
148-
func (m *Mock) Advance(d time.Duration) {
149-
m.mu.Lock()
150-
defer m.mu.Unlock()
151-
m.advanceLocked(d)
145+
// AdvanceWaiter is returned from Advance and Set calls and allows you to wait for tick functions of
146+
// tickers created using TickerFunc to complete. If multiple timers or tickers trigger
147+
// simultaneously, they are all run on separate go routines.
148+
type AdvanceWaiter struct {
149+
ch chan struct{}
150+
}
151+
152+
// Wait for all timers and ticks to complete, or until context expires.
153+
func (w AdvanceWaiter) Wait(ctx context.Context) error {
154+
select {
155+
case <-w.ch:
156+
return nil
157+
case <-ctx.Done():
158+
return ctx.Err()
159+
}
160+
}
161+
162+
// MustWait waits for all timers and ticks to complete, and fails the test immediately if the
163+
// context completes first.
164+
func (w AdvanceWaiter) MustWait(ctx context.Context, t testing.TB) {
165+
select {
166+
case <-w.ch:
167+
return
168+
case <-ctx.Done():
169+
t.Fatalf("context expired waiting for Advance: %s", ctx.Err())
170+
}
171+
}
172+
173+
// Done returns a channel that is closed when all timers and ticks complete.
174+
func (w AdvanceWaiter) Done() <-chan struct{} {
175+
return w.ch
176+
}
177+
178+
// Advance moves the clock forward by d, triggering any timers or tickers. The returned value can
179+
// be used to wait for all timers and ticks to complete.
180+
func (m *Mock) Advance(d time.Duration) AdvanceWaiter {
181+
w := AdvanceWaiter{ch: make(chan struct{})}
182+
go func() {
183+
defer close(w.ch)
184+
m.mu.Lock()
185+
defer m.mu.Unlock()
186+
m.advanceLocked(d)
187+
}()
188+
return w
152189
}
153190

154191
func (m *Mock) advanceLocked(d time.Duration) {
@@ -194,19 +231,24 @@ func (m *Mock) advanceLocked(d time.Duration) {
194231
// Set the time to t. If the time is after the current mocked time, then this is equivalent to
195232
// Advance() with the difference. You may only Set the time earlier than the current time before
196233
// starting tickers and timers (e.g. at the start of your test case).
197-
func (m *Mock) Set(t time.Time) {
198-
m.mu.Lock()
199-
defer m.mu.Unlock()
200-
if t.Before(m.cur) {
201-
// past
202-
if !m.nextTime.IsZero() {
203-
panic("Set mock clock to the past after timers/tickers started")
234+
func (m *Mock) Set(t time.Time) AdvanceWaiter {
235+
w := AdvanceWaiter{ch: make(chan struct{})}
236+
go func() {
237+
defer close(w.ch)
238+
m.mu.Lock()
239+
defer m.mu.Unlock()
240+
if t.Before(m.cur) {
241+
// past
242+
if !m.nextTime.IsZero() {
243+
panic("Set mock clock to the past after timers/tickers started")
244+
}
245+
m.cur = t
246+
return
204247
}
205-
m.cur = t
206-
return
207-
}
208-
// future, just advance as normal.
209-
m.advanceLocked(t.Sub(m.cur))
248+
// future, just advance as normal.
249+
m.advanceLocked(t.Sub(m.cur))
250+
}()
251+
return w
210252
}
211253

212254
// Trapper allows the creation of Traps

coderd/database/pubsub/watchdog_test.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,11 @@ func TestWatchdog_NoTimeout(t *testing.T) {
4242

4343
// 5 min / 15 sec = 20, so do 21 ticks
4444
for i := 0; i < 21; i++ {
45-
mClock.Advance(15 * time.Second)
45+
mClock.Advance(15*time.Second).MustWait(ctx, t)
4646
p := testutil.RequireRecvCtx(ctx, t, fPS.pubs)
4747
require.Equal(t, pubsub.EventPubsubWatchdog, p)
48-
mClock.Advance(30 * time.Millisecond) // reasonable round-trip
48+
mClock.Advance(30*time.Millisecond). // reasonable round-trip
49+
MustWait(ctx, t)
4950
// forward the beat
5051
sub.listener(ctx, []byte{})
5152
// we shouldn't time out
@@ -95,10 +96,11 @@ func TestWatchdog_Timeout(t *testing.T) {
9596

9697
// 5 min / 15 sec = 20, so do 19 ticks without timing out
9798
for i := 0; i < 19; i++ {
98-
mClock.Advance(15 * time.Second)
99+
mClock.Advance(15*time.Second).MustWait(ctx, t)
99100
p := testutil.RequireRecvCtx(ctx, t, fPS.pubs)
100101
require.Equal(t, pubsub.EventPubsubWatchdog, p)
101-
mClock.Advance(30 * time.Millisecond) // reasonable round-trip
102+
mClock.Advance(30*time.Millisecond). // reasonable round-trip
103+
MustWait(ctx, t)
102104
// we DO NOT forward the heartbeat
103105
// we shouldn't time out
104106
select {
@@ -108,7 +110,7 @@ func TestWatchdog_Timeout(t *testing.T) {
108110
// OK!
109111
}
110112
}
111-
mClock.Advance(15 * time.Second)
113+
mClock.Advance(15*time.Second).MustWait(ctx, t)
112114
p := testutil.RequireRecvCtx(ctx, t, fPS.pubs)
113115
require.Equal(t, pubsub.EventPubsubWatchdog, p)
114116
testutil.RequireRecvCtx(ctx, t, uut.Timeout())

0 commit comments

Comments
 (0)