Skip to content

Commit fa64155

Browse files
authored
fix: Improve agent connection tracking when agent is closed (coder#5253)
1 parent 81c3948 commit fa64155

File tree

1 file changed

+77
-33
lines changed

1 file changed

+77
-33
lines changed

agent/agent.go

Lines changed: 77 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -231,13 +231,27 @@ func (a *agent) run(ctx context.Context) error {
231231
return nil
232232
}
233233

234-
func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*tailnet.Conn, error) {
234+
func (a *agent) trackConnGoroutine(fn func()) error {
235+
a.closeMutex.Lock()
236+
defer a.closeMutex.Unlock()
237+
if a.isClosed() {
238+
return xerrors.New("track conn goroutine: agent is closed")
239+
}
240+
a.connCloseWait.Add(1)
241+
go func() {
242+
defer a.connCloseWait.Done()
243+
fn()
244+
}()
245+
return nil
246+
}
247+
248+
func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (network *tailnet.Conn, err error) {
235249
a.closeMutex.Lock()
236250
if a.isClosed() {
237251
a.closeMutex.Unlock()
238252
return nil, xerrors.New("closed")
239253
}
240-
network, err := tailnet.NewConn(&tailnet.Options{
254+
network, err = tailnet.NewConn(&tailnet.Options{
241255
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.TailnetIP, 128)},
242256
DERPMap: derpMap,
243257
Logger: a.logger.Named("tailnet"),
@@ -247,31 +261,45 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t
247261
a.closeMutex.Unlock()
248262
return nil, xerrors.Errorf("create tailnet: %w", err)
249263
}
264+
defer func() {
265+
if err != nil {
266+
network.Close()
267+
}
268+
}()
250269
a.network = network
251-
a.connCloseWait.Add(4)
252270
a.closeMutex.Unlock()
253271

254272
sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetSSHPort))
255273
if err != nil {
256274
return nil, xerrors.Errorf("listen on the ssh port: %w", err)
257275
}
258-
go func() {
259-
defer a.connCloseWait.Done()
276+
defer func() {
277+
if err != nil {
278+
_ = sshListener.Close()
279+
}
280+
}()
281+
if err = a.trackConnGoroutine(func() {
260282
for {
261283
conn, err := sshListener.Accept()
262284
if err != nil {
263285
return
264286
}
265287
go a.sshServer.HandleConn(conn)
266288
}
267-
}()
289+
}); err != nil {
290+
return nil, err
291+
}
268292

269293
reconnectingPTYListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetReconnectingPTYPort))
270294
if err != nil {
271295
return nil, xerrors.Errorf("listen for reconnecting pty: %w", err)
272296
}
273-
go func() {
274-
defer a.connCloseWait.Done()
297+
defer func() {
298+
if err != nil {
299+
_ = reconnectingPTYListener.Close()
300+
}
301+
}()
302+
if err = a.trackConnGoroutine(func() {
275303
for {
276304
conn, err := reconnectingPTYListener.Accept()
277305
if err != nil {
@@ -298,36 +326,48 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t
298326
}
299327
go a.handleReconnectingPTY(ctx, msg, conn)
300328
}
301-
}()
329+
}); err != nil {
330+
return nil, err
331+
}
302332

303333
speedtestListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetSpeedtestPort))
304334
if err != nil {
305335
return nil, xerrors.Errorf("listen for speedtest: %w", err)
306336
}
307-
go func() {
308-
defer a.connCloseWait.Done()
337+
defer func() {
338+
if err != nil {
339+
_ = speedtestListener.Close()
340+
}
341+
}()
342+
if err = a.trackConnGoroutine(func() {
309343
for {
310344
conn, err := speedtestListener.Accept()
311345
if err != nil {
312346
a.logger.Debug(ctx, "speedtest listener failed", slog.Error(err))
313347
return
314348
}
315-
a.closeMutex.Lock()
316-
a.connCloseWait.Add(1)
317-
a.closeMutex.Unlock()
318-
go func() {
319-
defer a.connCloseWait.Done()
349+
if err = a.trackConnGoroutine(func() {
320350
_ = speedtest.ServeConn(conn)
321-
}()
351+
}); err != nil {
352+
a.logger.Debug(ctx, "speedtest listener failed", slog.Error(err))
353+
_ = conn.Close()
354+
return
355+
}
322356
}
323-
}()
357+
}); err != nil {
358+
return nil, err
359+
}
324360

325361
statisticsListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetStatisticsPort))
326362
if err != nil {
327363
return nil, xerrors.Errorf("listen for statistics: %w", err)
328364
}
329-
go func() {
330-
defer a.connCloseWait.Done()
365+
defer func() {
366+
if err != nil {
367+
_ = statisticsListener.Close()
368+
}
369+
}()
370+
if err = a.trackConnGoroutine(func() {
331371
defer statisticsListener.Close()
332372
server := &http.Server{
333373
Handler: a.statisticsHandler(),
@@ -341,11 +381,13 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t
341381
_ = server.Close()
342382
}()
343383

344-
err = server.Serve(statisticsListener)
384+
err := server.Serve(statisticsListener)
345385
if err != nil && !xerrors.Is(err, http.ErrServerClosed) && !strings.Contains(err.Error(), "use of closed network connection") {
346386
a.logger.Critical(ctx, "serve statistics HTTP server", slog.Error(err))
347387
}
348-
}()
388+
}); err != nil {
389+
return nil, err
390+
}
349391

350392
return network, nil
351393
}
@@ -527,12 +569,15 @@ func (a *agent) init(ctx context.Context) {
527569
a.logger.Error(ctx, "report stats", slog.Error(err))
528570
return
529571
}
530-
a.connCloseWait.Add(1)
531-
go func() {
532-
defer a.connCloseWait.Done()
572+
573+
if err = a.trackConnGoroutine(func() {
533574
<-a.closed
534-
cl.Close()
535-
}()
575+
_ = cl.Close()
576+
}); err != nil {
577+
a.logger.Error(ctx, "report stats goroutine", slog.Error(err))
578+
_ = cl.Close()
579+
return
580+
}
536581
}
537582

538583
func convertAgentStats(counts map[netlogtype.Connection]netlogtype.Counts) *codersdk.AgentStats {
@@ -787,9 +832,6 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec
787832
return
788833
}
789834

790-
a.closeMutex.Lock()
791-
a.connCloseWait.Add(1)
792-
a.closeMutex.Unlock()
793835
ctx, cancelFunc := context.WithCancel(ctx)
794836
rpty = &reconnectingPTY{
795837
activeConns: map[string]net.Conn{
@@ -818,7 +860,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec
818860
_ = process.Wait()
819861
rpty.Close()
820862
}()
821-
go func() {
863+
if err = a.trackConnGoroutine(func() {
822864
buffer := make([]byte, 1024)
823865
for {
824866
read, err := rpty.ptty.Output().Read(buffer)
@@ -846,8 +888,10 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec
846888
_ = process.Kill()
847889
rpty.Close()
848890
a.reconnectingPTYs.Delete(msg.ID)
849-
a.connCloseWait.Done()
850-
}()
891+
}); err != nil {
892+
a.logger.Error(ctx, "start reconnecting pty routine", slog.F("id", msg.ID), slog.Error(err))
893+
return
894+
}
851895
}
852896
// Resize the PTY to initial height + width.
853897
err := rpty.ptty.Resize(msg.Height, msg.Width)

0 commit comments

Comments
 (0)