Skip to content

feat(agent): add connection reporting for SSH and reconnecing PTY #16652

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"hash/fnv"
"io"
"net"
"net/http"
"net/netip"
"os"
Expand All @@ -28,6 +29,7 @@ import (
"golang.org/x/exp/slices"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/timestamppb"
"tailscale.com/net/speedtest"
"tailscale.com/tailcfg"
"tailscale.com/types/netlogtype"
Expand Down Expand Up @@ -90,6 +92,7 @@ type Options struct {
ContainerLister agentcontainers.Lister

ExperimentalContainersEnabled bool
ExperimentalConnectionReports bool
}

type Client interface {
Expand Down Expand Up @@ -177,6 +180,7 @@ func New(options Options) Agent {
lifecycleUpdate: make(chan struct{}, 1),
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}},
reportConnectionsUpdate: make(chan struct{}, 1),
ignorePorts: options.IgnorePorts,
portCacheDuration: options.PortCacheDuration,
reportMetadataInterval: options.ReportMetadataInterval,
Expand All @@ -192,6 +196,7 @@ func New(options Options) Agent {
lister: options.ContainerLister,

experimentalDevcontainersEnabled: options.ExperimentalContainersEnabled,
experimentalConnectionReports: options.ExperimentalConnectionReports,
}
// Initially, we have a closed channel, reflecting the fact that we are not initially connected.
// Each time we connect we replace the channel (while holding the closeMutex) with a new one
Expand Down Expand Up @@ -252,6 +257,10 @@ type agent struct {
lifecycleStates []agentsdk.PostLifecycleRequest
lifecycleLastReportedIndex int // Keeps track of the last lifecycle state we successfully reported.

reportConnectionsUpdate chan struct{}
reportConnectionsMu sync.Mutex
reportConnections []*proto.ReportConnectionRequest

network *tailnet.Conn
statsReporter *statsReporter
logSender *agentsdk.LogSender
Expand All @@ -264,6 +273,7 @@ type agent struct {
lister agentcontainers.Lister

experimentalDevcontainersEnabled bool
experimentalConnectionReports bool
}

func (a *agent) TailnetConn() *tailnet.Conn {
Expand All @@ -279,6 +289,24 @@ func (a *agent) init() {
UpdateEnv: a.updateCommandEnv,
WorkingDirectory: func() string { return a.manifest.Load().Directory },
BlockFileTransfer: a.blockFileTransfer,
ReportConnection: func(id uuid.UUID, magicType agentssh.MagicSessionType, ip string) func(code int, reason string) {
var connectionType proto.Connection_Type
switch magicType {
case agentssh.MagicSessionTypeSSH:
connectionType = proto.Connection_SSH
case agentssh.MagicSessionTypeVSCode:
connectionType = proto.Connection_VSCODE
case agentssh.MagicSessionTypeJetBrains:
connectionType = proto.Connection_JETBRAINS
case agentssh.MagicSessionTypeUnknown:
connectionType = proto.Connection_TYPE_UNSPECIFIED
default:
a.logger.Error(a.hardCtx, "unhandled magic session type when reporting connection", slog.F("magic_type", magicType))
connectionType = proto.Connection_TYPE_UNSPECIFIED
}

return a.reportConnection(id, connectionType, ip)
},
})
if err != nil {
panic(err)
Expand All @@ -301,6 +329,9 @@ func (a *agent) init() {
a.reconnectingPTYServer = reconnectingpty.NewServer(
a.logger.Named("reconnecting-pty"),
a.sshServer,
func(id uuid.UUID, ip string) func(code int, reason string) {
return a.reportConnection(id, proto.Connection_RECONNECTING_PTY, ip)
},
a.metrics.connectionsTotal, a.metrics.reconnectingPTYErrors,
a.reconnectingPTYTimeout,
func(s *reconnectingpty.Server) {
Expand Down Expand Up @@ -713,6 +744,129 @@ func (a *agent) setLifecycle(state codersdk.WorkspaceAgentLifecycle) {
}
}

// reportConnectionsLoop reports connections to the agent for auditing.
func (a *agent) reportConnectionsLoop(ctx context.Context, aAPI proto.DRPCAgentClient24) error {
for {
select {
case <-a.reportConnectionsUpdate:
case <-ctx.Done():
return ctx.Err()
}

for {
a.reportConnectionsMu.Lock()
if len(a.reportConnections) == 0 {
a.reportConnectionsMu.Unlock()
break
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a label here for clarity? break will always break the innermost loop but I always have trouble remembering that personally.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you review in VS Code, the syntax highlighting can be helpful here!
image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only use labels when I need to. I don't think we have examples in our code for labels breaking the inner loop. Do we?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You got me 😄 ❤️

}
payload := a.reportConnections[0]
// Release lock while we send the payload, this is safe
// since we only append to the slice.
a.reportConnectionsMu.Unlock()

logger := a.logger.With(slog.F("payload", payload))
logger.Debug(ctx, "reporting connection")
_, err := aAPI.ReportConnection(ctx, payload)
if err != nil {
return xerrors.Errorf("failed to report connection: %w", err)
}

logger.Debug(ctx, "successfully reported connection")

// Remove the payload we sent.
a.reportConnectionsMu.Lock()
a.reportConnections[0] = nil // Release the pointer from the underlying array.
a.reportConnections = a.reportConnections[1:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just make reportConnections channel? If it's an append only slice that is only read in this function.

This slice behavior is correct, just feels like a weaker implementation of a channel.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it could be a channel, but how big to make it? How many in-flight reports is "too many"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair, a channel would not be a terrible option here. It requires upfront allocation which can either be a good or a bad thing in memory constrained systems. For now I've limited this to 2048 reports pending, or about 300KB. We can revisit this later if needed.

a.reportConnectionsMu.Unlock()
}
}
}

const (
// reportConnectionBufferLimit limits the number of connection reports we
// buffer to avoid growing the buffer indefinitely. This should not happen
// unless the agent has lost connection to coderd for a long time or if
// the agent is being spammed with connections.
//
// If we assume ~150 byte per connection report, this would be around 300KB
// of memory which seems acceptable. We could reduce this if necessary by
// not using the proto struct directly.
reportConnectionBufferLimit = 2048
)

func (a *agent) reportConnection(id uuid.UUID, connectionType proto.Connection_Type, ip string) (disconnected func(code int, reason string)) {
// If the experiment hasn't been enabled, we don't report connections.
if !a.experimentalConnectionReports {
return func(int, string) {} // Noop.
}

// Remove the port from the IP because ports are not supported in coderd.
if host, _, err := net.SplitHostPort(ip); err != nil {
a.logger.Error(a.hardCtx, "split host and port for connection report failed", slog.F("ip", ip), slog.Error(err))
} else {
// Best effort.
ip = host
}

a.reportConnectionsMu.Lock()
defer a.reportConnectionsMu.Unlock()

if len(a.reportConnections) >= reportConnectionBufferLimit {
a.logger.Warn(a.hardCtx, "connection report buffer limit reached, dropping connect",
slog.F("limit", reportConnectionBufferLimit),
slog.F("connection_id", id),
slog.F("connection_type", connectionType),
slog.F("ip", ip),
)
} else {
a.reportConnections = append(a.reportConnections, &proto.ReportConnectionRequest{
Connection: &proto.Connection{
Id: id[:],
Action: proto.Connection_CONNECT,
Type: connectionType,
Timestamp: timestamppb.New(time.Now()),
Ip: ip,
StatusCode: 0,
Reason: nil,
},
})
select {
case a.reportConnectionsUpdate <- struct{}{}:
default:
}
}

return func(code int, reason string) {
a.reportConnectionsMu.Lock()
defer a.reportConnectionsMu.Unlock()
if len(a.reportConnections) >= reportConnectionBufferLimit {
a.logger.Warn(a.hardCtx, "connection report buffer limit reached, dropping disconnect",
slog.F("limit", reportConnectionBufferLimit),
slog.F("connection_id", id),
slog.F("connection_type", connectionType),
slog.F("ip", ip),
)
return
}

a.reportConnections = append(a.reportConnections, &proto.ReportConnectionRequest{
Connection: &proto.Connection{
Id: id[:],
Action: proto.Connection_DISCONNECT,
Type: connectionType,
Timestamp: timestamppb.New(time.Now()),
Ip: ip,
StatusCode: int32(code), //nolint:gosec
Reason: &reason,
},
})
select {
case a.reportConnectionsUpdate <- struct{}{}:
default:
}
}
}

// fetchServiceBannerLoop fetches the service banner on an interval. It will
// not be fetched immediately; the expectation is that it is primed elsewhere
// (and must be done before the session actually starts).
Expand Down Expand Up @@ -823,6 +977,10 @@ func (a *agent) run() (retErr error) {
return resourcesmonitor.Start(ctx)
})

// Connection reports are part of auditing, we should keep sending them via
// gracefulShutdownBehaviorRemain.
connMan.startAgentAPI("report connections", gracefulShutdownBehaviorRemain, a.reportConnectionsLoop)

// channels to sync goroutines below
// handle manifest
// |
Expand Down
Loading
Loading