@@ -229,13 +229,21 @@ type agent struct {
229
229
// we track 2 contexts and associated cancel functions: "graceful" which is Done when it is time
230
230
// to start gracefully shutting down and "hard" which is Done when it is time to close
231
231
// everything down (regardless of whether graceful shutdown completed).
232
- gracefulCtx context.Context
233
- gracefulCancel context.CancelFunc
234
- hardCtx context.Context
235
- hardCancel context.CancelFunc
236
- closeWaitGroup sync.WaitGroup
232
+ gracefulCtx context.Context
233
+ gracefulCancel context.CancelFunc
234
+ hardCtx context.Context
235
+ hardCancel context.CancelFunc
236
+
237
+ // closeMutex protects the following:
237
238
closeMutex sync.Mutex
239
+ closeWaitGroup sync.WaitGroup
238
240
coordDisconnected chan struct {}
241
+ closing bool
242
+ // note that once the network is set to non-nil, it is never modified, as with the statsReporter. So, routines
243
+ // that run after createOrUpdateNetwork and check the networkOK checkpoint do not need to hold the lock to use them.
244
+ network * tailnet.Conn
245
+ statsReporter * statsReporter
246
+ // end fields protected by closeMutex
239
247
240
248
environmentVariables map [string ]string
241
249
@@ -259,9 +267,7 @@ type agent struct {
259
267
reportConnectionsMu sync.Mutex
260
268
reportConnections []* proto.ReportConnectionRequest
261
269
262
- network * tailnet.Conn
263
- statsReporter * statsReporter
264
- logSender * agentsdk.LogSender
270
+ logSender * agentsdk.LogSender
265
271
266
272
prometheusRegistry * prometheus.Registry
267
273
// metrics are prometheus registered metrics that will be collected and
@@ -274,6 +280,8 @@ type agent struct {
274
280
}
275
281
276
282
func (a * agent ) TailnetConn () * tailnet.Conn {
283
+ a .closeMutex .Lock ()
284
+ defer a .closeMutex .Unlock ()
277
285
return a .network
278
286
}
279
287
@@ -1205,15 +1213,15 @@ func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(co
1205
1213
}
1206
1214
a .closeMutex .Lock ()
1207
1215
// Re-check if agent was closed while initializing the network.
1208
- closed := a .isClosed ()
1209
- if ! closed {
1216
+ closing := a .closing
1217
+ if ! closing {
1210
1218
a .network = network
1211
1219
a .statsReporter = newStatsReporter (a .logger , network , a )
1212
1220
}
1213
1221
a .closeMutex .Unlock ()
1214
- if closed {
1222
+ if closing {
1215
1223
_ = network .Close ()
1216
- return xerrors .New ("agent is closed " )
1224
+ return xerrors .New ("agent is closing " )
1217
1225
}
1218
1226
} else {
1219
1227
// Update the wireguard IPs if the agent ID changed.
@@ -1328,8 +1336,8 @@ func (*agent) wireguardAddresses(agentID uuid.UUID) []netip.Prefix {
1328
1336
func (a * agent ) trackGoroutine (fn func ()) error {
1329
1337
a .closeMutex .Lock ()
1330
1338
defer a .closeMutex .Unlock ()
1331
- if a .isClosed () {
1332
- return xerrors .New ("track conn goroutine: agent is closed " )
1339
+ if a .closing {
1340
+ return xerrors .New ("track conn goroutine: agent is closing " )
1333
1341
}
1334
1342
a .closeWaitGroup .Add (1 )
1335
1343
go func () {
@@ -1547,7 +1555,7 @@ func (a *agent) runCoordinator(ctx context.Context, tClient tailnetproto.DRPCTai
1547
1555
func (a * agent ) setCoordDisconnected () chan struct {} {
1548
1556
a .closeMutex .Lock ()
1549
1557
defer a .closeMutex .Unlock ()
1550
- if a .isClosed () {
1558
+ if a .closing {
1551
1559
return nil
1552
1560
}
1553
1561
disconnected := make (chan struct {})
@@ -1772,7 +1780,10 @@ func (a *agent) HTTPDebug() http.Handler {
1772
1780
1773
1781
func (a * agent ) Close () error {
1774
1782
a .closeMutex .Lock ()
1775
- defer a .closeMutex .Unlock ()
1783
+ network := a .network
1784
+ coordDisconnected := a .coordDisconnected
1785
+ a .closing = true
1786
+ a .closeMutex .Unlock ()
1776
1787
if a .isClosed () {
1777
1788
return nil
1778
1789
}
@@ -1849,7 +1860,7 @@ lifecycleWaitLoop:
1849
1860
select {
1850
1861
case <- a .hardCtx .Done ():
1851
1862
a .logger .Warn (context .Background (), "timed out waiting for Coordinator RPC disconnect" )
1852
- case <- a . coordDisconnected :
1863
+ case <- coordDisconnected :
1853
1864
a .logger .Debug (context .Background (), "coordinator RPC disconnected" )
1854
1865
}
1855
1866
@@ -1860,8 +1871,8 @@ lifecycleWaitLoop:
1860
1871
}
1861
1872
1862
1873
a .hardCancel ()
1863
- if a . network != nil {
1864
- _ = a . network .Close ()
1874
+ if network != nil {
1875
+ _ = network .Close ()
1865
1876
}
1866
1877
a .closeWaitGroup .Wait ()
1867
1878
0 commit comments