@@ -231,13 +231,27 @@ func (a *agent) run(ctx context.Context) error {
231
231
return nil
232
232
}
233
233
234
- func (a * agent ) createTailnet (ctx context.Context , derpMap * tailcfg.DERPMap ) (* tailnet.Conn , error ) {
234
+ func (a * agent ) trackConnGoroutine (fn func ()) error {
235
+ a .closeMutex .Lock ()
236
+ defer a .closeMutex .Unlock ()
237
+ if a .isClosed () {
238
+ return xerrors .New ("track conn goroutine: agent is closed" )
239
+ }
240
+ a .connCloseWait .Add (1 )
241
+ go func () {
242
+ defer a .connCloseWait .Done ()
243
+ fn ()
244
+ }()
245
+ return nil
246
+ }
247
+
248
+ func (a * agent ) createTailnet (ctx context.Context , derpMap * tailcfg.DERPMap ) (network * tailnet.Conn , err error ) {
235
249
a .closeMutex .Lock ()
236
250
if a .isClosed () {
237
251
a .closeMutex .Unlock ()
238
252
return nil , xerrors .New ("closed" )
239
253
}
240
- network , err : = tailnet .NewConn (& tailnet.Options {
254
+ network , err = tailnet .NewConn (& tailnet.Options {
241
255
Addresses : []netip.Prefix {netip .PrefixFrom (codersdk .TailnetIP , 128 )},
242
256
DERPMap : derpMap ,
243
257
Logger : a .logger .Named ("tailnet" ),
@@ -247,31 +261,45 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t
247
261
a .closeMutex .Unlock ()
248
262
return nil , xerrors .Errorf ("create tailnet: %w" , err )
249
263
}
264
+ defer func () {
265
+ if err != nil {
266
+ network .Close ()
267
+ }
268
+ }()
250
269
a .network = network
251
- a .connCloseWait .Add (4 )
252
270
a .closeMutex .Unlock ()
253
271
254
272
sshListener , err := network .Listen ("tcp" , ":" + strconv .Itoa (codersdk .TailnetSSHPort ))
255
273
if err != nil {
256
274
return nil , xerrors .Errorf ("listen on the ssh port: %w" , err )
257
275
}
258
- go func () {
259
- defer a .connCloseWait .Done ()
276
+ defer func () {
277
+ if err != nil {
278
+ _ = sshListener .Close ()
279
+ }
280
+ }()
281
+ if err = a .trackConnGoroutine (func () {
260
282
for {
261
283
conn , err := sshListener .Accept ()
262
284
if err != nil {
263
285
return
264
286
}
265
287
go a .sshServer .HandleConn (conn )
266
288
}
267
- }()
289
+ }); err != nil {
290
+ return nil , err
291
+ }
268
292
269
293
reconnectingPTYListener , err := network .Listen ("tcp" , ":" + strconv .Itoa (codersdk .TailnetReconnectingPTYPort ))
270
294
if err != nil {
271
295
return nil , xerrors .Errorf ("listen for reconnecting pty: %w" , err )
272
296
}
273
- go func () {
274
- defer a .connCloseWait .Done ()
297
+ defer func () {
298
+ if err != nil {
299
+ _ = reconnectingPTYListener .Close ()
300
+ }
301
+ }()
302
+ if err = a .trackConnGoroutine (func () {
275
303
for {
276
304
conn , err := reconnectingPTYListener .Accept ()
277
305
if err != nil {
@@ -298,36 +326,48 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t
298
326
}
299
327
go a .handleReconnectingPTY (ctx , msg , conn )
300
328
}
301
- }()
329
+ }); err != nil {
330
+ return nil , err
331
+ }
302
332
303
333
speedtestListener , err := network .Listen ("tcp" , ":" + strconv .Itoa (codersdk .TailnetSpeedtestPort ))
304
334
if err != nil {
305
335
return nil , xerrors .Errorf ("listen for speedtest: %w" , err )
306
336
}
307
- go func () {
308
- defer a .connCloseWait .Done ()
337
+ defer func () {
338
+ if err != nil {
339
+ _ = speedtestListener .Close ()
340
+ }
341
+ }()
342
+ if err = a .trackConnGoroutine (func () {
309
343
for {
310
344
conn , err := speedtestListener .Accept ()
311
345
if err != nil {
312
346
a .logger .Debug (ctx , "speedtest listener failed" , slog .Error (err ))
313
347
return
314
348
}
315
- a .closeMutex .Lock ()
316
- a .connCloseWait .Add (1 )
317
- a .closeMutex .Unlock ()
318
- go func () {
319
- defer a .connCloseWait .Done ()
349
+ if err = a .trackConnGoroutine (func () {
320
350
_ = speedtest .ServeConn (conn )
321
- }()
351
+ }); err != nil {
352
+ a .logger .Debug (ctx , "speedtest listener failed" , slog .Error (err ))
353
+ _ = conn .Close ()
354
+ return
355
+ }
322
356
}
323
- }()
357
+ }); err != nil {
358
+ return nil , err
359
+ }
324
360
325
361
statisticsListener , err := network .Listen ("tcp" , ":" + strconv .Itoa (codersdk .TailnetStatisticsPort ))
326
362
if err != nil {
327
363
return nil , xerrors .Errorf ("listen for statistics: %w" , err )
328
364
}
329
- go func () {
330
- defer a .connCloseWait .Done ()
365
+ defer func () {
366
+ if err != nil {
367
+ _ = statisticsListener .Close ()
368
+ }
369
+ }()
370
+ if err = a .trackConnGoroutine (func () {
331
371
defer statisticsListener .Close ()
332
372
server := & http.Server {
333
373
Handler : a .statisticsHandler (),
@@ -341,11 +381,13 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t
341
381
_ = server .Close ()
342
382
}()
343
383
344
- err = server .Serve (statisticsListener )
384
+ err : = server .Serve (statisticsListener )
345
385
if err != nil && ! xerrors .Is (err , http .ErrServerClosed ) && ! strings .Contains (err .Error (), "use of closed network connection" ) {
346
386
a .logger .Critical (ctx , "serve statistics HTTP server" , slog .Error (err ))
347
387
}
348
- }()
388
+ }); err != nil {
389
+ return nil , err
390
+ }
349
391
350
392
return network , nil
351
393
}
@@ -527,12 +569,15 @@ func (a *agent) init(ctx context.Context) {
527
569
a .logger .Error (ctx , "report stats" , slog .Error (err ))
528
570
return
529
571
}
530
- a .connCloseWait .Add (1 )
531
- go func () {
532
- defer a .connCloseWait .Done ()
572
+
573
+ if err = a .trackConnGoroutine (func () {
533
574
<- a .closed
534
- cl .Close ()
535
- }()
575
+ _ = cl .Close ()
576
+ }); err != nil {
577
+ a .logger .Error (ctx , "report stats goroutine" , slog .Error (err ))
578
+ _ = cl .Close ()
579
+ return
580
+ }
536
581
}
537
582
538
583
func convertAgentStats (counts map [netlogtype.Connection ]netlogtype.Counts ) * codersdk.AgentStats {
@@ -787,9 +832,6 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec
787
832
return
788
833
}
789
834
790
- a .closeMutex .Lock ()
791
- a .connCloseWait .Add (1 )
792
- a .closeMutex .Unlock ()
793
835
ctx , cancelFunc := context .WithCancel (ctx )
794
836
rpty = & reconnectingPTY {
795
837
activeConns : map [string ]net.Conn {
@@ -818,7 +860,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec
818
860
_ = process .Wait ()
819
861
rpty .Close ()
820
862
}()
821
- go func () {
863
+ if err = a . trackConnGoroutine ( func () {
822
864
buffer := make ([]byte , 1024 )
823
865
for {
824
866
read , err := rpty .ptty .Output ().Read (buffer )
@@ -846,8 +888,10 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec
846
888
_ = process .Kill ()
847
889
rpty .Close ()
848
890
a .reconnectingPTYs .Delete (msg .ID )
849
- a .connCloseWait .Done ()
850
- }()
891
+ }); err != nil {
892
+ a .logger .Error (ctx , "start reconnecting pty routine" , slog .F ("id" , msg .ID ), slog .Error (err ))
893
+ return
894
+ }
851
895
}
852
896
// Resize the PTY to initial height + width.
853
897
err := rpty .ptty .Resize (msg .Height , msg .Width )
0 commit comments