diff --git a/README.md b/README.md index d9606c0..709ced0 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ Error handling is omitted for brevity. ```golang conn, _, _ := websocket.Dial(ctx, "ws://remote.exec.addr", nil) -defer conn.Close(websocket.StatusAbnormalClosure, "terminate process") +defer conn.Close(websocket.StatusNormalClosure, "normal closure") execer := wsep.RemoteExecer(conn) process, _ := execer.Start(ctx, wsep.Command{ @@ -25,7 +25,6 @@ go io.Copy(os.Stderr, process.Stderr()) go io.Copy(os.Stdout, process.Stdout()) process.Wait() -conn.Close(websocket.StatusNormalClosure, "normal closure") ``` ### Server @@ -33,10 +32,9 @@ conn.Close(websocket.StatusNormalClosure, "normal closure") ```golang func (s server) ServeHTTP(w http.ResponseWriter, r *http.Request) { conn, _ := websocket.Accept(w, r, nil) + defer conn.Close(websocket.StatusNormalClosure, "normal closure") wsep.Serve(r.Context(), conn, wsep.LocalExecer{}) - - ws.Close(websocket.StatusNormalClosure, "normal closure") } ``` diff --git a/browser/client.ts b/browser/client.ts index 009912e..0a5322e 100644 --- a/browser/client.ts +++ b/browser/client.ts @@ -14,7 +14,7 @@ export interface Command { } export type ClientHeader = - | { type: 'start'; command: Command } + | { type: 'start'; id: string; command: Command; cols: number; rows: number; } | { type: 'stdin' } | { type: 'close_stdin' } | { type: 'resize'; cols: number; rows: number }; @@ -42,8 +42,14 @@ export const closeStdin = (ws: WebSocket) => { ws.send(msg.buffer); }; -export const startCommand = (ws: WebSocket, command: Command) => { - const msg = joinMessage({ type: 'start', command: command }); +export const startCommand = ( + ws: WebSocket, + command: Command, + id: string, + rows: number, + cols: number +) => { + const msg = joinMessage({ type: 'start', command, id, rows, cols }); ws.send(msg.buffer); }; diff --git a/ci/alt.sh b/ci/alt.sh new file mode 100755 index 0000000..b92d60e --- /dev/null +++ b/ci/alt.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +# Script for testing the alt screen. + +# Enter alt screen. +tput smcup + +function display() { + # Clear the screen. + tput clear + # Move cursor to the top left. + tput cup 0 0 + # Display content. + echo "ALT SCREEN" +} + +function redraw() { + display + echo "redrawn" +} + +# Re-display on resize. +trap 'redraw' WINCH + +display + +# The trap will not run while waiting for a command so read input in a loop with +# a timeout. +while true ; do + if read -n 1 -t .1 ; then + # Clear the screen. + tput clear + # Exit alt screen. + tput rmcup + exit + fi +done diff --git a/ci/fmt.sh b/ci/fmt.sh index c3a0c16..a9522dd 100755 --- a/ci/fmt.sh +++ b/ci/fmt.sh @@ -1,4 +1,5 @@ -#!/bin/bash +#!/usr/bin/env bash + echo "Formatting..." go mod tidy diff --git a/ci/image/Dockerfile b/ci/image/Dockerfile index 5b5c0fb..e08d509 100644 --- a/ci/image/Dockerfile +++ b/ci/image/Dockerfile @@ -3,6 +3,8 @@ FROM golang:1 ENV GOFLAGS="-mod=readonly" ENV CI=true +RUN apt update && apt install -y screen + RUN go install golang.org/x/tools/cmd/goimports@latest RUN go install golang.org/x/lint/golint@latest RUN go install github.com/mattn/goveralls@latest diff --git a/ci/lint.sh b/ci/lint.sh index 9a9554c..1e06b61 100755 --- a/ci/lint.sh +++ b/ci/lint.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash echo "Linting..." diff --git a/client.go b/client.go index 4630438..2df93e7 100644 --- a/client.go +++ b/client.go @@ -28,10 +28,13 @@ func RemoteExecer(conn *websocket.Conn) Execer { // Command represents an external command to be run type Command struct { // ID allows reconnecting commands that have a TTY. - ID string - Command string - Args []string + ID string + Command string + Args []string + // Commands with a TTY also require Rows and Cols. TTY bool + Rows uint16 + Cols uint16 Stdin bool UID uint32 GID uint32 @@ -103,7 +106,7 @@ type remoteProcess struct { pid int done chan struct{} closeErr error - exitCode *int + exitMsg *proto.ServerExitCodeHeader readErr error stdin io.WriteCloser stdout pipe @@ -217,9 +220,16 @@ func (r *remoteProcess) listen(ctx context.Context) { r.closeErr = r.conn.Close(websocket.StatusNormalClosure, "normal closure") // If we were in r.conn.Read() we cancel the ctx, the websocket library closes - // the websocket before we have a chance to. This is a normal closure. - if r.closeErr != nil && strings.Contains(r.closeErr.Error(), "already wrote close") && - r.readErr != nil && strings.Contains(r.readErr.Error(), "context canceled") { + // the websocket before we have a chance to. Unfortunately there is a race in the + // the websocket library, where sometimes close frame has already been written before + // we even call r.conn.Close(), and sometimes it gets written during our call to + // r.conn.Close(), so we need to handle both those cases in examining the error that comes + // back. This is a normal closure, so report nil for the error. + readCtxCanceled := r.readErr != nil && strings.Contains(r.readErr.Error(), "context canceled") + alreadyClosed := r.closeErr != nil && + (strings.Contains(r.closeErr.Error(), "already wrote close") || + strings.Contains(r.closeErr.Error(), "WebSocket closed")) + if alreadyClosed && readCtxCanceled { r.closeErr = nil } close(r.done) @@ -260,8 +270,7 @@ func (r *remoteProcess) listen(ctx context.Context) { r.readErr = err return } - - r.exitCode = &exitMsg.ExitCode + r.exitMsg = &exitMsg return } } @@ -312,11 +321,10 @@ func (r *remoteProcess) Wait() error { if r.readErr != nil { return r.readErr } - // when listen() closes r.done, either there must be a read error - // or exitCode is set non-nil, so it's safe to dereference the pointer - // here - if *r.exitCode != 0 { - return ExitError{Code: *r.exitCode} + // when listen() closes r.done, either there must be a read error or exitMsg + // is set non-nil, so it's safe to access members here. + if r.exitMsg.ExitCode != 0 { + return ExitError{code: r.exitMsg.ExitCode, error: r.exitMsg.Error} } return nil } diff --git a/client_test.go b/client_test.go index f665057..96c10b0 100644 --- a/client_test.go +++ b/client_test.go @@ -51,17 +51,25 @@ func TestRemoteStdin(t *testing.T) { } } -func mockConn(ctx context.Context, t *testing.T, options *Options) (*websocket.Conn, *httptest.Server) { +func mockConn(ctx context.Context, t *testing.T, wsepServer *Server, options *Options) (*websocket.Conn, *httptest.Server) { mockServerHandler := func(w http.ResponseWriter, r *http.Request) { ws, err := websocket.Accept(w, r, nil) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } - err = Serve(r.Context(), ws, LocalExecer{}, options) + if wsepServer != nil { + err = wsepServer.Serve(r.Context(), ws, LocalExecer{}, options) + } else { + err = Serve(r.Context(), ws, LocalExecer{}, options) + } if err != nil { - t.Errorf("failed to serve execer: %v", err) - ws.Close(websocket.StatusAbnormalClosure, "failed to serve execer") + // Max reason string length is 123. + errStr := err.Error() + if len(errStr) > 123 { + errStr = errStr[:123] + } + ws.Close(websocket.StatusInternalError, errStr) return } ws.Close(websocket.StatusNormalClosure, "normal closure") @@ -79,7 +87,11 @@ func TestRemoteExec(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() - ws, server := mockConn(ctx, t, nil) + wsepServer := NewServer() + defer wsepServer.Close() + defer assert.Equal(t, "no leaked sessions", 0, wsepServer.SessionCount()) + + ws, server := mockConn(ctx, t, wsepServer, nil) defer server.Close() execer := RemoteExecer(ws) @@ -92,14 +104,20 @@ func TestRemoteClose(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() - ws, server := mockConn(ctx, t, nil) + wsepServer := NewServer() + defer wsepServer.Close() + defer assert.Equal(t, "no leaked sessions", 0, wsepServer.SessionCount()) + + ws, server := mockConn(ctx, t, wsepServer, nil) defer server.Close() execer := RemoteExecer(ws) cmd := Command{ - Command: "/bin/bash", + Command: "sh", TTY: true, Stdin: true, + Cols: 100, + Rows: 100, Env: []string{"TERM=linux"}, } @@ -138,14 +156,20 @@ func TestRemoteCloseNoData(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() - ws, server := mockConn(ctx, t, nil) + wsepServer := NewServer() + defer wsepServer.Close() + defer assert.Equal(t, "no leaked sessions", 0, wsepServer.SessionCount()) + + ws, server := mockConn(ctx, t, wsepServer, nil) defer server.Close() execer := RemoteExecer(ws) cmd := Command{ - Command: "/bin/bash", + Command: "sh", TTY: true, Stdin: true, + Cols: 100, + Rows: 100, Env: []string{"TERM=linux"}, } @@ -171,14 +195,20 @@ func TestRemoteClosePartialRead(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() - ws, server := mockConn(ctx, t, nil) + wsepServer := NewServer() + defer wsepServer.Close() + defer assert.Equal(t, "no leaked sessions", 0, wsepServer.SessionCount()) + + ws, server := mockConn(ctx, t, wsepServer, nil) defer server.Close() execer := RemoteExecer(ws) cmd := Command{ - Command: "/bin/bash", + Command: "sh", TTY: true, Stdin: true, + Cols: 100, + Rows: 100, Env: []string{"TERM=linux"}, } @@ -205,7 +235,11 @@ func TestRemoteExecFail(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() - ws, server := mockConn(ctx, t, nil) + wsepServer := NewServer() + defer wsepServer.Close() + defer assert.Equal(t, "no leaked sessions", 0, wsepServer.SessionCount()) + + ws, server := mockConn(ctx, t, wsepServer, nil) defer server.Close() execer := RemoteExecer(ws) @@ -223,9 +257,10 @@ func testExecerFail(ctx context.Context, t *testing.T, execer Execer) { go io.Copy(ioutil.Discard, process.Stdout()) err = process.Wait() - code, ok := err.(ExitError) + exitErr, ok := err.(ExitError) assert.True(t, "is exit error", ok) - assert.True(t, "exit code is nonzero", code.Code != 0) + assert.True(t, "exit code is nonzero", exitErr.ExitCode() != 0) + assert.Equal(t, "exit error", exitErr.Error(), "exit status 2") assert.Error(t, "wait for process to error", err) } @@ -239,7 +274,11 @@ func TestStderrVsStdout(t *testing.T) { stderr bytes.Buffer ) - ws, server := mockConn(ctx, t, nil) + wsepServer := NewServer() + defer wsepServer.Close() + defer assert.Equal(t, "no leaked sessions", 0, wsepServer.SessionCount()) + + ws, server := mockConn(ctx, t, wsepServer, nil) defer server.Close() execer := RemoteExecer(ws) diff --git a/dev/client/main.go b/dev/client/main.go index 0c28fd4..c270cdc 100644 --- a/dev/client/main.go +++ b/dev/client/main.go @@ -9,6 +9,7 @@ import ( "os" "os/signal" "syscall" + "time" "cdr.dev/wsep" "github.com/spf13/pflag" @@ -20,41 +21,48 @@ import ( ) type notty struct { + timeout time.Duration } func (c *notty) Run(fl *pflag.FlagSet) { - do(fl, false, "") + do(fl, false, "", c.timeout) } func (c *notty) Spec() cli.CommandSpec { return cli.CommandSpec{ Name: "notty", - Usage: "[flags]", + Usage: "[flags] ", Desc: `Run a command without tty enabled.`, } } +func (c *notty) RegisterFlags(fl *pflag.FlagSet) { + fl.DurationVar(&c.timeout, "timeout", 0, "disconnect after specified timeout") +} + type tty struct { - id string + id string + timeout time.Duration } func (c *tty) Run(fl *pflag.FlagSet) { - do(fl, true, c.id) + do(fl, true, c.id, c.timeout) } func (c *tty) Spec() cli.CommandSpec { return cli.CommandSpec{ Name: "tty", - Usage: "[id] [flags]", + Usage: "[flags] ", Desc: `Run a command with tty enabled. Use the same ID to reconnect.`, } } func (c *tty) RegisterFlags(fl *pflag.FlagSet) { fl.StringVar(&c.id, "id", "", "sets id for reconnection") + fl.DurationVar(&c.timeout, "timeout", 0, "disconnect after the specified timeout") } -func do(fl *pflag.FlagSet, tty bool, id string) { +func do(fl *pflag.FlagSet, tty bool, id string, timeout time.Duration) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -62,7 +70,7 @@ func do(fl *pflag.FlagSet, tty bool, id string) { if err != nil { flog.Fatal("failed to dial remote executor: %v", err) } - defer conn.Close(websocket.StatusAbnormalClosure, "terminate process") + defer conn.Close(websocket.StatusNormalClosure, "terminate process") executor := wsep.RemoteExecer(conn) @@ -73,12 +81,19 @@ func do(fl *pflag.FlagSet, tty bool, id string) { if len(fl.Args()) > 1 { args = fl.Args()[1:] } + width, height, err := term.GetSize(int(os.Stdin.Fd())) + if err != nil { + flog.Fatal("unable to get term size") + } process, err := executor.Start(ctx, wsep.Command{ ID: id, Command: fl.Arg(0), Args: args, TTY: tty, Stdin: true, + Rows: uint16(height), + Cols: uint16(width), + Env: []string{"TERM=" + os.Getenv("TERM")}, }) if err != nil { flog.Fatal("failed to start remote command: %v", err) @@ -112,6 +127,15 @@ func do(fl *pflag.FlagSet, tty bool, id string) { io.Copy(stdin, os.Stdin) }() + if timeout != 0 { + timer := time.NewTimer(timeout) + defer timer.Stop() + go func() { + <-timer.C + conn.Close(websocket.StatusNormalClosure, "normal closure") + }() + } + err = process.Wait() if err != nil { flog.Error("process failed: %v", err) diff --git a/dev/server/main.go b/dev/server/main.go index 66020d5..1085d29 100644 --- a/dev/server/main.go +++ b/dev/server/main.go @@ -2,6 +2,7 @@ package main import ( "net/http" + "time" "cdr.dev/wsep" "go.coder.com/flog" @@ -23,10 +24,12 @@ func serve(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) return } - err = wsep.Serve(r.Context(), ws, wsep.LocalExecer{}, nil) + err = wsep.Serve(r.Context(), ws, wsep.LocalExecer{}, &wsep.Options{ + SessionTimeout: 30 * time.Second, + }) if err != nil { flog.Error("failed to serve execer: %v", err) - ws.Close(websocket.StatusAbnormalClosure, "failed to serve execer") + ws.Close(websocket.StatusInternalError, "failed to serve execer") return } ws.Close(websocket.StatusNormalClosure, "normal closure") diff --git a/exec.go b/exec.go index e515289..0ba6def 100644 --- a/exec.go +++ b/exec.go @@ -2,7 +2,6 @@ package wsep import ( "context" - "fmt" "io" "cdr.dev/wsep/internal/proto" @@ -10,11 +9,18 @@ import ( // ExitError is sent when the command terminates. type ExitError struct { - Code int + code int + error string } +// ExitCode returns the exit code of the process. +func (e ExitError) ExitCode() int { + return e.code +} + +// Error returns a string describing why the process errored. func (e ExitError) Error() string { - return fmt.Sprintf("process exited with code %v", e.Code) + return e.error } // Process represents a started command. @@ -32,8 +38,8 @@ type Process interface { Resize(ctx context.Context, rows, cols uint16) error // Wait returns ExitError when the command terminates with a non-zero exit code. Wait() error - // Close terminates the process and underlying connection(s). - // It must be called otherwise a connection or process may leak. + // Close sends a SIGTERM to the process. To force a shutdown cancel the + // context passed into the execer. Close() error } @@ -49,6 +55,8 @@ func mapToProtoCmd(c Command) proto.Command { Args: c.Args, Stdin: c.Stdin, TTY: c.TTY, + Rows: c.Rows, + Cols: c.Cols, UID: c.UID, GID: c.GID, Env: c.Env, @@ -56,12 +64,14 @@ func mapToProtoCmd(c Command) proto.Command { } } -func mapToClientCmd(c proto.Command) Command { - return Command{ +func mapToClientCmd(c proto.Command) *Command { + return &Command{ Command: c.Command, Args: c.Args, Stdin: c.Stdin, TTY: c.TTY, + Rows: c.Rows, + Cols: c.Cols, UID: c.UID, GID: c.GID, Env: c.Env, diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..6066b3b --- /dev/null +++ b/flake.lock @@ -0,0 +1,41 @@ +{ + "nodes": { + "flake-utils": { + "locked": { + "lastModified": 1659877975, + "narHash": "sha256-zllb8aq3YO3h8B/U0/J1WBgAL8EX5yWf5pMj3G0NAmc=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "c0e246b9b83f637f4681389ecabcb2681b4f3af0", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1663235518, + "narHash": "sha256-q8zLK6rK/CLXEguaPgm9yQJcY0VQtOBhAT9EV2UFK/A=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "2277e4c9010b0f27585eb0bed0a86d7cbc079354", + "type": "github" + }, + "original": { + "id": "nixpkgs", + "type": "indirect" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..1b932a4 --- /dev/null +++ b/flake.nix @@ -0,0 +1,19 @@ +{ + description = "wsep"; + + inputs.flake-utils.url = "github:numtide/flake-utils"; + + outputs = { self, nixpkgs, flake-utils }: + flake-utils.lib.eachDefaultSystem + (system: + let pkgs = nixpkgs.legacyPackages.${system}; + in { + devShells.default = pkgs.mkShell { + buildInputs = with pkgs; [ + go + screen + ]; + }; + } + ); +} diff --git a/go.mod b/go.mod index 67ac33f..2d51221 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.14 require ( cdr.dev/slog v1.3.0 - github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2 github.com/creack/pty v1.1.11 github.com/google/go-cmp v0.4.0 github.com/google/uuid v1.3.0 diff --git a/go.sum b/go.sum index 8703f19..c9a9a09 100644 --- a/go.sum +++ b/go.sum @@ -30,8 +30,6 @@ github.com/alecthomas/kong v0.2.1-0.20190708041108-0548c6b1afae/go.mod h1:+inYUS github.com/alecthomas/kong-hcl v0.1.8-0.20190615233001-b21fea9723c8/go.mod h1:MRgZdU3vrFd05IQ89AxUZ0aYdF39BYoNFa324SodPCA= github.com/alecthomas/repr v0.0.0-20180818092828-117648cd9897 h1:p9Sln00KOTlrYkxI1zYWl1QLnEqAqEARBEYa8FQnQcY= github.com/alecthomas/repr v0.0.0-20180818092828-117648cd9897/go.mod h1:xTS7Pm1pD1mvyM075QCDSRqH6qRLXylzS24ZTpRiSzQ= -github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2 h1:7Ip0wMmLHLRJdrloDxZfhMm0xrLXZS8+COSu2bXmEQs= -github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/creack/pty v1.1.11 h1:07n33Z8lZxZ2qwegKbObQohDhXDQxiMMz1NOUGYlesw= diff --git a/internal/proto/clientmsg.go b/internal/proto/clientmsg.go index ae05ad5..19c9add 100644 --- a/internal/proto/clientmsg.go +++ b/internal/proto/clientmsg.go @@ -28,6 +28,8 @@ type Command struct { Args []string `json:"args"` Stdin bool `json:"stdin"` TTY bool `json:"tty"` + Rows uint16 `json:"rows"` + Cols uint16 `json:"cols"` UID uint32 `json:"uid"` GID uint32 `json:"gid"` Env []string `json:"env"` diff --git a/internal/proto/servermsg.go b/internal/proto/servermsg.go index 61614fa..2132a26 100644 --- a/internal/proto/servermsg.go +++ b/internal/proto/servermsg.go @@ -18,4 +18,5 @@ type ServerPidHeader struct { type ServerExitCodeHeader struct { Type string `json:"type"` ExitCode int `json:"exit_code"` + Error string `json:"error"` } diff --git a/localexec.go b/localexec.go index bdb7237..31a1f2c 100644 --- a/localexec.go +++ b/localexec.go @@ -3,6 +3,7 @@ package wsep import ( "io" "os/exec" + "syscall" "golang.org/x/xerrors" ) @@ -29,14 +30,15 @@ func (l *localProcess) Wait() error { err := l.cmd.Wait() if exitErr, ok := err.(*exec.ExitError); ok { return ExitError{ - Code: exitErr.ExitCode(), + code: exitErr.ExitCode(), + error: exitErr.Error(), } } return err } func (l *localProcess) Close() error { - return l.cmd.Process.Kill() + return l.cmd.Process.Signal(syscall.SIGTERM) } func (l *localProcess) Pid() int { diff --git a/localexec_test.go b/localexec_test.go index 8a454c4..af60586 100644 --- a/localexec_test.go +++ b/localexec_test.go @@ -12,6 +12,7 @@ import ( "time" "cdr.dev/slog/sloggers/slogtest/assert" + "golang.org/x/sync/errgroup" ) func TestLocalExec(t *testing.T) { @@ -72,7 +73,8 @@ func TestExitCode(t *testing.T) { err = process.Wait() exitErr, ok := err.(ExitError) assert.True(t, "error is ExitError", ok) - assert.Equal(t, "exit error", exitErr.Code, 127) + assert.Equal(t, "exit error code", exitErr.ExitCode(), 127) + assert.Equal(t, "exit error", exitErr.Error(), "exit status 127") } func TestStdin(t *testing.T) { @@ -139,8 +141,18 @@ func TestStdoutVsStderr(t *testing.T) { }) assert.Success(t, "start command", err) - go io.Copy(&stdout, process.Stdout()) - go io.Copy(&stderr, process.Stderr()) + var outputgroup errgroup.Group + outputgroup.Go(func() error { + _, err := io.Copy(&stdout, process.Stdout()) + return err + }) + outputgroup.Go(func() error { + _, err := io.Copy(&stderr, process.Stderr()) + return err + }) + + err = outputgroup.Wait() + assert.Success(t, "wait for output to drain", err) err = process.Wait() assert.Success(t, "wait for process to complete", err) diff --git a/localexec_unix.go b/localexec_unix.go index a6d322c..d118cd4 100644 --- a/localexec_unix.go +++ b/localexec_unix.go @@ -61,7 +61,10 @@ func (l LocalExecer) Start(ctx context.Context, c Command) (Process, error) { if c.TTY { // This special WSEP_TTY variable helps debug unexpected TTYs. process.cmd.Env = append(process.cmd.Env, "WSEP_TTY=true") - process.tty, err = pty.Start(process.cmd) + process.tty, err = pty.StartWithSize(process.cmd, &pty.Winsize{ + Rows: c.Rows, + Cols: c.Cols, + }) if err != nil { return nil, xerrors.Errorf("start command with pty: %w", err) } diff --git a/server.go b/server.go index c9f4045..1274c72 100644 --- a/server.go +++ b/server.go @@ -7,12 +7,10 @@ import ( "errors" "io" "net" + "os/exec" "sync" "time" - "github.com/armon/circbuf" - "github.com/google/uuid" - "go.coder.com/flog" "golang.org/x/sync/errgroup" "golang.org/x/xerrors" @@ -21,25 +19,81 @@ import ( "cdr.dev/wsep/internal/proto" ) -var reconnectingProcesses sync.Map +const ( + defaultRows = 80 + defaultCols = 24 +) // Options allows configuring the server. type Options struct { - ReconnectingProcessTimeout time.Duration + SessionTimeout time.Duration } +// _sessions is a global map of sessions that exists for backwards +// compatibility. Server should be used instead which locally maintains the +// map. +var _sessions sync.Map + +// _sessionsMutex is a global mutex that exists for backwards compatibility. +// Server should be used instead which locally maintains the mutex. +var _sessionsMutex sync.Mutex + // Serve runs the server-side of wsep. -// The execer may be another wsep connection for chaining. -// Use LocalExecer for local command execution. +// Deprecated: Use Server.Serve() instead. func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Options) error { + srv := Server{sessions: &_sessions, sessionsMutex: &_sessionsMutex} + return srv.Serve(ctx, c, execer, options) +} + +// Server runs the server-side of wsep. The execer may be another wsep +// connection for chaining. Use LocalExecer for local command execution. +type Server struct { + sessions *sync.Map + sessionsMutex *sync.Mutex +} + +// NewServer returns as new wsep server. +func NewServer() *Server { + return &Server{ + sessions: &sync.Map{}, + sessionsMutex: &sync.Mutex{}, + } +} + +// SessionCount returns the number of sessions. +func (srv *Server) SessionCount() int { + var i int + srv.sessions.Range(func(k, rawSession interface{}) bool { + i++ + return true + }) + return i +} + +// Close closes all sessions. +func (srv *Server) Close() { + srv.sessions.Range(func(k, rawSession interface{}) bool { + if s, ok := rawSession.(*Session); ok { + s.Close("test cleanup") + } + return true + }) +} + +// Serve runs the server-side of wsep. The execer may be another wsep +// connection for chaining. Use LocalExecer for local command execution. The +// web socket will not be closed automatically; the caller must call Close() on +// the web socket (ideally with a reason) once Serve yields. +func (srv *Server) Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Options) error { + // The process will get killed when the connection context ends. ctx, cancel := context.WithCancel(ctx) defer cancel() if options == nil { options = &Options{} } - if options.ReconnectingProcessTimeout == 0 { - options.ReconnectingProcessTimeout = 5 * time.Minute + if options.SessionTimeout == 0 { + options.SessionTimeout = 5 * time.Minute } c.SetReadLimit(maxMessageSize) @@ -48,6 +102,7 @@ func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Optio process Process wsNetConn = websocket.NetConn(ctx, c, websocket.MessageBinary) ) + for { if err := ctx.Err(); err != nil { return err @@ -66,8 +121,8 @@ func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Optio } return nil } - headerByt, bodyByt := proto.SplitMessage(byt) + headerByt, bodyByt := proto.SplitMessage(byt) err = json.Unmarshal(headerByt, &header) if err != nil { return xerrors.Errorf("unmarshal header: %w", err) @@ -87,171 +142,49 @@ func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Optio command := mapToClientCmd(header.Command) - // Only allow TTYs with IDs to be reconnected. - if command.TTY && header.ID != "" { - // Enforce a consistent format for IDs. - _, err := uuid.Parse(header.ID) - if err != nil { - flog.Error("%s is not a valid uuid: %w", header.ID, err) + if command.TTY { + // If rows and cols are not provided, default to 80x24. + if command.Rows == 0 { + flog.Info("rows not provided, defaulting to 80") + command.Rows = defaultRows } - - // Get an existing process or create a new one. - var rprocess *reconnectingProcess - rawRProcess, ok := reconnectingProcesses.Load(header.ID) - if ok { - rprocess, ok = rawRProcess.(*reconnectingProcess) - if !ok { - flog.Error("found invalid type in reconnecting process map for ID %s", header.ID) - } - process = rprocess.process - } else { - // The process will be kept alive as long as this context does not - // finish (and as long as the process does not exit on its own). This - // is a new context since the parent context finishes when the request - // ends which would kill the process prematurely. - ctx, cancel := context.WithCancel(context.Background()) - - // The process will be killed if the provided context ends. - process, err = execer.Start(ctx, command) - if err != nil { - cancel() - return err - } - - // Default to buffer 64KB. - ringBuffer, err := circbuf.NewBuffer(64 * 1024) - if err != nil { - cancel() - return xerrors.Errorf("unable to create ring buffer %w", err) - } - - rprocess = &reconnectingProcess{ - activeConns: make(map[string]net.Conn), - process: process, - // Timeouts created with AfterFunc can be reset. - timeout: time.AfterFunc(options.ReconnectingProcessTimeout, cancel), - ringBuffer: ringBuffer, - } - reconnectingProcesses.Store(header.ID, rprocess) - - // If the process exits send the exit code to all listening - // connections then close everything. - go func() { - err = process.Wait() - code := 0 - if exitErr, ok := err.(ExitError); ok { - code = exitErr.Code - } - rprocess.activeConnsMutex.Lock() - for _, conn := range rprocess.activeConns { - _ = sendExitCode(ctx, code, conn) - } - rprocess.activeConnsMutex.Unlock() - rprocess.Close() - reconnectingProcesses.Delete(header.ID) - }() - - // Write to the ring buffer and all connections as we receive stdout. - go func() { - buffer := make([]byte, 32*1024) - for { - read, err := rprocess.process.Stdout().Read(buffer) - if err != nil { - // When the process is closed this is triggered. - break - } - part := buffer[:read] - _, err = rprocess.ringBuffer.Write(part) - if err != nil { - flog.Error("reconnecting process %s write buffer: %v", header.ID, err) - cancel() - break - } - rprocess.activeConnsMutex.Lock() - for _, conn := range rprocess.activeConns { - _ = sendOutput(ctx, part, conn) - } - rprocess.activeConnsMutex.Unlock() - } - }() - } - - err = sendPID(ctx, process.Pid(), wsNetConn) - if err != nil { - flog.Error("failed to send pid %d", process.Pid()) - } - - // Write out the initial contents in the ring buffer. - err = sendOutput(ctx, rprocess.ringBuffer.Bytes(), wsNetConn) - if err != nil { - return xerrors.Errorf("write reconnecting process %s buffer: %w", header.ID, err) + if command.Cols == 0 { + flog.Info("cols not provided, defaulting to 24") + command.Cols = defaultCols } + } - // Store this connection on the reconnecting process. All connections - // stored on the process will receive the process's stdout. - connectionID := uuid.NewString() - rprocess.activeConnsMutex.Lock() - rprocess.activeConns[connectionID] = wsNetConn - rprocess.activeConnsMutex.Unlock() - - // Keep resetting the inactivity timer while this connection is alive. - rprocess.timeout.Reset(options.ReconnectingProcessTimeout) - heartbeat := time.NewTicker(options.ReconnectingProcessTimeout / 2) - defer heartbeat.Stop() - go func() { - for { - select { - // Stop looping once this request finishes. - case <-ctx.Done(): - return - case <-heartbeat.C: - } - rprocess.timeout.Reset(options.ReconnectingProcessTimeout) - } - }() - - // Remove this connection from the process's connection list once the - // connection ends so data is no longer sent to it. - defer func() { - wsNetConn.Close() // REVIEW@asher: Not sure if necessary. - rprocess.activeConnsMutex.Lock() - delete(rprocess.activeConns, connectionID) - rprocess.activeConnsMutex.Unlock() - }() + // Only TTYs with IDs can be reconnected. + if command.TTY && header.ID != "" { + process, err = srv.withSession(ctx, header.ID, command, execer, options) } else { - process, err = execer.Start(ctx, command) - if err != nil { - return err - } - - err = sendPID(ctx, process.Pid(), wsNetConn) - if err != nil { - flog.Error("failed to send pid %d", process.Pid()) - } + process, err = execer.Start(ctx, *command) + } + if err != nil { + return err + } - var outputgroup errgroup.Group - outputgroup.Go(func() error { - return copyWithHeader(process.Stdout(), wsNetConn, proto.Header{Type: proto.TypeStdout}) - }) - outputgroup.Go(func() error { - return copyWithHeader(process.Stderr(), wsNetConn, proto.Header{Type: proto.TypeStderr}) - }) - - go func() { - defer wsNetConn.Close() - _ = outputgroup.Wait() - err = process.Wait() - if exitErr, ok := err.(ExitError); ok { - _ = sendExitCode(ctx, exitErr.Code, wsNetConn) - return - } - _ = sendExitCode(ctx, 0, wsNetConn) - }() - - defer func() { - process.Close() - }() + err = sendPID(ctx, process.Pid(), wsNetConn) + if err != nil { + return xerrors.Errorf("failed to send pid %d: %w", process.Pid(), err) } + + var outputgroup errgroup.Group + outputgroup.Go(func() error { + return copyWithHeader(process.Stdout(), wsNetConn, proto.Header{Type: proto.TypeStdout}) + }) + outputgroup.Go(func() error { + return copyWithHeader(process.Stderr(), wsNetConn, proto.Header{Type: proto.TypeStderr}) + }) + + go func() { + // Wait for the readers to close which happens when the connection + // closes or the process dies. + _ = outputgroup.Wait() + err := process.Wait() + _ = sendExitCode(ctx, err, wsNetConn) + }() + case proto.TypeResize: if process == nil { return errors.New("resize sent before command started") @@ -283,10 +216,61 @@ func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Optio } } -func sendExitCode(_ context.Context, exitCode int, conn net.Conn) error { +// withSession runs the command in a session if screen is available. +func (srv *Server) withSession(ctx context.Context, id string, command *Command, execer Execer, options *Options) (Process, error) { + // If screen is not installed spawn the command normally. + _, err := exec.LookPath("screen") + if err != nil { + flog.Info("`screen` could not be found; session %s will not persist", id) + return execer.Start(ctx, *command) + } + + var s *Session + srv.sessionsMutex.Lock() + if rawSession, ok := srv.sessions.Load(id); ok { + if s, ok = rawSession.(*Session); !ok { + return nil, xerrors.Errorf("found invalid type in session map for ID %s", id) + } + } + + // It is possible that the session has closed but the goroutine that waits for + // that state and deletes it from the map has not ran yet meaning the session + // is still in the map and we grabbed a closing session. Wait for any pending + // state changes and if it is closed create a new session instead. + if s != nil { + state, _ := s.WaitForState(StateReady) + if state > StateReady { + s = nil + } + } + + if s == nil { + s = NewSession(command, execer, options) + srv.sessions.Store(id, s) + go func() { // Remove the session from the map once it closes. + defer srv.sessions.Delete(id) + s.Wait() + }() + } + + srv.sessionsMutex.Unlock() + + return s.Attach(ctx) +} + +func sendExitCode(_ context.Context, err error, conn net.Conn) error { + exitCode := 0 + errorStr := "" + if err != nil { + errorStr = err.Error() + } + if exitErr, ok := err.(ExitError); ok { + exitCode = exitErr.ExitCode() + } header, err := json.Marshal(proto.ServerExitCodeHeader{ Type: proto.TypeExitCode, ExitCode: exitCode, + Error: errorStr, }) if err != nil { return err @@ -304,15 +288,6 @@ func sendPID(_ context.Context, pid int, conn net.Conn) error { return err } -func sendOutput(_ context.Context, data []byte, conn net.Conn) error { - header, err := json.Marshal(proto.ServerPidHeader{Type: proto.TypeStdout}) - if err != nil { - return err - } - _, err = proto.WithHeader(conn, header).Write(data) - return err -} - func copyWithHeader(r io.Reader, w io.Writer, header proto.Header) error { headerByt, err := json.Marshal(header) if err != nil { @@ -325,24 +300,3 @@ func copyWithHeader(r io.Reader, w io.Writer, header proto.Header) error { } return nil } - -type reconnectingProcess struct { - activeConnsMutex sync.Mutex - activeConns map[string]net.Conn - - ringBuffer *circbuf.Buffer - timeout *time.Timer - process Process -} - -// Close ends all connections to the reconnecting process and clears the ring -// buffer. -func (r *reconnectingProcess) Close() { - r.activeConnsMutex.Lock() - defer r.activeConnsMutex.Unlock() - for _, conn := range r.activeConns { - _ = conn.Close() - } - _ = r.process.Close() - r.ringBuffer.Reset() -} diff --git a/session.go b/session.go new file mode 100644 index 0000000..ec5c0bd --- /dev/null +++ b/session.go @@ -0,0 +1,383 @@ +package wsep + +import ( + "bufio" + "context" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "go.coder.com/flog" + "golang.org/x/xerrors" +) + +// State represents the current state of the session. States are sequential and +// will only move forward. +type State int + +const ( + // StateStarting is the default/start state. + StateStarting = iota + // StateReady means the session is ready to be attached. + StateReady + // StateClosing means the session has begun closing. The underlying process + // may still be exiting. + StateClosing + // StateDone means the session has completely shut down and the process has + // exited. + StateDone +) + +// Session represents a `screen` session. +type Session struct { + // command is the original command used to spawn the session. + command *Command + // cond broadcasts session changes and any accompanying errors. + cond *sync.Cond + // configFile is the location of the screen configuration file. + configFile string + // error hold any error that occurred during a state change. It is not safe + // to access outside of cond.L. + error error + // execer is used to spawn the session and ready commands. + execer Execer + // id holds the id of the session for both creating and attaching. This is + // generated uniquely for each session (rather than using the ID provided by + // the client) because without control of the daemon we do not have its PID + // and without the PID screen will do partial matching. Enforcing a UUID + // should guarantee we match on the right session. + id string + // mutex prevents concurrent attaches to the session. This is necessary since + // screen will happily spawn two separate sessions with the same name if + // multiple attaches happen in a close enough interval. We are not able to + // control the daemon ourselves to prevent this because the daemon will spawn + // with a hardcoded 24x80 size which results in confusing padding above the + // prompt once the attach comes in and resizes. + mutex sync.Mutex + // options holds options for configuring the session. + options *Options + // socketsDir is the location of the directory where screen should put its + // sockets. + socketsDir string + // state holds the current session state. It is not safe to access this + // outside of cond.L. + state State + // timer will close the session when it expires. The timer will be reset as + // long as there are active connections. + timer *time.Timer +} + +const attachTimeout = 30 * time.Second + +// NewSession sets up a new session. Any errors with starting are returned on +// Attach(). The session will close itself if nothing is attached for the +// duration of the session timeout. +func NewSession(command *Command, execer Execer, options *Options) *Session { + tempdir := filepath.Join(os.TempDir(), "coder-screen") + s := &Session{ + command: command, + cond: sync.NewCond(&sync.Mutex{}), + configFile: filepath.Join(tempdir, "config"), + execer: execer, + id: uuid.NewString(), + options: options, + state: StateStarting, + socketsDir: filepath.Join(tempdir, "sockets"), + } + go s.lifecycle() + return s +} + +// lifecycle manages the lifecycle of the session. +func (s *Session) lifecycle() { + err := s.ensureSettings() + if err != nil { + s.setState(StateDone, xerrors.Errorf("ensure settings: %w", err)) + return + } + + // The initial timeout for starting up is set here and will probably be far + // shorter than the session timeout in most cases. It should be at least long + // enough for the first screen attach to be able to start up the daemon. + s.timer = time.AfterFunc(attachTimeout, func() { + s.Close("session timeout") + }) + + s.setState(StateReady, nil) + + // Handle the close event by asking screen to quit the session. We have no + // way of knowing when the daemon process dies so the Go side will not get + // cleaned up until the timeout if the process gets killed externally (for + // example via `exit`). + s.WaitForState(StateClosing) + s.timer.Stop() + // If the command errors that the session is already gone that is fine. + err = s.sendCommand(context.Background(), "quit", []string{"No screen session found"}) + if err != nil { + flog.Error("failed to kill session %s: %v", s.id, err) + } else { + err = xerrors.Errorf(fmt.Sprintf("session is done")) + } + s.setState(StateDone, err) +} + +// sendCommand runs a screen command against a session. If the command fails +// with an error matching anything in successErrors it will be considered a +// success state (for example "no session" when quitting). The command will be +// retried until successful, the timeout is reached, or the context ends (in +// which case the context error is returned). +func (s *Session) sendCommand(ctx context.Context, command string, successErrors []string) error { + ctx, cancel := context.WithTimeout(ctx, attachTimeout) + defer cancel() + run := func() (bool, error) { + process, err := s.execer.Start(ctx, Command{ + Command: "screen", + Args: []string{"-S", s.id, "-X", command}, + UID: s.command.UID, + GID: s.command.GID, + Env: append(s.command.Env, "SCREENDIR="+s.socketsDir), + }) + if err != nil { + return true, err + } + stdout := captureStdout(process) + err = process.Wait() + // Try the context error in case it canceled while we waited. + if ctx.Err() != nil { + return true, ctx.Err() + } + if err != nil { + details := <-stdout + for _, se := range successErrors { + if strings.Contains(details, se) { + return true, nil + } + } + } + // Sometimes a command will fail without any error output whatsoever but + // will succeed later so all we can do is keep trying. + return err == nil, nil + } + + // Run immediately. + if done, err := run(); done { + return err + } + + // Then run on a timer. + ticker := time.NewTicker(250 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + if done, err := run(); done { + return err + } + } + } +} + +// Attach attaches to the session, waits for the attach to complete, then +// returns the attached process. +func (s *Session) Attach(ctx context.Context) (Process, error) { + // We need to do this while behind the mutex to ensure another attach does not + // come in and spawn a duplicate session. + s.mutex.Lock() + defer s.mutex.Unlock() + + state, err := s.WaitForState(StateReady) + switch state { + case StateClosing: + return nil, err + case StateDone: + return nil, err + } + + // Abort the heartbeat when the session closes. + ctx, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + s.waitForStateOrContext(ctx, StateClosing) + }() + + go s.heartbeat(ctx) + + // -S is for setting the session's name. + // -x allows attaching to an already attached session. + // -RR reattaches to the daemon or creates the session daemon if missing. + // -q disables the "New screen..." message that appears for five seconds when + // creating a new session with -RR. + // -c is the flag for the config file. + process, err := s.execer.Start(ctx, Command{ + Command: "screen", + Args: append([]string{"-S", s.id, "-xRRqc", s.configFile, s.command.Command}, s.command.Args...), + TTY: s.command.TTY, + Rows: s.command.Rows, + Cols: s.command.Cols, + Stdin: s.command.Stdin, + UID: s.command.UID, + GID: s.command.GID, + Env: append(s.command.Env, "SCREENDIR="+s.socketsDir), + WorkingDir: s.command.WorkingDir, + }) + if err != nil { + cancel() + return nil, err + } + + // Version seems to be the only command without a side effect so use it to + // wait for the session to come up. + err = s.sendCommand(ctx, "version", nil) + if err != nil { + cancel() + return nil, err + } + + return process, err +} + +// heartbeat keeps the session alive while the provided context is not done. +func (s *Session) heartbeat(ctx context.Context) { + // We just connected so reset the timer now in case it is near the end. + s.timer.Reset(s.options.SessionTimeout) + + // Reset when the connection closes to ensure the session stays up for the + // full timeout. + defer s.timer.Reset(s.options.SessionTimeout) + + heartbeat := time.NewTicker(s.options.SessionTimeout / 2) + defer heartbeat.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-heartbeat.C: + } + // The goroutine that cancels the heartbeat on a close state change might + // not run before the next heartbeat which means the heartbeat will start + // the timer again. + state, _ := s.WaitForState(StateReady) + if state > StateReady { + return + } + s.timer.Reset(s.options.SessionTimeout) + } +} + +// Wait waits for the session to close. The underlying process might still be +// exiting. +func (s *Session) Wait() { + s.WaitForState(StateClosing) +} + +// Close attempts to gracefully kill the session's underlying process then waits +// for the process to exit. If the session does not exit in a timely manner it +// forcefully kills the process. +func (s *Session) Close(reason string) { + s.setState(StateClosing, xerrors.Errorf(fmt.Sprintf("session is closing: %s", reason))) + s.WaitForState(StateDone) +} + +// ensureSettings writes config settings and creates the socket directory. +func (s *Session) ensureSettings() error { + settings := []string{ + // Tell screen not to handle motion for xterm* terminals which allows + // scrolling the terminal via the mouse wheel or scroll bar (by default + // screen uses it to cycle through the command history). There does not + // seem to be a way to make screen itself scroll on mouse wheel. tmux can + // do it but then there is no scroll bar and it kicks you into copy mode + // where keys stop working until you exit copy mode which seems like it + // could be confusing. + "termcapinfo xterm* ti@:te@", + // Enable alternate screen emulation otherwise applications get rendered in + // the current window which wipes out visible output resulting in missing + // output when scrolling back with the mouse wheel (copy mode still works + // since that is screen itself scrolling). + "altscreen on", + // Remap the control key to C-s since C-a may be used in applications. C-s + // cannot actually be used anyway since by default it will pause and C-q to + // resume will just kill the browser window. We may not want people using + // the control key anyway since it will not be obvious they are in screen + // and doing things like switching windows makes mouse wheel scroll wonky + // due to the terminal doing the scrolling rather than screen itself (but + // again copy mode will work just fine). + "escape ^Ss", + } + + dir := filepath.Join(os.TempDir(), "coder-screen") + config := filepath.Join(dir, "config") + socketdir := filepath.Join(dir, "sockets") + + err := os.MkdirAll(socketdir, 0o700) + if err != nil { + return err + } + + return os.WriteFile(config, []byte(strings.Join(settings, "\n")), 0o644) +} + +// setState sets and broadcasts the provided state if it is greater than the +// current state and the error if one has not already been set. +func (s *Session) setState(state State, err error) { + s.cond.L.Lock() + defer s.cond.L.Unlock() + // Cannot regress states (for example trying to close after the process is + // done should leave us in the done state and not the closing state). + if state <= s.state { + return + } + // Keep the first error we get. + if s.error == nil { + s.error = err + } + s.state = state + s.cond.Broadcast() +} + +// WaitForState blocks until the state or a greater one is reached. +func (s *Session) WaitForState(state State) (State, error) { + s.cond.L.Lock() + defer s.cond.L.Unlock() + for state > s.state { + s.cond.Wait() + } + return s.state, s.error +} + +// waitForStateOrContext blocks until the state or a greater one is reached or +// the provided context ends. If the context ends all goroutines will be woken. +func (s *Session) waitForStateOrContext(ctx context.Context, state State) { + go func() { + // Wake up when the context ends. + defer s.cond.Broadcast() + <-ctx.Done() + }() + s.cond.L.Lock() + defer s.cond.L.Unlock() + for ctx.Err() == nil && state > s.state { + s.cond.Wait() + } +} + +// captureStdout captures the first line of stdout. Screen emits errors to +// stdout so this allows logging extra context beyond the exit code. +func captureStdout(process Process) <-chan string { + stdout := make(chan string, 1) + go func() { + scanner := bufio.NewScanner(process.Stdout()) + if scanner.Scan() { + stdout <- scanner.Text() + } else { + stdout <- "no further details" + } + }() + return stdout +} diff --git a/tty_test.go b/tty_test.go index d9651bf..684f280 100644 --- a/tty_test.go +++ b/tty_test.go @@ -3,7 +3,9 @@ package wsep import ( "bufio" "context" - "io/ioutil" + "fmt" + "math/rand" + "regexp" "strings" "sync" "testing" @@ -11,180 +13,263 @@ import ( "cdr.dev/slog/sloggers/slogtest/assert" "github.com/google/uuid" - "nhooyr.io/websocket" ) func TestTTY(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + // Run some output in a new session. + server := newServer(t) + ctx, command := newSession(t) + command.ID = "" // No ID so we do not start a reconnectable session. + process1, _ := connect(ctx, t, command, server, nil, "") + expected := writeUnique(t, process1) + assert.True(t, "find initial output", checkStdout(t, process1, expected, []string{})) + + // Connect to the same session. There should not be shared output since + // these end up being separate sessions due to the lack of an ID. + process2, _ := connect(ctx, t, command, server, nil, "") + unexpected := expected + expected = writeUnique(t, process2) + assert.True(t, "find new session output", checkStdout(t, process2, expected, unexpected)) +} - ws, server := mockConn(ctx, t, nil) - defer ws.Close(websocket.StatusInternalError, "") - defer server.Close() +func TestReconnectTTY(t *testing.T) { + t.Run("NoSize", func(t *testing.T) { + server := newServer(t) + ctx, command := newSession(t) + command.Rows = 0 + command.Cols = 0 + ps1, _ := connect(ctx, t, command, server, nil, "") + expected := writeUnique(t, ps1) + assert.True(t, "find initial output", checkStdout(t, ps1, expected, []string{})) + + ps2, _ := connect(ctx, t, command, server, nil, "") + assert.True(t, "find reconnected output", checkStdout(t, ps2, expected, []string{})) + }) - execer := RemoteExecer(ws) - testTTY(ctx, t, execer) -} + t.Run("DeprecatedServe", func(t *testing.T) { + // Do something in the first session. + ctx, command := newSession(t) + process1, _ := connect(ctx, t, command, nil, nil, "") + expected := writeUnique(t, process1) + assert.True(t, "find initial output", checkStdout(t, process1, expected, []string{})) -func testTTY(ctx context.Context, t *testing.T, e Execer) { - process, err := e.Start(ctx, Command{ - Command: "sh", - TTY: true, - Stdin: true, + // Connect to the same session. Should see the same output. + process2, _ := connect(ctx, t, command, nil, nil, "") + assert.True(t, "find reconnected output", checkStdout(t, process2, expected, []string{})) }) - assert.Success(t, "start sh", err) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - - stdout, err := ioutil.ReadAll(process.Stdout()) - assert.Success(t, "read stdout", err) - - t.Logf("bash tty stdout = %s", stdout) - prompt := string(stdout) - assert.True(t, `bash "$" or "#" prompt found`, - strings.HasSuffix(prompt, "$ ") || strings.HasSuffix(prompt, "# ")) - }() - wg.Add(1) - go func() { - defer wg.Done() - - stderr, err := ioutil.ReadAll(process.Stderr()) - assert.Success(t, "read stderr", err) - t.Logf("bash tty stderr = %s", stderr) - assert.True(t, "stderr is empty", len(stderr) == 0) - }() - time.Sleep(3 * time.Second) - - process.Close() - wg.Wait() -} -func TestReconnectTTY(t *testing.T) { - t.Parallel() + t.Run("NoScreen", func(t *testing.T) { + t.Setenv("PATH", "/bin") + + // Run some output in a new session. + server := newServer(t) + ctx, command := newSession(t) + process1, _ := connect(ctx, t, command, server, nil, "") + expected := writeUnique(t, process1) + assert.True(t, "find initial output", checkStdout(t, process1, expected, []string{})) + + // Connect to the same session. There should not be shared output since + // these end up being separate sessions due to the lack of screen. + process2, _ := connect(ctx, t, command, server, nil, "") + unexpected := expected + expected = writeUnique(t, process2) + assert.True(t, "find new session output", checkStdout(t, process2, expected, unexpected)) + }) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + t.Run("Regular", func(t *testing.T) { + t.Parallel() + + // Run some output in a new session. + server := newServer(t) + ctx, command := newSession(t) + process1, disconnect1 := connect(ctx, t, command, server, nil, "") + expected := writeUnique(t, process1) + assert.True(t, "find initial output", checkStdout(t, process1, expected, []string{})) + + // Reconnect and sleep; the inactivity timeout should not trigger since we + // were not disconnected during the timeout. + disconnect1() + process2, disconnect2 := connect(ctx, t, command, server, nil, "") + time.Sleep(time.Second) + expected = append(expected, writeUnique(t, process2)...) + assert.True(t, "find reconnected output", checkStdout(t, process2, expected, []string{})) + + // Make a simultaneously active connection. + process3, disconnect3 := connect(ctx, t, command, server, &Options{ + // Divide the time to test that the heartbeat keeps it open through + // multiple intervals. + SessionTimeout: time.Second / 4, + }, "") + + // Disconnect the previous connection and wait for inactivity. The session + // should stay up because of the second connection. + disconnect2() + time.Sleep(time.Second) + expected = append(expected, writeUnique(t, process3)...) + assert.True(t, "find second connection output", checkStdout(t, process3, expected, []string{})) + + // Disconnect the last connection and wait for inactivity. The next + // connection should start a new session so we should only see new output + // and not any output from the old session. + disconnect3() + time.Sleep(time.Second) + process4, _ := connect(ctx, t, command, server, nil, "") + unexpected := expected + expected = writeUnique(t, process4) + assert.True(t, "find new session output", checkStdout(t, process4, expected, unexpected)) + }) + + t.Run("Alternate", func(t *testing.T) { + t.Parallel() + + // Run an application that enters the alternate screen. + server := newServer(t) + ctx, command := newSession(t) + process1, disconnect1 := connect(ctx, t, command, server, nil, "") + write(t, process1, "./ci/alt.sh") + assert.True(t, "find alt screen", checkStdout(t, process1, []string{"./ci/alt.sh", "ALT SCREEN"}, []string{})) + + // Reconnect; the application should redraw. We should have only the + // application output and not the command that spawned the application. + disconnect1() + process2, disconnect2 := connect(ctx, t, command, server, nil, "") + assert.True(t, "find reconnected alt screen", checkStdout(t, process2, []string{"ALT SCREEN"}, []string{"./ci/alt.sh"})) + + // Exit the application and reconnect. Should now be in a regular shell. + write(t, process2, "q") + disconnect2() + process3, _ := connect(ctx, t, command, server, nil, "") + expected := writeUnique(t, process3) + assert.True(t, "find shell output", checkStdout(t, process3, expected, []string{})) + }) - ws1, server1 := mockConn(ctx, t, &Options{ - ReconnectingProcessTimeout: time.Second, + t.Run("Simultaneous", func(t *testing.T) { + t.Parallel() + + server := newServer(t) + + // Try connecting a bunch of sessions at once. + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + ctx, command := newSession(t) + process1, disconnect1 := connect(ctx, t, command, server, nil, "") + expected := writeUnique(t, process1) + assert.True(t, "find initial output", checkStdout(t, process1, expected, []string{})) + + n := rand.Intn(1000) + time.Sleep(time.Duration(n) * time.Millisecond) + disconnect1() + process2, _ := connect(ctx, t, command, server, nil, "") + expected = append(expected, writeUnique(t, process2)...) + assert.True(t, "find reconnected output", checkStdout(t, process2, expected, []string{})) + }() + } + wg.Wait() }) - defer server1.Close() +} + +// newServer returns a new wsep server. +func newServer(t *testing.T) *Server { + server := NewServer() + t.Cleanup(func() { + server.Close() + assert.Equal(t, "no leaked sessions", 0, server.SessionCount()) + }) + return server +} + +// newSession returns a command for starting/attaching to a session with a +// context for timing out. +func newSession(t *testing.T) (context.Context, Command) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + t.Cleanup(cancel) command := Command{ ID: uuid.NewString(), Command: "sh", TTY: true, Stdin: true, + Cols: defaultCols, + Rows: defaultRows, + Env: []string{"TERM=xterm"}, } - execer1 := RemoteExecer(ws1) - process1, err := execer1.Start(ctx, command) - assert.Success(t, "start sh", err) - - // Write some unique output. - echoCmd := "echo test:$((1+1))" - data := []byte(echoCmd + "\r\n") - _, err = process1.Stdin().Write(data) - assert.Success(t, "write to stdin", err) - expected := []string{echoCmd, "test:2"} - - assert.True(t, "find echo", findEcho(t, process1, expected)) - // Test disconnecting then reconnecting. - process1.Close() - server1.Close() + return ctx, command +} - ws2, server2 := mockConn(ctx, t, &Options{ - ReconnectingProcessTimeout: time.Second, +// connect connects to a wsep server and runs the provided command. +func connect(ctx context.Context, t *testing.T, command Command, wsepServer *Server, options *Options, error string) (Process, func()) { + if options == nil { + options = &Options{SessionTimeout: time.Second} + } + ws, server := mockConn(ctx, t, wsepServer, options) + t.Cleanup(func() { + server.Close() }) - defer server2.Close() - - execer2 := RemoteExecer(ws2) - process2, err := execer2.Start(ctx, command) - assert.Success(t, "attach sh", err) - - // The inactivity timeout should not have been triggered. - time.Sleep(time.Second) - - echoCmd = "echo test:$((2+2))" - data = []byte(echoCmd + "\r\n") - _, err = process2.Stdin().Write(data) - assert.Success(t, "write to stdin", err) - expected = append(expected, echoCmd, "test:4") - assert.True(t, "find echo", findEcho(t, process2, expected)) - - // Test disconnecting while another connection is active. - ws3, server3 := mockConn(ctx, t, &Options{ - // Divide the time to test that the heartbeat keeps it open through multiple - // intervals. - ReconnectingProcessTimeout: time.Second / 4, - }) - defer server3.Close() + process, err := RemoteExecer(ws).Start(ctx, command) + if error != "" { + assert.True(t, fmt.Sprintf("%s contains %s", err.Error(), error), strings.Contains(err.Error(), error)) + } else { + assert.Success(t, "start sh", err) + } - execer3 := RemoteExecer(ws3) - process3, err := execer3.Start(ctx, command) - assert.Success(t, "attach sh", err) + return process, func() { + process.Close() + server.Close() + } +} - process2.Close() - server2.Close() - time.Sleep(time.Second) +// writeUnique writes some unique output to the shell process and returns the +// expected output. +func writeUnique(t *testing.T, process Process) []string { + n := rand.Intn(1000000) + echoCmd := fmt.Sprintf("echo test:$((%d+%d))", n, n) + write(t, process, echoCmd) + return []string{echoCmd, fmt.Sprintf("test:%d", n+n)} +} - // This connection should still be up. - echoCmd = "echo test:$((3+3))" - data = []byte(echoCmd + "\r\n") - _, err = process3.Stdin().Write(data) +// write writes the provided input followed by a newline to the shell process. +func write(t *testing.T, process Process, input string) { + _, err := process.Stdin().Write([]byte(input + "\n")) assert.Success(t, "write to stdin", err) - expected = append(expected, echoCmd, "test:6") - - assert.True(t, "find echo", findEcho(t, process3, expected)) - - // Close the remaining connection and wait for inactivity. - process3.Close() - server3.Close() - time.Sleep(time.Second) +} - // The next connection should start a new process. - ws4, server4 := mockConn(ctx, t, &Options{ - ReconnectingProcessTimeout: time.Second, - }) - defer server4.Close() +const ansi = "[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))" - execer4 := RemoteExecer(ws4) - process4, err := execer4.Start(ctx, command) - assert.Success(t, "attach sh", err) +var re = regexp.MustCompile(ansi) - // This time no echo since it is a new process. - notExpected := expected - echoCmd = "echo done" - data = []byte(echoCmd + "\r\n") - _, err = process4.Stdin().Write(data) - assert.Success(t, "write to stdin", err) - expected = []string{echoCmd, "done"} - assert.True(t, "find echo", findEcho(t, process4, expected, notExpected...)) - assert.Success(t, "context", ctx.Err()) -} - -func findEcho(t *testing.T, process Process, expected []string, notExpected ...string) bool { +// checkStdout ensures that expected is in the stdout in the specified order. +// On the way if anything in unexpected comes up return false. Return once +// everything in expected has been found or EOF. +func checkStdout(t *testing.T, process Process, expected, unexpected []string) bool { + t.Helper() + i := 0 + t.Logf("expected: %s unexpected: %s", expected, unexpected) scanner := bufio.NewScanner(process.Stdout()) -outer: - for _, str := range expected { - for scanner.Scan() { - line := scanner.Text() - t.Logf("bash tty stdout = %s", line) - for _, bad := range notExpected { - if strings.Contains(line, bad) { - return false - } - } + for scanner.Scan() { + line := scanner.Text() + t.Logf("bash tty stdout = %s", re.ReplaceAllString(line, "")) + for _, str := range unexpected { if strings.Contains(line, str) { - continue outer + t.Logf("contains unexpected line %s", line) + return false } } - return false // Reached the end of output without finding str. + if strings.Contains(line, expected[i]) { + t.Logf("contains expected line %s", line) + i = i + 1 + } + if i == len(expected) { + t.Logf("got all expected values from stdout") + return true + } } - return true + t.Logf("reached end of stdout without seeing all expected values") + return false }