Skip to content

Commit dbbf8ac

Browse files
authored
fix: track JetBrains connections (coder#10968)
* feat: implement jetbrains agentssh tracking Based on tcp forwarding instead of ssh connections * Add JetBrains tracking to bottom bar
1 parent 51687c7 commit dbbf8ac

File tree

8 files changed

+347
-5
lines changed

8 files changed

+347
-5
lines changed

agent/agent_test.go

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package agent_test
22

33
import (
4+
"bufio"
45
"bytes"
56
"context"
67
"encoding/json"
@@ -152,7 +153,7 @@ func TestAgent_Stats_Magic(t *testing.T) {
152153
require.NoError(t, err)
153154
require.Equal(t, expected, strings.TrimSpace(string(output)))
154155
})
155-
t.Run("Tracks", func(t *testing.T) {
156+
t.Run("TracksVSCode", func(t *testing.T) {
156157
t.Parallel()
157158
if runtime.GOOS == "window" {
158159
t.Skip("Sleeping for infinity doesn't work on Windows")
@@ -191,6 +192,77 @@ func TestAgent_Stats_Magic(t *testing.T) {
191192
err = session.Wait()
192193
require.NoError(t, err)
193194
})
195+
196+
t.Run("TracksJetBrains", func(t *testing.T) {
197+
t.Parallel()
198+
if runtime.GOOS != "linux" {
199+
t.Skip("JetBrains tracking is only supported on Linux")
200+
}
201+
202+
ctx := testutil.Context(t, testutil.WaitLong)
203+
204+
// JetBrains tracking works by looking at the process name listening on the
205+
// forwarded port. If the process's command line includes the magic string
206+
// we are looking for, then we assume it is a JetBrains editor. So when we
207+
// connect to the port we must ensure the process includes that magic string
208+
// to fool the agent into thinking this is JetBrains. To do this we need to
209+
// spawn an external process (in this case a simple echo server) so we can
210+
// control the process name. The -D here is just to mimic how Java options
211+
// are set but is not necessary as the agent looks only for the magic
212+
// string itself anywhere in the command.
213+
_, b, _, ok := runtime.Caller(0)
214+
require.True(t, ok)
215+
dir := filepath.Join(filepath.Dir(b), "../scripts/echoserver/main.go")
216+
echoServerCmd := exec.Command("go", "run", dir,
217+
"-D", agentssh.MagicProcessCmdlineJetBrains)
218+
stdout, err := echoServerCmd.StdoutPipe()
219+
require.NoError(t, err)
220+
err = echoServerCmd.Start()
221+
require.NoError(t, err)
222+
defer echoServerCmd.Process.Kill()
223+
224+
// The echo server prints its port as the first line.
225+
sc := bufio.NewScanner(stdout)
226+
sc.Scan()
227+
remotePort := sc.Text()
228+
229+
//nolint:dogsled
230+
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
231+
sshClient, err := conn.SSHClient(ctx)
232+
require.NoError(t, err)
233+
234+
tunneledConn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%s", remotePort))
235+
require.NoError(t, err)
236+
t.Cleanup(func() {
237+
// always close on failure of test
238+
_ = conn.Close()
239+
_ = tunneledConn.Close()
240+
})
241+
242+
var s *agentsdk.Stats
243+
require.Eventuallyf(t, func() bool {
244+
var ok bool
245+
s, ok = <-stats
246+
return ok && s.ConnectionCount > 0 &&
247+
s.SessionCountJetBrains == 1
248+
}, testutil.WaitLong, testutil.IntervalFast,
249+
"never saw stats with conn open: %+v", s,
250+
)
251+
252+
// Kill the server and connection after checking for the echo.
253+
requireEcho(t, tunneledConn)
254+
_ = echoServerCmd.Process.Kill()
255+
_ = tunneledConn.Close()
256+
257+
require.Eventuallyf(t, func() bool {
258+
var ok bool
259+
s, ok = <-stats
260+
return ok && s.ConnectionCount == 0 &&
261+
s.SessionCountJetBrains == 0
262+
}, testutil.WaitLong, testutil.IntervalFast,
263+
"never saw stats after conn closes: %+v", s,
264+
)
265+
})
194266
}
195267

196268
func TestAgent_SessionExec(t *testing.T) {

agent/agentssh/agentssh.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,12 @@ const (
4747
MagicSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE"
4848
// MagicSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself.
4949
MagicSessionTypeVSCode = "vscode"
50-
// MagicSessionTypeJetBrains is set in the SSH config by the JetBrains extension to identify itself.
50+
// MagicSessionTypeJetBrains is set in the SSH config by the JetBrains
51+
// extension to identify itself.
5152
MagicSessionTypeJetBrains = "jetbrains"
53+
// MagicProcessCmdlineJetBrains is a string in a process's command line that
54+
// uniquely identifies it as JetBrains software.
55+
MagicProcessCmdlineJetBrains = "idea.vendor.name=JetBrains"
5256
)
5357

5458
type Server struct {
@@ -111,7 +115,11 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
111115

112116
srv := &ssh.Server{
113117
ChannelHandlers: map[string]ssh.ChannelHandler{
114-
"direct-tcpip": ssh.DirectTCPIPHandler,
118+
"direct-tcpip": func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
119+
// Wrapper is designed to find and track JetBrains Gateway connections.
120+
wrapped := NewJetbrainsChannelWatcher(ctx, s.logger, newChan, &s.connCountJetBrains)
121+
ssh.DirectTCPIPHandler(srv, conn, wrapped, ctx)
122+
},
115123
"direct-streamlocal@openssh.com": directStreamLocalHandler,
116124
"session": ssh.DefaultSessionHandler,
117125
},
@@ -291,8 +299,8 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv
291299
s.connCountVSCode.Add(1)
292300
defer s.connCountVSCode.Add(-1)
293301
case MagicSessionTypeJetBrains:
294-
s.connCountJetBrains.Add(1)
295-
defer s.connCountJetBrains.Add(-1)
302+
// Do nothing here because JetBrains launches hundreds of ssh sessions.
303+
// We instead track JetBrains in the single persistent tcp forwarding channel.
296304
case "":
297305
s.connCountSSHSession.Add(1)
298306
defer s.connCountSSHSession.Add(-1)

agent/agentssh/jetbrainstrack.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
package agentssh
2+
3+
import (
4+
"strings"
5+
"sync"
6+
7+
"cdr.dev/slog"
8+
"github.com/gliderlabs/ssh"
9+
"go.uber.org/atomic"
10+
gossh "golang.org/x/crypto/ssh"
11+
)
12+
13+
// localForwardChannelData is copied from the ssh package.
14+
type localForwardChannelData struct {
15+
DestAddr string
16+
DestPort uint32
17+
18+
OriginAddr string
19+
OriginPort uint32
20+
}
21+
22+
// JetbrainsChannelWatcher is used to track JetBrains port forwarded (Gateway)
23+
// channels. If the port forward is something other than JetBrains, this struct
24+
// is a noop.
25+
type JetbrainsChannelWatcher struct {
26+
gossh.NewChannel
27+
jetbrainsCounter *atomic.Int64
28+
}
29+
30+
func NewJetbrainsChannelWatcher(ctx ssh.Context, logger slog.Logger, newChannel gossh.NewChannel, counter *atomic.Int64) gossh.NewChannel {
31+
d := localForwardChannelData{}
32+
if err := gossh.Unmarshal(newChannel.ExtraData(), &d); err != nil {
33+
// If the data fails to unmarshal, do nothing.
34+
logger.Warn(ctx, "failed to unmarshal port forward data", slog.Error(err))
35+
return newChannel
36+
}
37+
38+
// If we do get a port, we should be able to get the matching PID and from
39+
// there look up the invocation.
40+
cmdline, err := getListeningPortProcessCmdline(d.DestPort)
41+
if err != nil {
42+
logger.Warn(ctx, "failed to inspect port",
43+
slog.F("destination_port", d.DestPort),
44+
slog.Error(err))
45+
return newChannel
46+
}
47+
48+
// If this is not JetBrains, then we do not need to do anything special. We
49+
// attempt to match on something that appears unique to JetBrains software.
50+
if !strings.Contains(strings.ToLower(cmdline), strings.ToLower(MagicProcessCmdlineJetBrains)) {
51+
return newChannel
52+
}
53+
54+
logger.Debug(ctx, "discovered forwarded JetBrains process",
55+
slog.F("destination_port", d.DestPort))
56+
57+
return &JetbrainsChannelWatcher{
58+
NewChannel: newChannel,
59+
jetbrainsCounter: counter,
60+
}
61+
}
62+
63+
func (w *JetbrainsChannelWatcher) Accept() (gossh.Channel, <-chan *gossh.Request, error) {
64+
c, r, err := w.NewChannel.Accept()
65+
if err != nil {
66+
return c, r, err
67+
}
68+
w.jetbrainsCounter.Add(1)
69+
70+
return &ChannelOnClose{
71+
Channel: c,
72+
done: func() {
73+
w.jetbrainsCounter.Add(-1)
74+
},
75+
}, r, err
76+
}
77+
78+
type ChannelOnClose struct {
79+
gossh.Channel
80+
// once ensures close only decrements the counter once.
81+
// Because close can be called multiple times.
82+
once sync.Once
83+
done func()
84+
}
85+
86+
func (c *ChannelOnClose) Close() error {
87+
c.once.Do(c.done)
88+
return c.Channel.Close()
89+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//go:build linux
2+
3+
package agentssh
4+
5+
import (
6+
"fmt"
7+
"os"
8+
9+
"github.com/cakturk/go-netstat/netstat"
10+
"golang.org/x/xerrors"
11+
)
12+
13+
func getListeningPortProcessCmdline(port uint32) (string, error) {
14+
tabs, err := netstat.TCPSocks(func(s *netstat.SockTabEntry) bool {
15+
return s.LocalAddr != nil && uint32(s.LocalAddr.Port) == port
16+
})
17+
if err != nil {
18+
return "", xerrors.Errorf("inspect port %d: %w", port, err)
19+
}
20+
if len(tabs) == 0 {
21+
return "", nil
22+
}
23+
// The process name provided by go-netstat does not include the full command
24+
// line so grab that instead.
25+
pid := tabs[0].Process.Pid
26+
data, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid))
27+
if err != nil {
28+
return "", xerrors.Errorf("read /proc/%d/cmdline: %w", pid, err)
29+
}
30+
return string(data), nil
31+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//go:build !linux
2+
3+
package agentssh
4+
5+
func getListeningPortProcessCmdline(port uint32) (string, error) {
6+
// We are not worrying about other platforms at the moment because Gateway
7+
// only supports Linux anyway.
8+
return "", nil
9+
}

scripts/echoserver/main.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package main
2+
3+
// A simple echo server. It listens on a random port, prints that port, then
4+
// echos back anything sent to it.
5+
6+
import (
7+
"errors"
8+
"fmt"
9+
"io"
10+
"log"
11+
"net"
12+
)
13+
14+
func main() {
15+
l, err := net.Listen("tcp", "127.0.0.1:0")
16+
if err != nil {
17+
log.Fatalf("listen error: err=%s", err)
18+
}
19+
20+
defer l.Close()
21+
tcpAddr, valid := l.Addr().(*net.TCPAddr)
22+
if !valid {
23+
log.Fatal("address is not valid")
24+
}
25+
26+
remotePort := tcpAddr.Port
27+
_, err = fmt.Println(remotePort)
28+
if err != nil {
29+
log.Fatalf("print error: err=%s", err)
30+
}
31+
32+
for {
33+
conn, err := l.Accept()
34+
if err != nil {
35+
log.Fatalf("accept error, err=%s", err)
36+
return
37+
}
38+
39+
go func() {
40+
defer conn.Close()
41+
_, err := io.Copy(conn, conn)
42+
43+
if errors.Is(err, io.EOF) {
44+
return
45+
} else if err != nil {
46+
log.Fatalf("copy error, err=%s", err)
47+
}
48+
}()
49+
}
50+
}

site/src/components/Dashboard/DeploymentBanner/DeploymentBannerView.tsx

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import BuildingIcon from "@mui/icons-material/Build";
1515
import Tooltip from "@mui/material/Tooltip";
1616
import { Link as RouterLink } from "react-router-dom";
1717
import Link from "@mui/material/Link";
18+
import { JetBrainsIcon } from "components/Icons/JetBrainsIcon";
1819
import { VSCodeIcon } from "components/Icons/VSCodeIcon";
1920
import DownloadIcon from "@mui/icons-material/CloudDownload";
2021
import UploadIcon from "@mui/icons-material/CloudUpload";
@@ -248,6 +249,21 @@ export const DeploymentBannerView: FC<DeploymentBannerViewProps> = ({
248249
</div>
249250
</Tooltip>
250251
<ValueSeparator />
252+
<Tooltip title="JetBrains Editors">
253+
<div css={styles.value}>
254+
<JetBrainsIcon
255+
css={css`
256+
& * {
257+
fill: currentColor;
258+
}
259+
`}
260+
/>
261+
{typeof stats?.session_count.jetbrains === "undefined"
262+
? "-"
263+
: stats?.session_count.jetbrains}
264+
</div>
265+
</Tooltip>
266+
<ValueSeparator />
251267
<Tooltip title="SSH Sessions">
252268
<div css={styles.value}>
253269
<TerminalIcon />

0 commit comments

Comments
 (0)