Skip to content

Commit b60d811

Browse files
committed
Return proper exit code on ssh with TTY
Signed-off-by: Spike Curtis <spike@coder.com>
1 parent 6377f17 commit b60d811

11 files changed

+185
-28
lines changed

agent/agent.go

+37-8
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ const (
4343
ProtocolReconnectingPTY = "reconnecting-pty"
4444
ProtocolSSH = "ssh"
4545
ProtocolDial = "dial"
46+
47+
// MagicSessionErrorCode indicates that something went wrong with the session, rather than the
48+
// command just returning a nonzero exit code, and is chosen as an arbitrary, high number
49+
// unlikely to shadow other exit codes, which are typically 1, 2, 3, etc.
50+
MagicSessionErrorCode = 229
4651
)
4752

4853
type Options struct {
@@ -273,9 +278,17 @@ func (a *agent) init(ctx context.Context) {
273278
},
274279
Handler: func(session ssh.Session) {
275280
err := a.handleSSHSession(session)
281+
var exitError *exec.ExitError
282+
if xerrors.As(err, &exitError) {
283+
a.logger.Debug(ctx, "ssh session returned", slog.Error(exitError))
284+
_ = session.Exit(exitError.ExitCode())
285+
return
286+
}
276287
if err != nil {
277288
a.logger.Warn(ctx, "ssh session failed", slog.Error(err))
278-
_ = session.Exit(1)
289+
// This exit code is designed to be unlikely to be confused for a legit exit code
290+
// from the process.
291+
_ = session.Exit(MagicSessionErrorCode)
279292
return
280293
}
281294
},
@@ -403,7 +416,7 @@ func (a *agent) createCommand(ctx context.Context, rawCommand string, env []stri
403416
return cmd, nil
404417
}
405418

406-
func (a *agent) handleSSHSession(session ssh.Session) error {
419+
func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
407420
cmd, err := a.createCommand(session.Context(), session.RawCommand(), session.Environ())
408421
if err != nil {
409422
return err
@@ -426,14 +439,24 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
426439
if err != nil {
427440
return xerrors.Errorf("start command: %w", err)
428441
}
442+
defer func() {
443+
closeErr := ptty.Close()
444+
if closeErr != nil {
445+
a.logger.Warn(context.Background(), "failed to close tty",
446+
slog.Error(closeErr))
447+
if retErr == nil {
448+
retErr = closeErr
449+
}
450+
}
451+
}()
429452
err = ptty.Resize(uint16(sshPty.Window.Height), uint16(sshPty.Window.Width))
430453
if err != nil {
431454
return xerrors.Errorf("resize ptty: %w", err)
432455
}
433456
go func() {
434457
for win := range windowSize {
435-
err = ptty.Resize(uint16(win.Height), uint16(win.Width))
436-
if err != nil {
458+
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
459+
if resizeErr != nil {
437460
a.logger.Warn(context.Background(), "failed to resize tty", slog.Error(err))
438461
}
439462
}
@@ -444,9 +467,15 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
444467
go func() {
445468
_, _ = io.Copy(session, ptty.Output())
446469
}()
447-
_, _ = process.Wait()
448-
_ = ptty.Close()
449-
return nil
470+
err = process.Wait()
471+
var exitErr *exec.ExitError
472+
// ExitErrors just mean the command we run returned a non-zero exit code, which is normal
473+
// and not something to be concerned about. But, if it's something else, we should log it.
474+
if err != nil && !xerrors.As(err, &exitErr) {
475+
a.logger.Warn(context.Background(), "wait error",
476+
slog.Error(err))
477+
}
478+
return err
450479
}
451480

452481
cmd.Stdout = session
@@ -549,7 +578,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne
549578
go func() {
550579
// If the process dies randomly, we should
551580
// close the pty.
552-
_, _ = process.Wait()
581+
_ = process.Wait()
553582
rpty.Close()
554583
}()
555584
go func() {

agent/agent_test.go

+26-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ import (
1616
"testing"
1717
"time"
1818

19+
"golang.org/x/xerrors"
20+
1921
scp "github.com/bramvdbogaerde/go-scp"
2022
"github.com/google/uuid"
2123
"github.com/pion/udp"
@@ -69,7 +71,7 @@ func TestAgent(t *testing.T) {
6971
require.True(t, strings.HasSuffix(strings.TrimSpace(string(output)), "gitssh --"))
7072
})
7173

72-
t.Run("SessionTTY", func(t *testing.T) {
74+
t.Run("SessionTTYShell", func(t *testing.T) {
7375
t.Parallel()
7476
if runtime.GOOS == "windows" {
7577
// This might be our implementation, or ConPTY itself.
@@ -103,6 +105,29 @@ func TestAgent(t *testing.T) {
103105
require.NoError(t, err)
104106
})
105107

108+
t.Run("SessionTTYExitCode", func(t *testing.T) {
109+
t.Parallel()
110+
session := setupSSHSession(t, agent.Metadata{})
111+
command := "areallynotrealcommand"
112+
err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
113+
require.NoError(t, err)
114+
ptty := ptytest.New(t)
115+
require.NoError(t, err)
116+
session.Stdout = ptty.Output()
117+
session.Stderr = ptty.Output()
118+
session.Stdin = ptty.Input()
119+
err = session.Start(command)
120+
require.NoError(t, err)
121+
err = session.Wait()
122+
exitErr := &ssh.ExitError{}
123+
require.True(t, xerrors.As(err, &exitErr))
124+
if runtime.GOOS == "windows" {
125+
assert.Equal(t, 1, exitErr.ExitStatus())
126+
} else {
127+
assert.Equal(t, 127, exitErr.ExitStatus())
128+
}
129+
})
130+
106131
t.Run("LocalForwarding", func(t *testing.T) {
107132
t.Parallel()
108133
random, err := net.Listen("tcp", "127.0.0.1:0")

pty/pty.go

+10
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,16 @@ type PTY interface {
2929
Resize(height uint16, width uint16) error
3030
}
3131

32+
// Process represents a process running in a PTY
33+
type Process interface {
34+
35+
// Wait for the command to complete. Returned error is as for exec.Cmd.Wait()
36+
Wait() error
37+
38+
// Kill the command process. Returned error is as for os.Process.Kill()
39+
Kill() error
40+
}
41+
3242
// WithFlags represents a PTY whose flags can be inspected, in particular
3343
// to determine whether local echo is enabled.
3444
type WithFlags interface {

pty/pty_other.go

+29
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ package pty
55

66
import (
77
"os"
8+
"os/exec"
9+
"runtime"
810
"sync"
911

1012
"github.com/creack/pty"
@@ -27,6 +29,15 @@ type otherPty struct {
2729
pty, tty *os.File
2830
}
2931

32+
type otherProcess struct {
33+
pty *os.File
34+
cmd *exec.Cmd
35+
36+
// cmdDone protects access to cmdErr: anything reading cmdErr should read from cmdDone first.
37+
cmdDone chan any
38+
cmdErr error
39+
}
40+
3041
func (p *otherPty) Input() ReadWriter {
3142
return ReadWriter{
3243
Reader: p.tty,
@@ -66,3 +77,21 @@ func (p *otherPty) Close() error {
6677
}
6778
return nil
6879
}
80+
81+
func (p *otherProcess) Wait() error {
82+
<-p.cmdDone
83+
return p.cmdErr
84+
}
85+
86+
func (p *otherProcess) Kill() error {
87+
return p.cmd.Process.Kill()
88+
}
89+
90+
func (p *otherProcess) waitInternal() {
91+
// The GC can garbage collect the TTY FD before the command
92+
// has finished running. See:
93+
// https://github.com/creack/pty/issues/127#issuecomment-932764012
94+
p.cmdErr = p.cmd.Wait()
95+
runtime.KeepAlive(p.pty)
96+
close(p.cmdDone)
97+
}

pty/pty_windows.go

+30
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package pty
55

66
import (
77
"os"
8+
"os/exec"
89
"sync"
910
"unsafe"
1011

@@ -66,6 +67,13 @@ type ptyWindows struct {
6667
closed bool
6768
}
6869

70+
type windowsProcess struct {
71+
// cmdDone protects access to cmdErr: anything reading cmdErr should read from cmdDone first.
72+
cmdDone chan any
73+
cmdErr error
74+
proc *os.Process
75+
}
76+
6977
func (p *ptyWindows) Output() ReadWriter {
7078
return ReadWriter{
7179
Reader: p.outputRead,
@@ -111,3 +119,25 @@ func (p *ptyWindows) Close() error {
111119

112120
return nil
113121
}
122+
123+
func (p *windowsProcess) waitInternal() {
124+
defer close(p.cmdDone)
125+
state, err := p.proc.Wait()
126+
if err != nil {
127+
p.cmdErr = err
128+
return
129+
}
130+
if !state.Success() {
131+
p.cmdErr = &exec.ExitError{ProcessState: state}
132+
return
133+
}
134+
}
135+
136+
func (p *windowsProcess) Wait() error {
137+
<-p.cmdDone
138+
return p.cmdErr
139+
}
140+
141+
func (p *windowsProcess) Kill() error {
142+
return p.proc.Kill()
143+
}

pty/ptytest/ptytest.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"bytes"
66
"context"
77
"io"
8-
"os"
98
"os/exec"
109
"runtime"
1110
"strings"
@@ -27,7 +26,7 @@ func New(t *testing.T) *PTY {
2726
return create(t, ptty, "cmd")
2827
}
2928

30-
func Start(t *testing.T, cmd *exec.Cmd) (*PTY, *os.Process) {
29+
func Start(t *testing.T, cmd *exec.Cmd) (*PTY, pty.Process) {
3130
ptty, ps, err := pty.Start(cmd)
3231
require.NoError(t, err)
3332
return create(t, ptty, cmd.Args[0]), ps

pty/start.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
package pty
22

33
import (
4-
"os"
54
"os/exec"
65
)
76

8-
func Start(cmd *exec.Cmd) (PTY, *os.Process, error) {
7+
// Start the command in a TTY. The calling code must not use cmd after passing it to the PTY, and
8+
// instead rely on the returned Process to manage the command/process.
9+
func Start(cmd *exec.Cmd) (PTY, Process, error) {
910
return startPty(cmd)
1011
}

pty/start_other.go

+8-10
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
package pty
55

66
import (
7-
"os"
87
"os/exec"
98
"runtime"
109
"strings"
@@ -14,7 +13,7 @@ import (
1413
"golang.org/x/xerrors"
1514
)
1615

17-
func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) {
16+
func startPty(cmd *exec.Cmd) (PTY, Process, error) {
1817
ptty, tty, err := pty.Open()
1918
if err != nil {
2019
return nil, nil, xerrors.Errorf("open: %w", err)
@@ -37,16 +36,15 @@ func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) {
3736
}
3837
return nil, nil, xerrors.Errorf("start: %w", err)
3938
}
40-
go func() {
41-
// The GC can garbage collect the TTY FD before the command
42-
// has finished running. See:
43-
// https://github.com/creack/pty/issues/127#issuecomment-932764012
44-
_ = cmd.Wait()
45-
runtime.KeepAlive(ptty)
46-
}()
4739
oPty := &otherPty{
4840
pty: ptty,
4941
tty: tty,
5042
}
51-
return oPty, cmd.Process, nil
43+
oProcess := &otherProcess{
44+
pty: ptty,
45+
cmd: cmd,
46+
cmdDone: make(chan any),
47+
}
48+
go oProcess.waitInternal()
49+
return oPty, oProcess, nil
5250
}

pty/start_other_test.go

+18-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ import (
77
"os/exec"
88
"testing"
99

10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
"golang.org/x/xerrors"
13+
1014
"go.uber.org/goleak"
1115

1216
"github.com/coder/coder/pty/ptytest"
@@ -20,7 +24,20 @@ func TestStart(t *testing.T) {
2024
t.Parallel()
2125
t.Run("Echo", func(t *testing.T) {
2226
t.Parallel()
23-
pty, _ := ptytest.Start(t, exec.Command("echo", "test"))
27+
pty, ps := ptytest.Start(t, exec.Command("echo", "test"))
2428
pty.ExpectMatch("test")
29+
err := ps.Wait()
30+
require.NoError(t, err)
31+
})
32+
33+
t.Run("Kill", func(t *testing.T) {
34+
t.Parallel()
35+
_, ps := ptytest.Start(t, exec.Command("sleep", "30"))
36+
err := ps.Kill()
37+
assert.NoError(t, err)
38+
err = ps.Wait()
39+
var exitErr *exec.ExitError
40+
require.True(t, xerrors.As(err, &exitErr))
41+
assert.Equal(t, -1, exitErr.ExitCode())
2542
})
2643
}

pty/start_windows.go

+7-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import (
1616

1717
// Allocates a PTY and starts the specified command attached to it.
1818
// See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session#creating-the-hosted-process
19-
func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) {
19+
func startPty(cmd *exec.Cmd) (PTY, Process, error) {
2020
fullPath, err := exec.LookPath(cmd.Path)
2121
if err != nil {
2222
return nil, nil, err
@@ -83,7 +83,12 @@ func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) {
8383
if err != nil {
8484
return nil, nil, xerrors.Errorf("find process %d: %w", processInfo.ProcessId, err)
8585
}
86-
return pty, process, nil
86+
wp := &windowsProcess{
87+
cmdDone: make(chan any),
88+
proc: process,
89+
}
90+
go wp.waitInternal()
91+
return pty, wp, nil
8792
}
8893

8994
// Taken from: https://github.com/microsoft/hcsshim/blob/7fbdca16f91de8792371ba22b7305bf4ca84170a/internal/exec/exec.go#L476

0 commit comments

Comments
 (0)