From 005be1dd55b27e8428a05ac3e93bce852f121212 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Fri, 7 Jun 2024 11:33:14 +0400 Subject: [PATCH] chore: change mock clock to allow Advance() within timer/tick functions --- clock/clock.go | 2 + clock/example_test.go | 69 +++++- clock/mock.go | 226 ++++++++++++-------- clock/real.go | 4 + clock/timer.go | 16 +- coderd/database/pubsub/watchdog_test.go | 26 ++- enterprise/tailnet/pgcoord.go | 80 +++---- enterprise/tailnet/pgcoord_internal_test.go | 132 +++++++++--- enterprise/tailnet/pgcoord_test.go | 45 ++-- tailnet/configmaps_internal_test.go | 41 ++-- 10 files changed, 424 insertions(+), 217 deletions(-) diff --git a/clock/clock.go b/clock/clock.go index 516b74e6b117b..5f3b0de105911 100644 --- a/clock/clock.go +++ b/clock/clock.go @@ -26,6 +26,8 @@ type Clock interface { Now(tags ...string) time.Time // Since returns the time elapsed since t. It is shorthand for Clock.Now().Sub(t). Since(t time.Time, tags ...string) time.Duration + // Until returns the duration until t. It is shorthand for t.Sub(Clock.Now()). + Until(t time.Time, tags ...string) time.Duration } // Waiter can be waited on for an error. diff --git a/clock/example_test.go b/clock/example_test.go index 69d6ba4a318ae..de72312d7d036 100644 --- a/clock/example_test.go +++ b/clock/example_test.go @@ -44,7 +44,7 @@ func TestExampleTickerFunc(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - mClock := clock.NewMock() + mClock := clock.NewMock(t) // Because the ticker is started on a goroutine, we can't immediately start // advancing the clock, or we will race with the start of the ticker. If we @@ -76,9 +76,74 @@ func TestExampleTickerFunc(t *testing.T) { } // Now that we know the ticker is started, we can advance the time. - mClock.Advance(time.Hour).MustWait(ctx, t) + mClock.Advance(time.Hour).MustWait(ctx) if tks := tc.Ticks(); tks != 1 { t.Fatalf("expected 1 got %d ticks", tks) } } + +type exampleLatencyMeasurer struct { + mu sync.Mutex + lastLatency time.Duration +} + +func newExampleLatencyMeasurer(ctx context.Context, clk clock.Clock) *exampleLatencyMeasurer { + m := &exampleLatencyMeasurer{} + clk.TickerFunc(ctx, 10*time.Second, func() error { + start := clk.Now() + // m.doSomething() + latency := clk.Since(start) + m.mu.Lock() + defer m.mu.Unlock() + m.lastLatency = latency + return nil + }) + return m +} + +func (m *exampleLatencyMeasurer) LastLatency() time.Duration { + m.mu.Lock() + defer m.mu.Unlock() + return m.lastLatency +} + +func TestExampleLatencyMeasurer(t *testing.T) { + t.Parallel() + + // nolint:gocritic // trying to avoid Coder-specific stuff with an eye toward spinning this out + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + mClock := clock.NewMock(t) + trap := mClock.Trap().Since() + defer trap.Close() + + lm := newExampleLatencyMeasurer(ctx, mClock) + + w := mClock.Advance(10 * time.Second) // triggers first tick + c := trap.MustWait(ctx) // call to Since() + mClock.Advance(33 * time.Millisecond) + c.Release() + w.MustWait(ctx) + + if l := lm.LastLatency(); l != 33*time.Millisecond { + t.Fatalf("expected 33ms got %s", l.String()) + } + + // Next tick is in 10s - 33ms, but if we don't want to calculate, we can use: + d, w2 := mClock.AdvanceNext() + c = trap.MustWait(ctx) + mClock.Advance(17 * time.Millisecond) + c.Release() + w2.MustWait(ctx) + + expectedD := 10*time.Second - 33*time.Millisecond + if d != expectedD { + t.Fatalf("expected %s got %s", expectedD.String(), d.String()) + } + + if l := lm.LastLatency(); l != 17*time.Millisecond { + t.Fatalf("expected 17ms got %s", l.String()) + } +} diff --git a/clock/mock.go b/clock/mock.go index 55c8cdcaa3277..6e66206c1614d 100644 --- a/clock/mock.go +++ b/clock/mock.go @@ -3,6 +3,7 @@ package clock import ( "context" "errors" + "fmt" "slices" "sync" "testing" @@ -12,13 +13,11 @@ import ( // Mock is the testing implementation of Clock. It tracks a time that monotonically increases // during a test, triggering any timers or tickers automatically. type Mock struct { + tb testing.TB mu sync.Mutex // cur is the current time cur time.Time - // advancing is true when we are in the process of advancing the clock. We don't support - // multiple goroutines doing this at once. - advancing bool all []event nextTime time.Time @@ -77,11 +76,9 @@ func (m *Mock) AfterFunc(d time.Duration, f func(), tags ...string) *Timer { } m.mu.Lock() defer m.mu.Unlock() - m.matchCallLocked(&Call{ - fn: clockFunctionAfterFunc, - Duration: d, - Tags: tags, - }) + c := newCall(clockFunctionAfterFunc, tags, withDuration(d)) + defer close(c.complete) + m.matchCallLocked(c) t := &Timer{ nxt: m.cur.Add(d), fn: f, @@ -94,23 +91,30 @@ func (m *Mock) AfterFunc(d time.Duration, f func(), tags ...string) *Timer { func (m *Mock) Now(tags ...string) time.Time { m.mu.Lock() defer m.mu.Unlock() - m.matchCallLocked(&Call{ - fn: clockFunctionNow, - Tags: tags, - }) + c := newCall(clockFunctionNow, tags) + defer close(c.complete) + m.matchCallLocked(c) return m.cur } func (m *Mock) Since(t time.Time, tags ...string) time.Duration { m.mu.Lock() defer m.mu.Unlock() - m.matchCallLocked(&Call{ - fn: clockFunctionSince, - Tags: tags, - }) + c := newCall(clockFunctionSince, tags, withTime(t)) + defer close(c.complete) + m.matchCallLocked(c) return m.cur.Sub(t) } +func (m *Mock) Until(t time.Time, tags ...string) time.Duration { + m.mu.Lock() + defer m.mu.Unlock() + c := newCall(clockFunctionUntil, tags, withTime(t)) + defer close(c.complete) + m.matchCallLocked(c) + return t.Sub(m.cur) +} + func (m *Mock) addTimerLocked(t *Timer) { m.all = append(m.all, t) m.recomputeNextLocked() @@ -182,14 +186,15 @@ func (m *Mock) matchCallLocked(c *Call) { m.mu.Lock() } -// AdvanceWaiter is returned from Advance and Set calls and allows you to wait for: -// -// - tick functions of tickers created using NewContextTicker -// - functions passed to AfterFunc +// AdvanceWaiter is returned from Advance and Set calls and allows you to wait for ticks and timers +// to complete. In the case of functions passed to AfterFunc or TickerFunc, it waits for the +// functions to return. For other ticks & timers, it just waits for the tick to be delivered to +// the channel. // -// to complete. If multiple timers or tickers trigger simultaneously, they are all run on separate +// If multiple timers or tickers trigger simultaneously, they are all run on separate // go routines. type AdvanceWaiter struct { + tb testing.TB ch chan struct{} } @@ -206,12 +211,13 @@ func (w AdvanceWaiter) Wait(ctx context.Context) error { // MustWait waits for all timers and ticks to complete, and fails the test immediately if the // context completes first. MustWait must be called from the goroutine running the test or // benchmark, similar to `t.FailNow()`. -func (w AdvanceWaiter) MustWait(ctx context.Context, t testing.TB) { +func (w AdvanceWaiter) MustWait(ctx context.Context) { + w.tb.Helper() select { case <-w.ch: return case <-ctx.Done(): - t.Fatalf("context expired while waiting for clock to advance: %s", ctx.Err()) + w.tb.Fatalf("context expired while waiting for clock to advance: %s", ctx.Err()) } } @@ -221,81 +227,112 @@ func (w AdvanceWaiter) Done() <-chan struct{} { } // Advance moves the clock forward by d, triggering any timers or tickers. The returned value can -// be used to wait for all timers and ticks to complete. +// be used to wait for all timers and ticks to complete. Advance sets the clock forward before +// returning, and can only advance up to the next timer or tick event. It will fail the test if you +// attempt to advance beyond. +// +// If you need to advance exactly to the next event, and don't know or don't wish to calculate it, +// consider AdvanceNext(). func (m *Mock) Advance(d time.Duration) AdvanceWaiter { - w := AdvanceWaiter{ch: make(chan struct{})} - go func() { - defer close(w.ch) - m.mu.Lock() - defer m.mu.Unlock() - m.advanceLocked(d) - }() - return w -} - -func (m *Mock) advanceLocked(d time.Duration) { - if m.advancing { - panic("multiple simultaneous calls to Advance/Set not supported") - } - m.advancing = true - defer func() { - m.advancing = false - }() - + m.tb.Helper() + w := AdvanceWaiter{tb: m.tb, ch: make(chan struct{})} + m.mu.Lock() fin := m.cur.Add(d) - for { - // nextTime.IsZero implies no events scheduled - if m.nextTime.IsZero() || m.nextTime.After(fin) { - m.cur = fin - return - } - - if m.nextTime.After(m.cur) { - m.cur = m.nextTime - } - - wg := sync.WaitGroup{} - for i := range m.nextEvents { - e := m.nextEvents[i] - t := m.cur - wg.Add(1) - go func() { - e.fire(t) - wg.Done() - }() - } - // release the lock and let the events resolve. This allows them to call back into the - // Mock to query the time or set new timers. Each event should remove or reschedule - // itself from nextEvents. + // nextTime.IsZero implies no events scheduled. + if m.nextTime.IsZero() || fin.Before(m.nextTime) { + m.cur = fin + m.mu.Unlock() + close(w.ch) + return w + } + if fin.After(m.nextTime) { + m.tb.Errorf(fmt.Sprintf("cannot advance %s which is beyond next timer/ticker event in %s", + d.String(), m.nextTime.Sub(m.cur))) m.mu.Unlock() - wg.Wait() - m.mu.Lock() + close(w.ch) + return w } + + m.cur = m.nextTime + go m.advanceLocked(w) + return w +} + +func (m *Mock) advanceLocked(w AdvanceWaiter) { + defer close(w.ch) + wg := sync.WaitGroup{} + for i := range m.nextEvents { + e := m.nextEvents[i] + t := m.cur + wg.Add(1) + go func() { + e.fire(t) + wg.Done() + }() + } + // release the lock and let the events resolve. This allows them to call back into the + // Mock to query the time or set new timers. Each event should remove or reschedule + // itself from nextEvents. + m.mu.Unlock() + wg.Wait() } // Set the time to t. If the time is after the current mocked time, then this is equivalent to // Advance() with the difference. You may only Set the time earlier than the current time before // starting tickers and timers (e.g. at the start of your test case). func (m *Mock) Set(t time.Time) AdvanceWaiter { - w := AdvanceWaiter{ch: make(chan struct{})} - go func() { + m.tb.Helper() + w := AdvanceWaiter{tb: m.tb, ch: make(chan struct{})} + m.mu.Lock() + if t.Before(m.cur) { defer close(w.ch) - m.mu.Lock() defer m.mu.Unlock() - if t.Before(m.cur) { - // past - if !m.nextTime.IsZero() { - panic("Set mock clock to the past after timers/tickers started") - } - m.cur = t - return + // past + if !m.nextTime.IsZero() { + m.tb.Error("Set mock clock to the past after timers/tickers started") } - // future, just advance as normal. - m.advanceLocked(t.Sub(m.cur)) - }() + m.cur = t + return w + } + // future + // nextTime.IsZero implies no events scheduled. + if m.nextTime.IsZero() || t.Before(m.nextTime) { + defer close(w.ch) + defer m.mu.Unlock() + m.cur = t + return w + } + if t.After(m.nextTime) { + defer close(w.ch) + defer m.mu.Unlock() + m.tb.Errorf("cannot Set time to %s which is beyond next timer/ticker event at %s", + t.String(), m.nextTime) + return w + } + + m.cur = m.nextTime + go m.advanceLocked(w) return w } +// AdvanceNext advances the clock to the next timer or tick event. It fails the test if there are +// none scheduled. It returns the duration the clock was advanced and a waiter that can be used to +// wait for the timer/tick event(s) to finish. +func (m *Mock) AdvanceNext() (time.Duration, AdvanceWaiter) { + m.mu.Lock() + m.tb.Helper() + w := AdvanceWaiter{tb: m.tb, ch: make(chan struct{})} + if m.nextTime.IsZero() { + defer close(w.ch) + defer m.mu.Unlock() + m.tb.Error("cannot AdvanceNext because there are no timers or tickers running") + } + d := m.nextTime.Sub(m.cur) + m.cur = m.nextTime + go m.advanceLocked(w) + return d, w +} + // Trapper allows the creation of Traps type Trapper struct { // mock is the underlying Mock. This is a thin wrapper around Mock so that @@ -335,6 +372,10 @@ func (t Trapper) Since(tags ...string) *Trap { return t.mock.newTrap(clockFunctionSince, tags) } +func (t Trapper) Until(tags ...string) *Trap { + return t.mock.newTrap(clockFunctionUntil, tags) +} + func (m *Mock) Trap() Trapper { return Trapper{m} } @@ -356,12 +397,13 @@ func (m *Mock) newTrap(fn clockFunction, tags []string) *Trap { // NewMock creates a new Mock with the time set to midnight UTC on Jan 1, 2024. // You may re-set the time earlier than this, but only before timers or tickers // are created. -func NewMock() *Mock { +func NewMock(tb testing.TB) *Mock { cur, err := time.Parse(time.RFC3339, "2024-01-01T00:00:00Z") if err != nil { panic(err) } return &Mock{ + tb: tb, cur: cur, } } @@ -387,15 +429,12 @@ func (m *mockTickerFunc) next() time.Time { return m.nxt } -func (m *mockTickerFunc) fire(t time.Time) { +func (m *mockTickerFunc) fire(_ time.Time) { m.mock.mu.Lock() defer m.mock.mu.Unlock() if m.done { return } - if !m.nxt.Equal(t) { - panic("mockTickerFunc fired at wrong time") - } m.nxt = m.nxt.Add(m.d) m.mock.recomputeNextLocked() @@ -449,6 +488,7 @@ const ( clockFunctionTickerFuncWait clockFunctionNow clockFunctionSince + clockFunctionUntil ) type callArg func(c *Call) @@ -468,7 +508,6 @@ func (c *Call) Release() { <-c.complete } -// nolint: unused // it will be soon func withTime(t time.Time) callArg { return func(c *Call) { c.Time = t @@ -544,3 +583,14 @@ func (t *Trap) Wait(ctx context.Context) (*Call, error) { return c, nil } } + +// MustWait calls Wait() and then if there is an error, immediately fails the +// test via tb.Fatalf() +func (t *Trap) MustWait(ctx context.Context) *Call { + t.mock.tb.Helper() + c, err := t.Wait(ctx) + if err != nil { + t.mock.tb.Fatalf("context expired while waiting for trap: %s", err.Error()) + } + return c +} diff --git a/clock/real.go b/clock/real.go index e31c80616d896..41019571e6aea 100644 --- a/clock/real.go +++ b/clock/real.go @@ -68,4 +68,8 @@ func (realClock) Since(t time.Time, _ ...string) time.Duration { return time.Since(t) } +func (realClock) Until(t time.Time, _ ...string) time.Duration { + return time.Until(t) +} + var _ Clock = realClock{} diff --git a/clock/timer.go b/clock/timer.go index ee1d67485219d..b2175c953f0d5 100644 --- a/clock/timer.go +++ b/clock/timer.go @@ -14,9 +14,6 @@ type Timer struct { } func (t *Timer) fire(tt time.Time) { - if !tt.Equal(t.nxt) { - panic("mock timer fired at wrong time") - } t.mock.removeTimer(t) if t.fn != nil { t.fn() @@ -56,13 +53,20 @@ func (t *Timer) Reset(d time.Duration, tags ...string) bool { t.mock.matchCallLocked(c) defer close(c.complete) result := !t.stopped - t.mock.removeTimerLocked(t) - t.stopped = false - t.nxt = t.mock.cur.Add(d) select { case <-t.c: default: } + if d == 0 { + // zero duration timer means we should immediately re-fire it, rather + // than remove and re-add it. + t.stopped = false + go t.fire(t.mock.cur) + return result + } + t.mock.removeTimerLocked(t) + t.stopped = false + t.nxt = t.mock.cur.Add(d) t.mock.addTimerLocked(t) return result } diff --git a/coderd/database/pubsub/watchdog_test.go b/coderd/database/pubsub/watchdog_test.go index 62d51c8ecaaee..942f9eeb849c4 100644 --- a/coderd/database/pubsub/watchdog_test.go +++ b/coderd/database/pubsub/watchdog_test.go @@ -16,7 +16,7 @@ import ( func TestWatchdog_NoTimeout(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) - mClock := clock.NewMock() + mClock := clock.NewMock(t) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) fPS := newFakePubsub() @@ -42,11 +42,13 @@ func TestWatchdog_NoTimeout(t *testing.T) { // 5 min / 15 sec = 20, so do 21 ticks for i := 0; i < 21; i++ { - mClock.Advance(15*time.Second).MustWait(ctx, t) + d, w := mClock.AdvanceNext() + w.MustWait(ctx) + require.LessOrEqual(t, d, 15*time.Second) p := testutil.RequireRecvCtx(ctx, t, fPS.pubs) require.Equal(t, pubsub.EventPubsubWatchdog, p) - mClock.Advance(30*time.Millisecond). // reasonable round-trip - MustWait(ctx, t) + mClock.Advance(30 * time.Millisecond). // reasonable round-trip + MustWait(ctx) // forward the beat sub.listener(ctx, []byte{}) // we shouldn't time out @@ -72,11 +74,11 @@ func TestWatchdog_NoTimeout(t *testing.T) { func TestWatchdog_Timeout(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) - mClock := clock.NewMock() + mClock := clock.NewMock(t) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) fPS := newFakePubsub() - // trap the ticker and timer calls + // trap the ticker calls pubTrap := mClock.Trap().TickerFunc("publish") defer pubTrap.Close() @@ -96,11 +98,13 @@ func TestWatchdog_Timeout(t *testing.T) { // 5 min / 15 sec = 20, so do 19 ticks without timing out for i := 0; i < 19; i++ { - mClock.Advance(15*time.Second).MustWait(ctx, t) + d, w := mClock.AdvanceNext() + w.MustWait(ctx) + require.LessOrEqual(t, d, 15*time.Second) p := testutil.RequireRecvCtx(ctx, t, fPS.pubs) require.Equal(t, pubsub.EventPubsubWatchdog, p) - mClock.Advance(30*time.Millisecond). // reasonable round-trip - MustWait(ctx, t) + mClock.Advance(30 * time.Millisecond). // reasonable round-trip + MustWait(ctx) // we DO NOT forward the heartbeat // we shouldn't time out select { @@ -110,7 +114,9 @@ func TestWatchdog_Timeout(t *testing.T) { // OK! } } - mClock.Advance(15*time.Second).MustWait(ctx, t) + d, w := mClock.AdvanceNext() + w.MustWait(ctx) + require.LessOrEqual(t, d, 15*time.Second) p := testutil.RequireRecvCtx(ctx, t, fPS.pubs) require.Equal(t, pubsub.EventPubsubWatchdog, p) testutil.RequireRecvCtx(ctx, t, uut.Timeout()) diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 857cdafe94e79..104a649d87839 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -15,6 +15,7 @@ import ( gProto "google.golang.org/protobuf/proto" "cdr.dev/slog" + "github.com/coder/coder/v2/clock" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/pubsub" @@ -115,11 +116,16 @@ var pgCoordSubject = rbac.Subject{ // NewPGCoord creates a high-availability coordinator that stores state in the PostgreSQL database and // receives notifications of updates via the pubsub. func NewPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store database.Store) (agpl.Coordinator, error) { - return newPGCoordInternal(ctx, logger, ps, store) + return newPGCoordInternal(ctx, logger, ps, store, clock.NewReal()) +} + +// NewTestPGCoord is only used in testing to pass a clock.Clock in. +func NewTestPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store database.Store, clk clock.Clock) (agpl.Coordinator, error) { + return newPGCoordInternal(ctx, logger, ps, store, clk) } func newPGCoordInternal( - ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store database.Store, + ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store database.Store, clk clock.Clock, ) ( *pgCoord, error, ) { @@ -157,7 +163,7 @@ func newPGCoordInternal( handshaker: newHandshaker(ctx, logger, id, ps, rfhCh, fHB), handshakerCh: rfhCh, id: id, - querier: newQuerier(querierCtx, logger, id, ps, store, id, cCh, ccCh, numQuerierWorkers, fHB), + querier: newQuerier(querierCtx, logger, id, ps, store, id, cCh, ccCh, numQuerierWorkers, fHB, clk), closed: make(chan struct{}), } go func() { @@ -817,6 +823,7 @@ func newQuerier(ctx context.Context, closeConnections chan *connIO, numWorkers int, firstHeartbeat chan struct{}, + clk clock.Clock, ) *querier { updates := make(chan hbUpdate) q := &querier{ @@ -828,7 +835,7 @@ func newQuerier(ctx context.Context, newConnections: newConnections, closeConnections: closeConnections, workQ: newWorkQ[querierWorkKey](ctx), - heartbeats: newHeartbeats(ctx, logger, ps, store, self, updates, firstHeartbeat), + heartbeats: newHeartbeats(ctx, logger, ps, store, self, updates, firstHeartbeat, clk), mappers: make(map[mKey]*mapper), updates: updates, healthy: true, // assume we start healthy @@ -1462,12 +1469,12 @@ type heartbeats struct { lock sync.RWMutex coordinators map[uuid.UUID]time.Time - timer *time.Timer + timer *clock.Timer wg sync.WaitGroup - // overwritten in tests, but otherwise constant - cleanupPeriod time.Duration + // for testing + clock clock.Clock } func newHeartbeats( @@ -1475,6 +1482,7 @@ func newHeartbeats( ps pubsub.Pubsub, store database.Store, self uuid.UUID, update chan<- hbUpdate, firstHeartbeat chan<- struct{}, + clk clock.Clock, ) *heartbeats { h := &heartbeats{ ctx: ctx, @@ -1485,7 +1493,7 @@ func newHeartbeats( update: update, firstHeartbeat: firstHeartbeat, coordinators: make(map[uuid.UUID]time.Time), - cleanupPeriod: cleanupPeriod, + clock: clk, } h.wg.Add(3) go h.subscribe() @@ -1576,11 +1584,11 @@ func (h *heartbeats) recvBeat(id uuid.UUID) { _ = agpl.SendCtx(h.ctx, h.update, hbUpdate{filter: filterUpdateUpdated}) }() } - h.coordinators[id] = time.Now() + h.coordinators[id] = h.clock.Now("heartbeats", "recvBeat") if h.timer == nil { // this can only happen for the very first beat - h.timer = time.AfterFunc(MissedHeartbeats*HeartbeatPeriod, h.checkExpiry) + h.timer = h.clock.AfterFunc(MissedHeartbeats*HeartbeatPeriod, h.checkExpiry, "heartbeats", "recvBeat") h.logger.Debug(h.ctx, "set initial heartbeat timeout") return } @@ -1594,24 +1602,30 @@ func (h *heartbeats) resetExpiryTimerWithLock() { oldestTime = t } } - d := time.Until(oldestTime.Add(MissedHeartbeats * HeartbeatPeriod)) + d := h.clock.Until( + oldestTime.Add(MissedHeartbeats*HeartbeatPeriod), + "heartbeats", "resetExpiryTimerWithLock", + ) + if len(h.coordinators) == 0 { + return + } h.logger.Debug(h.ctx, "computed oldest heartbeat", slog.F("oldest", oldestTime), slog.F("time_to_expiry", d)) - // only reschedule if it's in the future. - if d > 0 { - h.timer.Reset(d) + if d < 0 { + d = 0 } + h.timer.Reset(d) } func (h *heartbeats) checkExpiry() { h.logger.Debug(h.ctx, "checking heartbeat expiry") h.lock.Lock() defer h.lock.Unlock() - now := time.Now() + now := h.clock.Now() expired := false for id, t := range h.coordinators { lastHB := now.Sub(t) h.logger.Debug(h.ctx, "last heartbeat from coordinator", slog.F("other_coordinator_id", id), slog.F("last_heartbeat", lastHB)) - if lastHB > MissedHeartbeats*HeartbeatPeriod { + if lastHB >= MissedHeartbeats*HeartbeatPeriod { expired = true delete(h.coordinators, id) h.logger.Info(h.ctx, "coordinator failed heartbeat check", slog.F("other_coordinator_id", id), slog.F("last_heartbeat", lastHB)) @@ -1633,17 +1647,12 @@ func (h *heartbeats) sendBeats() { h.sendBeat() close(h.firstHeartbeat) // signal binder it can start writing defer h.sendDelete() - tkr := time.NewTicker(HeartbeatPeriod) - defer tkr.Stop() - for { - select { - case <-h.ctx.Done(): - h.logger.Debug(h.ctx, "ending heartbeats", slog.Error(h.ctx.Err())) - return - case <-tkr.C: - h.sendBeat() - } - } + tkr := h.clock.TickerFunc(h.ctx, HeartbeatPeriod, func() error { + h.sendBeat() + return nil + }, "heartbeats", "sendBeats") + err := tkr.Wait() + h.logger.Debug(h.ctx, "ending heartbeats", slog.Error(err)) } func (h *heartbeats) sendBeat() { @@ -1682,17 +1691,12 @@ func (h *heartbeats) sendDelete() { func (h *heartbeats) cleanupLoop() { defer h.wg.Done() h.cleanup() - tkr := time.NewTicker(h.cleanupPeriod) - defer tkr.Stop() - for { - select { - case <-h.ctx.Done(): - h.logger.Debug(h.ctx, "ending cleanupLoop", slog.Error(h.ctx.Err())) - return - case <-tkr.C: - h.cleanup() - } - } + tkr := h.clock.TickerFunc(h.ctx, cleanupPeriod, func() error { + h.cleanup() + return nil + }, "heartbeats", "cleanupLoop") + err := tkr.Wait() + h.logger.Debug(h.ctx, "ending cleanupLoop", slog.Error(err)) } // cleanup issues a DB command to clean out any old expired coordinators or lost peer state. The diff --git a/enterprise/tailnet/pgcoord_internal_test.go b/enterprise/tailnet/pgcoord_internal_test.go index 4607e6fb2ab2f..5117131c05956 100644 --- a/enterprise/tailnet/pgcoord_internal_test.go +++ b/enterprise/tailnet/pgcoord_internal_test.go @@ -10,6 +10,8 @@ import ( "testing" "time" + "github.com/coder/coder/v2/clock" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -34,8 +36,7 @@ import ( // make update-golden-files var UpdateGoldenFiles = flag.Bool("update", false, "update .golden files") -// TestHeartbeats_Cleanup is internal so that we can overwrite the cleanup period and not wait an hour for the timed -// cleanup. +// TestHeartbeats_Cleanup tests the cleanup loop func TestHeartbeats_Cleanup(t *testing.T) { t.Parallel() @@ -46,38 +47,82 @@ func TestHeartbeats_Cleanup(t *testing.T) { defer cancel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - waitForCleanup := make(chan struct{}) - mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).MinTimes(2).DoAndReturn(func(_ context.Context) error { - <-waitForCleanup - return nil - }) - mStore.EXPECT().CleanTailnetLostPeers(gomock.Any()).MinTimes(2).DoAndReturn(func(_ context.Context) error { - <-waitForCleanup - return nil - }) - mStore.EXPECT().CleanTailnetTunnels(gomock.Any()).MinTimes(2).DoAndReturn(func(_ context.Context) error { - <-waitForCleanup - return nil - }) + mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).Times(2).Return(nil) + mStore.EXPECT().CleanTailnetLostPeers(gomock.Any()).Times(2).Return(nil) + mStore.EXPECT().CleanTailnetTunnels(gomock.Any()).Times(2).Return(nil) + + mClock := clock.NewMock(t) + trap := mClock.Trap().TickerFunc("heartbeats", "cleanupLoop") + defer trap.Close() uut := &heartbeats{ - ctx: ctx, - logger: logger, - store: mStore, - cleanupPeriod: time.Millisecond, + ctx: ctx, + logger: logger, + store: mStore, + clock: mClock, } uut.wg.Add(1) go uut.cleanupLoop() - for i := 0; i < 6; i++ { - select { - case <-ctx.Done(): - t.Fatal("timeout") - case waitForCleanup <- struct{}{}: - // ok - } + call, err := trap.Wait(ctx) + require.NoError(t, err) + call.Release() + require.Equal(t, cleanupPeriod, call.Duration) + mClock.Advance(cleanupPeriod).MustWait(ctx) +} + +// TestHeartbeats_recvBeat_resetSkew is a regression test for a bug where heartbeats from two +// coordinators slightly skewed from one another could result in one coordinator failing to get +// expired +func TestHeartbeats_recvBeat_resetSkew(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + mClock := clock.NewMock(t) + trap := mClock.Trap().Until("heartbeats", "resetExpiryTimerWithLock") + defer trap.Close() + + uut := heartbeats{ + ctx: ctx, + logger: logger, + clock: mClock, + self: uuid.UUID{1}, + update: make(chan hbUpdate, 4), + coordinators: make(map[uuid.UUID]time.Time), } - close(waitForCleanup) + + coord2 := uuid.UUID{2} + coord3 := uuid.UUID{3} + + uut.listen(ctx, []byte(coord2.String()), nil) + + // coord 3 heartbeat comes very soon after + mClock.Advance(time.Millisecond).MustWait(ctx) + go uut.listen(ctx, []byte(coord3.String()), nil) + trap.MustWait(ctx).Release() + + // both coordinators are present + uut.lock.RLock() + require.Contains(t, uut.coordinators, coord2) + require.Contains(t, uut.coordinators, coord3) + uut.lock.RUnlock() + + // no more heartbeats arrive, and coord2 expires + w := mClock.Advance(MissedHeartbeats*HeartbeatPeriod - time.Millisecond) + // however, several ms pass between expiring 2 and computing the time until 3 expires + c := trap.MustWait(ctx) + mClock.Advance(2 * time.Millisecond).MustWait(ctx) // 3 has now expired _in the past_ + c.Release() + w.MustWait(ctx) + + // expired in the past means we immediately reschedule checkExpiry, so we get another call + trap.MustWait(ctx).Release() + + uut.lock.RLock() + require.NotContains(t, uut.coordinators, coord2) + require.NotContains(t, uut.coordinators, coord3) + uut.lock.RUnlock() } func TestHeartbeats_LostCoordinator_MarkLost(t *testing.T) { @@ -85,25 +130,26 @@ func TestHeartbeats_LostCoordinator_MarkLost(t *testing.T) { ctrl := gomock.NewController(t) mStore := dbmock.NewMockStore(ctrl) + mClock := clock.NewMock(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) uut := &heartbeats{ - ctx: ctx, - logger: logger, - store: mStore, - cleanupPeriod: time.Millisecond, + ctx: ctx, + logger: logger, + store: mStore, coordinators: map[uuid.UUID]time.Time{ - uuid.New(): time.Now(), + uuid.New(): mClock.Now(), }, + clock: mClock, } mpngs := []mapping{{ peer: uuid.New(), coordinator: uuid.New(), - updatedAt: time.Now(), + updatedAt: mClock.Now(), node: &proto.Node{}, kind: proto.CoordinateResponse_PeerUpdate_NODE, }} @@ -342,11 +388,14 @@ func TestPGCoordinatorUnhealthy(t *testing.T) { ctrl := gomock.NewController(t) mStore := dbmock.NewMockStore(ctrl) ps := pubsub.NewInMemory() + mClock := clock.NewMock(t) + tfTrap := mClock.Trap().TickerFunc("heartbeats", "sendBeats") + defer tfTrap.Close() // after 3 failed heartbeats, the coordinator is unhealthy mStore.EXPECT(). UpsertTailnetCoordinator(gomock.Any(), gomock.Any()). - MinTimes(3). + Times(3). Return(database.TailnetCoordinator{}, xerrors.New("badness")) mStore.EXPECT(). DeleteCoordinator(gomock.Any(), gomock.Any()). @@ -360,9 +409,22 @@ func TestPGCoordinatorUnhealthy(t *testing.T) { mStore.EXPECT().CleanTailnetLostPeers(gomock.Any()).AnyTimes().Return(nil) mStore.EXPECT().CleanTailnetTunnels(gomock.Any()).AnyTimes().Return(nil) - coordinator, err := newPGCoordInternal(ctx, logger, ps, mStore) + coordinator, err := newPGCoordInternal(ctx, logger, ps, mStore, mClock) + require.NoError(t, err) + + expectedPeriod := HeartbeatPeriod + tfCall, err := tfTrap.Wait(ctx) require.NoError(t, err) + tfCall.Release() + require.Equal(t, expectedPeriod, tfCall.Duration) + + // Now that the ticker has started, we can advance 2 more beats to get to 3 + // failed heartbeats + mClock.Advance(HeartbeatPeriod).MustWait(ctx) + mClock.Advance(HeartbeatPeriod).MustWait(ctx) + // The querier is informed async about being unhealthy, so we need to wait + // until it is. require.Eventually(t, func() bool { return !coordinator.querier.isHealthy() }, testutil.WaitShort, testutil.IntervalFast) diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index 9c363ee700570..c02774adb7245 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -10,6 +10,8 @@ import ( "testing" "time" + "github.com/coder/coder/v2/clock" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -337,7 +339,13 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store) + mClock := clock.NewMock(t) + nowTrap := mClock.Trap().Now("heartbeats", "recvBeat") + defer nowTrap.Close() + afTrap := mClock.Trap().AfterFunc("heartbeats", "recvBeat") + defer afTrap.Close() + + coordinator, err := tailnet.NewTestPGCoord(ctx, logger, ps, store, mClock) require.NoError(t, err) defer coordinator.Close() @@ -360,21 +368,11 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) { store: store, id: uuid.New(), } - // heatbeat until canceled - ctx2, cancel2 := context.WithCancel(ctx) - go func() { - t := time.NewTicker(tailnet.HeartbeatPeriod) - defer t.Stop() - for { - select { - case <-ctx2.Done(): - return - case <-t.C: - fCoord2.heartbeat() - } - } - }() + fCoord2.heartbeat() + nowTrap.MustWait(ctx).Release() + afTrap.MustWait(ctx).Release() // heartbeat timeout started + fCoord2.agentNode(agent.id, &agpl.Node{PreferredDERP: 12}) assertEventuallyHasDERPs(ctx, t, client, 12) @@ -384,22 +382,31 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) { store: store, id: uuid.New(), } - start := time.Now() fCoord3.heartbeat() + nowTrap.MustWait(ctx).Release() fCoord3.agentNode(agent.id, &agpl.Node{PreferredDERP: 13}) assertEventuallyHasDERPs(ctx, t, client, 13) + // fCoord2 sends in a second heartbeat, one period later (on time) + fCoord2.heartbeat() + c := nowTrap.MustWait(ctx) + mClock.Advance(tailnet.HeartbeatPeriod).MustWait(ctx) + c.Release() + // when the fCoord3 misses enough heartbeats, the real coordinator should send an update with the // node from fCoord2 for the agent. + mClock.Advance(tailnet.HeartbeatPeriod).MustWait(ctx) + mClock.Advance(tailnet.HeartbeatPeriod).MustWait(ctx) assertEventuallyHasDERPs(ctx, t, client, 12) - assert.Greater(t, time.Since(start), tailnet.HeartbeatPeriod*tailnet.MissedHeartbeats) - // stop fCoord2 heartbeats, which should cause us to revert to the original agent mapping - cancel2() + // one more heartbeat period will result in fCoord2 being expired, which should cause us to + // revert to the original agent mapping + mClock.Advance(tailnet.HeartbeatPeriod).MustWait(ctx) assertEventuallyHasDERPs(ctx, t, client, 10) // send fCoord3 heartbeat, which should trigger us to consider that mapping valid again. fCoord3.heartbeat() + nowTrap.MustWait(ctx).Release() assertEventuallyHasDERPs(ctx, t, client, 13) err = agent.close() diff --git a/tailnet/configmaps_internal_test.go b/tailnet/configmaps_internal_test.go index c658e5fb2f44e..83b15387a9a43 100644 --- a/tailnet/configmaps_internal_test.go +++ b/tailnet/configmaps_internal_test.go @@ -195,7 +195,7 @@ func TestConfigMaps_updatePeers_new_waitForHandshake_neverConfigures(t *testing. discoKey := key.NewDisco() uut := newConfigMaps(logger, fEng, nodeID, nodePrivateKey, discoKey.Public()) defer uut.close() - mClock := clock.NewMock() + mClock := clock.NewMock(t) uut.clock = mClock p1ID := uuid.UUID{1} @@ -239,7 +239,7 @@ func TestConfigMaps_updatePeers_new_waitForHandshake_outOfOrder(t *testing.T) { discoKey := key.NewDisco() uut := newConfigMaps(logger, fEng, nodeID, nodePrivateKey, discoKey.Public()) defer uut.close() - mClock := clock.NewMock() + mClock := clock.NewMock(t) uut.clock = mClock p1ID := uuid.UUID{1} @@ -310,7 +310,7 @@ func TestConfigMaps_updatePeers_new_waitForHandshake(t *testing.T) { discoKey := key.NewDisco() uut := newConfigMaps(logger, fEng, nodeID, nodePrivateKey, discoKey.Public()) defer uut.close() - mClock := clock.NewMock() + mClock := clock.NewMock(t) uut.clock = mClock p1ID := uuid.UUID{1} @@ -381,7 +381,7 @@ func TestConfigMaps_updatePeers_new_waitForHandshake_timeout(t *testing.T) { discoKey := key.NewDisco() uut := newConfigMaps(logger, fEng, nodeID, nodePrivateKey, discoKey.Public()) defer uut.close() - mClock := clock.NewMock() + mClock := clock.NewMock(t) uut.clock = mClock p1ID := uuid.UUID{1} @@ -404,7 +404,7 @@ func TestConfigMaps_updatePeers_new_waitForHandshake_timeout(t *testing.T) { } uut.updatePeers(u1) - mClock.Advance(5*time.Second).MustWait(ctx, t) + mClock.Advance(5 * time.Second).MustWait(ctx) // it should now send the peer to the netmap @@ -566,7 +566,7 @@ func TestConfigMaps_updatePeers_lost(t *testing.T) { discoKey := key.NewDisco() uut := newConfigMaps(logger, fEng, nodeID, nodePrivateKey, discoKey.Public()) defer uut.close() - mClock := clock.NewMock() + mClock := clock.NewMock(t) start := mClock.Now() uut.clock = mClock @@ -591,7 +591,7 @@ func TestConfigMaps_updatePeers_lost(t *testing.T) { require.Len(t, r.wg.Peers, 1) _ = testutil.RequireRecvCtx(ctx, t, s1) - mClock.Advance(5*time.Second).MustWait(ctx, t) + mClock.Advance(5 * time.Second).MustWait(ctx) s2 := expectStatusWithHandshake(ctx, t, fEng, p1Node.Key, start) @@ -612,7 +612,8 @@ func TestConfigMaps_updatePeers_lost(t *testing.T) { // latest handshake has advanced by a minute, so we don't remove the peer. lh := start.Add(time.Minute) s3 := expectStatusWithHandshake(ctx, t, fEng, p1Node.Key, lh) - mClock.Advance(lostTimeout).MustWait(ctx, t) + // 5 seconds have already elapsed from above + mClock.Advance(lostTimeout - 5*time.Second).MustWait(ctx) _ = testutil.RequireRecvCtx(ctx, t, s3) select { case <-fEng.setNetworkMap: @@ -624,7 +625,7 @@ func TestConfigMaps_updatePeers_lost(t *testing.T) { // Advance the clock again by a minute, which should trigger the reprogrammed // timeout. s4 := expectStatusWithHandshake(ctx, t, fEng, p1Node.Key, lh) - mClock.Advance(time.Minute).MustWait(ctx, t) + mClock.Advance(time.Minute).MustWait(ctx) nm = testutil.RequireRecvCtx(ctx, t, fEng.setNetworkMap) r = testutil.RequireRecvCtx(ctx, t, fEng.reconfig) @@ -650,7 +651,7 @@ func TestConfigMaps_updatePeers_lost_and_found(t *testing.T) { discoKey := key.NewDisco() uut := newConfigMaps(logger, fEng, nodeID, nodePrivateKey, discoKey.Public()) defer uut.close() - mClock := clock.NewMock() + mClock := clock.NewMock(t) start := mClock.Now() uut.clock = mClock @@ -675,7 +676,7 @@ func TestConfigMaps_updatePeers_lost_and_found(t *testing.T) { require.Len(t, r.wg.Peers, 1) _ = testutil.RequireRecvCtx(ctx, t, s1) - mClock.Advance(5*time.Second).MustWait(ctx, t) + mClock.Advance(5 * time.Second).MustWait(ctx) s2 := expectStatusWithHandshake(ctx, t, fEng, p1Node.Key, start) @@ -692,7 +693,7 @@ func TestConfigMaps_updatePeers_lost_and_found(t *testing.T) { // OK! } - mClock.Advance(5*time.Second).MustWait(ctx, t) + mClock.Advance(5 * time.Second).MustWait(ctx) s3 := expectStatusWithHandshake(ctx, t, fEng, p1Node.Key, start) updates[0].Kind = proto.CoordinateResponse_PeerUpdate_NODE @@ -709,7 +710,7 @@ func TestConfigMaps_updatePeers_lost_and_found(t *testing.T) { // When we advance the clock, nothing happens because the timeout was // canceled - mClock.Advance(lostTimeout).MustWait(ctx, t) + mClock.Advance(lostTimeout).MustWait(ctx) select { case <-fEng.setNetworkMap: t.Fatal("should not reprogram") @@ -735,7 +736,7 @@ func TestConfigMaps_setAllPeersLost(t *testing.T) { discoKey := key.NewDisco() uut := newConfigMaps(logger, fEng, nodeID, nodePrivateKey, discoKey.Public()) defer uut.close() - mClock := clock.NewMock() + mClock := clock.NewMock(t) start := mClock.Now() uut.clock = mClock @@ -769,7 +770,7 @@ func TestConfigMaps_setAllPeersLost(t *testing.T) { require.Len(t, r.wg.Peers, 2) _ = testutil.RequireRecvCtx(ctx, t, s1) - mClock.Advance(5*time.Second).MustWait(ctx, t) + mClock.Advance(5 * time.Second).MustWait(ctx) uut.setAllPeersLost() // No reprogramming yet, since we keep the peer around. @@ -780,10 +781,12 @@ func TestConfigMaps_setAllPeersLost(t *testing.T) { // OK! } - // When we advance the clock, even by a few ms, the timeout for peer 2 pops - // because our status only includes a handshake for peer 1 + // When we advance the clock, even by a millisecond, the timeout for peer 2 + // pops because our status only includes a handshake for peer 1 s2 := expectStatusWithHandshake(ctx, t, fEng, p1Node.Key, start) - mClock.Advance(time.Millisecond*10).MustWait(ctx, t) + d, w := mClock.AdvanceNext() + w.MustWait(ctx) + require.LessOrEqual(t, d, time.Millisecond) _ = testutil.RequireRecvCtx(ctx, t, s2) nm = testutil.RequireRecvCtx(ctx, t, fEng.setNetworkMap) @@ -793,7 +796,7 @@ func TestConfigMaps_setAllPeersLost(t *testing.T) { // Finally, advance the clock until after the timeout s3 := expectStatusWithHandshake(ctx, t, fEng, p1Node.Key, start) - mClock.Advance(lostTimeout).MustWait(ctx, t) + mClock.Advance(lostTimeout - d - 5*time.Second).MustWait(ctx) _ = testutil.RequireRecvCtx(ctx, t, s3) nm = testutil.RequireRecvCtx(ctx, t, fEng.setNetworkMap)