@@ -104,6 +104,7 @@ func TestHeartbeat(t *testing.T) {
104
104
ctx , cancel := context .WithCancel (context .Background ())
105
105
t .Cleanup (cancel )
106
106
heartbeatChan := make (chan struct {})
107
+ heartbeatDone := make (chan struct {})
107
108
heartbeatFn := func (hbCtx context.Context ) error {
108
109
t .Logf ("heartbeat" )
109
110
select {
@@ -117,6 +118,7 @@ func TestHeartbeat(t *testing.T) {
117
118
//nolint:dogsled // 。:゚૮ ˶ˆ ﻌ ˆ˶ ა ゚:。
118
119
_ , _ , _ , _ = setup (t , false , & overrides {
119
120
ctx : ctx ,
121
+ heartbeatDone : heartbeatDone ,
120
122
heartbeatFn : heartbeatFn ,
121
123
heartbeatInterval : testutil .IntervalFast ,
122
124
})
@@ -125,17 +127,9 @@ func TestHeartbeat(t *testing.T) {
125
127
require .True (t , ok , "first heartbeat not received" )
126
128
_ , ok = <- heartbeatChan
127
129
require .True (t , ok , "second heartbeat not received" )
130
+ // Cancel the context. This should cause heartbeatDone to be closed.
128
131
cancel ()
129
- // Close the channel to ensure we don't receive any more heartbeats.
130
- // The test will fail if we do.
131
- defer func () {
132
- if r := recover (); r != nil {
133
- t .Fatalf ("heartbeat received after cancel: %v" , r )
134
- }
135
- }()
136
-
137
- close (heartbeatChan )
138
- <- time .After (testutil .IntervalMedium )
132
+ <- heartbeatDone
139
133
}
140
134
141
135
func TestAcquireJob (t * testing.T ) {
@@ -1727,6 +1721,7 @@ type overrides struct {
1727
1721
acquireJobLongPollDuration time.Duration
1728
1722
heartbeatFn func (ctx context.Context ) error
1729
1723
heartbeatInterval time.Duration
1724
+ heartbeatDone chan struct {}
1730
1725
}
1731
1726
1732
1727
func setup (t * testing.T , ignoreLogErrors bool , ov * overrides ) (proto.DRPCProvisionerDaemonServer , database.Store , pubsub.Pubsub , database.ProvisionerDaemon ) {
@@ -1751,6 +1746,9 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
1751
1746
if ov .heartbeatInterval == 0 {
1752
1747
ov .heartbeatInterval = testutil .IntervalMedium
1753
1748
}
1749
+ if ov .heartbeatDone == nil {
1750
+ ov .heartbeatDone = make (chan struct {})
1751
+ }
1754
1752
if ov .deploymentValues != nil {
1755
1753
deploymentValues = ov .deploymentValues
1756
1754
}
@@ -1815,6 +1813,7 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
1815
1813
AcquireJobLongPollDur : pollDur ,
1816
1814
HeartbeatInterval : ov .heartbeatInterval ,
1817
1815
HeartbeatFn : ov .heartbeatFn ,
1816
+ HeartbeatDone : ov .heartbeatDone ,
1818
1817
},
1819
1818
)
1820
1819
require .NoError (t , err )
0 commit comments