Skip to content

feat: add agent stats for different connection types #6412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 139 additions & 45 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"os/user"
"path/filepath"
"runtime"
"sort"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -56,6 +57,14 @@ const (
// command just returning a nonzero exit code, and is chosen as an arbitrary, high number
// unlikely to shadow other exit codes, which are typically 1, 2, 3, etc.
MagicSessionErrorCode = 229

// MagicSSHSessionTypeEnvironmentVariable is used to track the purpose behind an SSH connection.
// This is stripped from any commands being executed, and is counted towards connection stats.
MagicSSHSessionTypeEnvironmentVariable = "__CODER_SSH_SESSION_TYPE"
// MagicSSHSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself.
MagicSSHSessionTypeVSCode = "vscode"
// MagicSSHSessionTypeJetBrains is set in the SSH config by the JetBrains extension to identify itself.
MagicSSHSessionTypeJetBrains = "jetbrains"
)

type Options struct {
Expand Down Expand Up @@ -146,6 +155,15 @@ type agent struct {

network *tailnet.Conn
connStatsChan chan *agentsdk.Stats

statRxPackets atomic.Int64
statRxBytes atomic.Int64
statTxPackets atomic.Int64
statTxBytes atomic.Int64
connCountVSCode atomic.Int64
connCountJetBrains atomic.Int64
connCountReconnectingPTY atomic.Int64
connCountSSHSession atomic.Int64
}

// runLoop attempts to start the agent in a retry loop.
Expand Down Expand Up @@ -350,33 +368,7 @@ func (a *agent) run(ctx context.Context) error {
return xerrors.New("agent is closed")
}

setStatInterval := func(d time.Duration) {
network.SetConnStatsCallback(d, 2048,
func(_, _ time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) {
select {
case a.connStatsChan <- convertAgentStats(virtual):
default:
a.logger.Warn(ctx, "network stat dropped")
}
},
)
}

// Report statistics from the created network.
cl, err := a.client.ReportStats(ctx, a.logger, a.connStatsChan, setStatInterval)
if err != nil {
a.logger.Error(ctx, "report stats", slog.Error(err))
} else {
if err = a.trackConnGoroutine(func() {
// This is OK because the agent never re-creates the tailnet
// and the only shutdown indicator is agent.Close().
<-a.closed
_ = cl.Close()
}); err != nil {
a.logger.Debug(ctx, "report stats goroutine", slog.Error(err))
_ = cl.Close()
}
}
a.startReportingConnectionStats(ctx)
} else {
// Update the DERP map!
network.SetDERPMap(metadata.DERPMap)
Expand Down Expand Up @@ -765,23 +757,6 @@ func (a *agent) init(ctx context.Context) {
go a.runLoop(ctx)
}

func convertAgentStats(counts map[netlogtype.Connection]netlogtype.Counts) *agentsdk.Stats {
stats := &agentsdk.Stats{
ConnectionsByProto: map[string]int64{},
ConnectionCount: int64(len(counts)),
}

for conn, count := range counts {
stats.ConnectionsByProto[conn.Proto.String()]++
stats.RxPackets += int64(count.RxPackets)
stats.RxBytes += int64(count.RxBytes)
stats.TxPackets += int64(count.TxPackets)
stats.TxBytes += int64(count.TxBytes)
}

return stats
}

// createCommand processes raw command input with OpenSSH-like behavior.
// If the rawCommand provided is empty, it will default to the users shell.
// This injects environment variables specified by the user at launch too.
Expand Down Expand Up @@ -892,7 +867,27 @@ func (a *agent) createCommand(ctx context.Context, rawCommand string, env []stri

func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
ctx := session.Context()
cmd, err := a.createCommand(ctx, session.RawCommand(), session.Environ())
env := session.Environ()
var magicType string
for index, kv := range env {
if !strings.HasPrefix(kv, MagicSSHSessionTypeEnvironmentVariable) {
continue
}
magicType = strings.TrimPrefix(kv, MagicSSHSessionTypeEnvironmentVariable+"=")
env = append(env[:index], env[index+1:]...)
}
switch magicType {
case MagicSSHSessionTypeVSCode:
a.connCountVSCode.Add(1)
case MagicSSHSessionTypeJetBrains:
a.connCountJetBrains.Add(1)
case "":
a.connCountSSHSession.Add(1)
default:
a.logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("type", magicType))
}

cmd, err := a.createCommand(ctx, session.RawCommand(), env)
if err != nil {
return err
}
Expand Down Expand Up @@ -990,6 +985,8 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, msg codersdk.WorkspaceAgentReconnectingPTYInit, conn net.Conn) (retErr error) {
defer conn.Close()

a.connCountReconnectingPTY.Add(1)

connectionID := uuid.NewString()
logger = logger.With(slog.F("id", msg.ID), slog.F("connection_id", connectionID))

Expand Down Expand Up @@ -1180,6 +1177,103 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
}
}

// startReportingConnectionStats runs the connection stats reporting goroutine.
func (a *agent) startReportingConnectionStats(ctx context.Context) {
reportStats := func(networkStats map[netlogtype.Connection]netlogtype.Counts) {
stats := &agentsdk.Stats{
ConnectionCount: int64(len(networkStats)),
ConnectionsByProto: map[string]int64{},
}
// Tailscale resets counts on every report!
// We'd rather have these compound, like Linux does!
for conn, counts := range networkStats {
stats.ConnectionsByProto[conn.Proto.String()]++
stats.RxBytes = a.statRxBytes.Add(int64(counts.RxBytes))
stats.RxPackets = a.statRxPackets.Add(int64(counts.RxPackets))
stats.TxBytes = a.statTxBytes.Add(int64(counts.TxBytes))
stats.TxPackets = a.statTxPackets.Add(int64(counts.TxPackets))
}

// Tailscale's connection stats are not cumulative, but it makes no sense to make
// ours temporary.
stats.SessionCountSSH = a.connCountSSHSession.Load()
stats.SessionCountVSCode = a.connCountVSCode.Load()
stats.SessionCountJetBrains = a.connCountJetBrains.Load()
stats.SessionCountReconnectingPTY = a.connCountReconnectingPTY.Load()

// Compute the median connection latency!
var wg sync.WaitGroup
var mu sync.Mutex
status := a.network.Status()
durations := []float64{}
ctx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
for nodeID, peer := range status.Peer {
if !peer.Active {
continue
}
addresses, found := a.network.NodeAddresses(nodeID)
if !found {
continue
}
if len(addresses) == 0 {
continue
}
wg.Add(1)
go func() {
defer wg.Done()
duration, _, _, err := a.network.Ping(ctx, addresses[0].Addr())
if err != nil {
return
}
mu.Lock()
durations = append(durations, float64(duration.Microseconds()))
mu.Unlock()
}()
}
wg.Wait()
sort.Float64s(durations)
durationsLength := len(durations)
if durationsLength == 0 {
stats.ConnectionMedianLatencyMS = -1
} else if durationsLength%2 == 0 {
stats.ConnectionMedianLatencyMS = (durations[durationsLength/2-1] + durations[durationsLength/2]) / 2
} else {
stats.ConnectionMedianLatencyMS = durations[durationsLength/2]
}
// Convert from microseconds to milliseconds.
stats.ConnectionMedianLatencyMS /= 1000

select {
case a.connStatsChan <- stats:
default:
a.logger.Warn(ctx, "network stat dropped")
}
}

// Report statistics from the created network.
cl, err := a.client.ReportStats(ctx, a.logger, a.connStatsChan, func(d time.Duration) {
a.network.SetConnStatsCallback(d, 2048,
func(_, _ time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) {
reportStats(virtual)
},
)
})
if err != nil {
a.logger.Error(ctx, "report stats", slog.Error(err))
} else {
if err = a.trackConnGoroutine(func() {
// This is OK because the agent never re-creates the tailnet
// and the only shutdown indicator is agent.Close().
<-a.closed
_ = cl.Close()
}); err != nil {
a.logger.Debug(ctx, "report stats goroutine", slog.Error(err))
_ = cl.Close()
}
}
}

// isClosed returns whether the API is closed or not.
func (a *agent) isClosed() bool {
select {
Expand Down
44 changes: 42 additions & 2 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func TestAgent_Stats_SSH(t *testing.T) {
require.Eventuallyf(t, func() bool {
var ok bool
s, ok = <-stats
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountSSH == 1
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats: %+v", s,
)
Expand Down Expand Up @@ -102,7 +102,47 @@ func TestAgent_Stats_ReconnectingPTY(t *testing.T) {
require.Eventuallyf(t, func() bool {
var ok bool
s, ok = <-stats
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountReconnectingPTY == 1
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats: %+v", s,
)
}

func TestAgent_Stats_Magic(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()

conn, _, stats, _ := setupAgent(t, agentsdk.Metadata{}, 0)
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
session, err := sshClient.NewSession()
require.NoError(t, err)
session.Setenv(agent.MagicSSHSessionTypeEnvironmentVariable, agent.MagicSSHSessionTypeVSCode)
defer session.Close()

command := "sh -c 'echo $" + agent.MagicSSHSessionTypeEnvironmentVariable + "'"
expected := ""
if runtime.GOOS == "windows" {
expected = "%" + agent.MagicSSHSessionTypeEnvironmentVariable + "%"
command = "cmd.exe /c echo " + expected
}
output, err := session.Output(command)
require.NoError(t, err)
require.Equal(t, expected, strings.TrimSpace(string(output)))
var s *agentsdk.Stats
require.Eventuallyf(t, func() bool {
var ok bool
s, ok = <-stats
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 &&
// Ensure that the connection didn't count as a "normal" SSH session.
// This was a special one, so it should be labeled specially in the stats!
s.SessionCountVSCode == 1 &&
// Ensure that connection latency is being counted!
// If it isn't, it's set to -1.
s.ConnectionMedianLatencyMS >= 0
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats: %+v", s,
)
Expand Down
30 changes: 25 additions & 5 deletions coderd/apidoc/docs.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 25 additions & 5 deletions coderd/apidoc/swagger.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading