Skip to content

Return proper exit code on ssh with TTY #3192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ const (
ProtocolReconnectingPTY = "reconnecting-pty"
ProtocolSSH = "ssh"
ProtocolDial = "dial"

// MagicSessionErrorCode indicates that something went wrong with the session, rather than the
// command just returning a nonzero exit code, and is chosen as an arbitrary, high number
// unlikely to shadow other exit codes, which are typically 1, 2, 3, etc.
MagicSessionErrorCode = 229
)

type Options struct {
Expand Down Expand Up @@ -273,9 +278,17 @@ func (a *agent) init(ctx context.Context) {
},
Handler: func(session ssh.Session) {
err := a.handleSSHSession(session)
var exitError *exec.ExitError
if xerrors.As(err, &exitError) {
a.logger.Debug(ctx, "ssh session returned", slog.Error(exitError))
_ = session.Exit(exitError.ExitCode())
return
}
if err != nil {
a.logger.Warn(ctx, "ssh session failed", slog.Error(err))
_ = session.Exit(1)
// This exit code is designed to be unlikely to be confused for a legit exit code
// from the process.
_ = session.Exit(MagicSessionErrorCode)
return
}
},
Expand Down Expand Up @@ -403,7 +416,7 @@ func (a *agent) createCommand(ctx context.Context, rawCommand string, env []stri
return cmd, nil
}

func (a *agent) handleSSHSession(session ssh.Session) error {
func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
cmd, err := a.createCommand(session.Context(), session.RawCommand(), session.Environ())
if err != nil {
return err
Expand All @@ -426,14 +439,24 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
if err != nil {
return xerrors.Errorf("start command: %w", err)
}
defer func() {
closeErr := ptty.Close()
if closeErr != nil {
a.logger.Warn(context.Background(), "failed to close tty",
slog.Error(closeErr))
if retErr == nil {
retErr = closeErr
}
}
}()
err = ptty.Resize(uint16(sshPty.Window.Height), uint16(sshPty.Window.Width))
if err != nil {
return xerrors.Errorf("resize ptty: %w", err)
}
go func() {
for win := range windowSize {
err = ptty.Resize(uint16(win.Height), uint16(win.Width))
if err != nil {
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
if resizeErr != nil {
a.logger.Warn(context.Background(), "failed to resize tty", slog.Error(err))
}
}
Expand All @@ -444,9 +467,15 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
go func() {
_, _ = io.Copy(session, ptty.Output())
}()
_, _ = process.Wait()
_ = ptty.Close()
return nil
err = process.Wait()
var exitErr *exec.ExitError
// ExitErrors just mean the command we run returned a non-zero exit code, which is normal
// and not something to be concerned about. But, if it's something else, we should log it.
if err != nil && !xerrors.As(err, &exitErr) {
a.logger.Warn(context.Background(), "wait error",
slog.Error(err))
}
return err
}

cmd.Stdout = session
Expand Down Expand Up @@ -549,7 +578,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne
go func() {
// If the process dies randomly, we should
// close the pty.
_, _ = process.Wait()
_ = process.Wait()
rpty.Close()
}()
go func() {
Expand Down
27 changes: 26 additions & 1 deletion agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
"testing"
"time"

"golang.org/x/xerrors"

scp "github.com/bramvdbogaerde/go-scp"
"github.com/google/uuid"
"github.com/pion/udp"
Expand Down Expand Up @@ -69,7 +71,7 @@ func TestAgent(t *testing.T) {
require.True(t, strings.HasSuffix(strings.TrimSpace(string(output)), "gitssh --"))
})

t.Run("SessionTTY", func(t *testing.T) {
t.Run("SessionTTYShell", func(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
// This might be our implementation, or ConPTY itself.
Expand Down Expand Up @@ -103,6 +105,29 @@ func TestAgent(t *testing.T) {
require.NoError(t, err)
})

t.Run("SessionTTYExitCode", func(t *testing.T) {
t.Parallel()
session := setupSSHSession(t, agent.Metadata{})
command := "areallynotrealcommand"
err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
require.NoError(t, err)
ptty := ptytest.New(t)
require.NoError(t, err)
session.Stdout = ptty.Output()
session.Stderr = ptty.Output()
session.Stdin = ptty.Input()
err = session.Start(command)
require.NoError(t, err)
err = session.Wait()
exitErr := &ssh.ExitError{}
require.True(t, xerrors.As(err, &exitErr))
if runtime.GOOS == "windows" {
assert.Equal(t, 1, exitErr.ExitStatus())
} else {
assert.Equal(t, 127, exitErr.ExitStatus())
}
})

t.Run("LocalForwarding", func(t *testing.T) {
t.Parallel()
random, err := net.Listen("tcp", "127.0.0.1:0")
Expand Down
10 changes: 10 additions & 0 deletions pty/pty.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ type PTY interface {
Resize(height uint16, width uint16) error
}

// Process represents a process running in a PTY
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be helpful to explain why this abstraction is necessary.

type Process interface {

// Wait for the command to complete. Returned error is as for exec.Cmd.Wait()
Wait() error

// Kill the command process. Returned error is as for os.Process.Kill()
Kill() error
}

// WithFlags represents a PTY whose flags can be inspected, in particular
// to determine whether local echo is enabled.
type WithFlags interface {
Expand Down
29 changes: 29 additions & 0 deletions pty/pty_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package pty

import (
"os"
"os/exec"
"runtime"
"sync"

"github.com/creack/pty"
Expand All @@ -27,6 +29,15 @@ type otherPty struct {
pty, tty *os.File
}

type otherProcess struct {
pty *os.File
cmd *exec.Cmd

// cmdDone protects access to cmdErr: anything reading cmdErr should read from cmdDone first.
cmdDone chan any
cmdErr error
}

func (p *otherPty) Input() ReadWriter {
return ReadWriter{
Reader: p.tty,
Expand Down Expand Up @@ -66,3 +77,21 @@ func (p *otherPty) Close() error {
}
return nil
}

func (p *otherProcess) Wait() error {
<-p.cmdDone
return p.cmdErr
}

func (p *otherProcess) Kill() error {
return p.cmd.Process.Kill()
}

func (p *otherProcess) waitInternal() {
// The GC can garbage collect the TTY FD before the command
// has finished running. See:
// https://github.com/creack/pty/issues/127#issuecomment-932764012
p.cmdErr = p.cmd.Wait()
runtime.KeepAlive(p.pty)
close(p.cmdDone)
}
30 changes: 30 additions & 0 deletions pty/pty_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package pty

import (
"os"
"os/exec"
"sync"
"unsafe"

Expand Down Expand Up @@ -66,6 +67,13 @@ type ptyWindows struct {
closed bool
}

type windowsProcess struct {
// cmdDone protects access to cmdErr: anything reading cmdErr should read from cmdDone first.
cmdDone chan any
cmdErr error
proc *os.Process
}

func (p *ptyWindows) Output() ReadWriter {
return ReadWriter{
Reader: p.outputRead,
Expand Down Expand Up @@ -111,3 +119,25 @@ func (p *ptyWindows) Close() error {

return nil
}

func (p *windowsProcess) waitInternal() {
defer close(p.cmdDone)
state, err := p.proc.Wait()
if err != nil {
p.cmdErr = err
return
}
if !state.Success() {
p.cmdErr = &exec.ExitError{ProcessState: state}
return
}
}

func (p *windowsProcess) Wait() error {
<-p.cmdDone
return p.cmdErr
}

func (p *windowsProcess) Kill() error {
return p.proc.Kill()
}
3 changes: 1 addition & 2 deletions pty/ptytest/ptytest.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"bytes"
"context"
"io"
"os"
"os/exec"
"runtime"
"strings"
Expand All @@ -27,7 +26,7 @@ func New(t *testing.T) *PTY {
return create(t, ptty, "cmd")
}

func Start(t *testing.T, cmd *exec.Cmd) (*PTY, *os.Process) {
func Start(t *testing.T, cmd *exec.Cmd) (*PTY, pty.Process) {
ptty, ps, err := pty.Start(cmd)
require.NoError(t, err)
return create(t, ptty, cmd.Args[0]), ps
Expand Down
5 changes: 3 additions & 2 deletions pty/start.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package pty

import (
"os"
"os/exec"
)

func Start(cmd *exec.Cmd) (PTY, *os.Process, error) {
// Start the command in a TTY. The calling code must not use cmd after passing it to the PTY, and
// instead rely on the returned Process to manage the command/process.
func Start(cmd *exec.Cmd) (PTY, Process, error) {
return startPty(cmd)
}
18 changes: 8 additions & 10 deletions pty/start_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package pty

import (
"os"
"os/exec"
"runtime"
"strings"
Expand All @@ -14,7 +13,7 @@ import (
"golang.org/x/xerrors"
)

func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) {
func startPty(cmd *exec.Cmd) (PTY, Process, error) {
ptty, tty, err := pty.Open()
if err != nil {
return nil, nil, xerrors.Errorf("open: %w", err)
Expand All @@ -37,16 +36,15 @@ func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) {
}
return nil, nil, xerrors.Errorf("start: %w", err)
}
go func() {
// The GC can garbage collect the TTY FD before the command
// has finished running. See:
// https://github.com/creack/pty/issues/127#issuecomment-932764012
_ = cmd.Wait()
runtime.KeepAlive(ptty)
}()
oPty := &otherPty{
pty: ptty,
tty: tty,
}
return oPty, cmd.Process, nil
oProcess := &otherProcess{
pty: ptty,
cmd: cmd,
cmdDone: make(chan any),
}
go oProcess.waitInternal()
return oPty, oProcess, nil
}
19 changes: 18 additions & 1 deletion pty/start_other_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ import (
"os/exec"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"

"go.uber.org/goleak"

"github.com/coder/coder/pty/ptytest"
Expand All @@ -20,7 +24,20 @@ func TestStart(t *testing.T) {
t.Parallel()
t.Run("Echo", func(t *testing.T) {
t.Parallel()
pty, _ := ptytest.Start(t, exec.Command("echo", "test"))
pty, ps := ptytest.Start(t, exec.Command("echo", "test"))
pty.ExpectMatch("test")
err := ps.Wait()
require.NoError(t, err)
})

t.Run("Kill", func(t *testing.T) {
t.Parallel()
_, ps := ptytest.Start(t, exec.Command("sleep", "30"))
err := ps.Kill()
assert.NoError(t, err)
err = ps.Wait()
var exitErr *exec.ExitError
require.True(t, xerrors.As(err, &exitErr))
assert.NotEqual(t, 0, exitErr.ExitCode())
})
}
9 changes: 7 additions & 2 deletions pty/start_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (

// Allocates a PTY and starts the specified command attached to it.
// See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session#creating-the-hosted-process
func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) {
func startPty(cmd *exec.Cmd) (PTY, Process, error) {
fullPath, err := exec.LookPath(cmd.Path)
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -83,7 +83,12 @@ func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) {
if err != nil {
return nil, nil, xerrors.Errorf("find process %d: %w", processInfo.ProcessId, err)
}
return pty, process, nil
wp := &windowsProcess{
cmdDone: make(chan any),
proc: process,
}
go wp.waitInternal()
return pty, wp, nil
}

// Taken from: https://github.com/microsoft/hcsshim/blob/7fbdca16f91de8792371ba22b7305bf4ca84170a/internal/exec/exec.go#L476
Expand Down
Loading