From a6cb39dfcadbca80c556cefcb04a62d250cac326 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Tue, 8 Nov 2022 21:41:21 +0000 Subject: [PATCH 1/4] Handle expected close errors Signed-off-by: Spike Curtis --- client.go | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 4630438..f395fcf 100644 --- a/client.go +++ b/client.go @@ -217,9 +217,16 @@ func (r *remoteProcess) listen(ctx context.Context) { r.closeErr = r.conn.Close(websocket.StatusNormalClosure, "normal closure") // If we were in r.conn.Read() we cancel the ctx, the websocket library closes - // the websocket before we have a chance to. This is a normal closure. - if r.closeErr != nil && strings.Contains(r.closeErr.Error(), "already wrote close") && - r.readErr != nil && strings.Contains(r.readErr.Error(), "context canceled") { + // the websocket before we have a chance to. Unfortunately there is a race in the + // the websocket library, where sometimes close frame has already been written before + // we even call r.conn.Close(), and sometimes it gets written during our call to + // r.conn.Close(), so we need to handle both those cases in examining the error that comes + // back. This is a normal closure, so report nil for the error. + readCtxCanceled := r.readErr != nil && strings.Contains(r.readErr.Error(), "context canceled") + alreadyClosed := r.closeErr != nil && + (strings.Contains(r.closeErr.Error(), "already wrote close") || + strings.Contains(r.closeErr.Error(), "WebSocket closed")) + if alreadyClosed && readCtxCanceled { r.closeErr = nil } close(r.done) From da995d89bd47d580f97205862a33ed44b12a9996 Mon Sep 17 00:00:00 2001 From: Asher Date: Tue, 22 Nov 2022 10:33:17 -0600 Subject: [PATCH 2/4] Implement reconnections with screen (#25) * Refactor reconnect test to support sub-tests Going to add an alternate screen test next. * Revert "Add reconnecting ptys (#23)" This partially reverts commit 91201718149964a06c6350679d8cda0161edf887. The new method using screen will not share processes which is a fundamental shift so I think it will be easier to start from scratch. Even though we could keep the UUID check I removed it because it seems cool that you could create your own sessions in the terminal then connect to them in the browser (or vice-versa). * Add test for alternate screen The output test waits for EOF; modify that behavior so we can check that certain strings are not displayed *without* waiting for the timeout. This means to be accurate we should always check for output that should exist after the output that should not exist would have shown up. * Add timeout flag to dev client This makes it easier to test reconnects manually. * Add size to initial connection This way you do not need a subsequent resize and we can have the right size from the get-go. * Prevent prompt from rendering twice in tests * Add Nix flake * Propagate process close error * Implement reconnecting TTY with screen * Encapsulate session logic * Localize session map * Consolidate test scaffolding into helpers I think this helps make the tests a bit more concise. * Test many connections at once * Fix errors not propagating through web socket close Since the server closed the socket the caller has no chance to close with the right code and reason. Also abnormal closure is not a valid close code. * Fix test flake in reading output Without waiting for the copy you can sometimes get "file already closed". I guess process.Wait must have some side effect. --- README.md | 6 +- browser/client.ts | 12 +- ci/alt.sh | 36 ++++ ci/fmt.sh | 3 +- ci/image/Dockerfile | 2 + ci/lint.sh | 2 +- client.go | 23 +-- client_test.go | 69 +++++-- dev/client/main.go | 36 +++- dev/server/main.go | 5 +- exec.go | 24 ++- flake.lock | 41 ++++ flake.nix | 19 ++ go.mod | 1 - go.sum | 2 - internal/proto/clientmsg.go | 2 + internal/proto/servermsg.go | 1 + localexec.go | 6 +- localexec_test.go | 18 +- localexec_unix.go | 5 +- server.go | 331 +++++++++++++------------------ session.go | 385 ++++++++++++++++++++++++++++++++++++ tty_test.go | 355 ++++++++++++++++++++------------- 23 files changed, 985 insertions(+), 399 deletions(-) create mode 100755 ci/alt.sh create mode 100644 flake.lock create mode 100644 flake.nix create mode 100644 session.go diff --git a/README.md b/README.md index d9606c0..709ced0 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ Error handling is omitted for brevity. ```golang conn, _, _ := websocket.Dial(ctx, "ws://remote.exec.addr", nil) -defer conn.Close(websocket.StatusAbnormalClosure, "terminate process") +defer conn.Close(websocket.StatusNormalClosure, "normal closure") execer := wsep.RemoteExecer(conn) process, _ := execer.Start(ctx, wsep.Command{ @@ -25,7 +25,6 @@ go io.Copy(os.Stderr, process.Stderr()) go io.Copy(os.Stdout, process.Stdout()) process.Wait() -conn.Close(websocket.StatusNormalClosure, "normal closure") ``` ### Server @@ -33,10 +32,9 @@ conn.Close(websocket.StatusNormalClosure, "normal closure") ```golang func (s server) ServeHTTP(w http.ResponseWriter, r *http.Request) { conn, _ := websocket.Accept(w, r, nil) + defer conn.Close(websocket.StatusNormalClosure, "normal closure") wsep.Serve(r.Context(), conn, wsep.LocalExecer{}) - - ws.Close(websocket.StatusNormalClosure, "normal closure") } ``` diff --git a/browser/client.ts b/browser/client.ts index 009912e..0a5322e 100644 --- a/browser/client.ts +++ b/browser/client.ts @@ -14,7 +14,7 @@ export interface Command { } export type ClientHeader = - | { type: 'start'; command: Command } + | { type: 'start'; id: string; command: Command; cols: number; rows: number; } | { type: 'stdin' } | { type: 'close_stdin' } | { type: 'resize'; cols: number; rows: number }; @@ -42,8 +42,14 @@ export const closeStdin = (ws: WebSocket) => { ws.send(msg.buffer); }; -export const startCommand = (ws: WebSocket, command: Command) => { - const msg = joinMessage({ type: 'start', command: command }); +export const startCommand = ( + ws: WebSocket, + command: Command, + id: string, + rows: number, + cols: number +) => { + const msg = joinMessage({ type: 'start', command, id, rows, cols }); ws.send(msg.buffer); }; diff --git a/ci/alt.sh b/ci/alt.sh new file mode 100755 index 0000000..b92d60e --- /dev/null +++ b/ci/alt.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +# Script for testing the alt screen. + +# Enter alt screen. +tput smcup + +function display() { + # Clear the screen. + tput clear + # Move cursor to the top left. + tput cup 0 0 + # Display content. + echo "ALT SCREEN" +} + +function redraw() { + display + echo "redrawn" +} + +# Re-display on resize. +trap 'redraw' WINCH + +display + +# The trap will not run while waiting for a command so read input in a loop with +# a timeout. +while true ; do + if read -n 1 -t .1 ; then + # Clear the screen. + tput clear + # Exit alt screen. + tput rmcup + exit + fi +done diff --git a/ci/fmt.sh b/ci/fmt.sh index c3a0c16..a9522dd 100755 --- a/ci/fmt.sh +++ b/ci/fmt.sh @@ -1,4 +1,5 @@ -#!/bin/bash +#!/usr/bin/env bash + echo "Formatting..." go mod tidy diff --git a/ci/image/Dockerfile b/ci/image/Dockerfile index 5b5c0fb..e08d509 100644 --- a/ci/image/Dockerfile +++ b/ci/image/Dockerfile @@ -3,6 +3,8 @@ FROM golang:1 ENV GOFLAGS="-mod=readonly" ENV CI=true +RUN apt update && apt install -y screen + RUN go install golang.org/x/tools/cmd/goimports@latest RUN go install golang.org/x/lint/golint@latest RUN go install github.com/mattn/goveralls@latest diff --git a/ci/lint.sh b/ci/lint.sh index 9a9554c..1e06b61 100755 --- a/ci/lint.sh +++ b/ci/lint.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash echo "Linting..." diff --git a/client.go b/client.go index f395fcf..2df93e7 100644 --- a/client.go +++ b/client.go @@ -28,10 +28,13 @@ func RemoteExecer(conn *websocket.Conn) Execer { // Command represents an external command to be run type Command struct { // ID allows reconnecting commands that have a TTY. - ID string - Command string - Args []string + ID string + Command string + Args []string + // Commands with a TTY also require Rows and Cols. TTY bool + Rows uint16 + Cols uint16 Stdin bool UID uint32 GID uint32 @@ -103,7 +106,7 @@ type remoteProcess struct { pid int done chan struct{} closeErr error - exitCode *int + exitMsg *proto.ServerExitCodeHeader readErr error stdin io.WriteCloser stdout pipe @@ -267,8 +270,7 @@ func (r *remoteProcess) listen(ctx context.Context) { r.readErr = err return } - - r.exitCode = &exitMsg.ExitCode + r.exitMsg = &exitMsg return } } @@ -319,11 +321,10 @@ func (r *remoteProcess) Wait() error { if r.readErr != nil { return r.readErr } - // when listen() closes r.done, either there must be a read error - // or exitCode is set non-nil, so it's safe to dereference the pointer - // here - if *r.exitCode != 0 { - return ExitError{Code: *r.exitCode} + // when listen() closes r.done, either there must be a read error or exitMsg + // is set non-nil, so it's safe to access members here. + if r.exitMsg.ExitCode != 0 { + return ExitError{code: r.exitMsg.ExitCode, error: r.exitMsg.Error} } return nil } diff --git a/client_test.go b/client_test.go index f665057..96c10b0 100644 --- a/client_test.go +++ b/client_test.go @@ -51,17 +51,25 @@ func TestRemoteStdin(t *testing.T) { } } -func mockConn(ctx context.Context, t *testing.T, options *Options) (*websocket.Conn, *httptest.Server) { +func mockConn(ctx context.Context, t *testing.T, wsepServer *Server, options *Options) (*websocket.Conn, *httptest.Server) { mockServerHandler := func(w http.ResponseWriter, r *http.Request) { ws, err := websocket.Accept(w, r, nil) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } - err = Serve(r.Context(), ws, LocalExecer{}, options) + if wsepServer != nil { + err = wsepServer.Serve(r.Context(), ws, LocalExecer{}, options) + } else { + err = Serve(r.Context(), ws, LocalExecer{}, options) + } if err != nil { - t.Errorf("failed to serve execer: %v", err) - ws.Close(websocket.StatusAbnormalClosure, "failed to serve execer") + // Max reason string length is 123. + errStr := err.Error() + if len(errStr) > 123 { + errStr = errStr[:123] + } + ws.Close(websocket.StatusInternalError, errStr) return } ws.Close(websocket.StatusNormalClosure, "normal closure") @@ -79,7 +87,11 @@ func TestRemoteExec(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() - ws, server := mockConn(ctx, t, nil) + wsepServer := NewServer() + defer wsepServer.Close() + defer assert.Equal(t, "no leaked sessions", 0, wsepServer.SessionCount()) + + ws, server := mockConn(ctx, t, wsepServer, nil) defer server.Close() execer := RemoteExecer(ws) @@ -92,14 +104,20 @@ func TestRemoteClose(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() - ws, server := mockConn(ctx, t, nil) + wsepServer := NewServer() + defer wsepServer.Close() + defer assert.Equal(t, "no leaked sessions", 0, wsepServer.SessionCount()) + + ws, server := mockConn(ctx, t, wsepServer, nil) defer server.Close() execer := RemoteExecer(ws) cmd := Command{ - Command: "/bin/bash", + Command: "sh", TTY: true, Stdin: true, + Cols: 100, + Rows: 100, Env: []string{"TERM=linux"}, } @@ -138,14 +156,20 @@ func TestRemoteCloseNoData(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() - ws, server := mockConn(ctx, t, nil) + wsepServer := NewServer() + defer wsepServer.Close() + defer assert.Equal(t, "no leaked sessions", 0, wsepServer.SessionCount()) + + ws, server := mockConn(ctx, t, wsepServer, nil) defer server.Close() execer := RemoteExecer(ws) cmd := Command{ - Command: "/bin/bash", + Command: "sh", TTY: true, Stdin: true, + Cols: 100, + Rows: 100, Env: []string{"TERM=linux"}, } @@ -171,14 +195,20 @@ func TestRemoteClosePartialRead(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() - ws, server := mockConn(ctx, t, nil) + wsepServer := NewServer() + defer wsepServer.Close() + defer assert.Equal(t, "no leaked sessions", 0, wsepServer.SessionCount()) + + ws, server := mockConn(ctx, t, wsepServer, nil) defer server.Close() execer := RemoteExecer(ws) cmd := Command{ - Command: "/bin/bash", + Command: "sh", TTY: true, Stdin: true, + Cols: 100, + Rows: 100, Env: []string{"TERM=linux"}, } @@ -205,7 +235,11 @@ func TestRemoteExecFail(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() - ws, server := mockConn(ctx, t, nil) + wsepServer := NewServer() + defer wsepServer.Close() + defer assert.Equal(t, "no leaked sessions", 0, wsepServer.SessionCount()) + + ws, server := mockConn(ctx, t, wsepServer, nil) defer server.Close() execer := RemoteExecer(ws) @@ -223,9 +257,10 @@ func testExecerFail(ctx context.Context, t *testing.T, execer Execer) { go io.Copy(ioutil.Discard, process.Stdout()) err = process.Wait() - code, ok := err.(ExitError) + exitErr, ok := err.(ExitError) assert.True(t, "is exit error", ok) - assert.True(t, "exit code is nonzero", code.Code != 0) + assert.True(t, "exit code is nonzero", exitErr.ExitCode() != 0) + assert.Equal(t, "exit error", exitErr.Error(), "exit status 2") assert.Error(t, "wait for process to error", err) } @@ -239,7 +274,11 @@ func TestStderrVsStdout(t *testing.T) { stderr bytes.Buffer ) - ws, server := mockConn(ctx, t, nil) + wsepServer := NewServer() + defer wsepServer.Close() + defer assert.Equal(t, "no leaked sessions", 0, wsepServer.SessionCount()) + + ws, server := mockConn(ctx, t, wsepServer, nil) defer server.Close() execer := RemoteExecer(ws) diff --git a/dev/client/main.go b/dev/client/main.go index 0c28fd4..41f1815 100644 --- a/dev/client/main.go +++ b/dev/client/main.go @@ -9,6 +9,7 @@ import ( "os" "os/signal" "syscall" + "time" "cdr.dev/wsep" "github.com/spf13/pflag" @@ -20,41 +21,48 @@ import ( ) type notty struct { + timeout time.Duration } func (c *notty) Run(fl *pflag.FlagSet) { - do(fl, false, "") + do(fl, false, "", c.timeout) } func (c *notty) Spec() cli.CommandSpec { return cli.CommandSpec{ Name: "notty", - Usage: "[flags]", + Usage: "[flags] ", Desc: `Run a command without tty enabled.`, } } +func (c *notty) RegisterFlags(fl *pflag.FlagSet) { + fl.DurationVar(&c.timeout, "timeout", 0, "disconnect after specified timeout") +} + type tty struct { - id string + id string + timeout time.Duration } func (c *tty) Run(fl *pflag.FlagSet) { - do(fl, true, c.id) + do(fl, true, c.id, c.timeout) } func (c *tty) Spec() cli.CommandSpec { return cli.CommandSpec{ Name: "tty", - Usage: "[id] [flags]", + Usage: "[flags] ", Desc: `Run a command with tty enabled. Use the same ID to reconnect.`, } } func (c *tty) RegisterFlags(fl *pflag.FlagSet) { fl.StringVar(&c.id, "id", "", "sets id for reconnection") + fl.DurationVar(&c.timeout, "timeout", 0, "disconnect after the specified timeout") } -func do(fl *pflag.FlagSet, tty bool, id string) { +func do(fl *pflag.FlagSet, tty bool, id string, timeout time.Duration) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -73,12 +81,19 @@ func do(fl *pflag.FlagSet, tty bool, id string) { if len(fl.Args()) > 1 { args = fl.Args()[1:] } + width, height, err := term.GetSize(int(os.Stdin.Fd())) + if err != nil { + flog.Fatal("unable to get term size") + } process, err := executor.Start(ctx, wsep.Command{ ID: id, Command: fl.Arg(0), Args: args, TTY: tty, Stdin: true, + Rows: uint16(height), + Cols: uint16(width), + Env: []string{"TERM=" + os.Getenv("TERM")}, }) if err != nil { flog.Fatal("failed to start remote command: %v", err) @@ -112,6 +127,15 @@ func do(fl *pflag.FlagSet, tty bool, id string) { io.Copy(stdin, os.Stdin) }() + if timeout != 0 { + timer := time.NewTimer(timeout) + defer timer.Stop() + go func() { + <-timer.C + conn.Close(websocket.StatusNormalClosure, "normal closure") + }() + } + err = process.Wait() if err != nil { flog.Error("process failed: %v", err) diff --git a/dev/server/main.go b/dev/server/main.go index 66020d5..fcd94bf 100644 --- a/dev/server/main.go +++ b/dev/server/main.go @@ -2,6 +2,7 @@ package main import ( "net/http" + "time" "cdr.dev/wsep" "go.coder.com/flog" @@ -23,7 +24,9 @@ func serve(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) return } - err = wsep.Serve(r.Context(), ws, wsep.LocalExecer{}, nil) + err = wsep.Serve(r.Context(), ws, wsep.LocalExecer{}, &wsep.Options{ + SessionTimeout: 30 * time.Second, + }) if err != nil { flog.Error("failed to serve execer: %v", err) ws.Close(websocket.StatusAbnormalClosure, "failed to serve execer") diff --git a/exec.go b/exec.go index e515289..0ba6def 100644 --- a/exec.go +++ b/exec.go @@ -2,7 +2,6 @@ package wsep import ( "context" - "fmt" "io" "cdr.dev/wsep/internal/proto" @@ -10,11 +9,18 @@ import ( // ExitError is sent when the command terminates. type ExitError struct { - Code int + code int + error string } +// ExitCode returns the exit code of the process. +func (e ExitError) ExitCode() int { + return e.code +} + +// Error returns a string describing why the process errored. func (e ExitError) Error() string { - return fmt.Sprintf("process exited with code %v", e.Code) + return e.error } // Process represents a started command. @@ -32,8 +38,8 @@ type Process interface { Resize(ctx context.Context, rows, cols uint16) error // Wait returns ExitError when the command terminates with a non-zero exit code. Wait() error - // Close terminates the process and underlying connection(s). - // It must be called otherwise a connection or process may leak. + // Close sends a SIGTERM to the process. To force a shutdown cancel the + // context passed into the execer. Close() error } @@ -49,6 +55,8 @@ func mapToProtoCmd(c Command) proto.Command { Args: c.Args, Stdin: c.Stdin, TTY: c.TTY, + Rows: c.Rows, + Cols: c.Cols, UID: c.UID, GID: c.GID, Env: c.Env, @@ -56,12 +64,14 @@ func mapToProtoCmd(c Command) proto.Command { } } -func mapToClientCmd(c proto.Command) Command { - return Command{ +func mapToClientCmd(c proto.Command) *Command { + return &Command{ Command: c.Command, Args: c.Args, Stdin: c.Stdin, TTY: c.TTY, + Rows: c.Rows, + Cols: c.Cols, UID: c.UID, GID: c.GID, Env: c.Env, diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..6066b3b --- /dev/null +++ b/flake.lock @@ -0,0 +1,41 @@ +{ + "nodes": { + "flake-utils": { + "locked": { + "lastModified": 1659877975, + "narHash": "sha256-zllb8aq3YO3h8B/U0/J1WBgAL8EX5yWf5pMj3G0NAmc=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "c0e246b9b83f637f4681389ecabcb2681b4f3af0", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1663235518, + "narHash": "sha256-q8zLK6rK/CLXEguaPgm9yQJcY0VQtOBhAT9EV2UFK/A=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "2277e4c9010b0f27585eb0bed0a86d7cbc079354", + "type": "github" + }, + "original": { + "id": "nixpkgs", + "type": "indirect" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..1b932a4 --- /dev/null +++ b/flake.nix @@ -0,0 +1,19 @@ +{ + description = "wsep"; + + inputs.flake-utils.url = "github:numtide/flake-utils"; + + outputs = { self, nixpkgs, flake-utils }: + flake-utils.lib.eachDefaultSystem + (system: + let pkgs = nixpkgs.legacyPackages.${system}; + in { + devShells.default = pkgs.mkShell { + buildInputs = with pkgs; [ + go + screen + ]; + }; + } + ); +} diff --git a/go.mod b/go.mod index 67ac33f..2d51221 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.14 require ( cdr.dev/slog v1.3.0 - github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2 github.com/creack/pty v1.1.11 github.com/google/go-cmp v0.4.0 github.com/google/uuid v1.3.0 diff --git a/go.sum b/go.sum index 8703f19..c9a9a09 100644 --- a/go.sum +++ b/go.sum @@ -30,8 +30,6 @@ github.com/alecthomas/kong v0.2.1-0.20190708041108-0548c6b1afae/go.mod h1:+inYUS github.com/alecthomas/kong-hcl v0.1.8-0.20190615233001-b21fea9723c8/go.mod h1:MRgZdU3vrFd05IQ89AxUZ0aYdF39BYoNFa324SodPCA= github.com/alecthomas/repr v0.0.0-20180818092828-117648cd9897 h1:p9Sln00KOTlrYkxI1zYWl1QLnEqAqEARBEYa8FQnQcY= github.com/alecthomas/repr v0.0.0-20180818092828-117648cd9897/go.mod h1:xTS7Pm1pD1mvyM075QCDSRqH6qRLXylzS24ZTpRiSzQ= -github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2 h1:7Ip0wMmLHLRJdrloDxZfhMm0xrLXZS8+COSu2bXmEQs= -github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/creack/pty v1.1.11 h1:07n33Z8lZxZ2qwegKbObQohDhXDQxiMMz1NOUGYlesw= diff --git a/internal/proto/clientmsg.go b/internal/proto/clientmsg.go index ae05ad5..19c9add 100644 --- a/internal/proto/clientmsg.go +++ b/internal/proto/clientmsg.go @@ -28,6 +28,8 @@ type Command struct { Args []string `json:"args"` Stdin bool `json:"stdin"` TTY bool `json:"tty"` + Rows uint16 `json:"rows"` + Cols uint16 `json:"cols"` UID uint32 `json:"uid"` GID uint32 `json:"gid"` Env []string `json:"env"` diff --git a/internal/proto/servermsg.go b/internal/proto/servermsg.go index 61614fa..2132a26 100644 --- a/internal/proto/servermsg.go +++ b/internal/proto/servermsg.go @@ -18,4 +18,5 @@ type ServerPidHeader struct { type ServerExitCodeHeader struct { Type string `json:"type"` ExitCode int `json:"exit_code"` + Error string `json:"error"` } diff --git a/localexec.go b/localexec.go index bdb7237..31a1f2c 100644 --- a/localexec.go +++ b/localexec.go @@ -3,6 +3,7 @@ package wsep import ( "io" "os/exec" + "syscall" "golang.org/x/xerrors" ) @@ -29,14 +30,15 @@ func (l *localProcess) Wait() error { err := l.cmd.Wait() if exitErr, ok := err.(*exec.ExitError); ok { return ExitError{ - Code: exitErr.ExitCode(), + code: exitErr.ExitCode(), + error: exitErr.Error(), } } return err } func (l *localProcess) Close() error { - return l.cmd.Process.Kill() + return l.cmd.Process.Signal(syscall.SIGTERM) } func (l *localProcess) Pid() int { diff --git a/localexec_test.go b/localexec_test.go index 8a454c4..af60586 100644 --- a/localexec_test.go +++ b/localexec_test.go @@ -12,6 +12,7 @@ import ( "time" "cdr.dev/slog/sloggers/slogtest/assert" + "golang.org/x/sync/errgroup" ) func TestLocalExec(t *testing.T) { @@ -72,7 +73,8 @@ func TestExitCode(t *testing.T) { err = process.Wait() exitErr, ok := err.(ExitError) assert.True(t, "error is ExitError", ok) - assert.Equal(t, "exit error", exitErr.Code, 127) + assert.Equal(t, "exit error code", exitErr.ExitCode(), 127) + assert.Equal(t, "exit error", exitErr.Error(), "exit status 127") } func TestStdin(t *testing.T) { @@ -139,8 +141,18 @@ func TestStdoutVsStderr(t *testing.T) { }) assert.Success(t, "start command", err) - go io.Copy(&stdout, process.Stdout()) - go io.Copy(&stderr, process.Stderr()) + var outputgroup errgroup.Group + outputgroup.Go(func() error { + _, err := io.Copy(&stdout, process.Stdout()) + return err + }) + outputgroup.Go(func() error { + _, err := io.Copy(&stderr, process.Stderr()) + return err + }) + + err = outputgroup.Wait() + assert.Success(t, "wait for output to drain", err) err = process.Wait() assert.Success(t, "wait for process to complete", err) diff --git a/localexec_unix.go b/localexec_unix.go index a6d322c..d118cd4 100644 --- a/localexec_unix.go +++ b/localexec_unix.go @@ -61,7 +61,10 @@ func (l LocalExecer) Start(ctx context.Context, c Command) (Process, error) { if c.TTY { // This special WSEP_TTY variable helps debug unexpected TTYs. process.cmd.Env = append(process.cmd.Env, "WSEP_TTY=true") - process.tty, err = pty.Start(process.cmd) + process.tty, err = pty.StartWithSize(process.cmd, &pty.Winsize{ + Rows: c.Rows, + Cols: c.Cols, + }) if err != nil { return nil, xerrors.Errorf("start command with pty: %w", err) } diff --git a/server.go b/server.go index c9f4045..4575a62 100644 --- a/server.go +++ b/server.go @@ -7,12 +7,10 @@ import ( "errors" "io" "net" + "os/exec" "sync" "time" - "github.com/armon/circbuf" - "github.com/google/uuid" - "go.coder.com/flog" "golang.org/x/sync/errgroup" "golang.org/x/xerrors" @@ -21,25 +19,75 @@ import ( "cdr.dev/wsep/internal/proto" ) -var reconnectingProcesses sync.Map - // Options allows configuring the server. type Options struct { - ReconnectingProcessTimeout time.Duration + SessionTimeout time.Duration } +// _sessions is a global map of sessions that exists for backwards +// compatibility. Server should be used instead which locally maintains the +// map. +var _sessions sync.Map + +// _sessionsMutex is a global mutex that exists for backwards compatibility. +// Server should be used instead which locally maintains the mutex. +var _sessionsMutex sync.Mutex + // Serve runs the server-side of wsep. -// The execer may be another wsep connection for chaining. -// Use LocalExecer for local command execution. +// Deprecated: Use Server.Serve() instead. func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Options) error { + srv := Server{sessions: &_sessions, sessionsMutex: &_sessionsMutex} + return srv.Serve(ctx, c, execer, options) +} + +// Server runs the server-side of wsep. The execer may be another wsep +// connection for chaining. Use LocalExecer for local command execution. +type Server struct { + sessions *sync.Map + sessionsMutex *sync.Mutex +} + +// NewServer returns as new wsep server. +func NewServer() *Server { + return &Server{ + sessions: &sync.Map{}, + sessionsMutex: &sync.Mutex{}, + } +} + +// SessionCount returns the number of sessions. +func (srv *Server) SessionCount() int { + var i int + srv.sessions.Range(func(k, rawSession interface{}) bool { + i++ + return true + }) + return i +} + +// Close closes all sessions. +func (srv *Server) Close() { + srv.sessions.Range(func(k, rawSession interface{}) bool { + if s, ok := rawSession.(*Session); ok { + s.Close() + } + return true + }) +} + +// Serve runs the server-side of wsep. The execer may be another wsep +// connection for chaining. Use LocalExecer for local command execution. The +// web socket will not be closed automatically; the caller must call Close() on +// the web socket (ideally with a reason) once Serve yields. +func (srv *Server) Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Options) error { ctx, cancel := context.WithCancel(ctx) defer cancel() if options == nil { options = &Options{} } - if options.ReconnectingProcessTimeout == 0 { - options.ReconnectingProcessTimeout = 5 * time.Minute + if options.SessionTimeout == 0 { + options.SessionTimeout = 5 * time.Minute } c.SetReadLimit(maxMessageSize) @@ -48,6 +96,7 @@ func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Optio process Process wsNetConn = websocket.NetConn(ctx, c, websocket.MessageBinary) ) + for { if err := ctx.Err(); err != nil { return err @@ -66,8 +115,8 @@ func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Optio } return nil } - headerByt, bodyByt := proto.SplitMessage(byt) + headerByt, bodyByt := proto.SplitMessage(byt) err = json.Unmarshal(headerByt, &header) if err != nil { return xerrors.Errorf("unmarshal header: %w", err) @@ -87,171 +136,48 @@ func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Optio command := mapToClientCmd(header.Command) - // Only allow TTYs with IDs to be reconnected. - if command.TTY && header.ID != "" { - // Enforce a consistent format for IDs. - _, err := uuid.Parse(header.ID) - if err != nil { - flog.Error("%s is not a valid uuid: %w", header.ID, err) - } - - // Get an existing process or create a new one. - var rprocess *reconnectingProcess - rawRProcess, ok := reconnectingProcesses.Load(header.ID) - if ok { - rprocess, ok = rawRProcess.(*reconnectingProcess) - if !ok { - flog.Error("found invalid type in reconnecting process map for ID %s", header.ID) - } - process = rprocess.process - } else { - // The process will be kept alive as long as this context does not - // finish (and as long as the process does not exit on its own). This - // is a new context since the parent context finishes when the request - // ends which would kill the process prematurely. - ctx, cancel := context.WithCancel(context.Background()) - - // The process will be killed if the provided context ends. - process, err = execer.Start(ctx, command) - if err != nil { - cancel() - return err - } - - // Default to buffer 64KB. - ringBuffer, err := circbuf.NewBuffer(64 * 1024) - if err != nil { - cancel() - return xerrors.Errorf("unable to create ring buffer %w", err) - } - - rprocess = &reconnectingProcess{ - activeConns: make(map[string]net.Conn), - process: process, - // Timeouts created with AfterFunc can be reset. - timeout: time.AfterFunc(options.ReconnectingProcessTimeout, cancel), - ringBuffer: ringBuffer, - } - reconnectingProcesses.Store(header.ID, rprocess) - - // If the process exits send the exit code to all listening - // connections then close everything. - go func() { - err = process.Wait() - code := 0 - if exitErr, ok := err.(ExitError); ok { - code = exitErr.Code - } - rprocess.activeConnsMutex.Lock() - for _, conn := range rprocess.activeConns { - _ = sendExitCode(ctx, code, conn) - } - rprocess.activeConnsMutex.Unlock() - rprocess.Close() - reconnectingProcesses.Delete(header.ID) - }() - - // Write to the ring buffer and all connections as we receive stdout. - go func() { - buffer := make([]byte, 32*1024) - for { - read, err := rprocess.process.Stdout().Read(buffer) - if err != nil { - // When the process is closed this is triggered. - break - } - part := buffer[:read] - _, err = rprocess.ringBuffer.Write(part) - if err != nil { - flog.Error("reconnecting process %s write buffer: %v", header.ID, err) - cancel() - break - } - rprocess.activeConnsMutex.Lock() - for _, conn := range rprocess.activeConns { - _ = sendOutput(ctx, part, conn) - } - rprocess.activeConnsMutex.Unlock() - } - }() - } - - err = sendPID(ctx, process.Pid(), wsNetConn) - if err != nil { - flog.Error("failed to send pid %d", process.Pid()) - } - - // Write out the initial contents in the ring buffer. - err = sendOutput(ctx, rprocess.ringBuffer.Bytes(), wsNetConn) - if err != nil { - return xerrors.Errorf("write reconnecting process %s buffer: %w", header.ID, err) + if command.TTY { + // Enforce rows and columns so the TTY will be correctly sized. + if command.Rows == 0 || command.Cols == 0 { + return xerrors.Errorf("rows and cols must be non-zero") } + } - // Store this connection on the reconnecting process. All connections - // stored on the process will receive the process's stdout. - connectionID := uuid.NewString() - rprocess.activeConnsMutex.Lock() - rprocess.activeConns[connectionID] = wsNetConn - rprocess.activeConnsMutex.Unlock() - - // Keep resetting the inactivity timer while this connection is alive. - rprocess.timeout.Reset(options.ReconnectingProcessTimeout) - heartbeat := time.NewTicker(options.ReconnectingProcessTimeout / 2) - defer heartbeat.Stop() - go func() { - for { - select { - // Stop looping once this request finishes. - case <-ctx.Done(): - return - case <-heartbeat.C: - } - rprocess.timeout.Reset(options.ReconnectingProcessTimeout) - } - }() - - // Remove this connection from the process's connection list once the - // connection ends so data is no longer sent to it. - defer func() { - wsNetConn.Close() // REVIEW@asher: Not sure if necessary. - rprocess.activeConnsMutex.Lock() - delete(rprocess.activeConns, connectionID) - rprocess.activeConnsMutex.Unlock() - }() - } else { - process, err = execer.Start(ctx, command) + // Only TTYs with IDs can be reconnected. + if command.TTY && header.ID != "" { + command, err = srv.withSession(ctx, header.ID, command, execer, options) if err != nil { return err } + } - err = sendPID(ctx, process.Pid(), wsNetConn) - if err != nil { - flog.Error("failed to send pid %d", process.Pid()) - } + // The process will get killed when the connection context ends. + process, err = execer.Start(ctx, *command) + if err != nil { + return err + } - var outputgroup errgroup.Group - outputgroup.Go(func() error { - return copyWithHeader(process.Stdout(), wsNetConn, proto.Header{Type: proto.TypeStdout}) - }) - outputgroup.Go(func() error { - return copyWithHeader(process.Stderr(), wsNetConn, proto.Header{Type: proto.TypeStderr}) - }) - - go func() { - defer wsNetConn.Close() - _ = outputgroup.Wait() - err = process.Wait() - if exitErr, ok := err.(ExitError); ok { - _ = sendExitCode(ctx, exitErr.Code, wsNetConn) - return - } - _ = sendExitCode(ctx, 0, wsNetConn) - }() - - defer func() { - process.Close() - }() + err = sendPID(ctx, process.Pid(), wsNetConn) + if err != nil { + return xerrors.Errorf("failed to send pid %d: %w", process.Pid(), err) } + + var outputgroup errgroup.Group + outputgroup.Go(func() error { + return copyWithHeader(process.Stdout(), wsNetConn, proto.Header{Type: proto.TypeStdout}) + }) + outputgroup.Go(func() error { + return copyWithHeader(process.Stderr(), wsNetConn, proto.Header{Type: proto.TypeStderr}) + }) + + go func() { + // Wait for the readers to close which happens when the connection + // closes or the process dies. + _ = outputgroup.Wait() + err := process.Wait() + _ = sendExitCode(ctx, err, wsNetConn) + }() + case proto.TypeResize: if process == nil { return errors.New("resize sent before command started") @@ -283,10 +209,47 @@ func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Optio } } -func sendExitCode(_ context.Context, exitCode int, conn net.Conn) error { +// withSession wraps the command in a session if screen is available. +func (srv *Server) withSession(ctx context.Context, id string, command *Command, execer Execer, options *Options) (*Command, error) { + // If screen is not installed spawn the command normally. + _, err := exec.LookPath("screen") + if err != nil { + flog.Info("`screen` could not be found; session %s will not persist", id) + return command, nil + } + + var s *Session + srv.sessionsMutex.Lock() + if rawSession, ok := srv.sessions.Load(id); ok { + if s, ok = rawSession.(*Session); !ok { + return nil, xerrors.Errorf("found invalid type in session map for ID %s", id) + } + } else { + s = NewSession(id, command, execer, options) + srv.sessions.Store(id, s) + go func() { // Remove the session from the map once it closes. + defer srv.sessions.Delete(id) + s.Wait() + }() + } + srv.sessionsMutex.Unlock() + + return s.Attach(ctx) +} + +func sendExitCode(_ context.Context, err error, conn net.Conn) error { + exitCode := 0 + errorStr := "" + if err != nil { + errorStr = err.Error() + } + if exitErr, ok := err.(ExitError); ok { + exitCode = exitErr.ExitCode() + } header, err := json.Marshal(proto.ServerExitCodeHeader{ Type: proto.TypeExitCode, ExitCode: exitCode, + Error: errorStr, }) if err != nil { return err @@ -304,15 +267,6 @@ func sendPID(_ context.Context, pid int, conn net.Conn) error { return err } -func sendOutput(_ context.Context, data []byte, conn net.Conn) error { - header, err := json.Marshal(proto.ServerPidHeader{Type: proto.TypeStdout}) - if err != nil { - return err - } - _, err = proto.WithHeader(conn, header).Write(data) - return err -} - func copyWithHeader(r io.Reader, w io.Writer, header proto.Header) error { headerByt, err := json.Marshal(header) if err != nil { @@ -325,24 +279,3 @@ func copyWithHeader(r io.Reader, w io.Writer, header proto.Header) error { } return nil } - -type reconnectingProcess struct { - activeConnsMutex sync.Mutex - activeConns map[string]net.Conn - - ringBuffer *circbuf.Buffer - timeout *time.Timer - process Process -} - -// Close ends all connections to the reconnecting process and clears the ring -// buffer. -func (r *reconnectingProcess) Close() { - r.activeConnsMutex.Lock() - defer r.activeConnsMutex.Unlock() - for _, conn := range r.activeConns { - _ = conn.Close() - } - _ = r.process.Close() - r.ringBuffer.Reset() -} diff --git a/session.go b/session.go new file mode 100644 index 0000000..b77cf8b --- /dev/null +++ b/session.go @@ -0,0 +1,385 @@ +package wsep + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "golang.org/x/xerrors" +) + +// State represents the current state of the session. States are sequential and +// will only move forward. +type State int + +const ( + // StateStarting is the default/start state. + StateStarting = iota + // StateReady means the session is ready to be attached. + StateReady + // StateClosing means the session has begun closing. The underlying process + // may still be exiting. + StateClosing + // StateDone means the session has completely shut down and the process has + // exited. + StateDone +) + +// Session represents a `screen` session. +type Session struct { + // command is the original command used to spawn the session. + command *Command + // cond broadcasts session changes. + cond *sync.Cond + // configFile is the location of the screen configuration file. + configFile string + // error hold any error that occurred during a state change. + error error + // execer is used to spawn the session and ready commands. + execer Execer + // id holds the id of the session for both creating and attaching. + id string + // options holds options for configuring the session. + options *Options + // socketsDir is the location of the directory where screen should put its + // sockets. + socketsDir string + // state holds the current session state. + state State + // timer will close the session when it expires. + timer *time.Timer +} + +// NewSession creates and immediately starts a new session. Any errors with +// starting are returned on Attach(). The session will close itself if nothing +// is attached for the duration of the session timeout. +func NewSession(id string, command *Command, execer Execer, options *Options) *Session { + tempdir := filepath.Join(os.TempDir(), "coder-screen") + s := &Session{ + command: command, + cond: sync.NewCond(&sync.Mutex{}), + configFile: filepath.Join(tempdir, "config"), + execer: execer, + options: options, + state: StateStarting, + socketsDir: filepath.Join(tempdir, "sockets"), + } + go s.lifecycle(id) + return s +} + +// lifecycle manages the lifecycle of the session. +func (s *Session) lifecycle(id string) { + // When this context is done the process is dead. + ctx, cancel := context.WithCancel(context.Background()) + + // Close the session down immediately if there was an error. + process, err := s.start(ctx, id) + if err != nil { + defer cancel() + s.setState(StateDone, xerrors.Errorf("process start: %w", err)) + return + } + + // Close the session after a timeout. Timeouts created with AfterFunc can be + // reset; we will reset it as long as there are active connections. The + // initial timeout for starting up is set here and will probably be less than + // the session timeout in most cases. It should be at least long enough for + // screen to be able to start up. + s.timer = time.AfterFunc(30*time.Second, s.Close) + + // Emit the done event when the process exits. + go func() { + defer cancel() + err := process.Wait() + if err != nil { + err = xerrors.Errorf("process exit: %w", err) + } + s.setState(StateDone, err) + }() + + // Handle the close event by killing the process. + go func() { + s.waitForState(StateClosing) + s.timer.Stop() + // process.Close() will send a SIGTERM allowing screen to clean up. If we + // kill it abruptly the socket will be left behind which is probably + // harmless but it causes screen -list to show a bunch of dead sessions. + // The user would likely not see this since we have a custom socket + // directory but it seems ideal to let screen clean up anyway. + process.Close() + select { + case <-ctx.Done(): + return // Process exited on its own. + case <-time.After(5 * time.Second): + // Still running; cancel the context to forcefully terminate the process. + cancel() + } + }() + + // Wait until the session is ready to receive attaches. + err = s.waitReady(ctx) + if err != nil { + defer cancel() + s.setState(StateClosing, xerrors.Errorf("session wait: %w", err)) + return + } + + // Once the session is ready external callers have until their provided + // timeout to attach something before the session will close itself. + s.timer.Reset(s.options.SessionTimeout) + s.setState(StateReady, nil) +} + +// start starts the session. +func (s *Session) start(ctx context.Context, id string) (Process, error) { + err := s.ensureSettings() + if err != nil { + return nil, err + } + + // -S is for setting the session's name. + // -Dm causes screen to launch a server tied to this process, letting us + // attach to and kill it with the PID of this process (rather than having + // to do something flaky like run `screen -S id -quit`). + // -c is the flag for the config file. + process, err := s.execer.Start(ctx, Command{ + Command: "screen", + Args: append([]string{"-S", id, "-Dmc", s.configFile, s.command.Command}, s.command.Args...), + UID: s.command.UID, + GID: s.command.GID, + Env: append(s.command.Env, "SCREENDIR="+s.socketsDir), + WorkingDir: s.command.WorkingDir, + }) + if err != nil { + return nil, err + } + + // Screen allows targeting sessions via either the session name or + // .. Using the latter form allows us to differentiate + // between sessions with the same name. For example if a session closes due + // to a timeout and a client reconnects around the same time there can be two + // sessions with the same name while the old session is cleaning up. This is + // only a problem while attaching and not creating since the creation command + // used always creates a new session. + s.id = fmt.Sprintf("%d.%s", process.Pid(), id) + return process, nil +} + +// waitReady waits for the session to be ready by running a command against the +// session until it works since sometimes if you attach too quickly after +// spawning screen it will say the session does not exist. If the provided +// context finishes the wait is aborted and the context's error is returned. +func (s *Session) waitReady(ctx context.Context) error { + check := func() (bool, error) { + // The `version` command seems to be the only command without a side effect + // so use it to check whether the session is up. + process, err := s.execer.Start(ctx, Command{ + Command: "screen", + Args: []string{"-S", s.id, "-X", "version"}, + UID: s.command.UID, + GID: s.command.GID, + Env: append(s.command.Env, "SCREENDIR="+s.socketsDir), + }) + if err != nil { + return true, err + } + err = process.Wait() + // Try the context error in case it canceled while we waited. + if ctx.Err() != nil { + return true, ctx.Err() + } + // Session is ready once we executed without an error. + // TODO: Should we specifically check for "no screen to be attached + // matching"? That might be the only error we actually want to retry and + // otherwise we return immediately. But this text may not be stable between + // screen versions or there could be additional errors that can be retried. + return err == nil, nil + } + + // Check immediately. + if done, err := check(); done { + return err + } + + // Then check on a timer. + ticker := time.NewTicker(250 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + if done, err := check(); done { + return err + } + } + } +} + +// Attach waits for the session to become attachable and returns a command that +// can be used to attach to the session. +func (s *Session) Attach(ctx context.Context) (*Command, error) { + state, err := s.waitForState(StateReady) + switch state { + case StateClosing: + if err == nil { + // No error means s.Close() was called, either by external code or via the + // session timeout. + err = xerrors.Errorf("session is closing") + } + return nil, err + case StateDone: + if err == nil { + // No error means the process exited with zero, probably after being + // killed due to a call to s.Close() either externally or via the session + // timeout. + err = xerrors.Errorf("session is done") + } + return nil, err + } + + // Abort the heartbeat when the session closes. + ctx, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + s.waitForStateOrContext(ctx, StateClosing) + }() + + go s.heartbeat(ctx) + + return &Command{ + Command: "screen", + Args: []string{"-S", s.id, "-xc", s.configFile}, + TTY: s.command.TTY, + Rows: s.command.Rows, + Cols: s.command.Cols, + Stdin: s.command.Stdin, + UID: s.command.UID, + GID: s.command.GID, + Env: append(s.command.Env, "SCREENDIR="+s.socketsDir), + WorkingDir: s.command.WorkingDir, + }, nil +} + +// heartbeat keeps the session alive while the provided context is not done. +func (s *Session) heartbeat(ctx context.Context) { + // We just connected so reset the timer now in case it is near the end. + s.timer.Reset(s.options.SessionTimeout) + + // Reset when the connection closes to ensure the session stays up for the + // full timeout. + defer s.timer.Reset(s.options.SessionTimeout) + + heartbeat := time.NewTicker(s.options.SessionTimeout / 2) + defer heartbeat.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-heartbeat.C: + } + s.timer.Reset(s.options.SessionTimeout) + } +} + +// Wait waits for the session to close. The underlying process might still be +// exiting. +func (s *Session) Wait() { + s.waitForState(StateClosing) +} + +// Close attempts to gracefully kill the session's underlying process then waits +// for the process to exit. If the session does not exit in a timely manner it +// forcefully kills the process. +func (s *Session) Close() { + s.setState(StateClosing, nil) + s.waitForState(StateDone) +} + +// ensureSettings writes config settings and creates the socket directory. +func (s *Session) ensureSettings() error { + settings := []string{ + // Tell screen not to handle motion for xterm* terminals which allows + // scrolling the terminal via the mouse wheel or scroll bar (by default + // screen uses it to cycle through the command history). There does not + // seem to be a way to make screen itself scroll on mouse wheel. tmux can + // do it but then there is no scroll bar and it kicks you into copy mode + // where keys stop working until you exit copy mode which seems like it + // could be confusing. + "termcapinfo xterm* ti@:te@", + // Enable alternate screen emulation otherwise applications get rendered in + // the current window which wipes out visible output resulting in missing + // output when scrolling back with the mouse wheel (copy mode still works + // since that is screen itself scrolling). + "altscreen on", + // Remap the control key to C-s since C-a may be used in applications. C-s + // cannot actually be used anyway since by default it will pause and C-q to + // resume will just kill the browser window. We may not want people using + // the control key anyway since it will not be obvious they are in screen + // and doing things like switching windows makes mouse wheel scroll wonky + // due to the terminal doing the scrolling rather than screen itself (but + // again copy mode will work just fine). + "escape ^Ss", + } + + dir := filepath.Join(os.TempDir(), "coder-screen") + config := filepath.Join(dir, "config") + socketdir := filepath.Join(dir, "sockets") + + err := os.MkdirAll(socketdir, 0o700) + if err != nil { + return err + } + + return os.WriteFile(config, []byte(strings.Join(settings, "\n")), 0o644) +} + +// setState sets and broadcasts the provided state if it is greater than the +// current state and the error if one has not already been set. +func (s *Session) setState(state State, err error) { + s.cond.L.Lock() + defer s.cond.L.Unlock() + // Cannot regress states (for example trying to close after the process is + // done should leave us in the done state and not the closing state). + if state <= s.state { + return + } + // Keep the first error we get. + if s.error == nil { + s.error = err + } + s.state = state + s.cond.Broadcast() +} + +// waitForState blocks until the state or a greater one is reached. +func (s *Session) waitForState(state State) (State, error) { + s.cond.L.Lock() + defer s.cond.L.Unlock() + for state > s.state { + s.cond.Wait() + } + return s.state, s.error +} + +// waitForStateOrContext blocks until the state or a greater one is reached or +// the provided context ends. If the context ends all goroutines will be woken. +func (s *Session) waitForStateOrContext(ctx context.Context, state State) { + go func() { + // Wake up when the context ends. + defer s.cond.Broadcast() + <-ctx.Done() + }() + s.cond.L.Lock() + defer s.cond.L.Unlock() + for ctx.Err() == nil && state > s.state { + s.cond.Wait() + } +} diff --git a/tty_test.go b/tty_test.go index d9651bf..026c124 100644 --- a/tty_test.go +++ b/tty_test.go @@ -3,7 +3,8 @@ package wsep import ( "bufio" "context" - "io/ioutil" + "fmt" + "math/rand" "strings" "sync" "testing" @@ -11,180 +12,250 @@ import ( "cdr.dev/slog/sloggers/slogtest/assert" "github.com/google/uuid" - "nhooyr.io/websocket" ) func TestTTY(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + // Run some output in a new session. + server := newServer(t) + ctx, command := newSession(t) + command.ID = "" // No ID so we do not start a reconnectable session. + process1, _ := connect(ctx, t, command, server, nil, "") + expected := writeUnique(t, process1) + assert.True(t, "find initial output", checkStdout(t, process1, expected, []string{})) + + // Connect to the same session. There should not be shared output since + // these end up being separate sessions due to the lack of an ID. + process2, _ := connect(ctx, t, command, server, nil, "") + unexpected := expected + expected = writeUnique(t, process2) + assert.True(t, "find new session output", checkStdout(t, process2, expected, unexpected)) +} - ws, server := mockConn(ctx, t, nil) - defer ws.Close(websocket.StatusInternalError, "") - defer server.Close() +func TestReconnectTTY(t *testing.T) { + t.Run("NoSize", func(t *testing.T) { + server := newServer(t) + ctx, command := newSession(t) + command.Rows = 0 + connect(ctx, t, command, server, nil, "rows and cols must be non-zero") + }) - execer := RemoteExecer(ws) - testTTY(ctx, t, execer) -} + t.Run("DeprecatedServe", func(t *testing.T) { + // Do something in the first session. + ctx, command := newSession(t) + process1, _ := connect(ctx, t, command, nil, nil, "") + expected := writeUnique(t, process1) + assert.True(t, "find initial output", checkStdout(t, process1, expected, []string{})) -func testTTY(ctx context.Context, t *testing.T, e Execer) { - process, err := e.Start(ctx, Command{ - Command: "sh", - TTY: true, - Stdin: true, + // Connect to the same session. Should see the same output. + process2, _ := connect(ctx, t, command, nil, nil, "") + assert.True(t, "find reconnected output", checkStdout(t, process2, expected, []string{})) }) - assert.Success(t, "start sh", err) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - - stdout, err := ioutil.ReadAll(process.Stdout()) - assert.Success(t, "read stdout", err) - - t.Logf("bash tty stdout = %s", stdout) - prompt := string(stdout) - assert.True(t, `bash "$" or "#" prompt found`, - strings.HasSuffix(prompt, "$ ") || strings.HasSuffix(prompt, "# ")) - }() - wg.Add(1) - go func() { - defer wg.Done() - - stderr, err := ioutil.ReadAll(process.Stderr()) - assert.Success(t, "read stderr", err) - t.Logf("bash tty stderr = %s", stderr) - assert.True(t, "stderr is empty", len(stderr) == 0) - }() - time.Sleep(3 * time.Second) - - process.Close() - wg.Wait() -} -func TestReconnectTTY(t *testing.T) { - t.Parallel() + t.Run("NoScreen", func(t *testing.T) { + t.Setenv("PATH", "/bin") + + // Run some output in a new session. + server := newServer(t) + ctx, command := newSession(t) + process1, _ := connect(ctx, t, command, server, nil, "") + expected := writeUnique(t, process1) + assert.True(t, "find initial output", checkStdout(t, process1, expected, []string{})) + + // Connect to the same session. There should not be shared output since + // these end up being separate sessions due to the lack of screen. + process2, _ := connect(ctx, t, command, server, nil, "") + unexpected := expected + expected = writeUnique(t, process2) + assert.True(t, "find new session output", checkStdout(t, process2, expected, unexpected)) + }) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + t.Run("Regular", func(t *testing.T) { + t.Parallel() + + // Run some output in a new session. + server := newServer(t) + ctx, command := newSession(t) + process1, disconnect1 := connect(ctx, t, command, server, nil, "") + expected := writeUnique(t, process1) + assert.True(t, "find initial output", checkStdout(t, process1, expected, []string{})) + + // Reconnect and sleep; the inactivity timeout should not trigger since we + // were not disconnected during the timeout. + disconnect1() + process2, disconnect2 := connect(ctx, t, command, server, nil, "") + time.Sleep(time.Second) + expected = append(expected, writeUnique(t, process2)...) + assert.True(t, "find reconnected output", checkStdout(t, process2, expected, []string{})) + + // Make a simultaneously active connection. + process3, disconnect3 := connect(ctx, t, command, server, &Options{ + // Divide the time to test that the heartbeat keeps it open through + // multiple intervals. + SessionTimeout: time.Second / 4, + }, "") + + // Disconnect the previous connection and wait for inactivity. The session + // should stay up because of the second connection. + disconnect2() + time.Sleep(time.Second) + expected = append(expected, writeUnique(t, process3)...) + assert.True(t, "find second connection output", checkStdout(t, process3, expected, []string{})) + + // Disconnect the last connection and wait for inactivity. The next + // connection should start a new session so we should only see new output + // and not any output from the old session. + disconnect3() + time.Sleep(time.Second) + process4, _ := connect(ctx, t, command, server, nil, "") + unexpected := expected + expected = writeUnique(t, process4) + assert.True(t, "find new session output", checkStdout(t, process4, expected, unexpected)) + }) + + t.Run("Alternate", func(t *testing.T) { + t.Parallel() + + // Run an application that enters the alternate screen. + server := newServer(t) + ctx, command := newSession(t) + process1, disconnect1 := connect(ctx, t, command, server, nil, "") + write(t, process1, "./ci/alt.sh") + assert.True(t, "find alt screen", checkStdout(t, process1, []string{"./ci/alt.sh", "ALT SCREEN"}, []string{})) + + // Reconnect; the application should redraw. We should have only the + // application output and not the command that spawned the application. + disconnect1() + process2, disconnect2 := connect(ctx, t, command, server, nil, "") + assert.True(t, "find reconnected alt screen", checkStdout(t, process2, []string{"ALT SCREEN"}, []string{"./ci/alt.sh"})) + + // Exit the application and reconnect. Should now be in a regular shell. + write(t, process2, "q") + disconnect2() + process3, _ := connect(ctx, t, command, server, nil, "") + expected := writeUnique(t, process3) + assert.True(t, "find shell output", checkStdout(t, process3, expected, []string{})) + }) + + t.Run("Simultaneous", func(t *testing.T) { + t.Parallel() + + server := newServer(t) + + // Try connecting a bunch of sessions at once. + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + ctx, command := newSession(t) + process1, disconnect1 := connect(ctx, t, command, server, nil, "") + expected := writeUnique(t, process1) + assert.True(t, "find initial output", checkStdout(t, process1, expected, []string{})) + + n := rand.Intn(1000) + time.Sleep(time.Duration(n) * time.Millisecond) + disconnect1() + process2, _ := connect(ctx, t, command, server, nil, "") + expected = append(expected, writeUnique(t, process2)...) + assert.True(t, "find reconnected output", checkStdout(t, process2, expected, []string{})) + }() + } + wg.Wait() + }) +} - ws1, server1 := mockConn(ctx, t, &Options{ - ReconnectingProcessTimeout: time.Second, +// newServer returns a new wsep server. +func newServer(t *testing.T) *Server { + server := NewServer() + t.Cleanup(func() { + server.Close() + assert.Equal(t, "no leaked sessions", 0, server.SessionCount()) }) - defer server1.Close() + return server +} + +// newSession returns a command for starting/attaching to a session with a +// context for timing out. +func newSession(t *testing.T) (context.Context, Command) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(cancel) command := Command{ ID: uuid.NewString(), Command: "sh", TTY: true, Stdin: true, + Cols: 100, + Rows: 100, + Env: []string{"TERM=xterm"}, } - execer1 := RemoteExecer(ws1) - process1, err := execer1.Start(ctx, command) - assert.Success(t, "start sh", err) - - // Write some unique output. - echoCmd := "echo test:$((1+1))" - data := []byte(echoCmd + "\r\n") - _, err = process1.Stdin().Write(data) - assert.Success(t, "write to stdin", err) - expected := []string{echoCmd, "test:2"} - assert.True(t, "find echo", findEcho(t, process1, expected)) - - // Test disconnecting then reconnecting. - process1.Close() - server1.Close() - - ws2, server2 := mockConn(ctx, t, &Options{ - ReconnectingProcessTimeout: time.Second, - }) - defer server2.Close() - - execer2 := RemoteExecer(ws2) - process2, err := execer2.Start(ctx, command) - assert.Success(t, "attach sh", err) - - // The inactivity timeout should not have been triggered. - time.Sleep(time.Second) - - echoCmd = "echo test:$((2+2))" - data = []byte(echoCmd + "\r\n") - _, err = process2.Stdin().Write(data) - assert.Success(t, "write to stdin", err) - expected = append(expected, echoCmd, "test:4") - - assert.True(t, "find echo", findEcho(t, process2, expected)) + return ctx, command +} - // Test disconnecting while another connection is active. - ws3, server3 := mockConn(ctx, t, &Options{ - // Divide the time to test that the heartbeat keeps it open through multiple - // intervals. - ReconnectingProcessTimeout: time.Second / 4, +// connect connects to a wsep server and runs the provided command. +func connect(ctx context.Context, t *testing.T, command Command, wsepServer *Server, options *Options, error string) (Process, func()) { + if options == nil { + options = &Options{SessionTimeout: time.Second} + } + ws, server := mockConn(ctx, t, wsepServer, options) + t.Cleanup(func() { + server.Close() }) - defer server3.Close() - - execer3 := RemoteExecer(ws3) - process3, err := execer3.Start(ctx, command) - assert.Success(t, "attach sh", err) - - process2.Close() - server2.Close() - time.Sleep(time.Second) - // This connection should still be up. - echoCmd = "echo test:$((3+3))" - data = []byte(echoCmd + "\r\n") - _, err = process3.Stdin().Write(data) - assert.Success(t, "write to stdin", err) - expected = append(expected, echoCmd, "test:6") - - assert.True(t, "find echo", findEcho(t, process3, expected)) - - // Close the remaining connection and wait for inactivity. - process3.Close() - server3.Close() - time.Sleep(time.Second) + process, err := RemoteExecer(ws).Start(ctx, command) + if error != "" { + assert.True(t, fmt.Sprintf("%s contains %s", err.Error(), error), strings.Contains(err.Error(), error)) + } else { + assert.Success(t, "start sh", err) + } - // The next connection should start a new process. - ws4, server4 := mockConn(ctx, t, &Options{ - ReconnectingProcessTimeout: time.Second, - }) - defer server4.Close() + return process, func() { + process.Close() + server.Close() + } +} - execer4 := RemoteExecer(ws4) - process4, err := execer4.Start(ctx, command) - assert.Success(t, "attach sh", err) +// writeUnique writes some unique output to the shell process and returns the +// expected output. +func writeUnique(t *testing.T, process Process) []string { + n := rand.Intn(1000000) + echoCmd := fmt.Sprintf("echo test:$((%d+%d))", n, n) + write(t, process, echoCmd) + return []string{echoCmd, fmt.Sprintf("test:%d", n+n)} +} - // This time no echo since it is a new process. - notExpected := expected - echoCmd = "echo done" - data = []byte(echoCmd + "\r\n") - _, err = process4.Stdin().Write(data) +// write writes the provided input followed by a newline to the shell process. +func write(t *testing.T, process Process, input string) { + _, err := process.Stdin().Write([]byte(input + "\n")) assert.Success(t, "write to stdin", err) - expected = []string{echoCmd, "done"} - assert.True(t, "find echo", findEcho(t, process4, expected, notExpected...)) - assert.Success(t, "context", ctx.Err()) } -func findEcho(t *testing.T, process Process, expected []string, notExpected ...string) bool { +// checkStdout ensures that expected is in the stdout in the specified order. +// On the way if anything in unexpected comes up return false. Return once +// everything in expected has been found or EOF. +func checkStdout(t *testing.T, process Process, expected, unexpected []string) bool { + t.Helper() + i := 0 scanner := bufio.NewScanner(process.Stdout()) -outer: - for _, str := range expected { - for scanner.Scan() { - line := scanner.Text() - t.Logf("bash tty stdout = %s", line) - for _, bad := range notExpected { - if strings.Contains(line, bad) { - return false - } - } + for scanner.Scan() { + line := scanner.Text() + t.Logf("bash tty stdout = %s", strings.ReplaceAll(line, "\x1b", "ESC")) + for _, str := range unexpected { if strings.Contains(line, str) { - continue outer + t.Logf("contains unexpected line %s", line) + return false } } - return false // Reached the end of output without finding str. + if strings.Contains(line, expected[i]) { + t.Logf("contains expected line %s", line) + i = i + 1 + } + if i == len(expected) { + return true + } } - return true + return false } From 5ba238940dd502501bd81c9a00792e3507405d12 Mon Sep 17 00:00:00 2001 From: Asher Date: Mon, 19 Dec 2022 14:57:21 -0900 Subject: [PATCH 3/4] Fix attach resize (#29) * StatusAbnormalClosure cannot be set So says the warning emitted by the websocket library every time we try to set StatusAbnormalClosure. * Refactor screen sessions Previously we would spawn the daemon via -Dm then attach with -x once it was ready. This had a flaw: the daemon starts with a hardcoded 24x80 size and when the attach comes in it resizes and leaves a bunch of confusing whitespace above your prompt. Now we skip the daemon spawn and go straight for the attach but with the addition of -RR which lets screen spawn the daemon for us if it does not yet exist. Consequences: 1. We can only allow one attach at a time because screen has no problem creating multiple daemons with the same name. 2. IDs cannot overlap since screen will do partial matching when we do not include the PID which we no longer have. 3. We do not know when the daemon exits so cleanup only happens on the timeout. 4. We have to kill the session by sending the quit command through screen. When we do this it is possible the session is already dead. 5. If the daemon exits and the user reconnects before the timeout the daemon will be respawned all while the Go program remains blissfully unaware assuming it has been up this whole time. Does not change anything in practice, just a bit different in terms of underlying architecture. In some ways this new architecture is actually simpler with roughly the same functionality but it does not support concurrent attaches and has a greater danger of desyncing from screen's own state. * Log expected/unexpected and strip ansi from tests Hopefully makes it a bit easier to see what is going on. * Add a close reason Trying to figure out why a session is prematurely closing. To do this I am setting the error via setState (so I can add the reason at the close call site) rather than checking for a nil error and returning it in the Attach. Alternatively I was thinking of adding a reason arg to setState but only the close state needs a reason and ultimately it was transformed into the error anyway so might as well do it earlier and skip a step. * Increase test timeout This was canceling while the reconnect test was running. * Decrease number of concurrent test sessions Having too many seems to be causing some to exit unexpectedly. --- dev/client/main.go | 2 +- dev/server/main.go | 2 +- server.go | 37 +++--- session.go | 274 ++++++++++++++++++++++----------------------- tty_test.go | 14 ++- 5 files changed, 173 insertions(+), 156 deletions(-) diff --git a/dev/client/main.go b/dev/client/main.go index 41f1815..c270cdc 100644 --- a/dev/client/main.go +++ b/dev/client/main.go @@ -70,7 +70,7 @@ func do(fl *pflag.FlagSet, tty bool, id string, timeout time.Duration) { if err != nil { flog.Fatal("failed to dial remote executor: %v", err) } - defer conn.Close(websocket.StatusAbnormalClosure, "terminate process") + defer conn.Close(websocket.StatusNormalClosure, "terminate process") executor := wsep.RemoteExecer(conn) diff --git a/dev/server/main.go b/dev/server/main.go index fcd94bf..1085d29 100644 --- a/dev/server/main.go +++ b/dev/server/main.go @@ -29,7 +29,7 @@ func serve(w http.ResponseWriter, r *http.Request) { }) if err != nil { flog.Error("failed to serve execer: %v", err) - ws.Close(websocket.StatusAbnormalClosure, "failed to serve execer") + ws.Close(websocket.StatusInternalError, "failed to serve execer") return } ws.Close(websocket.StatusNormalClosure, "normal closure") diff --git a/server.go b/server.go index 4575a62..86815b3 100644 --- a/server.go +++ b/server.go @@ -69,7 +69,7 @@ func (srv *Server) SessionCount() int { func (srv *Server) Close() { srv.sessions.Range(func(k, rawSession interface{}) bool { if s, ok := rawSession.(*Session); ok { - s.Close() + s.Close("test cleanup") } return true }) @@ -80,6 +80,7 @@ func (srv *Server) Close() { // web socket will not be closed automatically; the caller must call Close() on // the web socket (ideally with a reason) once Serve yields. func (srv *Server) Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Options) error { + // The process will get killed when the connection context ends. ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -145,14 +146,10 @@ func (srv *Server) Serve(ctx context.Context, c *websocket.Conn, execer Execer, // Only TTYs with IDs can be reconnected. if command.TTY && header.ID != "" { - command, err = srv.withSession(ctx, header.ID, command, execer, options) - if err != nil { - return err - } + process, err = srv.withSession(ctx, header.ID, command, execer, options) + } else { + process, err = execer.Start(ctx, *command) } - - // The process will get killed when the connection context ends. - process, err = execer.Start(ctx, *command) if err != nil { return err } @@ -209,13 +206,13 @@ func (srv *Server) Serve(ctx context.Context, c *websocket.Conn, execer Execer, } } -// withSession wraps the command in a session if screen is available. -func (srv *Server) withSession(ctx context.Context, id string, command *Command, execer Execer, options *Options) (*Command, error) { +// withSession runs the command in a session if screen is available. +func (srv *Server) withSession(ctx context.Context, id string, command *Command, execer Execer, options *Options) (Process, error) { // If screen is not installed spawn the command normally. _, err := exec.LookPath("screen") if err != nil { flog.Info("`screen` could not be found; session %s will not persist", id) - return command, nil + return execer.Start(ctx, *command) } var s *Session @@ -224,14 +221,28 @@ func (srv *Server) withSession(ctx context.Context, id string, command *Command, if s, ok = rawSession.(*Session); !ok { return nil, xerrors.Errorf("found invalid type in session map for ID %s", id) } - } else { - s = NewSession(id, command, execer, options) + } + + // It is possible that the session has closed but the goroutine that waits for + // that state and deletes it from the map has not ran yet meaning the session + // is still in the map and we grabbed a closing session. Wait for any pending + // state changes and if it is closed create a new session instead. + if s != nil { + state, _ := s.WaitForState(StateReady) + if state > StateReady { + s = nil + } + } + + if s == nil { + s = NewSession(command, execer, options) srv.sessions.Store(id, s) go func() { // Remove the session from the map once it closes. defer srv.sessions.Delete(id) s.Wait() }() } + srv.sessionsMutex.Unlock() return s.Attach(ctx) diff --git a/session.go b/session.go index b77cf8b..ec5c0bd 100644 --- a/session.go +++ b/session.go @@ -1,6 +1,7 @@ package wsep import ( + "bufio" "context" "fmt" "os" @@ -9,6 +10,8 @@ import ( "sync" "time" + "github.com/google/uuid" + "go.coder.com/flog" "golang.org/x/xerrors" ) @@ -33,154 +36,107 @@ const ( type Session struct { // command is the original command used to spawn the session. command *Command - // cond broadcasts session changes. + // cond broadcasts session changes and any accompanying errors. cond *sync.Cond // configFile is the location of the screen configuration file. configFile string - // error hold any error that occurred during a state change. + // error hold any error that occurred during a state change. It is not safe + // to access outside of cond.L. error error // execer is used to spawn the session and ready commands. execer Execer - // id holds the id of the session for both creating and attaching. + // id holds the id of the session for both creating and attaching. This is + // generated uniquely for each session (rather than using the ID provided by + // the client) because without control of the daemon we do not have its PID + // and without the PID screen will do partial matching. Enforcing a UUID + // should guarantee we match on the right session. id string + // mutex prevents concurrent attaches to the session. This is necessary since + // screen will happily spawn two separate sessions with the same name if + // multiple attaches happen in a close enough interval. We are not able to + // control the daemon ourselves to prevent this because the daemon will spawn + // with a hardcoded 24x80 size which results in confusing padding above the + // prompt once the attach comes in and resizes. + mutex sync.Mutex // options holds options for configuring the session. options *Options // socketsDir is the location of the directory where screen should put its // sockets. socketsDir string - // state holds the current session state. + // state holds the current session state. It is not safe to access this + // outside of cond.L. state State - // timer will close the session when it expires. + // timer will close the session when it expires. The timer will be reset as + // long as there are active connections. timer *time.Timer } -// NewSession creates and immediately starts a new session. Any errors with -// starting are returned on Attach(). The session will close itself if nothing -// is attached for the duration of the session timeout. -func NewSession(id string, command *Command, execer Execer, options *Options) *Session { +const attachTimeout = 30 * time.Second + +// NewSession sets up a new session. Any errors with starting are returned on +// Attach(). The session will close itself if nothing is attached for the +// duration of the session timeout. +func NewSession(command *Command, execer Execer, options *Options) *Session { tempdir := filepath.Join(os.TempDir(), "coder-screen") s := &Session{ command: command, cond: sync.NewCond(&sync.Mutex{}), configFile: filepath.Join(tempdir, "config"), execer: execer, + id: uuid.NewString(), options: options, state: StateStarting, socketsDir: filepath.Join(tempdir, "sockets"), } - go s.lifecycle(id) + go s.lifecycle() return s } // lifecycle manages the lifecycle of the session. -func (s *Session) lifecycle(id string) { - // When this context is done the process is dead. - ctx, cancel := context.WithCancel(context.Background()) - - // Close the session down immediately if there was an error. - process, err := s.start(ctx, id) +func (s *Session) lifecycle() { + err := s.ensureSettings() if err != nil { - defer cancel() - s.setState(StateDone, xerrors.Errorf("process start: %w", err)) + s.setState(StateDone, xerrors.Errorf("ensure settings: %w", err)) return } - // Close the session after a timeout. Timeouts created with AfterFunc can be - // reset; we will reset it as long as there are active connections. The - // initial timeout for starting up is set here and will probably be less than - // the session timeout in most cases. It should be at least long enough for - // screen to be able to start up. - s.timer = time.AfterFunc(30*time.Second, s.Close) - - // Emit the done event when the process exits. - go func() { - defer cancel() - err := process.Wait() - if err != nil { - err = xerrors.Errorf("process exit: %w", err) - } - s.setState(StateDone, err) - }() - - // Handle the close event by killing the process. - go func() { - s.waitForState(StateClosing) - s.timer.Stop() - // process.Close() will send a SIGTERM allowing screen to clean up. If we - // kill it abruptly the socket will be left behind which is probably - // harmless but it causes screen -list to show a bunch of dead sessions. - // The user would likely not see this since we have a custom socket - // directory but it seems ideal to let screen clean up anyway. - process.Close() - select { - case <-ctx.Done(): - return // Process exited on its own. - case <-time.After(5 * time.Second): - // Still running; cancel the context to forcefully terminate the process. - cancel() - } - }() - - // Wait until the session is ready to receive attaches. - err = s.waitReady(ctx) - if err != nil { - defer cancel() - s.setState(StateClosing, xerrors.Errorf("session wait: %w", err)) - return - } + // The initial timeout for starting up is set here and will probably be far + // shorter than the session timeout in most cases. It should be at least long + // enough for the first screen attach to be able to start up the daemon. + s.timer = time.AfterFunc(attachTimeout, func() { + s.Close("session timeout") + }) - // Once the session is ready external callers have until their provided - // timeout to attach something before the session will close itself. - s.timer.Reset(s.options.SessionTimeout) s.setState(StateReady, nil) -} -// start starts the session. -func (s *Session) start(ctx context.Context, id string) (Process, error) { - err := s.ensureSettings() + // Handle the close event by asking screen to quit the session. We have no + // way of knowing when the daemon process dies so the Go side will not get + // cleaned up until the timeout if the process gets killed externally (for + // example via `exit`). + s.WaitForState(StateClosing) + s.timer.Stop() + // If the command errors that the session is already gone that is fine. + err = s.sendCommand(context.Background(), "quit", []string{"No screen session found"}) if err != nil { - return nil, err + flog.Error("failed to kill session %s: %v", s.id, err) + } else { + err = xerrors.Errorf(fmt.Sprintf("session is done")) } - - // -S is for setting the session's name. - // -Dm causes screen to launch a server tied to this process, letting us - // attach to and kill it with the PID of this process (rather than having - // to do something flaky like run `screen -S id -quit`). - // -c is the flag for the config file. - process, err := s.execer.Start(ctx, Command{ - Command: "screen", - Args: append([]string{"-S", id, "-Dmc", s.configFile, s.command.Command}, s.command.Args...), - UID: s.command.UID, - GID: s.command.GID, - Env: append(s.command.Env, "SCREENDIR="+s.socketsDir), - WorkingDir: s.command.WorkingDir, - }) - if err != nil { - return nil, err - } - - // Screen allows targeting sessions via either the session name or - // .. Using the latter form allows us to differentiate - // between sessions with the same name. For example if a session closes due - // to a timeout and a client reconnects around the same time there can be two - // sessions with the same name while the old session is cleaning up. This is - // only a problem while attaching and not creating since the creation command - // used always creates a new session. - s.id = fmt.Sprintf("%d.%s", process.Pid(), id) - return process, nil + s.setState(StateDone, err) } -// waitReady waits for the session to be ready by running a command against the -// session until it works since sometimes if you attach too quickly after -// spawning screen it will say the session does not exist. If the provided -// context finishes the wait is aborted and the context's error is returned. -func (s *Session) waitReady(ctx context.Context) error { - check := func() (bool, error) { - // The `version` command seems to be the only command without a side effect - // so use it to check whether the session is up. +// sendCommand runs a screen command against a session. If the command fails +// with an error matching anything in successErrors it will be considered a +// success state (for example "no session" when quitting). The command will be +// retried until successful, the timeout is reached, or the context ends (in +// which case the context error is returned). +func (s *Session) sendCommand(ctx context.Context, command string, successErrors []string) error { + ctx, cancel := context.WithTimeout(ctx, attachTimeout) + defer cancel() + run := func() (bool, error) { process, err := s.execer.Start(ctx, Command{ Command: "screen", - Args: []string{"-S", s.id, "-X", "version"}, + Args: []string{"-S", s.id, "-X", command}, UID: s.command.UID, GID: s.command.GID, Env: append(s.command.Env, "SCREENDIR="+s.socketsDir), @@ -188,25 +144,31 @@ func (s *Session) waitReady(ctx context.Context) error { if err != nil { return true, err } + stdout := captureStdout(process) err = process.Wait() // Try the context error in case it canceled while we waited. if ctx.Err() != nil { return true, ctx.Err() } - // Session is ready once we executed without an error. - // TODO: Should we specifically check for "no screen to be attached - // matching"? That might be the only error we actually want to retry and - // otherwise we return immediately. But this text may not be stable between - // screen versions or there could be additional errors that can be retried. + if err != nil { + details := <-stdout + for _, se := range successErrors { + if strings.Contains(details, se) { + return true, nil + } + } + } + // Sometimes a command will fail without any error output whatsoever but + // will succeed later so all we can do is keep trying. return err == nil, nil } - // Check immediately. - if done, err := check(); done { + // Run immediately. + if done, err := run(); done { return err } - // Then check on a timer. + // Then run on a timer. ticker := time.NewTicker(250 * time.Millisecond) defer ticker.Stop() @@ -215,32 +177,26 @@ func (s *Session) waitReady(ctx context.Context) error { case <-ctx.Done(): return ctx.Err() case <-ticker.C: - if done, err := check(); done { + if done, err := run(); done { return err } } } } -// Attach waits for the session to become attachable and returns a command that -// can be used to attach to the session. -func (s *Session) Attach(ctx context.Context) (*Command, error) { - state, err := s.waitForState(StateReady) +// Attach attaches to the session, waits for the attach to complete, then +// returns the attached process. +func (s *Session) Attach(ctx context.Context) (Process, error) { + // We need to do this while behind the mutex to ensure another attach does not + // come in and spawn a duplicate session. + s.mutex.Lock() + defer s.mutex.Unlock() + + state, err := s.WaitForState(StateReady) switch state { case StateClosing: - if err == nil { - // No error means s.Close() was called, either by external code or via the - // session timeout. - err = xerrors.Errorf("session is closing") - } return nil, err case StateDone: - if err == nil { - // No error means the process exited with zero, probably after being - // killed due to a call to s.Close() either externally or via the session - // timeout. - err = xerrors.Errorf("session is done") - } return nil, err } @@ -253,9 +209,15 @@ func (s *Session) Attach(ctx context.Context) (*Command, error) { go s.heartbeat(ctx) - return &Command{ + // -S is for setting the session's name. + // -x allows attaching to an already attached session. + // -RR reattaches to the daemon or creates the session daemon if missing. + // -q disables the "New screen..." message that appears for five seconds when + // creating a new session with -RR. + // -c is the flag for the config file. + process, err := s.execer.Start(ctx, Command{ Command: "screen", - Args: []string{"-S", s.id, "-xc", s.configFile}, + Args: append([]string{"-S", s.id, "-xRRqc", s.configFile, s.command.Command}, s.command.Args...), TTY: s.command.TTY, Rows: s.command.Rows, Cols: s.command.Cols, @@ -264,7 +226,21 @@ func (s *Session) Attach(ctx context.Context) (*Command, error) { GID: s.command.GID, Env: append(s.command.Env, "SCREENDIR="+s.socketsDir), WorkingDir: s.command.WorkingDir, - }, nil + }) + if err != nil { + cancel() + return nil, err + } + + // Version seems to be the only command without a side effect so use it to + // wait for the session to come up. + err = s.sendCommand(ctx, "version", nil) + if err != nil { + cancel() + return nil, err + } + + return process, err } // heartbeat keeps the session alive while the provided context is not done. @@ -285,6 +261,13 @@ func (s *Session) heartbeat(ctx context.Context) { return case <-heartbeat.C: } + // The goroutine that cancels the heartbeat on a close state change might + // not run before the next heartbeat which means the heartbeat will start + // the timer again. + state, _ := s.WaitForState(StateReady) + if state > StateReady { + return + } s.timer.Reset(s.options.SessionTimeout) } } @@ -292,15 +275,15 @@ func (s *Session) heartbeat(ctx context.Context) { // Wait waits for the session to close. The underlying process might still be // exiting. func (s *Session) Wait() { - s.waitForState(StateClosing) + s.WaitForState(StateClosing) } // Close attempts to gracefully kill the session's underlying process then waits // for the process to exit. If the session does not exit in a timely manner it // forcefully kills the process. -func (s *Session) Close() { - s.setState(StateClosing, nil) - s.waitForState(StateDone) +func (s *Session) Close(reason string) { + s.setState(StateClosing, xerrors.Errorf(fmt.Sprintf("session is closing: %s", reason))) + s.WaitForState(StateDone) } // ensureSettings writes config settings and creates the socket directory. @@ -359,8 +342,8 @@ func (s *Session) setState(state State, err error) { s.cond.Broadcast() } -// waitForState blocks until the state or a greater one is reached. -func (s *Session) waitForState(state State) (State, error) { +// WaitForState blocks until the state or a greater one is reached. +func (s *Session) WaitForState(state State) (State, error) { s.cond.L.Lock() defer s.cond.L.Unlock() for state > s.state { @@ -383,3 +366,18 @@ func (s *Session) waitForStateOrContext(ctx context.Context, state State) { s.cond.Wait() } } + +// captureStdout captures the first line of stdout. Screen emits errors to +// stdout so this allows logging extra context beyond the exit code. +func captureStdout(process Process) <-chan string { + stdout := make(chan string, 1) + go func() { + scanner := bufio.NewScanner(process.Stdout()) + if scanner.Scan() { + stdout <- scanner.Text() + } else { + stdout <- "no further details" + } + }() + return stdout +} diff --git a/tty_test.go b/tty_test.go index 026c124..5728d1e 100644 --- a/tty_test.go +++ b/tty_test.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "math/rand" + "regexp" "strings" "sync" "testing" @@ -145,7 +146,7 @@ func TestReconnectTTY(t *testing.T) { // Try connecting a bunch of sessions at once. var wg sync.WaitGroup - for i := 0; i < 100; i++ { + for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() @@ -179,7 +180,7 @@ func newServer(t *testing.T) *Server { // newSession returns a command for starting/attaching to a session with a // context for timing out. func newSession(t *testing.T) (context.Context, Command) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) t.Cleanup(cancel) command := Command{ @@ -233,16 +234,21 @@ func write(t *testing.T, process Process, input string) { assert.Success(t, "write to stdin", err) } +const ansi = "[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))" + +var re = regexp.MustCompile(ansi) + // checkStdout ensures that expected is in the stdout in the specified order. // On the way if anything in unexpected comes up return false. Return once // everything in expected has been found or EOF. func checkStdout(t *testing.T, process Process, expected, unexpected []string) bool { t.Helper() i := 0 + t.Logf("expected: %s unexpected: %s", expected, unexpected) scanner := bufio.NewScanner(process.Stdout()) for scanner.Scan() { line := scanner.Text() - t.Logf("bash tty stdout = %s", strings.ReplaceAll(line, "\x1b", "ESC")) + t.Logf("bash tty stdout = %s", re.ReplaceAllString(line, "")) for _, str := range unexpected { if strings.Contains(line, str) { t.Logf("contains unexpected line %s", line) @@ -254,8 +260,10 @@ func checkStdout(t *testing.T, process Process, expected, unexpected []string) b i = i + 1 } if i == len(expected) { + t.Logf("got all expected values from stdout") return true } } + t.Logf("reached end of stdout without seeing all expected values") return false } From 30cd2740af81190fc43e9891ceb47755b5511655 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Tue, 10 Jan 2023 13:13:40 +0000 Subject: [PATCH 4/4] fix: Command: default rows and cols if set to 0 --- server.go | 16 +++++++++++++--- tty_test.go | 12 +++++++++--- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/server.go b/server.go index 86815b3..1274c72 100644 --- a/server.go +++ b/server.go @@ -19,6 +19,11 @@ import ( "cdr.dev/wsep/internal/proto" ) +const ( + defaultRows = 80 + defaultCols = 24 +) + // Options allows configuring the server. type Options struct { SessionTimeout time.Duration @@ -138,9 +143,14 @@ func (srv *Server) Serve(ctx context.Context, c *websocket.Conn, execer Execer, command := mapToClientCmd(header.Command) if command.TTY { - // Enforce rows and columns so the TTY will be correctly sized. - if command.Rows == 0 || command.Cols == 0 { - return xerrors.Errorf("rows and cols must be non-zero") + // If rows and cols are not provided, default to 80x24. + if command.Rows == 0 { + flog.Info("rows not provided, defaulting to 80") + command.Rows = defaultRows + } + if command.Cols == 0 { + flog.Info("cols not provided, defaulting to 24") + command.Cols = defaultCols } } diff --git a/tty_test.go b/tty_test.go index 5728d1e..684f280 100644 --- a/tty_test.go +++ b/tty_test.go @@ -39,7 +39,13 @@ func TestReconnectTTY(t *testing.T) { server := newServer(t) ctx, command := newSession(t) command.Rows = 0 - connect(ctx, t, command, server, nil, "rows and cols must be non-zero") + command.Cols = 0 + ps1, _ := connect(ctx, t, command, server, nil, "") + expected := writeUnique(t, ps1) + assert.True(t, "find initial output", checkStdout(t, ps1, expected, []string{})) + + ps2, _ := connect(ctx, t, command, server, nil, "") + assert.True(t, "find reconnected output", checkStdout(t, ps2, expected, []string{})) }) t.Run("DeprecatedServe", func(t *testing.T) { @@ -188,8 +194,8 @@ func newSession(t *testing.T) (context.Context, Command) { Command: "sh", TTY: true, Stdin: true, - Cols: 100, - Rows: 100, + Cols: defaultCols, + Rows: defaultRows, Env: []string{"TERM=xterm"}, }