Skip to content

Commit 891bbda

Browse files
authored
fix(agent): More protection for lost output of SSH PTY commands (#6833)
Fixes #6656 (part 2)
1 parent 1645281 commit 891bbda

File tree

2 files changed

+77
-2
lines changed

2 files changed

+77
-2
lines changed

agent/agent.go

+16-1
Original file line numberDiff line numberDiff line change
@@ -1125,13 +1125,28 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
11251125
go func() {
11261126
_, _ = io.Copy(ptty.Input(), session)
11271127
}()
1128+
1129+
// In low parallelism scenarios, the command may exit and we may close
1130+
// the pty before the output copy has started. This can result in the
1131+
// output being lost. To avoid this, we wait for the output copy to
1132+
// start before waiting for the command to exit. This ensures that the
1133+
// output copy goroutine will be scheduled before calling close on the
1134+
// pty. There is still a risk of data loss if a command produces a lot
1135+
// of output, see TestAgent_Session_TTY_HugeOutputIsNotLost (skipped).
1136+
outputCopyStarted := make(chan struct{})
1137+
ptyOutput := func() io.Reader {
1138+
defer close(outputCopyStarted)
1139+
return ptty.Output()
1140+
}
11281141
wg.Add(1)
11291142
go func() {
11301143
// Ensure data is flushed to session on command exit, if we
11311144
// close the session too soon, we might lose data.
11321145
defer wg.Done()
1133-
_, _ = io.Copy(session, ptty.Output())
1146+
_, _ = io.Copy(session, ptyOutput())
11341147
}()
1148+
<-outputCopyStarted
1149+
11351150
err = process.Wait()
11361151
var exitErr *exec.ExitError
11371152
// ExitErrors just mean the command we run returned a non-zero exit code, which is normal

agent/agent_test.go

+61-1
Original file line numberDiff line numberDiff line change
@@ -373,9 +373,12 @@ func TestAgent_Session_TTY_FastCommandHasOutput(t *testing.T) {
373373

374374
var stdout bytes.Buffer
375375
// NOTE(mafredri): Increase iterations to increase chance of failure,
376-
// assuming bug is present.
376+
// assuming bug is present. Limiting GOMAXPROCS further
377+
// increases the chance of failure.
377378
// Using 1000 iterations is basically a guaranteed failure (but let's
378379
// not increase test times needlessly).
380+
// Limit GOMAXPROCS (e.g. `export GOMAXPROCS=1`) to further increase
381+
// chance of failure. Also -race helps.
379382
for i := 0; i < 5; i++ {
380383
func() {
381384
stdout.Reset()
@@ -399,6 +402,63 @@ func TestAgent_Session_TTY_FastCommandHasOutput(t *testing.T) {
399402
}
400403
}
401404

405+
func TestAgent_Session_TTY_HugeOutputIsNotLost(t *testing.T) {
406+
t.Parallel()
407+
if runtime.GOOS == "windows" {
408+
// This might be our implementation, or ConPTY itself.
409+
// It's difficult to find extensive tests for it, so
410+
// it seems like it could be either.
411+
t.Skip("ConPTY appears to be inconsistent on Windows.")
412+
}
413+
t.Skip("This test proves we have a bug where parts of large output on a PTY can be lost after the command exits, skipped to avoid test failures.")
414+
415+
// This test is here to prevent prove we have a bug where quickly executing
416+
// commands (with TTY) don't flush their output to the SSH session. This is
417+
// due to the pty being closed before all the output has been copied, but
418+
// protecting against this requires a non-trivial rewrite of the output
419+
// processing (or figuring out a way to put the pty in a mode where this
420+
// does not happen).
421+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
422+
defer cancel()
423+
//nolint:dogsled
424+
conn, _, _, _, _ := setupAgent(t, agentsdk.Metadata{}, 0)
425+
sshClient, err := conn.SSHClient(ctx)
426+
require.NoError(t, err)
427+
defer sshClient.Close()
428+
429+
ptty := ptytest.New(t)
430+
431+
var stdout bytes.Buffer
432+
// NOTE(mafredri): Increase iterations to increase chance of failure,
433+
// assuming bug is present.
434+
// Using 10 iterations is basically a guaranteed failure (but let's
435+
// not increase test times needlessly). Run with -race and do not
436+
// limit parallelism (`export GOMAXPROCS=10`) to increase the chance
437+
// of failure.
438+
for i := 0; i < 1; i++ {
439+
func() {
440+
stdout.Reset()
441+
442+
session, err := sshClient.NewSession()
443+
require.NoError(t, err)
444+
defer session.Close()
445+
err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
446+
require.NoError(t, err)
447+
448+
session.Stdout = &stdout
449+
session.Stderr = ptty.Output()
450+
session.Stdin = ptty.Input()
451+
want := strings.Repeat("wazzup", 1024+1) // ~6KB, +1 because 1024 is a common buffer size.
452+
err = session.Start("echo " + want)
453+
require.NoError(t, err)
454+
455+
err = session.Wait()
456+
require.NoError(t, err)
457+
require.Contains(t, stdout.String(), want, "should output entire greeting")
458+
}()
459+
}
460+
}
461+
402462
//nolint:paralleltest // This test reserves a port.
403463
func TestAgent_TCPLocalForwarding(t *testing.T) {
404464
random, err := net.Listen("tcp", "127.0.0.1:0")

0 commit comments

Comments
 (0)