|
5 | 5 | "errors"
|
6 | 6 | "slices"
|
7 | 7 | "sync"
|
| 8 | + "testing" |
8 | 9 | "time"
|
9 | 10 | )
|
10 | 11 |
|
@@ -141,14 +142,51 @@ func (m *Mock) matchCallLocked(c *Call) {
|
141 | 142 | m.mu.Lock()
|
142 | 143 | }
|
143 | 144 |
|
144 |
| -// Advance moves the clock forward by d, triggering any timers or tickers. Advance will wait for |
145 |
| -// tick functions of tickers created using TickerFunc to complete before returning from |
146 |
| -// Advance. If multiple timers or tickers trigger simultaneously, they are all run on separate go |
147 |
| -// routines. |
148 |
| -func (m *Mock) Advance(d time.Duration) { |
149 |
| - m.mu.Lock() |
150 |
| - defer m.mu.Unlock() |
151 |
| - m.advanceLocked(d) |
| 145 | +// AdvanceWaiter is returned from Advance and Set calls and allows you to wait for tick functions of |
| 146 | +// tickers created using TickerFunc to complete. If multiple timers or tickers trigger |
| 147 | +// simultaneously, they are all run on separate go routines. |
| 148 | +type AdvanceWaiter struct { |
| 149 | + ch chan struct{} |
| 150 | +} |
| 151 | + |
| 152 | +// Wait for all timers and ticks to complete, or until context expires. |
| 153 | +func (w AdvanceWaiter) Wait(ctx context.Context) error { |
| 154 | + select { |
| 155 | + case <-w.ch: |
| 156 | + return nil |
| 157 | + case <-ctx.Done(): |
| 158 | + return ctx.Err() |
| 159 | + } |
| 160 | +} |
| 161 | + |
| 162 | +// MustWait waits for all timers and ticks to complete, and fails the test immediately if the |
| 163 | +// context completes first. MustWait must be called from the goroutine running the test or |
| 164 | +// benchmark, similar to `t.FailNow()`. |
| 165 | +func (w AdvanceWaiter) MustWait(ctx context.Context, t testing.TB) { |
| 166 | + select { |
| 167 | + case <-w.ch: |
| 168 | + return |
| 169 | + case <-ctx.Done(): |
| 170 | + t.Fatalf("context expired while waiting for clock to advance: %s", ctx.Err()) |
| 171 | + } |
| 172 | +} |
| 173 | + |
| 174 | +// Done returns a channel that is closed when all timers and ticks complete. |
| 175 | +func (w AdvanceWaiter) Done() <-chan struct{} { |
| 176 | + return w.ch |
| 177 | +} |
| 178 | + |
| 179 | +// Advance moves the clock forward by d, triggering any timers or tickers. The returned value can |
| 180 | +// be used to wait for all timers and ticks to complete. |
| 181 | +func (m *Mock) Advance(d time.Duration) AdvanceWaiter { |
| 182 | + w := AdvanceWaiter{ch: make(chan struct{})} |
| 183 | + go func() { |
| 184 | + defer close(w.ch) |
| 185 | + m.mu.Lock() |
| 186 | + defer m.mu.Unlock() |
| 187 | + m.advanceLocked(d) |
| 188 | + }() |
| 189 | + return w |
152 | 190 | }
|
153 | 191 |
|
154 | 192 | func (m *Mock) advanceLocked(d time.Duration) {
|
@@ -194,19 +232,24 @@ func (m *Mock) advanceLocked(d time.Duration) {
|
194 | 232 | // Set the time to t. If the time is after the current mocked time, then this is equivalent to
|
195 | 233 | // Advance() with the difference. You may only Set the time earlier than the current time before
|
196 | 234 | // starting tickers and timers (e.g. at the start of your test case).
|
197 |
| -func (m *Mock) Set(t time.Time) { |
198 |
| - m.mu.Lock() |
199 |
| - defer m.mu.Unlock() |
200 |
| - if t.Before(m.cur) { |
201 |
| - // past |
202 |
| - if !m.nextTime.IsZero() { |
203 |
| - panic("Set mock clock to the past after timers/tickers started") |
| 235 | +func (m *Mock) Set(t time.Time) AdvanceWaiter { |
| 236 | + w := AdvanceWaiter{ch: make(chan struct{})} |
| 237 | + go func() { |
| 238 | + defer close(w.ch) |
| 239 | + m.mu.Lock() |
| 240 | + defer m.mu.Unlock() |
| 241 | + if t.Before(m.cur) { |
| 242 | + // past |
| 243 | + if !m.nextTime.IsZero() { |
| 244 | + panic("Set mock clock to the past after timers/tickers started") |
| 245 | + } |
| 246 | + m.cur = t |
| 247 | + return |
204 | 248 | }
|
205 |
| - m.cur = t |
206 |
| - return |
207 |
| - } |
208 |
| - // future, just advance as normal. |
209 |
| - m.advanceLocked(t.Sub(m.cur)) |
| 249 | + // future, just advance as normal. |
| 250 | + m.advanceLocked(t.Sub(m.cur)) |
| 251 | + }() |
| 252 | + return w |
210 | 253 | }
|
211 | 254 |
|
212 | 255 | // Trapper allows the creation of Traps
|
|
0 commit comments