From 58463c6c9fe207cb8999b877581d90232cbbae41 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Fri, 21 Feb 2025 13:22:04 +0000 Subject: [PATCH 1/7] feat(agent): add connection reporting for SSH and reconnecing PTY Updates #15139 --- agent/agent.go | 116 +++++++++++++++++++++++++++++++ agent/agent_test.go | 74 +++++++++++++++++--- agent/agentssh/agentssh.go | 87 ++++++++++++++++++++--- agent/agentssh/jetbrainstrack.go | 11 ++- agent/agenttest/client.go | 30 +++++--- agent/reconnectingpty/server.go | 26 ++++++- 6 files changed, 312 insertions(+), 32 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 523892d3f65c9..f9b2754a4730c 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -27,6 +27,7 @@ import ( "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" "golang.org/x/xerrors" + "google.golang.org/protobuf/types/known/timestamppb" "tailscale.com/net/speedtest" "tailscale.com/tailcfg" "tailscale.com/types/netlogtype" @@ -174,6 +175,7 @@ func New(options Options) Agent { lifecycleUpdate: make(chan struct{}, 1), lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1), lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}}, + reportConnectionsUpdate: make(chan struct{}, 1), ignorePorts: options.IgnorePorts, portCacheDuration: options.PortCacheDuration, reportMetadataInterval: options.ReportMetadataInterval, @@ -247,6 +249,10 @@ type agent struct { lifecycleStates []agentsdk.PostLifecycleRequest lifecycleLastReportedIndex int // Keeps track of the last lifecycle state we successfully reported. + reportConnectionsUpdate chan struct{} + reportConnectionsMu sync.Mutex + reportConnections []*proto.ReportConnectionRequest + network *tailnet.Conn statsReporter *statsReporter logSender *agentsdk.LogSender @@ -272,6 +278,24 @@ func (a *agent) init() { UpdateEnv: a.updateCommandEnv, WorkingDirectory: func() string { return a.manifest.Load().Directory }, BlockFileTransfer: a.blockFileTransfer, + ReportConnection: func(id uuid.UUID, magicType agentssh.MagicSessionType, ip string) func(code int, reason string) { + var connectionType proto.Connection_Type + switch magicType { + case agentssh.MagicSessionTypeSSH: + connectionType = proto.Connection_SSH + case agentssh.MagicSessionTypeVSCode: + connectionType = proto.Connection_VSCODE + case agentssh.MagicSessionTypeJetBrains: + connectionType = proto.Connection_JETBRAINS + case agentssh.MagicSessionTypeUnknown: + connectionType = proto.Connection_TYPE_UNSPECIFIED + default: + a.logger.Error(a.hardCtx, "unhandled magic session type when reporting connection", slog.F("magic_type", magicType)) + connectionType = proto.Connection_TYPE_UNSPECIFIED + } + + return a.reportConnection(id, connectionType, ip) + }, }) if err != nil { panic(err) @@ -294,6 +318,9 @@ func (a *agent) init() { a.reconnectingPTYServer = reconnectingpty.NewServer( a.logger.Named("reconnecting-pty"), a.sshServer, + func(id uuid.UUID, ip string) func(code int, reason string) { + return a.reportConnection(id, proto.Connection_RECONNECTING_PTY, ip) + }, a.metrics.connectionsTotal, a.metrics.reconnectingPTYErrors, a.reconnectingPTYTimeout, ) @@ -703,6 +730,91 @@ func (a *agent) setLifecycle(state codersdk.WorkspaceAgentLifecycle) { } } +// reportConnectionsLoop reports connections to the agent for auditing. +func (a *agent) reportConnectionsLoop(ctx context.Context, aAPI proto.DRPCAgentClient24) error { + for { + select { + case <-a.reportConnectionsUpdate: + case <-ctx.Done(): + return ctx.Err() + } + + for { + a.reportConnectionsMu.Lock() + if len(a.reportConnections) == 0 { + a.reportConnectionsMu.Unlock() + break + } + payload := a.reportConnections[0] + a.reportConnectionsMu.Unlock() + + logger := a.logger.With(slog.F("payload", payload)) + logger.Debug(ctx, "reporting connection") + _, err := aAPI.ReportConnection(ctx, payload) + if err != nil { + return xerrors.Errorf("failed to report connection: %w", err) + } + + logger.Debug(ctx, "successfully reported connection") + + a.reportConnectionsMu.Lock() + a.reportConnections = a.reportConnections[1:] + count := len(a.reportConnections) + a.reportConnectionsMu.Unlock() + + if count == 0 { + break + } + } + } +} + +func (a *agent) reportConnection(id uuid.UUID, connectionType proto.Connection_Type, ip string) (disconnected func(code int, reason string)) { + // Remove the port from the IP. + if portIndex := strings.LastIndex(ip, ":"); portIndex != -1 { + ip = ip[:portIndex] + ip = strings.Trim(ip, "[]") // IPv6 addresses are wrapped in brackets. + } + + a.reportConnectionsMu.Lock() + defer a.reportConnectionsMu.Unlock() + a.reportConnections = append(a.reportConnections, &proto.ReportConnectionRequest{ + Connection: &proto.Connection{ + Id: id[:], + Action: proto.Connection_CONNECT, + Type: connectionType, + Timestamp: timestamppb.New(time.Now()), + Ip: ip, + StatusCode: 0, + Reason: nil, + }, + }) + select { + case a.reportConnectionsUpdate <- struct{}{}: + default: + } + + return func(code int, reason string) { + a.reportConnectionsMu.Lock() + defer a.reportConnectionsMu.Unlock() + a.reportConnections = append(a.reportConnections, &proto.ReportConnectionRequest{ + Connection: &proto.Connection{ + Id: id[:], + Action: proto.Connection_DISCONNECT, + Type: connectionType, + Timestamp: timestamppb.New(time.Now()), + Ip: ip, + StatusCode: int32(code), //nolint:gosec + Reason: &reason, + }, + }) + select { + case a.reportConnectionsUpdate <- struct{}{}: + default: + } + } +} + // fetchServiceBannerLoop fetches the service banner on an interval. It will // not be fetched immediately; the expectation is that it is primed elsewhere // (and must be done before the session actually starts). @@ -813,6 +925,10 @@ func (a *agent) run() (retErr error) { return resourcesmonitor.Start(ctx) }) + // Connection reports are part of auditing, we should keep sending them via + // gracefulShutdownBehaviorRemain. + connMan.startAgentAPI("report connections", gracefulShutdownBehaviorRemain, a.reportConnectionsLoop) + // channels to sync goroutines below // handle manifest // | diff --git a/agent/agent_test.go b/agent/agent_test.go index 834e0a3e68151..f94152207ab34 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -159,7 +159,7 @@ func TestAgent_Stats_Magic(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() //nolint:dogsled - conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, agentClient, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -189,6 +189,8 @@ func TestAgent_Stats_Magic(t *testing.T) { _ = stdin.Close() err = session.Wait() require.NoError(t, err) + + assertConnectionReport(t, agentClient, proto.Connection_VSCODE, 0, "") }) t.Run("TracksJetBrains", func(t *testing.T) { @@ -225,7 +227,7 @@ func TestAgent_Stats_Magic(t *testing.T) { remotePort := sc.Text() //nolint:dogsled - conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, agentClient, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) @@ -261,6 +263,8 @@ func TestAgent_Stats_Magic(t *testing.T) { }, testutil.WaitLong, testutil.IntervalFast, "never saw stats after conn closes", ) + + assertConnectionReport(t, agentClient, proto.Connection_JETBRAINS, 0, "") }) } @@ -918,7 +922,7 @@ func TestAgent_SFTP(t *testing.T) { home = "/" + strings.ReplaceAll(home, "\\", "/") } //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -941,6 +945,10 @@ func TestAgent_SFTP(t *testing.T) { require.NoError(t, err) _, err = os.Stat(tempFile) require.NoError(t, err) + + // Close the client to trigger disconnect event. + _ = client.Close() + assertConnectionReport(t, agentClient, proto.Connection_SSH, 0, "") } func TestAgent_SCP(t *testing.T) { @@ -950,7 +958,7 @@ func TestAgent_SCP(t *testing.T) { defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -963,6 +971,10 @@ func TestAgent_SCP(t *testing.T) { require.NoError(t, err) _, err = os.Stat(tempFile) require.NoError(t, err) + + // Close the client to trigger disconnect event. + scpClient.Close() + assertConnectionReport(t, agentClient, proto.Connection_SSH, 0, "") } func TestAgent_FileTransferBlocked(t *testing.T) { @@ -987,7 +999,7 @@ func TestAgent_FileTransferBlocked(t *testing.T) { defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { + conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { o.BlockFileTransfer = true }) sshClient, err := conn.SSHClient(ctx) @@ -996,6 +1008,8 @@ func TestAgent_FileTransferBlocked(t *testing.T) { _, err = sftp.NewClient(sshClient) require.Error(t, err) assertFileTransferBlocked(t, err.Error()) + + assertConnectionReport(t, agentClient, proto.Connection_SSH, agentssh.BlockedFileTransferErrorCode, "") }) t.Run("SCP with go-scp package", func(t *testing.T) { @@ -1005,7 +1019,7 @@ func TestAgent_FileTransferBlocked(t *testing.T) { defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { + conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { o.BlockFileTransfer = true }) sshClient, err := conn.SSHClient(ctx) @@ -1018,6 +1032,8 @@ func TestAgent_FileTransferBlocked(t *testing.T) { err = scpClient.CopyFile(context.Background(), strings.NewReader("hello world"), tempFile, "0755") require.Error(t, err) assertFileTransferBlocked(t, err.Error()) + + assertConnectionReport(t, agentClient, proto.Connection_SSH, agentssh.BlockedFileTransferErrorCode, "") }) t.Run("Forbidden commands", func(t *testing.T) { @@ -1031,7 +1047,7 @@ func TestAgent_FileTransferBlocked(t *testing.T) { defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { + conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { o.BlockFileTransfer = true }) sshClient, err := conn.SSHClient(ctx) @@ -1053,6 +1069,8 @@ func TestAgent_FileTransferBlocked(t *testing.T) { msg, err := io.ReadAll(stdout) require.NoError(t, err) assertFileTransferBlocked(t, string(msg)) + + assertConnectionReport(t, agentClient, proto.Connection_SSH, agentssh.BlockedFileTransferErrorCode, "") }) } }) @@ -1661,8 +1679,16 @@ func TestAgent_ReconnectingPTY(t *testing.T) { defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) id := uuid.New() + + // Test that the connection is reported. This must be tested in the + // first connection because we care about verifying all of these. + netConn0, err := conn.ReconnectingPTY(ctx, id, 80, 80, "bash --norc") + require.NoError(t, err) + _ = netConn0.Close() + assertConnectionReport(t, agentClient, proto.Connection_RECONNECTING_PTY, 0, "") + // --norc disables executing .bashrc, which is often used to customize the bash prompt netConn1, err := conn.ReconnectingPTY(ctx, id, 80, 80, "bash --norc") require.NoError(t, err) @@ -2691,3 +2717,35 @@ func requireEcho(t *testing.T, conn net.Conn) { require.NoError(t, err) require.Equal(t, "test", string(b)) } + +func assertConnectionReport(t testing.TB, agentClient *agenttest.Client, connectionType proto.Connection_Type, status int, reason string) { + t.Helper() + + var reports []*proto.ReportConnectionRequest + if !assert.Eventually(t, func() bool { + reports = agentClient.GetConnectionReports() + return len(reports) >= 2 + }, testutil.WaitMedium, testutil.IntervalFast, "waiting for 2 connection reports or more; got %d", len(reports)) { + return + } + + assert.Len(t, reports, 2, "want 2 connection reports") + + assert.Equal(t, proto.Connection_CONNECT, reports[0].GetConnection().GetAction(), "first report should be connect") + assert.Equal(t, proto.Connection_DISCONNECT, reports[1].GetConnection().GetAction(), "second report should be disconnect") + assert.Equal(t, connectionType, reports[0].GetConnection().GetType(), "connect type should be %s", connectionType) + assert.Equal(t, connectionType, reports[1].GetConnection().GetType(), "disconnect type should be %s", connectionType) + t1 := reports[0].GetConnection().GetTimestamp().AsTime() + t2 := reports[1].GetConnection().GetTimestamp().AsTime() + assert.True(t, t1.Before(t2) || t1.Equal(t2), "connect timestamp should be before or equal to disconnect timestamp") + assert.NotEmpty(t, reports[0].GetConnection().GetIp(), "connect ip should not be empty") + assert.NotEmpty(t, reports[1].GetConnection().GetIp(), "disconnect ip should not be empty") + assert.Equal(t, 0, int(reports[0].GetConnection().GetStatusCode()), "connect status code should be 0") + assert.Equal(t, status, int(reports[1].GetConnection().GetStatusCode()), "disconnect status code should be %d", status) + assert.Equal(t, "", reports[0].GetConnection().GetReason(), "connect reason should be empty") + if reason != "" { + assert.Contains(t, reports[1].GetConnection().GetReason(), reason, "disconnect reason should contain %s", reason) + } else { + t.Logf("connection report disconnect reason: %s", reports[1].GetConnection().GetReason()) + } +} diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index 0f7d0adadc865..8e3659fd44164 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -79,6 +79,8 @@ const ( // BlockedFileTransferCommands contains a list of restricted file transfer commands. var BlockedFileTransferCommands = []string{"nc", "rsync", "scp", "sftp"} +type reportConnectionFunc func(id uuid.UUID, sessionType MagicSessionType, ip string) (disconnected func(code int, reason string)) + // Config sets configuration parameters for the agent SSH server. type Config struct { // MaxTimeout sets the absolute connection timeout, none if empty. If set to @@ -101,6 +103,8 @@ type Config struct { X11DisplayOffset *int // BlockFileTransfer restricts use of file transfer applications. BlockFileTransfer bool + // ReportConnection. + ReportConnection reportConnectionFunc } type Server struct { @@ -164,6 +168,9 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom return home } } + if config.ReportConnection == nil { + config.ReportConnection = func(uuid.UUID, MagicSessionType, string) func(int, string) { return func(int, string) {} } + } forwardHandler := &ssh.ForwardedTCPHandler{} unixForwardHandler := newForwardedUnixHandler(logger) @@ -186,7 +193,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom ChannelHandlers: map[string]ssh.ChannelHandler{ "direct-tcpip": func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) { // Wrapper is designed to find and track JetBrains Gateway connections. - wrapped := NewJetbrainsChannelWatcher(ctx, s.logger, newChan, &s.connCountJetBrains) + wrapped := NewJetbrainsChannelWatcher(ctx, s.logger, s.config.ReportConnection, newChan, &s.connCountJetBrains) ssh.DirectTCPIPHandler(srv, conn, wrapped, ctx) }, "direct-streamlocal@openssh.com": directStreamLocalHandler, @@ -298,6 +305,35 @@ func extractMagicSessionType(env []string) (magicType MagicSessionType, rawType }) } +// sessionCloseTracker is a wrapper around Session that tracks the exit code. +type sessionCloseTracker struct { + ssh.Session + exitOnce sync.Once + code atomic.Int64 +} + +var _ ssh.Session = &sessionCloseTracker{} + +func (s *sessionCloseTracker) track(code int) { + s.exitOnce.Do(func() { + s.code.Store(int64(code)) + }) +} + +func (s *sessionCloseTracker) exitCode() int { + return int(s.code.Load()) +} + +func (s *sessionCloseTracker) Exit(code int) error { + s.track(code) + return s.Session.Exit(code) +} + +func (s *sessionCloseTracker) Close() error { + s.track(1) + return s.Session.Close() +} + func (s *Server) sessionHandler(session ssh.Session) { ctx := session.Context() id := uuid.New() @@ -310,17 +346,23 @@ func (s *Server) sessionHandler(session ssh.Session) { ) logger.Info(ctx, "handling ssh session") + env := session.Environ() + magicType, magicTypeRaw, env := extractMagicSessionType(env) + if !s.trackSession(session, true) { + reason := "unable to accept new session, server is closing" + // Report connection attempt even if we couldn't accept it. + disconnected := s.config.ReportConnection(id, magicType, session.RemoteAddr().String()) + defer disconnected(1, reason) + + logger.Info(ctx, reason) // See (*Server).Close() for why we call Close instead of Exit. _ = session.Close() - logger.Info(ctx, "unable to accept new session, server is closing") return } defer s.trackSession(session, false) - env := session.Environ() - magicType, magicTypeRaw, env := extractMagicSessionType(env) - + reportSession := true switch magicType { case MagicSessionTypeVSCode: s.connCountVSCode.Add(1) @@ -328,6 +370,7 @@ func (s *Server) sessionHandler(session ssh.Session) { case MagicSessionTypeJetBrains: // Do nothing here because JetBrains launches hundreds of ssh sessions. // We instead track JetBrains in the single persistent tcp forwarding channel. + reportSession = false case MagicSessionTypeSSH: s.connCountSSHSession.Add(1) defer s.connCountSSHSession.Add(-1) @@ -335,6 +378,20 @@ func (s *Server) sessionHandler(session ssh.Session) { logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("raw_type", magicTypeRaw)) } + closeCause := func(string) {} + if reportSession { + var reason string + closeCause = func(r string) { reason = r } + + scr := &sessionCloseTracker{Session: session} + session = scr + + disconnected := s.config.ReportConnection(id, magicType, session.RemoteAddr().String()) + defer func() { + disconnected(scr.exitCode(), reason) + }() + } + if s.fileTransferBlocked(session) { s.logger.Warn(ctx, "file transfer blocked", slog.F("session_subsystem", session.Subsystem()), slog.F("raw_command", session.RawCommand())) @@ -343,6 +400,7 @@ func (s *Server) sessionHandler(session ssh.Session) { errorMessage := fmt.Sprintf("\x02%s\n", BlockedFileTransferErrorMessage) _, _ = session.Write([]byte(errorMessage)) } + closeCause("file transfer blocked") _ = session.Exit(BlockedFileTransferErrorCode) return } @@ -350,10 +408,14 @@ func (s *Server) sessionHandler(session ssh.Session) { switch ss := session.Subsystem(); ss { case "": case "sftp": - s.sftpHandler(logger, session) + err := s.sftpHandler(logger, session) + if err != nil { + closeCause(err.Error()) + } return default: logger.Warn(ctx, "unsupported subsystem", slog.F("subsystem", ss)) + closeCause(fmt.Sprintf("unsupported subsystem: %s", ss)) _ = session.Exit(1) return } @@ -362,8 +424,9 @@ func (s *Server) sessionHandler(session ssh.Session) { if hasX11 { display, handled := s.x11Handler(session.Context(), x11) if !handled { - _ = session.Exit(1) logger.Error(ctx, "x11 handler failed") + closeCause("x11 handler failed") + _ = session.Exit(1) return } env = append(env, fmt.Sprintf("DISPLAY=localhost:%d.%d", display, x11.ScreenNumber)) @@ -390,6 +453,8 @@ func (s *Server) sessionHandler(session ssh.Session) { slog.F("exit_code", code), ) + closeCause(fmt.Sprintf("process exited with error status: %d", exitError.ExitCode())) + // TODO(mafredri): For signal exit, there's also an "exit-signal" // request (session.Exit sends "exit-status"), however, since it's // not implemented on the session interface and not used by @@ -401,6 +466,7 @@ func (s *Server) sessionHandler(session ssh.Session) { logger.Warn(ctx, "ssh session failed", slog.Error(err)) // This exit code is designed to be unlikely to be confused for a legit exit code // from the process. + closeCause(err.Error()) _ = session.Exit(MagicSessionErrorCode) return } @@ -660,7 +726,7 @@ func handleSignal(logger slog.Logger, ssig ssh.Signal, signaler interface{ Signa } } -func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) { +func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) error { s.metrics.sftpConnectionsTotal.Add(1) ctx := session.Context() @@ -684,7 +750,7 @@ func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) { server, err := sftp.NewServer(session, opts...) if err != nil { logger.Debug(ctx, "initialize sftp server", slog.Error(err)) - return + return xerrors.Errorf("initialize sftp server: %w", err) } defer server.Close() @@ -699,11 +765,12 @@ func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) { // code but `scp` on macOS does (when using the default // SFTP backend). _ = session.Exit(0) - return + return nil } logger.Warn(ctx, "sftp server closed with error", slog.Error(err)) s.metrics.sftpServerErrors.Add(1) _ = session.Exit(1) + return xerrors.Errorf("sftp server closed with error: %w", err) } // EnvInfoer encapsulates external information required by CreateCommand. diff --git a/agent/agentssh/jetbrainstrack.go b/agent/agentssh/jetbrainstrack.go index 534f2899b11ae..9b2fdf83b21d0 100644 --- a/agent/agentssh/jetbrainstrack.go +++ b/agent/agentssh/jetbrainstrack.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/gliderlabs/ssh" + "github.com/google/uuid" "go.uber.org/atomic" gossh "golang.org/x/crypto/ssh" @@ -28,9 +29,11 @@ type JetbrainsChannelWatcher struct { gossh.NewChannel jetbrainsCounter *atomic.Int64 logger slog.Logger + originAddr string + reportConnection reportConnectionFunc } -func NewJetbrainsChannelWatcher(ctx ssh.Context, logger slog.Logger, newChannel gossh.NewChannel, counter *atomic.Int64) gossh.NewChannel { +func NewJetbrainsChannelWatcher(ctx ssh.Context, logger slog.Logger, reportConnection reportConnectionFunc, newChannel gossh.NewChannel, counter *atomic.Int64) gossh.NewChannel { d := localForwardChannelData{} if err := gossh.Unmarshal(newChannel.ExtraData(), &d); err != nil { // If the data fails to unmarshal, do nothing. @@ -61,12 +64,17 @@ func NewJetbrainsChannelWatcher(ctx ssh.Context, logger slog.Logger, newChannel NewChannel: newChannel, jetbrainsCounter: counter, logger: logger.With(slog.F("destination_port", d.DestPort)), + originAddr: d.OriginAddr, + reportConnection: reportConnection, } } func (w *JetbrainsChannelWatcher) Accept() (gossh.Channel, <-chan *gossh.Request, error) { + disconnected := w.reportConnection(uuid.New(), MagicSessionTypeJetBrains, w.originAddr) + c, r, err := w.NewChannel.Accept() if err != nil { + disconnected(1, err.Error()) return c, r, err } w.jetbrainsCounter.Add(1) @@ -77,6 +85,7 @@ func (w *JetbrainsChannelWatcher) Accept() (gossh.Channel, <-chan *gossh.Request Channel: c, done: func() { w.jetbrainsCounter.Add(-1) + disconnected(0, "") // nolint: gocritic // JetBrains is a proper noun and should be capitalized w.logger.Debug(context.Background(), "JetBrains watcher channel closed") }, diff --git a/agent/agenttest/client.go b/agent/agenttest/client.go index ed734c6df9f6c..b5fa6ea8c2189 100644 --- a/agent/agenttest/client.go +++ b/agent/agenttest/client.go @@ -158,20 +158,24 @@ func (c *Client) SetLogsChannel(ch chan<- *agentproto.BatchCreateLogsRequest) { c.fakeAgentAPI.SetLogsChannel(ch) } +func (c *Client) GetConnectionReports() []*agentproto.ReportConnectionRequest { + return c.fakeAgentAPI.GetConnectionReports() +} + type FakeAgentAPI struct { sync.Mutex t testing.TB logger slog.Logger - manifest *agentproto.Manifest - startupCh chan *agentproto.Startup - statsCh chan *agentproto.Stats - appHealthCh chan *agentproto.BatchUpdateAppHealthRequest - logsCh chan<- *agentproto.BatchCreateLogsRequest - lifecycleStates []codersdk.WorkspaceAgentLifecycle - metadata map[string]agentsdk.Metadata - timings []*agentproto.Timing - connections []*agentproto.Connection + manifest *agentproto.Manifest + startupCh chan *agentproto.Startup + statsCh chan *agentproto.Stats + appHealthCh chan *agentproto.BatchUpdateAppHealthRequest + logsCh chan<- *agentproto.BatchCreateLogsRequest + lifecycleStates []codersdk.WorkspaceAgentLifecycle + metadata map[string]agentsdk.Metadata + timings []*agentproto.Timing + connectionReports []*agentproto.ReportConnectionRequest getAnnouncementBannersFunc func() ([]codersdk.BannerConfig, error) getResourcesMonitoringConfigurationFunc func() (*agentproto.GetResourcesMonitoringConfigurationResponse, error) @@ -348,12 +352,18 @@ func (f *FakeAgentAPI) ScriptCompleted(_ context.Context, req *agentproto.Worksp func (f *FakeAgentAPI) ReportConnection(_ context.Context, req *agentproto.ReportConnectionRequest) (*emptypb.Empty, error) { f.Lock() - f.connections = append(f.connections, req.GetConnection()) + f.connectionReports = append(f.connectionReports, req) f.Unlock() return &emptypb.Empty{}, nil } +func (f *FakeAgentAPI) GetConnectionReports() []*agentproto.ReportConnectionRequest { + f.Lock() + defer f.Unlock() + return slices.Clone(f.connectionReports) +} + func NewFakeAgentAPI(t testing.TB, logger slog.Logger, manifest *agentproto.Manifest, statsCh chan *agentproto.Stats) *FakeAgentAPI { return &FakeAgentAPI{ t: t, diff --git a/agent/reconnectingpty/server.go b/agent/reconnectingpty/server.go index 465667c616180..bd9518aaad690 100644 --- a/agent/reconnectingpty/server.go +++ b/agent/reconnectingpty/server.go @@ -18,24 +18,33 @@ import ( "github.com/coder/coder/v2/codersdk/workspacesdk" ) +type reportConnectionFunc func(id uuid.UUID, ip string) (disconnected func(code int, reason string)) + type Server struct { logger slog.Logger connectionsTotal prometheus.Counter errorsTotal *prometheus.CounterVec commandCreator *agentssh.Server + reportConnection reportConnectionFunc connCount atomic.Int64 reconnectingPTYs sync.Map timeout time.Duration } // NewServer returns a new ReconnectingPTY server -func NewServer(logger slog.Logger, commandCreator *agentssh.Server, +func NewServer(logger slog.Logger, commandCreator *agentssh.Server, reportConnection reportConnectionFunc, connectionsTotal prometheus.Counter, errorsTotal *prometheus.CounterVec, timeout time.Duration, ) *Server { + if reportConnection == nil { + reportConnection = func(uuid.UUID, string) func(int, string) { + return func(int, string) {} + } + } return &Server{ logger: logger, commandCreator: commandCreator, + reportConnection: reportConnection, connectionsTotal: connectionsTotal, errorsTotal: errorsTotal, timeout: timeout, @@ -59,20 +68,31 @@ func (s *Server) Serve(ctx, hardCtx context.Context, l net.Listener) (retErr err slog.F("local", conn.LocalAddr().String())) clog.Info(ctx, "accepted conn") wg.Add(1) + disconnected := s.reportConnection(uuid.New(), conn.RemoteAddr().String()) closed := make(chan struct{}) go func() { + defer wg.Done() select { case <-closed: case <-hardCtx.Done(): + disconnected(1, "server shut down") _ = conn.Close() } - wg.Done() }() wg.Add(1) go func() { defer close(closed) defer wg.Done() - _ = s.handleConn(ctx, clog, conn) + err := s.handleConn(ctx, clog, conn) + if err != nil { + if ctx.Err() != nil { + disconnected(1, "server shutting down") + } else { + disconnected(1, err.Error()) + } + } else { + disconnected(0, "") + } }() } wg.Wait() From a77ceacf65df8f1c8d9e872a51d75518c34474e2 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Mon, 24 Feb 2025 13:16:21 +0000 Subject: [PATCH 2/7] chore: put connection reports behind experimental flag --- agent/agent.go | 11 +++++++++++ agent/agent_test.go | 23 ++++++++++++++++++----- cli/agent.go | 16 ++++++++++++++++ 3 files changed, 45 insertions(+), 5 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index f9b2754a4730c..06f96c4b93866 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -88,6 +88,8 @@ type Options struct { BlockFileTransfer bool Execer agentexec.Execer ContainerLister agentcontainers.Lister + + ExperimentalConnectionReports bool } type Client interface { @@ -189,6 +191,8 @@ func New(options Options) Agent { metrics: newAgentMetrics(prometheusRegistry), execer: options.Execer, lister: options.ContainerLister, + + experimentalConnectionReports: options.ExperimentalConnectionReports, } // Initially, we have a closed channel, reflecting the fact that we are not initially connected. // Each time we connect we replace the channel (while holding the closeMutex) with a new one @@ -263,6 +267,8 @@ type agent struct { metrics *agentMetrics execer agentexec.Execer lister agentcontainers.Lister + + experimentalConnectionReports bool } func (a *agent) TailnetConn() *tailnet.Conn { @@ -770,6 +776,11 @@ func (a *agent) reportConnectionsLoop(ctx context.Context, aAPI proto.DRPCAgentC } func (a *agent) reportConnection(id uuid.UUID, connectionType proto.Connection_Type, ip string) (disconnected func(code int, reason string)) { + // If the experiment hasn't been enabled, we don't report connections. + if !a.experimentalConnectionReports { + return func(int, string) {} // Noop. + } + // Remove the port from the IP. if portIndex := strings.LastIndex(ip, ":"); portIndex != -1 { ip = ip[:portIndex] diff --git a/agent/agent_test.go b/agent/agent_test.go index f94152207ab34..8aca8ccbf1fa9 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -159,7 +159,9 @@ func TestAgent_Stats_Magic(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() //nolint:dogsled - conn, agentClient, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, agentClient, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { + o.ExperimentalConnectionReports = true + }) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -227,7 +229,9 @@ func TestAgent_Stats_Magic(t *testing.T) { remotePort := sc.Text() //nolint:dogsled - conn, agentClient, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, agentClient, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { + o.ExperimentalConnectionReports = true + }) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) @@ -922,7 +926,9 @@ func TestAgent_SFTP(t *testing.T) { home = "/" + strings.ReplaceAll(home, "\\", "/") } //nolint:dogsled - conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { + o.ExperimentalConnectionReports = true + }) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -958,7 +964,9 @@ func TestAgent_SCP(t *testing.T) { defer cancel() //nolint:dogsled - conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { + o.ExperimentalConnectionReports = true + }) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -1001,6 +1009,7 @@ func TestAgent_FileTransferBlocked(t *testing.T) { //nolint:dogsled conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { o.BlockFileTransfer = true + o.ExperimentalConnectionReports = true }) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) @@ -1021,6 +1030,7 @@ func TestAgent_FileTransferBlocked(t *testing.T) { //nolint:dogsled conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { o.BlockFileTransfer = true + o.ExperimentalConnectionReports = true }) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) @@ -1049,6 +1059,7 @@ func TestAgent_FileTransferBlocked(t *testing.T) { //nolint:dogsled conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { o.BlockFileTransfer = true + o.ExperimentalConnectionReports = true }) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) @@ -1679,7 +1690,9 @@ func TestAgent_ReconnectingPTY(t *testing.T) { defer cancel() //nolint:dogsled - conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { + o.ExperimentalConnectionReports = true + }) id := uuid.New() // Test that the connection is reported. This must be tested in the diff --git a/cli/agent.go b/cli/agent.go index e8a46a84e071c..3f874a94916b5 100644 --- a/cli/agent.go +++ b/cli/agent.go @@ -54,6 +54,8 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { agentHeaderCommand string agentHeader []string devcontainersEnabled bool + + experimentalConnectionReports bool ) cmd := &serpent.Command{ Use: "agent", @@ -325,6 +327,10 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { containerLister = agentcontainers.NewDocker(execer) } + if experimentalConnectionReports { + logger.Info(ctx, "experimental connection reports enabled") + } + agnt := agent.New(agent.Options{ Client: client, Logger: logger, @@ -351,6 +357,8 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { BlockFileTransfer: blockFileTransfer, Execer: execer, ContainerLister: containerLister, + + ExperimentalConnectionReports: experimentalConnectionReports, }) promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger) @@ -480,6 +488,14 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { Description: "Allow the agent to automatically detect running devcontainers.", Value: serpent.BoolOf(&devcontainersEnabled), }, + { + Flag: "experimental-connection-reports-enable", + Hidden: true, + Default: "false", + Env: "CODER_AGENT_EXPERIMENTAL_CONNECTION_REPORTS_ENABLE", + Description: "Enable experimental connection reports.", + Value: serpent.BoolOf(&experimentalConnectionReports), + }, } return cmd From 17ddb8e622d6e9595b709d11262a5d5dcef32588 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Mon, 24 Feb 2025 13:59:50 +0000 Subject: [PATCH 3/7] simplify --- agent/agent.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 06f96c4b93866..17e738d532681 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -752,6 +752,8 @@ func (a *agent) reportConnectionsLoop(ctx context.Context, aAPI proto.DRPCAgentC break } payload := a.reportConnections[0] + // Release lock while we send the payload, this is safe + // since we only append to the slice. a.reportConnectionsMu.Unlock() logger := a.logger.With(slog.F("payload", payload)) @@ -763,14 +765,10 @@ func (a *agent) reportConnectionsLoop(ctx context.Context, aAPI proto.DRPCAgentC logger.Debug(ctx, "successfully reported connection") + // Remove the payload we sent. a.reportConnectionsMu.Lock() a.reportConnections = a.reportConnections[1:] - count := len(a.reportConnections) a.reportConnectionsMu.Unlock() - - if count == 0 { - break - } } } } From 82b6fab00bb8c09d6675fc01612358d3e1f7c02d Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Mon, 24 Feb 2025 14:12:52 +0000 Subject: [PATCH 4/7] net.SplitHostPort --- agent/agent.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 17e738d532681..57ddf30a70dc8 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "net" "net/http" "net/netip" "os" @@ -779,10 +780,12 @@ func (a *agent) reportConnection(id uuid.UUID, connectionType proto.Connection_T return func(int, string) {} // Noop. } - // Remove the port from the IP. - if portIndex := strings.LastIndex(ip, ":"); portIndex != -1 { - ip = ip[:portIndex] - ip = strings.Trim(ip, "[]") // IPv6 addresses are wrapped in brackets. + // Remove the port from the IP because ports are not supported in coderd. + if host, _, err := net.SplitHostPort(ip); err != nil { + a.logger.Error(a.hardCtx, "split host and port for connection report failed", slog.F("ip", ip), slog.Error(err)) + } else { + // Best effort. + ip = host } a.reportConnectionsMu.Lock() From 0318212af1193d95b2ec81c308d4072c74a3e39d Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Thu, 27 Feb 2025 10:18:43 +0000 Subject: [PATCH 5/7] add buffer limit --- agent/agent.go | 60 ++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 46 insertions(+), 14 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index ab5d70a9fe13f..bff60e95919db 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -781,6 +781,18 @@ func (a *agent) reportConnectionsLoop(ctx context.Context, aAPI proto.DRPCAgentC } } +const ( + // reportConnectionBufferLimit limits the number of connection reports we + // buffer to avoid growing the buffer indefinitely. This should not happen + // unless the agent has lost connection to coderd for a long time or if + // the agent is being spammed with connections. + // + // If we assume ~150 byte per connection report, this would be around 300KB + // of memory which seems acceptable. We could reduce this if necessary by + // not using the proto struct directly. + reportConnectionBufferLimit = 2048 +) + func (a *agent) reportConnection(id uuid.UUID, connectionType proto.Connection_Type, ip string) (disconnected func(code int, reason string)) { // If the experiment hasn't been enabled, we don't report connections. if !a.experimentalConnectionReports { @@ -797,25 +809,45 @@ func (a *agent) reportConnection(id uuid.UUID, connectionType proto.Connection_T a.reportConnectionsMu.Lock() defer a.reportConnectionsMu.Unlock() - a.reportConnections = append(a.reportConnections, &proto.ReportConnectionRequest{ - Connection: &proto.Connection{ - Id: id[:], - Action: proto.Connection_CONNECT, - Type: connectionType, - Timestamp: timestamppb.New(time.Now()), - Ip: ip, - StatusCode: 0, - Reason: nil, - }, - }) - select { - case a.reportConnectionsUpdate <- struct{}{}: - default: + + if len(a.reportConnections) >= reportConnectionBufferLimit { + a.logger.Warn(a.hardCtx, "connection report buffer limit reached, dropping connect", + slog.F("limit", reportConnectionBufferLimit), + slog.F("connection_id", id), + slog.F("connection_type", connectionType), + slog.F("ip", ip), + ) + } else { + a.reportConnections = append(a.reportConnections, &proto.ReportConnectionRequest{ + Connection: &proto.Connection{ + Id: id[:], + Action: proto.Connection_CONNECT, + Type: connectionType, + Timestamp: timestamppb.New(time.Now()), + Ip: ip, + StatusCode: 0, + Reason: nil, + }, + }) + select { + case a.reportConnectionsUpdate <- struct{}{}: + default: + } } return func(code int, reason string) { a.reportConnectionsMu.Lock() defer a.reportConnectionsMu.Unlock() + if len(a.reportConnections) >= reportConnectionBufferLimit { + a.logger.Warn(a.hardCtx, "connection report buffer limit reached, dropping connect", + slog.F("limit", reportConnectionBufferLimit), + slog.F("connection_id", id), + slog.F("connection_type", connectionType), + slog.F("ip", ip), + ) + return + } + a.reportConnections = append(a.reportConnections, &proto.ReportConnectionRequest{ Connection: &proto.Connection{ Id: id[:], From d47e8a4616e539be781fa127c898e175451cd973 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Thu, 27 Feb 2025 10:27:24 +0000 Subject: [PATCH 6/7] release pointer --- agent/agent.go | 1 + 1 file changed, 1 insertion(+) diff --git a/agent/agent.go b/agent/agent.go index bff60e95919db..c610e6ae59e1d 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -775,6 +775,7 @@ func (a *agent) reportConnectionsLoop(ctx context.Context, aAPI proto.DRPCAgentC // Remove the payload we sent. a.reportConnectionsMu.Lock() + a.reportConnections[0] = nil // Release the pointer from the underlying array. a.reportConnections = a.reportConnections[1:] a.reportConnectionsMu.Unlock() } From 601b6f4216c8dc3feb04a3b8b7243f6e999d9a18 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Thu, 27 Feb 2025 10:29:47 +0000 Subject: [PATCH 7/7] typo --- agent/agent.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agent/agent.go b/agent/agent.go index c610e6ae59e1d..504fff2386826 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -840,7 +840,7 @@ func (a *agent) reportConnection(id uuid.UUID, connectionType proto.Connection_T a.reportConnectionsMu.Lock() defer a.reportConnectionsMu.Unlock() if len(a.reportConnections) >= reportConnectionBufferLimit { - a.logger.Warn(a.hardCtx, "connection report buffer limit reached, dropping connect", + a.logger.Warn(a.hardCtx, "connection report buffer limit reached, dropping disconnect", slog.F("limit", reportConnectionBufferLimit), slog.F("connection_id", id), slog.F("connection_type", connectionType),