Skip to content

Commit 9f5ad23

Browse files
authored
refactor(agent/agentssh): move parsing of magic session and create type (#16630)
This change refactors the parsing of MagicSessionEnvs in the agentssh package and moves the logic to an earlier stage. Also intoduces enums for MagicSessionType. Refs #15139
1 parent 570e42b commit 9f5ad23

File tree

3 files changed

+92
-56
lines changed

3 files changed

+92
-56
lines changed

agent/agent_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ func TestAgent_Stats_Magic(t *testing.T) {
138138
defer sshClient.Close()
139139
session, err := sshClient.NewSession()
140140
require.NoError(t, err)
141-
session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, agentssh.MagicSessionTypeVSCode)
141+
session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, string(agentssh.MagicSessionTypeVSCode))
142142
defer session.Close()
143143

144144
command := "sh -c 'echo $" + agentssh.MagicSessionTypeEnvironmentVariable + "'"
@@ -165,7 +165,7 @@ func TestAgent_Stats_Magic(t *testing.T) {
165165
defer sshClient.Close()
166166
session, err := sshClient.NewSession()
167167
require.NoError(t, err)
168-
session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, agentssh.MagicSessionTypeVSCode)
168+
session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, string(agentssh.MagicSessionTypeVSCode))
169169
defer session.Close()
170170
stdin, err := session.StdinPipe()
171171
require.NoError(t, err)

agent/agentssh/agentssh.go

Lines changed: 85 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"github.com/spf13/afero"
2727
"go.uber.org/atomic"
2828
gossh "golang.org/x/crypto/ssh"
29+
"golang.org/x/exp/slices"
2930
"golang.org/x/xerrors"
3031

3132
"cdr.dev/slog"
@@ -42,14 +43,6 @@ const (
4243
// unlikely to shadow other exit codes, which are typically 1, 2, 3, etc.
4344
MagicSessionErrorCode = 229
4445

45-
// MagicSessionTypeEnvironmentVariable is used to track the purpose behind an SSH connection.
46-
// This is stripped from any commands being executed, and is counted towards connection stats.
47-
MagicSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE"
48-
// MagicSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself.
49-
MagicSessionTypeVSCode = "vscode"
50-
// MagicSessionTypeJetBrains is set in the SSH config by the JetBrains
51-
// extension to identify itself.
52-
MagicSessionTypeJetBrains = "jetbrains"
5346
// MagicProcessCmdlineJetBrains is a string in a process's command line that
5447
// uniquely identifies it as JetBrains software.
5548
MagicProcessCmdlineJetBrains = "idea.vendor.name=JetBrains"
@@ -60,6 +53,29 @@ const (
6053
BlockedFileTransferErrorMessage = "File transfer has been disabled."
6154
)
6255

56+
// MagicSessionType is a type that represents the type of session that is being
57+
// established.
58+
type MagicSessionType string
59+
60+
const (
61+
// MagicSessionTypeEnvironmentVariable is used to track the purpose behind an SSH connection.
62+
// This is stripped from any commands being executed, and is counted towards connection stats.
63+
MagicSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE"
64+
)
65+
66+
// MagicSessionType enums.
67+
const (
68+
// MagicSessionTypeUnknown means the session type could not be determined.
69+
MagicSessionTypeUnknown MagicSessionType = "unknown"
70+
// MagicSessionTypeSSH is the default session type.
71+
MagicSessionTypeSSH MagicSessionType = "ssh"
72+
// MagicSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself.
73+
MagicSessionTypeVSCode MagicSessionType = "vscode"
74+
// MagicSessionTypeJetBrains is set in the SSH config by the JetBrains
75+
// extension to identify itself.
76+
MagicSessionTypeJetBrains MagicSessionType = "jetbrains"
77+
)
78+
6379
// BlockedFileTransferCommands contains a list of restricted file transfer commands.
6480
var BlockedFileTransferCommands = []string{"nc", "rsync", "scp", "sftp"}
6581

@@ -255,14 +271,42 @@ func (s *Server) ConnStats() ConnStats {
255271
}
256272
}
257273

274+
func extractMagicSessionType(env []string) (magicType MagicSessionType, rawType string, filteredEnv []string) {
275+
for _, kv := range env {
276+
if !strings.HasPrefix(kv, MagicSessionTypeEnvironmentVariable) {
277+
continue
278+
}
279+
280+
rawType = strings.TrimPrefix(kv, MagicSessionTypeEnvironmentVariable+"=")
281+
// Keep going, we'll use the last instance of the env.
282+
}
283+
284+
// Always force lowercase checking to be case-insensitive.
285+
switch MagicSessionType(strings.ToLower(rawType)) {
286+
case MagicSessionTypeVSCode:
287+
magicType = MagicSessionTypeVSCode
288+
case MagicSessionTypeJetBrains:
289+
magicType = MagicSessionTypeJetBrains
290+
case "", MagicSessionTypeSSH:
291+
magicType = MagicSessionTypeSSH
292+
default:
293+
magicType = MagicSessionTypeUnknown
294+
}
295+
296+
return magicType, rawType, slices.DeleteFunc(env, func(kv string) bool {
297+
return strings.HasPrefix(kv, MagicSessionTypeEnvironmentVariable+"=")
298+
})
299+
}
300+
258301
func (s *Server) sessionHandler(session ssh.Session) {
259302
ctx := session.Context()
303+
id := uuid.New()
260304
logger := s.logger.With(
261305
slog.F("remote_addr", session.RemoteAddr()),
262306
slog.F("local_addr", session.LocalAddr()),
263307
// Assigning a random uuid for each session is useful for tracking
264308
// logs for the same ssh session.
265-
slog.F("id", uuid.NewString()),
309+
slog.F("id", id.String()),
266310
)
267311
logger.Info(ctx, "handling ssh session")
268312

@@ -274,16 +318,21 @@ func (s *Server) sessionHandler(session ssh.Session) {
274318
}
275319
defer s.trackSession(session, false)
276320

277-
extraEnv := make([]string, 0)
278-
x11, hasX11 := session.X11()
279-
if hasX11 {
280-
display, handled := s.x11Handler(session.Context(), x11)
281-
if !handled {
282-
_ = session.Exit(1)
283-
logger.Error(ctx, "x11 handler failed")
284-
return
285-
}
286-
extraEnv = append(extraEnv, fmt.Sprintf("DISPLAY=localhost:%d.%d", display, x11.ScreenNumber))
321+
env := session.Environ()
322+
magicType, magicTypeRaw, env := extractMagicSessionType(env)
323+
324+
switch magicType {
325+
case MagicSessionTypeVSCode:
326+
s.connCountVSCode.Add(1)
327+
defer s.connCountVSCode.Add(-1)
328+
case MagicSessionTypeJetBrains:
329+
// Do nothing here because JetBrains launches hundreds of ssh sessions.
330+
// We instead track JetBrains in the single persistent tcp forwarding channel.
331+
case MagicSessionTypeSSH:
332+
s.connCountSSHSession.Add(1)
333+
defer s.connCountSSHSession.Add(-1)
334+
case MagicSessionTypeUnknown:
335+
logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("raw_type", magicTypeRaw))
287336
}
288337

289338
if s.fileTransferBlocked(session) {
@@ -309,7 +358,18 @@ func (s *Server) sessionHandler(session ssh.Session) {
309358
return
310359
}
311360

312-
err := s.sessionStart(logger, session, extraEnv)
361+
x11, hasX11 := session.X11()
362+
if hasX11 {
363+
display, handled := s.x11Handler(session.Context(), x11)
364+
if !handled {
365+
_ = session.Exit(1)
366+
logger.Error(ctx, "x11 handler failed")
367+
return
368+
}
369+
env = append(env, fmt.Sprintf("DISPLAY=localhost:%d.%d", display, x11.ScreenNumber))
370+
}
371+
372+
err := s.sessionStart(logger, session, env, magicType)
313373
var exitError *exec.ExitError
314374
if xerrors.As(err, &exitError) {
315375
code := exitError.ExitCode()
@@ -379,32 +439,8 @@ func (s *Server) fileTransferBlocked(session ssh.Session) bool {
379439
return false
380440
}
381441

382-
func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv []string) (retErr error) {
442+
func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, env []string, magicType MagicSessionType) (retErr error) {
383443
ctx := session.Context()
384-
env := append(session.Environ(), extraEnv...)
385-
var magicType string
386-
for index, kv := range env {
387-
if !strings.HasPrefix(kv, MagicSessionTypeEnvironmentVariable) {
388-
continue
389-
}
390-
magicType = strings.ToLower(strings.TrimPrefix(kv, MagicSessionTypeEnvironmentVariable+"="))
391-
env = append(env[:index], env[index+1:]...)
392-
}
393-
394-
// Always force lowercase checking to be case-insensitive.
395-
switch magicType {
396-
case MagicSessionTypeVSCode:
397-
s.connCountVSCode.Add(1)
398-
defer s.connCountVSCode.Add(-1)
399-
case MagicSessionTypeJetBrains:
400-
// Do nothing here because JetBrains launches hundreds of ssh sessions.
401-
// We instead track JetBrains in the single persistent tcp forwarding channel.
402-
case "":
403-
s.connCountSSHSession.Add(1)
404-
defer s.connCountSSHSession.Add(-1)
405-
default:
406-
logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("type", magicType))
407-
}
408444

409445
magicTypeLabel := magicTypeMetricLabel(magicType)
410446
sshPty, windowSize, isPty := session.Pty()
@@ -473,7 +509,7 @@ func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, mag
473509
}()
474510
go func() {
475511
for sig := range sigs {
476-
s.handleSignal(logger, sig, cmd.Process, magicTypeLabel)
512+
handleSignal(logger, sig, cmd.Process, s.metrics, magicTypeLabel)
477513
}
478514
}()
479515
return cmd.Wait()
@@ -558,7 +594,7 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
558594
sigs = nil
559595
continue
560596
}
561-
s.handleSignal(logger, sig, process, magicTypeLabel)
597+
handleSignal(logger, sig, process, s.metrics, magicTypeLabel)
562598
case win, ok := <-windowSize:
563599
if !ok {
564600
windowSize = nil
@@ -612,15 +648,15 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
612648
return nil
613649
}
614650

615-
func (s *Server) handleSignal(logger slog.Logger, ssig ssh.Signal, signaler interface{ Signal(os.Signal) error }, magicTypeLabel string) {
651+
func handleSignal(logger slog.Logger, ssig ssh.Signal, signaler interface{ Signal(os.Signal) error }, metrics *sshServerMetrics, magicTypeLabel string) {
616652
ctx := context.Background()
617653
sig := osSignalFrom(ssig)
618654
logger = logger.With(slog.F("ssh_signal", ssig), slog.F("signal", sig.String()))
619655
logger.Info(ctx, "received signal from client")
620656
err := signaler.Signal(sig)
621657
if err != nil {
622658
logger.Warn(ctx, "signaling the process failed", slog.Error(err))
623-
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "signal").Add(1)
659+
metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "signal").Add(1)
624660
}
625661
}
626662

agent/agentssh/metrics.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,15 @@ func newSSHServerMetrics(registerer prometheus.Registerer) *sshServerMetrics {
7171
}
7272
}
7373

74-
func magicTypeMetricLabel(magicType string) string {
74+
func magicTypeMetricLabel(magicType MagicSessionType) string {
7575
switch magicType {
7676
case MagicSessionTypeVSCode:
7777
case MagicSessionTypeJetBrains:
78-
case "":
79-
magicType = "ssh"
78+
case MagicSessionTypeSSH:
79+
case MagicSessionTypeUnknown:
8080
default:
81-
magicType = "unknown"
81+
magicType = MagicSessionTypeUnknown
8282
}
8383
// Always be case insensitive
84-
return strings.ToLower(magicType)
84+
return strings.ToLower(string(magicType))
8585
}

0 commit comments

Comments
 (0)