diff --git a/cmd/coder/shell.go b/cmd/coder/shell.go index 9abc828f..e54c39bf 100644 --- a/cmd/coder/shell.go +++ b/cmd/coder/shell.go @@ -8,11 +8,13 @@ import ( "time" "github.com/spf13/pflag" - "go.coder.com/cli" - "go.coder.com/flog" "golang.org/x/crypto/ssh/terminal" "golang.org/x/sys/unix" "golang.org/x/time/rate" + "golang.org/x/xerrors" + + "go.coder.com/cli" + "go.coder.com/flog" client "cdr.dev/coder-cli/internal/entclient" "cdr.dev/coder-cli/wush" @@ -30,14 +32,20 @@ func (cmd *shellCmd) Spec() cli.CommandSpec { } } -func enableTerminal(fd int) { - _, err := terminal.MakeRaw(fd) +func enableTerminal(fd int) (restore func(), err error) { + state, err := terminal.MakeRaw(fd) if err != nil { - flog.Fatal("make raw term: %v", err) + return restore, xerrors.Errorf("make raw term: %w", err) } + return func() { + err := terminal.Restore(fd, state) + if err != nil { + flog.Error("restore term state: %v", err) + } + }, nil } -func (cmd *shellCmd) sendResizeEvents(termfd int, client *wush.Client) { +func sendResizeEvents(termfd int, client *wush.Client) { sigs := make(chan os.Signal, 16) signal.Notify(sigs, unix.SIGWINCH) @@ -83,6 +91,14 @@ func (cmd *shellCmd) Run(fl *pflag.FlagSet) { args = []string{"-c", "exec $(getent passwd $(whoami) | awk -F: '{ print $7 }')"} } + exitCode, err := runCommand(envName, command, args) + if err != nil { + flog.Fatal("run command: %v", err) + } + os.Exit(exitCode) +} + +func runCommand(envName string, command string, args []string) (int, error) { var ( entClient = requireAuth() env = findEnv(entClient, envName) @@ -92,7 +108,11 @@ func (cmd *shellCmd) Run(fl *pflag.FlagSet) { tty := terminal.IsTerminal(termfd) if tty { - enableTerminal(termfd) + restore, err := enableTerminal(termfd) + if err != nil { + return -1, err + } + defer restore() } conn, err := entClient.DialWush( @@ -102,16 +122,16 @@ func (cmd *shellCmd) Run(fl *pflag.FlagSet) { Stdin: true, }, command, args...) if err != nil { - flog.Fatal("dial wush: %v", err) + return -1, err } ctx := context.Background() wc := wush.NewClient(ctx, conn) if tty { - go cmd.sendResizeEvents(termfd, wc) + go sendResizeEvents(termfd, wc) } - go func(){ + go func() { defer wc.Stdin.Close() io.Copy(wc.Stdin, os.Stdin) }() @@ -120,7 +140,8 @@ func (cmd *shellCmd) Run(fl *pflag.FlagSet) { exitCode, err := wc.Wait() if err != nil { - flog.Fatal("wush error: %v", err) + return -1, err } - os.Exit(int(exitCode)) + + return int(exitCode), nil }