diff --git a/agent/agent.go b/agent/agent.go index 33f3c5d16d493..1efb3e88f3dbe 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -1125,13 +1125,28 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) { go func() { _, _ = io.Copy(ptty.Input(), session) }() + + // In low parallelism scenarios, the command may exit and we may close + // the pty before the output copy has started. This can result in the + // output being lost. To avoid this, we wait for the output copy to + // start before waiting for the command to exit. This ensures that the + // output copy goroutine will be scheduled before calling close on the + // pty. There is still a risk of data loss if a command produces a lot + // of output, see TestAgent_Session_TTY_HugeOutputIsNotLost (skipped). + outputCopyStarted := make(chan struct{}) + ptyOutput := func() io.Reader { + defer close(outputCopyStarted) + return ptty.Output() + } wg.Add(1) go func() { // Ensure data is flushed to session on command exit, if we // close the session too soon, we might lose data. defer wg.Done() - _, _ = io.Copy(session, ptty.Output()) + _, _ = io.Copy(session, ptyOutput()) }() + <-outputCopyStarted + err = process.Wait() var exitErr *exec.ExitError // ExitErrors just mean the command we run returned a non-zero exit code, which is normal diff --git a/agent/agent_test.go b/agent/agent_test.go index c40980f5dc647..10ccbe51242a8 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -373,9 +373,12 @@ func TestAgent_Session_TTY_FastCommandHasOutput(t *testing.T) { var stdout bytes.Buffer // NOTE(mafredri): Increase iterations to increase chance of failure, - // assuming bug is present. + // assuming bug is present. Limiting GOMAXPROCS further + // increases the chance of failure. // Using 1000 iterations is basically a guaranteed failure (but let's // not increase test times needlessly). + // Limit GOMAXPROCS (e.g. `export GOMAXPROCS=1`) to further increase + // chance of failure. Also -race helps. for i := 0; i < 5; i++ { func() { stdout.Reset() @@ -399,6 +402,63 @@ func TestAgent_Session_TTY_FastCommandHasOutput(t *testing.T) { } } +func TestAgent_Session_TTY_HugeOutputIsNotLost(t *testing.T) { + t.Parallel() + if runtime.GOOS == "windows" { + // This might be our implementation, or ConPTY itself. + // It's difficult to find extensive tests for it, so + // it seems like it could be either. + t.Skip("ConPTY appears to be inconsistent on Windows.") + } + t.Skip("This test proves we have a bug where parts of large output on a PTY can be lost after the command exits, skipped to avoid test failures.") + + // This test is here to prevent prove we have a bug where quickly executing + // commands (with TTY) don't flush their output to the SSH session. This is + // due to the pty being closed before all the output has been copied, but + // protecting against this requires a non-trivial rewrite of the output + // processing (or figuring out a way to put the pty in a mode where this + // does not happen). + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + //nolint:dogsled + conn, _, _, _, _ := setupAgent(t, agentsdk.Metadata{}, 0) + sshClient, err := conn.SSHClient(ctx) + require.NoError(t, err) + defer sshClient.Close() + + ptty := ptytest.New(t) + + var stdout bytes.Buffer + // NOTE(mafredri): Increase iterations to increase chance of failure, + // assuming bug is present. + // Using 10 iterations is basically a guaranteed failure (but let's + // not increase test times needlessly). Run with -race and do not + // limit parallelism (`export GOMAXPROCS=10`) to increase the chance + // of failure. + for i := 0; i < 1; i++ { + func() { + stdout.Reset() + + session, err := sshClient.NewSession() + require.NoError(t, err) + defer session.Close() + err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{}) + require.NoError(t, err) + + session.Stdout = &stdout + session.Stderr = ptty.Output() + session.Stdin = ptty.Input() + want := strings.Repeat("wazzup", 1024+1) // ~6KB, +1 because 1024 is a common buffer size. + err = session.Start("echo " + want) + require.NoError(t, err) + + err = session.Wait() + require.NoError(t, err) + require.Contains(t, stdout.String(), want, "should output entire greeting") + }() + } +} + //nolint:paralleltest // This test reserves a port. func TestAgent_TCPLocalForwarding(t *testing.T) { random, err := net.Listen("tcp", "127.0.0.1:0")