@@ -26,6 +26,7 @@ import (
26
26
"github.com/spf13/afero"
27
27
"go.uber.org/atomic"
28
28
gossh "golang.org/x/crypto/ssh"
29
+ "golang.org/x/exp/slices"
29
30
"golang.org/x/xerrors"
30
31
31
32
"cdr.dev/slog"
@@ -42,14 +43,6 @@ const (
42
43
// unlikely to shadow other exit codes, which are typically 1, 2, 3, etc.
43
44
MagicSessionErrorCode = 229
44
45
45
- // MagicSessionTypeEnvironmentVariable is used to track the purpose behind an SSH connection.
46
- // This is stripped from any commands being executed, and is counted towards connection stats.
47
- MagicSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE"
48
- // MagicSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself.
49
- MagicSessionTypeVSCode = "vscode"
50
- // MagicSessionTypeJetBrains is set in the SSH config by the JetBrains
51
- // extension to identify itself.
52
- MagicSessionTypeJetBrains = "jetbrains"
53
46
// MagicProcessCmdlineJetBrains is a string in a process's command line that
54
47
// uniquely identifies it as JetBrains software.
55
48
MagicProcessCmdlineJetBrains = "idea.vendor.name=JetBrains"
@@ -60,6 +53,29 @@ const (
60
53
BlockedFileTransferErrorMessage = "File transfer has been disabled."
61
54
)
62
55
56
+ // MagicSessionType is a type that represents the type of session that is being
57
+ // established.
58
+ type MagicSessionType string
59
+
60
+ const (
61
+ // MagicSessionTypeEnvironmentVariable is used to track the purpose behind an SSH connection.
62
+ // This is stripped from any commands being executed, and is counted towards connection stats.
63
+ MagicSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE"
64
+ )
65
+
66
+ // MagicSessionType enums.
67
+ const (
68
+ // MagicSessionTypeUnknown means the session type could not be determined.
69
+ MagicSessionTypeUnknown MagicSessionType = "unknown"
70
+ // MagicSessionTypeSSH is the default session type.
71
+ MagicSessionTypeSSH MagicSessionType = "ssh"
72
+ // MagicSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself.
73
+ MagicSessionTypeVSCode MagicSessionType = "vscode"
74
+ // MagicSessionTypeJetBrains is set in the SSH config by the JetBrains
75
+ // extension to identify itself.
76
+ MagicSessionTypeJetBrains MagicSessionType = "jetbrains"
77
+ )
78
+
63
79
// BlockedFileTransferCommands contains a list of restricted file transfer commands.
64
80
var BlockedFileTransferCommands = []string {"nc" , "rsync" , "scp" , "sftp" }
65
81
@@ -255,14 +271,42 @@ func (s *Server) ConnStats() ConnStats {
255
271
}
256
272
}
257
273
274
+ func extractMagicSessionType (env []string ) (magicType MagicSessionType , rawType string , filteredEnv []string ) {
275
+ for _ , kv := range env {
276
+ if ! strings .HasPrefix (kv , MagicSessionTypeEnvironmentVariable ) {
277
+ continue
278
+ }
279
+
280
+ rawType = strings .TrimPrefix (kv , MagicSessionTypeEnvironmentVariable + "=" )
281
+ // Keep going, we'll use the last instance of the env.
282
+ }
283
+
284
+ // Always force lowercase checking to be case-insensitive.
285
+ switch MagicSessionType (strings .ToLower (rawType )) {
286
+ case MagicSessionTypeVSCode :
287
+ magicType = MagicSessionTypeVSCode
288
+ case MagicSessionTypeJetBrains :
289
+ magicType = MagicSessionTypeJetBrains
290
+ case "" , MagicSessionTypeSSH :
291
+ magicType = MagicSessionTypeSSH
292
+ default :
293
+ magicType = MagicSessionTypeUnknown
294
+ }
295
+
296
+ return magicType , rawType , slices .DeleteFunc (env , func (kv string ) bool {
297
+ return strings .HasPrefix (kv , MagicSessionTypeEnvironmentVariable + "=" )
298
+ })
299
+ }
300
+
258
301
func (s * Server ) sessionHandler (session ssh.Session ) {
259
302
ctx := session .Context ()
303
+ id := uuid .New ()
260
304
logger := s .logger .With (
261
305
slog .F ("remote_addr" , session .RemoteAddr ()),
262
306
slog .F ("local_addr" , session .LocalAddr ()),
263
307
// Assigning a random uuid for each session is useful for tracking
264
308
// logs for the same ssh session.
265
- slog .F ("id" , uuid . NewString ()),
309
+ slog .F ("id" , id . String ()),
266
310
)
267
311
logger .Info (ctx , "handling ssh session" )
268
312
@@ -274,16 +318,21 @@ func (s *Server) sessionHandler(session ssh.Session) {
274
318
}
275
319
defer s .trackSession (session , false )
276
320
277
- extraEnv := make ([]string , 0 )
278
- x11 , hasX11 := session .X11 ()
279
- if hasX11 {
280
- display , handled := s .x11Handler (session .Context (), x11 )
281
- if ! handled {
282
- _ = session .Exit (1 )
283
- logger .Error (ctx , "x11 handler failed" )
284
- return
285
- }
286
- extraEnv = append (extraEnv , fmt .Sprintf ("DISPLAY=localhost:%d.%d" , display , x11 .ScreenNumber ))
321
+ env := session .Environ ()
322
+ magicType , magicTypeRaw , env := extractMagicSessionType (env )
323
+
324
+ switch magicType {
325
+ case MagicSessionTypeVSCode :
326
+ s .connCountVSCode .Add (1 )
327
+ defer s .connCountVSCode .Add (- 1 )
328
+ case MagicSessionTypeJetBrains :
329
+ // Do nothing here because JetBrains launches hundreds of ssh sessions.
330
+ // We instead track JetBrains in the single persistent tcp forwarding channel.
331
+ case MagicSessionTypeSSH :
332
+ s .connCountSSHSession .Add (1 )
333
+ defer s .connCountSSHSession .Add (- 1 )
334
+ case MagicSessionTypeUnknown :
335
+ logger .Warn (ctx , "invalid magic ssh session type specified" , slog .F ("raw_type" , magicTypeRaw ))
287
336
}
288
337
289
338
if s .fileTransferBlocked (session ) {
@@ -309,7 +358,18 @@ func (s *Server) sessionHandler(session ssh.Session) {
309
358
return
310
359
}
311
360
312
- err := s .sessionStart (logger , session , extraEnv )
361
+ x11 , hasX11 := session .X11 ()
362
+ if hasX11 {
363
+ display , handled := s .x11Handler (session .Context (), x11 )
364
+ if ! handled {
365
+ _ = session .Exit (1 )
366
+ logger .Error (ctx , "x11 handler failed" )
367
+ return
368
+ }
369
+ env = append (env , fmt .Sprintf ("DISPLAY=localhost:%d.%d" , display , x11 .ScreenNumber ))
370
+ }
371
+
372
+ err := s .sessionStart (logger , session , env , magicType )
313
373
var exitError * exec.ExitError
314
374
if xerrors .As (err , & exitError ) {
315
375
code := exitError .ExitCode ()
@@ -379,32 +439,8 @@ func (s *Server) fileTransferBlocked(session ssh.Session) bool {
379
439
return false
380
440
}
381
441
382
- func (s * Server ) sessionStart (logger slog.Logger , session ssh.Session , extraEnv []string ) (retErr error ) {
442
+ func (s * Server ) sessionStart (logger slog.Logger , session ssh.Session , env []string , magicType MagicSessionType ) (retErr error ) {
383
443
ctx := session .Context ()
384
- env := append (session .Environ (), extraEnv ... )
385
- var magicType string
386
- for index , kv := range env {
387
- if ! strings .HasPrefix (kv , MagicSessionTypeEnvironmentVariable ) {
388
- continue
389
- }
390
- magicType = strings .ToLower (strings .TrimPrefix (kv , MagicSessionTypeEnvironmentVariable + "=" ))
391
- env = append (env [:index ], env [index + 1 :]... )
392
- }
393
-
394
- // Always force lowercase checking to be case-insensitive.
395
- switch magicType {
396
- case MagicSessionTypeVSCode :
397
- s .connCountVSCode .Add (1 )
398
- defer s .connCountVSCode .Add (- 1 )
399
- case MagicSessionTypeJetBrains :
400
- // Do nothing here because JetBrains launches hundreds of ssh sessions.
401
- // We instead track JetBrains in the single persistent tcp forwarding channel.
402
- case "" :
403
- s .connCountSSHSession .Add (1 )
404
- defer s .connCountSSHSession .Add (- 1 )
405
- default :
406
- logger .Warn (ctx , "invalid magic ssh session type specified" , slog .F ("type" , magicType ))
407
- }
408
444
409
445
magicTypeLabel := magicTypeMetricLabel (magicType )
410
446
sshPty , windowSize , isPty := session .Pty ()
@@ -473,7 +509,7 @@ func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, mag
473
509
}()
474
510
go func () {
475
511
for sig := range sigs {
476
- s . handleSignal (logger , sig , cmd .Process , magicTypeLabel )
512
+ handleSignal (logger , sig , cmd .Process , s . metrics , magicTypeLabel )
477
513
}
478
514
}()
479
515
return cmd .Wait ()
@@ -558,7 +594,7 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
558
594
sigs = nil
559
595
continue
560
596
}
561
- s . handleSignal (logger , sig , process , magicTypeLabel )
597
+ handleSignal (logger , sig , process , s . metrics , magicTypeLabel )
562
598
case win , ok := <- windowSize :
563
599
if ! ok {
564
600
windowSize = nil
@@ -612,15 +648,15 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
612
648
return nil
613
649
}
614
650
615
- func ( s * Server ) handleSignal (logger slog.Logger , ssig ssh.Signal , signaler interface { Signal (os.Signal ) error }, magicTypeLabel string ) {
651
+ func handleSignal (logger slog.Logger , ssig ssh.Signal , signaler interface { Signal (os.Signal ) error }, metrics * sshServerMetrics , magicTypeLabel string ) {
616
652
ctx := context .Background ()
617
653
sig := osSignalFrom (ssig )
618
654
logger = logger .With (slog .F ("ssh_signal" , ssig ), slog .F ("signal" , sig .String ()))
619
655
logger .Info (ctx , "received signal from client" )
620
656
err := signaler .Signal (sig )
621
657
if err != nil {
622
658
logger .Warn (ctx , "signaling the process failed" , slog .Error (err ))
623
- s . metrics .sessionErrors .WithLabelValues (magicTypeLabel , "yes" , "signal" ).Add (1 )
659
+ metrics .sessionErrors .WithLabelValues (magicTypeLabel , "yes" , "signal" ).Add (1 )
624
660
}
625
661
}
626
662
0 commit comments