diff --git a/agent/agent.go b/agent/agent.go index 3d9b8947b18b1..ea2fae6d430f6 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -21,7 +21,6 @@ import ( "sync" "time" - "github.com/armon/circbuf" "github.com/go-chi/chi/v5" "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" @@ -36,12 +35,12 @@ import ( "cdr.dev/slog" "github.com/coder/coder/agent/agentssh" + "github.com/coder/coder/agent/reconnectingpty" "github.com/coder/coder/buildinfo" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/gitauth" "github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk/agentsdk" - "github.com/coder/coder/pty" "github.com/coder/coder/tailnet" "github.com/coder/retry" ) @@ -92,9 +91,6 @@ type Agent interface { } func New(options Options) Agent { - if options.ReconnectingPTYTimeout == 0 { - options.ReconnectingPTYTimeout = 5 * time.Minute - } if options.Filesystem == nil { options.Filesystem = afero.NewOsFs() } @@ -1075,8 +1071,8 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m defer a.connCountReconnectingPTY.Add(-1) connectionID := uuid.NewString() - logger = logger.With(slog.F("message_id", msg.ID), slog.F("connection_id", connectionID)) - logger.Debug(ctx, "starting handler") + connLogger := logger.With(slog.F("message_id", msg.ID), slog.F("connection_id", connectionID)) + connLogger.Debug(ctx, "starting handler") defer func() { if err := retErr; err != nil { @@ -1087,22 +1083,22 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m // If the agent is closed, we don't want to // log this as an error since it's expected. if closed { - logger.Debug(ctx, "reconnecting PTY failed with session error (agent closed)", slog.Error(err)) + connLogger.Debug(ctx, "reconnecting pty failed with attach error (agent closed)", slog.Error(err)) } else { - logger.Error(ctx, "reconnecting PTY failed with session error", slog.Error(err)) + connLogger.Error(ctx, "reconnecting pty failed with attach error", slog.Error(err)) } } - logger.Debug(ctx, "session closed") + connLogger.Debug(ctx, "reconnecting pty connection closed") }() - var rpty *reconnectingPTY - sendConnected := make(chan *reconnectingPTY, 1) + var rpty reconnectingpty.ReconnectingPTY + sendConnected := make(chan reconnectingpty.ReconnectingPTY, 1) // On store, reserve this ID to prevent multiple concurrent new connections. waitReady, ok := a.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected) if ok { close(sendConnected) // Unused. - logger.Debug(ctx, "connecting to existing session") - c, ok := waitReady.(chan *reconnectingPTY) + connLogger.Debug(ctx, "connecting to existing reconnecting pty") + c, ok := waitReady.(chan reconnectingpty.ReconnectingPTY) if !ok { return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady) } @@ -1112,7 +1108,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m } c <- rpty // Put it back for the next reconnect. } else { - logger.Debug(ctx, "creating new session") + connLogger.Debug(ctx, "creating new reconnecting pty") connected := false defer func() { @@ -1128,169 +1124,24 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m a.metrics.reconnectingPTYErrors.WithLabelValues("create_command").Add(1) return xerrors.Errorf("create command: %w", err) } - cmd.Env = append(cmd.Env, "TERM=xterm-256color") - - // Default to buffer 64KiB. - circularBuffer, err := circbuf.NewBuffer(64 << 10) - if err != nil { - return xerrors.Errorf("create circular buffer: %w", err) - } - ptty, process, err := pty.Start(cmd) - if err != nil { - a.metrics.reconnectingPTYErrors.WithLabelValues("start_command").Add(1) - return xerrors.Errorf("start command: %w", err) - } + rpty = reconnectingpty.New(ctx, cmd, &reconnectingpty.Options{ + Timeout: a.reconnectingPTYTimeout, + Metrics: a.metrics.reconnectingPTYErrors, + }, logger.With(slog.F("message_id", msg.ID))) - ctx, cancel := context.WithCancel(ctx) - rpty = &reconnectingPTY{ - activeConns: map[string]net.Conn{ - // We have to put the connection in the map instantly otherwise - // the connection won't be closed if the process instantly dies. - connectionID: conn, - }, - ptty: ptty, - // Timeouts created with an after func can be reset! - timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancel), - circularBuffer: circularBuffer, - } - // We don't need to separately monitor for the process exiting. - // When it exits, our ptty.OutputReader() will return EOF after - // reading all process output. if err = a.trackConnGoroutine(func() { - buffer := make([]byte, 1024) - for { - read, err := rpty.ptty.OutputReader().Read(buffer) - if err != nil { - // When the PTY is closed, this is triggered. - // Error is typically a benign EOF, so only log for debugging. - if errors.Is(err, io.EOF) { - logger.Debug(ctx, "unable to read pty output, command might have exited", slog.Error(err)) - } else { - logger.Warn(ctx, "unable to read pty output, command might have exited", slog.Error(err)) - a.metrics.reconnectingPTYErrors.WithLabelValues("output_reader").Add(1) - } - break - } - part := buffer[:read] - rpty.circularBufferMutex.Lock() - _, err = rpty.circularBuffer.Write(part) - rpty.circularBufferMutex.Unlock() - if err != nil { - logger.Error(ctx, "write to circular buffer", slog.Error(err)) - break - } - rpty.activeConnsMutex.Lock() - for cid, conn := range rpty.activeConns { - _, err = conn.Write(part) - if err != nil { - logger.Warn(ctx, - "error writing to active conn", - slog.F("other_conn_id", cid), - slog.Error(err), - ) - a.metrics.reconnectingPTYErrors.WithLabelValues("write").Add(1) - } - } - rpty.activeConnsMutex.Unlock() - } - - // Cleanup the process, PTY, and delete it's - // ID from memory. - _ = process.Kill() - rpty.Close() + rpty.Wait() a.reconnectingPTYs.Delete(msg.ID) }); err != nil { - _ = process.Kill() - _ = ptty.Close() + rpty.Close(err.Error()) return xerrors.Errorf("start routine: %w", err) } + connected = true sendConnected <- rpty } - // Resize the PTY to initial height + width. - err := rpty.ptty.Resize(msg.Height, msg.Width) - if err != nil { - // We can continue after this, it's not fatal! - logger.Error(ctx, "reconnecting PTY initial resize failed, but will continue", slog.Error(err)) - a.metrics.reconnectingPTYErrors.WithLabelValues("resize").Add(1) - } - // Write any previously stored data for the TTY. - rpty.circularBufferMutex.RLock() - prevBuf := slices.Clone(rpty.circularBuffer.Bytes()) - rpty.circularBufferMutex.RUnlock() - // Note that there is a small race here between writing buffered - // data and storing conn in activeConns. This is likely a very minor - // edge case, but we should look into ways to avoid it. Holding - // activeConnsMutex would be one option, but holding this mutex - // while also holding circularBufferMutex seems dangerous. - _, err = conn.Write(prevBuf) - if err != nil { - a.metrics.reconnectingPTYErrors.WithLabelValues("write").Add(1) - return xerrors.Errorf("write buffer to conn: %w", err) - } - // Multiple connections to the same TTY are permitted. - // This could easily be used for terminal sharing, but - // we do it because it's a nice user experience to - // copy/paste a terminal URL and have it _just work_. - rpty.activeConnsMutex.Lock() - rpty.activeConns[connectionID] = conn - rpty.activeConnsMutex.Unlock() - // Resetting this timeout prevents the PTY from exiting. - rpty.timeout.Reset(a.reconnectingPTYTimeout) - - ctx, cancelFunc := context.WithCancel(ctx) - defer cancelFunc() - heartbeat := time.NewTicker(a.reconnectingPTYTimeout / 2) - defer heartbeat.Stop() - go func() { - // Keep updating the activity while this - // connection is alive! - for { - select { - case <-ctx.Done(): - return - case <-heartbeat.C: - } - rpty.timeout.Reset(a.reconnectingPTYTimeout) - } - }() - defer func() { - // After this connection ends, remove it from - // the PTYs active connections. If it isn't - // removed, all PTY data will be sent to it. - rpty.activeConnsMutex.Lock() - delete(rpty.activeConns, connectionID) - rpty.activeConnsMutex.Unlock() - }() - decoder := json.NewDecoder(conn) - var req codersdk.ReconnectingPTYRequest - for { - err = decoder.Decode(&req) - if xerrors.Is(err, io.EOF) { - return nil - } - if err != nil { - logger.Warn(ctx, "reconnecting PTY failed with read error", slog.Error(err)) - return nil - } - _, err = rpty.ptty.InputWriter().Write([]byte(req.Data)) - if err != nil { - logger.Warn(ctx, "reconnecting PTY failed with write error", slog.Error(err)) - a.metrics.reconnectingPTYErrors.WithLabelValues("input_writer").Add(1) - return nil - } - // Check if a resize needs to happen! - if req.Height == 0 || req.Width == 0 { - continue - } - err = rpty.ptty.Resize(req.Height, req.Width) - if err != nil { - // We can continue after this, it's not fatal! - logger.Error(ctx, "reconnecting PTY resize failed, but will continue", slog.Error(err)) - a.metrics.reconnectingPTYErrors.WithLabelValues("resize").Add(1) - } - } + return rpty.Attach(ctx, connectionID, conn, msg.Height, msg.Width, connLogger) } // startReportingConnectionStats runs the connection stats reporting goroutine. @@ -1541,31 +1392,6 @@ lifecycleWaitLoop: return nil } -type reconnectingPTY struct { - activeConnsMutex sync.Mutex - activeConns map[string]net.Conn - - circularBuffer *circbuf.Buffer - circularBufferMutex sync.RWMutex - timeout *time.Timer - ptty pty.PTYCmd -} - -// Close ends all connections to the reconnecting -// PTY and clear the circular buffer. -func (r *reconnectingPTY) Close() { - r.activeConnsMutex.Lock() - defer r.activeConnsMutex.Unlock() - for _, conn := range r.activeConns { - _ = conn.Close() - } - _ = r.ptty.Close() - r.circularBufferMutex.Lock() - r.circularBuffer.Reset() - r.circularBufferMutex.Unlock() - r.timeout.Stop() -} - // userHomeDir returns the home directory of the current user, giving // priority to the $HOME environment variable. func userHomeDir() (string, error) { diff --git a/agent/agent_test.go b/agent/agent_test.go index 92be40764f209..e9f8bc772bb0f 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -12,6 +12,7 @@ import ( "net/http/httptest" "net/netip" "os" + "os/exec" "os/user" "path" "path/filepath" @@ -102,7 +103,7 @@ func TestAgent_Stats_ReconnectingPTY(t *testing.T) { //nolint:dogsled conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) - ptyConn, err := conn.ReconnectingPTY(ctx, uuid.New(), 128, 128, "/bin/bash") + ptyConn, err := conn.ReconnectingPTY(ctx, uuid.New(), 128, 128, "bash") require.NoError(t, err) defer ptyConn.Close() @@ -1587,8 +1588,12 @@ func TestAgent_Startup(t *testing.T) { }) } +const ansi = "[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))" + +var re = regexp.MustCompile(ansi) + +//nolint:paralleltest // This test sets an environment variable. func TestAgent_ReconnectingPTY(t *testing.T) { - t.Parallel() if runtime.GOOS == "windows" { // This might be our implementation, or ConPTY itself. // It's difficult to find extensive tests for it, so @@ -1596,61 +1601,139 @@ func TestAgent_ReconnectingPTY(t *testing.T) { t.Skip("ConPTY appears to be inconsistent on Windows.") } - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() + backends := []string{"Buffered", "Screen"} - //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) - id := uuid.New() - netConn, err := conn.ReconnectingPTY(ctx, id, 100, 100, "/bin/bash") - require.NoError(t, err) - defer netConn.Close() + _, err := exec.LookPath("screen") + hasScreen := err == nil - bufRead := bufio.NewReader(netConn) + for _, backendType := range backends { + backendType := backendType + t.Run(backendType, func(t *testing.T) { + if backendType == "Screen" { + t.Parallel() + if runtime.GOOS != "linux" { + t.Skipf("`screen` is not supported on %s", runtime.GOOS) + } else if !hasScreen { + t.Skip("`screen` not found") + } + } else if hasScreen && runtime.GOOS == "linux" { + // Set up a PATH that does not have screen in it. + bashPath, err := exec.LookPath("bash") + require.NoError(t, err) + dir, err := os.MkdirTemp("/tmp", "coder-test-reconnecting-pty-PATH") + require.NoError(t, err, "create temp dir for reconnecting pty PATH") + err = os.Symlink(bashPath, filepath.Join(dir, "bash")) + require.NoError(t, err, "symlink bash into reconnecting pty PATH") + t.Setenv("PATH", dir) + } else { + t.Parallel() + } - // Brief pause to reduce the likelihood that we send keystrokes while - // the shell is simultaneously sending a prompt. - time.Sleep(100 * time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() - data, err := json.Marshal(codersdk.ReconnectingPTYRequest{ - Data: "echo test\r\n", - }) - require.NoError(t, err) - _, err = netConn.Write(data) - require.NoError(t, err) + //nolint:dogsled + conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + id := uuid.New() + netConn1, err := conn.ReconnectingPTY(ctx, id, 100, 100, "bash") + require.NoError(t, err) + defer netConn1.Close() - expectLine := func(matcher func(string) bool) { - for { - line, err := bufRead.ReadString('\n') + scanner1 := bufio.NewScanner(netConn1) + + // A second simultaneous connection. + netConn2, err := conn.ReconnectingPTY(ctx, id, 100, 100, "bash") + require.NoError(t, err) + defer netConn2.Close() + scanner2 := bufio.NewScanner(netConn2) + + // Brief pause to reduce the likelihood that we send keystrokes while + // the shell is simultaneously sending a prompt. + time.Sleep(100 * time.Millisecond) + + data, err := json.Marshal(codersdk.ReconnectingPTYRequest{ + Data: "echo test\r\n", + }) + require.NoError(t, err) + _, err = netConn1.Write(data) require.NoError(t, err) - if matcher(line) { - break + + hasLine := func(scanner *bufio.Scanner, matcher func(string) bool) bool { + for scanner.Scan() { + line := scanner.Text() + t.Logf("bash tty stdout = %s", re.ReplaceAllString(line, "")) + if matcher(line) { + return true + } + } + return false } - } - } - matchEchoCommand := func(line string) bool { - return strings.Contains(line, "echo test") - } - matchEchoOutput := func(line string) bool { - return strings.Contains(line, "test") && !strings.Contains(line, "echo") - } + matchEchoCommand := func(line string) bool { + return strings.Contains(line, "echo test") + } + matchEchoOutput := func(line string) bool { + return strings.Contains(line, "test") && !strings.Contains(line, "echo") + } + matchExitCommand := func(line string) bool { + return strings.Contains(line, "exit") + } + matchExitOutput := func(line string) bool { + return strings.Contains(line, "exit") || strings.Contains(line, "logout") + } - // Once for typing the command... - expectLine(matchEchoCommand) - // And another time for the actual output. - expectLine(matchEchoOutput) + // Once for typing the command... + require.True(t, hasLine(scanner1, matchEchoCommand), "find echo command") + // And another time for the actual output. + require.True(t, hasLine(scanner1, matchEchoOutput), "find echo output") - _ = netConn.Close() - netConn, err = conn.ReconnectingPTY(ctx, id, 100, 100, "/bin/bash") - require.NoError(t, err) - defer netConn.Close() + // Same for the other connection. + require.True(t, hasLine(scanner2, matchEchoCommand), "find echo command") + require.True(t, hasLine(scanner2, matchEchoOutput), "find echo output") + + _ = netConn1.Close() + _ = netConn2.Close() + netConn3, err := conn.ReconnectingPTY(ctx, id, 100, 100, "bash") + require.NoError(t, err) + defer netConn3.Close() + + scanner3 := bufio.NewScanner(netConn3) + + // Same output again! + require.True(t, hasLine(scanner3, matchEchoCommand), "find echo command") + require.True(t, hasLine(scanner3, matchEchoOutput), "find echo output") + + // Exit should cause the connection to close. + data, err = json.Marshal(codersdk.ReconnectingPTYRequest{ + Data: "exit\r\n", + }) + require.NoError(t, err) + _, err = netConn3.Write(data) + require.NoError(t, err) - bufRead = bufio.NewReader(netConn) + // Once for the input and again for the output. + require.True(t, hasLine(scanner3, matchExitCommand), "find exit command") + require.True(t, hasLine(scanner3, matchExitOutput), "find exit output") - // Same output again! - expectLine(matchEchoCommand) - expectLine(matchEchoOutput) + // Wait for the connection to close. + for scanner3.Scan() { + line := scanner3.Text() + t.Logf("bash tty stdout = %s", re.ReplaceAllString(line, "")) + } + + // Try a non-shell command. It should output then immediately exit. + netConn4, err := conn.ReconnectingPTY(ctx, uuid.New(), 100, 100, "echo test") + require.NoError(t, err) + defer netConn4.Close() + + scanner4 := bufio.NewScanner(netConn4) + require.True(t, hasLine(scanner4, matchEchoOutput), "find echo output") + for scanner4.Scan() { + line := scanner4.Text() + t.Logf("bash tty stdout = %s", re.ReplaceAllString(line, "")) + } + }) + } } func TestAgent_Dial(t *testing.T) { diff --git a/agent/reconnectingpty/buffered.go b/agent/reconnectingpty/buffered.go new file mode 100644 index 0000000000000..93241ada29687 --- /dev/null +++ b/agent/reconnectingpty/buffered.go @@ -0,0 +1,241 @@ +package reconnectingpty + +import ( + "context" + "errors" + "io" + "net" + "time" + + "github.com/armon/circbuf" + "github.com/prometheus/client_golang/prometheus" + "golang.org/x/exp/slices" + "golang.org/x/xerrors" + + "cdr.dev/slog" + + "github.com/coder/coder/pty" +) + +// bufferedReconnectingPTY provides a reconnectable PTY by using a ring buffer to store +// scrollback. +type bufferedReconnectingPTY struct { + command *pty.Cmd + + activeConns map[string]net.Conn + circularBuffer *circbuf.Buffer + + ptty pty.PTYCmd + process pty.Process + + metrics *prometheus.CounterVec + + state *ptyState + // timer will close the reconnecting pty when it expires. The timer will be + // reset as long as there are active connections. + timer *time.Timer + timeout time.Duration +} + +// newBuffered starts the buffered pty. If the context ends the process will be +// killed. +func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.Logger) *bufferedReconnectingPTY { + rpty := &bufferedReconnectingPTY{ + activeConns: map[string]net.Conn{}, + command: cmd, + metrics: options.Metrics, + state: newState(), + timeout: options.Timeout, + } + + // Default to buffer 64KiB. + circularBuffer, err := circbuf.NewBuffer(64 << 10) + if err != nil { + rpty.state.setState(StateDone, xerrors.Errorf("generate screen id: %w", err)) + return rpty + } + rpty.circularBuffer = circularBuffer + + // Add TERM then start the command with a pty. pty.Cmd duplicates Path as the + // first argument so remove it. + cmdWithEnv := pty.CommandContext(ctx, cmd.Path, cmd.Args[1:]...) + cmdWithEnv.Env = append(rpty.command.Env, "TERM=xterm-256color") + cmdWithEnv.Dir = rpty.command.Dir + ptty, process, err := pty.Start(cmdWithEnv) + if err != nil { + rpty.state.setState(StateDone, xerrors.Errorf("generate screen id: %w", err)) + return rpty + } + rpty.ptty = ptty + rpty.process = process + + go rpty.lifecycle(ctx, logger) + + // Multiplex the output onto the circular buffer and each active connection. + // We do not need to separately monitor for the process exiting. When it + // exits, our ptty.OutputReader() will return EOF after reading all process + // output. + go func() { + buffer := make([]byte, 1024) + for { + read, err := ptty.OutputReader().Read(buffer) + if err != nil { + // When the PTY is closed, this is triggered. + // Error is typically a benign EOF, so only log for debugging. + if errors.Is(err, io.EOF) { + logger.Debug(ctx, "unable to read pty output, command might have exited", slog.Error(err)) + } else { + logger.Warn(ctx, "unable to read pty output, command might have exited", slog.Error(err)) + rpty.metrics.WithLabelValues("output_reader").Add(1) + } + // Could have been killed externally or failed to start at all (command + // not found for example). + // TODO: Should we check the process's exit code in case the command was + // invalid? + rpty.Close("unable to read pty output, command might have exited") + break + } + part := buffer[:read] + rpty.state.cond.L.Lock() + _, err = rpty.circularBuffer.Write(part) + if err != nil { + logger.Error(ctx, "write to circular buffer", slog.Error(err)) + rpty.metrics.WithLabelValues("write_buffer").Add(1) + } + // TODO: Instead of ranging over a map, could we send the output to a + // channel and have each individual Attach read from that? + for cid, conn := range rpty.activeConns { + _, err = conn.Write(part) + if err != nil { + logger.Warn(ctx, + "error writing to active connection", + slog.F("connection_id", cid), + slog.Error(err), + ) + rpty.metrics.WithLabelValues("write").Add(1) + } + } + rpty.state.cond.L.Unlock() + } + }() + + return rpty +} + +// lifecycle manages the lifecycle of the reconnecting pty. If the context ends +// or the reconnecting pty closes the pty will be shut down. +func (rpty *bufferedReconnectingPTY) lifecycle(ctx context.Context, logger slog.Logger) { + rpty.timer = time.AfterFunc(attachTimeout, func() { + rpty.Close("reconnecting pty timeout") + }) + + logger.Debug(ctx, "reconnecting pty ready") + rpty.state.setState(StateReady, nil) + + state, reasonErr := rpty.state.waitForStateOrContext(ctx, StateClosing) + if state < StateClosing { + // If we have not closed yet then the context is what unblocked us (which + // means the agent is shutting down) so move into the closing phase. + rpty.Close(reasonErr.Error()) + } + rpty.timer.Stop() + + rpty.state.cond.L.Lock() + // Log these closes only for debugging since the connections or processes + // might have already closed on their own. + for _, conn := range rpty.activeConns { + err := conn.Close() + if err != nil { + logger.Debug(ctx, "closed conn with error", slog.Error(err)) + } + } + // Connections get removed once they close but it is possible there is still + // some data that will be written before that happens so clear the map now to + // avoid writing to closed connections. + rpty.activeConns = map[string]net.Conn{} + rpty.state.cond.L.Unlock() + + // Log close/kill only for debugging since the process might have already + // closed on its own. + err := rpty.ptty.Close() + if err != nil { + logger.Debug(ctx, "closed ptty with error", slog.Error(err)) + } + + err = rpty.process.Kill() + if err != nil { + logger.Debug(ctx, "killed process with error", slog.Error(err)) + } + + logger.Info(ctx, "closed reconnecting pty") + rpty.state.setState(StateDone, xerrors.Errorf("reconnecting pty closed: %w", reasonErr)) +} + +func (rpty *bufferedReconnectingPTY) Attach(ctx context.Context, connID string, conn net.Conn, height, width uint16, logger slog.Logger) error { + logger.Info(ctx, "attach to reconnecting pty") + + // This will kill the heartbeat once we hit EOF or an error. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + err := rpty.doAttach(ctx, connID, conn, height, width, logger) + if err != nil { + return err + } + + defer func() { + rpty.state.cond.L.Lock() + defer rpty.state.cond.L.Unlock() + delete(rpty.activeConns, connID) + }() + + // Pipe conn -> pty and block. pty -> conn is handled in newBuffered(). + readConnLoop(ctx, conn, rpty.ptty, rpty.metrics, logger) + return nil +} + +// doAttach adds the connection to the map, replays the buffer, and starts the +// heartbeat. It exists separately only so we can defer the mutex unlock which +// is not possible in Attach since it blocks. +func (rpty *bufferedReconnectingPTY) doAttach(ctx context.Context, connID string, conn net.Conn, height, width uint16, logger slog.Logger) error { + rpty.state.cond.L.Lock() + defer rpty.state.cond.L.Unlock() + + // Write any previously stored data for the TTY. Since the command might be + // short-lived and have already exited, make sure we always at least output + // the buffer before returning, mostly just so tests pass. + prevBuf := slices.Clone(rpty.circularBuffer.Bytes()) + _, err := conn.Write(prevBuf) + if err != nil { + rpty.metrics.WithLabelValues("write").Add(1) + return xerrors.Errorf("write buffer to conn: %w", err) + } + + state, err := rpty.state.waitForStateOrContextLocked(ctx, StateReady) + if state != StateReady { + return xerrors.Errorf("reconnecting pty ready wait: %w", err) + } + + go heartbeat(ctx, rpty.timer, rpty.timeout) + + // Resize the PTY to initial height + width. + err = rpty.ptty.Resize(height, width) + if err != nil { + // We can continue after this, it's not fatal! + logger.Warn(ctx, "reconnecting PTY initial resize failed, but will continue", slog.Error(err)) + rpty.metrics.WithLabelValues("resize").Add(1) + } + + rpty.activeConns[connID] = conn + + return nil +} + +func (rpty *bufferedReconnectingPTY) Wait() { + _, _ = rpty.state.waitForState(StateClosing) +} + +func (rpty *bufferedReconnectingPTY) Close(reason string) { + // The closing state change will be handled by the lifecycle. + rpty.state.setState(StateClosing, xerrors.Errorf("reconnecting pty closing: %s", reason)) +} diff --git a/agent/reconnectingpty/reconnectingpty.go b/agent/reconnectingpty/reconnectingpty.go new file mode 100644 index 0000000000000..e3dbb9024b063 --- /dev/null +++ b/agent/reconnectingpty/reconnectingpty.go @@ -0,0 +1,231 @@ +package reconnectingpty + +import ( + "context" + "encoding/json" + "io" + "net" + "os/exec" + "runtime" + "sync" + "time" + + "github.com/prometheus/client_golang/prometheus" + "golang.org/x/xerrors" + + "cdr.dev/slog" + + "github.com/coder/coder/codersdk" + "github.com/coder/coder/pty" +) + +// attachTimeout is the initial timeout for attaching and will probably be far +// shorter than the reconnect timeout in most cases; in tests it might be +// longer. It should be at least long enough for the first screen attach to be +// able to start up the daemon and for the buffered pty to start. +const attachTimeout = 30 * time.Second + +// Options allows configuring the reconnecting pty. +type Options struct { + // Timeout describes how long to keep the pty alive without any connections. + // Once elapsed the pty will be killed. + Timeout time.Duration + // Metrics tracks various error counters. + Metrics *prometheus.CounterVec +} + +// ReconnectingPTY is a pty that can be reconnected within a timeout and to +// simultaneous connections. The reconnecting pty can be backed by screen if +// installed or a (buggy) buffer replay fallback. +type ReconnectingPTY interface { + // Attach pipes the connection and pty, spawning it if necessary, replays + // history, then blocks until EOF, an error, or the context's end. The + // connection is expected to send JSON-encoded messages and accept raw output + // from the ptty. If the context ends or the process dies the connection will + // be detached. + Attach(ctx context.Context, connID string, conn net.Conn, height, width uint16, logger slog.Logger) error + // Wait waits for the reconnecting pty to close. The underlying process might + // still be exiting. + Wait() + // Close kills the reconnecting pty process. + Close(reason string) +} + +// New sets up a new reconnecting pty that wraps the provided command. Any +// errors with starting are returned on Attach(). The reconnecting pty will +// close itself (and all connections to it) if nothing is attached for the +// duration of the timeout, if the context ends, or the process exits (buffered +// backend only). +func New(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.Logger) ReconnectingPTY { + if options.Timeout == 0 { + options.Timeout = 5 * time.Minute + } + // Screen seems flaky on Darwin. Locally the tests pass 100% of the time (100 + // runs) but in CI screen often incorrectly claims the session name does not + // exist even though screen -list shows it. For now, restrict screen to + // Linux. + backendType := "buffered" + if runtime.GOOS == "linux" { + _, err := exec.LookPath("screen") + if err == nil { + backendType = "screen" + } + } + + logger.Info(ctx, "start reconnecting pty", slog.F("backend_type", backendType)) + + switch backendType { + case "screen": + return newScreen(ctx, cmd, options, logger) + default: + return newBuffered(ctx, cmd, options, logger) + } +} + +// heartbeat resets timer before timeout elapses and blocks until ctx ends. +func heartbeat(ctx context.Context, timer *time.Timer, timeout time.Duration) { + // Reset now in case it is near the end. + timer.Reset(timeout) + + // Reset when the context ends to ensure the pty stays up for the full + // timeout. + defer timer.Reset(timeout) + + heartbeat := time.NewTicker(timeout / 2) + defer heartbeat.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-heartbeat.C: + timer.Reset(timeout) + } + } +} + +// State represents the current state of the reconnecting pty. States are +// sequential and will only move forward. +type State int + +const ( + // StateStarting is the default/start state. Attaching will block until the + // reconnecting pty becomes ready. + StateStarting = iota + // StateReady means the reconnecting pty is ready to be attached. + StateReady + // StateClosing means the reconnecting pty has begun closing. The underlying + // process may still be exiting. Attaching will result in an error. + StateClosing + // StateDone means the reconnecting pty has completely shut down and the + // process has exited. Attaching will result in an error. + StateDone +) + +// ptyState is a helper for tracking the reconnecting PTY's state. +type ptyState struct { + // cond broadcasts state changes and any accompanying errors. + cond *sync.Cond + // error describes the error that caused the state change, if there was one. + // It is not safe to access outside of cond.L. + error error + // state holds the current reconnecting pty state. It is not safe to access + // this outside of cond.L. + state State +} + +func newState() *ptyState { + return &ptyState{ + cond: sync.NewCond(&sync.Mutex{}), + state: StateStarting, + } +} + +// setState sets and broadcasts the provided state if it is greater than the +// current state and the error if one has not already been set. +func (s *ptyState) setState(state State, err error) { + s.cond.L.Lock() + defer s.cond.L.Unlock() + // Cannot regress states. For example, trying to close after the process is + // done should leave us in the done state and not the closing state. + if state <= s.state { + return + } + s.error = err + s.state = state + s.cond.Broadcast() +} + +// waitForState blocks until the state or a greater one is reached. +func (s *ptyState) waitForState(state State) (State, error) { + s.cond.L.Lock() + defer s.cond.L.Unlock() + for state > s.state { + s.cond.Wait() + } + return s.state, s.error +} + +// waitForStateOrContext blocks until the state or a greater one is reached or +// the provided context ends. +func (s *ptyState) waitForStateOrContext(ctx context.Context, state State) (State, error) { + s.cond.L.Lock() + defer s.cond.L.Unlock() + return s.waitForStateOrContextLocked(ctx, state) +} + +// waitForStateOrContextLocked is the same as waitForStateOrContext except it +// assumes the caller has already locked cond. +func (s *ptyState) waitForStateOrContextLocked(ctx context.Context, state State) (State, error) { + nevermind := make(chan struct{}) + defer close(nevermind) + go func() { + select { + case <-ctx.Done(): + // Wake up when the context ends. + s.cond.Broadcast() + case <-nevermind: + } + }() + + for ctx.Err() == nil && state > s.state { + s.cond.Wait() + } + if ctx.Err() != nil { + return s.state, ctx.Err() + } + return s.state, s.error +} + +// readConnLoop reads messages from conn and writes to ptty as needed. Blocks +// until EOF or an error writing to ptty or reading from conn. +func readConnLoop(ctx context.Context, conn net.Conn, ptty pty.PTYCmd, metrics *prometheus.CounterVec, logger slog.Logger) { + decoder := json.NewDecoder(conn) + var req codersdk.ReconnectingPTYRequest + for { + err := decoder.Decode(&req) + if xerrors.Is(err, io.EOF) { + return + } + if err != nil { + logger.Warn(ctx, "reconnecting pty failed with read error", slog.Error(err)) + return + } + _, err = ptty.InputWriter().Write([]byte(req.Data)) + if err != nil { + logger.Warn(ctx, "reconnecting pty failed with write error", slog.Error(err)) + metrics.WithLabelValues("input_writer").Add(1) + return + } + // Check if a resize needs to happen! + if req.Height == 0 || req.Width == 0 { + continue + } + err = ptty.Resize(req.Height, req.Width) + if err != nil { + // We can continue after this, it's not fatal! + logger.Warn(ctx, "reconnecting pty resize failed, but will continue", slog.Error(err)) + metrics.WithLabelValues("resize").Add(1) + } + } +} diff --git a/agent/reconnectingpty/screen.go b/agent/reconnectingpty/screen.go new file mode 100644 index 0000000000000..0203154f83335 --- /dev/null +++ b/agent/reconnectingpty/screen.go @@ -0,0 +1,388 @@ +package reconnectingpty + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/hex" + "errors" + "io" + "net" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/gliderlabs/ssh" + "github.com/prometheus/client_golang/prometheus" + "golang.org/x/xerrors" + + "cdr.dev/slog" + "github.com/coder/coder/pty" +) + +// screenReconnectingPTY provides a reconnectable PTY via `screen`. +type screenReconnectingPTY struct { + command *pty.Cmd + + // id holds the id of the session for both creating and attaching. This will + // be generated uniquely for each session because without control of the + // screen daemon we do not have its PID and without the PID screen will do + // partial matching. Enforcing a unique ID should guarantee we match on the + // right session. + id string + + // mutex prevents concurrent attaches to the session. Screen will happily + // spawn two separate sessions with the same name if multiple attaches happen + // in a close enough interval. We are not able to control the screen daemon + // ourselves to prevent this because the daemon will spawn with a hardcoded + // 24x80 size which results in confusing padding above the prompt once the + // attach comes in and resizes. + mutex sync.Mutex + + configFile string + + metrics *prometheus.CounterVec + + state *ptyState + // timer will close the reconnecting pty when it expires. The timer will be + // reset as long as there are active connections. + timer *time.Timer + timeout time.Duration +} + +// newScreen creates a new screen-backed reconnecting PTY. It writes config +// settings and creates the socket directory. If we could, we would want to +// spawn the daemon here and attach each connection to it but since doing that +// spawns the daemon with a hardcoded 24x80 size it is not a very good user +// experience. Instead we will let the attach command spawn the daemon on its +// own which causes it to spawn with the specified size. +func newScreen(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.Logger) *screenReconnectingPTY { + rpty := &screenReconnectingPTY{ + command: cmd, + metrics: options.Metrics, + state: newState(), + timeout: options.Timeout, + } + + go rpty.lifecycle(ctx, logger) + + // Socket paths are limited to around 100 characters on Linux and macOS which + // depending on the temporary directory can be a problem. To give more leeway + // use a short ID. + buf := make([]byte, 4) + _, err := rand.Read(buf) + if err != nil { + rpty.state.setState(StateDone, xerrors.Errorf("generate screen id: %w", err)) + return rpty + } + rpty.id = hex.EncodeToString(buf) + + settings := []string{ + // Tell screen not to handle motion for xterm* terminals which allows + // scrolling the terminal via the mouse wheel or scroll bar (by default + // screen uses it to cycle through the command history). There does not + // seem to be a way to make screen itself scroll on mouse wheel. tmux can + // do it but then there is no scroll bar and it kicks you into copy mode + // where keys stop working until you exit copy mode which seems like it + // could be confusing. + "termcapinfo xterm* ti@:te@", + // Enable alternate screen emulation otherwise applications get rendered in + // the current window which wipes out visible output resulting in missing + // output when scrolling back with the mouse wheel (copy mode still works + // since that is screen itself scrolling). + "altscreen on", + // Remap the control key to C-s since C-a may be used in applications. C-s + // is chosen because it cannot actually be used because by default it will + // pause and C-q to resume will just kill the browser window. We may not + // want people using the control key anyway since it will not be obvious + // they are in screen and doing things like switching windows makes mouse + // wheel scroll wonky due to the terminal doing the scrolling rather than + // screen itself (but again copy mode will work just fine). + "escape ^Ss", + } + + rpty.configFile = filepath.Join(os.TempDir(), "coder-screen", "config") + err = os.MkdirAll(filepath.Dir(rpty.configFile), 0o700) + if err != nil { + rpty.state.setState(StateDone, xerrors.Errorf("make screen config dir: %w", err)) + return rpty + } + + err = os.WriteFile(rpty.configFile, []byte(strings.Join(settings, "\n")), 0o600) + if err != nil { + rpty.state.setState(StateDone, xerrors.Errorf("create config file: %w", err)) + return rpty + } + + return rpty +} + +// lifecycle manages the lifecycle of the reconnecting pty. If the context ends +// the reconnecting pty will be closed. +func (rpty *screenReconnectingPTY) lifecycle(ctx context.Context, logger slog.Logger) { + rpty.timer = time.AfterFunc(attachTimeout, func() { + rpty.Close("reconnecting pty timeout") + }) + + logger.Debug(ctx, "reconnecting pty ready") + rpty.state.setState(StateReady, nil) + + state, reasonErr := rpty.state.waitForStateOrContext(ctx, StateClosing) + if state < StateClosing { + // If we have not closed yet then the context is what unblocked us (which + // means the agent is shutting down) so move into the closing phase. + rpty.Close(reasonErr.Error()) + } + rpty.timer.Stop() + + // If the command errors that the session is already gone that is fine. + err := rpty.sendCommand(context.Background(), "quit", []string{"No screen session found"}) + if err != nil { + logger.Error(ctx, "close screen session", slog.Error(err)) + } + + logger.Info(ctx, "closed reconnecting pty") + rpty.state.setState(StateDone, xerrors.Errorf("reconnecting pty closed: %w", reasonErr)) +} + +func (rpty *screenReconnectingPTY) Attach(ctx context.Context, _ string, conn net.Conn, height, width uint16, logger slog.Logger) error { + logger.Info(ctx, "attach to reconnecting pty") + + // This will kill the heartbeat once we hit EOF or an error. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + state, err := rpty.state.waitForStateOrContext(ctx, StateReady) + if state != StateReady { + return xerrors.Errorf("reconnecting pty ready wait: %w", err) + } + + go heartbeat(ctx, rpty.timer, rpty.timeout) + + ptty, process, err := rpty.doAttach(ctx, conn, height, width, logger) + if err != nil { + if errors.Is(err, context.Canceled) { + // Likely the process was too short-lived and canceled the version command. + // TODO: Is it worth distinguishing between that and a cancel from the + // Attach() caller? Additionally, since this could also happen if + // the command was invalid, should we check the process's exit code? + return nil + } + return err + } + + defer func() { + // Log only for debugging since the process might have already exited on its + // own. + err := ptty.Close() + if err != nil { + logger.Debug(ctx, "closed ptty with error", slog.Error(err)) + } + err = process.Kill() + if err != nil { + logger.Debug(ctx, "killed process with error", slog.Error(err)) + } + }() + + // Pipe conn -> pty and block. + readConnLoop(ctx, conn, ptty, rpty.metrics, logger) + return nil +} + +// doAttach spawns the screen client and starts the heartbeat. It exists +// separately only so we can defer the mutex unlock which is not possible in +// Attach since it blocks. +func (rpty *screenReconnectingPTY) doAttach(ctx context.Context, conn net.Conn, height, width uint16, logger slog.Logger) (pty.PTYCmd, pty.Process, error) { + // Ensure another attach does not come in and spawn a duplicate session. + rpty.mutex.Lock() + defer rpty.mutex.Unlock() + + logger.Debug(ctx, "spawning screen client", slog.F("screen_id", rpty.id)) + + // Wrap the command with screen and tie it to the connection's context. + cmd := pty.CommandContext(ctx, "screen", append([]string{ + // -S is for setting the session's name. + "-S", rpty.id, + // -x allows attaching to an already attached session. + // -RR reattaches to the daemon or creates the session daemon if missing. + // -q disables the "New screen..." message that appears for five seconds + // when creating a new session with -RR. + // -c is the flag for the config file. + "-xRRqc", rpty.configFile, + rpty.command.Path, + // pty.Cmd duplicates Path as the first argument so remove it. + }, rpty.command.Args[1:]...)...) + cmd.Env = append(rpty.command.Env, "TERM=xterm-256color") + cmd.Dir = rpty.command.Dir + ptty, process, err := pty.Start(cmd, pty.WithPTYOption( + pty.WithSSHRequest(ssh.Pty{ + Window: ssh.Window{ + // Make sure to spawn at the right size because if we resize afterward it + // leaves confusing padding (screen will resize such that the screen + // contents are aligned to the bottom). + Height: int(height), + Width: int(width), + }, + }), + )) + if err != nil { + rpty.metrics.WithLabelValues("screen_spawn").Add(1) + return nil, nil, err + } + + // This context lets us abort the version command if the process dies. + versionCtx, versionCancel := context.WithCancel(ctx) + defer versionCancel() + + // Pipe pty -> conn and close the connection when the process exits. + // We do not need to separately monitor for the process exiting. When it + // exits, our ptty.OutputReader() will return EOF after reading all process + // output. + go func() { + defer versionCancel() + defer func() { + err := conn.Close() + if err != nil { + // Log only for debugging since the connection might have already closed + // on its own. + logger.Debug(ctx, "closed connection with error", slog.Error(err)) + } + }() + buffer := make([]byte, 1024) + for { + read, err := ptty.OutputReader().Read(buffer) + if err != nil { + // When the PTY is closed, this is triggered. + // Error is typically a benign EOF, so only log for debugging. + if errors.Is(err, io.EOF) { + logger.Debug(ctx, "unable to read pty output; screen might have exited", slog.Error(err)) + } else { + logger.Warn(ctx, "unable to read pty output; screen might have exited", slog.Error(err)) + rpty.metrics.WithLabelValues("screen_output_reader").Add(1) + } + // The process might have died because the session itself died or it + // might have been separately killed and the session is still up (for + // example `exit` or we killed it when the connection closed). If the + // session is still up we might leave the reconnecting pty in memory + // around longer than it needs to be but it will eventually clean up + // with the timer or context, or the next attach will respawn the screen + // daemon which is fine too. + break + } + part := buffer[:read] + _, err = conn.Write(part) + if err != nil { + // Connection might have been closed. + if errors.Unwrap(err).Error() != "endpoint is closed for send" { + logger.Warn(ctx, "error writing to active conn", slog.Error(err)) + rpty.metrics.WithLabelValues("screen_write").Add(1) + } + break + } + } + }() + + // Version seems to be the only command without a side effect (other than + // making the version pop up briefly) so use it to wait for the session to + // come up. If we do not wait we could end up spawning multiple sessions with + // the same name. + err = rpty.sendCommand(versionCtx, "version", nil) + if err != nil { + // Log only for debugging since the process might already have closed. + closeErr := ptty.Close() + if closeErr != nil { + logger.Debug(ctx, "closed ptty with error", slog.Error(closeErr)) + } + closeErr = process.Kill() + if closeErr != nil { + logger.Debug(ctx, "killed process with error", slog.Error(closeErr)) + } + rpty.metrics.WithLabelValues("screen_wait").Add(1) + return nil, nil, err + } + + return ptty, process, nil +} + +// sendCommand runs a screen command against a running screen session. If the +// command fails with an error matching anything in successErrors it will be +// considered a success state (for example "no session" when quitting and the +// session is already dead). The command will be retried until successful, the +// timeout is reached, or the context ends. A canceled context will return the +// canceled context's error as-is while a timed-out context returns together +// with the last error from the command. +func (rpty *screenReconnectingPTY) sendCommand(ctx context.Context, command string, successErrors []string) error { + ctx, cancel := context.WithTimeout(ctx, attachTimeout) + defer cancel() + + var lastErr error + run := func() bool { + var stdout bytes.Buffer + //nolint:gosec + cmd := exec.CommandContext(ctx, "screen", + // -x targets an attached session. + "-x", rpty.id, + // -c is the flag for the config file. + "-c", rpty.configFile, + // -X runs a command in the matching session. + "-X", command, + ) + cmd.Env = append(rpty.command.Env, "TERM=xterm-256color") + cmd.Dir = rpty.command.Dir + cmd.Stdout = &stdout + err := cmd.Run() + if err == nil { + return true + } + + stdoutStr := stdout.String() + for _, se := range successErrors { + if strings.Contains(stdoutStr, se) { + return true + } + } + + // Things like "exit status 1" are imprecise so include stdout as it may + // contain more information ("no screen session found" for example). + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + lastErr = xerrors.Errorf("`screen -x %s -X %s`: %w: %s", rpty.id, command, err, stdoutStr) + } + + return false + } + + // Run immediately. + if done := run(); done { + return nil + } + + // Then run on an interval. + ticker := time.NewTicker(250 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.Canceled) { + return ctx.Err() + } + return errors.Join(ctx.Err(), lastErr) + case <-ticker.C: + if done := run(); done { + return nil + } + } + } +} + +func (rpty *screenReconnectingPTY) Wait() { + _, _ = rpty.state.waitForState(StateClosing) +} + +func (rpty *screenReconnectingPTY) Close(reason string) { + // The closing state change will be handled by the lifecycle. + rpty.state.setState(StateClosing, xerrors.Errorf("reconnecting pty closing: %s", reason)) +} diff --git a/coderd/workspaceapps/apptest/apptest.go b/coderd/workspaceapps/apptest/apptest.go index f64ba7c30bf31..1c20fd19f5b28 100644 --- a/coderd/workspaceapps/apptest/apptest.go +++ b/coderd/workspaceapps/apptest/apptest.go @@ -12,6 +12,7 @@ import ( "net/http/httputil" "net/url" "path" + "regexp" "runtime" "strconv" "strings" @@ -30,6 +31,10 @@ import ( "github.com/coder/coder/testutil" ) +const ansi = "[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))" + +var re = regexp.MustCompile(ansi) + // Run runs the entire workspace app test suite against deployments minted // by the provided factory. // @@ -51,23 +56,8 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) { t.Skip("ConPTY appears to be inconsistent on Windows.") } - expectLine := func(t *testing.T, r *bufio.Reader, matcher func(string) bool) { - for { - line, err := r.ReadString('\n') - require.NoError(t, err) - if matcher(line) { - break - } - } - } - matchEchoCommand := func(line string) bool { - return strings.Contains(line, "echo test") - } - matchEchoOutput := func(line string) bool { - return strings.Contains(line, "test") && !strings.Contains(line, "echo") - } - t.Run("OK", func(t *testing.T) { + t.Parallel() appDetails := setupProxyTest(t, nil) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) @@ -76,40 +66,13 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) { // Run the test against the path app hostname since that's where the // reconnecting-pty proxy server we want to test is mounted. client := appDetails.AppClient(t) - conn, err := client.WorkspaceAgentReconnectingPTY(ctx, codersdk.WorkspaceAgentReconnectingPTYOpts{ + testReconnectingPTY(ctx, t, client, codersdk.WorkspaceAgentReconnectingPTYOpts{ AgentID: appDetails.Agent.ID, Reconnect: uuid.New(), Height: 80, Width: 80, - Command: "/bin/bash", + Command: "bash", }) - require.NoError(t, err) - defer conn.Close() - - // First attempt to resize the TTY. - // The websocket will close if it fails! - data, err := json.Marshal(codersdk.ReconnectingPTYRequest{ - Height: 250, - Width: 250, - }) - require.NoError(t, err) - _, err = conn.Write(data) - require.NoError(t, err) - bufRead := bufio.NewReader(conn) - - // Brief pause to reduce the likelihood that we send keystrokes while - // the shell is simultaneously sending a prompt. - time.Sleep(100 * time.Millisecond) - - data, err = json.Marshal(codersdk.ReconnectingPTYRequest{ - Data: "echo test\r\n", - }) - require.NoError(t, err) - _, err = conn.Write(data) - require.NoError(t, err) - - expectLine(t, bufRead, matchEchoCommand) - expectLine(t, bufRead, matchEchoOutput) }) t.Run("SignedTokenQueryParameter", func(t *testing.T) { @@ -137,41 +100,14 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) { // Make an unauthenticated client. unauthedAppClient := codersdk.New(appDetails.AppClient(t).URL) - conn, err := unauthedAppClient.WorkspaceAgentReconnectingPTY(ctx, codersdk.WorkspaceAgentReconnectingPTYOpts{ + testReconnectingPTY(ctx, t, unauthedAppClient, codersdk.WorkspaceAgentReconnectingPTYOpts{ AgentID: appDetails.Agent.ID, Reconnect: uuid.New(), Height: 80, Width: 80, - Command: "/bin/bash", + Command: "bash", SignedToken: issueRes.SignedToken, }) - require.NoError(t, err) - defer conn.Close() - - // First attempt to resize the TTY. - // The websocket will close if it fails! - data, err := json.Marshal(codersdk.ReconnectingPTYRequest{ - Height: 250, - Width: 250, - }) - require.NoError(t, err) - _, err = conn.Write(data) - require.NoError(t, err) - bufRead := bufio.NewReader(conn) - - // Brief pause to reduce the likelihood that we send keystrokes while - // the shell is simultaneously sending a prompt. - time.Sleep(100 * time.Millisecond) - - data, err = json.Marshal(codersdk.ReconnectingPTYRequest{ - Data: "echo test\r\n", - }) - require.NoError(t, err) - _, err = conn.Write(data) - require.NoError(t, err) - - expectLine(t, bufRead, matchEchoCommand) - expectLine(t, bufRead, matchEchoOutput) }) }) @@ -1407,3 +1343,75 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) { require.Equal(t, []string{"baz"}, resp.Header.Values("X-Foobar")) }) } + +func testReconnectingPTY(ctx context.Context, t *testing.T, client *codersdk.Client, opts codersdk.WorkspaceAgentReconnectingPTYOpts) { + hasLine := func(scanner *bufio.Scanner, matcher func(string) bool) bool { + for scanner.Scan() { + line := scanner.Text() + t.Logf("bash tty stdout = %s", re.ReplaceAllString(line, "")) + if matcher(line) { + return true + } + } + return false + } + matchEchoCommand := func(line string) bool { + return strings.Contains(line, "echo test") + } + matchEchoOutput := func(line string) bool { + return strings.Contains(line, "test") && !strings.Contains(line, "echo") + } + matchExitCommand := func(line string) bool { + return strings.Contains(line, "exit") + } + matchExitOutput := func(line string) bool { + return strings.Contains(line, "exit") || strings.Contains(line, "logout") + } + + conn, err := client.WorkspaceAgentReconnectingPTY(ctx, opts) + require.NoError(t, err) + defer conn.Close() + + // First attempt to resize the TTY. + // The websocket will close if it fails! + data, err := json.Marshal(codersdk.ReconnectingPTYRequest{ + Height: 250, + Width: 250, + }) + require.NoError(t, err) + _, err = conn.Write(data) + require.NoError(t, err) + scanner := bufio.NewScanner(conn) + + // Brief pause to reduce the likelihood that we send keystrokes while + // the shell is simultaneously sending a prompt. + time.Sleep(100 * time.Millisecond) + + data, err = json.Marshal(codersdk.ReconnectingPTYRequest{ + Data: "echo test\r\n", + }) + require.NoError(t, err) + _, err = conn.Write(data) + require.NoError(t, err) + + require.True(t, hasLine(scanner, matchEchoCommand), "find echo command") + require.True(t, hasLine(scanner, matchEchoOutput), "find echo output") + + // Exit should cause the connection to close. + data, err = json.Marshal(codersdk.ReconnectingPTYRequest{ + Data: "exit\r\n", + }) + require.NoError(t, err) + _, err = conn.Write(data) + require.NoError(t, err) + + // Once for the input and again for the output. + require.True(t, hasLine(scanner, matchExitCommand), "find exit command") + require.True(t, hasLine(scanner, matchExitOutput), "find exit output") + + // Ensure the connection closes. + for scanner.Scan() { + line := scanner.Text() + t.Logf("bash tty stdout = %s", re.ReplaceAllString(line, "")) + } +} diff --git a/dogfood/Dockerfile b/dogfood/Dockerfile index c5145ecb98629..f269962369ef9 100644 --- a/dogfood/Dockerfile +++ b/dogfood/Dockerfile @@ -162,6 +162,7 @@ RUN apt-get update --quiet && apt-get install --yes \ fish \ unzip \ zstd \ + screen \ gettext-base && \ # Delete package cache to avoid consuming space in layer apt-get clean && \ diff --git a/flake.nix b/flake.nix index d2e7cd492dd7a..18d06adc4a98f 100644 --- a/flake.nix +++ b/flake.nix @@ -44,6 +44,7 @@ postgresql protoc-gen-go ripgrep + screen shellcheck shfmt sqlc diff --git a/scaletest/workspacetraffic/conn.go b/scaletest/workspacetraffic/conn.go index 167164c5ef33f..4be38a02c6abf 100644 --- a/scaletest/workspacetraffic/conn.go +++ b/scaletest/workspacetraffic/conn.go @@ -19,7 +19,7 @@ func connectPTY(ctx context.Context, client *codersdk.Client, agentID, reconnect Reconnect: reconnect, Height: 25, Width: 80, - Command: "/bin/sh", + Command: "sh", }) if err != nil { return nil, xerrors.Errorf("connect pty: %w", err)