Skip to content

Commit 2ff1c6d

Browse files
authored
feat: add agent stats for different connection types (#6412)
This allows us to track when our extensions are used, when the web terminal is used, and average connection latency to the agent.
1 parent 537547f commit 2ff1c6d

18 files changed

+412
-131
lines changed

agent/agent.go

+139-45
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"os/user"
1919
"path/filepath"
2020
"runtime"
21+
"sort"
2122
"strconv"
2223
"strings"
2324
"sync"
@@ -56,6 +57,14 @@ const (
5657
// command just returning a nonzero exit code, and is chosen as an arbitrary, high number
5758
// unlikely to shadow other exit codes, which are typically 1, 2, 3, etc.
5859
MagicSessionErrorCode = 229
60+
61+
// MagicSSHSessionTypeEnvironmentVariable is used to track the purpose behind an SSH connection.
62+
// This is stripped from any commands being executed, and is counted towards connection stats.
63+
MagicSSHSessionTypeEnvironmentVariable = "__CODER_SSH_SESSION_TYPE"
64+
// MagicSSHSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself.
65+
MagicSSHSessionTypeVSCode = "vscode"
66+
// MagicSSHSessionTypeJetBrains is set in the SSH config by the JetBrains extension to identify itself.
67+
MagicSSHSessionTypeJetBrains = "jetbrains"
5968
)
6069

6170
type Options struct {
@@ -146,6 +155,15 @@ type agent struct {
146155

147156
network *tailnet.Conn
148157
connStatsChan chan *agentsdk.Stats
158+
159+
statRxPackets atomic.Int64
160+
statRxBytes atomic.Int64
161+
statTxPackets atomic.Int64
162+
statTxBytes atomic.Int64
163+
connCountVSCode atomic.Int64
164+
connCountJetBrains atomic.Int64
165+
connCountReconnectingPTY atomic.Int64
166+
connCountSSHSession atomic.Int64
149167
}
150168

151169
// runLoop attempts to start the agent in a retry loop.
@@ -350,33 +368,7 @@ func (a *agent) run(ctx context.Context) error {
350368
return xerrors.New("agent is closed")
351369
}
352370

353-
setStatInterval := func(d time.Duration) {
354-
network.SetConnStatsCallback(d, 2048,
355-
func(_, _ time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) {
356-
select {
357-
case a.connStatsChan <- convertAgentStats(virtual):
358-
default:
359-
a.logger.Warn(ctx, "network stat dropped")
360-
}
361-
},
362-
)
363-
}
364-
365-
// Report statistics from the created network.
366-
cl, err := a.client.ReportStats(ctx, a.logger, a.connStatsChan, setStatInterval)
367-
if err != nil {
368-
a.logger.Error(ctx, "report stats", slog.Error(err))
369-
} else {
370-
if err = a.trackConnGoroutine(func() {
371-
// This is OK because the agent never re-creates the tailnet
372-
// and the only shutdown indicator is agent.Close().
373-
<-a.closed
374-
_ = cl.Close()
375-
}); err != nil {
376-
a.logger.Debug(ctx, "report stats goroutine", slog.Error(err))
377-
_ = cl.Close()
378-
}
379-
}
371+
a.startReportingConnectionStats(ctx)
380372
} else {
381373
// Update the DERP map!
382374
network.SetDERPMap(metadata.DERPMap)
@@ -765,23 +757,6 @@ func (a *agent) init(ctx context.Context) {
765757
go a.runLoop(ctx)
766758
}
767759

768-
func convertAgentStats(counts map[netlogtype.Connection]netlogtype.Counts) *agentsdk.Stats {
769-
stats := &agentsdk.Stats{
770-
ConnectionsByProto: map[string]int64{},
771-
ConnectionCount: int64(len(counts)),
772-
}
773-
774-
for conn, count := range counts {
775-
stats.ConnectionsByProto[conn.Proto.String()]++
776-
stats.RxPackets += int64(count.RxPackets)
777-
stats.RxBytes += int64(count.RxBytes)
778-
stats.TxPackets += int64(count.TxPackets)
779-
stats.TxBytes += int64(count.TxBytes)
780-
}
781-
782-
return stats
783-
}
784-
785760
// createCommand processes raw command input with OpenSSH-like behavior.
786761
// If the rawCommand provided is empty, it will default to the users shell.
787762
// This injects environment variables specified by the user at launch too.
@@ -892,7 +867,27 @@ func (a *agent) createCommand(ctx context.Context, rawCommand string, env []stri
892867

893868
func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
894869
ctx := session.Context()
895-
cmd, err := a.createCommand(ctx, session.RawCommand(), session.Environ())
870+
env := session.Environ()
871+
var magicType string
872+
for index, kv := range env {
873+
if !strings.HasPrefix(kv, MagicSSHSessionTypeEnvironmentVariable) {
874+
continue
875+
}
876+
magicType = strings.TrimPrefix(kv, MagicSSHSessionTypeEnvironmentVariable+"=")
877+
env = append(env[:index], env[index+1:]...)
878+
}
879+
switch magicType {
880+
case MagicSSHSessionTypeVSCode:
881+
a.connCountVSCode.Add(1)
882+
case MagicSSHSessionTypeJetBrains:
883+
a.connCountJetBrains.Add(1)
884+
case "":
885+
a.connCountSSHSession.Add(1)
886+
default:
887+
a.logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("type", magicType))
888+
}
889+
890+
cmd, err := a.createCommand(ctx, session.RawCommand(), env)
896891
if err != nil {
897892
return err
898893
}
@@ -990,6 +985,8 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
990985
func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, msg codersdk.WorkspaceAgentReconnectingPTYInit, conn net.Conn) (retErr error) {
991986
defer conn.Close()
992987

988+
a.connCountReconnectingPTY.Add(1)
989+
993990
connectionID := uuid.NewString()
994991
logger = logger.With(slog.F("id", msg.ID), slog.F("connection_id", connectionID))
995992

@@ -1180,6 +1177,103 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
11801177
}
11811178
}
11821179

1180+
// startReportingConnectionStats runs the connection stats reporting goroutine.
1181+
func (a *agent) startReportingConnectionStats(ctx context.Context) {
1182+
reportStats := func(networkStats map[netlogtype.Connection]netlogtype.Counts) {
1183+
stats := &agentsdk.Stats{
1184+
ConnectionCount: int64(len(networkStats)),
1185+
ConnectionsByProto: map[string]int64{},
1186+
}
1187+
// Tailscale resets counts on every report!
1188+
// We'd rather have these compound, like Linux does!
1189+
for conn, counts := range networkStats {
1190+
stats.ConnectionsByProto[conn.Proto.String()]++
1191+
stats.RxBytes = a.statRxBytes.Add(int64(counts.RxBytes))
1192+
stats.RxPackets = a.statRxPackets.Add(int64(counts.RxPackets))
1193+
stats.TxBytes = a.statTxBytes.Add(int64(counts.TxBytes))
1194+
stats.TxPackets = a.statTxPackets.Add(int64(counts.TxPackets))
1195+
}
1196+
1197+
// Tailscale's connection stats are not cumulative, but it makes no sense to make
1198+
// ours temporary.
1199+
stats.SessionCountSSH = a.connCountSSHSession.Load()
1200+
stats.SessionCountVSCode = a.connCountVSCode.Load()
1201+
stats.SessionCountJetBrains = a.connCountJetBrains.Load()
1202+
stats.SessionCountReconnectingPTY = a.connCountReconnectingPTY.Load()
1203+
1204+
// Compute the median connection latency!
1205+
var wg sync.WaitGroup
1206+
var mu sync.Mutex
1207+
status := a.network.Status()
1208+
durations := []float64{}
1209+
ctx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
1210+
defer cancelFunc()
1211+
for nodeID, peer := range status.Peer {
1212+
if !peer.Active {
1213+
continue
1214+
}
1215+
addresses, found := a.network.NodeAddresses(nodeID)
1216+
if !found {
1217+
continue
1218+
}
1219+
if len(addresses) == 0 {
1220+
continue
1221+
}
1222+
wg.Add(1)
1223+
go func() {
1224+
defer wg.Done()
1225+
duration, _, _, err := a.network.Ping(ctx, addresses[0].Addr())
1226+
if err != nil {
1227+
return
1228+
}
1229+
mu.Lock()
1230+
durations = append(durations, float64(duration.Microseconds()))
1231+
mu.Unlock()
1232+
}()
1233+
}
1234+
wg.Wait()
1235+
sort.Float64s(durations)
1236+
durationsLength := len(durations)
1237+
if durationsLength == 0 {
1238+
stats.ConnectionMedianLatencyMS = -1
1239+
} else if durationsLength%2 == 0 {
1240+
stats.ConnectionMedianLatencyMS = (durations[durationsLength/2-1] + durations[durationsLength/2]) / 2
1241+
} else {
1242+
stats.ConnectionMedianLatencyMS = durations[durationsLength/2]
1243+
}
1244+
// Convert from microseconds to milliseconds.
1245+
stats.ConnectionMedianLatencyMS /= 1000
1246+
1247+
select {
1248+
case a.connStatsChan <- stats:
1249+
default:
1250+
a.logger.Warn(ctx, "network stat dropped")
1251+
}
1252+
}
1253+
1254+
// Report statistics from the created network.
1255+
cl, err := a.client.ReportStats(ctx, a.logger, a.connStatsChan, func(d time.Duration) {
1256+
a.network.SetConnStatsCallback(d, 2048,
1257+
func(_, _ time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) {
1258+
reportStats(virtual)
1259+
},
1260+
)
1261+
})
1262+
if err != nil {
1263+
a.logger.Error(ctx, "report stats", slog.Error(err))
1264+
} else {
1265+
if err = a.trackConnGoroutine(func() {
1266+
// This is OK because the agent never re-creates the tailnet
1267+
// and the only shutdown indicator is agent.Close().
1268+
<-a.closed
1269+
_ = cl.Close()
1270+
}); err != nil {
1271+
a.logger.Debug(ctx, "report stats goroutine", slog.Error(err))
1272+
_ = cl.Close()
1273+
}
1274+
}
1275+
}
1276+
11831277
// isClosed returns whether the API is closed or not.
11841278
func (a *agent) isClosed() bool {
11851279
select {

agent/agent_test.go

+42-2
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ func TestAgent_Stats_SSH(t *testing.T) {
7373
require.Eventuallyf(t, func() bool {
7474
var ok bool
7575
s, ok = <-stats
76-
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0
76+
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountSSH == 1
7777
}, testutil.WaitLong, testutil.IntervalFast,
7878
"never saw stats: %+v", s,
7979
)
@@ -102,7 +102,47 @@ func TestAgent_Stats_ReconnectingPTY(t *testing.T) {
102102
require.Eventuallyf(t, func() bool {
103103
var ok bool
104104
s, ok = <-stats
105-
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0
105+
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountReconnectingPTY == 1
106+
}, testutil.WaitLong, testutil.IntervalFast,
107+
"never saw stats: %+v", s,
108+
)
109+
}
110+
111+
func TestAgent_Stats_Magic(t *testing.T) {
112+
t.Parallel()
113+
114+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
115+
defer cancel()
116+
117+
conn, _, stats, _ := setupAgent(t, agentsdk.Metadata{}, 0)
118+
sshClient, err := conn.SSHClient(ctx)
119+
require.NoError(t, err)
120+
defer sshClient.Close()
121+
session, err := sshClient.NewSession()
122+
require.NoError(t, err)
123+
session.Setenv(agent.MagicSSHSessionTypeEnvironmentVariable, agent.MagicSSHSessionTypeVSCode)
124+
defer session.Close()
125+
126+
command := "sh -c 'echo $" + agent.MagicSSHSessionTypeEnvironmentVariable + "'"
127+
expected := ""
128+
if runtime.GOOS == "windows" {
129+
expected = "%" + agent.MagicSSHSessionTypeEnvironmentVariable + "%"
130+
command = "cmd.exe /c echo " + expected
131+
}
132+
output, err := session.Output(command)
133+
require.NoError(t, err)
134+
require.Equal(t, expected, strings.TrimSpace(string(output)))
135+
var s *agentsdk.Stats
136+
require.Eventuallyf(t, func() bool {
137+
var ok bool
138+
s, ok = <-stats
139+
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 &&
140+
// Ensure that the connection didn't count as a "normal" SSH session.
141+
// This was a special one, so it should be labeled specially in the stats!
142+
s.SessionCountVSCode == 1 &&
143+
// Ensure that connection latency is being counted!
144+
// If it isn't, it's set to -1.
145+
s.ConnectionMedianLatencyMS >= 0
106146
}, testutil.WaitLong, testutil.IntervalFast,
107147
"never saw stats: %+v", s,
108148
)

coderd/apidoc/docs.go

+25-5
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/apidoc/swagger.json

+25-5
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)