Skip to content

Commit 61be4df

Browse files
authored
fix: improve exit codes for agent/agentssh and cli/ssh (#10850)
1 parent dbdcad0 commit 61be4df

File tree

4 files changed

+78
-8
lines changed

4 files changed

+78
-8
lines changed

agent/agentssh/agentssh.go

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,29 @@ func (s *Server) sessionHandler(session ssh.Session) {
237237
err := s.sessionStart(logger, session, extraEnv)
238238
var exitError *exec.ExitError
239239
if xerrors.As(err, &exitError) {
240-
logger.Info(ctx, "ssh session returned", slog.Error(exitError))
241-
_ = session.Exit(exitError.ExitCode())
240+
code := exitError.ExitCode()
241+
if code == -1 {
242+
// If we return -1 here, it will be transmitted as an
243+
// uint32(4294967295). This exit code is nonsense, so
244+
// instead we return 255 (same as OpenSSH). This is
245+
// also the same exit code that the shell returns for
246+
// -1.
247+
//
248+
// For signals, we could consider sending 128+signal
249+
// instead (however, OpenSSH doesn't seem to do this).
250+
code = 255
251+
}
252+
logger.Info(ctx, "ssh session returned",
253+
slog.Error(exitError),
254+
slog.F("process_exit_code", exitError.ExitCode()),
255+
slog.F("exit_code", code),
256+
)
257+
258+
// TODO(mafredri): For signal exit, there's also an "exit-signal"
259+
// request (session.Exit sends "exit-status"), however, since it's
260+
// not implemented on the session interface and not used by
261+
// OpenSSH, we'll leave it for now.
262+
_ = session.Exit(code)
242263
return
243264
}
244265
if err != nil {

agent/agentssh/agentssh_test.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,13 @@ func TestNewServer_Signal(t *testing.T) {
227227
require.NoError(t, sc.Err())
228228

229229
err = sess.Wait()
230-
require.Error(t, err)
230+
exitErr := &ssh.ExitError{}
231+
require.ErrorAs(t, err, &exitErr)
232+
wantCode := 255
233+
if runtime.GOOS == "windows" {
234+
wantCode = 1
235+
}
236+
require.Equal(t, wantCode, exitErr.ExitStatus())
231237
})
232238
t.Run("PTY", func(t *testing.T) {
233239
t.Parallel()
@@ -300,7 +306,13 @@ func TestNewServer_Signal(t *testing.T) {
300306
require.NoError(t, sc.Err())
301307

302308
err = sess.Wait()
303-
require.Error(t, err)
309+
exitErr := &ssh.ExitError{}
310+
require.ErrorAs(t, err, &exitErr)
311+
wantCode := 255
312+
if runtime.GOOS == "windows" {
313+
wantCode = 1
314+
}
315+
require.Equal(t, wantCode, exitErr.ExitStatus())
304316
})
305317
}
306318

cli/root.go

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,22 @@ func (r *RootCmd) RunMain(subcommands []*clibase.Cmd) {
136136
}
137137
err = cmd.Invoke().WithOS().Run()
138138
if err != nil {
139+
code := 1
140+
var exitErr *exitError
141+
if errors.As(err, &exitErr) {
142+
code = exitErr.code
143+
err = exitErr.err
144+
}
139145
if errors.Is(err, cliui.Canceled) {
140146
//nolint:revive
141-
os.Exit(1)
147+
os.Exit(code)
142148
}
143149
f := prettyErrorFormatter{w: os.Stderr, verbose: r.verbose}
144-
f.format(err)
150+
if err != nil {
151+
f.format(err)
152+
}
145153
//nolint:revive
146-
os.Exit(1)
154+
os.Exit(code)
147155
}
148156
}
149157

@@ -953,6 +961,30 @@ func DumpHandler(ctx context.Context) {
953961
}
954962
}
955963

964+
type exitError struct {
965+
code int
966+
err error
967+
}
968+
969+
var _ error = (*exitError)(nil)
970+
971+
func (e *exitError) Error() string {
972+
if e.err != nil {
973+
return fmt.Sprintf("exit code %d: %v", e.code, e.err)
974+
}
975+
return fmt.Sprintf("exit code %d", e.code)
976+
}
977+
978+
func (e *exitError) Unwrap() error {
979+
return e.err
980+
}
981+
982+
// ExitError returns an error that will cause the CLI to exit with the given
983+
// exit code. If err is non-nil, it will be wrapped by the returned error.
984+
func ExitError(code int, err error) error {
985+
return &exitError{code: code, err: err}
986+
}
987+
956988
// IiConnectionErr is a convenience function for checking if the source of an
957989
// error is due to a 'connection refused', 'no such host', etc.
958990
func isConnectionError(err error) bool {

cli/ssh.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,11 +379,16 @@ func (r *RootCmd) ssh() *clibase.Cmd {
379379

380380
err = sshSession.Wait()
381381
if err != nil {
382+
if exitErr := (&gossh.ExitError{}); errors.As(err, &exitErr) {
383+
// Clear the error since it's not useful beyond
384+
// reporting status.
385+
return ExitError(exitErr.ExitStatus(), nil)
386+
}
382387
// If the connection drops unexpectedly, we get an
383388
// ExitMissingError but no other error details, so try to at
384389
// least give the user a better message
385390
if errors.Is(err, &gossh.ExitMissingError{}) {
386-
return xerrors.New("SSH connection ended unexpectedly")
391+
return ExitError(255, xerrors.New("SSH connection ended unexpectedly"))
387392
}
388393
return xerrors.Errorf("session ended: %w", err)
389394
}

0 commit comments

Comments
 (0)