Skip to content

Commit fa4abb2

Browse files
committed
fix(agent): Prevent SSH TTYs from losing command output on exit
Fixes #6656
1 parent 622fc6d commit fa4abb2

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

agent/agent.go

+7
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,7 @@ func (a *agent) init(ctx context.Context) {
844844
_ = session.Exit(MagicSessionErrorCode)
845845
return
846846
}
847+
_ = session.Exit(0)
847848
},
848849
HostSigners: []ssh.Signer{randomSigner},
849850
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
@@ -1100,6 +1101,7 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
11001101
if err != nil {
11011102
return xerrors.Errorf("start command: %w", err)
11021103
}
1104+
var wg sync.WaitGroup
11031105
defer func() {
11041106
closeErr := ptty.Close()
11051107
if closeErr != nil {
@@ -1108,6 +1110,7 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
11081110
retErr = closeErr
11091111
}
11101112
}
1113+
wg.Wait()
11111114
}()
11121115
go func() {
11131116
for win := range windowSize {
@@ -1117,10 +1120,14 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
11171120
}
11181121
}
11191122
}()
1123+
// We don't add input copy to wait group because
1124+
// it won't return until the session is closed.
11201125
go func() {
11211126
_, _ = io.Copy(ptty.Input(), session)
11221127
}()
1128+
wg.Add(1)
11231129
go func() {
1130+
defer wg.Done()
11241131
_, _ = io.Copy(session, ptty.Output())
11251132
}()
11261133
err = process.Wait()

agent/agent_test.go

+51
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,57 @@ func TestAgent_Session_TTY_Hushlogin(t *testing.T) {
348348
require.NotContains(t, stdout.String(), wantNotMOTD, "should not show motd")
349349
}
350350

351+
func TestAgent_Session_TTY_FastCommandHasOutput(t *testing.T) {
352+
t.Parallel()
353+
if runtime.GOOS == "windows" {
354+
// This might be our implementation, or ConPTY itself.
355+
// It's difficult to find extensive tests for it, so
356+
// it seems like it could be either.
357+
t.Skip("ConPTY appears to be inconsistent on Windows.")
358+
}
359+
360+
// This test is here to prevent regressions where quickly executing
361+
// commands (with TTY) don't flush their output to the SSH session.
362+
//
363+
// See: https://github.com/coder/coder/issues/6656
364+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
365+
defer cancel()
366+
//nolint:dogsled
367+
conn, _, _, _, _ := setupAgent(t, agentsdk.Metadata{}, 0)
368+
sshClient, err := conn.SSHClient(ctx)
369+
require.NoError(t, err)
370+
defer sshClient.Close()
371+
372+
ptty := ptytest.New(t)
373+
374+
var stdout bytes.Buffer
375+
// NOTE(mafredri): Increase iterations to increase chance of failure,
376+
// assuming bug is present.
377+
// Using 1000 iterations is basically a guaranteed failure (but let's
378+
// not increase test times needlessly).
379+
for i := 0; i < 5; i++ {
380+
func() {
381+
stdout.Reset()
382+
383+
session, err := sshClient.NewSession()
384+
require.NoError(t, err)
385+
defer session.Close()
386+
err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
387+
require.NoError(t, err)
388+
389+
session.Stdout = &stdout
390+
session.Stderr = ptty.Output()
391+
session.Stdin = ptty.Input()
392+
err = session.Start("echo wazzup")
393+
require.NoError(t, err)
394+
395+
err = session.Wait()
396+
require.NoError(t, err)
397+
require.Contains(t, stdout.String(), "wazzup", "should output greeting")
398+
}()
399+
}
400+
}
401+
351402
//nolint:paralleltest // This test reserves a port.
352403
func TestAgent_TCPLocalForwarding(t *testing.T) {
353404
random, err := net.Listen("tcp", "127.0.0.1:0")

0 commit comments

Comments
 (0)