Skip to content
Merged
74 changes: 73 additions & 1 deletion agent/agent_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package agent_test

import (
"bufio"
"bytes"
"context"
"encoding/json"
Expand Down Expand Up @@ -152,7 +153,7 @@ func TestAgent_Stats_Magic(t *testing.T) {
require.NoError(t, err)
require.Equal(t, expected, strings.TrimSpace(string(output)))
})
t.Run("Tracks", func(t *testing.T) {
t.Run("TracksVSCode", func(t *testing.T) {
t.Parallel()
if runtime.GOOS == "window" {
t.Skip("Sleeping for infinity doesn't work on Windows")
Expand Down Expand Up @@ -191,6 +192,77 @@ func TestAgent_Stats_Magic(t *testing.T) {
err = session.Wait()
require.NoError(t, err)
})

t.Run("TracksJetBrains", func(t *testing.T) {
t.Parallel()
if runtime.GOOS != "linux" {
t.Skip("JetBrains tracking is only supported on Linux")
}

ctx := testutil.Context(t, testutil.WaitLong)

// JetBrains tracking works by looking at the process name listening on the
// forwarded port. If the process's command line includes the magic string
// we are looking for, then we assume it is a JetBrains editor. So when we
// connect to the port we must ensure the process includes that magic string
// to fool the agent into thinking this is JetBrains. To do this we need to
// spawn an external process (in this case a simple echo server) so we can
// control the process name. The -D here is just to mimic how Java options
// are set but is not necessary as the agent looks only for the magic
// string itself anywhere in the command.
_, b, _, ok := runtime.Caller(0)
require.True(t, ok)
dir := filepath.Join(filepath.Dir(b), "../scripts/echoserver/main.go")
echoServerCmd := exec.Command("go", "run", dir,
"-D", agentssh.MagicProcessCmdlineJetBrains)
stdout, err := echoServerCmd.StdoutPipe()
require.NoError(t, err)
err = echoServerCmd.Start()
require.NoError(t, err)
defer echoServerCmd.Process.Kill()

// The echo server prints its port as the first line.
sc := bufio.NewScanner(stdout)
sc.Scan()
remotePort := sc.Text()

//nolint:dogsled
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)

tunneledConn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%s", remotePort))
require.NoError(t, err)
t.Cleanup(func() {
// always close on failure of test
_ = conn.Close()
_ = tunneledConn.Close()
})

var s *agentsdk.Stats
require.Eventuallyf(t, func() bool {
var ok bool
s, ok = <-stats
return ok && s.ConnectionCount > 0 &&
s.SessionCountJetBrains == 1
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats with conn open: %+v", s,
)

// Kill the server and connection after checking for the echo.
requireEcho(t, tunneledConn)
_ = echoServerCmd.Process.Kill()
_ = tunneledConn.Close()

require.Eventuallyf(t, func() bool {
var ok bool
s, ok = <-stats
return ok && s.ConnectionCount == 0 &&
s.SessionCountJetBrains == 0
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats after conn closes: %+v", s,
)
})
}

func TestAgent_SessionExec(t *testing.T) {
Expand Down
16 changes: 12 additions & 4 deletions agent/agentssh/agentssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,12 @@ const (
MagicSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE"
// MagicSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself.
MagicSessionTypeVSCode = "vscode"
// MagicSessionTypeJetBrains is set in the SSH config by the JetBrains extension to identify itself.
// MagicSessionTypeJetBrains is set in the SSH config by the JetBrains
// extension to identify itself.
MagicSessionTypeJetBrains = "jetbrains"
// MagicProcessCmdlineJetBrains is a string in a process's command line that
// uniquely identifies it as JetBrains software.
MagicProcessCmdlineJetBrains = "idea.vendor.name=JetBrains"
)

type Server struct {
Expand Down Expand Up @@ -111,7 +115,11 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom

srv := &ssh.Server{
ChannelHandlers: map[string]ssh.ChannelHandler{
"direct-tcpip": ssh.DirectTCPIPHandler,
"direct-tcpip": func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
// Wrapper is designed to find and track JetBrains Gateway connections.
wrapped := NewJetbrainsChannelWatcher(ctx, s.logger, newChan, &s.connCountJetBrains)
ssh.DirectTCPIPHandler(srv, conn, wrapped, ctx)
},
"direct-streamlocal@openssh.com": directStreamLocalHandler,
"session": ssh.DefaultSessionHandler,
},
Expand Down Expand Up @@ -291,8 +299,8 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv
s.connCountVSCode.Add(1)
defer s.connCountVSCode.Add(-1)
case MagicSessionTypeJetBrains:
s.connCountJetBrains.Add(1)
defer s.connCountJetBrains.Add(-1)
// Do nothing here because JetBrains launches hundreds of ssh sessions.
// We instead track JetBrains in the single persistent tcp forwarding channel.
case "":
s.connCountSSHSession.Add(1)
defer s.connCountSSHSession.Add(-1)
Expand Down
89 changes: 89 additions & 0 deletions agent/agentssh/jetbrainstrack.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package agentssh

import (
"strings"
"sync"

"cdr.dev/slog"
"github.com/gliderlabs/ssh"
"go.uber.org/atomic"
gossh "golang.org/x/crypto/ssh"
)

// localForwardChannelData is copied from the ssh package.
type localForwardChannelData struct {
DestAddr string
DestPort uint32

OriginAddr string
OriginPort uint32
}

// JetbrainsChannelWatcher is used to track JetBrains port forwarded (Gateway)
// channels. If the port forward is something other than JetBrains, this struct
// is a noop.
type JetbrainsChannelWatcher struct {
gossh.NewChannel
jetbrainsCounter *atomic.Int64
}

func NewJetbrainsChannelWatcher(ctx ssh.Context, logger slog.Logger, newChannel gossh.NewChannel, counter *atomic.Int64) gossh.NewChannel {
d := localForwardChannelData{}
if err := gossh.Unmarshal(newChannel.ExtraData(), &d); err != nil {
// If the data fails to unmarshal, do nothing.
logger.Warn(ctx, "failed to unmarshal port forward data", slog.Error(err))
return newChannel
}

// If we do get a port, we should be able to get the matching PID and from
// there look up the invocation.
cmdline, err := getListeningPortProcessCmdline(d.DestPort)
if err != nil {
logger.Warn(ctx, "failed to inspect port",
slog.F("destination_port", d.DestPort),
slog.Error(err))
return newChannel
}

// If this is not JetBrains, then we do not need to do anything special. We
// attempt to match on something that appears unique to JetBrains software.
if !strings.Contains(strings.ToLower(cmdline), strings.ToLower(MagicProcessCmdlineJetBrains)) {
return newChannel
}

logger.Debug(ctx, "discovered forwarded JetBrains process",
slog.F("destination_port", d.DestPort))

return &JetbrainsChannelWatcher{
NewChannel: newChannel,
jetbrainsCounter: counter,
}
}

func (w *JetbrainsChannelWatcher) Accept() (gossh.Channel, <-chan *gossh.Request, error) {
c, r, err := w.NewChannel.Accept()
if err != nil {
return c, r, err
}
w.jetbrainsCounter.Add(1)

return &ChannelOnClose{
Channel: c,
done: func() {
w.jetbrainsCounter.Add(-1)
},
}, r, err
}

type ChannelOnClose struct {
gossh.Channel
// once ensures close only decrements the counter once.
// Because close can be called multiple times.
once sync.Once
done func()
}

func (c *ChannelOnClose) Close() error {
c.once.Do(c.done)
return c.Channel.Close()
}
31 changes: 31 additions & 0 deletions agent/agentssh/portinspection_supported.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//go:build linux

package agentssh

import (
"fmt"
"os"

"github.com/cakturk/go-netstat/netstat"
"golang.org/x/xerrors"
)

func getListeningPortProcessCmdline(port uint32) (string, error) {
tabs, err := netstat.TCPSocks(func(s *netstat.SockTabEntry) bool {
return s.LocalAddr != nil && uint32(s.LocalAddr.Port) == port
})
if err != nil {
return "", xerrors.Errorf("inspect port %d: %w", port, err)
}
if len(tabs) == 0 {
return "", nil
}
// The process name provided by go-netstat does not include the full command
// line so grab that instead.
pid := tabs[0].Process.Pid
data, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid))
if err != nil {
return "", xerrors.Errorf("read /proc/%d/cmdline: %w", pid, err)
}
return string(data), nil
}
9 changes: 9 additions & 0 deletions agent/agentssh/portinspection_unsupported.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
//go:build !linux

package agentssh

func getListeningPortProcessCmdline(port uint32) (string, error) {
// We are not worrying about other platforms at the moment because Gateway
// only supports Linux anyway.
return "", nil
}
50 changes: 50 additions & 0 deletions scripts/echoserver/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package main

// A simple echo server. It listens on a random port, prints that port, then
// echos back anything sent to it.

import (
"errors"
"fmt"
"io"
"log"
"net"
)

func main() {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
log.Fatalf("listen error: err=%s", err)
}

defer l.Close()
tcpAddr, valid := l.Addr().(*net.TCPAddr)
if !valid {
log.Fatal("address is not valid")
}

remotePort := tcpAddr.Port
_, err = fmt.Println(remotePort)
if err != nil {
log.Fatalf("print error: err=%s", err)
}

for {
conn, err := l.Accept()
if err != nil {
log.Fatalf("accept error, err=%s", err)
return
}

go func() {
defer conn.Close()
_, err := io.Copy(conn, conn)

if errors.Is(err, io.EOF) {
return
} else if err != nil {
log.Fatalf("copy error, err=%s", err)
}
}()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import BuildingIcon from "@mui/icons-material/Build";
import Tooltip from "@mui/material/Tooltip";
import { Link as RouterLink } from "react-router-dom";
import Link from "@mui/material/Link";
import { JetBrainsIcon } from "components/Icons/JetBrainsIcon";
import { VSCodeIcon } from "components/Icons/VSCodeIcon";
import DownloadIcon from "@mui/icons-material/CloudDownload";
import UploadIcon from "@mui/icons-material/CloudUpload";
Expand Down Expand Up @@ -263,6 +264,21 @@ export const DeploymentBannerView: FC<DeploymentBannerViewProps> = ({
</div>
</Tooltip>
<ValueSeparator />
<Tooltip title="JetBrains Editors">
<div css={styles.value}>
<JetBrainsIcon
css={css`
& * {
fill: currentColor;
}
`}
/>
{typeof stats?.session_count.jetbrains === "undefined"
? "-"
: stats?.session_count.jetbrains}
</div>
</Tooltip>
<ValueSeparator />
<Tooltip title="SSH Sessions">
<div css={styles.value}>
<TerminalIcon />
Expand Down
Loading