Skip to content

Commit d43228c

Browse files
committed
feat: add --network-info-dir and --network-info-interval flags to coder ssh
This is the first in a series of PRs to enable "coder ssh" to replace "coder vscodessh". This change adds --network-info-dir and --network-info-interval flags to the ssh subcommand. These were formerly only available with the vscodessh subcommand. Subsequent PRs will add a --ssh-host-prefix flag to the ssh subcommand, and adjust the log file naming to contain the parent PID.
1 parent 6ca1e59 commit d43228c

File tree

4 files changed

+309
-163
lines changed

4 files changed

+309
-163
lines changed

cli/ssh.go

Lines changed: 215 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package cli
33
import (
44
"bytes"
55
"context"
6+
"encoding/json"
67
"errors"
78
"fmt"
89
"io"
@@ -13,6 +14,7 @@ import (
1314
"os/exec"
1415
"path/filepath"
1516
"slices"
17+
"strconv"
1618
"strings"
1719
"sync"
1820
"time"
@@ -21,11 +23,14 @@ import (
2123
"github.com/gofrs/flock"
2224
"github.com/google/uuid"
2325
"github.com/mattn/go-isatty"
26+
"github.com/spf13/afero"
2427
gossh "golang.org/x/crypto/ssh"
2528
gosshagent "golang.org/x/crypto/ssh/agent"
2629
"golang.org/x/term"
2730
"golang.org/x/xerrors"
2831
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
32+
"tailscale.com/tailcfg"
33+
"tailscale.com/types/netlogtype"
2934

3035
"cdr.dev/slog"
3136
"cdr.dev/slog/sloggers/sloghuman"
@@ -55,19 +60,21 @@ var (
5560

5661
func (r *RootCmd) ssh() *serpent.Command {
5762
var (
58-
stdio bool
59-
forwardAgent bool
60-
forwardGPG bool
61-
identityAgent string
62-
wsPollInterval time.Duration
63-
waitEnum string
64-
noWait bool
65-
logDirPath string
66-
remoteForwards []string
67-
env []string
68-
usageApp string
69-
disableAutostart bool
70-
appearanceConfig codersdk.AppearanceConfig
63+
stdio bool
64+
forwardAgent bool
65+
forwardGPG bool
66+
identityAgent string
67+
wsPollInterval time.Duration
68+
waitEnum string
69+
noWait bool
70+
logDirPath string
71+
remoteForwards []string
72+
env []string
73+
usageApp string
74+
disableAutostart bool
75+
appearanceConfig codersdk.AppearanceConfig
76+
networkInfoDir string
77+
networkInfoInterval time.Duration
7178
)
7279
client := new(codersdk.Client)
7380
cmd := &serpent.Command{
@@ -274,6 +281,11 @@ func (r *RootCmd) ssh() *serpent.Command {
274281
defer closeUsage()
275282
}
276283

284+
fs, ok := inv.Context().Value("fs").(afero.Fs)
285+
if !ok {
286+
fs = afero.NewOsFs()
287+
}
288+
277289
if stdio {
278290
rawSSH, err := conn.SSH(ctx)
279291
if err != nil {
@@ -284,13 +296,21 @@ func (r *RootCmd) ssh() *serpent.Command {
284296
return err
285297
}
286298

299+
var errCh <-chan error
300+
if networkInfoDir != "" {
301+
errCh, err = setStatsCallback(ctx, conn, fs, logger, networkInfoDir, networkInfoInterval)
302+
if err != nil {
303+
return err
304+
}
305+
}
306+
287307
wg.Add(1)
288308
go func() {
289309
defer wg.Done()
290310
watchAndClose(ctx, func() error {
291311
stack.close(xerrors.New("watchAndClose"))
292312
return nil
293-
}, logger, client, workspace)
313+
}, logger, client, workspace, errCh)
294314
}()
295315
copier.copy(&wg)
296316
return nil
@@ -312,6 +332,14 @@ func (r *RootCmd) ssh() *serpent.Command {
312332
return err
313333
}
314334

335+
var errCh <-chan error
336+
if networkInfoDir != "" {
337+
errCh, err = setStatsCallback(ctx, conn, fs, logger, networkInfoDir, networkInfoInterval)
338+
if err != nil {
339+
return err
340+
}
341+
}
342+
315343
wg.Add(1)
316344
go func() {
317345
defer wg.Done()
@@ -324,6 +352,7 @@ func (r *RootCmd) ssh() *serpent.Command {
324352
logger,
325353
client,
326354
workspace,
355+
errCh,
327356
)
328357
}()
329358

@@ -540,6 +569,17 @@ func (r *RootCmd) ssh() *serpent.Command {
540569
Value: serpent.StringOf(&usageApp),
541570
Hidden: true,
542571
},
572+
{
573+
Flag: "network-info-dir",
574+
Description: "Specifies a directory to write network information periodically.",
575+
Value: serpent.StringOf(&networkInfoDir),
576+
},
577+
{
578+
Flag: "network-info-interval",
579+
Description: "Specifies the interval to update network information.",
580+
Default: "5s",
581+
Value: serpent.DurationOf(&networkInfoInterval),
582+
},
543583
sshDisableAutostartOption(serpent.BoolOf(&disableAutostart)),
544584
}
545585
return cmd
@@ -555,7 +595,7 @@ func (r *RootCmd) ssh() *serpent.Command {
555595
// will usually not propagate.
556596
//
557597
// See: https://github.com/coder/coder/issues/6180
558-
func watchAndClose(ctx context.Context, closer func() error, logger slog.Logger, client *codersdk.Client, workspace codersdk.Workspace) {
598+
func watchAndClose(ctx context.Context, closer func() error, logger slog.Logger, client *codersdk.Client, workspace codersdk.Workspace, errCh <-chan error) {
559599
// Ensure session is ended on both context cancellation
560600
// and workspace stop.
561601
defer func() {
@@ -606,6 +646,9 @@ startWatchLoop:
606646
logger.Info(ctx, "workspace stopped")
607647
return
608648
}
649+
case err := <-errCh:
650+
logger.Error(ctx, "%s", err)
651+
return
609652
}
610653
}
611654
}
@@ -1144,3 +1187,160 @@ func getUsageAppName(usageApp string) codersdk.UsageAppName {
11441187

11451188
return codersdk.UsageAppNameSSH
11461189
}
1190+
1191+
func setStatsCallback(
1192+
ctx context.Context,
1193+
agentConn *workspacesdk.AgentConn,
1194+
fs afero.Fs,
1195+
logger slog.Logger,
1196+
networkInfoDir string,
1197+
networkInfoInterval time.Duration,
1198+
) (<-chan error, error) {
1199+
fs, ok := ctx.Value("fs").(afero.Fs)
1200+
if !ok {
1201+
fs = afero.NewOsFs()
1202+
}
1203+
if err := fs.MkdirAll(networkInfoDir, 0o700); err != nil {
1204+
return nil, xerrors.Errorf("mkdir: %w", err)
1205+
}
1206+
1207+
// The VS Code extension obtains the PID of the SSH process to
1208+
// read files to display logs and network info.
1209+
//
1210+
// We get the parent PID because it's assumed `ssh` is calling this
1211+
// command via the ProxyCommand SSH option.
1212+
pid := os.Getppid()
1213+
1214+
// The VS Code extension obtains the PID of the SSH process to
1215+
// read the file below which contains network information to display.
1216+
//
1217+
// We get the parent PID because it's assumed `ssh` is calling this
1218+
// command via the ProxyCommand SSH option.
1219+
networkInfoFilePath := filepath.Join(networkInfoDir, fmt.Sprintf("%d.json", pid))
1220+
1221+
var (
1222+
firstErrTime time.Time
1223+
errCh = make(chan error, 1)
1224+
)
1225+
cb := func(start, end time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) {
1226+
sendErr := func(tolerate bool, err error) {
1227+
logger.Error(ctx, "collect network stats", slog.Error(err))
1228+
// Tolerate up to 1 minute of errors.
1229+
if tolerate {
1230+
if firstErrTime.IsZero() {
1231+
logger.Info(ctx, "tolerating network stats errors for up to 1 minute")
1232+
firstErrTime = time.Now()
1233+
}
1234+
if time.Since(firstErrTime) < time.Minute {
1235+
return
1236+
}
1237+
}
1238+
1239+
select {
1240+
case errCh <- err:
1241+
default:
1242+
}
1243+
}
1244+
1245+
stats, err := collectNetworkStats(ctx, agentConn, start, end, virtual)
1246+
if err != nil {
1247+
sendErr(true, err)
1248+
return
1249+
}
1250+
1251+
rawStats, err := json.Marshal(stats)
1252+
if err != nil {
1253+
sendErr(false, err)
1254+
return
1255+
}
1256+
err = afero.WriteFile(fs, networkInfoFilePath, rawStats, 0o600)
1257+
if err != nil {
1258+
sendErr(false, err)
1259+
return
1260+
}
1261+
1262+
firstErrTime = time.Time{}
1263+
}
1264+
1265+
now := time.Now()
1266+
cb(now, now.Add(time.Nanosecond), map[netlogtype.Connection]netlogtype.Counts{}, map[netlogtype.Connection]netlogtype.Counts{})
1267+
agentConn.SetConnStatsCallback(networkInfoInterval, 2048, cb)
1268+
return errCh, nil
1269+
}
1270+
1271+
type sshNetworkStats struct {
1272+
P2P bool `json:"p2p"`
1273+
Latency float64 `json:"latency"`
1274+
PreferredDERP string `json:"preferred_derp"`
1275+
DERPLatency map[string]float64 `json:"derp_latency"`
1276+
UploadBytesSec int64 `json:"upload_bytes_sec"`
1277+
DownloadBytesSec int64 `json:"download_bytes_sec"`
1278+
}
1279+
1280+
func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) {
1281+
latency, p2p, pingResult, err := agentConn.Ping(ctx)
1282+
if err != nil {
1283+
return nil, err
1284+
}
1285+
node := agentConn.Node()
1286+
derpMap := agentConn.DERPMap()
1287+
derpLatency := map[string]float64{}
1288+
1289+
// Convert DERP region IDs to friendly names for display in the UI.
1290+
for rawRegion, latency := range node.DERPLatency {
1291+
regionParts := strings.SplitN(rawRegion, "-", 2)
1292+
regionID, err := strconv.Atoi(regionParts[0])
1293+
if err != nil {
1294+
continue
1295+
}
1296+
region, found := derpMap.Regions[regionID]
1297+
if !found {
1298+
// It's possible that a workspace agent is using an old DERPMap
1299+
// and reports regions that do not exist. If that's the case,
1300+
// report the region as unknown!
1301+
region = &tailcfg.DERPRegion{
1302+
RegionID: regionID,
1303+
RegionName: fmt.Sprintf("Unnamed %d", regionID),
1304+
}
1305+
}
1306+
// Convert the microseconds to milliseconds.
1307+
derpLatency[region.RegionName] = latency * 1000
1308+
}
1309+
1310+
totalRx := uint64(0)
1311+
totalTx := uint64(0)
1312+
for _, stat := range counts {
1313+
totalRx += stat.RxBytes
1314+
totalTx += stat.TxBytes
1315+
}
1316+
// Tracking the time since last request is required because
1317+
// ExtractTrafficStats() resets its counters after each call.
1318+
dur := end.Sub(start)
1319+
uploadSecs := float64(totalTx) / dur.Seconds()
1320+
downloadSecs := float64(totalRx) / dur.Seconds()
1321+
1322+
// Sometimes the preferred DERP doesn't match the one we're actually
1323+
// connected with. Perhaps because the agent prefers a different DERP and
1324+
// we're using that server instead.
1325+
preferredDerpID := node.PreferredDERP
1326+
if pingResult.DERPRegionID != 0 {
1327+
preferredDerpID = pingResult.DERPRegionID
1328+
}
1329+
preferredDerp, ok := derpMap.Regions[preferredDerpID]
1330+
preferredDerpName := fmt.Sprintf("Unnamed %d", preferredDerpID)
1331+
if ok {
1332+
preferredDerpName = preferredDerp.RegionName
1333+
}
1334+
if _, ok := derpLatency[preferredDerpName]; !ok {
1335+
derpLatency[preferredDerpName] = 0
1336+
}
1337+
1338+
return &sshNetworkStats{
1339+
P2P: p2p,
1340+
Latency: float64(latency.Microseconds()) / 1000,
1341+
PreferredDERP: preferredDerpName,
1342+
DERPLatency: derpLatency,
1343+
UploadBytesSec: int64(uploadSecs),
1344+
DownloadBytesSec: int64(downloadSecs),
1345+
}, nil
1346+
}

0 commit comments

Comments
 (0)