diff --git a/agent/agent.go b/agent/agent.go index ea2fae6d430f6..3cbdafa301401 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -1134,7 +1134,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m rpty.Wait() a.reconnectingPTYs.Delete(msg.ID) }); err != nil { - rpty.Close(err.Error()) + rpty.Close(err) return xerrors.Errorf("start routine: %w", err) } diff --git a/agent/reconnectingpty/buffered.go b/agent/reconnectingpty/buffered.go index 93241ada29687..47d74595472a5 100644 --- a/agent/reconnectingpty/buffered.go +++ b/agent/reconnectingpty/buffered.go @@ -51,7 +51,7 @@ func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slo // Default to buffer 64KiB. circularBuffer, err := circbuf.NewBuffer(64 << 10) if err != nil { - rpty.state.setState(StateDone, xerrors.Errorf("generate screen id: %w", err)) + rpty.state.setState(StateDone, xerrors.Errorf("create circular buffer: %w", err)) return rpty } rpty.circularBuffer = circularBuffer @@ -63,7 +63,7 @@ func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slo cmdWithEnv.Dir = rpty.command.Dir ptty, process, err := pty.Start(cmdWithEnv) if err != nil { - rpty.state.setState(StateDone, xerrors.Errorf("generate screen id: %w", err)) + rpty.state.setState(StateDone, xerrors.Errorf("start pty: %w", err)) return rpty } rpty.ptty = ptty @@ -92,7 +92,7 @@ func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slo // not found for example). // TODO: Should we check the process's exit code in case the command was // invalid? - rpty.Close("unable to read pty output, command might have exited") + rpty.Close(nil) break } part := buffer[:read] @@ -126,7 +126,7 @@ func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slo // or the reconnecting pty closes the pty will be shut down. func (rpty *bufferedReconnectingPTY) lifecycle(ctx context.Context, logger slog.Logger) { rpty.timer = time.AfterFunc(attachTimeout, func() { - rpty.Close("reconnecting pty timeout") + rpty.Close(xerrors.New("reconnecting pty timeout")) }) logger.Debug(ctx, "reconnecting pty ready") @@ -136,7 +136,7 @@ func (rpty *bufferedReconnectingPTY) lifecycle(ctx context.Context, logger slog. if state < StateClosing { // If we have not closed yet then the context is what unblocked us (which // means the agent is shutting down) so move into the closing phase. - rpty.Close(reasonErr.Error()) + rpty.Close(reasonErr) } rpty.timer.Stop() @@ -168,7 +168,7 @@ func (rpty *bufferedReconnectingPTY) lifecycle(ctx context.Context, logger slog. } logger.Info(ctx, "closed reconnecting pty") - rpty.state.setState(StateDone, xerrors.Errorf("reconnecting pty closed: %w", reasonErr)) + rpty.state.setState(StateDone, reasonErr) } func (rpty *bufferedReconnectingPTY) Attach(ctx context.Context, connID string, conn net.Conn, height, width uint16, logger slog.Logger) error { @@ -178,7 +178,7 @@ func (rpty *bufferedReconnectingPTY) Attach(ctx context.Context, connID string, ctx, cancel := context.WithCancel(ctx) defer cancel() - err := rpty.doAttach(ctx, connID, conn, height, width, logger) + err := rpty.doAttach(connID, conn) if err != nil { return err } @@ -189,15 +189,30 @@ func (rpty *bufferedReconnectingPTY) Attach(ctx context.Context, connID string, delete(rpty.activeConns, connID) }() + state, err := rpty.state.waitForStateOrContext(ctx, StateReady) + if state != StateReady { + return err + } + + go heartbeat(ctx, rpty.timer, rpty.timeout) + + // Resize the PTY to initial height + width. + err = rpty.ptty.Resize(height, width) + if err != nil { + // We can continue after this, it's not fatal! + logger.Warn(ctx, "reconnecting PTY initial resize failed, but will continue", slog.Error(err)) + rpty.metrics.WithLabelValues("resize").Add(1) + } + // Pipe conn -> pty and block. pty -> conn is handled in newBuffered(). readConnLoop(ctx, conn, rpty.ptty, rpty.metrics, logger) return nil } -// doAttach adds the connection to the map, replays the buffer, and starts the -// heartbeat. It exists separately only so we can defer the mutex unlock which -// is not possible in Attach since it blocks. -func (rpty *bufferedReconnectingPTY) doAttach(ctx context.Context, connID string, conn net.Conn, height, width uint16, logger slog.Logger) error { +// doAttach adds the connection to the map and replays the buffer. It exists +// separately only for convenience to defer the mutex unlock which is not +// possible in Attach since it blocks. +func (rpty *bufferedReconnectingPTY) doAttach(connID string, conn net.Conn) error { rpty.state.cond.L.Lock() defer rpty.state.cond.L.Unlock() @@ -211,21 +226,6 @@ func (rpty *bufferedReconnectingPTY) doAttach(ctx context.Context, connID string return xerrors.Errorf("write buffer to conn: %w", err) } - state, err := rpty.state.waitForStateOrContextLocked(ctx, StateReady) - if state != StateReady { - return xerrors.Errorf("reconnecting pty ready wait: %w", err) - } - - go heartbeat(ctx, rpty.timer, rpty.timeout) - - // Resize the PTY to initial height + width. - err = rpty.ptty.Resize(height, width) - if err != nil { - // We can continue after this, it's not fatal! - logger.Warn(ctx, "reconnecting PTY initial resize failed, but will continue", slog.Error(err)) - rpty.metrics.WithLabelValues("resize").Add(1) - } - rpty.activeConns[connID] = conn return nil @@ -235,7 +235,7 @@ func (rpty *bufferedReconnectingPTY) Wait() { _, _ = rpty.state.waitForState(StateClosing) } -func (rpty *bufferedReconnectingPTY) Close(reason string) { +func (rpty *bufferedReconnectingPTY) Close(error error) { // The closing state change will be handled by the lifecycle. - rpty.state.setState(StateClosing, xerrors.Errorf("reconnecting pty closing: %s", reason)) + rpty.state.setState(StateClosing, error) } diff --git a/agent/reconnectingpty/reconnectingpty.go b/agent/reconnectingpty/reconnectingpty.go index e3dbb9024b063..60f347c81ea72 100644 --- a/agent/reconnectingpty/reconnectingpty.go +++ b/agent/reconnectingpty/reconnectingpty.go @@ -48,7 +48,7 @@ type ReconnectingPTY interface { // still be exiting. Wait() // Close kills the reconnecting pty process. - Close(reason string) + Close(err error) } // New sets up a new reconnecting pty that wraps the provided command. Any @@ -171,12 +171,7 @@ func (s *ptyState) waitForState(state State) (State, error) { func (s *ptyState) waitForStateOrContext(ctx context.Context, state State) (State, error) { s.cond.L.Lock() defer s.cond.L.Unlock() - return s.waitForStateOrContextLocked(ctx, state) -} -// waitForStateOrContextLocked is the same as waitForStateOrContext except it -// assumes the caller has already locked cond. -func (s *ptyState) waitForStateOrContextLocked(ctx context.Context, state State) (State, error) { nevermind := make(chan struct{}) defer close(nevermind) go func() { diff --git a/agent/reconnectingpty/screen.go b/agent/reconnectingpty/screen.go index 0203154f83335..94854a8b8bf81 100644 --- a/agent/reconnectingpty/screen.go +++ b/agent/reconnectingpty/screen.go @@ -124,7 +124,7 @@ func newScreen(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog. // the reconnecting pty will be closed. func (rpty *screenReconnectingPTY) lifecycle(ctx context.Context, logger slog.Logger) { rpty.timer = time.AfterFunc(attachTimeout, func() { - rpty.Close("reconnecting pty timeout") + rpty.Close(xerrors.New("reconnecting pty timeout")) }) logger.Debug(ctx, "reconnecting pty ready") @@ -134,7 +134,7 @@ func (rpty *screenReconnectingPTY) lifecycle(ctx context.Context, logger slog.Lo if state < StateClosing { // If we have not closed yet then the context is what unblocked us (which // means the agent is shutting down) so move into the closing phase. - rpty.Close(reasonErr.Error()) + rpty.Close(reasonErr) } rpty.timer.Stop() @@ -145,7 +145,7 @@ func (rpty *screenReconnectingPTY) lifecycle(ctx context.Context, logger slog.Lo } logger.Info(ctx, "closed reconnecting pty") - rpty.state.setState(StateDone, xerrors.Errorf("reconnecting pty closed: %w", reasonErr)) + rpty.state.setState(StateDone, reasonErr) } func (rpty *screenReconnectingPTY) Attach(ctx context.Context, _ string, conn net.Conn, height, width uint16, logger slog.Logger) error { @@ -157,7 +157,7 @@ func (rpty *screenReconnectingPTY) Attach(ctx context.Context, _ string, conn ne state, err := rpty.state.waitForStateOrContext(ctx, StateReady) if state != StateReady { - return xerrors.Errorf("reconnecting pty ready wait: %w", err) + return err } go heartbeat(ctx, rpty.timer, rpty.timeout) @@ -382,7 +382,7 @@ func (rpty *screenReconnectingPTY) Wait() { _, _ = rpty.state.waitForState(StateClosing) } -func (rpty *screenReconnectingPTY) Close(reason string) { +func (rpty *screenReconnectingPTY) Close(err error) { // The closing state change will be handled by the lifecycle. - rpty.state.setState(StateClosing, xerrors.Errorf("reconnecting pty closing: %s", reason)) + rpty.state.setState(StateClosing, err) }