From 874eff6a939dcca0b10f7d42b6f4a9fb03f3a860 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Fri, 2 Dec 2022 13:00:21 +0000 Subject: [PATCH 1/4] fix: Improve agent connection tracking when agent is closed --- agent/agent.go | 105 +++++++++++++++++++++++++++++++++---------------- 1 file changed, 72 insertions(+), 33 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index bb40e27aa97e6..70b0f9db1f1b7 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -231,13 +231,27 @@ func (a *agent) run(ctx context.Context) error { return nil } -func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*tailnet.Conn, error) { +func (a *agent) trackConnGoroutine(fn func()) error { + a.closeMutex.Lock() + defer a.closeMutex.Unlock() + if a.isClosed() { + return xerrors.New("track conn goroutine: agent is closed") + } + a.connCloseWait.Add(1) + go func() { + defer a.connCloseWait.Done() + fn() + }() + return nil +} + +func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (network *tailnet.Conn, err error) { a.closeMutex.Lock() if a.isClosed() { a.closeMutex.Unlock() return nil, xerrors.New("closed") } - network, err := tailnet.NewConn(&tailnet.Options{ + network, err = tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.TailnetIP, 128)}, DERPMap: derpMap, Logger: a.logger.Named("tailnet"), @@ -248,15 +262,18 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t return nil, xerrors.Errorf("create tailnet: %w", err) } a.network = network - a.connCloseWait.Add(4) a.closeMutex.Unlock() sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetSSHPort)) if err != nil { return nil, xerrors.Errorf("listen on the ssh port: %w", err) } - go func() { - defer a.connCloseWait.Done() + defer func() { + if err != nil { + _ = sshListener.Close() + } + }() + if err = a.trackConnGoroutine(func() { for { conn, err := sshListener.Accept() if err != nil { @@ -264,14 +281,20 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t } go a.sshServer.HandleConn(conn) } - }() + }); err != nil { + return nil, err + } reconnectingPTYListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetReconnectingPTYPort)) if err != nil { return nil, xerrors.Errorf("listen for reconnecting pty: %w", err) } - go func() { - defer a.connCloseWait.Done() + defer func() { + if err != nil { + _ = reconnectingPTYListener.Close() + } + }() + if err = a.trackConnGoroutine(func() { for { conn, err := reconnectingPTYListener.Accept() if err != nil { @@ -298,36 +321,48 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t } go a.handleReconnectingPTY(ctx, msg, conn) } - }() + }); err != nil { + return nil, err + } speedtestListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetSpeedtestPort)) if err != nil { return nil, xerrors.Errorf("listen for speedtest: %w", err) } - go func() { - defer a.connCloseWait.Done() + defer func() { + if err != nil { + _ = speedtestListener.Close() + } + }() + if err = a.trackConnGoroutine(func() { for { conn, err := speedtestListener.Accept() if err != nil { a.logger.Debug(ctx, "speedtest listener failed", slog.Error(err)) return } - a.closeMutex.Lock() - a.connCloseWait.Add(1) - a.closeMutex.Unlock() - go func() { - defer a.connCloseWait.Done() + if err = a.trackConnGoroutine(func() { _ = speedtest.ServeConn(conn) - }() + }); err != nil { + a.logger.Debug(ctx, "speedtest listener failed", slog.Error(err)) + _ = conn.Close() + return + } } - }() + }); err != nil { + return nil, err + } statisticsListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetStatisticsPort)) if err != nil { return nil, xerrors.Errorf("listen for statistics: %w", err) } - go func() { - defer a.connCloseWait.Done() + defer func() { + if err != nil { + _ = statisticsListener.Close() + } + }() + if err = a.trackConnGoroutine(func() { defer statisticsListener.Close() server := &http.Server{ Handler: a.statisticsHandler(), @@ -341,11 +376,13 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t _ = server.Close() }() - err = server.Serve(statisticsListener) + err := server.Serve(statisticsListener) if err != nil && !xerrors.Is(err, http.ErrServerClosed) && !strings.Contains(err.Error(), "use of closed network connection") { a.logger.Critical(ctx, "serve statistics HTTP server", slog.Error(err)) } - }() + }); err != nil { + return nil, err + } return network, nil } @@ -527,12 +564,15 @@ func (a *agent) init(ctx context.Context) { a.logger.Error(ctx, "report stats", slog.Error(err)) return } - a.connCloseWait.Add(1) - go func() { - defer a.connCloseWait.Done() + + if err = a.trackConnGoroutine(func() { <-a.closed - cl.Close() - }() + _ = cl.Close() + }); err != nil { + a.logger.Error(ctx, "report stats goroutine", slog.Error(err)) + _ = cl.Close() + return + } } func convertAgentStats(counts map[netlogtype.Connection]netlogtype.Counts) *codersdk.AgentStats { @@ -787,9 +827,6 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec return } - a.closeMutex.Lock() - a.connCloseWait.Add(1) - a.closeMutex.Unlock() ctx, cancelFunc := context.WithCancel(ctx) rpty = &reconnectingPTY{ activeConns: map[string]net.Conn{ @@ -818,7 +855,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec _ = process.Wait() rpty.Close() }() - go func() { + if err = a.trackConnGoroutine(func() { buffer := make([]byte, 1024) for { read, err := rpty.ptty.Output().Read(buffer) @@ -846,8 +883,10 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec _ = process.Kill() rpty.Close() a.reconnectingPTYs.Delete(msg.ID) - a.connCloseWait.Done() - }() + }); err != nil { + a.logger.Error(ctx, "start reconnecting pty routine", slog.F("id", msg.ID), slog.Error(err)) + return + } } // Resize the PTY to initial height + width. err := rpty.ptty.Resize(msg.Height, msg.Width) From 71071d5e399c3bee935038a6b98a3e117236b3d7 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Fri, 2 Dec 2022 13:22:13 +0000 Subject: [PATCH 2/4] Also close tailnet if there was an error --- agent/agent.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/agent/agent.go b/agent/agent.go index 70b0f9db1f1b7..822297ad70c3d 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -261,6 +261,11 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (ne a.closeMutex.Unlock() return nil, xerrors.Errorf("create tailnet: %w", err) } + defer func() { + if err != nil { + network.Close() + } + }() a.network = network a.closeMutex.Unlock() From 6f9b3192399dcd926d05cb3e653d1fa7e0e12faa Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Fri, 2 Dec 2022 13:27:05 +0000 Subject: [PATCH 3/4] Avoid double assign of `a.network` --- agent/agent.go | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 822297ad70c3d..dd86bbf19e073 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -210,13 +210,10 @@ func (a *agent) run(ctx context.Context) error { a.closeMutex.Unlock() if network == nil { a.logger.Debug(ctx, "creating tailnet") - network, err = a.createTailnet(ctx, metadata.DERPMap) + err = a.createTailnet(ctx, metadata.DERPMap) if err != nil { return xerrors.Errorf("create tailnet: %w", err) } - a.closeMutex.Lock() - a.network = network - a.closeMutex.Unlock() } else { // Update the DERP map! network.SetDERPMap(metadata.DERPMap) @@ -245,9 +242,10 @@ func (a *agent) trackConnGoroutine(fn func()) error { return nil } -func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (network *tailnet.Conn, err error) { +func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (err error) { a.closeMutex.Lock() if a.isClosed() { + a.network = nil a.closeMutex.Unlock() return nil, xerrors.New("closed") } @@ -264,6 +262,9 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (ne defer func() { if err != nil { network.Close() + a.closeMutex.Lock() + a.network = nil + a.closeMutex.Unlock() } }() a.network = network @@ -287,7 +288,7 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (ne go a.sshServer.HandleConn(conn) } }); err != nil { - return nil, err + return err } reconnectingPTYListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetReconnectingPTYPort)) @@ -327,7 +328,7 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (ne go a.handleReconnectingPTY(ctx, msg, conn) } }); err != nil { - return nil, err + return err } speedtestListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetSpeedtestPort)) @@ -355,7 +356,7 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (ne } } }); err != nil { - return nil, err + return err } statisticsListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetStatisticsPort)) @@ -386,7 +387,7 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (ne a.logger.Critical(ctx, "serve statistics HTTP server", slog.Error(err)) } }); err != nil { - return nil, err + return err } return network, nil From a0de5a05527a886bfe5cc9dd110d3203a49d6798 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Fri, 2 Dec 2022 13:31:18 +0000 Subject: [PATCH 4/4] Revert "Avoid double assign of `a.network`" This reverts commit 6f9b3192399dcd926d05cb3e653d1fa7e0e12faa. --- agent/agent.go | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index dd86bbf19e073..822297ad70c3d 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -210,10 +210,13 @@ func (a *agent) run(ctx context.Context) error { a.closeMutex.Unlock() if network == nil { a.logger.Debug(ctx, "creating tailnet") - err = a.createTailnet(ctx, metadata.DERPMap) + network, err = a.createTailnet(ctx, metadata.DERPMap) if err != nil { return xerrors.Errorf("create tailnet: %w", err) } + a.closeMutex.Lock() + a.network = network + a.closeMutex.Unlock() } else { // Update the DERP map! network.SetDERPMap(metadata.DERPMap) @@ -242,10 +245,9 @@ func (a *agent) trackConnGoroutine(fn func()) error { return nil } -func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (err error) { +func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (network *tailnet.Conn, err error) { a.closeMutex.Lock() if a.isClosed() { - a.network = nil a.closeMutex.Unlock() return nil, xerrors.New("closed") } @@ -262,9 +264,6 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (er defer func() { if err != nil { network.Close() - a.closeMutex.Lock() - a.network = nil - a.closeMutex.Unlock() } }() a.network = network @@ -288,7 +287,7 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (er go a.sshServer.HandleConn(conn) } }); err != nil { - return err + return nil, err } reconnectingPTYListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetReconnectingPTYPort)) @@ -328,7 +327,7 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (er go a.handleReconnectingPTY(ctx, msg, conn) } }); err != nil { - return err + return nil, err } speedtestListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetSpeedtestPort)) @@ -356,7 +355,7 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (er } } }); err != nil { - return err + return nil, err } statisticsListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetStatisticsPort)) @@ -387,7 +386,7 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (er a.logger.Critical(ctx, "serve statistics HTTP server", slog.Error(err)) } }); err != nil { - return err + return nil, err } return network, nil