Skip to content

fix: catch missed output with reconnecting PTY #9094

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
rpty.Wait()
a.reconnectingPTYs.Delete(msg.ID)
}); err != nil {
rpty.Close(err.Error())
rpty.Close(err)
return xerrors.Errorf("start routine: %w", err)
}

Expand Down
56 changes: 28 additions & 28 deletions agent/reconnectingpty/buffered.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slo
// Default to buffer 64KiB.
circularBuffer, err := circbuf.NewBuffer(64 << 10)
if err != nil {
rpty.state.setState(StateDone, xerrors.Errorf("generate screen id: %w", err))
rpty.state.setState(StateDone, xerrors.Errorf("create circular buffer: %w", err))
return rpty
}
rpty.circularBuffer = circularBuffer
Expand All @@ -63,7 +63,7 @@ func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slo
cmdWithEnv.Dir = rpty.command.Dir
ptty, process, err := pty.Start(cmdWithEnv)
if err != nil {
rpty.state.setState(StateDone, xerrors.Errorf("generate screen id: %w", err))
rpty.state.setState(StateDone, xerrors.Errorf("start pty: %w", err))
return rpty
}
rpty.ptty = ptty
Expand Down Expand Up @@ -92,7 +92,7 @@ func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slo
// not found for example).
// TODO: Should we check the process's exit code in case the command was
// invalid?
rpty.Close("unable to read pty output, command might have exited")
rpty.Close(nil)
break
}
part := buffer[:read]
Expand Down Expand Up @@ -126,7 +126,7 @@ func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slo
// or the reconnecting pty closes the pty will be shut down.
func (rpty *bufferedReconnectingPTY) lifecycle(ctx context.Context, logger slog.Logger) {
rpty.timer = time.AfterFunc(attachTimeout, func() {
rpty.Close("reconnecting pty timeout")
rpty.Close(xerrors.New("reconnecting pty timeout"))
})

logger.Debug(ctx, "reconnecting pty ready")
Expand All @@ -136,7 +136,7 @@ func (rpty *bufferedReconnectingPTY) lifecycle(ctx context.Context, logger slog.
if state < StateClosing {
// If we have not closed yet then the context is what unblocked us (which
// means the agent is shutting down) so move into the closing phase.
rpty.Close(reasonErr.Error())
rpty.Close(reasonErr)
}
rpty.timer.Stop()

Expand Down Expand Up @@ -168,7 +168,7 @@ func (rpty *bufferedReconnectingPTY) lifecycle(ctx context.Context, logger slog.
}

logger.Info(ctx, "closed reconnecting pty")
rpty.state.setState(StateDone, xerrors.Errorf("reconnecting pty closed: %w", reasonErr))
rpty.state.setState(StateDone, reasonErr)
}

func (rpty *bufferedReconnectingPTY) Attach(ctx context.Context, connID string, conn net.Conn, height, width uint16, logger slog.Logger) error {
Expand All @@ -178,7 +178,7 @@ func (rpty *bufferedReconnectingPTY) Attach(ctx context.Context, connID string,
ctx, cancel := context.WithCancel(ctx)
defer cancel()

err := rpty.doAttach(ctx, connID, conn, height, width, logger)
err := rpty.doAttach(connID, conn)
if err != nil {
return err
}
Expand All @@ -189,15 +189,30 @@ func (rpty *bufferedReconnectingPTY) Attach(ctx context.Context, connID string,
delete(rpty.activeConns, connID)
}()

state, err := rpty.state.waitForStateOrContext(ctx, StateReady)
if state != StateReady {
return err
}

go heartbeat(ctx, rpty.timer, rpty.timeout)

// Resize the PTY to initial height + width.
err = rpty.ptty.Resize(height, width)
if err != nil {
// We can continue after this, it's not fatal!
logger.Warn(ctx, "reconnecting PTY initial resize failed, but will continue", slog.Error(err))
rpty.metrics.WithLabelValues("resize").Add(1)
}

// Pipe conn -> pty and block. pty -> conn is handled in newBuffered().
readConnLoop(ctx, conn, rpty.ptty, rpty.metrics, logger)
return nil
}

// doAttach adds the connection to the map, replays the buffer, and starts the
// heartbeat. It exists separately only so we can defer the mutex unlock which
// is not possible in Attach since it blocks.
func (rpty *bufferedReconnectingPTY) doAttach(ctx context.Context, connID string, conn net.Conn, height, width uint16, logger slog.Logger) error {
// doAttach adds the connection to the map and replays the buffer. It exists
// separately only for convenience to defer the mutex unlock which is not
// possible in Attach since it blocks.
func (rpty *bufferedReconnectingPTY) doAttach(connID string, conn net.Conn) error {
rpty.state.cond.L.Lock()
defer rpty.state.cond.L.Unlock()

Expand All @@ -211,21 +226,6 @@ func (rpty *bufferedReconnectingPTY) doAttach(ctx context.Context, connID string
return xerrors.Errorf("write buffer to conn: %w", err)
}

state, err := rpty.state.waitForStateOrContextLocked(ctx, StateReady)
if state != StateReady {
return xerrors.Errorf("reconnecting pty ready wait: %w", err)
}

go heartbeat(ctx, rpty.timer, rpty.timeout)

// Resize the PTY to initial height + width.
err = rpty.ptty.Resize(height, width)
if err != nil {
// We can continue after this, it's not fatal!
logger.Warn(ctx, "reconnecting PTY initial resize failed, but will continue", slog.Error(err))
rpty.metrics.WithLabelValues("resize").Add(1)
}

rpty.activeConns[connID] = conn

return nil
Expand All @@ -235,7 +235,7 @@ func (rpty *bufferedReconnectingPTY) Wait() {
_, _ = rpty.state.waitForState(StateClosing)
}

func (rpty *bufferedReconnectingPTY) Close(reason string) {
func (rpty *bufferedReconnectingPTY) Close(error error) {
// The closing state change will be handled by the lifecycle.
rpty.state.setState(StateClosing, xerrors.Errorf("reconnecting pty closing: %s", reason))
rpty.state.setState(StateClosing, error)
}
7 changes: 1 addition & 6 deletions agent/reconnectingpty/reconnectingpty.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ type ReconnectingPTY interface {
// still be exiting.
Wait()
// Close kills the reconnecting pty process.
Close(reason string)
Close(err error)
}

// New sets up a new reconnecting pty that wraps the provided command. Any
Expand Down Expand Up @@ -171,12 +171,7 @@ func (s *ptyState) waitForState(state State) (State, error) {
func (s *ptyState) waitForStateOrContext(ctx context.Context, state State) (State, error) {
s.cond.L.Lock()
defer s.cond.L.Unlock()
return s.waitForStateOrContextLocked(ctx, state)
}

// waitForStateOrContextLocked is the same as waitForStateOrContext except it
// assumes the caller has already locked cond.
func (s *ptyState) waitForStateOrContextLocked(ctx context.Context, state State) (State, error) {
nevermind := make(chan struct{})
defer close(nevermind)
go func() {
Expand Down
12 changes: 6 additions & 6 deletions agent/reconnectingpty/screen.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func newScreen(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.
// the reconnecting pty will be closed.
func (rpty *screenReconnectingPTY) lifecycle(ctx context.Context, logger slog.Logger) {
rpty.timer = time.AfterFunc(attachTimeout, func() {
rpty.Close("reconnecting pty timeout")
rpty.Close(xerrors.New("reconnecting pty timeout"))
})

logger.Debug(ctx, "reconnecting pty ready")
Expand All @@ -134,7 +134,7 @@ func (rpty *screenReconnectingPTY) lifecycle(ctx context.Context, logger slog.Lo
if state < StateClosing {
// If we have not closed yet then the context is what unblocked us (which
// means the agent is shutting down) so move into the closing phase.
rpty.Close(reasonErr.Error())
rpty.Close(reasonErr)
}
rpty.timer.Stop()

Expand All @@ -145,7 +145,7 @@ func (rpty *screenReconnectingPTY) lifecycle(ctx context.Context, logger slog.Lo
}

logger.Info(ctx, "closed reconnecting pty")
rpty.state.setState(StateDone, xerrors.Errorf("reconnecting pty closed: %w", reasonErr))
rpty.state.setState(StateDone, reasonErr)
}

func (rpty *screenReconnectingPTY) Attach(ctx context.Context, _ string, conn net.Conn, height, width uint16, logger slog.Logger) error {
Expand All @@ -157,7 +157,7 @@ func (rpty *screenReconnectingPTY) Attach(ctx context.Context, _ string, conn ne

state, err := rpty.state.waitForStateOrContext(ctx, StateReady)
if state != StateReady {
return xerrors.Errorf("reconnecting pty ready wait: %w", err)
return err
}

go heartbeat(ctx, rpty.timer, rpty.timeout)
Expand Down Expand Up @@ -382,7 +382,7 @@ func (rpty *screenReconnectingPTY) Wait() {
_, _ = rpty.state.waitForState(StateClosing)
}

func (rpty *screenReconnectingPTY) Close(reason string) {
func (rpty *screenReconnectingPTY) Close(err error) {
// The closing state change will be handled by the lifecycle.
rpty.state.setState(StateClosing, xerrors.Errorf("reconnecting pty closing: %s", reason))
rpty.state.setState(StateClosing, err)
}