Skip to content

Commit 9ae2251

Browse files
committed
Add screen backend for reconnecting ptys
The screen portion is a port from wsep. There is an interface that lets you choose between screen and the previous method. By default it will choose screen if it is installed but this can be overidden (mostly for tests). The tests use a scanner instead of a reader now because the reader will loop infinitely at the end of a stream. Relpace /bin/bash with bash since bash is not always in /bin.
1 parent 9ffbdc6 commit 9ae2251

File tree

20 files changed

+1103
-398
lines changed

20 files changed

+1103
-398
lines changed

agent/agent.go

Lines changed: 19 additions & 191 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import (
2121
"sync"
2222
"time"
2323

24-
"github.com/armon/circbuf"
2524
"github.com/google/uuid"
2625
"github.com/prometheus/client_golang/prometheus"
2726
"github.com/spf13/afero"
@@ -35,12 +34,12 @@ import (
3534

3635
"cdr.dev/slog"
3736
"github.com/coder/coder/agent/agentssh"
37+
"github.com/coder/coder/agent/reconnectingpty"
3838
"github.com/coder/coder/buildinfo"
3939
"github.com/coder/coder/coderd/database"
4040
"github.com/coder/coder/coderd/gitauth"
4141
"github.com/coder/coder/codersdk"
4242
"github.com/coder/coder/codersdk/agentsdk"
43-
"github.com/coder/coder/pty"
4443
"github.com/coder/coder/tailnet"
4544
"github.com/coder/retry"
4645
)
@@ -89,9 +88,6 @@ type Agent interface {
8988
}
9089

9190
func New(options Options) Agent {
92-
if options.ReconnectingPTYTimeout == 0 {
93-
options.ReconnectingPTYTimeout = 5 * time.Minute
94-
}
9591
if options.Filesystem == nil {
9692
options.Filesystem = afero.NewOsFs()
9793
}
@@ -1078,22 +1074,22 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10781074
// If the agent is closed, we don't want to
10791075
// log this as an error since it's expected.
10801076
if closed {
1081-
logger.Debug(ctx, "reconnecting PTY failed with session error (agent closed)", slog.Error(err))
1077+
logger.Debug(ctx, "reconnecting pty failed with attach error (agent closed)", slog.Error(err))
10821078
} else {
1083-
logger.Error(ctx, "reconnecting PTY failed with session error", slog.Error(err))
1079+
logger.Error(ctx, "reconnecting pty failed with attach error", slog.Error(err))
10841080
}
10851081
}
1086-
logger.Debug(ctx, "session closed")
1082+
logger.Debug(ctx, "reconnecting pty connection closed")
10871083
}()
10881084

1089-
var rpty *reconnectingPTY
1090-
sendConnected := make(chan *reconnectingPTY, 1)
1085+
var rpty *reconnectingpty.ReconnectingPTY
1086+
sendConnected := make(chan *reconnectingpty.ReconnectingPTY, 1)
10911087
// On store, reserve this ID to prevent multiple concurrent new connections.
10921088
waitReady, ok := a.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected)
10931089
if ok {
10941090
close(sendConnected) // Unused.
1095-
logger.Debug(ctx, "connecting to existing session")
1096-
c, ok := waitReady.(chan *reconnectingPTY)
1091+
logger.Debug(ctx, "connecting to existing reconnecting pty")
1092+
c, ok := waitReady.(chan *reconnectingpty.ReconnectingPTY)
10971093
if !ok {
10981094
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady)
10991095
}
@@ -1103,7 +1099,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
11031099
}
11041100
c <- rpty // Put it back for the next reconnect.
11051101
} else {
1106-
logger.Debug(ctx, "creating new session")
1102+
logger.Debug(ctx, "creating new reconnecting pty", slog.F("backend", msg.BackendType))
11071103

11081104
connected := false
11091105
defer func() {
@@ -1119,169 +1115,26 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
11191115
a.metrics.reconnectingPTYErrors.WithLabelValues("create_command").Add(1)
11201116
return xerrors.Errorf("create command: %w", err)
11211117
}
1122-
cmd.Env = append(cmd.Env, "TERM=xterm-256color")
1123-
1124-
// Default to buffer 64KiB.
1125-
circularBuffer, err := circbuf.NewBuffer(64 << 10)
1126-
if err != nil {
1127-
return xerrors.Errorf("create circular buffer: %w", err)
1128-
}
11291118

1130-
ptty, process, err := pty.Start(cmd)
1131-
if err != nil {
1132-
a.metrics.reconnectingPTYErrors.WithLabelValues("start_command").Add(1)
1133-
return xerrors.Errorf("start command: %w", err)
1134-
}
1119+
rpty = reconnectingpty.New(ctx, cmd, &reconnectingpty.Options{
1120+
BackendType: msg.BackendType,
1121+
Timeout: a.reconnectingPTYTimeout,
1122+
Metrics: a.metrics.reconnectingPTYErrors,
1123+
Logger: logger,
1124+
})
11351125

1136-
ctx, cancel := context.WithCancel(ctx)
1137-
rpty = &reconnectingPTY{
1138-
activeConns: map[string]net.Conn{
1139-
// We have to put the connection in the map instantly otherwise
1140-
// the connection won't be closed if the process instantly dies.
1141-
connectionID: conn,
1142-
},
1143-
ptty: ptty,
1144-
// Timeouts created with an after func can be reset!
1145-
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancel),
1146-
circularBuffer: circularBuffer,
1147-
}
1148-
// We don't need to separately monitor for the process exiting.
1149-
// When it exits, our ptty.OutputReader() will return EOF after
1150-
// reading all process output.
11511126
if err = a.trackConnGoroutine(func() {
1152-
buffer := make([]byte, 1024)
1153-
for {
1154-
read, err := rpty.ptty.OutputReader().Read(buffer)
1155-
if err != nil {
1156-
// When the PTY is closed, this is triggered.
1157-
// Error is typically a benign EOF, so only log for debugging.
1158-
if errors.Is(err, io.EOF) {
1159-
logger.Debug(ctx, "unable to read pty output, command might have exited", slog.Error(err))
1160-
} else {
1161-
logger.Warn(ctx, "unable to read pty output, command might have exited", slog.Error(err))
1162-
a.metrics.reconnectingPTYErrors.WithLabelValues("output_reader").Add(1)
1163-
}
1164-
break
1165-
}
1166-
part := buffer[:read]
1167-
rpty.circularBufferMutex.Lock()
1168-
_, err = rpty.circularBuffer.Write(part)
1169-
rpty.circularBufferMutex.Unlock()
1170-
if err != nil {
1171-
logger.Error(ctx, "write to circular buffer", slog.Error(err))
1172-
break
1173-
}
1174-
rpty.activeConnsMutex.Lock()
1175-
for cid, conn := range rpty.activeConns {
1176-
_, err = conn.Write(part)
1177-
if err != nil {
1178-
logger.Warn(ctx,
1179-
"error writing to active conn",
1180-
slog.F("other_conn_id", cid),
1181-
slog.Error(err),
1182-
)
1183-
a.metrics.reconnectingPTYErrors.WithLabelValues("write").Add(1)
1184-
}
1185-
}
1186-
rpty.activeConnsMutex.Unlock()
1187-
}
1188-
1189-
// Cleanup the process, PTY, and delete it's
1190-
// ID from memory.
1191-
_ = process.Kill()
1192-
rpty.Close()
1127+
rpty.Wait()
11931128
a.reconnectingPTYs.Delete(msg.ID)
11941129
}); err != nil {
1195-
_ = process.Kill()
1196-
_ = ptty.Close()
1130+
rpty.Close(err.Error())
11971131
return xerrors.Errorf("start routine: %w", err)
11981132
}
1133+
11991134
connected = true
12001135
sendConnected <- rpty
12011136
}
1202-
// Resize the PTY to initial height + width.
1203-
err := rpty.ptty.Resize(msg.Height, msg.Width)
1204-
if err != nil {
1205-
// We can continue after this, it's not fatal!
1206-
logger.Error(ctx, "reconnecting PTY initial resize failed, but will continue", slog.Error(err))
1207-
a.metrics.reconnectingPTYErrors.WithLabelValues("resize").Add(1)
1208-
}
1209-
// Write any previously stored data for the TTY.
1210-
rpty.circularBufferMutex.RLock()
1211-
prevBuf := slices.Clone(rpty.circularBuffer.Bytes())
1212-
rpty.circularBufferMutex.RUnlock()
1213-
// Note that there is a small race here between writing buffered
1214-
// data and storing conn in activeConns. This is likely a very minor
1215-
// edge case, but we should look into ways to avoid it. Holding
1216-
// activeConnsMutex would be one option, but holding this mutex
1217-
// while also holding circularBufferMutex seems dangerous.
1218-
_, err = conn.Write(prevBuf)
1219-
if err != nil {
1220-
a.metrics.reconnectingPTYErrors.WithLabelValues("write").Add(1)
1221-
return xerrors.Errorf("write buffer to conn: %w", err)
1222-
}
1223-
// Multiple connections to the same TTY are permitted.
1224-
// This could easily be used for terminal sharing, but
1225-
// we do it because it's a nice user experience to
1226-
// copy/paste a terminal URL and have it _just work_.
1227-
rpty.activeConnsMutex.Lock()
1228-
rpty.activeConns[connectionID] = conn
1229-
rpty.activeConnsMutex.Unlock()
1230-
// Resetting this timeout prevents the PTY from exiting.
1231-
rpty.timeout.Reset(a.reconnectingPTYTimeout)
1232-
1233-
ctx, cancelFunc := context.WithCancel(ctx)
1234-
defer cancelFunc()
1235-
heartbeat := time.NewTicker(a.reconnectingPTYTimeout / 2)
1236-
defer heartbeat.Stop()
1237-
go func() {
1238-
// Keep updating the activity while this
1239-
// connection is alive!
1240-
for {
1241-
select {
1242-
case <-ctx.Done():
1243-
return
1244-
case <-heartbeat.C:
1245-
}
1246-
rpty.timeout.Reset(a.reconnectingPTYTimeout)
1247-
}
1248-
}()
1249-
defer func() {
1250-
// After this connection ends, remove it from
1251-
// the PTYs active connections. If it isn't
1252-
// removed, all PTY data will be sent to it.
1253-
rpty.activeConnsMutex.Lock()
1254-
delete(rpty.activeConns, connectionID)
1255-
rpty.activeConnsMutex.Unlock()
1256-
}()
1257-
decoder := json.NewDecoder(conn)
1258-
var req codersdk.ReconnectingPTYRequest
1259-
for {
1260-
err = decoder.Decode(&req)
1261-
if xerrors.Is(err, io.EOF) {
1262-
return nil
1263-
}
1264-
if err != nil {
1265-
logger.Warn(ctx, "reconnecting PTY failed with read error", slog.Error(err))
1266-
return nil
1267-
}
1268-
_, err = rpty.ptty.InputWriter().Write([]byte(req.Data))
1269-
if err != nil {
1270-
logger.Warn(ctx, "reconnecting PTY failed with write error", slog.Error(err))
1271-
a.metrics.reconnectingPTYErrors.WithLabelValues("input_writer").Add(1)
1272-
return nil
1273-
}
1274-
// Check if a resize needs to happen!
1275-
if req.Height == 0 || req.Width == 0 {
1276-
continue
1277-
}
1278-
err = rpty.ptty.Resize(req.Height, req.Width)
1279-
if err != nil {
1280-
// We can continue after this, it's not fatal!
1281-
logger.Error(ctx, "reconnecting PTY resize failed, but will continue", slog.Error(err))
1282-
a.metrics.reconnectingPTYErrors.WithLabelValues("resize").Add(1)
1283-
}
1284-
}
1137+
return rpty.Attach(ctx, connectionID, conn, msg.Height, msg.Width)
12851138
}
12861139

12871140
// startReportingConnectionStats runs the connection stats reporting goroutine.
@@ -1499,31 +1352,6 @@ lifecycleWaitLoop:
14991352
return nil
15001353
}
15011354

1502-
type reconnectingPTY struct {
1503-
activeConnsMutex sync.Mutex
1504-
activeConns map[string]net.Conn
1505-
1506-
circularBuffer *circbuf.Buffer
1507-
circularBufferMutex sync.RWMutex
1508-
timeout *time.Timer
1509-
ptty pty.PTYCmd
1510-
}
1511-
1512-
// Close ends all connections to the reconnecting
1513-
// PTY and clear the circular buffer.
1514-
func (r *reconnectingPTY) Close() {
1515-
r.activeConnsMutex.Lock()
1516-
defer r.activeConnsMutex.Unlock()
1517-
for _, conn := range r.activeConns {
1518-
_ = conn.Close()
1519-
}
1520-
_ = r.ptty.Close()
1521-
r.circularBufferMutex.Lock()
1522-
r.circularBuffer.Reset()
1523-
r.circularBufferMutex.Unlock()
1524-
r.timeout.Stop()
1525-
}
1526-
15271355
// userHomeDir returns the home directory of the current user, giving
15281356
// priority to the $HOME environment variable.
15291357
func userHomeDir() (string, error) {

0 commit comments

Comments
 (0)