From c9a226f96c5f95784683aa1875ac7e2db1c9ee36 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 30 Nov 2023 16:08:03 -0600 Subject: [PATCH 01/12] feat: implement jetbrains agentssh tracking Based on tcp forwarding instead of ssh connections --- agent/agentssh/agentssh.go | 6 ++- agent/agentssh/jetbrainstrack.go | 63 ++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 agent/agentssh/jetbrainstrack.go diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index f88446ecf30e5..3437fe423633e 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -111,7 +111,11 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom srv := &ssh.Server{ ChannelHandlers: map[string]ssh.ChannelHandler{ - "direct-tcpip": ssh.DirectTCPIPHandler, + "direct-tcpip": func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) { + // wrapper is designed to find and track jetbrains gateway connections. + wrapped := NewChannelAcceptWatcher(s.logger, newChan, &s.connCountJetBrains) + ssh.DirectTCPIPHandler(srv, conn, wrapped, ctx) + }, "direct-streamlocal@openssh.com": directStreamLocalHandler, "session": ssh.DefaultSessionHandler, }, diff --git a/agent/agentssh/jetbrainstrack.go b/agent/agentssh/jetbrainstrack.go new file mode 100644 index 0000000000000..d8254ad81993a --- /dev/null +++ b/agent/agentssh/jetbrainstrack.go @@ -0,0 +1,63 @@ +package agentssh + +import ( + "cdr.dev/slog" + "go.uber.org/atomic" + gossh "golang.org/x/crypto/ssh" +) + +type localForwardChannelData struct { + DestAddr string + DestPort uint32 + + OriginAddr string + OriginPort uint32 +} + +type ChannelAcceptWatcher struct { + gossh.NewChannel + jetbrainsCounter *atomic.Int64 +} + +func NewChannelAcceptWatcher(logger slog.Logger, newChannel gossh.NewChannel, counter *atomic.Int64) gossh.NewChannel { + d := localForwardChannelData{} + if err := gossh.Unmarshal(newChannel.ExtraData(), &d); err != nil { + // If the data fails to unmarshal, do nothing + return newChannel + } + + //if !jetbrains { + // If this isn't jetbrains, then we don't need to do anything special. + //return newChannel + //} + + return &ChannelAcceptWatcher{ + NewChannel: newChannel, + jetbrainsCounter: counter, + } +} + +func (w *ChannelAcceptWatcher) Accept() (gossh.Channel, <-chan *gossh.Request, error) { + c, r, err := w.NewChannel.Accept() + if err != nil { + return c, r, err + } + w.jetbrainsCounter.Add(1) + + return &ChannelOnClose{ + Channel: c, + done: func() { + w.jetbrainsCounter.Add(-1) + }, + }, r, err +} + +type ChannelOnClose struct { + gossh.Channel + done func() +} + +func (c *ChannelOnClose) Close() error { + c.done() + return c.Channel.Close() +} From 244797063de85c86dcf01333d824e87cecd68ce6 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 30 Nov 2023 16:20:55 -0600 Subject: [PATCH 02/12] Add unit test to confirm tracking --- agent/agent_test.go | 22 ++++++++++++++++++++++ agent/agentssh/agentssh.go | 4 ++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/agent/agent_test.go b/agent/agent_test.go index 31f1448f34018..42983af89aba8 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -191,6 +191,28 @@ func TestAgent_Stats_Magic(t *testing.T) { err = session.Wait() require.NoError(t, err) }) + + // This test name being "Jetbrains" is required to be a certain string. + // It must match the regex check in the agent for Jetbrains. + t.Run("Jetbrains", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + + rl, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer rl.Close() + tcpAddr, valid := rl.Addr().(*net.TCPAddr) + require.True(t, valid) + remotePort := tcpAddr.Port + go echoOnce(t, rl) + + sshClient := setupAgentSSHClient(ctx, t) + + conn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort)) + require.NoError(t, err) + defer conn.Close() + requireEcho(t, conn) + }) } func TestAgent_SessionExec(t *testing.T) { diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index 3437fe423633e..9d82dab2b2b3f 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -295,8 +295,8 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv s.connCountVSCode.Add(1) defer s.connCountVSCode.Add(-1) case MagicSessionTypeJetBrains: - s.connCountJetBrains.Add(1) - defer s.connCountJetBrains.Add(-1) + // Do nothing here because jetbrains launches hundreds of ssh sessions. + // We instead track jetbrains in the single persistent tcp forwarding channel. case "": s.connCountSSHSession.Add(1) defer s.connCountSSHSession.Add(-1) From 76d3a24094f8146594d6cda68be1c72a43b42d8c Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 30 Nov 2023 16:39:42 -0600 Subject: [PATCH 03/12] implement unit test to verify jetbrains functionality --- agent/agent_test.go | 37 ++++++++++++++++++++++++++++---- agent/agentssh/jetbrainstrack.go | 7 +++++- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/agent/agent_test.go b/agent/agent_test.go index 42983af89aba8..7300b89370b5e 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -206,12 +206,41 @@ func TestAgent_Stats_Magic(t *testing.T) { remotePort := tcpAddr.Port go echoOnce(t, rl) - sshClient := setupAgentSSHClient(ctx, t) + //nolint:dogsled + conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + sshClient, err := conn.SSHClient(ctx) + require.NoError(t, err) - conn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort)) + tunneledConn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort)) require.NoError(t, err) - defer conn.Close() - requireEcho(t, conn) + t.Cleanup(func() { + // always close on failure of test + _ = conn.Close() + _ = tunneledConn.Close() + }) + + var s *agentsdk.Stats + require.Eventuallyf(t, func() bool { + var ok bool + s, ok = <-stats + return ok && s.ConnectionCount > 0 && + s.SessionCountJetBrains == 1 + }, testutil.WaitLong, testutil.IntervalFast, + "never saw stats with conn open: %+v", s, + ) + + // Manually closing the connection + requireEcho(t, tunneledConn) + _ = rl.Close() + + require.Eventuallyf(t, func() bool { + var ok bool + s, ok = <-stats + return ok && s.ConnectionCount == 0 && + s.SessionCountJetBrains == 0 + }, testutil.WaitLong, testutil.IntervalFast, + "never saw stats after conn closes: %+v", s, + ) }) } diff --git a/agent/agentssh/jetbrainstrack.go b/agent/agentssh/jetbrainstrack.go index d8254ad81993a..4a996e0da683d 100644 --- a/agent/agentssh/jetbrainstrack.go +++ b/agent/agentssh/jetbrainstrack.go @@ -1,6 +1,8 @@ package agentssh import ( + "sync" + "cdr.dev/slog" "go.uber.org/atomic" gossh "golang.org/x/crypto/ssh" @@ -54,10 +56,13 @@ func (w *ChannelAcceptWatcher) Accept() (gossh.Channel, <-chan *gossh.Request, e type ChannelOnClose struct { gossh.Channel + // once ensures close only decrements the counter once. + // Because close can be called multiple times. + once sync.Once done func() } func (c *ChannelOnClose) Close() error { - c.done() + c.once.Do(c.done) return c.Channel.Close() } From adf2fb37a942cb7b1885ae692262103e79355517 Mon Sep 17 00:00:00 2001 From: Asher Date: Thu, 30 Nov 2023 14:58:13 -0900 Subject: [PATCH 04/12] Implement port process inspection --- agent/agent_test.go | 9 ++++-- agent/agentssh/agentssh.go | 2 +- agent/agentssh/jetbrainstrack.go | 32 ++++++++++++++++---- agent/agentssh/portinspection_supported.go | 31 +++++++++++++++++++ agent/agentssh/portinspection_unsupported.go | 9 ++++++ 5 files changed, 73 insertions(+), 10 deletions(-) create mode 100644 agent/agentssh/portinspection_supported.go create mode 100644 agent/agentssh/portinspection_unsupported.go diff --git a/agent/agent_test.go b/agent/agent_test.go index 7300b89370b5e..901427a74d0a6 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -192,10 +192,13 @@ func TestAgent_Stats_Magic(t *testing.T) { require.NoError(t, err) }) - // This test name being "Jetbrains" is required to be a certain string. - // It must match the regex check in the agent for Jetbrains. - t.Run("Jetbrains", func(t *testing.T) { + // This test name must contain the string checked for by the agent, since it + // looks for this string in the process name. + t.Run("TracksIdea.vendor.name=JetBrains", func(t *testing.T) { t.Parallel() + if runtime.GOOS != "linux" { + t.Skip("JetBrains tracking is only supported on Linux") + } ctx := testutil.Context(t, testutil.WaitLong) rl, err := net.Listen("tcp", "127.0.0.1:0") diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index 9d82dab2b2b3f..da115f89a2509 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -113,7 +113,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom ChannelHandlers: map[string]ssh.ChannelHandler{ "direct-tcpip": func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) { // wrapper is designed to find and track jetbrains gateway connections. - wrapped := NewChannelAcceptWatcher(s.logger, newChan, &s.connCountJetBrains) + wrapped := NewChannelAcceptWatcher(ctx, s.logger, newChan, &s.connCountJetBrains) ssh.DirectTCPIPHandler(srv, conn, wrapped, ctx) }, "direct-streamlocal@openssh.com": directStreamLocalHandler, diff --git a/agent/agentssh/jetbrainstrack.go b/agent/agentssh/jetbrainstrack.go index 4a996e0da683d..c171ac98e7e6c 100644 --- a/agent/agentssh/jetbrainstrack.go +++ b/agent/agentssh/jetbrainstrack.go @@ -1,9 +1,11 @@ package agentssh import ( + "strings" "sync" "cdr.dev/slog" + "github.com/gliderlabs/ssh" "go.uber.org/atomic" gossh "golang.org/x/crypto/ssh" ) @@ -21,17 +23,35 @@ type ChannelAcceptWatcher struct { jetbrainsCounter *atomic.Int64 } -func NewChannelAcceptWatcher(logger slog.Logger, newChannel gossh.NewChannel, counter *atomic.Int64) gossh.NewChannel { +func NewChannelAcceptWatcher(ctx ssh.Context, logger slog.Logger, newChannel gossh.NewChannel, counter *atomic.Int64) gossh.NewChannel { d := localForwardChannelData{} if err := gossh.Unmarshal(newChannel.ExtraData(), &d); err != nil { - // If the data fails to unmarshal, do nothing + // If the data fails to unmarshal, do nothing. return newChannel } - //if !jetbrains { - // If this isn't jetbrains, then we don't need to do anything special. - //return newChannel - //} + // If we do get a port, we should be able to get the matching PID and from + // there look up the invocation. + cmdline, err := getListeningPortProcessCmdline(d.DestPort) + if err != nil { + logger.Warn(ctx, "port inspection failed", + slog.F("destination_port", d.DestPort), + slog.Error(err)) + return newChannel + } + logger.Debug(ctx, "checking forwarded process", + slog.F("cmdline", cmdline), + slog.F("destination_port", d.DestPort)) + + // If this is not JetBrains, then we do not need to do anything special. We + // attempt to match on something that appears unique to JetBrains software and + // the vendor name flag seems like it might be a reasonable choice. + if !strings.Contains(strings.ToLower(cmdline), "idea.vendor.name=jetbrains") { + return newChannel + } + + logger.Debug(ctx, "discovered forwarded JetBrains process", + slog.F("destination_port", d.DestPort)) return &ChannelAcceptWatcher{ NewChannel: newChannel, diff --git a/agent/agentssh/portinspection_supported.go b/agent/agentssh/portinspection_supported.go new file mode 100644 index 0000000000000..45f59accc40f6 --- /dev/null +++ b/agent/agentssh/portinspection_supported.go @@ -0,0 +1,31 @@ +//go:build linux + +package agentssh + +import ( + "fmt" + "os" + + "github.com/cakturk/go-netstat/netstat" + "golang.org/x/xerrors" +) + +func getListeningPortProcessCmdline(port uint32) (string, error) { + tabs, err := netstat.TCPSocks(func(s *netstat.SockTabEntry) bool { + return s.LocalAddr != nil && uint32(s.LocalAddr.Port) == port + }) + if err != nil { + return "", xerrors.Errorf("inspect port %d: %w", port, err) + } + if len(tabs) == 0 { + return "", nil + } + // The process name provided by go-netstat does not include the full command + // line so grab that instead. + pid := tabs[0].Process.Pid + data, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid)) + if err != nil { + return "", xerrors.Errorf("read /proc/%d/cmdline: %w", pid, err) + } + return string(data), nil +} diff --git a/agent/agentssh/portinspection_unsupported.go b/agent/agentssh/portinspection_unsupported.go new file mode 100644 index 0000000000000..f010d0385815f --- /dev/null +++ b/agent/agentssh/portinspection_unsupported.go @@ -0,0 +1,9 @@ +//go:build !linux + +package agentssh + +func getListeningPortProcessCmdline(port uint32) (string, error) { + // We are not worrying about other platforms at the moment because Gateway + // only supports Linux anyway. + return "", nil +} From ad034f245c53b15b5e357504c5c43a49617ed5cc Mon Sep 17 00:00:00 2001 From: Asher Date: Fri, 1 Dec 2023 10:18:28 -0900 Subject: [PATCH 05/12] Add JetBrains tracking to bottom bar --- .../DeploymentBanner/DeploymentBannerView.tsx | 16 +++++ site/src/components/Icons/JetBrainsIcon.tsx | 67 +++++++++++++++++++ 2 files changed, 83 insertions(+) create mode 100644 site/src/components/Icons/JetBrainsIcon.tsx diff --git a/site/src/components/Dashboard/DeploymentBanner/DeploymentBannerView.tsx b/site/src/components/Dashboard/DeploymentBanner/DeploymentBannerView.tsx index 16a01a15c3ac3..e6f3cffe19b15 100644 --- a/site/src/components/Dashboard/DeploymentBanner/DeploymentBannerView.tsx +++ b/site/src/components/Dashboard/DeploymentBanner/DeploymentBannerView.tsx @@ -15,6 +15,7 @@ import BuildingIcon from "@mui/icons-material/Build"; import Tooltip from "@mui/material/Tooltip"; import { Link as RouterLink } from "react-router-dom"; import Link from "@mui/material/Link"; +import { JetBrainsIcon } from "components/Icons/JetBrainsIcon"; import { VSCodeIcon } from "components/Icons/VSCodeIcon"; import DownloadIcon from "@mui/icons-material/CloudDownload"; import UploadIcon from "@mui/icons-material/CloudUpload"; @@ -263,6 +264,21 @@ export const DeploymentBannerView: FC = ({ + +
+ + {typeof stats?.session_count.jetbrains === "undefined" + ? "-" + : stats?.session_count.jetbrains} +
+
+
diff --git a/site/src/components/Icons/JetBrainsIcon.tsx b/site/src/components/Icons/JetBrainsIcon.tsx new file mode 100644 index 0000000000000..fb551a7e52f7a --- /dev/null +++ b/site/src/components/Icons/JetBrainsIcon.tsx @@ -0,0 +1,67 @@ +import SvgIcon, { SvgIconProps } from "@mui/material/SvgIcon"; + +export const JetBrainsIcon = (props: SvgIconProps) => ( + + + + + + + + + + + + + + + + + + + + + +); From 34b7c5eb6003cf3667f64b6a229a254052f1636c Mon Sep 17 00:00:00 2001 From: Asher Date: Fri, 1 Dec 2023 11:47:33 -0900 Subject: [PATCH 06/12] Elaborate on process name check comment Co-authored-by: Steven Masley --- agent/agent_test.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/agent/agent_test.go b/agent/agent_test.go index 901427a74d0a6..11922c5c6d1f3 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -194,6 +194,13 @@ func TestAgent_Stats_Magic(t *testing.T) { // This test name must contain the string checked for by the agent, since it // looks for this string in the process name. + // + // This test sets up a port forward that emulates what Jetbrains IDE's do when + // using gateway. The remote server side of the port forward is spun up using + // the gotest process, which includes the test name. + // So this unit test emulates a PID in the workspace with a similar + // name to the jetbrains IDE. That makes the agent this this SSH port + // forward is a "jetbrains" session. t.Run("TracksIdea.vendor.name=JetBrains", func(t *testing.T) { t.Parallel() if runtime.GOOS != "linux" { From dce56fd1a1cf66a6c8013ed17af73caef0ca2592 Mon Sep 17 00:00:00 2001 From: Asher Date: Fri, 1 Dec 2023 11:48:04 -0900 Subject: [PATCH 07/12] Comment that localForwardChannelData is copied Co-authored-by: Steven Masley --- agent/agentssh/jetbrainstrack.go | 1 + 1 file changed, 1 insertion(+) diff --git a/agent/agentssh/jetbrainstrack.go b/agent/agentssh/jetbrainstrack.go index c171ac98e7e6c..63b07413b2a69 100644 --- a/agent/agentssh/jetbrainstrack.go +++ b/agent/agentssh/jetbrainstrack.go @@ -10,6 +10,7 @@ import ( gossh "golang.org/x/crypto/ssh" ) +// localForwardChannelData is copied from the ssh package. type localForwardChannelData struct { DestAddr string DestPort uint32 From 7139448235a78112a2e6421aa43c97e1802de986 Mon Sep 17 00:00:00 2001 From: Asher Date: Fri, 1 Dec 2023 11:48:23 -0900 Subject: [PATCH 08/12] Comment ChannelAccepterWatcher Co-authored-by: Steven Masley --- agent/agentssh/jetbrainstrack.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/agent/agentssh/jetbrainstrack.go b/agent/agentssh/jetbrainstrack.go index 63b07413b2a69..9c1d27b087db5 100644 --- a/agent/agentssh/jetbrainstrack.go +++ b/agent/agentssh/jetbrainstrack.go @@ -19,6 +19,9 @@ type localForwardChannelData struct { OriginPort uint32 } +// ChannelAcceptWatcher is used to track jetbrains port forwarding (gateway) +// connections. If the port forward is something other than jetbrains, this +// struct is a noop. type ChannelAcceptWatcher struct { gossh.NewChannel jetbrainsCounter *atomic.Int64 From 6e8f235e38e0bf71021c35f4df08b094e42148ea Mon Sep 17 00:00:00 2001 From: Asher Date: Fri, 1 Dec 2023 11:51:53 -0900 Subject: [PATCH 09/12] Rename channel watcher to be specific to Jetbrains --- agent/agentssh/agentssh.go | 8 ++++---- agent/agentssh/jetbrainstrack.go | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index da115f89a2509..74e93410cc4a9 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -112,8 +112,8 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom srv := &ssh.Server{ ChannelHandlers: map[string]ssh.ChannelHandler{ "direct-tcpip": func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) { - // wrapper is designed to find and track jetbrains gateway connections. - wrapped := NewChannelAcceptWatcher(ctx, s.logger, newChan, &s.connCountJetBrains) + // Wrapper is designed to find and track JetBrains Gateway connections. + wrapped := NewJetbrainsChannelWatcher(ctx, s.logger, newChan, &s.connCountJetBrains) ssh.DirectTCPIPHandler(srv, conn, wrapped, ctx) }, "direct-streamlocal@openssh.com": directStreamLocalHandler, @@ -295,8 +295,8 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv s.connCountVSCode.Add(1) defer s.connCountVSCode.Add(-1) case MagicSessionTypeJetBrains: - // Do nothing here because jetbrains launches hundreds of ssh sessions. - // We instead track jetbrains in the single persistent tcp forwarding channel. + // Do nothing here because JetBrains launches hundreds of ssh sessions. + // We instead track JetBrains in the single persistent tcp forwarding channel. case "": s.connCountSSHSession.Add(1) defer s.connCountSSHSession.Add(-1) diff --git a/agent/agentssh/jetbrainstrack.go b/agent/agentssh/jetbrainstrack.go index 9c1d27b087db5..38bba1cc65465 100644 --- a/agent/agentssh/jetbrainstrack.go +++ b/agent/agentssh/jetbrainstrack.go @@ -19,15 +19,15 @@ type localForwardChannelData struct { OriginPort uint32 } -// ChannelAcceptWatcher is used to track jetbrains port forwarding (gateway) -// connections. If the port forward is something other than jetbrains, this -// struct is a noop. -type ChannelAcceptWatcher struct { +// JetbrainsChannelWatcher is used to track JetBrains port forwarded (Gateway) +// channels. If the port forward is something other than JetBrains, this struct +// is a noop. +type JetbrainsChannelWatcher struct { gossh.NewChannel jetbrainsCounter *atomic.Int64 } -func NewChannelAcceptWatcher(ctx ssh.Context, logger slog.Logger, newChannel gossh.NewChannel, counter *atomic.Int64) gossh.NewChannel { +func NewJetbrainsChannelWatcher(ctx ssh.Context, logger slog.Logger, newChannel gossh.NewChannel, counter *atomic.Int64) gossh.NewChannel { d := localForwardChannelData{} if err := gossh.Unmarshal(newChannel.ExtraData(), &d); err != nil { // If the data fails to unmarshal, do nothing. @@ -57,13 +57,13 @@ func NewChannelAcceptWatcher(ctx ssh.Context, logger slog.Logger, newChannel gos logger.Debug(ctx, "discovered forwarded JetBrains process", slog.F("destination_port", d.DestPort)) - return &ChannelAcceptWatcher{ + return &JetbrainsChannelWatcher{ NewChannel: newChannel, jetbrainsCounter: counter, } } -func (w *ChannelAcceptWatcher) Accept() (gossh.Channel, <-chan *gossh.Request, error) { +func (w *JetbrainsChannelWatcher) Accept() (gossh.Channel, <-chan *gossh.Request, error) { c, r, err := w.NewChannel.Accept() if err != nil { return c, r, err From 4d654782c420345f0f69febd73af87654688304b Mon Sep 17 00:00:00 2001 From: Asher Date: Fri, 1 Dec 2023 12:14:04 -0900 Subject: [PATCH 10/12] Log unmarshal failure --- agent/agentssh/jetbrainstrack.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/agent/agentssh/jetbrainstrack.go b/agent/agentssh/jetbrainstrack.go index 38bba1cc65465..761262ab9a1df 100644 --- a/agent/agentssh/jetbrainstrack.go +++ b/agent/agentssh/jetbrainstrack.go @@ -31,6 +31,7 @@ func NewJetbrainsChannelWatcher(ctx ssh.Context, logger slog.Logger, newChannel d := localForwardChannelData{} if err := gossh.Unmarshal(newChannel.ExtraData(), &d); err != nil { // If the data fails to unmarshal, do nothing. + logger.Warn(ctx, "failed to unmarshal port forward data", slog.Error(err)) return newChannel } @@ -38,14 +39,11 @@ func NewJetbrainsChannelWatcher(ctx ssh.Context, logger slog.Logger, newChannel // there look up the invocation. cmdline, err := getListeningPortProcessCmdline(d.DestPort) if err != nil { - logger.Warn(ctx, "port inspection failed", + logger.Warn(ctx, "failed to inspect port", slog.F("destination_port", d.DestPort), slog.Error(err)) return newChannel } - logger.Debug(ctx, "checking forwarded process", - slog.F("cmdline", cmdline), - slog.F("destination_port", d.DestPort)) // If this is not JetBrains, then we do not need to do anything special. We // attempt to match on something that appears unique to JetBrains software and From 254a5b6a6a0245d0c11d74af226804a8e0eee945 Mon Sep 17 00:00:00 2001 From: Asher Date: Fri, 1 Dec 2023 13:38:48 -0900 Subject: [PATCH 11/12] Add constant for JetBrains magic string --- agent/agentssh/agentssh.go | 6 +++++- agent/agentssh/jetbrainstrack.go | 5 ++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index 74e93410cc4a9..1021d04592629 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -47,8 +47,12 @@ const ( MagicSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE" // MagicSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself. MagicSessionTypeVSCode = "vscode" - // MagicSessionTypeJetBrains is set in the SSH config by the JetBrains extension to identify itself. + // MagicSessionTypeJetBrains is set in the SSH config by the JetBrains + // extension to identify itself. MagicSessionTypeJetBrains = "jetbrains" + // MagicProcessCmdlineJetBrains is a string in a process's command line that + // uniquely identifies it as JetBrains software. + MagicProcessCmdlineJetBrains = "idea.vendor.name=JetBrains" ) type Server struct { diff --git a/agent/agentssh/jetbrainstrack.go b/agent/agentssh/jetbrainstrack.go index 761262ab9a1df..25c8f04dd6e78 100644 --- a/agent/agentssh/jetbrainstrack.go +++ b/agent/agentssh/jetbrainstrack.go @@ -46,9 +46,8 @@ func NewJetbrainsChannelWatcher(ctx ssh.Context, logger slog.Logger, newChannel } // If this is not JetBrains, then we do not need to do anything special. We - // attempt to match on something that appears unique to JetBrains software and - // the vendor name flag seems like it might be a reasonable choice. - if !strings.Contains(strings.ToLower(cmdline), "idea.vendor.name=jetbrains") { + // attempt to match on something that appears unique to JetBrains software. + if !strings.Contains(strings.ToLower(cmdline), strings.ToLower(MagicProcessCmdlineJetBrains)) { return newChannel } From a75ed6c26930375588837dbefcc4d27e3eddf423 Mon Sep 17 00:00:00 2001 From: Asher Date: Fri, 1 Dec 2023 12:27:08 -0900 Subject: [PATCH 12/12] Fix JetBrains tracking test The test name only shows up in the process name if you are running that test directly so we have to spawn a separate process instead. --- agent/agent_test.go | 51 +++++++++++++++++++++++--------------- scripts/echoserver/main.go | 50 +++++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 20 deletions(-) create mode 100644 scripts/echoserver/main.go diff --git a/agent/agent_test.go b/agent/agent_test.go index 11922c5c6d1f3..225270f67bf2a 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -1,6 +1,7 @@ package agent_test import ( + "bufio" "bytes" "context" "encoding/json" @@ -152,7 +153,7 @@ func TestAgent_Stats_Magic(t *testing.T) { require.NoError(t, err) require.Equal(t, expected, strings.TrimSpace(string(output))) }) - t.Run("Tracks", func(t *testing.T) { + t.Run("TracksVSCode", func(t *testing.T) { t.Parallel() if runtime.GOOS == "window" { t.Skip("Sleeping for infinity doesn't work on Windows") @@ -192,36 +193,45 @@ func TestAgent_Stats_Magic(t *testing.T) { require.NoError(t, err) }) - // This test name must contain the string checked for by the agent, since it - // looks for this string in the process name. - // - // This test sets up a port forward that emulates what Jetbrains IDE's do when - // using gateway. The remote server side of the port forward is spun up using - // the gotest process, which includes the test name. - // So this unit test emulates a PID in the workspace with a similar - // name to the jetbrains IDE. That makes the agent this this SSH port - // forward is a "jetbrains" session. - t.Run("TracksIdea.vendor.name=JetBrains", func(t *testing.T) { + t.Run("TracksJetBrains", func(t *testing.T) { t.Parallel() if runtime.GOOS != "linux" { t.Skip("JetBrains tracking is only supported on Linux") } + ctx := testutil.Context(t, testutil.WaitLong) - rl, err := net.Listen("tcp", "127.0.0.1:0") + // JetBrains tracking works by looking at the process name listening on the + // forwarded port. If the process's command line includes the magic string + // we are looking for, then we assume it is a JetBrains editor. So when we + // connect to the port we must ensure the process includes that magic string + // to fool the agent into thinking this is JetBrains. To do this we need to + // spawn an external process (in this case a simple echo server) so we can + // control the process name. The -D here is just to mimic how Java options + // are set but is not necessary as the agent looks only for the magic + // string itself anywhere in the command. + _, b, _, ok := runtime.Caller(0) + require.True(t, ok) + dir := filepath.Join(filepath.Dir(b), "../scripts/echoserver/main.go") + echoServerCmd := exec.Command("go", "run", dir, + "-D", agentssh.MagicProcessCmdlineJetBrains) + stdout, err := echoServerCmd.StdoutPipe() require.NoError(t, err) - defer rl.Close() - tcpAddr, valid := rl.Addr().(*net.TCPAddr) - require.True(t, valid) - remotePort := tcpAddr.Port - go echoOnce(t, rl) + err = echoServerCmd.Start() + require.NoError(t, err) + defer echoServerCmd.Process.Kill() + + // The echo server prints its port as the first line. + sc := bufio.NewScanner(stdout) + sc.Scan() + remotePort := sc.Text() //nolint:dogsled conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) - tunneledConn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort)) + tunneledConn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%s", remotePort)) require.NoError(t, err) t.Cleanup(func() { // always close on failure of test @@ -239,9 +249,10 @@ func TestAgent_Stats_Magic(t *testing.T) { "never saw stats with conn open: %+v", s, ) - // Manually closing the connection + // Kill the server and connection after checking for the echo. requireEcho(t, tunneledConn) - _ = rl.Close() + _ = echoServerCmd.Process.Kill() + _ = tunneledConn.Close() require.Eventuallyf(t, func() bool { var ok bool diff --git a/scripts/echoserver/main.go b/scripts/echoserver/main.go new file mode 100644 index 0000000000000..cb30a0b3839df --- /dev/null +++ b/scripts/echoserver/main.go @@ -0,0 +1,50 @@ +package main + +// A simple echo server. It listens on a random port, prints that port, then +// echos back anything sent to it. + +import ( + "errors" + "fmt" + "io" + "log" + "net" +) + +func main() { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + log.Fatalf("listen error: err=%s", err) + } + + defer l.Close() + tcpAddr, valid := l.Addr().(*net.TCPAddr) + if !valid { + log.Fatal("address is not valid") + } + + remotePort := tcpAddr.Port + _, err = fmt.Println(remotePort) + if err != nil { + log.Fatalf("print error: err=%s", err) + } + + for { + conn, err := l.Accept() + if err != nil { + log.Fatalf("accept error, err=%s", err) + return + } + + go func() { + defer conn.Close() + _, err := io.Copy(conn, conn) + + if errors.Is(err, io.EOF) { + return + } else if err != nil { + log.Fatalf("copy error, err=%s", err) + } + }() + } +}