diff --git a/agent/agent.go b/agent/agent.go index e614a28c8905c..33f3c5d16d493 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -844,6 +844,7 @@ func (a *agent) init(ctx context.Context) { _ = session.Exit(MagicSessionErrorCode) return } + _ = session.Exit(0) }, HostSigners: []ssh.Signer{randomSigner}, LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { @@ -1100,7 +1101,9 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) { if err != nil { return xerrors.Errorf("start command: %w", err) } + var wg sync.WaitGroup defer func() { + defer wg.Wait() closeErr := ptty.Close() if closeErr != nil { a.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr)) @@ -1117,10 +1120,16 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) { } } }() + // We don't add input copy to wait group because + // it won't return until the session is closed. go func() { _, _ = io.Copy(ptty.Input(), session) }() + wg.Add(1) go func() { + // Ensure data is flushed to session on command exit, if we + // close the session too soon, we might lose data. + defer wg.Done() _, _ = io.Copy(session, ptty.Output()) }() err = process.Wait() diff --git a/agent/agent_test.go b/agent/agent_test.go index d163d7bf60dd2..c40980f5dc647 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -348,6 +348,57 @@ func TestAgent_Session_TTY_Hushlogin(t *testing.T) { require.NotContains(t, stdout.String(), wantNotMOTD, "should not show motd") } +func TestAgent_Session_TTY_FastCommandHasOutput(t *testing.T) { + t.Parallel() + if runtime.GOOS == "windows" { + // This might be our implementation, or ConPTY itself. + // It's difficult to find extensive tests for it, so + // it seems like it could be either. + t.Skip("ConPTY appears to be inconsistent on Windows.") + } + + // This test is here to prevent regressions where quickly executing + // commands (with TTY) don't flush their output to the SSH session. + // + // See: https://github.com/coder/coder/issues/6656 + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + //nolint:dogsled + conn, _, _, _, _ := setupAgent(t, agentsdk.Metadata{}, 0) + sshClient, err := conn.SSHClient(ctx) + require.NoError(t, err) + defer sshClient.Close() + + ptty := ptytest.New(t) + + var stdout bytes.Buffer + // NOTE(mafredri): Increase iterations to increase chance of failure, + // assuming bug is present. + // Using 1000 iterations is basically a guaranteed failure (but let's + // not increase test times needlessly). + for i := 0; i < 5; i++ { + func() { + stdout.Reset() + + session, err := sshClient.NewSession() + require.NoError(t, err) + defer session.Close() + err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{}) + require.NoError(t, err) + + session.Stdout = &stdout + session.Stderr = ptty.Output() + session.Stdin = ptty.Input() + err = session.Start("echo wazzup") + require.NoError(t, err) + + err = session.Wait() + require.NoError(t, err) + require.Contains(t, stdout.String(), "wazzup", "should output greeting") + }() + } +} + //nolint:paralleltest // This test reserves a port. func TestAgent_TCPLocalForwarding(t *testing.T) { random, err := net.Listen("tcp", "127.0.0.1:0")