From 33200a0c92b4fc6a7130c9122b4a627de09c1952 Mon Sep 17 00:00:00 2001 From: Asher Date: Tue, 23 Aug 2022 11:12:46 -0500 Subject: [PATCH 01/18] Refactor reconnect test to support sub-tests Going to add an alternate screen test next. --- tty_test.go | 179 ++++++++++++++++++++++++++-------------------------- 1 file changed, 89 insertions(+), 90 deletions(-) diff --git a/tty_test.go b/tty_test.go index bfc20f3..7b73559 100644 --- a/tty_test.go +++ b/tty_test.go @@ -66,113 +66,112 @@ func testTTY(ctx context.Context, t *testing.T, e Execer) { func TestReconnectTTY(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + t.Run("RegularScreen", func(t *testing.T) { + t.Parallel() - ws, server := mockConn(ctx, t, &Options{ - ReconnectingProcessTimeout: time.Second, - }) - defer server.Close() - - command := Command{ - ID: uuid.NewString(), - Command: "sh", - TTY: true, - Stdin: true, - } - execer := RemoteExecer(ws) - process, err := execer.Start(ctx, command) - assert.Success(t, "start sh", err) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() - // Write some unique output. - echoCmd := "echo test:$((1+1))" - data := []byte(echoCmd + "\r\n") - _, err = process.Stdin().Write(data) - assert.Success(t, "write to stdin", err) - expected := []string{echoCmd, "test:2"} - - findEcho := func(expected []string) bool { - scanner := bufio.NewScanner(process.Stdout()) - outer: - for _, str := range expected { - for scanner.Scan() { - line := scanner.Text() - t.Logf("bash tty stdout = %s", line) - if strings.Contains(line, str) { - continue outer - } - } - return false // Reached the end of output without finding str. + command := Command{ + ID: uuid.NewString(), + Command: "sh", + TTY: true, + Stdin: true, } - return true - } - assert.True(t, "find echo", findEcho(expected)) + ws, server := mockConn(ctx, t, &Options{ + ReconnectingProcessTimeout: time.Second, + }) + defer server.Close() - // Test disconnecting then reconnecting. - ws.Close(websocket.StatusNormalClosure, "disconnected") - server.Close() + process, err := RemoteExecer(ws).Start(ctx, command) + assert.Success(t, "start sh", err) - ws, server = mockConn(ctx, t, &Options{ - ReconnectingProcessTimeout: time.Second, - }) - defer server.Close() + // Write some unique output. + echoCmd := "echo test:$((1+1))" + _, err = process.Stdin().Write([]byte(echoCmd + "\r\n")) + assert.Success(t, "write to stdin", err) + expected := []string{echoCmd, "test:2"} - execer = RemoteExecer(ws) - process, err = execer.Start(ctx, command) - assert.Success(t, "attach sh", err) + assert.True(t, "find echo", findStdout(t, process, expected)) - // The inactivity timeout should not have been triggered. - time.Sleep(time.Second) + // Test disconnecting then reconnecting. + ws.Close(websocket.StatusNormalClosure, "disconnected") + server.Close() - echoCmd = "echo test:$((2+2))" - data = []byte(echoCmd + "\r\n") - _, err = process.Stdin().Write(data) - assert.Success(t, "write to stdin", err) - expected = append(expected, echoCmd, "test:4") + ws, server = mockConn(ctx, t, &Options{ + ReconnectingProcessTimeout: time.Second, + }) + defer server.Close() - assert.True(t, "find echo", findEcho(expected)) + process, err = RemoteExecer(ws).Start(ctx, command) + assert.Success(t, "attach sh", err) - // Test disconnecting while another connection is active. - ws2, server2 := mockConn(ctx, t, &Options{ - // Divide the time to test that the heartbeat keeps it open through multiple - // intervals. - ReconnectingProcessTimeout: time.Second / 4, - }) - defer server2.Close() + // The inactivity timeout should not have been triggered. + time.Sleep(time.Second) - execer = RemoteExecer(ws2) - process, err = execer.Start(ctx, command) - assert.Success(t, "attach sh", err) + echoCmd = "echo test:$((2+2))" + _, err = process.Stdin().Write([]byte(echoCmd + "\r\n")) + assert.Success(t, "write to stdin", err) + expected = append(expected, echoCmd, "test:4") - ws.Close(websocket.StatusNormalClosure, "disconnected") - server.Close() - time.Sleep(time.Second) + assert.True(t, "find echo", findStdout(t, process, expected)) - // This connection should still be up. - echoCmd = "echo test:$((3+3))" - data = []byte(echoCmd + "\r\n") - _, err = process.Stdin().Write(data) - assert.Success(t, "write to stdin", err) - expected = append(expected, echoCmd, "test:6") + // Test disconnecting while another connection is active. + ws2, server2 := mockConn(ctx, t, &Options{ + // Divide the time to test that the heartbeat keeps it open through multiple + // intervals. + ReconnectingProcessTimeout: time.Second / 4, + }) + defer server2.Close() - assert.True(t, "find echo", findEcho(expected)) + process, err = RemoteExecer(ws2).Start(ctx, command) + assert.Success(t, "attach sh", err) - // Close the remaining connection and wait for inactivity. - ws2.Close(websocket.StatusNormalClosure, "disconnected") - server2.Close() - time.Sleep(time.Second) + ws.Close(websocket.StatusNormalClosure, "disconnected") + server.Close() + time.Sleep(time.Second) - // The next connection should start a new process. - ws, server = mockConn(ctx, t, &Options{ - ReconnectingProcessTimeout: time.Second, - }) - defer server.Close() + // This connection should still be up. + echoCmd = "echo test:$((3+3))" + _, err = process.Stdin().Write([]byte(echoCmd + "\r\n")) + assert.Success(t, "write to stdin", err) + expected = append(expected, echoCmd, "test:6") + + assert.True(t, "find echo", findStdout(t, process, expected)) + + // Close the remaining connection and wait for inactivity. + ws2.Close(websocket.StatusNormalClosure, "disconnected") + server2.Close() + time.Sleep(time.Second) + + // The next connection should start a new process. + ws, server = mockConn(ctx, t, &Options{ + ReconnectingProcessTimeout: time.Second, + }) + defer server.Close() - execer = RemoteExecer(ws) - process, err = execer.Start(ctx, command) - assert.Success(t, "attach sh", err) + process, err = RemoteExecer(ws).Start(ctx, command) + assert.Success(t, "attach sh", err) - // This time no echo since it is a new process. - assert.True(t, "find echo", !findEcho(expected)) + // This time no echo since it is a new process. + assert.True(t, "find echo", !findStdout(t, process, expected)) + }) +} + +func findStdout(t *testing.T, process Process, expected []string) bool { + t.Helper() + scanner := bufio.NewScanner(process.Stdout()) +outer: + for _, str := range expected { + for scanner.Scan() { + line := scanner.Text() + t.Logf("bash tty stdout = %s", line) + if strings.Contains(line, str) { + continue outer + } + } + return false // Reached the end of output without finding str. + } + return true } From 684df7aaff70160f8ba0a679d7a6f10a12229351 Mon Sep 17 00:00:00 2001 From: Asher Date: Tue, 13 Sep 2022 13:34:22 -0500 Subject: [PATCH 02/18] Revert "Add reconnecting ptys (#23)" This partially reverts commit 91201718149964a06c6350679d8cda0161edf887. The new method using screen will not share processes which is a fundamental shift so I think it will be easier to start from scratch. Even though we could keep the UUID check I removed it because it seems cool that you could create your own sessions in the terminal then connect to them in the browser (or vice-versa). --- go.mod | 1 - go.sum | 2 - server.go | 232 +++++++--------------------------------------------- tty_test.go | 8 +- 4 files changed, 34 insertions(+), 209 deletions(-) diff --git a/go.mod b/go.mod index 67ac33f..2d51221 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.14 require ( cdr.dev/slog v1.3.0 - github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2 github.com/creack/pty v1.1.11 github.com/google/go-cmp v0.4.0 github.com/google/uuid v1.3.0 diff --git a/go.sum b/go.sum index 8703f19..c9a9a09 100644 --- a/go.sum +++ b/go.sum @@ -30,8 +30,6 @@ github.com/alecthomas/kong v0.2.1-0.20190708041108-0548c6b1afae/go.mod h1:+inYUS github.com/alecthomas/kong-hcl v0.1.8-0.20190615233001-b21fea9723c8/go.mod h1:MRgZdU3vrFd05IQ89AxUZ0aYdF39BYoNFa324SodPCA= github.com/alecthomas/repr v0.0.0-20180818092828-117648cd9897 h1:p9Sln00KOTlrYkxI1zYWl1QLnEqAqEARBEYa8FQnQcY= github.com/alecthomas/repr v0.0.0-20180818092828-117648cd9897/go.mod h1:xTS7Pm1pD1mvyM075QCDSRqH6qRLXylzS24ZTpRiSzQ= -github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2 h1:7Ip0wMmLHLRJdrloDxZfhMm0xrLXZS8+COSu2bXmEQs= -github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/creack/pty v1.1.11 h1:07n33Z8lZxZ2qwegKbObQohDhXDQxiMMz1NOUGYlesw= diff --git a/server.go b/server.go index c9f4045..022ffb8 100644 --- a/server.go +++ b/server.go @@ -7,12 +7,8 @@ import ( "errors" "io" "net" - "sync" "time" - "github.com/armon/circbuf" - "github.com/google/uuid" - "go.coder.com/flog" "golang.org/x/sync/errgroup" "golang.org/x/xerrors" @@ -21,11 +17,9 @@ import ( "cdr.dev/wsep/internal/proto" ) -var reconnectingProcesses sync.Map - // Options allows configuring the server. type Options struct { - ReconnectingProcessTimeout time.Duration + SessionTimeout time.Duration } // Serve runs the server-side of wsep. @@ -38,8 +32,8 @@ func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Optio if options == nil { options = &Options{} } - if options.ReconnectingProcessTimeout == 0 { - options.ReconnectingProcessTimeout = 5 * time.Minute + if options.SessionTimeout == 0 { + options.SessionTimeout = 5 * time.Minute } c.SetReadLimit(maxMessageSize) @@ -66,8 +60,8 @@ func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Optio } return nil } - headerByt, bodyByt := proto.SplitMessage(byt) + headerByt, bodyByt := proto.SplitMessage(byt) err = json.Unmarshal(headerByt, &header) if err != nil { return xerrors.Errorf("unmarshal header: %w", err) @@ -86,172 +80,36 @@ func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Optio } command := mapToClientCmd(header.Command) + process, err = execer.Start(ctx, command) + if err != nil { + return err + } + defer process.Close() - // Only allow TTYs with IDs to be reconnected. - if command.TTY && header.ID != "" { - // Enforce a consistent format for IDs. - _, err := uuid.Parse(header.ID) - if err != nil { - flog.Error("%s is not a valid uuid: %w", header.ID, err) - } - - // Get an existing process or create a new one. - var rprocess *reconnectingProcess - rawRProcess, ok := reconnectingProcesses.Load(header.ID) - if ok { - rprocess, ok = rawRProcess.(*reconnectingProcess) - if !ok { - flog.Error("found invalid type in reconnecting process map for ID %s", header.ID) - } - process = rprocess.process - } else { - // The process will be kept alive as long as this context does not - // finish (and as long as the process does not exit on its own). This - // is a new context since the parent context finishes when the request - // ends which would kill the process prematurely. - ctx, cancel := context.WithCancel(context.Background()) - - // The process will be killed if the provided context ends. - process, err = execer.Start(ctx, command) - if err != nil { - cancel() - return err - } - - // Default to buffer 64KB. - ringBuffer, err := circbuf.NewBuffer(64 * 1024) - if err != nil { - cancel() - return xerrors.Errorf("unable to create ring buffer %w", err) - } - - rprocess = &reconnectingProcess{ - activeConns: make(map[string]net.Conn), - process: process, - // Timeouts created with AfterFunc can be reset. - timeout: time.AfterFunc(options.ReconnectingProcessTimeout, cancel), - ringBuffer: ringBuffer, - } - reconnectingProcesses.Store(header.ID, rprocess) - - // If the process exits send the exit code to all listening - // connections then close everything. - go func() { - err = process.Wait() - code := 0 - if exitErr, ok := err.(ExitError); ok { - code = exitErr.Code - } - rprocess.activeConnsMutex.Lock() - for _, conn := range rprocess.activeConns { - _ = sendExitCode(ctx, code, conn) - } - rprocess.activeConnsMutex.Unlock() - rprocess.Close() - reconnectingProcesses.Delete(header.ID) - }() - - // Write to the ring buffer and all connections as we receive stdout. - go func() { - buffer := make([]byte, 32*1024) - for { - read, err := rprocess.process.Stdout().Read(buffer) - if err != nil { - // When the process is closed this is triggered. - break - } - part := buffer[:read] - _, err = rprocess.ringBuffer.Write(part) - if err != nil { - flog.Error("reconnecting process %s write buffer: %v", header.ID, err) - cancel() - break - } - rprocess.activeConnsMutex.Lock() - for _, conn := range rprocess.activeConns { - _ = sendOutput(ctx, part, conn) - } - rprocess.activeConnsMutex.Unlock() - } - }() - } - - err = sendPID(ctx, process.Pid(), wsNetConn) - if err != nil { - flog.Error("failed to send pid %d", process.Pid()) - } - - // Write out the initial contents in the ring buffer. - err = sendOutput(ctx, rprocess.ringBuffer.Bytes(), wsNetConn) - if err != nil { - return xerrors.Errorf("write reconnecting process %s buffer: %w", header.ID, err) - } - - // Store this connection on the reconnecting process. All connections - // stored on the process will receive the process's stdout. - connectionID := uuid.NewString() - rprocess.activeConnsMutex.Lock() - rprocess.activeConns[connectionID] = wsNetConn - rprocess.activeConnsMutex.Unlock() - - // Keep resetting the inactivity timer while this connection is alive. - rprocess.timeout.Reset(options.ReconnectingProcessTimeout) - heartbeat := time.NewTicker(options.ReconnectingProcessTimeout / 2) - defer heartbeat.Stop() - go func() { - for { - select { - // Stop looping once this request finishes. - case <-ctx.Done(): - return - case <-heartbeat.C: - } - rprocess.timeout.Reset(options.ReconnectingProcessTimeout) - } - }() - - // Remove this connection from the process's connection list once the - // connection ends so data is no longer sent to it. - defer func() { - wsNetConn.Close() // REVIEW@asher: Not sure if necessary. - rprocess.activeConnsMutex.Lock() - delete(rprocess.activeConns, connectionID) - rprocess.activeConnsMutex.Unlock() - }() - } else { - process, err = execer.Start(ctx, command) - if err != nil { - return err - } + err = sendPID(ctx, process.Pid(), wsNetConn) + if err != nil { + flog.Error("failed to send pid %d", process.Pid()) + } - err = sendPID(ctx, process.Pid(), wsNetConn) - if err != nil { - flog.Error("failed to send pid %d", process.Pid()) + var outputgroup errgroup.Group + outputgroup.Go(func() error { + return copyWithHeader(process.Stdout(), wsNetConn, proto.Header{Type: proto.TypeStdout}) + }) + outputgroup.Go(func() error { + return copyWithHeader(process.Stderr(), wsNetConn, proto.Header{Type: proto.TypeStderr}) + }) + + go func() { + defer wsNetConn.Close() + _ = outputgroup.Wait() + err = process.Wait() + if exitErr, ok := err.(ExitError); ok { + _ = sendExitCode(ctx, exitErr.Code, wsNetConn) + return } + _ = sendExitCode(ctx, 0, wsNetConn) + }() - var outputgroup errgroup.Group - outputgroup.Go(func() error { - return copyWithHeader(process.Stdout(), wsNetConn, proto.Header{Type: proto.TypeStdout}) - }) - outputgroup.Go(func() error { - return copyWithHeader(process.Stderr(), wsNetConn, proto.Header{Type: proto.TypeStderr}) - }) - - go func() { - defer wsNetConn.Close() - _ = outputgroup.Wait() - err = process.Wait() - if exitErr, ok := err.(ExitError); ok { - _ = sendExitCode(ctx, exitErr.Code, wsNetConn) - return - } - _ = sendExitCode(ctx, 0, wsNetConn) - }() - - defer func() { - process.Close() - }() - } case proto.TypeResize: if process == nil { return errors.New("resize sent before command started") @@ -304,15 +162,6 @@ func sendPID(_ context.Context, pid int, conn net.Conn) error { return err } -func sendOutput(_ context.Context, data []byte, conn net.Conn) error { - header, err := json.Marshal(proto.ServerPidHeader{Type: proto.TypeStdout}) - if err != nil { - return err - } - _, err = proto.WithHeader(conn, header).Write(data) - return err -} - func copyWithHeader(r io.Reader, w io.Writer, header proto.Header) error { headerByt, err := json.Marshal(header) if err != nil { @@ -325,24 +174,3 @@ func copyWithHeader(r io.Reader, w io.Writer, header proto.Header) error { } return nil } - -type reconnectingProcess struct { - activeConnsMutex sync.Mutex - activeConns map[string]net.Conn - - ringBuffer *circbuf.Buffer - timeout *time.Timer - process Process -} - -// Close ends all connections to the reconnecting process and clears the ring -// buffer. -func (r *reconnectingProcess) Close() { - r.activeConnsMutex.Lock() - defer r.activeConnsMutex.Unlock() - for _, conn := range r.activeConns { - _ = conn.Close() - } - _ = r.process.Close() - r.ringBuffer.Reset() -} diff --git a/tty_test.go b/tty_test.go index 7b73559..a6adc28 100644 --- a/tty_test.go +++ b/tty_test.go @@ -80,7 +80,7 @@ func TestReconnectTTY(t *testing.T) { } ws, server := mockConn(ctx, t, &Options{ - ReconnectingProcessTimeout: time.Second, + SessionTimeout: time.Second, }) defer server.Close() @@ -100,7 +100,7 @@ func TestReconnectTTY(t *testing.T) { server.Close() ws, server = mockConn(ctx, t, &Options{ - ReconnectingProcessTimeout: time.Second, + SessionTimeout: time.Second, }) defer server.Close() @@ -121,7 +121,7 @@ func TestReconnectTTY(t *testing.T) { ws2, server2 := mockConn(ctx, t, &Options{ // Divide the time to test that the heartbeat keeps it open through multiple // intervals. - ReconnectingProcessTimeout: time.Second / 4, + SessionTimeout: time.Second / 4, }) defer server2.Close() @@ -147,7 +147,7 @@ func TestReconnectTTY(t *testing.T) { // The next connection should start a new process. ws, server = mockConn(ctx, t, &Options{ - ReconnectingProcessTimeout: time.Second, + SessionTimeout: time.Second, }) defer server.Close() From 6ee33417f551e6b5f7114852a91a801cf22d8953 Mon Sep 17 00:00:00 2001 From: Asher Date: Tue, 23 Aug 2022 13:12:48 -0500 Subject: [PATCH 03/18] Add test for alternate screen The output test waits for EOF; modify that behavior so we can check that certain strings are not displayed *without* waiting for the timeout. This means to be accurate we should always check for output that should exist after the output that should not exist would have shown up. --- ci/alt.sh | 36 ++++++++++++++++ tty_test.go | 122 ++++++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 141 insertions(+), 17 deletions(-) create mode 100755 ci/alt.sh diff --git a/ci/alt.sh b/ci/alt.sh new file mode 100755 index 0000000..b92d60e --- /dev/null +++ b/ci/alt.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +# Script for testing the alt screen. + +# Enter alt screen. +tput smcup + +function display() { + # Clear the screen. + tput clear + # Move cursor to the top left. + tput cup 0 0 + # Display content. + echo "ALT SCREEN" +} + +function redraw() { + display + echo "redrawn" +} + +# Re-display on resize. +trap 'redraw' WINCH + +display + +# The trap will not run while waiting for a command so read input in a loop with +# a timeout. +while true ; do + if read -n 1 -t .1 ; then + # Clear the screen. + tput clear + # Exit alt screen. + tput rmcup + exit + fi +done diff --git a/tty_test.go b/tty_test.go index a6adc28..fbd6ab1 100644 --- a/tty_test.go +++ b/tty_test.go @@ -93,12 +93,13 @@ func TestReconnectTTY(t *testing.T) { assert.Success(t, "write to stdin", err) expected := []string{echoCmd, "test:2"} - assert.True(t, "find echo", findStdout(t, process, expected)) + assert.True(t, "find echo", checkStdout(t, process, expected, []string{})) - // Test disconnecting then reconnecting. + // Disconnect. ws.Close(websocket.StatusNormalClosure, "disconnected") server.Close() + // Reconnect. ws, server = mockConn(ctx, t, &Options{ SessionTimeout: time.Second, }) @@ -107,7 +108,7 @@ func TestReconnectTTY(t *testing.T) { process, err = RemoteExecer(ws).Start(ctx, command) assert.Success(t, "attach sh", err) - // The inactivity timeout should not have been triggered. + // The inactivity timeout should not trigger since we are connected. time.Sleep(time.Second) echoCmd = "echo test:$((2+2))" @@ -115,9 +116,9 @@ func TestReconnectTTY(t *testing.T) { assert.Success(t, "write to stdin", err) expected = append(expected, echoCmd, "test:4") - assert.True(t, "find echo", findStdout(t, process, expected)) + assert.True(t, "find echo", checkStdout(t, process, expected, []string{})) - // Test disconnecting while another connection is active. + // Make a simultaneously active connection. ws2, server2 := mockConn(ctx, t, &Options{ // Divide the time to test that the heartbeat keeps it open through multiple // intervals. @@ -128,8 +129,12 @@ func TestReconnectTTY(t *testing.T) { process, err = RemoteExecer(ws2).Start(ctx, command) assert.Success(t, "attach sh", err) + // Disconnect the first connection. ws.Close(websocket.StatusNormalClosure, "disconnected") server.Close() + + // Wait for inactivity. It should still stay up because of the second + // connection. time.Sleep(time.Second) // This connection should still be up. @@ -138,11 +143,13 @@ func TestReconnectTTY(t *testing.T) { assert.Success(t, "write to stdin", err) expected = append(expected, echoCmd, "test:6") - assert.True(t, "find echo", findStdout(t, process, expected)) + assert.True(t, "find echo", checkStdout(t, process, expected, []string{})) - // Close the remaining connection and wait for inactivity. + // Disconnect the second connection. ws2.Close(websocket.StatusNormalClosure, "disconnected") server2.Close() + + // Wait for inactivity. time.Sleep(time.Second) // The next connection should start a new process. @@ -154,24 +161,105 @@ func TestReconnectTTY(t *testing.T) { process, err = RemoteExecer(ws).Start(ctx, command) assert.Success(t, "attach sh", err) + echoCmd = "echo test:$((4+4))" + _, err = process.Stdin().Write([]byte(echoCmd + "\r\n")) + assert.Success(t, "write to stdin", err) + unexpected := expected + expected = []string{"test:8"} + // This time no echo since it is a new process. - assert.True(t, "find echo", !findStdout(t, process, expected)) + assert.True(t, "find echo", checkStdout(t, process, expected, unexpected)) + }) + + t.Run("AlternateScreen", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + command := Command{ + ID: uuid.NewString(), + Command: "sh", + TTY: true, + Stdin: true, + } + + ws, server := mockConn(ctx, t, &Options{ + SessionTimeout: time.Second, + }) + defer server.Close() + + process, err := RemoteExecer(ws).Start(ctx, command) + assert.Success(t, "attach sh", err) + + // Run an application that enters the alternate screen. + _, err = process.Stdin().Write([]byte("./ci/alt.sh\r\n")) + assert.Success(t, "write to stdin", err) + + assert.True(t, "find output", checkStdout(t, process, []string{"./ci/alt.sh", "ALT SCREEN"}, []string{})) + + // Disconnect. + ws.Close(websocket.StatusNormalClosure, "disconnected") + server.Close() + + // Reconnect; the application should redraw. + ws, server = mockConn(ctx, t, &Options{ + SessionTimeout: time.Second, + }) + defer server.Close() + + process, err = RemoteExecer(ws).Start(ctx, command) + assert.Success(t, "attach sh", err) + + // Should have only the application output. + assert.True(t, "find output", checkStdout(t, process, []string{"ALT SCREEN"}, []string{"./ci/alt.sh"})) + + // Exit the application. + _, err = process.Stdin().Write([]byte("q")) + assert.Success(t, "write to stdin", err) + + // Disconnect. + ws.Close(websocket.StatusNormalClosure, "disconnected") + server.Close() + + // Reconnect. + ws, server = mockConn(ctx, t, &Options{ + SessionTimeout: time.Second, + }) + defer server.Close() + + process, err = RemoteExecer(ws).Start(ctx, command) + assert.Success(t, "attach sh", err) + + echoCmd := "echo test:$((5+5))" + _, err = process.Stdin().Write([]byte(echoCmd + "\r\n")) + assert.Success(t, "write to stdin", err) + + assert.True(t, "find output", checkStdout(t, process, []string{echoCmd, "test:10"}, []string{})) }) } -func findStdout(t *testing.T, process Process, expected []string) bool { +// checkStdout ensures that expected is in the stdout in the specified order. +// On the way if anything in unexpected comes up return false. Return once +// everything in expected has been found or EOF. +func checkStdout(t *testing.T, process Process, expected, unexpected []string) bool { t.Helper() + i := 0 scanner := bufio.NewScanner(process.Stdout()) -outer: - for _, str := range expected { - for scanner.Scan() { - line := scanner.Text() - t.Logf("bash tty stdout = %s", line) + for scanner.Scan() { + line := scanner.Text() + t.Logf("bash tty stdout = %s", strings.ReplaceAll(line, "\x1b", "ESC")) + for _, str := range unexpected { if strings.Contains(line, str) { - continue outer + return false } } - return false // Reached the end of output without finding str. + if strings.Contains(line, expected[i]) { + i = i + 1 + } + if i == len(expected) { + return true + } } - return true + return false } From f2bf68b57e3135cb12f7892d17f13518a8243125 Mon Sep 17 00:00:00 2001 From: Asher Date: Tue, 23 Aug 2022 21:04:15 -0500 Subject: [PATCH 04/18] Add timeout flag to dev client This makes it easier to test reconnects manually. --- dev/client/main.go | 29 +++++++++++++++++++++++------ dev/server/main.go | 5 ++++- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/dev/client/main.go b/dev/client/main.go index 0c28fd4..2b05c18 100644 --- a/dev/client/main.go +++ b/dev/client/main.go @@ -9,6 +9,7 @@ import ( "os" "os/signal" "syscall" + "time" "cdr.dev/wsep" "github.com/spf13/pflag" @@ -20,41 +21,48 @@ import ( ) type notty struct { + timeout time.Duration } func (c *notty) Run(fl *pflag.FlagSet) { - do(fl, false, "") + do(fl, false, "", c.timeout) } func (c *notty) Spec() cli.CommandSpec { return cli.CommandSpec{ Name: "notty", - Usage: "[flags]", + Usage: "[flags] ", Desc: `Run a command without tty enabled.`, } } +func (c *notty) RegisterFlags(fl *pflag.FlagSet) { + fl.DurationVar(&c.timeout, "timeout", 0, "disconnect after specified timeout") +} + type tty struct { - id string + id string + timeout time.Duration } func (c *tty) Run(fl *pflag.FlagSet) { - do(fl, true, c.id) + do(fl, true, c.id, c.timeout) } func (c *tty) Spec() cli.CommandSpec { return cli.CommandSpec{ Name: "tty", - Usage: "[id] [flags]", + Usage: "[flags] ", Desc: `Run a command with tty enabled. Use the same ID to reconnect.`, } } func (c *tty) RegisterFlags(fl *pflag.FlagSet) { fl.StringVar(&c.id, "id", "", "sets id for reconnection") + fl.DurationVar(&c.timeout, "timeout", 0, "disconnect after the specified timeout") } -func do(fl *pflag.FlagSet, tty bool, id string) { +func do(fl *pflag.FlagSet, tty bool, id string, timeout time.Duration) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -112,6 +120,15 @@ func do(fl *pflag.FlagSet, tty bool, id string) { io.Copy(stdin, os.Stdin) }() + if timeout != 0 { + timer := time.NewTimer(timeout) + defer timer.Stop() + go func() { + <-timer.C + conn.Close(websocket.StatusNormalClosure, "normal closure") + }() + } + err = process.Wait() if err != nil { flog.Error("process failed: %v", err) diff --git a/dev/server/main.go b/dev/server/main.go index 66020d5..fcd94bf 100644 --- a/dev/server/main.go +++ b/dev/server/main.go @@ -2,6 +2,7 @@ package main import ( "net/http" + "time" "cdr.dev/wsep" "go.coder.com/flog" @@ -23,7 +24,9 @@ func serve(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) return } - err = wsep.Serve(r.Context(), ws, wsep.LocalExecer{}, nil) + err = wsep.Serve(r.Context(), ws, wsep.LocalExecer{}, &wsep.Options{ + SessionTimeout: 30 * time.Second, + }) if err != nil { flog.Error("failed to serve execer: %v", err) ws.Close(websocket.StatusAbnormalClosure, "failed to serve execer") From fa6c425d34d4518b2a649ff71da87a6964254aaf Mon Sep 17 00:00:00 2001 From: Asher Date: Thu, 25 Aug 2022 14:28:45 -0500 Subject: [PATCH 05/18] Add size to initial connection This way you do not need a subsequent resize and we can have the right size from the get-go. --- browser/client.ts | 12 +++++++++--- client.go | 9 ++++++--- dev/client/main.go | 6 ++++++ exec.go | 4 ++++ internal/proto/clientmsg.go | 2 ++ localexec_unix.go | 5 ++++- server.go | 8 ++++++++ tty_test.go | 6 ++++++ 8 files changed, 45 insertions(+), 7 deletions(-) diff --git a/browser/client.ts b/browser/client.ts index 009912e..0a5322e 100644 --- a/browser/client.ts +++ b/browser/client.ts @@ -14,7 +14,7 @@ export interface Command { } export type ClientHeader = - | { type: 'start'; command: Command } + | { type: 'start'; id: string; command: Command; cols: number; rows: number; } | { type: 'stdin' } | { type: 'close_stdin' } | { type: 'resize'; cols: number; rows: number }; @@ -42,8 +42,14 @@ export const closeStdin = (ws: WebSocket) => { ws.send(msg.buffer); }; -export const startCommand = (ws: WebSocket, command: Command) => { - const msg = joinMessage({ type: 'start', command: command }); +export const startCommand = ( + ws: WebSocket, + command: Command, + id: string, + rows: number, + cols: number +) => { + const msg = joinMessage({ type: 'start', command, id, rows, cols }); ws.send(msg.buffer); }; diff --git a/client.go b/client.go index 4a22270..5a59ac8 100644 --- a/client.go +++ b/client.go @@ -28,10 +28,13 @@ func RemoteExecer(conn *websocket.Conn) Execer { // Command represents an external command to be run type Command struct { // ID allows reconnecting commands that have a TTY. - ID string - Command string - Args []string + ID string + Command string + Args []string + // Commands with a TTY also require Rows and Cols. TTY bool + Rows uint16 + Cols uint16 Stdin bool UID uint32 GID uint32 diff --git a/dev/client/main.go b/dev/client/main.go index 2b05c18..2917114 100644 --- a/dev/client/main.go +++ b/dev/client/main.go @@ -81,12 +81,18 @@ func do(fl *pflag.FlagSet, tty bool, id string, timeout time.Duration) { if len(fl.Args()) > 1 { args = fl.Args()[1:] } + width, height, err := term.GetSize(int(os.Stdin.Fd())) + if err != nil { + flog.Fatal("unable to get term size") + } process, err := executor.Start(ctx, wsep.Command{ ID: id, Command: fl.Arg(0), Args: args, TTY: tty, Stdin: true, + Rows: uint16(height), + Cols: uint16(width), }) if err != nil { flog.Fatal("failed to start remote command: %v", err) diff --git a/exec.go b/exec.go index e515289..f03de21 100644 --- a/exec.go +++ b/exec.go @@ -49,6 +49,8 @@ func mapToProtoCmd(c Command) proto.Command { Args: c.Args, Stdin: c.Stdin, TTY: c.TTY, + Rows: c.Rows, + Cols: c.Cols, UID: c.UID, GID: c.GID, Env: c.Env, @@ -62,6 +64,8 @@ func mapToClientCmd(c proto.Command) Command { Args: c.Args, Stdin: c.Stdin, TTY: c.TTY, + Rows: c.Rows, + Cols: c.Cols, UID: c.UID, GID: c.GID, Env: c.Env, diff --git a/internal/proto/clientmsg.go b/internal/proto/clientmsg.go index ae05ad5..19c9add 100644 --- a/internal/proto/clientmsg.go +++ b/internal/proto/clientmsg.go @@ -28,6 +28,8 @@ type Command struct { Args []string `json:"args"` Stdin bool `json:"stdin"` TTY bool `json:"tty"` + Rows uint16 `json:"rows"` + Cols uint16 `json:"cols"` UID uint32 `json:"uid"` GID uint32 `json:"gid"` Env []string `json:"env"` diff --git a/localexec_unix.go b/localexec_unix.go index a6d322c..d118cd4 100644 --- a/localexec_unix.go +++ b/localexec_unix.go @@ -61,7 +61,10 @@ func (l LocalExecer) Start(ctx context.Context, c Command) (Process, error) { if c.TTY { // This special WSEP_TTY variable helps debug unexpected TTYs. process.cmd.Env = append(process.cmd.Env, "WSEP_TTY=true") - process.tty, err = pty.Start(process.cmd) + process.tty, err = pty.StartWithSize(process.cmd, &pty.Winsize{ + Rows: c.Rows, + Cols: c.Cols, + }) if err != nil { return nil, xerrors.Errorf("start command with pty: %w", err) } diff --git a/server.go b/server.go index 022ffb8..80f060c 100644 --- a/server.go +++ b/server.go @@ -80,6 +80,14 @@ func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Optio } command := mapToClientCmd(header.Command) + + if command.TTY { + // Enforce rows and columns so the TTY will be correctly sized. + if command.Rows == 0 || command.Cols == 0 { + return xerrors.Errorf("rows and cols must be non-zero: %w", err) + } + } + process, err = execer.Start(ctx, command) if err != nil { return err diff --git a/tty_test.go b/tty_test.go index fbd6ab1..375a400 100644 --- a/tty_test.go +++ b/tty_test.go @@ -33,6 +33,8 @@ func testTTY(ctx context.Context, t *testing.T, e Execer) { Command: "sh", TTY: true, Stdin: true, + Cols: 100, + Rows: 100, }) assert.Success(t, "start sh", err) var wg sync.WaitGroup @@ -77,6 +79,8 @@ func TestReconnectTTY(t *testing.T) { Command: "sh", TTY: true, Stdin: true, + Cols: 100, + Rows: 100, } ws, server := mockConn(ctx, t, &Options{ @@ -182,6 +186,8 @@ func TestReconnectTTY(t *testing.T) { Command: "sh", TTY: true, Stdin: true, + Cols: 100, + Rows: 100, } ws, server := mockConn(ctx, t, &Options{ From eb9b0e59f625b864ab75cdf6a3504ef994afb3b7 Mon Sep 17 00:00:00 2001 From: Asher Date: Thu, 25 Aug 2022 19:22:34 -0500 Subject: [PATCH 06/18] Prevent prompt from rendering twice in tests --- tty_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tty_test.go b/tty_test.go index 375a400..cb629d3 100644 --- a/tty_test.go +++ b/tty_test.go @@ -93,7 +93,7 @@ func TestReconnectTTY(t *testing.T) { // Write some unique output. echoCmd := "echo test:$((1+1))" - _, err = process.Stdin().Write([]byte(echoCmd + "\r\n")) + _, err = process.Stdin().Write([]byte(echoCmd + "\n")) assert.Success(t, "write to stdin", err) expected := []string{echoCmd, "test:2"} @@ -116,7 +116,7 @@ func TestReconnectTTY(t *testing.T) { time.Sleep(time.Second) echoCmd = "echo test:$((2+2))" - _, err = process.Stdin().Write([]byte(echoCmd + "\r\n")) + _, err = process.Stdin().Write([]byte(echoCmd + "\n")) assert.Success(t, "write to stdin", err) expected = append(expected, echoCmd, "test:4") @@ -143,7 +143,7 @@ func TestReconnectTTY(t *testing.T) { // This connection should still be up. echoCmd = "echo test:$((3+3))" - _, err = process.Stdin().Write([]byte(echoCmd + "\r\n")) + _, err = process.Stdin().Write([]byte(echoCmd + "\n")) assert.Success(t, "write to stdin", err) expected = append(expected, echoCmd, "test:6") @@ -199,7 +199,7 @@ func TestReconnectTTY(t *testing.T) { assert.Success(t, "attach sh", err) // Run an application that enters the alternate screen. - _, err = process.Stdin().Write([]byte("./ci/alt.sh\r\n")) + _, err = process.Stdin().Write([]byte("./ci/alt.sh\n")) assert.Success(t, "write to stdin", err) assert.True(t, "find output", checkStdout(t, process, []string{"./ci/alt.sh", "ALT SCREEN"}, []string{})) From 341c5ef77f9a168454e1350429cd931aaa53560b Mon Sep 17 00:00:00 2001 From: Asher Date: Thu, 15 Sep 2022 17:07:48 -0500 Subject: [PATCH 07/18] Add Nix flake --- ci/fmt.sh | 3 ++- ci/lint.sh | 2 +- flake.lock | 41 +++++++++++++++++++++++++++++++++++++++++ flake.nix | 19 +++++++++++++++++++ 4 files changed, 63 insertions(+), 2 deletions(-) create mode 100644 flake.lock create mode 100644 flake.nix diff --git a/ci/fmt.sh b/ci/fmt.sh index c3a0c16..a9522dd 100755 --- a/ci/fmt.sh +++ b/ci/fmt.sh @@ -1,4 +1,5 @@ -#!/bin/bash +#!/usr/bin/env bash + echo "Formatting..." go mod tidy diff --git a/ci/lint.sh b/ci/lint.sh index 9a9554c..1e06b61 100755 --- a/ci/lint.sh +++ b/ci/lint.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash echo "Linting..." diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..6066b3b --- /dev/null +++ b/flake.lock @@ -0,0 +1,41 @@ +{ + "nodes": { + "flake-utils": { + "locked": { + "lastModified": 1659877975, + "narHash": "sha256-zllb8aq3YO3h8B/U0/J1WBgAL8EX5yWf5pMj3G0NAmc=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "c0e246b9b83f637f4681389ecabcb2681b4f3af0", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1663235518, + "narHash": "sha256-q8zLK6rK/CLXEguaPgm9yQJcY0VQtOBhAT9EV2UFK/A=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "2277e4c9010b0f27585eb0bed0a86d7cbc079354", + "type": "github" + }, + "original": { + "id": "nixpkgs", + "type": "indirect" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..1b932a4 --- /dev/null +++ b/flake.nix @@ -0,0 +1,19 @@ +{ + description = "wsep"; + + inputs.flake-utils.url = "github:numtide/flake-utils"; + + outputs = { self, nixpkgs, flake-utils }: + flake-utils.lib.eachDefaultSystem + (system: + let pkgs = nixpkgs.legacyPackages.${system}; + in { + devShells.default = pkgs.mkShell { + buildInputs = with pkgs; [ + go + screen + ]; + }; + } + ); +} From f4d980b869f69afdcc6cc7e234cfbfc3ef9f59e4 Mon Sep 17 00:00:00 2001 From: Asher Date: Fri, 16 Sep 2022 13:59:22 -0500 Subject: [PATCH 08/18] Propagate process close error --- client.go | 14 +++++++------- client_test.go | 5 +++-- exec.go | 12 +++++++++--- internal/proto/servermsg.go | 1 + localexec.go | 3 ++- localexec_test.go | 3 ++- server.go | 19 ++++++++++++------- 7 files changed, 36 insertions(+), 21 deletions(-) diff --git a/client.go b/client.go index 5a59ac8..5097c90 100644 --- a/client.go +++ b/client.go @@ -162,7 +162,7 @@ func (r remoteProcess) listen(ctx context.Context) { defer r.conn.Close(websocket.StatusNormalClosure, "normal closure") defer close(r.done) - exitCode := make(chan int, 1) + exitMsg := make(chan proto.ServerExitCodeHeader, 1) var eg errgroup.Group eg.Go(func() error { @@ -195,13 +195,13 @@ func (r remoteProcess) listen(ctx context.Context) { return err } case proto.TypeExitCode: - var exitMsg proto.ServerExitCodeHeader - err = json.Unmarshal(headerByt, &exitMsg) + var msg proto.ServerExitCodeHeader + err = json.Unmarshal(headerByt, &msg) if err != nil { continue } - exitCode <- exitMsg.ExitCode + exitMsg <- msg return nil } } @@ -210,9 +210,9 @@ func (r remoteProcess) listen(ctx context.Context) { err := eg.Wait() select { - case exitCode := <-exitCode: - if exitCode != 0 { - r.done <- ExitError{Code: exitCode} + case exitMsg := <-exitMsg: + if exitMsg.ExitCode != 0 { + r.done <- ExitError{code: exitMsg.ExitCode, error: exitMsg.Error} } default: r.done <- err diff --git a/client_test.go b/client_test.go index 77276da..3ec1a77 100644 --- a/client_test.go +++ b/client_test.go @@ -107,9 +107,10 @@ func testExecerFail(ctx context.Context, t *testing.T, execer Execer) { go io.Copy(ioutil.Discard, process.Stdout()) err = process.Wait() - code, ok := err.(ExitError) + exitErr, ok := err.(ExitError) assert.True(t, "is exit error", ok) - assert.True(t, "exit code is nonzero", code.Code != 0) + assert.True(t, "exit code is nonzero", exitErr.ExitCode() != 0) + assert.Equal(t, "exit error", exitErr.Error(), "exit status 2") assert.Error(t, "wait for process to error", err) } diff --git a/exec.go b/exec.go index f03de21..869cd3c 100644 --- a/exec.go +++ b/exec.go @@ -2,7 +2,6 @@ package wsep import ( "context" - "fmt" "io" "cdr.dev/wsep/internal/proto" @@ -10,11 +9,18 @@ import ( // ExitError is sent when the command terminates. type ExitError struct { - Code int + code int + error string } +// ExitCode returns the exit code of the process. +func (e ExitError) ExitCode() int { + return e.code +} + +// Error returns a string describing why the process errored. func (e ExitError) Error() string { - return fmt.Sprintf("process exited with code %v", e.Code) + return e.error } // Process represents a started command. diff --git a/internal/proto/servermsg.go b/internal/proto/servermsg.go index 61614fa..2132a26 100644 --- a/internal/proto/servermsg.go +++ b/internal/proto/servermsg.go @@ -18,4 +18,5 @@ type ServerPidHeader struct { type ServerExitCodeHeader struct { Type string `json:"type"` ExitCode int `json:"exit_code"` + Error string `json:"error"` } diff --git a/localexec.go b/localexec.go index bdb7237..fe2e955 100644 --- a/localexec.go +++ b/localexec.go @@ -29,7 +29,8 @@ func (l *localProcess) Wait() error { err := l.cmd.Wait() if exitErr, ok := err.(*exec.ExitError); ok { return ExitError{ - Code: exitErr.ExitCode(), + code: exitErr.ExitCode(), + error: exitErr.Error(), } } return err diff --git a/localexec_test.go b/localexec_test.go index 8a454c4..4d2caac 100644 --- a/localexec_test.go +++ b/localexec_test.go @@ -72,7 +72,8 @@ func TestExitCode(t *testing.T) { err = process.Wait() exitErr, ok := err.(ExitError) assert.True(t, "error is ExitError", ok) - assert.Equal(t, "exit error", exitErr.Code, 127) + assert.Equal(t, "exit error code", exitErr.ExitCode(), 127) + assert.Equal(t, "exit error", exitErr.Error(), "exit status 127") } func TestStdin(t *testing.T) { diff --git a/server.go b/server.go index 80f060c..e4e7d55 100644 --- a/server.go +++ b/server.go @@ -110,12 +110,8 @@ func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Optio go func() { defer wsNetConn.Close() _ = outputgroup.Wait() - err = process.Wait() - if exitErr, ok := err.(ExitError); ok { - _ = sendExitCode(ctx, exitErr.Code, wsNetConn) - return - } - _ = sendExitCode(ctx, 0, wsNetConn) + err := process.Wait() + _ = sendExitCode(ctx, err, wsNetConn) }() case proto.TypeResize: @@ -149,10 +145,19 @@ func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Optio } } -func sendExitCode(_ context.Context, exitCode int, conn net.Conn) error { +func sendExitCode(_ context.Context, err error, conn net.Conn) error { + exitCode := 0 + errorStr := "" + if err != nil { + errorStr = err.Error() + } + if exitErr, ok := err.(ExitError); ok { + exitCode = exitErr.ExitCode() + } header, err := json.Marshal(proto.ServerExitCodeHeader{ Type: proto.TypeExitCode, ExitCode: exitCode, + Error: errorStr, }) if err != nil { return err From 91e0276d2dd921f9ae12a60ef0da4b288d904ea4 Mon Sep 17 00:00:00 2001 From: Asher Date: Tue, 13 Sep 2022 15:10:04 -0500 Subject: [PATCH 09/18] Implement reconnecting TTY with screen --- ci/image/Dockerfile | 2 + dev/client/main.go | 1 + exec.go | 8 +- localexec.go | 3 +- server.go | 238 +++++++++++++++++++++++++++++++++++++++++++- tty_test.go | 62 +++++++++++- 6 files changed, 302 insertions(+), 12 deletions(-) diff --git a/ci/image/Dockerfile b/ci/image/Dockerfile index 5b5c0fb..e08d509 100644 --- a/ci/image/Dockerfile +++ b/ci/image/Dockerfile @@ -3,6 +3,8 @@ FROM golang:1 ENV GOFLAGS="-mod=readonly" ENV CI=true +RUN apt update && apt install -y screen + RUN go install golang.org/x/tools/cmd/goimports@latest RUN go install golang.org/x/lint/golint@latest RUN go install github.com/mattn/goveralls@latest diff --git a/dev/client/main.go b/dev/client/main.go index 2917114..41f1815 100644 --- a/dev/client/main.go +++ b/dev/client/main.go @@ -93,6 +93,7 @@ func do(fl *pflag.FlagSet, tty bool, id string, timeout time.Duration) { Stdin: true, Rows: uint16(height), Cols: uint16(width), + Env: []string{"TERM=" + os.Getenv("TERM")}, }) if err != nil { flog.Fatal("failed to start remote command: %v", err) diff --git a/exec.go b/exec.go index 869cd3c..0ba6def 100644 --- a/exec.go +++ b/exec.go @@ -38,8 +38,8 @@ type Process interface { Resize(ctx context.Context, rows, cols uint16) error // Wait returns ExitError when the command terminates with a non-zero exit code. Wait() error - // Close terminates the process and underlying connection(s). - // It must be called otherwise a connection or process may leak. + // Close sends a SIGTERM to the process. To force a shutdown cancel the + // context passed into the execer. Close() error } @@ -64,8 +64,8 @@ func mapToProtoCmd(c Command) proto.Command { } } -func mapToClientCmd(c proto.Command) Command { - return Command{ +func mapToClientCmd(c proto.Command) *Command { + return &Command{ Command: c.Command, Args: c.Args, Stdin: c.Stdin, diff --git a/localexec.go b/localexec.go index fe2e955..31a1f2c 100644 --- a/localexec.go +++ b/localexec.go @@ -3,6 +3,7 @@ package wsep import ( "io" "os/exec" + "syscall" "golang.org/x/xerrors" ) @@ -37,7 +38,7 @@ func (l *localProcess) Wait() error { } func (l *localProcess) Close() error { - return l.cmd.Process.Kill() + return l.cmd.Process.Signal(syscall.SIGTERM) } func (l *localProcess) Pid() int { diff --git a/server.go b/server.go index e4e7d55..92c4308 100644 --- a/server.go +++ b/server.go @@ -5,8 +5,14 @@ import ( "context" "encoding/json" "errors" + "fmt" "io" "net" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" "time" "go.coder.com/flog" @@ -17,11 +23,32 @@ import ( "cdr.dev/wsep/internal/proto" ) +var sessions sync.Map + // Options allows configuring the server. type Options struct { SessionTimeout time.Duration } +type session struct { + ctx context.Context + screenID string + ready chan error + timeout *time.Timer +} + +// Dispose closes the specified session. +func Dispose(id string) { + if rawSession, ok := sessions.Load(id); ok { + if sess, ok := rawSession.(*session); ok { + sess.timeout.Reset(0) + select { + case <-sess.ctx.Done(): + } + } + } +} + // Serve runs the server-side of wsep. // The execer may be another wsep connection for chaining. // Use LocalExecer for local command execution. @@ -42,6 +69,7 @@ func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Optio process Process wsNetConn = websocket.NetConn(ctx, c, websocket.MessageBinary) ) + defer wsNetConn.Close() for { if err := ctx.Err(); err != nil { return err @@ -88,15 +116,23 @@ func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Optio } } - process, err = execer.Start(ctx, command) + // Only TTYs with IDs can be reconnected. + if command.TTY && header.ID != "" { + command, err = createSession(ctx, header.ID, command, execer, options) + if err != nil { + return err + } + } + + // The process will get killed when the connection context ends. + process, err = execer.Start(ctx, *command) if err != nil { return err } - defer process.Close() err = sendPID(ctx, process.Pid(), wsNetConn) if err != nil { - flog.Error("failed to send pid %d", process.Pid()) + return xerrors.Errorf("failed to send pid %d: %w", process.Pid(), err) } var outputgroup errgroup.Group @@ -108,7 +144,8 @@ func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Optio }) go func() { - defer wsNetConn.Close() + // Wait for the readers to close which happens when the connection + // closes or the process dies. _ = outputgroup.Wait() err := process.Wait() _ = sendExitCode(ctx, err, wsNetConn) @@ -187,3 +224,196 @@ func copyWithHeader(r io.Reader, w io.Writer, header proto.Header) error { } return nil } + +func createSession(ctx context.Context, id string, command *Command, execer Execer, options *Options) (*Command, error) { + // If screen is not installed spawn the command normally. + _, err := exec.LookPath("screen") + if err != nil { + flog.Info("`screen` could not be found; session %s will not persist", id) + return command, nil + } + + settings := []string{ + // Tell screen not to handle motion for xterm* terminals which + // allows scrolling the terminal via the mouse wheel or scroll bar + // (by default screen uses it to cycle through the command history). + // There does not seem to be a way to make screen itself scroll on + // mouse wheel. tmux can do it but then there is no scroll bar and + // it kicks you into copy mode where keys stop working until you + // exit copy mode which seems like it could be confusing. + "termcapinfo xterm* ti@:te@", + // Enable alternate screen emulation otherwise applications get + // rendered in the current window which wipes out visible output + // resulting in missing output when scrolling back with the mouse + // wheel (copy mode still works since that is screen itself + // scrolling). + "altscreen on", + // Remap the control key to C-s since C-a may be used in + // applications. C-s cannot actually be used anyway since by + // default it will pause and C-q to resume will just kill the + // browser window. We may not want people using the control key + // anyway since it will not be obvious they are in screen and doing + // things like switching windows makes mouse wheel scroll wonky due + // to the terminal doing the scrolling rather than screen itself + // (but again copy mode will work just fine). + "escape ^Ss", + } + + dir := filepath.Join(os.TempDir(), "coder-screen") + config := filepath.Join(dir, "config") + screendir := filepath.Join(dir, "sockets") + + err = os.MkdirAll(screendir, 0o700) + if err != nil { + return nil, xerrors.Errorf("unable to create %s for session %s: %w", screendir, id, err) + } + + err = os.WriteFile(config, []byte(strings.Join(settings, "\n")), 0o644) + if err != nil { + return nil, xerrors.Errorf("unable to create %s for session %s: %w", config, id, err) + } + + var sess *session + if rawSession, ok := sessions.Load(id); ok { + if sess, ok = rawSession.(*session); !ok { + return nil, xerrors.Errorf("found invalid type in session map for ID %s", id) + } + } else { + ctx, cancel := context.WithCancel(context.Background()) + sess = &session{ + ctx: ctx, + ready: make(chan error, 1), + } + sessions.Store(id, sess) + + // Starting screen with -Dm causes it to launch a server tied to this + // process, letting us attach to and kill it with the PID. + process, err := execer.Start(ctx, Command{ + Command: "screen", + Args: append([]string{"-S", id, "-Dmc", config, command.Command}, command.Args...), + UID: command.UID, + GID: command.GID, + Env: append(command.Env, "SCREENDIR="+screendir), + WorkingDir: command.WorkingDir, + }) + if err != nil { + cancel() + err = xerrors.Errorf("failed to create session %s: %w", id, err) + sess.ready <- err + close(sess.ready) + return nil, err + } + + sess.screenID = fmt.Sprintf("%d.%s", process.Pid(), id) + + // Timeouts created with AfterFunc can be reset. + sess.timeout = time.AfterFunc(options.SessionTimeout, func() { + // Delete immediately in case it takes a while to clean up otherwise new + // connections with this ID will try to connect to the closing screen. + sessions.Delete(id) + // Close() will send a SIGTERM allowing screen to clean up. If we kill it + // abruptly the socket will be left behind which is probably harmless but + // it causes screen -list to show a bunch of dead sessions (of course the + // user would not see this since we have a custom socket directory) and it + // seems ideal to let it clean up anyway. + process.Close() + select { + case <-sess.ctx.Done(): + return + case <-time.After(5 * time.Second): + cancel() // Force screen to exit. + } + }) + + // Remove the session if screen exits. + go func() { + err := process.Wait() + // Screen exits with a one when it gets a SIGTERM. + if exitErr, ok := err.(ExitError); ok && exitErr.ExitCode() != 1 { + flog.Error("session %s exited with error %w", id, err) + } + cancel() + sess.timeout.Stop() + sessions.Delete(id) + }() + + // Sometimes if you attach too quickly after spawning screen it will say the + // session does not exist so run a command against it until it works. + go func() { + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + defer close(sess.ready) + for { + select { + case <-sess.ctx.Done(): + sess.ready <- xerrors.Errorf("session %s is gone", id) + return + case <-ctx.Done(): + sess.ready <- xerrors.Errorf("timed out waiting for session %s", id) + return + default: + process, err := execer.Start(ctx, Command{ + Command: "screen", + Args: []string{"-S", sess.screenID, "-X", "version"}, + UID: command.UID, + GID: command.GID, + Env: append(command.Env, "SCREENDIR="+screendir), + }) + if err != nil { + sess.ready <- xerrors.Errorf("error waiting for session %s: %w", id, err) + return + } + err = process.Wait() + // TODO: Send error if it is anything but "no screen session found". + if err == nil { + return + } + time.Sleep(250 * time.Millisecond) + } + } + }() + } + + // Block until the server is ready. + select { + case err := <-sess.ready: + if err != nil { + return nil, err + } + } + + // Refresh the session timeout now. + sess.timeout.Reset(options.SessionTimeout) + + // Keep refreshing the session timeout while this connection is alive. + heartbeat := time.NewTicker(options.SessionTimeout / 2) + go func() { + defer heartbeat.Stop() + // Reset when the connection closes to ensure the session stays up for the + // full timeout. + defer sess.timeout.Reset(options.SessionTimeout) + for { + select { + // Stop looping once this request finishes. + case <-ctx.Done(): + return + case <-heartbeat.C: + } + sess.timeout.Reset(options.SessionTimeout) + } + }() + + // Use screen to connect to the session. + return &Command{ + Command: "screen", + Args: []string{"-S", sess.screenID, "-xc", config}, + TTY: command.TTY, + Rows: command.Rows, + Cols: command.Cols, + Stdin: command.Stdin, + UID: command.UID, + GID: command.GID, + Env: append(command.Env, "SCREENDIR="+screendir), + WorkingDir: command.WorkingDir, + }, nil +} diff --git a/tty_test.go b/tty_test.go index cb629d3..070a2e6 100644 --- a/tty_test.go +++ b/tty_test.go @@ -35,6 +35,7 @@ func testTTY(ctx context.Context, t *testing.T, e Execer) { Stdin: true, Cols: 100, Rows: 100, + Env: []string{"TERM=xterm"}, }) assert.Success(t, "start sh", err) var wg sync.WaitGroup @@ -66,9 +67,54 @@ func testTTY(ctx context.Context, t *testing.T, e Execer) { } func TestReconnectTTY(t *testing.T) { - t.Parallel() + t.Run("NoScreen", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() - t.Run("RegularScreen", func(t *testing.T) { + t.Setenv("PATH", "/bin") + + command := Command{ + ID: uuid.NewString(), + Command: "sh", + TTY: true, + Stdin: true, + Cols: 100, + Rows: 100, + Env: []string{"TERM=xterm"}, + } + + ws, server := mockConn(ctx, t, nil) + defer server.Close() + + process, err := RemoteExecer(ws).Start(ctx, command) + assert.Success(t, "start sh", err) + + // Write some unique output. + echoCmd := "echo test:$((5+5))" + _, err = process.Stdin().Write([]byte(echoCmd + "\n")) + assert.Success(t, "write to stdin", err) + expected := []string{echoCmd, "test:10"} + + assert.True(t, "find echo", checkStdout(t, process, expected, []string{})) + + // Connect to the same session. + ws, server = mockConn(ctx, t, nil) + defer server.Close() + + process, err = RemoteExecer(ws).Start(ctx, command) + assert.Success(t, "attach sh", err) + + echoCmd = "echo test:$((6+6))" + _, err = process.Stdin().Write([]byte(echoCmd + "\r\n")) + assert.Success(t, "write to stdin", err) + unexpected := expected + expected = []string{"test:12"} + + // No echo since it is a new process. + assert.True(t, "find echo", checkStdout(t, process, expected, unexpected)) + }) + + t.Run("Regular", func(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -81,8 +127,13 @@ func TestReconnectTTY(t *testing.T) { Stdin: true, Cols: 100, Rows: 100, + Env: []string{"TERM=xterm"}, } + t.Cleanup(func() { + Dispose(command.ID) + }) + ws, server := mockConn(ctx, t, &Options{ SessionTimeout: time.Second, }) @@ -175,7 +226,7 @@ func TestReconnectTTY(t *testing.T) { assert.True(t, "find echo", checkStdout(t, process, expected, unexpected)) }) - t.Run("AlternateScreen", func(t *testing.T) { + t.Run("Alternate", func(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -188,8 +239,13 @@ func TestReconnectTTY(t *testing.T) { Stdin: true, Cols: 100, Rows: 100, + Env: []string{"TERM=xterm"}, } + t.Cleanup(func() { + Dispose(command.ID) + }) + ws, server := mockConn(ctx, t, &Options{ SessionTimeout: time.Second, }) From a8eaaa4f04e3a4d1794584b48af2e54b79175ef9 Mon Sep 17 00:00:00 2001 From: Asher Date: Thu, 22 Sep 2022 12:09:53 -0500 Subject: [PATCH 10/18] Comment on screenID --- server.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/server.go b/server.go index 92c4308..6f98cd2 100644 --- a/server.go +++ b/server.go @@ -304,6 +304,12 @@ func createSession(ctx context.Context, id string, command *Command, execer Exec return nil, err } + // screenID will be used to attach to the session that was just created. + // Screen allows attaching to a session via either the session name or + // .. Use the latter form to differentiate between + // sessions with the same name (for example if a session closes due to a + // timeout and a client reconnects around the same time there can be two + // sessions with the same name while the old session cleans up). sess.screenID = fmt.Sprintf("%d.%s", process.Pid(), id) // Timeouts created with AfterFunc can be reset. From b22b565c90eb3cacd791d9106ac21908eddd2b64 Mon Sep 17 00:00:00 2001 From: Asher Date: Thu, 22 Sep 2022 16:09:57 -0500 Subject: [PATCH 11/18] Encapsulate session logic --- server.go | 216 +++---------------------------------- session.go | 309 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 325 insertions(+), 200 deletions(-) create mode 100644 session.go diff --git a/server.go b/server.go index 6f98cd2..5f1deb9 100644 --- a/server.go +++ b/server.go @@ -5,13 +5,9 @@ import ( "context" "encoding/json" "errors" - "fmt" "io" "net" - "os" "os/exec" - "path/filepath" - "strings" "sync" "time" @@ -24,27 +20,18 @@ import ( ) var sessions sync.Map +var sessionsMutex sync.Mutex // Options allows configuring the server. type Options struct { SessionTimeout time.Duration } -type session struct { - ctx context.Context - screenID string - ready chan error - timeout *time.Timer -} - // Dispose closes the specified session. func Dispose(id string) { if rawSession, ok := sessions.Load(id); ok { - if sess, ok := rawSession.(*session); ok { - sess.timeout.Reset(0) - select { - case <-sess.ctx.Done(): - } + if s, ok := rawSession.(*Session); ok { + s.Close() } } } @@ -118,7 +105,7 @@ func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Optio // Only TTYs with IDs can be reconnected. if command.TTY && header.ID != "" { - command, err = createSession(ctx, header.ID, command, execer, options) + command, err = withSession(ctx, header.ID, command, execer, options) if err != nil { return err } @@ -225,7 +212,8 @@ func copyWithHeader(r io.Reader, w io.Writer, header proto.Header) error { return nil } -func createSession(ctx context.Context, id string, command *Command, execer Execer, options *Options) (*Command, error) { +// withSession wraps the command in a session if screen is available. +func withSession(ctx context.Context, id string, command *Command, execer Execer, options *Options) (*Command, error) { // If screen is not installed spawn the command normally. _, err := exec.LookPath("screen") if err != nil { @@ -233,193 +221,21 @@ func createSession(ctx context.Context, id string, command *Command, execer Exec return command, nil } - settings := []string{ - // Tell screen not to handle motion for xterm* terminals which - // allows scrolling the terminal via the mouse wheel or scroll bar - // (by default screen uses it to cycle through the command history). - // There does not seem to be a way to make screen itself scroll on - // mouse wheel. tmux can do it but then there is no scroll bar and - // it kicks you into copy mode where keys stop working until you - // exit copy mode which seems like it could be confusing. - "termcapinfo xterm* ti@:te@", - // Enable alternate screen emulation otherwise applications get - // rendered in the current window which wipes out visible output - // resulting in missing output when scrolling back with the mouse - // wheel (copy mode still works since that is screen itself - // scrolling). - "altscreen on", - // Remap the control key to C-s since C-a may be used in - // applications. C-s cannot actually be used anyway since by - // default it will pause and C-q to resume will just kill the - // browser window. We may not want people using the control key - // anyway since it will not be obvious they are in screen and doing - // things like switching windows makes mouse wheel scroll wonky due - // to the terminal doing the scrolling rather than screen itself - // (but again copy mode will work just fine). - "escape ^Ss", - } - - dir := filepath.Join(os.TempDir(), "coder-screen") - config := filepath.Join(dir, "config") - screendir := filepath.Join(dir, "sockets") - - err = os.MkdirAll(screendir, 0o700) - if err != nil { - return nil, xerrors.Errorf("unable to create %s for session %s: %w", screendir, id, err) - } - - err = os.WriteFile(config, []byte(strings.Join(settings, "\n")), 0o644) - if err != nil { - return nil, xerrors.Errorf("unable to create %s for session %s: %w", config, id, err) - } - - var sess *session + var s *Session + sessionsMutex.Lock() if rawSession, ok := sessions.Load(id); ok { - if sess, ok = rawSession.(*session); !ok { + if s, ok = rawSession.(*Session); !ok { return nil, xerrors.Errorf("found invalid type in session map for ID %s", id) } } else { - ctx, cancel := context.WithCancel(context.Background()) - sess = &session{ - ctx: ctx, - ready: make(chan error, 1), - } - sessions.Store(id, sess) - - // Starting screen with -Dm causes it to launch a server tied to this - // process, letting us attach to and kill it with the PID. - process, err := execer.Start(ctx, Command{ - Command: "screen", - Args: append([]string{"-S", id, "-Dmc", config, command.Command}, command.Args...), - UID: command.UID, - GID: command.GID, - Env: append(command.Env, "SCREENDIR="+screendir), - WorkingDir: command.WorkingDir, - }) - if err != nil { - cancel() - err = xerrors.Errorf("failed to create session %s: %w", id, err) - sess.ready <- err - close(sess.ready) - return nil, err - } - - // screenID will be used to attach to the session that was just created. - // Screen allows attaching to a session via either the session name or - // .. Use the latter form to differentiate between - // sessions with the same name (for example if a session closes due to a - // timeout and a client reconnects around the same time there can be two - // sessions with the same name while the old session cleans up). - sess.screenID = fmt.Sprintf("%d.%s", process.Pid(), id) - - // Timeouts created with AfterFunc can be reset. - sess.timeout = time.AfterFunc(options.SessionTimeout, func() { - // Delete immediately in case it takes a while to clean up otherwise new - // connections with this ID will try to connect to the closing screen. - sessions.Delete(id) - // Close() will send a SIGTERM allowing screen to clean up. If we kill it - // abruptly the socket will be left behind which is probably harmless but - // it causes screen -list to show a bunch of dead sessions (of course the - // user would not see this since we have a custom socket directory) and it - // seems ideal to let it clean up anyway. - process.Close() - select { - case <-sess.ctx.Done(): - return - case <-time.After(5 * time.Second): - cancel() // Force screen to exit. - } - }) - - // Remove the session if screen exits. - go func() { - err := process.Wait() - // Screen exits with a one when it gets a SIGTERM. - if exitErr, ok := err.(ExitError); ok && exitErr.ExitCode() != 1 { - flog.Error("session %s exited with error %w", id, err) - } - cancel() - sess.timeout.Stop() - sessions.Delete(id) - }() - - // Sometimes if you attach too quickly after spawning screen it will say the - // session does not exist so run a command against it until it works. - go func() { - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - defer close(sess.ready) - for { - select { - case <-sess.ctx.Done(): - sess.ready <- xerrors.Errorf("session %s is gone", id) - return - case <-ctx.Done(): - sess.ready <- xerrors.Errorf("timed out waiting for session %s", id) - return - default: - process, err := execer.Start(ctx, Command{ - Command: "screen", - Args: []string{"-S", sess.screenID, "-X", "version"}, - UID: command.UID, - GID: command.GID, - Env: append(command.Env, "SCREENDIR="+screendir), - }) - if err != nil { - sess.ready <- xerrors.Errorf("error waiting for session %s: %w", id, err) - return - } - err = process.Wait() - // TODO: Send error if it is anything but "no screen session found". - if err == nil { - return - } - time.Sleep(250 * time.Millisecond) - } - } + s = NewSession(id, command, execer, options) + sessions.Store(id, s) + go func() { // Remove the session from the map once it closes. + defer sessions.Delete(id) + s.Wait() }() } + sessionsMutex.Unlock() - // Block until the server is ready. - select { - case err := <-sess.ready: - if err != nil { - return nil, err - } - } - - // Refresh the session timeout now. - sess.timeout.Reset(options.SessionTimeout) - - // Keep refreshing the session timeout while this connection is alive. - heartbeat := time.NewTicker(options.SessionTimeout / 2) - go func() { - defer heartbeat.Stop() - // Reset when the connection closes to ensure the session stays up for the - // full timeout. - defer sess.timeout.Reset(options.SessionTimeout) - for { - select { - // Stop looping once this request finishes. - case <-ctx.Done(): - return - case <-heartbeat.C: - } - sess.timeout.Reset(options.SessionTimeout) - } - }() - - // Use screen to connect to the session. - return &Command{ - Command: "screen", - Args: []string{"-S", sess.screenID, "-xc", config}, - TTY: command.TTY, - Rows: command.Rows, - Cols: command.Cols, - Stdin: command.Stdin, - UID: command.UID, - GID: command.GID, - Env: append(command.Env, "SCREENDIR="+screendir), - WorkingDir: command.WorkingDir, - }, nil + return s.Attach(ctx) } diff --git a/session.go b/session.go new file mode 100644 index 0000000..4cbec38 --- /dev/null +++ b/session.go @@ -0,0 +1,309 @@ +package wsep + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "golang.org/x/xerrors" +) + +// Session represents a `screen` session. +type Session struct { + // close receives nil when the session should be closed. + close chan struct{} + // closing receives nil when the session has begun closing. The underlying + // process may still be exiting. closing implies ready. + closing chan struct{} + // command is the original command used to spawn the session. + command *Command + // configFile is the location of the screen configuration file. + configFile string + // done receives nil when the session has completely shut down and the process + // has exited. done implies closing and ready. + done chan struct{} + // error hold any error that occurred while starting the session. + error error + // execer is used to spawn the session and ready commands. + execer Execer + // id holds the id of the session for both creating and attaching. + id string + // options holds options for configuring the session. + options *Options + // ready receives nil when the session is ready to be attached. error must be + // checked after receiving on this channel. + ready chan struct{} + // socketsDir is the location of the directory where screen should put its + // sockets. + socketsDir string + // timer will close the session when it expires. + timer *time.Timer +} + +// NewSession creates and immediately starts a new session. Any errors with +// starting are returned on Attach(). The session will close itself if nothing +// is attached for the duration of the session timeout. +func NewSession(id string, command *Command, execer Execer, options *Options) *Session { + tempdir := filepath.Join(os.TempDir(), "coder-screen") + s := &Session{ + close: make(chan struct{}), + closing: make(chan struct{}), + command: command, + configFile: filepath.Join(tempdir, "config"), + done: make(chan struct{}), + execer: execer, + options: options, + ready: make(chan struct{}), + socketsDir: filepath.Join(tempdir, "sockets"), + } + go s.lifecycle(id) + return s +} + +// lifecycle manages the lifecycle of the session. +func (s *Session) lifecycle(id string) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Close the session after a timeout. Timeouts created with AfterFunc can be + // reset; we will reset it as long as there are active connections. + s.timer = time.AfterFunc(s.options.SessionTimeout, s.Close) + defer s.timer.Stop() + + // Close the session down immediately if there was an error and store that + // error to emit on Attach() calls. + process, err := s.start(ctx, id) + if err != nil { + s.error = err + close(s.ready) + close(s.closing) + close(s.done) + return + } + + // Mark the session as fully done when the process exits. + go func() { + err = process.Wait() + if err != nil { + s.error = err + } + close(s.done) + }() + + // Wait until the session is ready to receive attaches. + err = s.waitReady(ctx) + if err != nil { + s.error = err + close(s.ready) + close(s.closing) + // The deferred cancel will kill the process if it is still running which + // will then close the done channel. + return + } + + close(s.ready) + + select { + // When the session is closed try gracefully killing the process. + case <-s.close: + // Mark the session as closing so you can use Wait() to stop attaching new + // connections to sessions that are closing down. + close(s.closing) + // process.Close() will send a SIGTERM allowing screen to clean up. If we + // kill it abruptly the socket will be left behind which is probably + // harmless but it causes screen -list to show a bunch of dead sessions. + // The user would likely not see this since we have a custom socket + // directory but it seems ideal to let screen clean up anyway. + process.Close() + select { + case <-s.done: + return // Process exited on its own. + case <-time.After(5 * time.Second): + // Still running; yield so we can run the deferred cancel which will + // forcefully terminate the session. + return + } + // If the process exits on its own the session is also considered closed. + case <-s.done: + close(s.closing) + } +} + +// start starts the session. +func (s *Session) start(ctx context.Context, id string) (Process, error) { + err := s.ensureSettings() + if err != nil { + return nil, err + } + + // -S is for setting the session's name. + // -Dm causes screen to launch a server tied to this process, letting us + // attach to and kill it with the PID of this process (rather than having + // to do something flaky like run `screen -S id -quit`). + // -c is the flag for the config file. + process, err := s.execer.Start(ctx, Command{ + Command: "screen", + Args: append([]string{"-S", id, "-Dmc", s.configFile, s.command.Command}, s.command.Args...), + UID: s.command.UID, + GID: s.command.GID, + Env: append(s.command.Env, "SCREENDIR="+s.socketsDir), + WorkingDir: s.command.WorkingDir, + }) + if err != nil { + return nil, err + } + + // Screen allows targeting sessions via either the session name or + // .. Using the latter form allows us to differentiate + // between sessions with the same name. For example if a session closes due + // to a timeout and a client reconnects around the same time there can be two + // sessions with the same name while the old session is cleaning up. This is + // only a problem while attaching and not creating since the creation command + // used always creates a new session. + s.id = fmt.Sprintf("%d.%s", process.Pid(), id) + return process, nil +} + +// waitReady waits for the session to be ready. Sometimes if you attach too +// quickly after spawning screen it will say the session does not exist so this +// will run a command against the session until it works or times out. +func (s *Session) waitReady(ctx context.Context) error { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + for { + select { + case <-s.close: + return xerrors.Errorf("session has been closed") + case <-s.done: + return xerrors.Errorf("session has exited") + case <-ctx.Done(): + return xerrors.Errorf("timed out waiting for session") + default: + // The `version` command seems to be the only command without a side + // effect so use it to check whether the session is up. + process, err := s.execer.Start(ctx, Command{ + Command: "screen", + Args: []string{"-S", s.id, "-X", "version"}, + UID: s.command.UID, + GID: s.command.GID, + Env: append(s.command.Env, "SCREENDIR="+s.socketsDir), + }) + if err != nil { + return err + } + err = process.Wait() + // TODO: Return error if it is anything but "no screen session found". + if err == nil { + return nil + } + time.Sleep(250 * time.Millisecond) + } + } +} + +// Attach waits for the session to become attachable and returns a command that +// can be used to attach to the session. +func (s *Session) Attach(ctx context.Context) (*Command, error) { + <-s.ready + if s.error != nil { + return nil, s.error + } + + go s.heartbeat(ctx) + + return &Command{ + Command: "screen", + Args: []string{"-S", s.id, "-xc", s.configFile}, + TTY: s.command.TTY, + Rows: s.command.Rows, + Cols: s.command.Cols, + Stdin: s.command.Stdin, + UID: s.command.UID, + GID: s.command.GID, + Env: append(s.command.Env, "SCREENDIR="+s.socketsDir), + WorkingDir: s.command.WorkingDir, + }, nil +} + +// heartbeat keeps the session alive while the provided context is not done. +func (s *Session) heartbeat(ctx context.Context) { + // We just connected so reset the timer now in case it is near the end or this + // is the first connection and the timer has not been set yet. + s.timer.Reset(s.options.SessionTimeout) + + // Reset when the connection closes to ensure the session stays up for the + // full timeout. + defer s.timer.Reset(s.options.SessionTimeout) + + heartbeat := time.NewTicker(s.options.SessionTimeout / 2) + defer heartbeat.Stop() + + for { + select { + case <-s.closing: + return + case <-ctx.Done(): + return + case <-heartbeat.C: + } + s.timer.Reset(s.options.SessionTimeout) + } +} + +// Wait waits for the session to close. The underlying process might still be +// exiting. +func (s *Session) Wait() { + <-s.closing +} + +// Close attempts to gracefully kill the session's underlying process then waits +// for the process to exit. If the session does not exit in a timely manner it +// forcefully kills the process. +func (s *Session) Close() { + select { + case s.close <- struct{}{}: + default: // Do not block; the lifecycle has already completed. + } + <-s.done +} + +// ensureSettings writes config settings and creates the socket directory. +func (s *Session) ensureSettings() error { + settings := []string{ + // Tell screen not to handle motion for xterm* terminals which allows + // scrolling the terminal via the mouse wheel or scroll bar (by default + // screen uses it to cycle through the command history). There does not + // seem to be a way to make screen itself scroll on mouse wheel. tmux can + // do it but then there is no scroll bar and it kicks you into copy mode + // where keys stop working until you exit copy mode which seems like it + // could be confusing. + "termcapinfo xterm* ti@:te@", + // Enable alternate screen emulation otherwise applications get rendered in + // the current window which wipes out visible output resulting in missing + // output when scrolling back with the mouse wheel (copy mode still works + // since that is screen itself scrolling). + "altscreen on", + // Remap the control key to C-s since C-a may be used in applications. C-s + // cannot actually be used anyway since by default it will pause and C-q to + // resume will just kill the browser window. We may not want people using + // the control key anyway since it will not be obvious they are in screen + // and doing things like switching windows makes mouse wheel scroll wonky + // due to the terminal doing the scrolling rather than screen itself (but + // again copy mode will work just fine). + "escape ^Ss", + } + + dir := filepath.Join(os.TempDir(), "coder-screen") + config := filepath.Join(dir, "config") + socketdir := filepath.Join(dir, "sockets") + + err := os.MkdirAll(socketdir, 0o700) + if err != nil { + return err + } + + return os.WriteFile(config, []byte(strings.Join(settings, "\n")), 0o644) +} From cfad1966271f659e21c5bfb43beb4a835a40410b Mon Sep 17 00:00:00 2001 From: Asher Date: Tue, 27 Sep 2022 15:27:44 -0500 Subject: [PATCH 12/18] Localize session map --- client_test.go | 26 ++++++++--- server.go | 118 ++++++++++++++++++++++++++++++++----------------- tty_test.go | 79 ++++++++++++++++++++++++++------- 3 files changed, 162 insertions(+), 61 deletions(-) diff --git a/client_test.go b/client_test.go index 3ec1a77..47c0aaa 100644 --- a/client_test.go +++ b/client_test.go @@ -49,14 +49,18 @@ func TestRemoteStdin(t *testing.T) { } } -func mockConn(ctx context.Context, t *testing.T, options *Options) (*websocket.Conn, *httptest.Server) { +func mockConn(ctx context.Context, t *testing.T, wsepServer *Server, options *Options) (*websocket.Conn, *httptest.Server) { mockServerHandler := func(w http.ResponseWriter, r *http.Request) { ws, err := websocket.Accept(w, r, nil) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } - err = Serve(r.Context(), ws, LocalExecer{}, options) + if wsepServer != nil { + err = wsepServer.Serve(r.Context(), ws, LocalExecer{}, options) + } else { + err = Serve(r.Context(), ws, LocalExecer{}, options) + } if err != nil { t.Errorf("failed to serve execer: %v", err) ws.Close(websocket.StatusAbnormalClosure, "failed to serve execer") @@ -77,7 +81,11 @@ func TestRemoteExec(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() - ws, server := mockConn(ctx, t, nil) + wsepServer := NewServer() + defer wsepServer.Close() + defer assert.Equal(t, "no leaked sessions", 0, wsepServer.SessionCount()) + + ws, server := mockConn(ctx, t, wsepServer, nil) defer server.Close() execer := RemoteExecer(ws) @@ -89,7 +97,11 @@ func TestRemoteExecFail(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() - ws, server := mockConn(ctx, t, nil) + wsepServer := NewServer() + defer wsepServer.Close() + defer assert.Equal(t, "no leaked sessions", 0, wsepServer.SessionCount()) + + ws, server := mockConn(ctx, t, wsepServer, nil) defer server.Close() execer := RemoteExecer(ws) @@ -124,7 +136,11 @@ func TestStderrVsStdout(t *testing.T) { stderr bytes.Buffer ) - ws, server := mockConn(ctx, t, nil) + wsepServer := NewServer() + defer wsepServer.Close() + defer assert.Equal(t, "no leaked sessions", 0, wsepServer.SessionCount()) + + ws, server := mockConn(ctx, t, wsepServer, nil) defer server.Close() execer := RemoteExecer(ws) diff --git a/server.go b/server.go index 5f1deb9..fdfa02d 100644 --- a/server.go +++ b/server.go @@ -19,27 +19,65 @@ import ( "cdr.dev/wsep/internal/proto" ) -var sessions sync.Map -var sessionsMutex sync.Mutex - // Options allows configuring the server. type Options struct { SessionTimeout time.Duration } -// Dispose closes the specified session. -func Dispose(id string) { - if rawSession, ok := sessions.Load(id); ok { +// _sessions is a global map of sessions that exists for backwards +// compatibility. Server should be used instead which locally maintains the +// map. +var _sessions sync.Map + +// _sessionsMutex is a global mutex that exists for backwards compatibility. +// Server should be used instead which locally maintains the mutex. +var _sessionsMutex sync.Mutex + +// Serve runs the server-side of wsep. +// Deprecated: Use Server.Serve() instead. +func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Options) error { + srv := Server{sessions: &_sessions, sessionsMutex: &_sessionsMutex} + return srv.Serve(ctx, c, execer, options) +} + +// Server runs the server-side of wsep. The execer may be another wsep +// connection for chaining. Use LocalExecer for local command execution. +type Server struct { + sessions *sync.Map + sessionsMutex *sync.Mutex +} + +// NewServer returns as new wsep server. +func NewServer() *Server { + return &Server{ + sessions: &sync.Map{}, + sessionsMutex: &sync.Mutex{}, + } +} + +// SessionCount returns the number of sessions. +func (srv *Server) SessionCount() int { + var i int + srv.sessions.Range(func(k, rawSession interface{}) bool { + i++ + return true + }) + return i +} + +// Close closes all sessions. +func (srv *Server) Close() { + srv.sessions.Range(func(k, rawSession interface{}) bool { if s, ok := rawSession.(*Session); ok { s.Close() } - } + return true + }) } -// Serve runs the server-side of wsep. -// The execer may be another wsep connection for chaining. -// Use LocalExecer for local command execution. -func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Options) error { +// Serve runs the server-side of wsep. The execer may be another wsep +// connection for chaining. Use LocalExecer for local command execution. +func (srv *Server) Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Options) error { ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -105,7 +143,7 @@ func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Optio // Only TTYs with IDs can be reconnected. if command.TTY && header.ID != "" { - command, err = withSession(ctx, header.ID, command, execer, options) + command, err = srv.withSession(ctx, header.ID, command, execer, options) if err != nil { return err } @@ -169,6 +207,34 @@ func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Optio } } +// withSession wraps the command in a session if screen is available. +func (srv *Server) withSession(ctx context.Context, id string, command *Command, execer Execer, options *Options) (*Command, error) { + // If screen is not installed spawn the command normally. + _, err := exec.LookPath("screen") + if err != nil { + flog.Info("`screen` could not be found; session %s will not persist", id) + return command, nil + } + + var s *Session + srv.sessionsMutex.Lock() + if rawSession, ok := srv.sessions.Load(id); ok { + if s, ok = rawSession.(*Session); !ok { + return nil, xerrors.Errorf("found invalid type in session map for ID %s", id) + } + } else { + s = NewSession(id, command, execer, options) + srv.sessions.Store(id, s) + go func() { // Remove the session from the map once it closes. + defer srv.sessions.Delete(id) + s.Wait() + }() + } + srv.sessionsMutex.Unlock() + + return s.Attach(ctx) +} + func sendExitCode(_ context.Context, err error, conn net.Conn) error { exitCode := 0 errorStr := "" @@ -211,31 +277,3 @@ func copyWithHeader(r io.Reader, w io.Writer, header proto.Header) error { } return nil } - -// withSession wraps the command in a session if screen is available. -func withSession(ctx context.Context, id string, command *Command, execer Execer, options *Options) (*Command, error) { - // If screen is not installed spawn the command normally. - _, err := exec.LookPath("screen") - if err != nil { - flog.Info("`screen` could not be found; session %s will not persist", id) - return command, nil - } - - var s *Session - sessionsMutex.Lock() - if rawSession, ok := sessions.Load(id); ok { - if s, ok = rawSession.(*Session); !ok { - return nil, xerrors.Errorf("found invalid type in session map for ID %s", id) - } - } else { - s = NewSession(id, command, execer, options) - sessions.Store(id, s) - go func() { // Remove the session from the map once it closes. - defer sessions.Delete(id) - s.Wait() - }() - } - sessionsMutex.Unlock() - - return s.Attach(ctx) -} diff --git a/tty_test.go b/tty_test.go index 070a2e6..3930f06 100644 --- a/tty_test.go +++ b/tty_test.go @@ -20,7 +20,11 @@ func TestTTY(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - ws, server := mockConn(ctx, t, nil) + wsepServer := NewServer() + defer wsepServer.Close() + defer assert.Equal(t, "no leaked sessions", 0, wsepServer.SessionCount()) + + ws, server := mockConn(ctx, t, wsepServer, nil) defer ws.Close(websocket.StatusInternalError, "") defer server.Close() @@ -67,6 +71,45 @@ func testTTY(ctx context.Context, t *testing.T, e Execer) { } func TestReconnectTTY(t *testing.T) { + t.Run("DeprecatedServe", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + command := Command{ + ID: uuid.NewString(), + Command: "sh", + TTY: true, + Stdin: true, + Cols: 100, + Rows: 100, + Env: []string{"TERM=xterm"}, + } + + ws, server := mockConn(ctx, t, nil, nil) + defer server.Close() + + process, err := RemoteExecer(ws).Start(ctx, command) + assert.Success(t, "start sh", err) + + // Write some unique output. + echoCmd := "echo test:$((12+12))" + _, err = process.Stdin().Write([]byte(echoCmd + "\n")) + assert.Success(t, "write to stdin", err) + expected := []string{echoCmd, "test:24"} + + assert.True(t, "find echo", checkStdout(t, process, expected, []string{})) + + // Connect to the same session. + ws, server = mockConn(ctx, t, nil, nil) + defer server.Close() + + process, err = RemoteExecer(ws).Start(ctx, command) + assert.Success(t, "start sh", err) + + // Find the same output. + assert.True(t, "find echo", checkStdout(t, process, expected, []string{})) + }) + t.Run("NoScreen", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -83,7 +126,11 @@ func TestReconnectTTY(t *testing.T) { Env: []string{"TERM=xterm"}, } - ws, server := mockConn(ctx, t, nil) + wsepServer := NewServer() + defer wsepServer.Close() + defer assert.Equal(t, "no leaked sessions", 0, wsepServer.SessionCount()) + + ws, server := mockConn(ctx, t, wsepServer, nil) defer server.Close() process, err := RemoteExecer(ws).Start(ctx, command) @@ -98,7 +145,7 @@ func TestReconnectTTY(t *testing.T) { assert.True(t, "find echo", checkStdout(t, process, expected, []string{})) // Connect to the same session. - ws, server = mockConn(ctx, t, nil) + ws, server = mockConn(ctx, t, wsepServer, nil) defer server.Close() process, err = RemoteExecer(ws).Start(ctx, command) @@ -130,11 +177,11 @@ func TestReconnectTTY(t *testing.T) { Env: []string{"TERM=xterm"}, } - t.Cleanup(func() { - Dispose(command.ID) - }) + wsepServer := NewServer() + defer wsepServer.Close() + defer assert.Equal(t, "no leaked sessions", 0, wsepServer.SessionCount()) - ws, server := mockConn(ctx, t, &Options{ + ws, server := mockConn(ctx, t, wsepServer, &Options{ SessionTimeout: time.Second, }) defer server.Close() @@ -155,7 +202,7 @@ func TestReconnectTTY(t *testing.T) { server.Close() // Reconnect. - ws, server = mockConn(ctx, t, &Options{ + ws, server = mockConn(ctx, t, wsepServer, &Options{ SessionTimeout: time.Second, }) defer server.Close() @@ -174,7 +221,7 @@ func TestReconnectTTY(t *testing.T) { assert.True(t, "find echo", checkStdout(t, process, expected, []string{})) // Make a simultaneously active connection. - ws2, server2 := mockConn(ctx, t, &Options{ + ws2, server2 := mockConn(ctx, t, wsepServer, &Options{ // Divide the time to test that the heartbeat keeps it open through multiple // intervals. SessionTimeout: time.Second / 4, @@ -208,7 +255,7 @@ func TestReconnectTTY(t *testing.T) { time.Sleep(time.Second) // The next connection should start a new process. - ws, server = mockConn(ctx, t, &Options{ + ws, server = mockConn(ctx, t, wsepServer, &Options{ SessionTimeout: time.Second, }) defer server.Close() @@ -242,11 +289,11 @@ func TestReconnectTTY(t *testing.T) { Env: []string{"TERM=xterm"}, } - t.Cleanup(func() { - Dispose(command.ID) - }) + wsepServer := NewServer() + defer wsepServer.Close() + defer assert.Equal(t, "no leaked sessions", 0, wsepServer.SessionCount()) - ws, server := mockConn(ctx, t, &Options{ + ws, server := mockConn(ctx, t, wsepServer, &Options{ SessionTimeout: time.Second, }) defer server.Close() @@ -265,7 +312,7 @@ func TestReconnectTTY(t *testing.T) { server.Close() // Reconnect; the application should redraw. - ws, server = mockConn(ctx, t, &Options{ + ws, server = mockConn(ctx, t, wsepServer, &Options{ SessionTimeout: time.Second, }) defer server.Close() @@ -285,7 +332,7 @@ func TestReconnectTTY(t *testing.T) { server.Close() // Reconnect. - ws, server = mockConn(ctx, t, &Options{ + ws, server = mockConn(ctx, t, wsepServer, &Options{ SessionTimeout: time.Second, }) defer server.Close() From f917a76baec38f438264bacfe8b29baa3ce158e2 Mon Sep 17 00:00:00 2001 From: Asher Date: Mon, 26 Sep 2022 12:55:55 -0500 Subject: [PATCH 13/18] Consolidate test scaffolding into helpers I think this helps make the tests a bit more concise. --- tty_test.go | 420 +++++++++++++++++----------------------------------- 1 file changed, 133 insertions(+), 287 deletions(-) diff --git a/tty_test.go b/tty_test.go index 3930f06..b5ab405 100644 --- a/tty_test.go +++ b/tty_test.go @@ -3,7 +3,8 @@ package wsep import ( "bufio" "context" - "io/ioutil" + "fmt" + "math/rand" "strings" "sync" "testing" @@ -17,335 +18,180 @@ import ( func TestTTY(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - wsepServer := NewServer() - defer wsepServer.Close() - defer assert.Equal(t, "no leaked sessions", 0, wsepServer.SessionCount()) - - ws, server := mockConn(ctx, t, wsepServer, nil) - defer ws.Close(websocket.StatusInternalError, "") - defer server.Close() - - execer := RemoteExecer(ws) - testTTY(ctx, t, execer) -} - -func testTTY(ctx context.Context, t *testing.T, e Execer) { - process, err := e.Start(ctx, Command{ - Command: "sh", - TTY: true, - Stdin: true, - Cols: 100, - Rows: 100, - Env: []string{"TERM=xterm"}, - }) - assert.Success(t, "start sh", err) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - - stdout, err := ioutil.ReadAll(process.Stdout()) - assert.Success(t, "read stdout", err) - - t.Logf("bash tty stdout = %s", stdout) - prompt := string(stdout) - assert.True(t, `bash "$" or "#" prompt found`, - strings.HasSuffix(prompt, "$ ") || strings.HasSuffix(prompt, "# ")) - }() - wg.Add(1) - go func() { - defer wg.Done() - - stderr, err := ioutil.ReadAll(process.Stderr()) - assert.Success(t, "read stderr", err) - t.Logf("bash tty stderr = %s", stderr) - assert.True(t, "stderr is empty", len(stderr) == 0) - }() - time.Sleep(3 * time.Second) - - process.Close() - wg.Wait() + // Run some output in a new session. + server := newServer(t) + ctx, command := newSession(t) + command.ID = "" // No ID so we do not start a reconnectable session. + process, _ := connect(ctx, t, command, server, nil) + expected := writeUnique(t, process) + assert.True(t, "find initial output", checkStdout(t, process, expected, []string{})) + + // Connect to the same session. There should not be shared output since + // these end up being separate sessions due to the lack of an ID. + process, _ = connect(ctx, t, command, server, nil) + unexpected := expected + expected = writeUnique(t, process) + assert.True(t, "find new session output", checkStdout(t, process, expected, unexpected)) } func TestReconnectTTY(t *testing.T) { t.Run("DeprecatedServe", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - command := Command{ - ID: uuid.NewString(), - Command: "sh", - TTY: true, - Stdin: true, - Cols: 100, - Rows: 100, - Env: []string{"TERM=xterm"}, - } - - ws, server := mockConn(ctx, t, nil, nil) - defer server.Close() - - process, err := RemoteExecer(ws).Start(ctx, command) - assert.Success(t, "start sh", err) - - // Write some unique output. - echoCmd := "echo test:$((12+12))" - _, err = process.Stdin().Write([]byte(echoCmd + "\n")) - assert.Success(t, "write to stdin", err) - expected := []string{echoCmd, "test:24"} - - assert.True(t, "find echo", checkStdout(t, process, expected, []string{})) - - // Connect to the same session. - ws, server = mockConn(ctx, t, nil, nil) - defer server.Close() - - process, err = RemoteExecer(ws).Start(ctx, command) - assert.Success(t, "start sh", err) - - // Find the same output. - assert.True(t, "find echo", checkStdout(t, process, expected, []string{})) + // Do something in the first session. + ctx, command := newSession(t) + process, _ := connect(ctx, t, command, nil, nil) + expected := writeUnique(t, process) + assert.True(t, "find initial output", checkStdout(t, process, expected, []string{})) + + // Connect to the same session. Should see the same output. + process, _ = connect(ctx, t, command, nil, nil) + assert.True(t, "find reconnected output", checkStdout(t, process, expected, []string{})) }) t.Run("NoScreen", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - t.Setenv("PATH", "/bin") - command := Command{ - ID: uuid.NewString(), - Command: "sh", - TTY: true, - Stdin: true, - Cols: 100, - Rows: 100, - Env: []string{"TERM=xterm"}, - } - - wsepServer := NewServer() - defer wsepServer.Close() - defer assert.Equal(t, "no leaked sessions", 0, wsepServer.SessionCount()) - - ws, server := mockConn(ctx, t, wsepServer, nil) - defer server.Close() - - process, err := RemoteExecer(ws).Start(ctx, command) - assert.Success(t, "start sh", err) + // Run some output in a new session. + server := newServer(t) + ctx, command := newSession(t) + process, _ := connect(ctx, t, command, server, nil) + expected := writeUnique(t, process) + assert.True(t, "find initial output", checkStdout(t, process, expected, []string{})) - // Write some unique output. - echoCmd := "echo test:$((5+5))" - _, err = process.Stdin().Write([]byte(echoCmd + "\n")) - assert.Success(t, "write to stdin", err) - expected := []string{echoCmd, "test:10"} - - assert.True(t, "find echo", checkStdout(t, process, expected, []string{})) - - // Connect to the same session. - ws, server = mockConn(ctx, t, wsepServer, nil) - defer server.Close() - - process, err = RemoteExecer(ws).Start(ctx, command) - assert.Success(t, "attach sh", err) - - echoCmd = "echo test:$((6+6))" - _, err = process.Stdin().Write([]byte(echoCmd + "\r\n")) - assert.Success(t, "write to stdin", err) + // Connect to the same session. There should not be shared output since + // these end up being separate sessions due to the lack of screen. + process, _ = connect(ctx, t, command, server, nil) unexpected := expected - expected = []string{"test:12"} - - // No echo since it is a new process. - assert.True(t, "find echo", checkStdout(t, process, expected, unexpected)) + expected = writeUnique(t, process) + assert.True(t, "find new session output", checkStdout(t, process, expected, unexpected)) }) t.Run("Regular", func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - command := Command{ - ID: uuid.NewString(), - Command: "sh", - TTY: true, - Stdin: true, - Cols: 100, - Rows: 100, - Env: []string{"TERM=xterm"}, - } - - wsepServer := NewServer() - defer wsepServer.Close() - defer assert.Equal(t, "no leaked sessions", 0, wsepServer.SessionCount()) - - ws, server := mockConn(ctx, t, wsepServer, &Options{ - SessionTimeout: time.Second, - }) - defer server.Close() - - process, err := RemoteExecer(ws).Start(ctx, command) - assert.Success(t, "start sh", err) - - // Write some unique output. - echoCmd := "echo test:$((1+1))" - _, err = process.Stdin().Write([]byte(echoCmd + "\n")) - assert.Success(t, "write to stdin", err) - expected := []string{echoCmd, "test:2"} - - assert.True(t, "find echo", checkStdout(t, process, expected, []string{})) - - // Disconnect. - ws.Close(websocket.StatusNormalClosure, "disconnected") - server.Close() - - // Reconnect. - ws, server = mockConn(ctx, t, wsepServer, &Options{ - SessionTimeout: time.Second, - }) - defer server.Close() - - process, err = RemoteExecer(ws).Start(ctx, command) - assert.Success(t, "attach sh", err) - - // The inactivity timeout should not trigger since we are connected. + // Run some output in a new session. + server := newServer(t) + ctx, command := newSession(t) + process, disconnect := connect(ctx, t, command, server, nil) + expected := writeUnique(t, process) + assert.True(t, "find initial output", checkStdout(t, process, expected, []string{})) + + // Reconnect and sleep; the inactivity timeout should not trigger since we + // were not disconnected during the timeout. + disconnect() + process, disconnect = connect(ctx, t, command, server, nil) time.Sleep(time.Second) - - echoCmd = "echo test:$((2+2))" - _, err = process.Stdin().Write([]byte(echoCmd + "\n")) - assert.Success(t, "write to stdin", err) - expected = append(expected, echoCmd, "test:4") - - assert.True(t, "find echo", checkStdout(t, process, expected, []string{})) + expected = append(expected, writeUnique(t, process)...) + assert.True(t, "find reconnected output", checkStdout(t, process, expected, []string{})) // Make a simultaneously active connection. - ws2, server2 := mockConn(ctx, t, wsepServer, &Options{ - // Divide the time to test that the heartbeat keeps it open through multiple - // intervals. + process2, disconnect2 := connect(ctx, t, command, server, &Options{ + // Divide the time to test that the heartbeat keeps it open through + // multiple intervals. SessionTimeout: time.Second / 4, }) - defer server2.Close() - process, err = RemoteExecer(ws2).Start(ctx, command) - assert.Success(t, "attach sh", err) - - // Disconnect the first connection. - ws.Close(websocket.StatusNormalClosure, "disconnected") - server.Close() - - // Wait for inactivity. It should still stay up because of the second - // connection. + // Disconnect the first connection and wait for inactivity. The session + // should stay up because of the second connection. + disconnect() time.Sleep(time.Second) + expected = append(expected, writeUnique(t, process2)...) + assert.True(t, "find second connection output", checkStdout(t, process2, expected, []string{})) - // This connection should still be up. - echoCmd = "echo test:$((3+3))" - _, err = process.Stdin().Write([]byte(echoCmd + "\n")) - assert.Success(t, "write to stdin", err) - expected = append(expected, echoCmd, "test:6") - - assert.True(t, "find echo", checkStdout(t, process, expected, []string{})) - - // Disconnect the second connection. - ws2.Close(websocket.StatusNormalClosure, "disconnected") - server2.Close() - - // Wait for inactivity. + // Disconnect the second connection and wait for inactivity. The next + // connection should start a new session so we should only see new output + // and not any output from the old session. + disconnect2() time.Sleep(time.Second) - - // The next connection should start a new process. - ws, server = mockConn(ctx, t, wsepServer, &Options{ - SessionTimeout: time.Second, - }) - defer server.Close() - - process, err = RemoteExecer(ws).Start(ctx, command) - assert.Success(t, "attach sh", err) - - echoCmd = "echo test:$((4+4))" - _, err = process.Stdin().Write([]byte(echoCmd + "\r\n")) - assert.Success(t, "write to stdin", err) + process, disconnect = connect(ctx, t, command, server, nil) unexpected := expected - expected = []string{"test:8"} - - // This time no echo since it is a new process. - assert.True(t, "find echo", checkStdout(t, process, expected, unexpected)) + expected = writeUnique(t, process) + assert.True(t, "find new session output", checkStdout(t, process, expected, unexpected)) }) t.Run("Alternate", func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - command := Command{ - ID: uuid.NewString(), - Command: "sh", - TTY: true, - Stdin: true, - Cols: 100, - Rows: 100, - Env: []string{"TERM=xterm"}, - } - - wsepServer := NewServer() - defer wsepServer.Close() - defer assert.Equal(t, "no leaked sessions", 0, wsepServer.SessionCount()) - - ws, server := mockConn(ctx, t, wsepServer, &Options{ - SessionTimeout: time.Second, - }) - defer server.Close() - - process, err := RemoteExecer(ws).Start(ctx, command) - assert.Success(t, "attach sh", err) - // Run an application that enters the alternate screen. - _, err = process.Stdin().Write([]byte("./ci/alt.sh\n")) - assert.Success(t, "write to stdin", err) - - assert.True(t, "find output", checkStdout(t, process, []string{"./ci/alt.sh", "ALT SCREEN"}, []string{})) + server := newServer(t) + ctx, command := newSession(t) + process, disconnect := connect(ctx, t, command, server, nil) + write(t, process, "./ci/alt.sh") + assert.True(t, "find alt screen", checkStdout(t, process, []string{"./ci/alt.sh", "ALT SCREEN"}, []string{})) + + // Reconnect; the application should redraw. We should have only the + // application output and not the command that spawned the application. + disconnect() + process, disconnect = connect(ctx, t, command, server, nil) + assert.True(t, "find reconnected alt screen", checkStdout(t, process, []string{"ALT SCREEN"}, []string{"./ci/alt.sh"})) + + // Exit the application and reconnect. Should now be in a regular shell. + write(t, process, "q") + disconnect() + process, _ = connect(ctx, t, command, server, nil) + expected := writeUnique(t, process) + assert.True(t, "find shell output", checkStdout(t, process, expected, []string{})) + }) +} - // Disconnect. - ws.Close(websocket.StatusNormalClosure, "disconnected") +// newServer returns a new wsep server. +func newServer(t *testing.T) *Server { + server := NewServer() + t.Cleanup(func() { server.Close() + assert.Equal(t, "no leaked sessions", 0, server.SessionCount()) + }) + return server +} - // Reconnect; the application should redraw. - ws, server = mockConn(ctx, t, wsepServer, &Options{ - SessionTimeout: time.Second, - }) - defer server.Close() +// newSession returns a command for starting/attaching to a session with a +// context for timing out. +func newSession(t *testing.T) (context.Context, Command) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(cancel) + + command := Command{ + ID: uuid.NewString(), + Command: "sh", + TTY: true, + Stdin: true, + Cols: 100, + Rows: 100, + Env: []string{"TERM=xterm"}, + } - process, err = RemoteExecer(ws).Start(ctx, command) - assert.Success(t, "attach sh", err) + return ctx, command +} - // Should have only the application output. - assert.True(t, "find output", checkStdout(t, process, []string{"ALT SCREEN"}, []string{"./ci/alt.sh"})) +// connect connects to a wsep server and runs the provided command. +func connect(ctx context.Context, t *testing.T, command Command, wsepServer *Server, options *Options) (Process, func()) { + if options == nil { + options = &Options{SessionTimeout: time.Second} + } + ws, server := mockConn(ctx, t, wsepServer, options) + t.Cleanup(server.Close) - // Exit the application. - _, err = process.Stdin().Write([]byte("q")) - assert.Success(t, "write to stdin", err) + process, err := RemoteExecer(ws).Start(ctx, command) + assert.Success(t, "start sh", err) - // Disconnect. + return process, func() { ws.Close(websocket.StatusNormalClosure, "disconnected") server.Close() + } +} - // Reconnect. - ws, server = mockConn(ctx, t, wsepServer, &Options{ - SessionTimeout: time.Second, - }) - defer server.Close() - - process, err = RemoteExecer(ws).Start(ctx, command) - assert.Success(t, "attach sh", err) - - echoCmd := "echo test:$((5+5))" - _, err = process.Stdin().Write([]byte(echoCmd + "\r\n")) - assert.Success(t, "write to stdin", err) +// writeUnique writes some unique output to the shell process and returns the +// expected output. +func writeUnique(t *testing.T, process Process) []string { + n := rand.Intn(100) + echoCmd := fmt.Sprintf("echo test:$((%d+%d))", n, n) + write(t, process, echoCmd) + return []string{echoCmd, fmt.Sprintf("test:%d", n+n)} +} - assert.True(t, "find output", checkStdout(t, process, []string{echoCmd, "test:10"}, []string{})) - }) +// write writes the provided input followed by a newline to the shell process. +func write(t *testing.T, process Process, input string) { + _, err := process.Stdin().Write([]byte(input + "\n")) + assert.Success(t, "write to stdin", err) } // checkStdout ensures that expected is in the stdout in the specified order. From 23a56e0e62bfc92f6907d73c8c62b7378e7a4b7b Mon Sep 17 00:00:00 2001 From: Asher Date: Mon, 26 Sep 2022 13:08:29 -0500 Subject: [PATCH 14/18] Test many connections at once --- tty_test.go | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/tty_test.go b/tty_test.go index b5ab405..1cc37c2 100644 --- a/tty_test.go +++ b/tty_test.go @@ -131,6 +131,33 @@ func TestReconnectTTY(t *testing.T) { expected := writeUnique(t, process) assert.True(t, "find shell output", checkStdout(t, process, expected, []string{})) }) + + t.Run("Simultaneous", func(t *testing.T) { + t.Parallel() + + server := newServer(t) + + // Try connecting a bunch of sessions at once. + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + ctx, command := newSession(t) + process, disconnect := connect(ctx, t, command, server, nil) + expected := writeUnique(t, process) + assert.True(t, "find initial output", checkStdout(t, process, expected, []string{})) + + n := rand.Intn(1000) + time.Sleep(time.Duration(n) * time.Millisecond) + disconnect() + process, disconnect = connect(ctx, t, command, server, nil) + expected = append(expected, writeUnique(t, process)...) + assert.True(t, "find reconnected output", checkStdout(t, process, expected, []string{})) + }() + } + wg.Wait() + }) } // newServer returns a new wsep server. @@ -168,7 +195,10 @@ func connect(ctx context.Context, t *testing.T, command Command, wsepServer *Ser options = &Options{SessionTimeout: time.Second} } ws, server := mockConn(ctx, t, wsepServer, options) - t.Cleanup(server.Close) + t.Cleanup(func() { + ws.Close(websocket.StatusNormalClosure, "disconnected") + server.Close() + }) process, err := RemoteExecer(ws).Start(ctx, command) assert.Success(t, "start sh", err) @@ -182,7 +212,7 @@ func connect(ctx context.Context, t *testing.T, command Command, wsepServer *Ser // writeUnique writes some unique output to the shell process and returns the // expected output. func writeUnique(t *testing.T, process Process) []string { - n := rand.Intn(100) + n := rand.Intn(1000000) echoCmd := fmt.Sprintf("echo test:$((%d+%d))", n, n) write(t, process, echoCmd) return []string{echoCmd, fmt.Sprintf("test:%d", n+n)} @@ -206,10 +236,12 @@ func checkStdout(t *testing.T, process Process, expected, unexpected []string) b t.Logf("bash tty stdout = %s", strings.ReplaceAll(line, "\x1b", "ESC")) for _, str := range unexpected { if strings.Contains(line, str) { + t.Logf("contains unexpected line %s", line) return false } } if strings.Contains(line, expected[i]) { + t.Logf("contains expected line %s", line) i = i + 1 } if i == len(expected) { From 2d8ff00090e6cc3296dd3cf27800fbe12e891d5c Mon Sep 17 00:00:00 2001 From: Asher Date: Tue, 27 Sep 2022 17:39:06 -0500 Subject: [PATCH 15/18] Fix errors not propagating through web socket close Since the server closed the socket the caller has no chance to close with the right code and reason. Also abnormal closure is not a valid close code. --- README.md | 6 ++---- client_test.go | 3 +-- server.go | 8 +++++--- tty_test.go | 45 ++++++++++++++++++++++++++++----------------- 4 files changed, 36 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 3718d23..8ce1144 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ Error handling is omitted for brevity. ```golang conn, _, _ := websocket.Dial(ctx, "ws://remote.exec.addr", nil) -defer conn.Close(websocket.StatusAbnormalClosure, "terminate process") +defer conn.Close(websocket.StatusNormalClosure, "normal closure") execer := wsep.RemoteExecer(conn) process, _ := execer.Start(ctx, wsep.Command{ @@ -28,7 +28,6 @@ go io.Copy(os.Stderr, process.Stderr()) go io.Copy(os.Stdout, process.Stdout()) process.Wait() -conn.Close(websocket.StatusNormalClosure, "normal closure") ``` ### Server @@ -36,10 +35,9 @@ conn.Close(websocket.StatusNormalClosure, "normal closure") ```golang func (s server) ServeHTTP(w http.ResponseWriter, r *http.Request) { conn, _ := websocket.Accept(w, r, nil) + defer conn.Close(websocket.StatusNormalClosure, "normal closure") wsep.Serve(r.Context(), conn, wsep.LocalExecer{}) - - ws.Close(websocket.StatusNormalClosure, "normal closure") } ``` diff --git a/client_test.go b/client_test.go index 47c0aaa..2890dc0 100644 --- a/client_test.go +++ b/client_test.go @@ -62,8 +62,7 @@ func mockConn(ctx context.Context, t *testing.T, wsepServer *Server, options *Op err = Serve(r.Context(), ws, LocalExecer{}, options) } if err != nil { - t.Errorf("failed to serve execer: %v", err) - ws.Close(websocket.StatusAbnormalClosure, "failed to serve execer") + ws.Close(websocket.StatusInternalError, err.Error()) return } ws.Close(websocket.StatusNormalClosure, "normal closure") diff --git a/server.go b/server.go index fdfa02d..4575a62 100644 --- a/server.go +++ b/server.go @@ -76,7 +76,9 @@ func (srv *Server) Close() { } // Serve runs the server-side of wsep. The execer may be another wsep -// connection for chaining. Use LocalExecer for local command execution. +// connection for chaining. Use LocalExecer for local command execution. The +// web socket will not be closed automatically; the caller must call Close() on +// the web socket (ideally with a reason) once Serve yields. func (srv *Server) Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Options) error { ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -94,7 +96,7 @@ func (srv *Server) Serve(ctx context.Context, c *websocket.Conn, execer Execer, process Process wsNetConn = websocket.NetConn(ctx, c, websocket.MessageBinary) ) - defer wsNetConn.Close() + for { if err := ctx.Err(); err != nil { return err @@ -137,7 +139,7 @@ func (srv *Server) Serve(ctx context.Context, c *websocket.Conn, execer Execer, if command.TTY { // Enforce rows and columns so the TTY will be correctly sized. if command.Rows == 0 || command.Cols == 0 { - return xerrors.Errorf("rows and cols must be non-zero: %w", err) + return xerrors.Errorf("rows and cols must be non-zero") } } diff --git a/tty_test.go b/tty_test.go index 1cc37c2..e73c9c5 100644 --- a/tty_test.go +++ b/tty_test.go @@ -22,28 +22,35 @@ func TestTTY(t *testing.T) { server := newServer(t) ctx, command := newSession(t) command.ID = "" // No ID so we do not start a reconnectable session. - process, _ := connect(ctx, t, command, server, nil) + process, _ := connect(ctx, t, command, server, nil, "") expected := writeUnique(t, process) assert.True(t, "find initial output", checkStdout(t, process, expected, []string{})) // Connect to the same session. There should not be shared output since // these end up being separate sessions due to the lack of an ID. - process, _ = connect(ctx, t, command, server, nil) + process, _ = connect(ctx, t, command, server, nil, "") unexpected := expected expected = writeUnique(t, process) assert.True(t, "find new session output", checkStdout(t, process, expected, unexpected)) } func TestReconnectTTY(t *testing.T) { + t.Run("NoSize", func(t *testing.T) { + server := newServer(t) + ctx, command := newSession(t) + command.Rows = 0 + connect(ctx, t, command, server, nil, "rows and cols must be non-zero") + }) + t.Run("DeprecatedServe", func(t *testing.T) { // Do something in the first session. ctx, command := newSession(t) - process, _ := connect(ctx, t, command, nil, nil) + process, _ := connect(ctx, t, command, nil, nil, "") expected := writeUnique(t, process) assert.True(t, "find initial output", checkStdout(t, process, expected, []string{})) // Connect to the same session. Should see the same output. - process, _ = connect(ctx, t, command, nil, nil) + process, _ = connect(ctx, t, command, nil, nil, "") assert.True(t, "find reconnected output", checkStdout(t, process, expected, []string{})) }) @@ -53,13 +60,13 @@ func TestReconnectTTY(t *testing.T) { // Run some output in a new session. server := newServer(t) ctx, command := newSession(t) - process, _ := connect(ctx, t, command, server, nil) + process, _ := connect(ctx, t, command, server, nil, "") expected := writeUnique(t, process) assert.True(t, "find initial output", checkStdout(t, process, expected, []string{})) // Connect to the same session. There should not be shared output since // these end up being separate sessions due to the lack of screen. - process, _ = connect(ctx, t, command, server, nil) + process, _ = connect(ctx, t, command, server, nil, "") unexpected := expected expected = writeUnique(t, process) assert.True(t, "find new session output", checkStdout(t, process, expected, unexpected)) @@ -71,14 +78,14 @@ func TestReconnectTTY(t *testing.T) { // Run some output in a new session. server := newServer(t) ctx, command := newSession(t) - process, disconnect := connect(ctx, t, command, server, nil) + process, disconnect := connect(ctx, t, command, server, nil, "") expected := writeUnique(t, process) assert.True(t, "find initial output", checkStdout(t, process, expected, []string{})) // Reconnect and sleep; the inactivity timeout should not trigger since we // were not disconnected during the timeout. disconnect() - process, disconnect = connect(ctx, t, command, server, nil) + process, disconnect = connect(ctx, t, command, server, nil, "") time.Sleep(time.Second) expected = append(expected, writeUnique(t, process)...) assert.True(t, "find reconnected output", checkStdout(t, process, expected, []string{})) @@ -88,7 +95,7 @@ func TestReconnectTTY(t *testing.T) { // Divide the time to test that the heartbeat keeps it open through // multiple intervals. SessionTimeout: time.Second / 4, - }) + }, "") // Disconnect the first connection and wait for inactivity. The session // should stay up because of the second connection. @@ -102,7 +109,7 @@ func TestReconnectTTY(t *testing.T) { // and not any output from the old session. disconnect2() time.Sleep(time.Second) - process, disconnect = connect(ctx, t, command, server, nil) + process, disconnect = connect(ctx, t, command, server, nil, "") unexpected := expected expected = writeUnique(t, process) assert.True(t, "find new session output", checkStdout(t, process, expected, unexpected)) @@ -114,20 +121,20 @@ func TestReconnectTTY(t *testing.T) { // Run an application that enters the alternate screen. server := newServer(t) ctx, command := newSession(t) - process, disconnect := connect(ctx, t, command, server, nil) + process, disconnect := connect(ctx, t, command, server, nil, "") write(t, process, "./ci/alt.sh") assert.True(t, "find alt screen", checkStdout(t, process, []string{"./ci/alt.sh", "ALT SCREEN"}, []string{})) // Reconnect; the application should redraw. We should have only the // application output and not the command that spawned the application. disconnect() - process, disconnect = connect(ctx, t, command, server, nil) + process, disconnect = connect(ctx, t, command, server, nil, "") assert.True(t, "find reconnected alt screen", checkStdout(t, process, []string{"ALT SCREEN"}, []string{"./ci/alt.sh"})) // Exit the application and reconnect. Should now be in a regular shell. write(t, process, "q") disconnect() - process, _ = connect(ctx, t, command, server, nil) + process, _ = connect(ctx, t, command, server, nil, "") expected := writeUnique(t, process) assert.True(t, "find shell output", checkStdout(t, process, expected, []string{})) }) @@ -144,14 +151,14 @@ func TestReconnectTTY(t *testing.T) { go func() { defer wg.Done() ctx, command := newSession(t) - process, disconnect := connect(ctx, t, command, server, nil) + process, disconnect := connect(ctx, t, command, server, nil, "") expected := writeUnique(t, process) assert.True(t, "find initial output", checkStdout(t, process, expected, []string{})) n := rand.Intn(1000) time.Sleep(time.Duration(n) * time.Millisecond) disconnect() - process, disconnect = connect(ctx, t, command, server, nil) + process, disconnect = connect(ctx, t, command, server, nil, "") expected = append(expected, writeUnique(t, process)...) assert.True(t, "find reconnected output", checkStdout(t, process, expected, []string{})) }() @@ -190,7 +197,7 @@ func newSession(t *testing.T) (context.Context, Command) { } // connect connects to a wsep server and runs the provided command. -func connect(ctx context.Context, t *testing.T, command Command, wsepServer *Server, options *Options) (Process, func()) { +func connect(ctx context.Context, t *testing.T, command Command, wsepServer *Server, options *Options, error string) (Process, func()) { if options == nil { options = &Options{SessionTimeout: time.Second} } @@ -201,7 +208,11 @@ func connect(ctx context.Context, t *testing.T, command Command, wsepServer *Ser }) process, err := RemoteExecer(ws).Start(ctx, command) - assert.Success(t, "start sh", err) + if error != "" { + assert.True(t, fmt.Sprintf("%s contains %s", err.Error(), error), strings.Contains(err.Error(), error)) + } else { + assert.Success(t, "start sh", err) + } return process, func() { ws.Close(websocket.StatusNormalClosure, "disconnected") From 4f2d5d1f29ef3d27a31a2ea1273fd0ee4376f8f0 Mon Sep 17 00:00:00 2001 From: Asher Date: Wed, 28 Sep 2022 15:26:51 -0500 Subject: [PATCH 16/18] Fix test flake in reading output Without waiting for the copy you can sometimes get "file already closed". I guess process.Wait must have some side effect. --- localexec_test.go | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/localexec_test.go b/localexec_test.go index 4d2caac..af60586 100644 --- a/localexec_test.go +++ b/localexec_test.go @@ -12,6 +12,7 @@ import ( "time" "cdr.dev/slog/sloggers/slogtest/assert" + "golang.org/x/sync/errgroup" ) func TestLocalExec(t *testing.T) { @@ -140,8 +141,18 @@ func TestStdoutVsStderr(t *testing.T) { }) assert.Success(t, "start command", err) - go io.Copy(&stdout, process.Stdout()) - go io.Copy(&stderr, process.Stderr()) + var outputgroup errgroup.Group + outputgroup.Go(func() error { + _, err := io.Copy(&stdout, process.Stdout()) + return err + }) + outputgroup.Go(func() error { + _, err := io.Copy(&stderr, process.Stderr()) + return err + }) + + err = outputgroup.Wait() + assert.Success(t, "wait for output to drain", err) err = process.Wait() assert.Success(t, "wait for process to complete", err) From 82d7e68f00af64295732851dddc1359a3853d312 Mon Sep 17 00:00:00 2001 From: Asher Date: Wed, 16 Nov 2022 12:57:29 -0600 Subject: [PATCH 17/18] Convert channels to state machine --- session.go | 272 ++++++++++++++++++++++++++++++++++------------------- 1 file changed, 174 insertions(+), 98 deletions(-) diff --git a/session.go b/session.go index 4cbec38..b77cf8b 100644 --- a/session.go +++ b/session.go @@ -6,26 +6,38 @@ import ( "os" "path/filepath" "strings" + "sync" "time" "golang.org/x/xerrors" ) +// State represents the current state of the session. States are sequential and +// will only move forward. +type State int + +const ( + // StateStarting is the default/start state. + StateStarting = iota + // StateReady means the session is ready to be attached. + StateReady + // StateClosing means the session has begun closing. The underlying process + // may still be exiting. + StateClosing + // StateDone means the session has completely shut down and the process has + // exited. + StateDone +) + // Session represents a `screen` session. type Session struct { - // close receives nil when the session should be closed. - close chan struct{} - // closing receives nil when the session has begun closing. The underlying - // process may still be exiting. closing implies ready. - closing chan struct{} // command is the original command used to spawn the session. command *Command + // cond broadcasts session changes. + cond *sync.Cond // configFile is the location of the screen configuration file. configFile string - // done receives nil when the session has completely shut down and the process - // has exited. done implies closing and ready. - done chan struct{} - // error hold any error that occurred while starting the session. + // error hold any error that occurred during a state change. error error // execer is used to spawn the session and ready commands. execer Execer @@ -33,12 +45,11 @@ type Session struct { id string // options holds options for configuring the session. options *Options - // ready receives nil when the session is ready to be attached. error must be - // checked after receiving on this channel. - ready chan struct{} // socketsDir is the location of the directory where screen should put its // sockets. socketsDir string + // state holds the current session state. + state State // timer will close the session when it expires. timer *time.Timer } @@ -49,14 +60,12 @@ type Session struct { func NewSession(id string, command *Command, execer Execer, options *Options) *Session { tempdir := filepath.Join(os.TempDir(), "coder-screen") s := &Session{ - close: make(chan struct{}), - closing: make(chan struct{}), command: command, + cond: sync.NewCond(&sync.Mutex{}), configFile: filepath.Join(tempdir, "config"), - done: make(chan struct{}), execer: execer, options: options, - ready: make(chan struct{}), + state: StateStarting, socketsDir: filepath.Join(tempdir, "sockets"), } go s.lifecycle(id) @@ -65,53 +74,38 @@ func NewSession(id string, command *Command, execer Execer, options *Options) *S // lifecycle manages the lifecycle of the session. func (s *Session) lifecycle(id string) { + // When this context is done the process is dead. ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Close the session after a timeout. Timeouts created with AfterFunc can be - // reset; we will reset it as long as there are active connections. - s.timer = time.AfterFunc(s.options.SessionTimeout, s.Close) - defer s.timer.Stop() - // Close the session down immediately if there was an error and store that - // error to emit on Attach() calls. + // Close the session down immediately if there was an error. process, err := s.start(ctx, id) if err != nil { - s.error = err - close(s.ready) - close(s.closing) - close(s.done) + defer cancel() + s.setState(StateDone, xerrors.Errorf("process start: %w", err)) return } - // Mark the session as fully done when the process exits. + // Close the session after a timeout. Timeouts created with AfterFunc can be + // reset; we will reset it as long as there are active connections. The + // initial timeout for starting up is set here and will probably be less than + // the session timeout in most cases. It should be at least long enough for + // screen to be able to start up. + s.timer = time.AfterFunc(30*time.Second, s.Close) + + // Emit the done event when the process exits. go func() { - err = process.Wait() + defer cancel() + err := process.Wait() if err != nil { - s.error = err + err = xerrors.Errorf("process exit: %w", err) } - close(s.done) + s.setState(StateDone, err) }() - // Wait until the session is ready to receive attaches. - err = s.waitReady(ctx) - if err != nil { - s.error = err - close(s.ready) - close(s.closing) - // The deferred cancel will kill the process if it is still running which - // will then close the done channel. - return - } - - close(s.ready) - - select { - // When the session is closed try gracefully killing the process. - case <-s.close: - // Mark the session as closing so you can use Wait() to stop attaching new - // connections to sessions that are closing down. - close(s.closing) + // Handle the close event by killing the process. + go func() { + s.waitForState(StateClosing) + s.timer.Stop() // process.Close() will send a SIGTERM allowing screen to clean up. If we // kill it abruptly the socket will be left behind which is probably // harmless but it causes screen -list to show a bunch of dead sessions. @@ -119,17 +113,26 @@ func (s *Session) lifecycle(id string) { // directory but it seems ideal to let screen clean up anyway. process.Close() select { - case <-s.done: + case <-ctx.Done(): return // Process exited on its own. case <-time.After(5 * time.Second): - // Still running; yield so we can run the deferred cancel which will - // forcefully terminate the session. - return + // Still running; cancel the context to forcefully terminate the process. + cancel() } - // If the process exits on its own the session is also considered closed. - case <-s.done: - close(s.closing) + }() + + // Wait until the session is ready to receive attaches. + err = s.waitReady(ctx) + if err != nil { + defer cancel() + s.setState(StateClosing, xerrors.Errorf("session wait: %w", err)) + return } + + // Once the session is ready external callers have until their provided + // timeout to attach something before the session will close itself. + s.timer.Reset(s.options.SessionTimeout) + s.setState(StateReady, nil) } // start starts the session. @@ -167,39 +170,54 @@ func (s *Session) start(ctx context.Context, id string) (Process, error) { return process, nil } -// waitReady waits for the session to be ready. Sometimes if you attach too -// quickly after spawning screen it will say the session does not exist so this -// will run a command against the session until it works or times out. +// waitReady waits for the session to be ready by running a command against the +// session until it works since sometimes if you attach too quickly after +// spawning screen it will say the session does not exist. If the provided +// context finishes the wait is aborted and the context's error is returned. func (s *Session) waitReady(ctx context.Context) error { - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() + check := func() (bool, error) { + // The `version` command seems to be the only command without a side effect + // so use it to check whether the session is up. + process, err := s.execer.Start(ctx, Command{ + Command: "screen", + Args: []string{"-S", s.id, "-X", "version"}, + UID: s.command.UID, + GID: s.command.GID, + Env: append(s.command.Env, "SCREENDIR="+s.socketsDir), + }) + if err != nil { + return true, err + } + err = process.Wait() + // Try the context error in case it canceled while we waited. + if ctx.Err() != nil { + return true, ctx.Err() + } + // Session is ready once we executed without an error. + // TODO: Should we specifically check for "no screen to be attached + // matching"? That might be the only error we actually want to retry and + // otherwise we return immediately. But this text may not be stable between + // screen versions or there could be additional errors that can be retried. + return err == nil, nil + } + + // Check immediately. + if done, err := check(); done { + return err + } + + // Then check on a timer. + ticker := time.NewTicker(250 * time.Millisecond) + defer ticker.Stop() + for { select { - case <-s.close: - return xerrors.Errorf("session has been closed") - case <-s.done: - return xerrors.Errorf("session has exited") case <-ctx.Done(): - return xerrors.Errorf("timed out waiting for session") - default: - // The `version` command seems to be the only command without a side - // effect so use it to check whether the session is up. - process, err := s.execer.Start(ctx, Command{ - Command: "screen", - Args: []string{"-S", s.id, "-X", "version"}, - UID: s.command.UID, - GID: s.command.GID, - Env: append(s.command.Env, "SCREENDIR="+s.socketsDir), - }) - if err != nil { + return ctx.Err() + case <-ticker.C: + if done, err := check(); done { return err } - err = process.Wait() - // TODO: Return error if it is anything but "no screen session found". - if err == nil { - return nil - } - time.Sleep(250 * time.Millisecond) } } } @@ -207,11 +225,32 @@ func (s *Session) waitReady(ctx context.Context) error { // Attach waits for the session to become attachable and returns a command that // can be used to attach to the session. func (s *Session) Attach(ctx context.Context) (*Command, error) { - <-s.ready - if s.error != nil { - return nil, s.error + state, err := s.waitForState(StateReady) + switch state { + case StateClosing: + if err == nil { + // No error means s.Close() was called, either by external code or via the + // session timeout. + err = xerrors.Errorf("session is closing") + } + return nil, err + case StateDone: + if err == nil { + // No error means the process exited with zero, probably after being + // killed due to a call to s.Close() either externally or via the session + // timeout. + err = xerrors.Errorf("session is done") + } + return nil, err } + // Abort the heartbeat when the session closes. + ctx, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + s.waitForStateOrContext(ctx, StateClosing) + }() + go s.heartbeat(ctx) return &Command{ @@ -230,8 +269,7 @@ func (s *Session) Attach(ctx context.Context) (*Command, error) { // heartbeat keeps the session alive while the provided context is not done. func (s *Session) heartbeat(ctx context.Context) { - // We just connected so reset the timer now in case it is near the end or this - // is the first connection and the timer has not been set yet. + // We just connected so reset the timer now in case it is near the end. s.timer.Reset(s.options.SessionTimeout) // Reset when the connection closes to ensure the session stays up for the @@ -243,8 +281,6 @@ func (s *Session) heartbeat(ctx context.Context) { for { select { - case <-s.closing: - return case <-ctx.Done(): return case <-heartbeat.C: @@ -256,18 +292,15 @@ func (s *Session) heartbeat(ctx context.Context) { // Wait waits for the session to close. The underlying process might still be // exiting. func (s *Session) Wait() { - <-s.closing + s.waitForState(StateClosing) } // Close attempts to gracefully kill the session's underlying process then waits // for the process to exit. If the session does not exit in a timely manner it // forcefully kills the process. func (s *Session) Close() { - select { - case s.close <- struct{}{}: - default: // Do not block; the lifecycle has already completed. - } - <-s.done + s.setState(StateClosing, nil) + s.waitForState(StateDone) } // ensureSettings writes config settings and creates the socket directory. @@ -307,3 +340,46 @@ func (s *Session) ensureSettings() error { return os.WriteFile(config, []byte(strings.Join(settings, "\n")), 0o644) } + +// setState sets and broadcasts the provided state if it is greater than the +// current state and the error if one has not already been set. +func (s *Session) setState(state State, err error) { + s.cond.L.Lock() + defer s.cond.L.Unlock() + // Cannot regress states (for example trying to close after the process is + // done should leave us in the done state and not the closing state). + if state <= s.state { + return + } + // Keep the first error we get. + if s.error == nil { + s.error = err + } + s.state = state + s.cond.Broadcast() +} + +// waitForState blocks until the state or a greater one is reached. +func (s *Session) waitForState(state State) (State, error) { + s.cond.L.Lock() + defer s.cond.L.Unlock() + for state > s.state { + s.cond.Wait() + } + return s.state, s.error +} + +// waitForStateOrContext blocks until the state or a greater one is reached or +// the provided context ends. If the context ends all goroutines will be woken. +func (s *Session) waitForStateOrContext(ctx context.Context, state State) { + go func() { + // Wake up when the context ends. + defer s.cond.Broadcast() + <-ctx.Done() + }() + s.cond.L.Lock() + defer s.cond.L.Unlock() + for ctx.Err() == nil && state > s.state { + s.cond.Wait() + } +} From cefda59e107aa4ef4a7c78100d0ba4c8ebf92130 Mon Sep 17 00:00:00 2001 From: Asher Date: Wed, 16 Nov 2022 12:57:34 -0600 Subject: [PATCH 18/18] Prevent close frame error in tests with long errors --- client_test.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/client_test.go b/client_test.go index 8836411..96c10b0 100644 --- a/client_test.go +++ b/client_test.go @@ -64,7 +64,12 @@ func mockConn(ctx context.Context, t *testing.T, wsepServer *Server, options *Op err = Serve(r.Context(), ws, LocalExecer{}, options) } if err != nil { - ws.Close(websocket.StatusInternalError, err.Error()) + // Max reason string length is 123. + errStr := err.Error() + if len(errStr) > 123 { + errStr = errStr[:123] + } + ws.Close(websocket.StatusInternalError, errStr) return } ws.Close(websocket.StatusNormalClosure, "normal closure")