Skip to content

Commit 9011a8a

Browse files
committed
chore: move ssh session tracking from agent to cli
1 parent 9abaa94 commit 9011a8a

File tree

3 files changed

+92
-15
lines changed

3 files changed

+92
-15
lines changed

agent/agentssh/agentssh.go

+32-14
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ const (
5252
// MagicProcessCmdlineJetBrains is a string in a process's command line that
5353
// uniquely identifies it as JetBrains software.
5454
MagicProcessCmdlineJetBrains = "idea.vendor.name=JetBrains"
55+
// MagicProcessCmdlineJetBrainsGateway is used to tell the agent not to report
56+
// session stats because it's now being handled by the CLI.
57+
MagicDisableUsageTrackingEnvironmentVariable = "CODER_SSH_DISABLE_USAGE_TRACKING"
5558

5659
// BlockedFileTransferErrorCode indicates that SSH server restricted the raw command from performing
5760
// the file transfer.
@@ -380,26 +383,41 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv
380383
env := append(session.Environ(), extraEnv...)
381384
var magicType string
382385
for index, kv := range env {
383-
if !strings.HasPrefix(kv, MagicSessionTypeEnvironmentVariable) {
386+
if !strings.HasPrefix(kv, MagicSessionTypeEnvironmentVariable+"=") {
384387
continue
385388
}
386389
magicType = strings.ToLower(strings.TrimPrefix(kv, MagicSessionTypeEnvironmentVariable+"="))
387390
env = append(env[:index], env[index+1:]...)
388391
}
392+
disableUsageTracking := false
393+
for index, kv := range env {
394+
if !strings.HasPrefix(kv, MagicDisableUsageTrackingEnvironmentVariable+"=") {
395+
continue
396+
}
397+
if strings.ToLower(strings.TrimPrefix(kv, MagicDisableUsageTrackingEnvironmentVariable+"=")) == "true" {
398+
disableUsageTracking = true
399+
}
400+
env = append(env[:index], env[index+1:]...)
401+
}
389402

390-
// Always force lowercase checking to be case-insensitive.
391-
switch magicType {
392-
case MagicSessionTypeVSCode:
393-
s.connCountVSCode.Add(1)
394-
defer s.connCountVSCode.Add(-1)
395-
case MagicSessionTypeJetBrains:
396-
// Do nothing here because JetBrains launches hundreds of ssh sessions.
397-
// We instead track JetBrains in the single persistent tcp forwarding channel.
398-
case "":
399-
s.connCountSSHSession.Add(1)
400-
defer s.connCountSSHSession.Add(-1)
401-
default:
402-
logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("type", magicType))
403+
// We only want to track the session counts if they are from
404+
// older clients that haven't migrating to reporting these stats
405+
// from the CLI. This ensures we don't double count sessions.
406+
if !disableUsageTracking {
407+
// Always force lowercase checking to be case-insensitive.
408+
switch magicType {
409+
case MagicSessionTypeVSCode:
410+
s.connCountVSCode.Add(1)
411+
defer s.connCountVSCode.Add(-1)
412+
case MagicSessionTypeJetBrains:
413+
// Do nothing here because JetBrains launches hundreds of ssh sessions.
414+
// We instead track JetBrains in the single persistent tcp forwarding channel.
415+
case "":
416+
s.connCountSSHSession.Add(1)
417+
defer s.connCountSSHSession.Add(-1)
418+
default:
419+
logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("type", magicType))
420+
}
403421
}
404422

405423
magicTypeLabel := magicTypeMetricLabel(magicType)

cli/ssh.go

+59
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"os"
1313
"os/exec"
1414
"path/filepath"
15+
"slices"
1516
"strings"
1617
"sync"
1718
"time"
@@ -28,6 +29,8 @@ import (
2829

2930
"cdr.dev/slog"
3031
"cdr.dev/slog/sloggers/sloghuman"
32+
"github.com/coder/coder/v2/agent/agentssh"
33+
"github.com/coder/coder/v2/apiversion"
3134
"github.com/coder/coder/v2/cli/cliui"
3235
"github.com/coder/coder/v2/cli/cliutil"
3336
"github.com/coder/coder/v2/coderd/autobuild/notify"
@@ -57,6 +60,7 @@ func (r *RootCmd) ssh() *serpent.Command {
5760
logDirPath string
5861
remoteForwards []string
5962
env []string
63+
usageApp string
6064
disableAutostart bool
6165
)
6266
client := new(codersdk.Client)
@@ -196,6 +200,11 @@ func (r *RootCmd) ssh() *serpent.Command {
196200
wait = false
197201
}
198202

203+
experiments, err := client.Experiments(ctx)
204+
if err != nil {
205+
return err
206+
}
207+
199208
templateVersion, err := client.TemplateVersion(ctx, workspace.LatestBuild.TemplateVersionID)
200209
if err != nil {
201210
return err
@@ -251,6 +260,18 @@ func (r *RootCmd) ssh() *serpent.Command {
251260
stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
252261
defer stopPolling()
253262

263+
usageAppName := getUsageAppName(usageApp, stdio, experiments, workspaceAgent.APIVersion)
264+
if usageAppName != "" {
265+
closeUsage := client.UpdateWorkspaceUsageWithBodyContext(ctx, workspace.ID, codersdk.PostWorkspaceUsageRequest{
266+
AgentID: workspaceAgent.ID,
267+
AppName: usageAppName,
268+
})
269+
defer closeUsage()
270+
271+
// signal to the agent that we are handling the usage tracking
272+
parsedEnv = append(parsedEnv, [2]string{agentssh.MagicDisableUsageTrackingEnvironmentVariable, "true"})
273+
}
274+
254275
if stdio {
255276
rawSSH, err := conn.SSH(ctx)
256277
if err != nil {
@@ -509,6 +530,12 @@ func (r *RootCmd) ssh() *serpent.Command {
509530
FlagShorthand: "e",
510531
Value: serpent.StringArrayOf(&env),
511532
},
533+
{
534+
Flag: "usage-app",
535+
Description: "Specifies the usage app to use for workspace activity tracking.",
536+
Env: "CODER_SSH_USAGE_APP",
537+
Value: serpent.StringOf(&usageApp),
538+
},
512539
sshDisableAutostartOption(serpent.BoolOf(&disableAutostart)),
513540
}
514541
return cmd
@@ -1044,3 +1071,35 @@ func (r stdioErrLogReader) Read(_ []byte) (int, error) {
10441071
r.l.Error(context.Background(), "reading from stdin in stdio mode is not allowed")
10451072
return 0, io.EOF
10461073
}
1074+
1075+
func getUsageAppName(usageApp string, stdio bool, experiments codersdk.Experiments, agentAPIVersion string) codersdk.UsageAppName {
1076+
// if experiment is not enabled do not report usage
1077+
if !slices.Contains(experiments, codersdk.ExperimentWorkspaceUsage) {
1078+
return ""
1079+
}
1080+
1081+
// need agent version to be at or after 2.2
1082+
major, minor, err := apiversion.Parse(agentAPIVersion)
1083+
if err != nil {
1084+
return ""
1085+
}
1086+
err = apiversion.New(major, minor).Validate("2.2")
1087+
if err != nil {
1088+
return ""
1089+
}
1090+
1091+
// if usageApp is empty and not stdio, default to ssh
1092+
if usageApp == "" && !stdio {
1093+
usageApp = string(codersdk.UsageAppNameSSH)
1094+
}
1095+
allowedUsageApps := []string{
1096+
string(codersdk.UsageAppNameJetbrains),
1097+
string(codersdk.UsageAppNameVscode),
1098+
string(codersdk.UsageAppNameSSH),
1099+
}
1100+
if slices.Contains(allowedUsageApps, usageApp) {
1101+
return codersdk.UsageAppName(usageApp)
1102+
}
1103+
1104+
return ""
1105+
}

tailnet/proto/version.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import (
66

77
const (
88
CurrentMajor = 2
9-
CurrentMinor = 1
9+
CurrentMinor = 2
1010
)
1111

1212
var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor).WithBackwardCompat(1)

0 commit comments

Comments
 (0)