Skip to content

Commit d201025

Browse files
Merge branch 'main' into dm-experiment-autostart
2 parents 100f54c + 0896f33 commit d201025

File tree

76 files changed

+2914
-514
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+2914
-514
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ yarn-error.log
1717
# Allow VSCode recommendations and default settings in project root.
1818
!/.vscode/extensions.json
1919
!/.vscode/settings.json
20+
# Allow code snippets
21+
!/.vscode/*.code-snippets
2022

2123
# Front-end ignore patterns.
2224
.next/

.prettierignore

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ yarn-error.log
2020
# Allow VSCode recommendations and default settings in project root.
2121
!/.vscode/extensions.json
2222
!/.vscode/settings.json
23+
# Allow code snippets
24+
!/.vscode/*.code-snippets
2325

2426
# Front-end ignore patterns.
2527
.next/

.vscode/markdown.code-snippets

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
{
2+
// For info about snippets, visit https://code.visualstudio.com/docs/editor/userdefinedsnippets
3+
4+
"admonition": {
5+
"prefix": "#callout",
6+
"body": [
7+
"<blockquote class=\"admonition ${1|caution,important,note,tip,warning|}\">\n",
8+
"${TM_SELECTED_TEXT:${2:add info here}}\n",
9+
"</blockquote>\n"
10+
],
11+
"description": "callout admonition caution info note tip warning"
12+
},
13+
"fenced code block": {
14+
"prefix": "#codeblock",
15+
"body": ["```${1|apache,bash,console,diff,Dockerfile,env,go,hcl,ini,json,lisp,md,powershell,shell,sql,text,tf,tsx,yaml|}", "${TM_SELECTED_TEXT}$0", "```"],
16+
"description": "fenced code block"
17+
},
18+
"image": {
19+
"prefix": "#image",
20+
"body": "![${TM_SELECTED_TEXT:${1:alt}}](${2:url})$0",
21+
"description": "image"
22+
},
23+
"tabs": {
24+
"prefix": "#tabs",
25+
"body": [
26+
"<div class=\"tabs\">\n",
27+
"${1:optional description}\n",
28+
"## ${2:tab title}\n",
29+
"${TM_SELECTED_TEXT:${3:first tab content}}\n",
30+
"## ${4:tab title}\n",
31+
"${5:second tab content}\n",
32+
"## ${6:tab title}\n",
33+
"${7:third tab content}\n",
34+
"</div>\n"
35+
],
36+
"description": "tabs"
37+
}
38+
}

agent/agent.go

+17-138
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,10 @@ package agent
33
import (
44
"bytes"
55
"context"
6-
"encoding/binary"
76
"encoding/json"
87
"errors"
98
"fmt"
109
"io"
11-
"net"
1210
"net/http"
1311
"net/netip"
1412
"os"
@@ -216,8 +214,8 @@ type agent struct {
216214
portCacheDuration time.Duration
217215
subsystems []codersdk.AgentSubsystem
218216

219-
reconnectingPTYs sync.Map
220217
reconnectingPTYTimeout time.Duration
218+
reconnectingPTYServer *reconnectingpty.Server
221219

222220
// we track 2 contexts and associated cancel functions: "graceful" which is Done when it is time
223221
// to start gracefully shutting down and "hard" which is Done when it is time to close
@@ -252,8 +250,6 @@ type agent struct {
252250
statsReporter *statsReporter
253251
logSender *agentsdk.LogSender
254252

255-
connCountReconnectingPTY atomic.Int64
256-
257253
prometheusRegistry *prometheus.Registry
258254
// metrics are prometheus registered metrics that will be collected and
259255
// labeled in Coder with the agent + workspace.
@@ -297,6 +293,13 @@ func (a *agent) init() {
297293
// Register runner metrics. If the prom registry is nil, the metrics
298294
// will not report anywhere.
299295
a.scriptRunner.RegisterMetrics(a.prometheusRegistry)
296+
297+
a.reconnectingPTYServer = reconnectingpty.NewServer(
298+
a.logger.Named("reconnecting-pty"),
299+
a.sshServer,
300+
a.metrics.connectionsTotal, a.metrics.reconnectingPTYErrors,
301+
a.reconnectingPTYTimeout,
302+
)
300303
go a.runLoop()
301304
}
302305

@@ -1181,55 +1184,12 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
11811184
}
11821185
}()
11831186
if err = a.trackGoroutine(func() {
1184-
logger := a.logger.Named("reconnecting-pty")
1185-
var wg sync.WaitGroup
1186-
for {
1187-
conn, err := reconnectingPTYListener.Accept()
1188-
if err != nil {
1189-
if !a.isClosed() {
1190-
logger.Debug(ctx, "accept pty failed", slog.Error(err))
1191-
}
1192-
break
1193-
}
1194-
clog := logger.With(
1195-
slog.F("remote", conn.RemoteAddr().String()),
1196-
slog.F("local", conn.LocalAddr().String()))
1197-
clog.Info(ctx, "accepted conn")
1198-
wg.Add(1)
1199-
closed := make(chan struct{})
1200-
go func() {
1201-
select {
1202-
case <-closed:
1203-
case <-a.hardCtx.Done():
1204-
_ = conn.Close()
1205-
}
1206-
wg.Done()
1207-
}()
1208-
go func() {
1209-
defer close(closed)
1210-
// This cannot use a JSON decoder, since that can
1211-
// buffer additional data that is required for the PTY.
1212-
rawLen := make([]byte, 2)
1213-
_, err = conn.Read(rawLen)
1214-
if err != nil {
1215-
return
1216-
}
1217-
length := binary.LittleEndian.Uint16(rawLen)
1218-
data := make([]byte, length)
1219-
_, err = conn.Read(data)
1220-
if err != nil {
1221-
return
1222-
}
1223-
var msg workspacesdk.AgentReconnectingPTYInit
1224-
err = json.Unmarshal(data, &msg)
1225-
if err != nil {
1226-
logger.Warn(ctx, "failed to unmarshal init", slog.F("raw", data))
1227-
return
1228-
}
1229-
_ = a.handleReconnectingPTY(ctx, clog, msg, conn)
1230-
}()
1187+
rPTYServeErr := a.reconnectingPTYServer.Serve(a.gracefulCtx, a.hardCtx, reconnectingPTYListener)
1188+
if rPTYServeErr != nil &&
1189+
a.gracefulCtx.Err() == nil &&
1190+
!strings.Contains(rPTYServeErr.Error(), "use of closed network connection") {
1191+
a.logger.Error(ctx, "error serving reconnecting PTY", slog.Error(err))
12311192
}
1232-
wg.Wait()
12331193
}); err != nil {
12341194
return nil, err
12351195
}
@@ -1308,9 +1268,9 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
13081268
_ = server.Close()
13091269
}()
13101270

1311-
err := server.Serve(apiListener)
1312-
if err != nil && !xerrors.Is(err, http.ErrServerClosed) && !strings.Contains(err.Error(), "use of closed network connection") {
1313-
a.logger.Critical(ctx, "serve HTTP API server", slog.Error(err))
1271+
apiServErr := server.Serve(apiListener)
1272+
if apiServErr != nil && !xerrors.Is(apiServErr, http.ErrServerClosed) && !strings.Contains(apiServErr.Error(), "use of closed network connection") {
1273+
a.logger.Critical(ctx, "serve HTTP API server", slog.Error(apiServErr))
13141274
}
13151275
}); err != nil {
13161276
return nil, err
@@ -1394,87 +1354,6 @@ func (a *agent) runDERPMapSubscriber(ctx context.Context, tClient tailnetproto.D
13941354
}
13951355
}
13961356

1397-
func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, msg workspacesdk.AgentReconnectingPTYInit, conn net.Conn) (retErr error) {
1398-
defer conn.Close()
1399-
a.metrics.connectionsTotal.Add(1)
1400-
1401-
a.connCountReconnectingPTY.Add(1)
1402-
defer a.connCountReconnectingPTY.Add(-1)
1403-
1404-
connectionID := uuid.NewString()
1405-
connLogger := logger.With(slog.F("message_id", msg.ID), slog.F("connection_id", connectionID))
1406-
connLogger.Debug(ctx, "starting handler")
1407-
1408-
defer func() {
1409-
if err := retErr; err != nil {
1410-
a.closeMutex.Lock()
1411-
closed := a.isClosed()
1412-
a.closeMutex.Unlock()
1413-
1414-
// If the agent is closed, we don't want to
1415-
// log this as an error since it's expected.
1416-
if closed {
1417-
connLogger.Info(ctx, "reconnecting pty failed with attach error (agent closed)", slog.Error(err))
1418-
} else {
1419-
connLogger.Error(ctx, "reconnecting pty failed with attach error", slog.Error(err))
1420-
}
1421-
}
1422-
connLogger.Info(ctx, "reconnecting pty connection closed")
1423-
}()
1424-
1425-
var rpty reconnectingpty.ReconnectingPTY
1426-
sendConnected := make(chan reconnectingpty.ReconnectingPTY, 1)
1427-
// On store, reserve this ID to prevent multiple concurrent new connections.
1428-
waitReady, ok := a.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected)
1429-
if ok {
1430-
close(sendConnected) // Unused.
1431-
connLogger.Debug(ctx, "connecting to existing reconnecting pty")
1432-
c, ok := waitReady.(chan reconnectingpty.ReconnectingPTY)
1433-
if !ok {
1434-
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady)
1435-
}
1436-
rpty, ok = <-c
1437-
if !ok || rpty == nil {
1438-
return xerrors.Errorf("reconnecting pty closed before connection")
1439-
}
1440-
c <- rpty // Put it back for the next reconnect.
1441-
} else {
1442-
connLogger.Debug(ctx, "creating new reconnecting pty")
1443-
1444-
connected := false
1445-
defer func() {
1446-
if !connected && retErr != nil {
1447-
a.reconnectingPTYs.Delete(msg.ID)
1448-
close(sendConnected)
1449-
}
1450-
}()
1451-
1452-
// Empty command will default to the users shell!
1453-
cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil)
1454-
if err != nil {
1455-
a.metrics.reconnectingPTYErrors.WithLabelValues("create_command").Add(1)
1456-
return xerrors.Errorf("create command: %w", err)
1457-
}
1458-
1459-
rpty = reconnectingpty.New(ctx, cmd, &reconnectingpty.Options{
1460-
Timeout: a.reconnectingPTYTimeout,
1461-
Metrics: a.metrics.reconnectingPTYErrors,
1462-
}, logger.With(slog.F("message_id", msg.ID)))
1463-
1464-
if err = a.trackGoroutine(func() {
1465-
rpty.Wait()
1466-
a.reconnectingPTYs.Delete(msg.ID)
1467-
}); err != nil {
1468-
rpty.Close(err)
1469-
return xerrors.Errorf("start routine: %w", err)
1470-
}
1471-
1472-
connected = true
1473-
sendConnected <- rpty
1474-
}
1475-
return rpty.Attach(ctx, connectionID, conn, msg.Height, msg.Width, connLogger)
1476-
}
1477-
14781357
// Collect collects additional stats from the agent
14791358
func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connection]netlogtype.Counts) *proto.Stats {
14801359
a.logger.Debug(context.Background(), "computing stats report")
@@ -1496,7 +1375,7 @@ func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connect
14961375
stats.SessionCountVscode = sshStats.VSCode
14971376
stats.SessionCountJetbrains = sshStats.JetBrains
14981377

1499-
stats.SessionCountReconnectingPty = a.connCountReconnectingPTY.Load()
1378+
stats.SessionCountReconnectingPty = a.reconnectingPTYServer.ConnCount()
15001379

15011380
// Compute the median connection latency!
15021381
a.logger.Debug(ctx, "starting peer latency measurement for stats")

0 commit comments

Comments
 (0)