diff --git a/agent/agent.go b/agent/agent.go index 285636cd31344..504fff2386826 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -8,6 +8,7 @@ import ( "fmt" "hash/fnv" "io" + "net" "net/http" "net/netip" "os" @@ -28,6 +29,7 @@ import ( "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" "golang.org/x/xerrors" + "google.golang.org/protobuf/types/known/timestamppb" "tailscale.com/net/speedtest" "tailscale.com/tailcfg" "tailscale.com/types/netlogtype" @@ -90,6 +92,7 @@ type Options struct { ContainerLister agentcontainers.Lister ExperimentalContainersEnabled bool + ExperimentalConnectionReports bool } type Client interface { @@ -177,6 +180,7 @@ func New(options Options) Agent { lifecycleUpdate: make(chan struct{}, 1), lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1), lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}}, + reportConnectionsUpdate: make(chan struct{}, 1), ignorePorts: options.IgnorePorts, portCacheDuration: options.PortCacheDuration, reportMetadataInterval: options.ReportMetadataInterval, @@ -192,6 +196,7 @@ func New(options Options) Agent { lister: options.ContainerLister, experimentalDevcontainersEnabled: options.ExperimentalContainersEnabled, + experimentalConnectionReports: options.ExperimentalConnectionReports, } // Initially, we have a closed channel, reflecting the fact that we are not initially connected. // Each time we connect we replace the channel (while holding the closeMutex) with a new one @@ -252,6 +257,10 @@ type agent struct { lifecycleStates []agentsdk.PostLifecycleRequest lifecycleLastReportedIndex int // Keeps track of the last lifecycle state we successfully reported. + reportConnectionsUpdate chan struct{} + reportConnectionsMu sync.Mutex + reportConnections []*proto.ReportConnectionRequest + network *tailnet.Conn statsReporter *statsReporter logSender *agentsdk.LogSender @@ -264,6 +273,7 @@ type agent struct { lister agentcontainers.Lister experimentalDevcontainersEnabled bool + experimentalConnectionReports bool } func (a *agent) TailnetConn() *tailnet.Conn { @@ -279,6 +289,24 @@ func (a *agent) init() { UpdateEnv: a.updateCommandEnv, WorkingDirectory: func() string { return a.manifest.Load().Directory }, BlockFileTransfer: a.blockFileTransfer, + ReportConnection: func(id uuid.UUID, magicType agentssh.MagicSessionType, ip string) func(code int, reason string) { + var connectionType proto.Connection_Type + switch magicType { + case agentssh.MagicSessionTypeSSH: + connectionType = proto.Connection_SSH + case agentssh.MagicSessionTypeVSCode: + connectionType = proto.Connection_VSCODE + case agentssh.MagicSessionTypeJetBrains: + connectionType = proto.Connection_JETBRAINS + case agentssh.MagicSessionTypeUnknown: + connectionType = proto.Connection_TYPE_UNSPECIFIED + default: + a.logger.Error(a.hardCtx, "unhandled magic session type when reporting connection", slog.F("magic_type", magicType)) + connectionType = proto.Connection_TYPE_UNSPECIFIED + } + + return a.reportConnection(id, connectionType, ip) + }, }) if err != nil { panic(err) @@ -301,6 +329,9 @@ func (a *agent) init() { a.reconnectingPTYServer = reconnectingpty.NewServer( a.logger.Named("reconnecting-pty"), a.sshServer, + func(id uuid.UUID, ip string) func(code int, reason string) { + return a.reportConnection(id, proto.Connection_RECONNECTING_PTY, ip) + }, a.metrics.connectionsTotal, a.metrics.reconnectingPTYErrors, a.reconnectingPTYTimeout, func(s *reconnectingpty.Server) { @@ -713,6 +744,129 @@ func (a *agent) setLifecycle(state codersdk.WorkspaceAgentLifecycle) { } } +// reportConnectionsLoop reports connections to the agent for auditing. +func (a *agent) reportConnectionsLoop(ctx context.Context, aAPI proto.DRPCAgentClient24) error { + for { + select { + case <-a.reportConnectionsUpdate: + case <-ctx.Done(): + return ctx.Err() + } + + for { + a.reportConnectionsMu.Lock() + if len(a.reportConnections) == 0 { + a.reportConnectionsMu.Unlock() + break + } + payload := a.reportConnections[0] + // Release lock while we send the payload, this is safe + // since we only append to the slice. + a.reportConnectionsMu.Unlock() + + logger := a.logger.With(slog.F("payload", payload)) + logger.Debug(ctx, "reporting connection") + _, err := aAPI.ReportConnection(ctx, payload) + if err != nil { + return xerrors.Errorf("failed to report connection: %w", err) + } + + logger.Debug(ctx, "successfully reported connection") + + // Remove the payload we sent. + a.reportConnectionsMu.Lock() + a.reportConnections[0] = nil // Release the pointer from the underlying array. + a.reportConnections = a.reportConnections[1:] + a.reportConnectionsMu.Unlock() + } + } +} + +const ( + // reportConnectionBufferLimit limits the number of connection reports we + // buffer to avoid growing the buffer indefinitely. This should not happen + // unless the agent has lost connection to coderd for a long time or if + // the agent is being spammed with connections. + // + // If we assume ~150 byte per connection report, this would be around 300KB + // of memory which seems acceptable. We could reduce this if necessary by + // not using the proto struct directly. + reportConnectionBufferLimit = 2048 +) + +func (a *agent) reportConnection(id uuid.UUID, connectionType proto.Connection_Type, ip string) (disconnected func(code int, reason string)) { + // If the experiment hasn't been enabled, we don't report connections. + if !a.experimentalConnectionReports { + return func(int, string) {} // Noop. + } + + // Remove the port from the IP because ports are not supported in coderd. + if host, _, err := net.SplitHostPort(ip); err != nil { + a.logger.Error(a.hardCtx, "split host and port for connection report failed", slog.F("ip", ip), slog.Error(err)) + } else { + // Best effort. + ip = host + } + + a.reportConnectionsMu.Lock() + defer a.reportConnectionsMu.Unlock() + + if len(a.reportConnections) >= reportConnectionBufferLimit { + a.logger.Warn(a.hardCtx, "connection report buffer limit reached, dropping connect", + slog.F("limit", reportConnectionBufferLimit), + slog.F("connection_id", id), + slog.F("connection_type", connectionType), + slog.F("ip", ip), + ) + } else { + a.reportConnections = append(a.reportConnections, &proto.ReportConnectionRequest{ + Connection: &proto.Connection{ + Id: id[:], + Action: proto.Connection_CONNECT, + Type: connectionType, + Timestamp: timestamppb.New(time.Now()), + Ip: ip, + StatusCode: 0, + Reason: nil, + }, + }) + select { + case a.reportConnectionsUpdate <- struct{}{}: + default: + } + } + + return func(code int, reason string) { + a.reportConnectionsMu.Lock() + defer a.reportConnectionsMu.Unlock() + if len(a.reportConnections) >= reportConnectionBufferLimit { + a.logger.Warn(a.hardCtx, "connection report buffer limit reached, dropping disconnect", + slog.F("limit", reportConnectionBufferLimit), + slog.F("connection_id", id), + slog.F("connection_type", connectionType), + slog.F("ip", ip), + ) + return + } + + a.reportConnections = append(a.reportConnections, &proto.ReportConnectionRequest{ + Connection: &proto.Connection{ + Id: id[:], + Action: proto.Connection_DISCONNECT, + Type: connectionType, + Timestamp: timestamppb.New(time.Now()), + Ip: ip, + StatusCode: int32(code), //nolint:gosec + Reason: &reason, + }, + }) + select { + case a.reportConnectionsUpdate <- struct{}{}: + default: + } + } +} + // fetchServiceBannerLoop fetches the service banner on an interval. It will // not be fetched immediately; the expectation is that it is primed elsewhere // (and must be done before the session actually starts). @@ -823,6 +977,10 @@ func (a *agent) run() (retErr error) { return resourcesmonitor.Start(ctx) }) + // Connection reports are part of auditing, we should keep sending them via + // gracefulShutdownBehaviorRemain. + connMan.startAgentAPI("report connections", gracefulShutdownBehaviorRemain, a.reportConnectionsLoop) + // channels to sync goroutines below // handle manifest // | diff --git a/agent/agent_test.go b/agent/agent_test.go index 935309e98d873..7ccce20ae776e 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -163,7 +163,9 @@ func TestAgent_Stats_Magic(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() //nolint:dogsled - conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, agentClient, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { + o.ExperimentalConnectionReports = true + }) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -193,6 +195,8 @@ func TestAgent_Stats_Magic(t *testing.T) { _ = stdin.Close() err = session.Wait() require.NoError(t, err) + + assertConnectionReport(t, agentClient, proto.Connection_VSCODE, 0, "") }) t.Run("TracksJetBrains", func(t *testing.T) { @@ -229,7 +233,9 @@ func TestAgent_Stats_Magic(t *testing.T) { remotePort := sc.Text() //nolint:dogsled - conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, agentClient, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { + o.ExperimentalConnectionReports = true + }) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) @@ -265,6 +271,8 @@ func TestAgent_Stats_Magic(t *testing.T) { }, testutil.WaitLong, testutil.IntervalFast, "never saw stats after conn closes", ) + + assertConnectionReport(t, agentClient, proto.Connection_JETBRAINS, 0, "") }) } @@ -922,7 +930,9 @@ func TestAgent_SFTP(t *testing.T) { home = "/" + strings.ReplaceAll(home, "\\", "/") } //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { + o.ExperimentalConnectionReports = true + }) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -945,6 +955,10 @@ func TestAgent_SFTP(t *testing.T) { require.NoError(t, err) _, err = os.Stat(tempFile) require.NoError(t, err) + + // Close the client to trigger disconnect event. + _ = client.Close() + assertConnectionReport(t, agentClient, proto.Connection_SSH, 0, "") } func TestAgent_SCP(t *testing.T) { @@ -954,7 +968,9 @@ func TestAgent_SCP(t *testing.T) { defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { + o.ExperimentalConnectionReports = true + }) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -967,6 +983,10 @@ func TestAgent_SCP(t *testing.T) { require.NoError(t, err) _, err = os.Stat(tempFile) require.NoError(t, err) + + // Close the client to trigger disconnect event. + scpClient.Close() + assertConnectionReport(t, agentClient, proto.Connection_SSH, 0, "") } func TestAgent_FileTransferBlocked(t *testing.T) { @@ -991,8 +1011,9 @@ func TestAgent_FileTransferBlocked(t *testing.T) { defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { + conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { o.BlockFileTransfer = true + o.ExperimentalConnectionReports = true }) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) @@ -1000,6 +1021,8 @@ func TestAgent_FileTransferBlocked(t *testing.T) { _, err = sftp.NewClient(sshClient) require.Error(t, err) assertFileTransferBlocked(t, err.Error()) + + assertConnectionReport(t, agentClient, proto.Connection_SSH, agentssh.BlockedFileTransferErrorCode, "") }) t.Run("SCP with go-scp package", func(t *testing.T) { @@ -1009,8 +1032,9 @@ func TestAgent_FileTransferBlocked(t *testing.T) { defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { + conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { o.BlockFileTransfer = true + o.ExperimentalConnectionReports = true }) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) @@ -1022,6 +1046,8 @@ func TestAgent_FileTransferBlocked(t *testing.T) { err = scpClient.CopyFile(context.Background(), strings.NewReader("hello world"), tempFile, "0755") require.Error(t, err) assertFileTransferBlocked(t, err.Error()) + + assertConnectionReport(t, agentClient, proto.Connection_SSH, agentssh.BlockedFileTransferErrorCode, "") }) t.Run("Forbidden commands", func(t *testing.T) { @@ -1035,8 +1061,9 @@ func TestAgent_FileTransferBlocked(t *testing.T) { defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { + conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { o.BlockFileTransfer = true + o.ExperimentalConnectionReports = true }) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) @@ -1057,6 +1084,8 @@ func TestAgent_FileTransferBlocked(t *testing.T) { msg, err := io.ReadAll(stdout) require.NoError(t, err) assertFileTransferBlocked(t, string(msg)) + + assertConnectionReport(t, agentClient, proto.Connection_SSH, agentssh.BlockedFileTransferErrorCode, "") }) } }) @@ -1665,8 +1694,18 @@ func TestAgent_ReconnectingPTY(t *testing.T) { defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { + o.ExperimentalConnectionReports = true + }) id := uuid.New() + + // Test that the connection is reported. This must be tested in the + // first connection because we care about verifying all of these. + netConn0, err := conn.ReconnectingPTY(ctx, id, 80, 80, "bash --norc") + require.NoError(t, err) + _ = netConn0.Close() + assertConnectionReport(t, agentClient, proto.Connection_RECONNECTING_PTY, 0, "") + // --norc disables executing .bashrc, which is often used to customize the bash prompt netConn1, err := conn.ReconnectingPTY(ctx, id, 80, 80, "bash --norc") require.NoError(t, err) @@ -2763,3 +2802,35 @@ func requireEcho(t *testing.T, conn net.Conn) { require.NoError(t, err) require.Equal(t, "test", string(b)) } + +func assertConnectionReport(t testing.TB, agentClient *agenttest.Client, connectionType proto.Connection_Type, status int, reason string) { + t.Helper() + + var reports []*proto.ReportConnectionRequest + if !assert.Eventually(t, func() bool { + reports = agentClient.GetConnectionReports() + return len(reports) >= 2 + }, testutil.WaitMedium, testutil.IntervalFast, "waiting for 2 connection reports or more; got %d", len(reports)) { + return + } + + assert.Len(t, reports, 2, "want 2 connection reports") + + assert.Equal(t, proto.Connection_CONNECT, reports[0].GetConnection().GetAction(), "first report should be connect") + assert.Equal(t, proto.Connection_DISCONNECT, reports[1].GetConnection().GetAction(), "second report should be disconnect") + assert.Equal(t, connectionType, reports[0].GetConnection().GetType(), "connect type should be %s", connectionType) + assert.Equal(t, connectionType, reports[1].GetConnection().GetType(), "disconnect type should be %s", connectionType) + t1 := reports[0].GetConnection().GetTimestamp().AsTime() + t2 := reports[1].GetConnection().GetTimestamp().AsTime() + assert.True(t, t1.Before(t2) || t1.Equal(t2), "connect timestamp should be before or equal to disconnect timestamp") + assert.NotEmpty(t, reports[0].GetConnection().GetIp(), "connect ip should not be empty") + assert.NotEmpty(t, reports[1].GetConnection().GetIp(), "disconnect ip should not be empty") + assert.Equal(t, 0, int(reports[0].GetConnection().GetStatusCode()), "connect status code should be 0") + assert.Equal(t, status, int(reports[1].GetConnection().GetStatusCode()), "disconnect status code should be %d", status) + assert.Equal(t, "", reports[0].GetConnection().GetReason(), "connect reason should be empty") + if reason != "" { + assert.Contains(t, reports[1].GetConnection().GetReason(), reason, "disconnect reason should contain %s", reason) + } else { + t.Logf("connection report disconnect reason: %s", reports[1].GetConnection().GetReason()) + } +} diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index 3b09df0e388dd..4a5d3215db911 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -78,6 +78,8 @@ const ( // BlockedFileTransferCommands contains a list of restricted file transfer commands. var BlockedFileTransferCommands = []string{"nc", "rsync", "scp", "sftp"} +type reportConnectionFunc func(id uuid.UUID, sessionType MagicSessionType, ip string) (disconnected func(code int, reason string)) + // Config sets configuration parameters for the agent SSH server. type Config struct { // MaxTimeout sets the absolute connection timeout, none if empty. If set to @@ -100,6 +102,8 @@ type Config struct { X11DisplayOffset *int // BlockFileTransfer restricts use of file transfer applications. BlockFileTransfer bool + // ReportConnection. + ReportConnection reportConnectionFunc } type Server struct { @@ -152,6 +156,9 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom return home } } + if config.ReportConnection == nil { + config.ReportConnection = func(uuid.UUID, MagicSessionType, string) func(int, string) { return func(int, string) {} } + } forwardHandler := &ssh.ForwardedTCPHandler{} unixForwardHandler := newForwardedUnixHandler(logger) @@ -174,7 +181,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom ChannelHandlers: map[string]ssh.ChannelHandler{ "direct-tcpip": func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) { // Wrapper is designed to find and track JetBrains Gateway connections. - wrapped := NewJetbrainsChannelWatcher(ctx, s.logger, newChan, &s.connCountJetBrains) + wrapped := NewJetbrainsChannelWatcher(ctx, s.logger, s.config.ReportConnection, newChan, &s.connCountJetBrains) ssh.DirectTCPIPHandler(srv, conn, wrapped, ctx) }, "direct-streamlocal@openssh.com": directStreamLocalHandler, @@ -288,6 +295,35 @@ func extractMagicSessionType(env []string) (magicType MagicSessionType, rawType }) } +// sessionCloseTracker is a wrapper around Session that tracks the exit code. +type sessionCloseTracker struct { + ssh.Session + exitOnce sync.Once + code atomic.Int64 +} + +var _ ssh.Session = &sessionCloseTracker{} + +func (s *sessionCloseTracker) track(code int) { + s.exitOnce.Do(func() { + s.code.Store(int64(code)) + }) +} + +func (s *sessionCloseTracker) exitCode() int { + return int(s.code.Load()) +} + +func (s *sessionCloseTracker) Exit(code int) error { + s.track(code) + return s.Session.Exit(code) +} + +func (s *sessionCloseTracker) Close() error { + s.track(1) + return s.Session.Close() +} + func (s *Server) sessionHandler(session ssh.Session) { ctx := session.Context() id := uuid.New() @@ -300,17 +336,23 @@ func (s *Server) sessionHandler(session ssh.Session) { ) logger.Info(ctx, "handling ssh session") + env := session.Environ() + magicType, magicTypeRaw, env := extractMagicSessionType(env) + if !s.trackSession(session, true) { + reason := "unable to accept new session, server is closing" + // Report connection attempt even if we couldn't accept it. + disconnected := s.config.ReportConnection(id, magicType, session.RemoteAddr().String()) + defer disconnected(1, reason) + + logger.Info(ctx, reason) // See (*Server).Close() for why we call Close instead of Exit. _ = session.Close() - logger.Info(ctx, "unable to accept new session, server is closing") return } defer s.trackSession(session, false) - env := session.Environ() - magicType, magicTypeRaw, env := extractMagicSessionType(env) - + reportSession := true switch magicType { case MagicSessionTypeVSCode: s.connCountVSCode.Add(1) @@ -318,6 +360,7 @@ func (s *Server) sessionHandler(session ssh.Session) { case MagicSessionTypeJetBrains: // Do nothing here because JetBrains launches hundreds of ssh sessions. // We instead track JetBrains in the single persistent tcp forwarding channel. + reportSession = false case MagicSessionTypeSSH: s.connCountSSHSession.Add(1) defer s.connCountSSHSession.Add(-1) @@ -325,6 +368,20 @@ func (s *Server) sessionHandler(session ssh.Session) { logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("raw_type", magicTypeRaw)) } + closeCause := func(string) {} + if reportSession { + var reason string + closeCause = func(r string) { reason = r } + + scr := &sessionCloseTracker{Session: session} + session = scr + + disconnected := s.config.ReportConnection(id, magicType, session.RemoteAddr().String()) + defer func() { + disconnected(scr.exitCode(), reason) + }() + } + if s.fileTransferBlocked(session) { s.logger.Warn(ctx, "file transfer blocked", slog.F("session_subsystem", session.Subsystem()), slog.F("raw_command", session.RawCommand())) @@ -333,6 +390,7 @@ func (s *Server) sessionHandler(session ssh.Session) { errorMessage := fmt.Sprintf("\x02%s\n", BlockedFileTransferErrorMessage) _, _ = session.Write([]byte(errorMessage)) } + closeCause("file transfer blocked") _ = session.Exit(BlockedFileTransferErrorCode) return } @@ -340,10 +398,14 @@ func (s *Server) sessionHandler(session ssh.Session) { switch ss := session.Subsystem(); ss { case "": case "sftp": - s.sftpHandler(logger, session) + err := s.sftpHandler(logger, session) + if err != nil { + closeCause(err.Error()) + } return default: logger.Warn(ctx, "unsupported subsystem", slog.F("subsystem", ss)) + closeCause(fmt.Sprintf("unsupported subsystem: %s", ss)) _ = session.Exit(1) return } @@ -352,8 +414,9 @@ func (s *Server) sessionHandler(session ssh.Session) { if hasX11 { display, handled := s.x11Handler(session.Context(), x11) if !handled { - _ = session.Exit(1) logger.Error(ctx, "x11 handler failed") + closeCause("x11 handler failed") + _ = session.Exit(1) return } env = append(env, fmt.Sprintf("DISPLAY=localhost:%d.%d", display, x11.ScreenNumber)) @@ -380,6 +443,8 @@ func (s *Server) sessionHandler(session ssh.Session) { slog.F("exit_code", code), ) + closeCause(fmt.Sprintf("process exited with error status: %d", exitError.ExitCode())) + // TODO(mafredri): For signal exit, there's also an "exit-signal" // request (session.Exit sends "exit-status"), however, since it's // not implemented on the session interface and not used by @@ -391,6 +456,7 @@ func (s *Server) sessionHandler(session ssh.Session) { logger.Warn(ctx, "ssh session failed", slog.Error(err)) // This exit code is designed to be unlikely to be confused for a legit exit code // from the process. + closeCause(err.Error()) _ = session.Exit(MagicSessionErrorCode) return } @@ -650,7 +716,7 @@ func handleSignal(logger slog.Logger, ssig ssh.Signal, signaler interface{ Signa } } -func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) { +func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) error { s.metrics.sftpConnectionsTotal.Add(1) ctx := session.Context() @@ -674,7 +740,7 @@ func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) { server, err := sftp.NewServer(session, opts...) if err != nil { logger.Debug(ctx, "initialize sftp server", slog.Error(err)) - return + return xerrors.Errorf("initialize sftp server: %w", err) } defer server.Close() @@ -689,11 +755,12 @@ func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) { // code but `scp` on macOS does (when using the default // SFTP backend). _ = session.Exit(0) - return + return nil } logger.Warn(ctx, "sftp server closed with error", slog.Error(err)) s.metrics.sftpServerErrors.Add(1) _ = session.Exit(1) + return xerrors.Errorf("sftp server closed with error: %w", err) } // CreateCommand processes raw command input with OpenSSH-like behavior. diff --git a/agent/agentssh/jetbrainstrack.go b/agent/agentssh/jetbrainstrack.go index 534f2899b11ae..9b2fdf83b21d0 100644 --- a/agent/agentssh/jetbrainstrack.go +++ b/agent/agentssh/jetbrainstrack.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/gliderlabs/ssh" + "github.com/google/uuid" "go.uber.org/atomic" gossh "golang.org/x/crypto/ssh" @@ -28,9 +29,11 @@ type JetbrainsChannelWatcher struct { gossh.NewChannel jetbrainsCounter *atomic.Int64 logger slog.Logger + originAddr string + reportConnection reportConnectionFunc } -func NewJetbrainsChannelWatcher(ctx ssh.Context, logger slog.Logger, newChannel gossh.NewChannel, counter *atomic.Int64) gossh.NewChannel { +func NewJetbrainsChannelWatcher(ctx ssh.Context, logger slog.Logger, reportConnection reportConnectionFunc, newChannel gossh.NewChannel, counter *atomic.Int64) gossh.NewChannel { d := localForwardChannelData{} if err := gossh.Unmarshal(newChannel.ExtraData(), &d); err != nil { // If the data fails to unmarshal, do nothing. @@ -61,12 +64,17 @@ func NewJetbrainsChannelWatcher(ctx ssh.Context, logger slog.Logger, newChannel NewChannel: newChannel, jetbrainsCounter: counter, logger: logger.With(slog.F("destination_port", d.DestPort)), + originAddr: d.OriginAddr, + reportConnection: reportConnection, } } func (w *JetbrainsChannelWatcher) Accept() (gossh.Channel, <-chan *gossh.Request, error) { + disconnected := w.reportConnection(uuid.New(), MagicSessionTypeJetBrains, w.originAddr) + c, r, err := w.NewChannel.Accept() if err != nil { + disconnected(1, err.Error()) return c, r, err } w.jetbrainsCounter.Add(1) @@ -77,6 +85,7 @@ func (w *JetbrainsChannelWatcher) Accept() (gossh.Channel, <-chan *gossh.Request Channel: c, done: func() { w.jetbrainsCounter.Add(-1) + disconnected(0, "") // nolint: gocritic // JetBrains is a proper noun and should be capitalized w.logger.Debug(context.Background(), "JetBrains watcher channel closed") }, diff --git a/agent/agenttest/client.go b/agent/agenttest/client.go index ed734c6df9f6c..b5fa6ea8c2189 100644 --- a/agent/agenttest/client.go +++ b/agent/agenttest/client.go @@ -158,20 +158,24 @@ func (c *Client) SetLogsChannel(ch chan<- *agentproto.BatchCreateLogsRequest) { c.fakeAgentAPI.SetLogsChannel(ch) } +func (c *Client) GetConnectionReports() []*agentproto.ReportConnectionRequest { + return c.fakeAgentAPI.GetConnectionReports() +} + type FakeAgentAPI struct { sync.Mutex t testing.TB logger slog.Logger - manifest *agentproto.Manifest - startupCh chan *agentproto.Startup - statsCh chan *agentproto.Stats - appHealthCh chan *agentproto.BatchUpdateAppHealthRequest - logsCh chan<- *agentproto.BatchCreateLogsRequest - lifecycleStates []codersdk.WorkspaceAgentLifecycle - metadata map[string]agentsdk.Metadata - timings []*agentproto.Timing - connections []*agentproto.Connection + manifest *agentproto.Manifest + startupCh chan *agentproto.Startup + statsCh chan *agentproto.Stats + appHealthCh chan *agentproto.BatchUpdateAppHealthRequest + logsCh chan<- *agentproto.BatchCreateLogsRequest + lifecycleStates []codersdk.WorkspaceAgentLifecycle + metadata map[string]agentsdk.Metadata + timings []*agentproto.Timing + connectionReports []*agentproto.ReportConnectionRequest getAnnouncementBannersFunc func() ([]codersdk.BannerConfig, error) getResourcesMonitoringConfigurationFunc func() (*agentproto.GetResourcesMonitoringConfigurationResponse, error) @@ -348,12 +352,18 @@ func (f *FakeAgentAPI) ScriptCompleted(_ context.Context, req *agentproto.Worksp func (f *FakeAgentAPI) ReportConnection(_ context.Context, req *agentproto.ReportConnectionRequest) (*emptypb.Empty, error) { f.Lock() - f.connections = append(f.connections, req.GetConnection()) + f.connectionReports = append(f.connectionReports, req) f.Unlock() return &emptypb.Empty{}, nil } +func (f *FakeAgentAPI) GetConnectionReports() []*agentproto.ReportConnectionRequest { + f.Lock() + defer f.Unlock() + return slices.Clone(f.connectionReports) +} + func NewFakeAgentAPI(t testing.TB, logger slog.Logger, manifest *agentproto.Manifest, statsCh chan *agentproto.Stats) *FakeAgentAPI { return &FakeAgentAPI{ t: t, diff --git a/agent/reconnectingpty/server.go b/agent/reconnectingpty/server.go index ab4ce854c789c..7ad7db976c8b0 100644 --- a/agent/reconnectingpty/server.go +++ b/agent/reconnectingpty/server.go @@ -20,11 +20,14 @@ import ( "github.com/coder/coder/v2/codersdk/workspacesdk" ) +type reportConnectionFunc func(id uuid.UUID, ip string) (disconnected func(code int, reason string)) + type Server struct { logger slog.Logger connectionsTotal prometheus.Counter errorsTotal *prometheus.CounterVec commandCreator *agentssh.Server + reportConnection reportConnectionFunc connCount atomic.Int64 reconnectingPTYs sync.Map timeout time.Duration @@ -33,13 +36,19 @@ type Server struct { } // NewServer returns a new ReconnectingPTY server -func NewServer(logger slog.Logger, commandCreator *agentssh.Server, +func NewServer(logger slog.Logger, commandCreator *agentssh.Server, reportConnection reportConnectionFunc, connectionsTotal prometheus.Counter, errorsTotal *prometheus.CounterVec, timeout time.Duration, opts ...func(*Server), ) *Server { + if reportConnection == nil { + reportConnection = func(uuid.UUID, string) func(int, string) { + return func(int, string) {} + } + } s := &Server{ logger: logger, commandCreator: commandCreator, + reportConnection: reportConnection, connectionsTotal: connectionsTotal, errorsTotal: errorsTotal, timeout: timeout, @@ -67,20 +76,31 @@ func (s *Server) Serve(ctx, hardCtx context.Context, l net.Listener) (retErr err slog.F("local", conn.LocalAddr().String())) clog.Info(ctx, "accepted conn") wg.Add(1) + disconnected := s.reportConnection(uuid.New(), conn.RemoteAddr().String()) closed := make(chan struct{}) go func() { + defer wg.Done() select { case <-closed: case <-hardCtx.Done(): + disconnected(1, "server shut down") _ = conn.Close() } - wg.Done() }() wg.Add(1) go func() { defer close(closed) defer wg.Done() - _ = s.handleConn(ctx, clog, conn) + err := s.handleConn(ctx, clog, conn) + if err != nil { + if ctx.Err() != nil { + disconnected(1, "server shutting down") + } else { + disconnected(1, err.Error()) + } + } else { + disconnected(0, "") + } }() } wg.Wait() diff --git a/cli/agent.go b/cli/agent.go index 01d6c36f7a045..638f7083805ab 100644 --- a/cli/agent.go +++ b/cli/agent.go @@ -54,6 +54,8 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { agentHeaderCommand string agentHeader []string devcontainersEnabled bool + + experimentalConnectionReports bool ) cmd := &serpent.Command{ Use: "agent", @@ -325,6 +327,10 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { containerLister = agentcontainers.NewDocker(execer) } + if experimentalConnectionReports { + logger.Info(ctx, "experimental connection reports enabled") + } + agnt := agent.New(agent.Options{ Client: client, Logger: logger, @@ -353,6 +359,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { ContainerLister: containerLister, ExperimentalContainersEnabled: devcontainersEnabled, + ExperimentalConnectionReports: experimentalConnectionReports, }) promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger) @@ -482,6 +489,14 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { Description: "Allow the agent to automatically detect running devcontainers.", Value: serpent.BoolOf(&devcontainersEnabled), }, + { + Flag: "experimental-connection-reports-enable", + Hidden: true, + Default: "false", + Env: "CODER_AGENT_EXPERIMENTAL_CONNECTION_REPORTS_ENABLE", + Description: "Enable experimental connection reports.", + Value: serpent.BoolOf(&experimentalConnectionReports), + }, } return cmd