Skip to content

Commit b40c93b

Browse files
committed
Merge branch 'main' into v2
2 parents 5771fb1 + 37f9d4b commit b40c93b

File tree

16 files changed

+1174
-327
lines changed

16 files changed

+1174
-327
lines changed

agent/agent.go

Lines changed: 18 additions & 193 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import (
2121
"sync"
2222
"time"
2323

24-
"github.com/armon/circbuf"
2524
"github.com/go-chi/chi/v5"
2625
"github.com/google/uuid"
2726
"github.com/prometheus/client_golang/prometheus"
@@ -41,7 +40,6 @@ import (
4140
"github.com/coder/coder/v2/coderd/gitauth"
4241
"github.com/coder/coder/v2/codersdk"
4342
"github.com/coder/coder/v2/codersdk/agentsdk"
44-
"github.com/coder/coder/v2/pty"
4543
"github.com/coder/coder/v2/tailnet"
4644
"github.com/coder/retry"
4745
)
@@ -92,9 +90,6 @@ type Agent interface {
9290
}
9391

9492
func New(options Options) Agent {
95-
if options.ReconnectingPTYTimeout == 0 {
96-
options.ReconnectingPTYTimeout = 5 * time.Minute
97-
}
9893
if options.Filesystem == nil {
9994
options.Filesystem = afero.NewOsFs()
10095
}
@@ -1075,8 +1070,8 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10751070
defer a.connCountReconnectingPTY.Add(-1)
10761071

10771072
connectionID := uuid.NewString()
1078-
logger = logger.With(slog.F("message_id", msg.ID), slog.F("connection_id", connectionID))
1079-
logger.Debug(ctx, "starting handler")
1073+
connLogger := logger.With(slog.F("message_id", msg.ID), slog.F("connection_id", connectionID))
1074+
connLogger.Debug(ctx, "starting handler")
10801075

10811076
defer func() {
10821077
if err := retErr; err != nil {
@@ -1087,22 +1082,22 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10871082
// If the agent is closed, we don't want to
10881083
// log this as an error since it's expected.
10891084
if closed {
1090-
logger.Debug(ctx, "reconnecting PTY failed with session error (agent closed)", slog.Error(err))
1085+
connLogger.Debug(ctx, "reconnecting pty failed with attach error (agent closed)", slog.Error(err))
10911086
} else {
1092-
logger.Error(ctx, "reconnecting PTY failed with session error", slog.Error(err))
1087+
connLogger.Error(ctx, "reconnecting pty failed with attach error", slog.Error(err))
10931088
}
10941089
}
1095-
logger.Debug(ctx, "session closed")
1090+
connLogger.Debug(ctx, "reconnecting pty connection closed")
10961091
}()
10971092

1098-
var rpty *reconnectingPTY
1099-
sendConnected := make(chan *reconnectingPTY, 1)
1093+
var rpty reconnectingpty.ReconnectingPTY
1094+
sendConnected := make(chan reconnectingpty.ReconnectingPTY, 1)
11001095
// On store, reserve this ID to prevent multiple concurrent new connections.
11011096
waitReady, ok := a.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected)
11021097
if ok {
11031098
close(sendConnected) // Unused.
1104-
logger.Debug(ctx, "connecting to existing session")
1105-
c, ok := waitReady.(chan *reconnectingPTY)
1099+
connLogger.Debug(ctx, "connecting to existing reconnecting pty")
1100+
c, ok := waitReady.(chan reconnectingpty.ReconnectingPTY)
11061101
if !ok {
11071102
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady)
11081103
}
@@ -1112,7 +1107,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
11121107
}
11131108
c <- rpty // Put it back for the next reconnect.
11141109
} else {
1115-
logger.Debug(ctx, "creating new session")
1110+
connLogger.Debug(ctx, "creating new reconnecting pty")
11161111

11171112
connected := false
11181113
defer func() {
@@ -1128,169 +1123,24 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
11281123
a.metrics.reconnectingPTYErrors.WithLabelValues("create_command").Add(1)
11291124
return xerrors.Errorf("create command: %w", err)
11301125
}
1131-
cmd.Env = append(cmd.Env, "TERM=xterm-256color")
1132-
1133-
// Default to buffer 64KiB.
1134-
circularBuffer, err := circbuf.NewBuffer(64 << 10)
1135-
if err != nil {
1136-
return xerrors.Errorf("create circular buffer: %w", err)
1137-
}
11381126

1139-
ptty, process, err := pty.Start(cmd)
1140-
if err != nil {
1141-
a.metrics.reconnectingPTYErrors.WithLabelValues("start_command").Add(1)
1142-
return xerrors.Errorf("start command: %w", err)
1143-
}
1127+
rpty = reconnectingpty.New(ctx, cmd, &reconnectingpty.Options{
1128+
Timeout: a.reconnectingPTYTimeout,
1129+
Metrics: a.metrics.reconnectingPTYErrors,
1130+
}, logger.With(slog.F("message_id", msg.ID)))
11441131

1145-
ctx, cancel := context.WithCancel(ctx)
1146-
rpty = &reconnectingPTY{
1147-
activeConns: map[string]net.Conn{
1148-
// We have to put the connection in the map instantly otherwise
1149-
// the connection won't be closed if the process instantly dies.
1150-
connectionID: conn,
1151-
},
1152-
ptty: ptty,
1153-
// Timeouts created with an after func can be reset!
1154-
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancel),
1155-
circularBuffer: circularBuffer,
1156-
}
1157-
// We don't need to separately monitor for the process exiting.
1158-
// When it exits, our ptty.OutputReader() will return EOF after
1159-
// reading all process output.
11601132
if err = a.trackConnGoroutine(func() {
1161-
buffer := make([]byte, 1024)
1162-
for {
1163-
read, err := rpty.ptty.OutputReader().Read(buffer)
1164-
if err != nil {
1165-
// When the PTY is closed, this is triggered.
1166-
// Error is typically a benign EOF, so only log for debugging.
1167-
if errors.Is(err, io.EOF) {
1168-
logger.Debug(ctx, "unable to read pty output, command might have exited", slog.Error(err))
1169-
} else {
1170-
logger.Warn(ctx, "unable to read pty output, command might have exited", slog.Error(err))
1171-
a.metrics.reconnectingPTYErrors.WithLabelValues("output_reader").Add(1)
1172-
}
1173-
break
1174-
}
1175-
part := buffer[:read]
1176-
rpty.circularBufferMutex.Lock()
1177-
_, err = rpty.circularBuffer.Write(part)
1178-
rpty.circularBufferMutex.Unlock()
1179-
if err != nil {
1180-
logger.Error(ctx, "write to circular buffer", slog.Error(err))
1181-
break
1182-
}
1183-
rpty.activeConnsMutex.Lock()
1184-
for cid, conn := range rpty.activeConns {
1185-
_, err = conn.Write(part)
1186-
if err != nil {
1187-
logger.Warn(ctx,
1188-
"error writing to active conn",
1189-
slog.F("other_conn_id", cid),
1190-
slog.Error(err),
1191-
)
1192-
a.metrics.reconnectingPTYErrors.WithLabelValues("write").Add(1)
1193-
}
1194-
}
1195-
rpty.activeConnsMutex.Unlock()
1196-
}
1197-
1198-
// Cleanup the process, PTY, and delete it's
1199-
// ID from memory.
1200-
_ = process.Kill()
1201-
rpty.Close()
1133+
rpty.Wait()
12021134
a.reconnectingPTYs.Delete(msg.ID)
12031135
}); err != nil {
1204-
_ = process.Kill()
1205-
_ = ptty.Close()
1136+
rpty.Close(err.Error())
12061137
return xerrors.Errorf("start routine: %w", err)
12071138
}
1139+
12081140
connected = true
12091141
sendConnected <- rpty
12101142
}
1211-
// Resize the PTY to initial height + width.
1212-
err := rpty.ptty.Resize(msg.Height, msg.Width)
1213-
if err != nil {
1214-
// We can continue after this, it's not fatal!
1215-
logger.Error(ctx, "reconnecting PTY initial resize failed, but will continue", slog.Error(err))
1216-
a.metrics.reconnectingPTYErrors.WithLabelValues("resize").Add(1)
1217-
}
1218-
// Write any previously stored data for the TTY.
1219-
rpty.circularBufferMutex.RLock()
1220-
prevBuf := slices.Clone(rpty.circularBuffer.Bytes())
1221-
rpty.circularBufferMutex.RUnlock()
1222-
// Note that there is a small race here between writing buffered
1223-
// data and storing conn in activeConns. This is likely a very minor
1224-
// edge case, but we should look into ways to avoid it. Holding
1225-
// activeConnsMutex would be one option, but holding this mutex
1226-
// while also holding circularBufferMutex seems dangerous.
1227-
_, err = conn.Write(prevBuf)
1228-
if err != nil {
1229-
a.metrics.reconnectingPTYErrors.WithLabelValues("write").Add(1)
1230-
return xerrors.Errorf("write buffer to conn: %w", err)
1231-
}
1232-
// Multiple connections to the same TTY are permitted.
1233-
// This could easily be used for terminal sharing, but
1234-
// we do it because it's a nice user experience to
1235-
// copy/paste a terminal URL and have it _just work_.
1236-
rpty.activeConnsMutex.Lock()
1237-
rpty.activeConns[connectionID] = conn
1238-
rpty.activeConnsMutex.Unlock()
1239-
// Resetting this timeout prevents the PTY from exiting.
1240-
rpty.timeout.Reset(a.reconnectingPTYTimeout)
1241-
1242-
ctx, cancelFunc := context.WithCancel(ctx)
1243-
defer cancelFunc()
1244-
heartbeat := time.NewTicker(a.reconnectingPTYTimeout / 2)
1245-
defer heartbeat.Stop()
1246-
go func() {
1247-
// Keep updating the activity while this
1248-
// connection is alive!
1249-
for {
1250-
select {
1251-
case <-ctx.Done():
1252-
return
1253-
case <-heartbeat.C:
1254-
}
1255-
rpty.timeout.Reset(a.reconnectingPTYTimeout)
1256-
}
1257-
}()
1258-
defer func() {
1259-
// After this connection ends, remove it from
1260-
// the PTYs active connections. If it isn't
1261-
// removed, all PTY data will be sent to it.
1262-
rpty.activeConnsMutex.Lock()
1263-
delete(rpty.activeConns, connectionID)
1264-
rpty.activeConnsMutex.Unlock()
1265-
}()
1266-
decoder := json.NewDecoder(conn)
1267-
var req codersdk.ReconnectingPTYRequest
1268-
for {
1269-
err = decoder.Decode(&req)
1270-
if xerrors.Is(err, io.EOF) {
1271-
return nil
1272-
}
1273-
if err != nil {
1274-
logger.Warn(ctx, "reconnecting PTY failed with read error", slog.Error(err))
1275-
return nil
1276-
}
1277-
_, err = rpty.ptty.InputWriter().Write([]byte(req.Data))
1278-
if err != nil {
1279-
logger.Warn(ctx, "reconnecting PTY failed with write error", slog.Error(err))
1280-
a.metrics.reconnectingPTYErrors.WithLabelValues("input_writer").Add(1)
1281-
return nil
1282-
}
1283-
// Check if a resize needs to happen!
1284-
if req.Height == 0 || req.Width == 0 {
1285-
continue
1286-
}
1287-
err = rpty.ptty.Resize(req.Height, req.Width)
1288-
if err != nil {
1289-
// We can continue after this, it's not fatal!
1290-
logger.Error(ctx, "reconnecting PTY resize failed, but will continue", slog.Error(err))
1291-
a.metrics.reconnectingPTYErrors.WithLabelValues("resize").Add(1)
1292-
}
1293-
}
1143+
return rpty.Attach(ctx, connectionID, conn, msg.Height, msg.Width, connLogger)
12941144
}
12951145

12961146
// startReportingConnectionStats runs the connection stats reporting goroutine.
@@ -1541,31 +1391,6 @@ lifecycleWaitLoop:
15411391
return nil
15421392
}
15431393

1544-
type reconnectingPTY struct {
1545-
activeConnsMutex sync.Mutex
1546-
activeConns map[string]net.Conn
1547-
1548-
circularBuffer *circbuf.Buffer
1549-
circularBufferMutex sync.RWMutex
1550-
timeout *time.Timer
1551-
ptty pty.PTYCmd
1552-
}
1553-
1554-
// Close ends all connections to the reconnecting
1555-
// PTY and clear the circular buffer.
1556-
func (r *reconnectingPTY) Close() {
1557-
r.activeConnsMutex.Lock()
1558-
defer r.activeConnsMutex.Unlock()
1559-
for _, conn := range r.activeConns {
1560-
_ = conn.Close()
1561-
}
1562-
_ = r.ptty.Close()
1563-
r.circularBufferMutex.Lock()
1564-
r.circularBuffer.Reset()
1565-
r.circularBufferMutex.Unlock()
1566-
r.timeout.Stop()
1567-
}
1568-
15691394
// userHomeDir returns the home directory of the current user, giving
15701395
// priority to the $HOME environment variable.
15711396
func userHomeDir() (string, error) {

0 commit comments

Comments
 (0)