Skip to content

Commit a7468cd

Browse files
committed
Add session with screen and buffer implementations
1 parent a8360d1 commit a7468cd

File tree

11 files changed

+782
-210
lines changed

11 files changed

+782
-210
lines changed

agent/agent.go

Lines changed: 23 additions & 184 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import (
2121
"sync"
2222
"time"
2323

24-
"github.com/armon/circbuf"
2524
"github.com/google/uuid"
2625
"github.com/prometheus/client_golang/prometheus"
2726
"github.com/spf13/afero"
@@ -34,12 +33,12 @@ import (
3433

3534
"cdr.dev/slog"
3635
"github.com/coder/coder/agent/agentssh"
36+
"github.com/coder/coder/agent/reconnectingpty"
3737
"github.com/coder/coder/buildinfo"
3838
"github.com/coder/coder/coderd/database"
3939
"github.com/coder/coder/coderd/gitauth"
4040
"github.com/coder/coder/codersdk"
4141
"github.com/coder/coder/codersdk/agentsdk"
42-
"github.com/coder/coder/pty"
4342
"github.com/coder/coder/tailnet"
4443
"github.com/coder/retry"
4544
)
@@ -87,9 +86,6 @@ type Agent interface {
8786
}
8887

8988
func New(options Options) Agent {
90-
if options.ReconnectingPTYTimeout == 0 {
91-
options.ReconnectingPTYTimeout = 5 * time.Minute
92-
}
9389
if options.Filesystem == nil {
9490
options.Filesystem = afero.NewOsFs()
9591
}
@@ -1042,14 +1038,14 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10421038
logger.Debug(ctx, "session closed")
10431039
}()
10441040

1045-
var rpty *reconnectingPTY
1046-
sendConnected := make(chan *reconnectingPTY, 1)
1041+
var rpty *reconnectingpty.ReconnectingPTY
1042+
sendConnected := make(chan *reconnectingpty.ReconnectingPTY, 1)
10471043
// On store, reserve this ID to prevent multiple concurrent new connections.
10481044
waitReady, ok := a.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected)
10491045
if ok {
10501046
close(sendConnected) // Unused.
10511047
logger.Debug(ctx, "connecting to existing session")
1052-
c, ok := waitReady.(chan *reconnectingPTY)
1048+
c, ok := waitReady.(chan *reconnectingpty.ReconnectingPTY)
10531049
if !ok {
10541050
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady)
10551051
}
@@ -1075,169 +1071,37 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10751071
a.metrics.reconnectingPTYErrors.WithLabelValues("create_command").Add(1)
10761072
return xerrors.Errorf("create command: %w", err)
10771073
}
1078-
cmd.Env = append(cmd.Env, "TERM=xterm-256color")
10791074

1080-
// Default to buffer 64KiB.
1081-
circularBuffer, err := circbuf.NewBuffer(64 << 10)
1082-
if err != nil {
1083-
return xerrors.Errorf("create circular buffer: %w", err)
1075+
// The ablity to select the backend type is mostly for tests.
1076+
backendType := msg.BackendType
1077+
if backendType == "" {
1078+
_, err = exec.LookPath("screen")
1079+
if err == nil {
1080+
backendType = codersdk.ReconnectingPTYBackendTypeScreen
1081+
} else {
1082+
backendType = codersdk.ReconnectingPTYBackendTypeBuffered
1083+
}
10841084
}
10851085

1086-
ptty, process, err := pty.Start(cmd)
1087-
if err != nil {
1088-
a.metrics.reconnectingPTYErrors.WithLabelValues("start_command").Add(1)
1089-
return xerrors.Errorf("start command: %w", err)
1090-
}
1086+
rpty = reconnectingpty.New(ctx, cmd, &reconnectingpty.Options{
1087+
BackendType: backendType,
1088+
Timeout: a.reconnectingPTYTimeout,
1089+
Metrics: a.metrics.reconnectingPTYErrors,
1090+
Logger: logger,
1091+
})
10911092

1092-
ctx, cancel := context.WithCancel(ctx)
1093-
rpty = &reconnectingPTY{
1094-
activeConns: map[string]net.Conn{
1095-
// We have to put the connection in the map instantly otherwise
1096-
// the connection won't be closed if the process instantly dies.
1097-
connectionID: conn,
1098-
},
1099-
ptty: ptty,
1100-
// Timeouts created with an after func can be reset!
1101-
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancel),
1102-
circularBuffer: circularBuffer,
1103-
}
1104-
// We don't need to separately monitor for the process exiting.
1105-
// When it exits, our ptty.OutputReader() will return EOF after
1106-
// reading all process output.
11071093
if err = a.trackConnGoroutine(func() {
1108-
buffer := make([]byte, 1024)
1109-
for {
1110-
read, err := rpty.ptty.OutputReader().Read(buffer)
1111-
if err != nil {
1112-
// When the PTY is closed, this is triggered.
1113-
// Error is typically a benign EOF, so only log for debugging.
1114-
if errors.Is(err, io.EOF) {
1115-
logger.Debug(ctx, "unable to read pty output, command might have exited", slog.Error(err))
1116-
} else {
1117-
logger.Warn(ctx, "unable to read pty output, command might have exited", slog.Error(err))
1118-
a.metrics.reconnectingPTYErrors.WithLabelValues("output_reader").Add(1)
1119-
}
1120-
break
1121-
}
1122-
part := buffer[:read]
1123-
rpty.circularBufferMutex.Lock()
1124-
_, err = rpty.circularBuffer.Write(part)
1125-
rpty.circularBufferMutex.Unlock()
1126-
if err != nil {
1127-
logger.Error(ctx, "write to circular buffer", slog.Error(err))
1128-
break
1129-
}
1130-
rpty.activeConnsMutex.Lock()
1131-
for cid, conn := range rpty.activeConns {
1132-
_, err = conn.Write(part)
1133-
if err != nil {
1134-
logger.Warn(ctx,
1135-
"error writing to active conn",
1136-
slog.F("other_conn_id", cid),
1137-
slog.Error(err),
1138-
)
1139-
a.metrics.reconnectingPTYErrors.WithLabelValues("write").Add(1)
1140-
}
1141-
}
1142-
rpty.activeConnsMutex.Unlock()
1143-
}
1144-
1145-
// Cleanup the process, PTY, and delete it's
1146-
// ID from memory.
1147-
_ = process.Kill()
1148-
rpty.Close()
1094+
rpty.Wait()
11491095
a.reconnectingPTYs.Delete(msg.ID)
11501096
}); err != nil {
1151-
_ = process.Kill()
1152-
_ = ptty.Close()
1097+
rpty.Close(err.Error())
11531098
return xerrors.Errorf("start routine: %w", err)
11541099
}
1100+
11551101
connected = true
11561102
sendConnected <- rpty
11571103
}
1158-
// Resize the PTY to initial height + width.
1159-
err := rpty.ptty.Resize(msg.Height, msg.Width)
1160-
if err != nil {
1161-
// We can continue after this, it's not fatal!
1162-
logger.Error(ctx, "reconnecting PTY initial resize failed, but will continue", slog.Error(err))
1163-
a.metrics.reconnectingPTYErrors.WithLabelValues("resize").Add(1)
1164-
}
1165-
// Write any previously stored data for the TTY.
1166-
rpty.circularBufferMutex.RLock()
1167-
prevBuf := slices.Clone(rpty.circularBuffer.Bytes())
1168-
rpty.circularBufferMutex.RUnlock()
1169-
// Note that there is a small race here between writing buffered
1170-
// data and storing conn in activeConns. This is likely a very minor
1171-
// edge case, but we should look into ways to avoid it. Holding
1172-
// activeConnsMutex would be one option, but holding this mutex
1173-
// while also holding circularBufferMutex seems dangerous.
1174-
_, err = conn.Write(prevBuf)
1175-
if err != nil {
1176-
a.metrics.reconnectingPTYErrors.WithLabelValues("write").Add(1)
1177-
return xerrors.Errorf("write buffer to conn: %w", err)
1178-
}
1179-
// Multiple connections to the same TTY are permitted.
1180-
// This could easily be used for terminal sharing, but
1181-
// we do it because it's a nice user experience to
1182-
// copy/paste a terminal URL and have it _just work_.
1183-
rpty.activeConnsMutex.Lock()
1184-
rpty.activeConns[connectionID] = conn
1185-
rpty.activeConnsMutex.Unlock()
1186-
// Resetting this timeout prevents the PTY from exiting.
1187-
rpty.timeout.Reset(a.reconnectingPTYTimeout)
1188-
1189-
ctx, cancelFunc := context.WithCancel(ctx)
1190-
defer cancelFunc()
1191-
heartbeat := time.NewTicker(a.reconnectingPTYTimeout / 2)
1192-
defer heartbeat.Stop()
1193-
go func() {
1194-
// Keep updating the activity while this
1195-
// connection is alive!
1196-
for {
1197-
select {
1198-
case <-ctx.Done():
1199-
return
1200-
case <-heartbeat.C:
1201-
}
1202-
rpty.timeout.Reset(a.reconnectingPTYTimeout)
1203-
}
1204-
}()
1205-
defer func() {
1206-
// After this connection ends, remove it from
1207-
// the PTYs active connections. If it isn't
1208-
// removed, all PTY data will be sent to it.
1209-
rpty.activeConnsMutex.Lock()
1210-
delete(rpty.activeConns, connectionID)
1211-
rpty.activeConnsMutex.Unlock()
1212-
}()
1213-
decoder := json.NewDecoder(conn)
1214-
var req codersdk.ReconnectingPTYRequest
1215-
for {
1216-
err = decoder.Decode(&req)
1217-
if xerrors.Is(err, io.EOF) {
1218-
return nil
1219-
}
1220-
if err != nil {
1221-
logger.Warn(ctx, "reconnecting PTY failed with read error", slog.Error(err))
1222-
return nil
1223-
}
1224-
_, err = rpty.ptty.InputWriter().Write([]byte(req.Data))
1225-
if err != nil {
1226-
logger.Warn(ctx, "reconnecting PTY failed with write error", slog.Error(err))
1227-
a.metrics.reconnectingPTYErrors.WithLabelValues("input_writer").Add(1)
1228-
return nil
1229-
}
1230-
// Check if a resize needs to happen!
1231-
if req.Height == 0 || req.Width == 0 {
1232-
continue
1233-
}
1234-
err = rpty.ptty.Resize(req.Height, req.Width)
1235-
if err != nil {
1236-
// We can continue after this, it's not fatal!
1237-
logger.Error(ctx, "reconnecting PTY resize failed, but will continue", slog.Error(err))
1238-
a.metrics.reconnectingPTYErrors.WithLabelValues("resize").Add(1)
1239-
}
1240-
}
1104+
return rpty.Attach(ctx, connectionID, conn, msg.Height, msg.Width)
12411105
}
12421106

12431107
// startReportingConnectionStats runs the connection stats reporting goroutine.
@@ -1455,31 +1319,6 @@ lifecycleWaitLoop:
14551319
return nil
14561320
}
14571321

1458-
type reconnectingPTY struct {
1459-
activeConnsMutex sync.Mutex
1460-
activeConns map[string]net.Conn
1461-
1462-
circularBuffer *circbuf.Buffer
1463-
circularBufferMutex sync.RWMutex
1464-
timeout *time.Timer
1465-
ptty pty.PTYCmd
1466-
}
1467-
1468-
// Close ends all connections to the reconnecting
1469-
// PTY and clear the circular buffer.
1470-
func (r *reconnectingPTY) Close() {
1471-
r.activeConnsMutex.Lock()
1472-
defer r.activeConnsMutex.Unlock()
1473-
for _, conn := range r.activeConns {
1474-
_ = conn.Close()
1475-
}
1476-
_ = r.ptty.Close()
1477-
r.circularBufferMutex.Lock()
1478-
r.circularBuffer.Reset()
1479-
r.circularBufferMutex.Unlock()
1480-
r.timeout.Stop()
1481-
}
1482-
14831322
// userHomeDir returns the home directory of the current user, giving
14841323
// priority to the $HOME environment variable.
14851324
func userHomeDir() (string, error) {

agent/agent_test.go

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"net/http/httptest"
1313
"net/netip"
1414
"os"
15+
"os/exec"
1516
"os/user"
1617
"path"
1718
"path/filepath"
@@ -102,7 +103,7 @@ func TestAgent_Stats_ReconnectingPTY(t *testing.T) {
102103
//nolint:dogsled
103104
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
104105

105-
ptyConn, err := conn.ReconnectingPTY(ctx, uuid.New(), 128, 128, "/bin/bash")
106+
ptyConn, err := conn.ReconnectingPTY(ctx, uuid.New(), 128, 128, "bash", codersdk.ReconnectingPTYBackendTypeBuffered)
106107
require.NoError(t, err)
107108
defer ptyConn.Close()
108109

@@ -1596,17 +1597,39 @@ func TestAgent_ReconnectingPTY(t *testing.T) {
15961597
t.Skip("ConPTY appears to be inconsistent on Windows.")
15971598
}
15981599

1600+
t.Run("Buffered", func(t *testing.T) {
1601+
t.Parallel()
1602+
testReconnectingPTY(t, codersdk.ReconnectingPTYBackendTypeBuffered)
1603+
})
1604+
1605+
t.Run("Screen", func(t *testing.T) {
1606+
t.Parallel()
1607+
_, err := exec.LookPath("screen")
1608+
if err != nil {
1609+
t.Skip("`screen` not found; skipping related tests")
1610+
}
1611+
testReconnectingPTY(t, codersdk.ReconnectingPTYBackendTypeScreen)
1612+
})
1613+
}
1614+
1615+
func testReconnectingPTY(t *testing.T, backendType codersdk.ReconnectingPTYBackendType) {
15991616
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
16001617
defer cancel()
16011618

16021619
//nolint:dogsled
16031620
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
16041621
id := uuid.New()
1605-
netConn, err := conn.ReconnectingPTY(ctx, id, 100, 100, "/bin/bash")
1622+
netConn1, err := conn.ReconnectingPTY(ctx, id, 100, 100, "bash", backendType)
16061623
require.NoError(t, err)
1607-
defer netConn.Close()
1624+
defer netConn1.Close()
16081625

1609-
bufRead := bufio.NewReader(netConn)
1626+
bufRead1 := bufio.NewReader(netConn1)
1627+
1628+
// A second simultaneous connection.
1629+
// netConn2, err := conn.ReconnectingPTY(ctx, id, 100, 100, "bash", backendType)
1630+
// require.NoError(t, err)
1631+
// defer netConn2.Close()
1632+
// bufRead2 := bufio.NewReader(netConn2)
16101633

16111634
// Brief pause to reduce the likelihood that we send keystrokes while
16121635
// the shell is simultaneously sending a prompt.
@@ -1616,10 +1639,10 @@ func TestAgent_ReconnectingPTY(t *testing.T) {
16161639
Data: "echo test\r\n",
16171640
})
16181641
require.NoError(t, err)
1619-
_, err = netConn.Write(data)
1642+
_, err = netConn1.Write(data)
16201643
require.NoError(t, err)
16211644

1622-
expectLine := func(matcher func(string) bool) {
1645+
expectLine := func(bufRead *bufio.Reader, matcher func(string) bool) {
16231646
for {
16241647
line, err := bufRead.ReadString('\n')
16251648
require.NoError(t, err)
@@ -1637,20 +1660,25 @@ func TestAgent_ReconnectingPTY(t *testing.T) {
16371660
}
16381661

16391662
// Once for typing the command...
1640-
expectLine(matchEchoCommand)
1663+
expectLine(bufRead1, matchEchoCommand)
16411664
// And another time for the actual output.
1642-
expectLine(matchEchoOutput)
1665+
expectLine(bufRead1, matchEchoOutput)
16431666

1644-
_ = netConn.Close()
1645-
netConn, err = conn.ReconnectingPTY(ctx, id, 100, 100, "/bin/bash")
1646-
require.NoError(t, err)
1647-
defer netConn.Close()
1667+
// // Same for the other connection.
1668+
// expectLine(bufRead2, matchEchoCommand)
1669+
// expectLine(bufRead2, matchEchoOutput)
1670+
1671+
// _ = netConn1.Close()
1672+
// _ = netConn2.Close()
1673+
// netConn3, err := conn.ReconnectingPTY(ctx, id, 100, 100, "bash", backendType)
1674+
// require.NoError(t, err)
1675+
// defer netConn3.Close()
16481676

1649-
bufRead = bufio.NewReader(netConn)
1677+
// bufRead3 := bufio.NewReader(netConn3)
16501678

1651-
// Same output again!
1652-
expectLine(matchEchoCommand)
1653-
expectLine(matchEchoOutput)
1679+
// // Same output again!
1680+
// expectLine(bufRead3, matchEchoCommand)
1681+
// expectLine(bufRead3, matchEchoOutput)
16541682
}
16551683

16561684
func TestAgent_Dial(t *testing.T) {

0 commit comments

Comments
 (0)