Skip to content

Commit 9c3fd5d

Browse files
authored
chore: add explicit Wait() to clock.Advance() (#13464)
1 parent 42324b3 commit 9c3fd5d

File tree

2 files changed

+70
-25
lines changed

2 files changed

+70
-25
lines changed

clock/mock.go

+63-20
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"errors"
66
"slices"
77
"sync"
8+
"testing"
89
"time"
910
)
1011

@@ -141,14 +142,51 @@ func (m *Mock) matchCallLocked(c *Call) {
141142
m.mu.Lock()
142143
}
143144

144-
// Advance moves the clock forward by d, triggering any timers or tickers. Advance will wait for
145-
// tick functions of tickers created using TickerFunc to complete before returning from
146-
// Advance. If multiple timers or tickers trigger simultaneously, they are all run on separate go
147-
// routines.
148-
func (m *Mock) Advance(d time.Duration) {
149-
m.mu.Lock()
150-
defer m.mu.Unlock()
151-
m.advanceLocked(d)
145+
// AdvanceWaiter is returned from Advance and Set calls and allows you to wait for tick functions of
146+
// tickers created using TickerFunc to complete. If multiple timers or tickers trigger
147+
// simultaneously, they are all run on separate go routines.
148+
type AdvanceWaiter struct {
149+
ch chan struct{}
150+
}
151+
152+
// Wait for all timers and ticks to complete, or until context expires.
153+
func (w AdvanceWaiter) Wait(ctx context.Context) error {
154+
select {
155+
case <-w.ch:
156+
return nil
157+
case <-ctx.Done():
158+
return ctx.Err()
159+
}
160+
}
161+
162+
// MustWait waits for all timers and ticks to complete, and fails the test immediately if the
163+
// context completes first. MustWait must be called from the goroutine running the test or
164+
// benchmark, similar to `t.FailNow()`.
165+
func (w AdvanceWaiter) MustWait(ctx context.Context, t testing.TB) {
166+
select {
167+
case <-w.ch:
168+
return
169+
case <-ctx.Done():
170+
t.Fatalf("context expired while waiting for clock to advance: %s", ctx.Err())
171+
}
172+
}
173+
174+
// Done returns a channel that is closed when all timers and ticks complete.
175+
func (w AdvanceWaiter) Done() <-chan struct{} {
176+
return w.ch
177+
}
178+
179+
// Advance moves the clock forward by d, triggering any timers or tickers. The returned value can
180+
// be used to wait for all timers and ticks to complete.
181+
func (m *Mock) Advance(d time.Duration) AdvanceWaiter {
182+
w := AdvanceWaiter{ch: make(chan struct{})}
183+
go func() {
184+
defer close(w.ch)
185+
m.mu.Lock()
186+
defer m.mu.Unlock()
187+
m.advanceLocked(d)
188+
}()
189+
return w
152190
}
153191

154192
func (m *Mock) advanceLocked(d time.Duration) {
@@ -194,19 +232,24 @@ func (m *Mock) advanceLocked(d time.Duration) {
194232
// Set the time to t. If the time is after the current mocked time, then this is equivalent to
195233
// Advance() with the difference. You may only Set the time earlier than the current time before
196234
// starting tickers and timers (e.g. at the start of your test case).
197-
func (m *Mock) Set(t time.Time) {
198-
m.mu.Lock()
199-
defer m.mu.Unlock()
200-
if t.Before(m.cur) {
201-
// past
202-
if !m.nextTime.IsZero() {
203-
panic("Set mock clock to the past after timers/tickers started")
235+
func (m *Mock) Set(t time.Time) AdvanceWaiter {
236+
w := AdvanceWaiter{ch: make(chan struct{})}
237+
go func() {
238+
defer close(w.ch)
239+
m.mu.Lock()
240+
defer m.mu.Unlock()
241+
if t.Before(m.cur) {
242+
// past
243+
if !m.nextTime.IsZero() {
244+
panic("Set mock clock to the past after timers/tickers started")
245+
}
246+
m.cur = t
247+
return
204248
}
205-
m.cur = t
206-
return
207-
}
208-
// future, just advance as normal.
209-
m.advanceLocked(t.Sub(m.cur))
249+
// future, just advance as normal.
250+
m.advanceLocked(t.Sub(m.cur))
251+
}()
252+
return w
210253
}
211254

212255
// Trapper allows the creation of Traps

coderd/database/pubsub/watchdog_test.go

+7-5
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,11 @@ func TestWatchdog_NoTimeout(t *testing.T) {
4242

4343
// 5 min / 15 sec = 20, so do 21 ticks
4444
for i := 0; i < 21; i++ {
45-
mClock.Advance(15 * time.Second)
45+
mClock.Advance(15*time.Second).MustWait(ctx, t)
4646
p := testutil.RequireRecvCtx(ctx, t, fPS.pubs)
4747
require.Equal(t, pubsub.EventPubsubWatchdog, p)
48-
mClock.Advance(30 * time.Millisecond) // reasonable round-trip
48+
mClock.Advance(30*time.Millisecond). // reasonable round-trip
49+
MustWait(ctx, t)
4950
// forward the beat
5051
sub.listener(ctx, []byte{})
5152
// we shouldn't time out
@@ -95,10 +96,11 @@ func TestWatchdog_Timeout(t *testing.T) {
9596

9697
// 5 min / 15 sec = 20, so do 19 ticks without timing out
9798
for i := 0; i < 19; i++ {
98-
mClock.Advance(15 * time.Second)
99+
mClock.Advance(15*time.Second).MustWait(ctx, t)
99100
p := testutil.RequireRecvCtx(ctx, t, fPS.pubs)
100101
require.Equal(t, pubsub.EventPubsubWatchdog, p)
101-
mClock.Advance(30 * time.Millisecond) // reasonable round-trip
102+
mClock.Advance(30*time.Millisecond). // reasonable round-trip
103+
MustWait(ctx, t)
102104
// we DO NOT forward the heartbeat
103105
// we shouldn't time out
104106
select {
@@ -108,7 +110,7 @@ func TestWatchdog_Timeout(t *testing.T) {
108110
// OK!
109111
}
110112
}
111-
mClock.Advance(15 * time.Second)
113+
mClock.Advance(15*time.Second).MustWait(ctx, t)
112114
p := testutil.RequireRecvCtx(ctx, t, fPS.pubs)
113115
require.Equal(t, pubsub.EventPubsubWatchdog, p)
114116
testutil.RequireRecvCtx(ctx, t, uut.Timeout())

0 commit comments

Comments
 (0)