Skip to content

Commit cf1ff18

Browse files
mafredriaslilac
authored andcommitted
feat(agent): add connection reporting for SSH and reconnecting PTY (#16652)
Updates #15139
1 parent 7de918d commit cf1ff18

File tree

7 files changed

+382
-32
lines changed

7 files changed

+382
-32
lines changed

agent/agent.go

+158
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"fmt"
99
"hash/fnv"
1010
"io"
11+
"net"
1112
"net/http"
1213
"net/netip"
1314
"os"
@@ -28,6 +29,7 @@ import (
2829
"golang.org/x/exp/slices"
2930
"golang.org/x/sync/errgroup"
3031
"golang.org/x/xerrors"
32+
"google.golang.org/protobuf/types/known/timestamppb"
3133
"tailscale.com/net/speedtest"
3234
"tailscale.com/tailcfg"
3335
"tailscale.com/types/netlogtype"
@@ -90,6 +92,7 @@ type Options struct {
9092
ContainerLister agentcontainers.Lister
9193

9294
ExperimentalContainersEnabled bool
95+
ExperimentalConnectionReports bool
9396
}
9497

9598
type Client interface {
@@ -177,6 +180,7 @@ func New(options Options) Agent {
177180
lifecycleUpdate: make(chan struct{}, 1),
178181
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
179182
lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}},
183+
reportConnectionsUpdate: make(chan struct{}, 1),
180184
ignorePorts: options.IgnorePorts,
181185
portCacheDuration: options.PortCacheDuration,
182186
reportMetadataInterval: options.ReportMetadataInterval,
@@ -192,6 +196,7 @@ func New(options Options) Agent {
192196
lister: options.ContainerLister,
193197

194198
experimentalDevcontainersEnabled: options.ExperimentalContainersEnabled,
199+
experimentalConnectionReports: options.ExperimentalConnectionReports,
195200
}
196201
// Initially, we have a closed channel, reflecting the fact that we are not initially connected.
197202
// Each time we connect we replace the channel (while holding the closeMutex) with a new one
@@ -252,6 +257,10 @@ type agent struct {
252257
lifecycleStates []agentsdk.PostLifecycleRequest
253258
lifecycleLastReportedIndex int // Keeps track of the last lifecycle state we successfully reported.
254259

260+
reportConnectionsUpdate chan struct{}
261+
reportConnectionsMu sync.Mutex
262+
reportConnections []*proto.ReportConnectionRequest
263+
255264
network *tailnet.Conn
256265
statsReporter *statsReporter
257266
logSender *agentsdk.LogSender
@@ -264,6 +273,7 @@ type agent struct {
264273
lister agentcontainers.Lister
265274

266275
experimentalDevcontainersEnabled bool
276+
experimentalConnectionReports bool
267277
}
268278

269279
func (a *agent) TailnetConn() *tailnet.Conn {
@@ -279,6 +289,24 @@ func (a *agent) init() {
279289
UpdateEnv: a.updateCommandEnv,
280290
WorkingDirectory: func() string { return a.manifest.Load().Directory },
281291
BlockFileTransfer: a.blockFileTransfer,
292+
ReportConnection: func(id uuid.UUID, magicType agentssh.MagicSessionType, ip string) func(code int, reason string) {
293+
var connectionType proto.Connection_Type
294+
switch magicType {
295+
case agentssh.MagicSessionTypeSSH:
296+
connectionType = proto.Connection_SSH
297+
case agentssh.MagicSessionTypeVSCode:
298+
connectionType = proto.Connection_VSCODE
299+
case agentssh.MagicSessionTypeJetBrains:
300+
connectionType = proto.Connection_JETBRAINS
301+
case agentssh.MagicSessionTypeUnknown:
302+
connectionType = proto.Connection_TYPE_UNSPECIFIED
303+
default:
304+
a.logger.Error(a.hardCtx, "unhandled magic session type when reporting connection", slog.F("magic_type", magicType))
305+
connectionType = proto.Connection_TYPE_UNSPECIFIED
306+
}
307+
308+
return a.reportConnection(id, connectionType, ip)
309+
},
282310
})
283311
if err != nil {
284312
panic(err)
@@ -301,6 +329,9 @@ func (a *agent) init() {
301329
a.reconnectingPTYServer = reconnectingpty.NewServer(
302330
a.logger.Named("reconnecting-pty"),
303331
a.sshServer,
332+
func(id uuid.UUID, ip string) func(code int, reason string) {
333+
return a.reportConnection(id, proto.Connection_RECONNECTING_PTY, ip)
334+
},
304335
a.metrics.connectionsTotal, a.metrics.reconnectingPTYErrors,
305336
a.reconnectingPTYTimeout,
306337
func(s *reconnectingpty.Server) {
@@ -713,6 +744,129 @@ func (a *agent) setLifecycle(state codersdk.WorkspaceAgentLifecycle) {
713744
}
714745
}
715746

747+
// reportConnectionsLoop reports connections to the agent for auditing.
748+
func (a *agent) reportConnectionsLoop(ctx context.Context, aAPI proto.DRPCAgentClient24) error {
749+
for {
750+
select {
751+
case <-a.reportConnectionsUpdate:
752+
case <-ctx.Done():
753+
return ctx.Err()
754+
}
755+
756+
for {
757+
a.reportConnectionsMu.Lock()
758+
if len(a.reportConnections) == 0 {
759+
a.reportConnectionsMu.Unlock()
760+
break
761+
}
762+
payload := a.reportConnections[0]
763+
// Release lock while we send the payload, this is safe
764+
// since we only append to the slice.
765+
a.reportConnectionsMu.Unlock()
766+
767+
logger := a.logger.With(slog.F("payload", payload))
768+
logger.Debug(ctx, "reporting connection")
769+
_, err := aAPI.ReportConnection(ctx, payload)
770+
if err != nil {
771+
return xerrors.Errorf("failed to report connection: %w", err)
772+
}
773+
774+
logger.Debug(ctx, "successfully reported connection")
775+
776+
// Remove the payload we sent.
777+
a.reportConnectionsMu.Lock()
778+
a.reportConnections[0] = nil // Release the pointer from the underlying array.
779+
a.reportConnections = a.reportConnections[1:]
780+
a.reportConnectionsMu.Unlock()
781+
}
782+
}
783+
}
784+
785+
const (
786+
// reportConnectionBufferLimit limits the number of connection reports we
787+
// buffer to avoid growing the buffer indefinitely. This should not happen
788+
// unless the agent has lost connection to coderd for a long time or if
789+
// the agent is being spammed with connections.
790+
//
791+
// If we assume ~150 byte per connection report, this would be around 300KB
792+
// of memory which seems acceptable. We could reduce this if necessary by
793+
// not using the proto struct directly.
794+
reportConnectionBufferLimit = 2048
795+
)
796+
797+
func (a *agent) reportConnection(id uuid.UUID, connectionType proto.Connection_Type, ip string) (disconnected func(code int, reason string)) {
798+
// If the experiment hasn't been enabled, we don't report connections.
799+
if !a.experimentalConnectionReports {
800+
return func(int, string) {} // Noop.
801+
}
802+
803+
// Remove the port from the IP because ports are not supported in coderd.
804+
if host, _, err := net.SplitHostPort(ip); err != nil {
805+
a.logger.Error(a.hardCtx, "split host and port for connection report failed", slog.F("ip", ip), slog.Error(err))
806+
} else {
807+
// Best effort.
808+
ip = host
809+
}
810+
811+
a.reportConnectionsMu.Lock()
812+
defer a.reportConnectionsMu.Unlock()
813+
814+
if len(a.reportConnections) >= reportConnectionBufferLimit {
815+
a.logger.Warn(a.hardCtx, "connection report buffer limit reached, dropping connect",
816+
slog.F("limit", reportConnectionBufferLimit),
817+
slog.F("connection_id", id),
818+
slog.F("connection_type", connectionType),
819+
slog.F("ip", ip),
820+
)
821+
} else {
822+
a.reportConnections = append(a.reportConnections, &proto.ReportConnectionRequest{
823+
Connection: &proto.Connection{
824+
Id: id[:],
825+
Action: proto.Connection_CONNECT,
826+
Type: connectionType,
827+
Timestamp: timestamppb.New(time.Now()),
828+
Ip: ip,
829+
StatusCode: 0,
830+
Reason: nil,
831+
},
832+
})
833+
select {
834+
case a.reportConnectionsUpdate <- struct{}{}:
835+
default:
836+
}
837+
}
838+
839+
return func(code int, reason string) {
840+
a.reportConnectionsMu.Lock()
841+
defer a.reportConnectionsMu.Unlock()
842+
if len(a.reportConnections) >= reportConnectionBufferLimit {
843+
a.logger.Warn(a.hardCtx, "connection report buffer limit reached, dropping disconnect",
844+
slog.F("limit", reportConnectionBufferLimit),
845+
slog.F("connection_id", id),
846+
slog.F("connection_type", connectionType),
847+
slog.F("ip", ip),
848+
)
849+
return
850+
}
851+
852+
a.reportConnections = append(a.reportConnections, &proto.ReportConnectionRequest{
853+
Connection: &proto.Connection{
854+
Id: id[:],
855+
Action: proto.Connection_DISCONNECT,
856+
Type: connectionType,
857+
Timestamp: timestamppb.New(time.Now()),
858+
Ip: ip,
859+
StatusCode: int32(code), //nolint:gosec
860+
Reason: &reason,
861+
},
862+
})
863+
select {
864+
case a.reportConnectionsUpdate <- struct{}{}:
865+
default:
866+
}
867+
}
868+
}
869+
716870
// fetchServiceBannerLoop fetches the service banner on an interval. It will
717871
// not be fetched immediately; the expectation is that it is primed elsewhere
718872
// (and must be done before the session actually starts).
@@ -823,6 +977,10 @@ func (a *agent) run() (retErr error) {
823977
return resourcesmonitor.Start(ctx)
824978
})
825979

980+
// Connection reports are part of auditing, we should keep sending them via
981+
// gracefulShutdownBehaviorRemain.
982+
connMan.startAgentAPI("report connections", gracefulShutdownBehaviorRemain, a.reportConnectionsLoop)
983+
826984
// channels to sync goroutines below
827985
// handle manifest
828986
// |

0 commit comments

Comments
 (0)