Skip to content

fix(agent): More protection for lost output of SSH PTY commands #6833

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -1125,13 +1125,28 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
go func() {
_, _ = io.Copy(ptty.Input(), session)
}()

// In low parallelism scenarios, the command may exit and we may close
// the pty before the output copy has started. This can result in the
// output being lost. To avoid this, we wait for the output copy to
// start before waiting for the command to exit. This ensures that the
// output copy goroutine will be scheduled before calling close on the
// pty. There is still a risk of data loss if a command produces a lot
// of output, see TestAgent_Session_TTY_HugeOutputIsNotLost (skipped).
outputCopyStarted := make(chan struct{})
ptyOutput := func() io.Reader {
defer close(outputCopyStarted)
return ptty.Output()
}
wg.Add(1)
go func() {
// Ensure data is flushed to session on command exit, if we
// close the session too soon, we might lose data.
defer wg.Done()
_, _ = io.Copy(session, ptty.Output())
_, _ = io.Copy(session, ptyOutput())
}()
<-outputCopyStarted

err = process.Wait()
var exitErr *exec.ExitError
// ExitErrors just mean the command we run returned a non-zero exit code, which is normal
Expand Down
62 changes: 61 additions & 1 deletion agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,12 @@ func TestAgent_Session_TTY_FastCommandHasOutput(t *testing.T) {

var stdout bytes.Buffer
// NOTE(mafredri): Increase iterations to increase chance of failure,
// assuming bug is present.
// assuming bug is present. Limiting GOMAXPROCS further
// increases the chance of failure.
// Using 1000 iterations is basically a guaranteed failure (but let's
// not increase test times needlessly).
// Limit GOMAXPROCS (e.g. `export GOMAXPROCS=1`) to further increase
// chance of failure. Also -race helps.
for i := 0; i < 5; i++ {
func() {
stdout.Reset()
Expand All @@ -399,6 +402,63 @@ func TestAgent_Session_TTY_FastCommandHasOutput(t *testing.T) {
}
}

func TestAgent_Session_TTY_HugeOutputIsNotLost(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
// This might be our implementation, or ConPTY itself.
// It's difficult to find extensive tests for it, so
// it seems like it could be either.
t.Skip("ConPTY appears to be inconsistent on Windows.")
}
t.Skip("This test proves we have a bug where parts of large output on a PTY can be lost after the command exits, skipped to avoid test failures.")

// This test is here to prevent prove we have a bug where quickly executing
// commands (with TTY) don't flush their output to the SSH session. This is
// due to the pty being closed before all the output has been copied, but
// protecting against this requires a non-trivial rewrite of the output
// processing (or figuring out a way to put the pty in a mode where this
// does not happen).
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Metadata{}, 0)
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()

ptty := ptytest.New(t)

var stdout bytes.Buffer
// NOTE(mafredri): Increase iterations to increase chance of failure,
// assuming bug is present.
// Using 10 iterations is basically a guaranteed failure (but let's
// not increase test times needlessly). Run with -race and do not
// limit parallelism (`export GOMAXPROCS=10`) to increase the chance
// of failure.
for i := 0; i < 1; i++ {
func() {
stdout.Reset()

session, err := sshClient.NewSession()
require.NoError(t, err)
defer session.Close()
err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
require.NoError(t, err)

session.Stdout = &stdout
session.Stderr = ptty.Output()
session.Stdin = ptty.Input()
want := strings.Repeat("wazzup", 1024+1) // ~6KB, +1 because 1024 is a common buffer size.
err = session.Start("echo " + want)
require.NoError(t, err)

err = session.Wait()
require.NoError(t, err)
require.Contains(t, stdout.String(), want, "should output entire greeting")
}()
}
}

//nolint:paralleltest // This test reserves a port.
func TestAgent_TCPLocalForwarding(t *testing.T) {
random, err := net.Listen("tcp", "127.0.0.1:0")
Expand Down