From b60d811e39bd48d4fe3910efc28d0e6d6cd8cfbe Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Mon, 25 Jul 2022 13:44:21 -0700 Subject: [PATCH 1/4] Return proper exit code on ssh with TTY Signed-off-by: Spike Curtis --- agent/agent.go | 45 ++++++++++++++++++++++++++++++++------- agent/agent_test.go | 27 ++++++++++++++++++++++- pty/pty.go | 10 +++++++++ pty/pty_other.go | 29 +++++++++++++++++++++++++ pty/pty_windows.go | 30 ++++++++++++++++++++++++++ pty/ptytest/ptytest.go | 3 +-- pty/start.go | 5 +++-- pty/start_other.go | 18 +++++++--------- pty/start_other_test.go | 19 ++++++++++++++++- pty/start_windows.go | 9 ++++++-- pty/start_windows_test.go | 18 ++++++++++++++-- 11 files changed, 185 insertions(+), 28 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index eaca5aa9e2321..4bdd1d9103409 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -43,6 +43,11 @@ const ( ProtocolReconnectingPTY = "reconnecting-pty" ProtocolSSH = "ssh" ProtocolDial = "dial" + + // MagicSessionErrorCode indicates that something went wrong with the session, rather than the + // command just returning a nonzero exit code, and is chosen as an arbitrary, high number + // unlikely to shadow other exit codes, which are typically 1, 2, 3, etc. + MagicSessionErrorCode = 229 ) type Options struct { @@ -273,9 +278,17 @@ func (a *agent) init(ctx context.Context) { }, Handler: func(session ssh.Session) { err := a.handleSSHSession(session) + var exitError *exec.ExitError + if xerrors.As(err, &exitError) { + a.logger.Debug(ctx, "ssh session returned", slog.Error(exitError)) + _ = session.Exit(exitError.ExitCode()) + return + } if err != nil { a.logger.Warn(ctx, "ssh session failed", slog.Error(err)) - _ = session.Exit(1) + // This exit code is designed to be unlikely to be confused for a legit exit code + // from the process. + _ = session.Exit(MagicSessionErrorCode) return } }, @@ -403,7 +416,7 @@ func (a *agent) createCommand(ctx context.Context, rawCommand string, env []stri return cmd, nil } -func (a *agent) handleSSHSession(session ssh.Session) error { +func (a *agent) handleSSHSession(session ssh.Session) (retErr error) { cmd, err := a.createCommand(session.Context(), session.RawCommand(), session.Environ()) if err != nil { return err @@ -426,14 +439,24 @@ func (a *agent) handleSSHSession(session ssh.Session) error { if err != nil { return xerrors.Errorf("start command: %w", err) } + defer func() { + closeErr := ptty.Close() + if closeErr != nil { + a.logger.Warn(context.Background(), "failed to close tty", + slog.Error(closeErr)) + if retErr == nil { + retErr = closeErr + } + } + }() err = ptty.Resize(uint16(sshPty.Window.Height), uint16(sshPty.Window.Width)) if err != nil { return xerrors.Errorf("resize ptty: %w", err) } go func() { for win := range windowSize { - err = ptty.Resize(uint16(win.Height), uint16(win.Width)) - if err != nil { + resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width)) + if resizeErr != nil { a.logger.Warn(context.Background(), "failed to resize tty", slog.Error(err)) } } @@ -444,9 +467,15 @@ func (a *agent) handleSSHSession(session ssh.Session) error { go func() { _, _ = io.Copy(session, ptty.Output()) }() - _, _ = process.Wait() - _ = ptty.Close() - return nil + err = process.Wait() + var exitErr *exec.ExitError + // ExitErrors just mean the command we run returned a non-zero exit code, which is normal + // and not something to be concerned about. But, if it's something else, we should log it. + if err != nil && !xerrors.As(err, &exitErr) { + a.logger.Warn(context.Background(), "wait error", + slog.Error(err)) + } + return err } cmd.Stdout = session @@ -549,7 +578,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne go func() { // If the process dies randomly, we should // close the pty. - _, _ = process.Wait() + _ = process.Wait() rpty.Close() }() go func() { diff --git a/agent/agent_test.go b/agent/agent_test.go index 15dfac8ff304e..cefa54d037f31 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -16,6 +16,8 @@ import ( "testing" "time" + "golang.org/x/xerrors" + scp "github.com/bramvdbogaerde/go-scp" "github.com/google/uuid" "github.com/pion/udp" @@ -69,7 +71,7 @@ func TestAgent(t *testing.T) { require.True(t, strings.HasSuffix(strings.TrimSpace(string(output)), "gitssh --")) }) - t.Run("SessionTTY", func(t *testing.T) { + t.Run("SessionTTYShell", func(t *testing.T) { t.Parallel() if runtime.GOOS == "windows" { // This might be our implementation, or ConPTY itself. @@ -103,6 +105,29 @@ func TestAgent(t *testing.T) { require.NoError(t, err) }) + t.Run("SessionTTYExitCode", func(t *testing.T) { + t.Parallel() + session := setupSSHSession(t, agent.Metadata{}) + command := "areallynotrealcommand" + err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{}) + require.NoError(t, err) + ptty := ptytest.New(t) + require.NoError(t, err) + session.Stdout = ptty.Output() + session.Stderr = ptty.Output() + session.Stdin = ptty.Input() + err = session.Start(command) + require.NoError(t, err) + err = session.Wait() + exitErr := &ssh.ExitError{} + require.True(t, xerrors.As(err, &exitErr)) + if runtime.GOOS == "windows" { + assert.Equal(t, 1, exitErr.ExitStatus()) + } else { + assert.Equal(t, 127, exitErr.ExitStatus()) + } + }) + t.Run("LocalForwarding", func(t *testing.T) { t.Parallel() random, err := net.Listen("tcp", "127.0.0.1:0") diff --git a/pty/pty.go b/pty/pty.go index 7a8fe6c99edb6..a5e7adc145611 100644 --- a/pty/pty.go +++ b/pty/pty.go @@ -29,6 +29,16 @@ type PTY interface { Resize(height uint16, width uint16) error } +// Process represents a process running in a PTY +type Process interface { + + // Wait for the command to complete. Returned error is as for exec.Cmd.Wait() + Wait() error + + // Kill the command process. Returned error is as for os.Process.Kill() + Kill() error +} + // WithFlags represents a PTY whose flags can be inspected, in particular // to determine whether local echo is enabled. type WithFlags interface { diff --git a/pty/pty_other.go b/pty/pty_other.go index 26448c88beea7..23a605f36210b 100644 --- a/pty/pty_other.go +++ b/pty/pty_other.go @@ -5,6 +5,8 @@ package pty import ( "os" + "os/exec" + "runtime" "sync" "github.com/creack/pty" @@ -27,6 +29,15 @@ type otherPty struct { pty, tty *os.File } +type otherProcess struct { + pty *os.File + cmd *exec.Cmd + + // cmdDone protects access to cmdErr: anything reading cmdErr should read from cmdDone first. + cmdDone chan any + cmdErr error +} + func (p *otherPty) Input() ReadWriter { return ReadWriter{ Reader: p.tty, @@ -66,3 +77,21 @@ func (p *otherPty) Close() error { } return nil } + +func (p *otherProcess) Wait() error { + <-p.cmdDone + return p.cmdErr +} + +func (p *otherProcess) Kill() error { + return p.cmd.Process.Kill() +} + +func (p *otherProcess) waitInternal() { + // The GC can garbage collect the TTY FD before the command + // has finished running. See: + // https://github.com/creack/pty/issues/127#issuecomment-932764012 + p.cmdErr = p.cmd.Wait() + runtime.KeepAlive(p.pty) + close(p.cmdDone) +} diff --git a/pty/pty_windows.go b/pty/pty_windows.go index 93e58c4405772..f206921c42b55 100644 --- a/pty/pty_windows.go +++ b/pty/pty_windows.go @@ -5,6 +5,7 @@ package pty import ( "os" + "os/exec" "sync" "unsafe" @@ -66,6 +67,13 @@ type ptyWindows struct { closed bool } +type windowsProcess struct { + // cmdDone protects access to cmdErr: anything reading cmdErr should read from cmdDone first. + cmdDone chan any + cmdErr error + proc *os.Process +} + func (p *ptyWindows) Output() ReadWriter { return ReadWriter{ Reader: p.outputRead, @@ -111,3 +119,25 @@ func (p *ptyWindows) Close() error { return nil } + +func (p *windowsProcess) waitInternal() { + defer close(p.cmdDone) + state, err := p.proc.Wait() + if err != nil { + p.cmdErr = err + return + } + if !state.Success() { + p.cmdErr = &exec.ExitError{ProcessState: state} + return + } +} + +func (p *windowsProcess) Wait() error { + <-p.cmdDone + return p.cmdErr +} + +func (p *windowsProcess) Kill() error { + return p.proc.Kill() +} diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index 43fbaec1109e2..dc22351416c50 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -5,7 +5,6 @@ import ( "bytes" "context" "io" - "os" "os/exec" "runtime" "strings" @@ -27,7 +26,7 @@ func New(t *testing.T) *PTY { return create(t, ptty, "cmd") } -func Start(t *testing.T, cmd *exec.Cmd) (*PTY, *os.Process) { +func Start(t *testing.T, cmd *exec.Cmd) (*PTY, pty.Process) { ptty, ps, err := pty.Start(cmd) require.NoError(t, err) return create(t, ptty, cmd.Args[0]), ps diff --git a/pty/start.go b/pty/start.go index d0cbcd667d7b7..385eddcd43325 100644 --- a/pty/start.go +++ b/pty/start.go @@ -1,10 +1,11 @@ package pty import ( - "os" "os/exec" ) -func Start(cmd *exec.Cmd) (PTY, *os.Process, error) { +// Start the command in a TTY. The calling code must not use cmd after passing it to the PTY, and +// instead rely on the returned Process to manage the command/process. +func Start(cmd *exec.Cmd) (PTY, Process, error) { return startPty(cmd) } diff --git a/pty/start_other.go b/pty/start_other.go index a38f8d3ce34d0..40fe97a15a696 100644 --- a/pty/start_other.go +++ b/pty/start_other.go @@ -4,7 +4,6 @@ package pty import ( - "os" "os/exec" "runtime" "strings" @@ -14,7 +13,7 @@ import ( "golang.org/x/xerrors" ) -func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) { +func startPty(cmd *exec.Cmd) (PTY, Process, error) { ptty, tty, err := pty.Open() if err != nil { return nil, nil, xerrors.Errorf("open: %w", err) @@ -37,16 +36,15 @@ func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) { } return nil, nil, xerrors.Errorf("start: %w", err) } - go func() { - // The GC can garbage collect the TTY FD before the command - // has finished running. See: - // https://github.com/creack/pty/issues/127#issuecomment-932764012 - _ = cmd.Wait() - runtime.KeepAlive(ptty) - }() oPty := &otherPty{ pty: ptty, tty: tty, } - return oPty, cmd.Process, nil + oProcess := &otherProcess{ + pty: ptty, + cmd: cmd, + cmdDone: make(chan any), + } + go oProcess.waitInternal() + return oPty, oProcess, nil } diff --git a/pty/start_other_test.go b/pty/start_other_test.go index 30c87935bcd69..65894e536d805 100644 --- a/pty/start_other_test.go +++ b/pty/start_other_test.go @@ -7,6 +7,10 @@ import ( "os/exec" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + "go.uber.org/goleak" "github.com/coder/coder/pty/ptytest" @@ -20,7 +24,20 @@ func TestStart(t *testing.T) { t.Parallel() t.Run("Echo", func(t *testing.T) { t.Parallel() - pty, _ := ptytest.Start(t, exec.Command("echo", "test")) + pty, ps := ptytest.Start(t, exec.Command("echo", "test")) pty.ExpectMatch("test") + err := ps.Wait() + require.NoError(t, err) + }) + + t.Run("Kill", func(t *testing.T) { + t.Parallel() + _, ps := ptytest.Start(t, exec.Command("sleep", "30")) + err := ps.Kill() + assert.NoError(t, err) + err = ps.Wait() + var exitErr *exec.ExitError + require.True(t, xerrors.As(err, &exitErr)) + assert.Equal(t, -1, exitErr.ExitCode()) }) } diff --git a/pty/start_windows.go b/pty/start_windows.go index 1019a969aef2c..d638e5cdd1cc2 100644 --- a/pty/start_windows.go +++ b/pty/start_windows.go @@ -16,7 +16,7 @@ import ( // Allocates a PTY and starts the specified command attached to it. // See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session#creating-the-hosted-process -func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) { +func startPty(cmd *exec.Cmd) (PTY, Process, error) { fullPath, err := exec.LookPath(cmd.Path) if err != nil { return nil, nil, err @@ -83,7 +83,12 @@ func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) { if err != nil { return nil, nil, xerrors.Errorf("find process %d: %w", processInfo.ProcessId, err) } - return pty, process, nil + wp := &windowsProcess{ + cmdDone: make(chan any), + proc: process, + } + go wp.waitInternal() + return pty, wp, nil } // Taken from: https://github.com/microsoft/hcsshim/blob/7fbdca16f91de8792371ba22b7305bf4ca84170a/internal/exec/exec.go#L476 diff --git a/pty/start_windows_test.go b/pty/start_windows_test.go index d0398d0dec019..fc8f644d13b0d 100644 --- a/pty/start_windows_test.go +++ b/pty/start_windows_test.go @@ -8,8 +8,10 @@ import ( "testing" "github.com/coder/coder/pty/ptytest" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" + "golang.org/x/xerrors" ) func TestMain(m *testing.M) { @@ -20,13 +22,25 @@ func TestStart(t *testing.T) { t.Parallel() t.Run("Echo", func(t *testing.T) { t.Parallel() - pty, _ := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test")) + pty, ps := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test")) pty.ExpectMatch("test") + err := ps.Wait() + require.NoError(t, err) }) t.Run("Resize", func(t *testing.T) { t.Parallel() - pty, _ := ptytest.Start(t, exec.Command("cmd.exe")) + pty := ptytest.Start(t, exec.Command("cmd.exe")) err := pty.Resize(100, 50) require.NoError(t, err) }) + t.Run("Kill", func(t *testing.T) { + t.Parallel() + pty := ptytest.Start(t, exec.Command("cmd.exe")) + err := ps.Kill() + assert.NoError(t, err) + err = ps.Wait() + var exitErr *exec.ExitError + require.True(t, xerrors.As(err, &exitErr)) + assert.Equal(t, -1, exitErr.ExitCode()) + }) } From ca93544ec7cf76b0bb521d08a625e31bb34f406d Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 27 Jul 2022 10:57:31 -0700 Subject: [PATCH 2/4] Fix Windows tests Signed-off-by: Spike Curtis --- pty/start_windows_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pty/start_windows_test.go b/pty/start_windows_test.go index fc8f644d13b0d..7699ec6da4350 100644 --- a/pty/start_windows_test.go +++ b/pty/start_windows_test.go @@ -29,13 +29,13 @@ func TestStart(t *testing.T) { }) t.Run("Resize", func(t *testing.T) { t.Parallel() - pty := ptytest.Start(t, exec.Command("cmd.exe")) + pty, _ := ptytest.Start(t, exec.Command("cmd.exe")) err := pty.Resize(100, 50) require.NoError(t, err) }) t.Run("Kill", func(t *testing.T) { t.Parallel() - pty := ptytest.Start(t, exec.Command("cmd.exe")) + pty, ps := ptytest.Start(t, exec.Command("cmd.exe")) err := ps.Kill() assert.NoError(t, err) err = ps.Wait() From c378faabbcff70cd710fce65472176e69c550756 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 27 Jul 2022 11:41:57 -0700 Subject: [PATCH 3/4] Fix windows tests Signed-off-by: Spike Curtis --- pty/start_windows_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pty/start_windows_test.go b/pty/start_windows_test.go index 7699ec6da4350..c760d53b2f6a3 100644 --- a/pty/start_windows_test.go +++ b/pty/start_windows_test.go @@ -35,7 +35,7 @@ func TestStart(t *testing.T) { }) t.Run("Kill", func(t *testing.T) { t.Parallel() - pty, ps := ptytest.Start(t, exec.Command("cmd.exe")) + _, ps := ptytest.Start(t, exec.Command("cmd.exe")) err := ps.Kill() assert.NoError(t, err) err = ps.Wait() From ee781bd115a95b75ac5e624723f406bddf9b32ea Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 27 Jul 2022 12:03:26 -0700 Subject: [PATCH 4/4] Check for nonzero exit code on kill Signed-off-by: Spike Curtis --- pty/start_other_test.go | 2 +- pty/start_windows_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pty/start_other_test.go b/pty/start_other_test.go index 65894e536d805..2af8a708b9dca 100644 --- a/pty/start_other_test.go +++ b/pty/start_other_test.go @@ -38,6 +38,6 @@ func TestStart(t *testing.T) { err = ps.Wait() var exitErr *exec.ExitError require.True(t, xerrors.As(err, &exitErr)) - assert.Equal(t, -1, exitErr.ExitCode()) + assert.NotEqual(t, 0, exitErr.ExitCode()) }) } diff --git a/pty/start_windows_test.go b/pty/start_windows_test.go index c760d53b2f6a3..edbbd5dd99c3b 100644 --- a/pty/start_windows_test.go +++ b/pty/start_windows_test.go @@ -41,6 +41,6 @@ func TestStart(t *testing.T) { err = ps.Wait() var exitErr *exec.ExitError require.True(t, xerrors.As(err, &exitErr)) - assert.Equal(t, -1, exitErr.ExitCode()) + assert.NotEqual(t, 0, exitErr.ExitCode()) }) }