diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index 1829449a850be..b6b916b834784 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -237,8 +237,29 @@ func (s *Server) sessionHandler(session ssh.Session) { err := s.sessionStart(logger, session, extraEnv) var exitError *exec.ExitError if xerrors.As(err, &exitError) { - logger.Info(ctx, "ssh session returned", slog.Error(exitError)) - _ = session.Exit(exitError.ExitCode()) + code := exitError.ExitCode() + if code == -1 { + // If we return -1 here, it will be transmitted as an + // uint32(4294967295). This exit code is nonsense, so + // instead we return 255 (same as OpenSSH). This is + // also the same exit code that the shell returns for + // -1. + // + // For signals, we could consider sending 128+signal + // instead (however, OpenSSH doesn't seem to do this). + code = 255 + } + logger.Info(ctx, "ssh session returned", + slog.Error(exitError), + slog.F("process_exit_code", exitError.ExitCode()), + slog.F("exit_code", code), + ) + + // TODO(mafredri): For signal exit, there's also an "exit-signal" + // request (session.Exit sends "exit-status"), however, since it's + // not implemented on the session interface and not used by + // OpenSSH, we'll leave it for now. + _ = session.Exit(code) return } if err != nil { diff --git a/agent/agentssh/agentssh_test.go b/agent/agentssh/agentssh_test.go index 48be4f5619630..49d07a11bd51e 100644 --- a/agent/agentssh/agentssh_test.go +++ b/agent/agentssh/agentssh_test.go @@ -227,7 +227,13 @@ func TestNewServer_Signal(t *testing.T) { require.NoError(t, sc.Err()) err = sess.Wait() - require.Error(t, err) + exitErr := &ssh.ExitError{} + require.ErrorAs(t, err, &exitErr) + wantCode := 255 + if runtime.GOOS == "windows" { + wantCode = 1 + } + require.Equal(t, wantCode, exitErr.ExitStatus()) }) t.Run("PTY", func(t *testing.T) { t.Parallel() @@ -300,7 +306,13 @@ func TestNewServer_Signal(t *testing.T) { require.NoError(t, sc.Err()) err = sess.Wait() - require.Error(t, err) + exitErr := &ssh.ExitError{} + require.ErrorAs(t, err, &exitErr) + wantCode := 255 + if runtime.GOOS == "windows" { + wantCode = 1 + } + require.Equal(t, wantCode, exitErr.ExitStatus()) }) } diff --git a/cli/root.go b/cli/root.go index 8d645ea5f1d7a..31aa5fec47629 100644 --- a/cli/root.go +++ b/cli/root.go @@ -136,14 +136,22 @@ func (r *RootCmd) RunMain(subcommands []*clibase.Cmd) { } err = cmd.Invoke().WithOS().Run() if err != nil { + code := 1 + var exitErr *exitError + if errors.As(err, &exitErr) { + code = exitErr.code + err = exitErr.err + } if errors.Is(err, cliui.Canceled) { //nolint:revive - os.Exit(1) + os.Exit(code) } f := prettyErrorFormatter{w: os.Stderr, verbose: r.verbose} - f.format(err) + if err != nil { + f.format(err) + } //nolint:revive - os.Exit(1) + os.Exit(code) } } @@ -953,6 +961,30 @@ func DumpHandler(ctx context.Context) { } } +type exitError struct { + code int + err error +} + +var _ error = (*exitError)(nil) + +func (e *exitError) Error() string { + if e.err != nil { + return fmt.Sprintf("exit code %d: %v", e.code, e.err) + } + return fmt.Sprintf("exit code %d", e.code) +} + +func (e *exitError) Unwrap() error { + return e.err +} + +// ExitError returns an error that will cause the CLI to exit with the given +// exit code. If err is non-nil, it will be wrapped by the returned error. +func ExitError(code int, err error) error { + return &exitError{code: code, err: err} +} + // IiConnectionErr is a convenience function for checking if the source of an // error is due to a 'connection refused', 'no such host', etc. func isConnectionError(err error) bool { diff --git a/cli/ssh.go b/cli/ssh.go index 0c4b537949806..c409bf877ddfe 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -379,11 +379,16 @@ func (r *RootCmd) ssh() *clibase.Cmd { err = sshSession.Wait() if err != nil { + if exitErr := (&gossh.ExitError{}); errors.As(err, &exitErr) { + // Clear the error since it's not useful beyond + // reporting status. + return ExitError(exitErr.ExitStatus(), nil) + } // If the connection drops unexpectedly, we get an // ExitMissingError but no other error details, so try to at // least give the user a better message if errors.Is(err, &gossh.ExitMissingError{}) { - return xerrors.New("SSH connection ended unexpectedly") + return ExitError(255, xerrors.New("SSH connection ended unexpectedly")) } return xerrors.Errorf("session ended: %w", err) }