Skip to content

Commit d37a61c

Browse files
committed
fix(coderd/provisionerdserver): fix test flake in TestHeartbeat
1 parent 8bc91b4 commit d37a61c

File tree

2 files changed

+23
-14
lines changed

2 files changed

+23
-14
lines changed

coderd/provisionerdserver/provisionerdserver.go

+14-4
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ type Options struct {
7272
// The default function just calls UpdateProvisionerDaemonLastSeenAt.
7373
// This is mainly used for testing.
7474
HeartbeatFn func(context.Context) error
75+
76+
// HeartbeatDone is used for testing.
77+
HeartbeatDone chan struct{}
7578
}
7679

7780
type server struct {
@@ -183,6 +186,9 @@ func NewServer(
183186
if options.HeartbeatInterval == 0 {
184187
options.HeartbeatInterval = DefaultHeartbeatInterval
185188
}
189+
if options.HeartbeatDone == nil {
190+
options.HeartbeatDone = make(chan struct{})
191+
}
186192

187193
s := &server{
188194
lifecycleCtx: lifecycleCtx,
@@ -213,7 +219,7 @@ func NewServer(
213219
s.heartbeatFn = s.defaultHeartbeat
214220
}
215221

216-
go s.heartbeatLoop()
222+
go s.heartbeatLoop(options.HeartbeatDone)
217223
return s, nil
218224
}
219225

@@ -227,17 +233,21 @@ func (s *server) timeNow() time.Time {
227233
}
228234

229235
// heartbeatLoop runs heartbeatOnce at the interval specified by HeartbeatInterval
230-
// until the lifecycle context is canceled.
231-
func (s *server) heartbeatLoop() {
236+
// until the lifecycle context is canceled. Done is closed on exit.
237+
func (s *server) heartbeatLoop(hbDone chan<- struct{}) {
232238
tick := time.NewTicker(time.Nanosecond)
233-
defer tick.Stop()
239+
defer func() {
240+
close(hbDone)
241+
}()
234242
for {
235243
select {
236244
case <-s.lifecycleCtx.Done():
237245
s.Logger.Debug(s.lifecycleCtx, "heartbeat loop canceled")
246+
tick.Stop()
238247
return
239248
case <-tick.C:
240249
if s.lifecycleCtx.Err() != nil {
250+
tick.Stop()
241251
return
242252
}
243253
start := s.timeNow()

coderd/provisionerdserver/provisionerdserver_test.go

+9-10
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ func TestHeartbeat(t *testing.T) {
104104
ctx, cancel := context.WithCancel(context.Background())
105105
t.Cleanup(cancel)
106106
heartbeatChan := make(chan struct{})
107+
heartbeatDone := make(chan struct{})
107108
heartbeatFn := func(hbCtx context.Context) error {
108109
t.Logf("heartbeat")
109110
select {
@@ -117,6 +118,7 @@ func TestHeartbeat(t *testing.T) {
117118
//nolint:dogsled // 。:゚૮ ˶ˆ ﻌ ˆ˶ ა ゚:。
118119
_, _, _, _ = setup(t, false, &overrides{
119120
ctx: ctx,
121+
heartbeatDone: heartbeatDone,
120122
heartbeatFn: heartbeatFn,
121123
heartbeatInterval: testutil.IntervalFast,
122124
})
@@ -125,17 +127,9 @@ func TestHeartbeat(t *testing.T) {
125127
require.True(t, ok, "first heartbeat not received")
126128
_, ok = <-heartbeatChan
127129
require.True(t, ok, "second heartbeat not received")
130+
// Cancel the context. This should cause heartbeatDone to be closed.
128131
cancel()
129-
// Close the channel to ensure we don't receive any more heartbeats.
130-
// The test will fail if we do.
131-
defer func() {
132-
if r := recover(); r != nil {
133-
t.Fatalf("heartbeat received after cancel: %v", r)
134-
}
135-
}()
136-
137-
close(heartbeatChan)
138-
<-time.After(testutil.IntervalMedium)
132+
<-heartbeatDone
139133
}
140134

141135
func TestAcquireJob(t *testing.T) {
@@ -1727,6 +1721,7 @@ type overrides struct {
17271721
acquireJobLongPollDuration time.Duration
17281722
heartbeatFn func(ctx context.Context) error
17291723
heartbeatInterval time.Duration
1724+
heartbeatDone chan struct{}
17301725
}
17311726

17321727
func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisionerDaemonServer, database.Store, pubsub.Pubsub, database.ProvisionerDaemon) {
@@ -1751,6 +1746,9 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
17511746
if ov.heartbeatInterval == 0 {
17521747
ov.heartbeatInterval = testutil.IntervalMedium
17531748
}
1749+
if ov.heartbeatDone == nil {
1750+
ov.heartbeatDone = make(chan struct{})
1751+
}
17541752
if ov.deploymentValues != nil {
17551753
deploymentValues = ov.deploymentValues
17561754
}
@@ -1815,6 +1813,7 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
18151813
AcquireJobLongPollDur: pollDur,
18161814
HeartbeatInterval: ov.heartbeatInterval,
18171815
HeartbeatFn: ov.heartbeatFn,
1816+
HeartbeatDone: ov.heartbeatDone,
18181817
},
18191818
)
18201819
require.NoError(t, err)

0 commit comments

Comments
 (0)