From 6f45bafb467ce594557aa873b9a7ae2cbbb14cd9 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 21 Nov 2024 16:55:48 +0400 Subject: [PATCH] fix: fix panic while tearing down reconnecting PTY --- agent/agent.go | 155 +++----------------------- agent/reconnectingpty/server.go | 191 ++++++++++++++++++++++++++++++++ 2 files changed, 208 insertions(+), 138 deletions(-) create mode 100644 agent/reconnectingpty/server.go diff --git a/agent/agent.go b/agent/agent.go index 6d27802d1291a..2d5b9a663202e 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -3,12 +3,10 @@ package agent import ( "bytes" "context" - "encoding/binary" "encoding/json" "errors" "fmt" "io" - "net" "net/http" "net/netip" "os" @@ -216,8 +214,8 @@ type agent struct { portCacheDuration time.Duration subsystems []codersdk.AgentSubsystem - reconnectingPTYs sync.Map reconnectingPTYTimeout time.Duration + reconnectingPTYServer *reconnectingpty.Server // we track 2 contexts and associated cancel functions: "graceful" which is Done when it is time // to start gracefully shutting down and "hard" which is Done when it is time to close @@ -252,8 +250,6 @@ type agent struct { statsReporter *statsReporter logSender *agentsdk.LogSender - connCountReconnectingPTY atomic.Int64 - prometheusRegistry *prometheus.Registry // metrics are prometheus registered metrics that will be collected and // labeled in Coder with the agent + workspace. @@ -297,6 +293,13 @@ func (a *agent) init() { // Register runner metrics. If the prom registry is nil, the metrics // will not report anywhere. a.scriptRunner.RegisterMetrics(a.prometheusRegistry) + + a.reconnectingPTYServer = reconnectingpty.NewServer( + a.logger.Named("reconnecting-pty"), + a.sshServer, + a.metrics.connectionsTotal, a.metrics.reconnectingPTYErrors, + a.reconnectingPTYTimeout, + ) go a.runLoop() } @@ -1181,55 +1184,12 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t } }() if err = a.trackGoroutine(func() { - logger := a.logger.Named("reconnecting-pty") - var wg sync.WaitGroup - for { - conn, err := reconnectingPTYListener.Accept() - if err != nil { - if !a.isClosed() { - logger.Debug(ctx, "accept pty failed", slog.Error(err)) - } - break - } - clog := logger.With( - slog.F("remote", conn.RemoteAddr().String()), - slog.F("local", conn.LocalAddr().String())) - clog.Info(ctx, "accepted conn") - wg.Add(1) - closed := make(chan struct{}) - go func() { - select { - case <-closed: - case <-a.hardCtx.Done(): - _ = conn.Close() - } - wg.Done() - }() - go func() { - defer close(closed) - // This cannot use a JSON decoder, since that can - // buffer additional data that is required for the PTY. - rawLen := make([]byte, 2) - _, err = conn.Read(rawLen) - if err != nil { - return - } - length := binary.LittleEndian.Uint16(rawLen) - data := make([]byte, length) - _, err = conn.Read(data) - if err != nil { - return - } - var msg workspacesdk.AgentReconnectingPTYInit - err = json.Unmarshal(data, &msg) - if err != nil { - logger.Warn(ctx, "failed to unmarshal init", slog.F("raw", data)) - return - } - _ = a.handleReconnectingPTY(ctx, clog, msg, conn) - }() + rPTYServeErr := a.reconnectingPTYServer.Serve(a.gracefulCtx, a.hardCtx, reconnectingPTYListener) + if rPTYServeErr != nil && + a.gracefulCtx.Err() == nil && + !strings.Contains(rPTYServeErr.Error(), "use of closed network connection") { + a.logger.Error(ctx, "error serving reconnecting PTY", slog.Error(err)) } - wg.Wait() }); err != nil { return nil, err } @@ -1308,9 +1268,9 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t _ = server.Close() }() - err := server.Serve(apiListener) - if err != nil && !xerrors.Is(err, http.ErrServerClosed) && !strings.Contains(err.Error(), "use of closed network connection") { - a.logger.Critical(ctx, "serve HTTP API server", slog.Error(err)) + apiServErr := server.Serve(apiListener) + if apiServErr != nil && !xerrors.Is(apiServErr, http.ErrServerClosed) && !strings.Contains(apiServErr.Error(), "use of closed network connection") { + a.logger.Critical(ctx, "serve HTTP API server", slog.Error(apiServErr)) } }); err != nil { return nil, err @@ -1394,87 +1354,6 @@ func (a *agent) runDERPMapSubscriber(ctx context.Context, tClient tailnetproto.D } } -func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, msg workspacesdk.AgentReconnectingPTYInit, conn net.Conn) (retErr error) { - defer conn.Close() - a.metrics.connectionsTotal.Add(1) - - a.connCountReconnectingPTY.Add(1) - defer a.connCountReconnectingPTY.Add(-1) - - connectionID := uuid.NewString() - connLogger := logger.With(slog.F("message_id", msg.ID), slog.F("connection_id", connectionID)) - connLogger.Debug(ctx, "starting handler") - - defer func() { - if err := retErr; err != nil { - a.closeMutex.Lock() - closed := a.isClosed() - a.closeMutex.Unlock() - - // If the agent is closed, we don't want to - // log this as an error since it's expected. - if closed { - connLogger.Info(ctx, "reconnecting pty failed with attach error (agent closed)", slog.Error(err)) - } else { - connLogger.Error(ctx, "reconnecting pty failed with attach error", slog.Error(err)) - } - } - connLogger.Info(ctx, "reconnecting pty connection closed") - }() - - var rpty reconnectingpty.ReconnectingPTY - sendConnected := make(chan reconnectingpty.ReconnectingPTY, 1) - // On store, reserve this ID to prevent multiple concurrent new connections. - waitReady, ok := a.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected) - if ok { - close(sendConnected) // Unused. - connLogger.Debug(ctx, "connecting to existing reconnecting pty") - c, ok := waitReady.(chan reconnectingpty.ReconnectingPTY) - if !ok { - return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady) - } - rpty, ok = <-c - if !ok || rpty == nil { - return xerrors.Errorf("reconnecting pty closed before connection") - } - c <- rpty // Put it back for the next reconnect. - } else { - connLogger.Debug(ctx, "creating new reconnecting pty") - - connected := false - defer func() { - if !connected && retErr != nil { - a.reconnectingPTYs.Delete(msg.ID) - close(sendConnected) - } - }() - - // Empty command will default to the users shell! - cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil) - if err != nil { - a.metrics.reconnectingPTYErrors.WithLabelValues("create_command").Add(1) - return xerrors.Errorf("create command: %w", err) - } - - rpty = reconnectingpty.New(ctx, cmd, &reconnectingpty.Options{ - Timeout: a.reconnectingPTYTimeout, - Metrics: a.metrics.reconnectingPTYErrors, - }, logger.With(slog.F("message_id", msg.ID))) - - if err = a.trackGoroutine(func() { - rpty.Wait() - a.reconnectingPTYs.Delete(msg.ID) - }); err != nil { - rpty.Close(err) - return xerrors.Errorf("start routine: %w", err) - } - - connected = true - sendConnected <- rpty - } - return rpty.Attach(ctx, connectionID, conn, msg.Height, msg.Width, connLogger) -} - // Collect collects additional stats from the agent func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connection]netlogtype.Counts) *proto.Stats { a.logger.Debug(context.Background(), "computing stats report") @@ -1496,7 +1375,7 @@ func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connect stats.SessionCountVscode = sshStats.VSCode stats.SessionCountJetbrains = sshStats.JetBrains - stats.SessionCountReconnectingPty = a.connCountReconnectingPTY.Load() + stats.SessionCountReconnectingPty = a.reconnectingPTYServer.ConnCount() // Compute the median connection latency! a.logger.Debug(ctx, "starting peer latency measurement for stats") diff --git a/agent/reconnectingpty/server.go b/agent/reconnectingpty/server.go new file mode 100644 index 0000000000000..052a88e52b0b4 --- /dev/null +++ b/agent/reconnectingpty/server.go @@ -0,0 +1,191 @@ +package reconnectingpty + +import ( + "context" + "encoding/binary" + "encoding/json" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" + "golang.org/x/xerrors" + + "cdr.dev/slog" + "github.com/coder/coder/v2/agent/agentssh" + "github.com/coder/coder/v2/codersdk/workspacesdk" +) + +type Server struct { + logger slog.Logger + connectionsTotal prometheus.Counter + errorsTotal *prometheus.CounterVec + commandCreator *agentssh.Server + connCount atomic.Int64 + reconnectingPTYs sync.Map + timeout time.Duration +} + +// NewServer returns a new ReconnectingPTY server +func NewServer(logger slog.Logger, commandCreator *agentssh.Server, + connectionsTotal prometheus.Counter, errorsTotal *prometheus.CounterVec, + timeout time.Duration, +) *Server { + return &Server{ + logger: logger, + commandCreator: commandCreator, + connectionsTotal: connectionsTotal, + errorsTotal: errorsTotal, + timeout: timeout, + } +} + +func (s *Server) Serve(ctx, hardCtx context.Context, l net.Listener) (retErr error) { + var wg sync.WaitGroup + for { + if ctx.Err() != nil { + break + } + conn, err := l.Accept() + if err != nil { + s.logger.Debug(ctx, "accept pty failed", slog.Error(err)) + retErr = err + break + } + clog := s.logger.With( + slog.F("remote", conn.RemoteAddr().String()), + slog.F("local", conn.LocalAddr().String())) + clog.Info(ctx, "accepted conn") + wg.Add(1) + closed := make(chan struct{}) + go func() { + select { + case <-closed: + case <-hardCtx.Done(): + _ = conn.Close() + } + wg.Done() + }() + wg.Add(1) + go func() { + defer close(closed) + defer wg.Done() + _ = s.handleConn(ctx, clog, conn) + }() + } + wg.Wait() + return retErr +} + +func (s *Server) ConnCount() int64 { + return s.connCount.Load() +} + +func (s *Server) handleConn(ctx context.Context, logger slog.Logger, conn net.Conn) (retErr error) { + defer conn.Close() + s.connectionsTotal.Add(1) + s.connCount.Add(1) + defer s.connCount.Add(-1) + + // This cannot use a JSON decoder, since that can + // buffer additional data that is required for the PTY. + rawLen := make([]byte, 2) + _, err := conn.Read(rawLen) + if err != nil { + // logging at info since a single incident isn't too worrying (the client could just have + // hung up), but if we get a lot of these we'd want to investigate. + logger.Info(ctx, "failed to read AgentReconnectingPTYInit length", slog.Error(err)) + return nil + } + length := binary.LittleEndian.Uint16(rawLen) + data := make([]byte, length) + _, err = conn.Read(data) + if err != nil { + // logging at info since a single incident isn't too worrying (the client could just have + // hung up), but if we get a lot of these we'd want to investigate. + logger.Info(ctx, "failed to read AgentReconnectingPTYInit", slog.Error(err)) + return nil + } + var msg workspacesdk.AgentReconnectingPTYInit + err = json.Unmarshal(data, &msg) + if err != nil { + logger.Warn(ctx, "failed to unmarshal init", slog.F("raw", data)) + return nil + } + + connectionID := uuid.NewString() + connLogger := logger.With(slog.F("message_id", msg.ID), slog.F("connection_id", connectionID)) + connLogger.Debug(ctx, "starting handler") + + defer func() { + if err := retErr; err != nil { + // If the context is done, we don't want to log this as an error since it's expected. + if ctx.Err() != nil { + connLogger.Info(ctx, "reconnecting pty failed with attach error (agent closed)", slog.Error(err)) + } else { + connLogger.Error(ctx, "reconnecting pty failed with attach error", slog.Error(err)) + } + } + connLogger.Info(ctx, "reconnecting pty connection closed") + }() + + var rpty ReconnectingPTY + sendConnected := make(chan ReconnectingPTY, 1) + // On store, reserve this ID to prevent multiple concurrent new connections. + waitReady, ok := s.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected) + if ok { + close(sendConnected) // Unused. + connLogger.Debug(ctx, "connecting to existing reconnecting pty") + c, ok := waitReady.(chan ReconnectingPTY) + if !ok { + return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady) + } + rpty, ok = <-c + if !ok || rpty == nil { + return xerrors.Errorf("reconnecting pty closed before connection") + } + c <- rpty // Put it back for the next reconnect. + } else { + connLogger.Debug(ctx, "creating new reconnecting pty") + + connected := false + defer func() { + if !connected && retErr != nil { + s.reconnectingPTYs.Delete(msg.ID) + close(sendConnected) + } + }() + + // Empty command will default to the users shell! + cmd, err := s.commandCreator.CreateCommand(ctx, msg.Command, nil) + if err != nil { + s.errorsTotal.WithLabelValues("create_command").Add(1) + return xerrors.Errorf("create command: %w", err) + } + + rpty = New(ctx, cmd, &Options{ + Timeout: s.timeout, + Metrics: s.errorsTotal, + }, logger.With(slog.F("message_id", msg.ID))) + + done := make(chan struct{}) + go func() { + select { + case <-done: + case <-ctx.Done(): + rpty.Close(ctx.Err()) + } + }() + + go func() { + rpty.Wait() + s.reconnectingPTYs.Delete(msg.ID) + }() + + connected = true + sendConnected <- rpty + } + return rpty.Attach(ctx, connectionID, conn, msg.Height, msg.Width, connLogger) +}