diff --git a/agent/agent.go b/agent/agent.go index bb40e27aa97e6..822297ad70c3d 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -231,13 +231,27 @@ func (a *agent) run(ctx context.Context) error { return nil } -func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*tailnet.Conn, error) { +func (a *agent) trackConnGoroutine(fn func()) error { + a.closeMutex.Lock() + defer a.closeMutex.Unlock() + if a.isClosed() { + return xerrors.New("track conn goroutine: agent is closed") + } + a.connCloseWait.Add(1) + go func() { + defer a.connCloseWait.Done() + fn() + }() + return nil +} + +func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (network *tailnet.Conn, err error) { a.closeMutex.Lock() if a.isClosed() { a.closeMutex.Unlock() return nil, xerrors.New("closed") } - network, err := tailnet.NewConn(&tailnet.Options{ + network, err = tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.TailnetIP, 128)}, DERPMap: derpMap, Logger: a.logger.Named("tailnet"), @@ -247,16 +261,24 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t a.closeMutex.Unlock() return nil, xerrors.Errorf("create tailnet: %w", err) } + defer func() { + if err != nil { + network.Close() + } + }() a.network = network - a.connCloseWait.Add(4) a.closeMutex.Unlock() sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetSSHPort)) if err != nil { return nil, xerrors.Errorf("listen on the ssh port: %w", err) } - go func() { - defer a.connCloseWait.Done() + defer func() { + if err != nil { + _ = sshListener.Close() + } + }() + if err = a.trackConnGoroutine(func() { for { conn, err := sshListener.Accept() if err != nil { @@ -264,14 +286,20 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t } go a.sshServer.HandleConn(conn) } - }() + }); err != nil { + return nil, err + } reconnectingPTYListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetReconnectingPTYPort)) if err != nil { return nil, xerrors.Errorf("listen for reconnecting pty: %w", err) } - go func() { - defer a.connCloseWait.Done() + defer func() { + if err != nil { + _ = reconnectingPTYListener.Close() + } + }() + if err = a.trackConnGoroutine(func() { for { conn, err := reconnectingPTYListener.Accept() if err != nil { @@ -298,36 +326,48 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t } go a.handleReconnectingPTY(ctx, msg, conn) } - }() + }); err != nil { + return nil, err + } speedtestListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetSpeedtestPort)) if err != nil { return nil, xerrors.Errorf("listen for speedtest: %w", err) } - go func() { - defer a.connCloseWait.Done() + defer func() { + if err != nil { + _ = speedtestListener.Close() + } + }() + if err = a.trackConnGoroutine(func() { for { conn, err := speedtestListener.Accept() if err != nil { a.logger.Debug(ctx, "speedtest listener failed", slog.Error(err)) return } - a.closeMutex.Lock() - a.connCloseWait.Add(1) - a.closeMutex.Unlock() - go func() { - defer a.connCloseWait.Done() + if err = a.trackConnGoroutine(func() { _ = speedtest.ServeConn(conn) - }() + }); err != nil { + a.logger.Debug(ctx, "speedtest listener failed", slog.Error(err)) + _ = conn.Close() + return + } } - }() + }); err != nil { + return nil, err + } statisticsListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetStatisticsPort)) if err != nil { return nil, xerrors.Errorf("listen for statistics: %w", err) } - go func() { - defer a.connCloseWait.Done() + defer func() { + if err != nil { + _ = statisticsListener.Close() + } + }() + if err = a.trackConnGoroutine(func() { defer statisticsListener.Close() server := &http.Server{ Handler: a.statisticsHandler(), @@ -341,11 +381,13 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t _ = server.Close() }() - err = server.Serve(statisticsListener) + err := server.Serve(statisticsListener) if err != nil && !xerrors.Is(err, http.ErrServerClosed) && !strings.Contains(err.Error(), "use of closed network connection") { a.logger.Critical(ctx, "serve statistics HTTP server", slog.Error(err)) } - }() + }); err != nil { + return nil, err + } return network, nil } @@ -527,12 +569,15 @@ func (a *agent) init(ctx context.Context) { a.logger.Error(ctx, "report stats", slog.Error(err)) return } - a.connCloseWait.Add(1) - go func() { - defer a.connCloseWait.Done() + + if err = a.trackConnGoroutine(func() { <-a.closed - cl.Close() - }() + _ = cl.Close() + }); err != nil { + a.logger.Error(ctx, "report stats goroutine", slog.Error(err)) + _ = cl.Close() + return + } } func convertAgentStats(counts map[netlogtype.Connection]netlogtype.Counts) *codersdk.AgentStats { @@ -787,9 +832,6 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec return } - a.closeMutex.Lock() - a.connCloseWait.Add(1) - a.closeMutex.Unlock() ctx, cancelFunc := context.WithCancel(ctx) rpty = &reconnectingPTY{ activeConns: map[string]net.Conn{ @@ -818,7 +860,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec _ = process.Wait() rpty.Close() }() - go func() { + if err = a.trackConnGoroutine(func() { buffer := make([]byte, 1024) for { read, err := rpty.ptty.Output().Read(buffer) @@ -846,8 +888,10 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec _ = process.Kill() rpty.Close() a.reconnectingPTYs.Delete(msg.ID) - a.connCloseWait.Done() - }() + }); err != nil { + a.logger.Error(ctx, "start reconnecting pty routine", slog.F("id", msg.ID), slog.Error(err)) + return + } } // Resize the PTY to initial height + width. err := rpty.ptty.Resize(msg.Height, msg.Width)