From b850ab4ca9f3b1ebac0680d595c4fa54980626f9 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Tue, 6 Feb 2024 14:19:46 +0400 Subject: [PATCH] feat: change agent to use v2 API for reporting stats --- agent/agent.go | 191 +++++++++++++------------------ agent/agent_test.go | 35 +++--- agent/agenttest/client.go | 61 +++------- agent/metrics.go | 29 +++-- coderd/tailnet_test.go | 3 +- coderd/workspaceagents_test.go | 1 + codersdk/agentsdk/agentsdk.go | 65 +---------- codersdk/workspaceagents_test.go | 51 --------- 8 files changed, 135 insertions(+), 301 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 9973c0bbe54ee..48c8f66694844 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -89,7 +89,6 @@ type Options struct { type Client interface { ConnectRPC(ctx context.Context) (drpc.Conn, error) - ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error PostMetadata(ctx context.Context, req agentsdk.PostMetadataRequest) error PatchLogs(ctx context.Context, req agentsdk.PatchLogs) error @@ -158,7 +157,6 @@ func New(options Options) Agent { lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}}, ignorePorts: options.IgnorePorts, portCacheDuration: options.PortCacheDuration, - connStatsChan: make(chan *agentsdk.Stats, 1), reportMetadataInterval: options.ReportMetadataInterval, serviceBannerRefreshInterval: options.ServiceBannerRefreshInterval, sshMaxTimeout: options.SSHMaxTimeout, @@ -216,8 +214,7 @@ type agent struct { network *tailnet.Conn addresses []netip.Prefix - connStatsChan chan *agentsdk.Stats - latestStat atomic.Pointer[agentsdk.Stats] + statsReporter *statsReporter connCountReconnectingPTY atomic.Int64 @@ -822,14 +819,13 @@ func (a *agent) run(ctx context.Context) error { closed := a.isClosed() if !closed { a.network = network + a.statsReporter = newStatsReporter(a.logger, network, a) } a.closeMutex.Unlock() if closed { _ = network.Close() return xerrors.New("agent is closed") } - - a.startReportingConnectionStats(ctx) } else { // Update the wireguard IPs if the agent ID changed. err := network.SetAddresses(a.wireguardAddresses(manifest.AgentID)) @@ -871,6 +867,15 @@ func (a *agent) run(ctx context.Context) error { return nil }) + eg.Go(func() error { + a.logger.Debug(egCtx, "running stats report loop") + err := a.statsReporter.reportLoop(egCtx, aAPI) + if err != nil { + return xerrors.Errorf("report stats loop: %w", err) + } + return nil + }) + return eg.Wait() } @@ -1218,115 +1223,83 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m return rpty.Attach(ctx, connectionID, conn, msg.Height, msg.Width, connLogger) } -// startReportingConnectionStats runs the connection stats reporting goroutine. -func (a *agent) startReportingConnectionStats(ctx context.Context) { - reportStats := func(networkStats map[netlogtype.Connection]netlogtype.Counts) { - a.logger.Debug(ctx, "computing stats report") - stats := &agentsdk.Stats{ - ConnectionCount: int64(len(networkStats)), - ConnectionsByProto: map[string]int64{}, - } - for conn, counts := range networkStats { - stats.ConnectionsByProto[conn.Proto.String()]++ - stats.RxBytes += int64(counts.RxBytes) - stats.RxPackets += int64(counts.RxPackets) - stats.TxBytes += int64(counts.TxBytes) - stats.TxPackets += int64(counts.TxPackets) - } - - // The count of active sessions. - sshStats := a.sshServer.ConnStats() - stats.SessionCountSSH = sshStats.Sessions - stats.SessionCountVSCode = sshStats.VSCode - stats.SessionCountJetBrains = sshStats.JetBrains - - stats.SessionCountReconnectingPTY = a.connCountReconnectingPTY.Load() - - // Compute the median connection latency! - a.logger.Debug(ctx, "starting peer latency measurement for stats") - var wg sync.WaitGroup - var mu sync.Mutex - status := a.network.Status() - durations := []float64{} - pingCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second) - defer cancelFunc() - for nodeID, peer := range status.Peer { - if !peer.Active { - continue - } - addresses, found := a.network.NodeAddresses(nodeID) - if !found { - continue - } - if len(addresses) == 0 { - continue - } - wg.Add(1) - go func() { - defer wg.Done() - duration, _, _, err := a.network.Ping(pingCtx, addresses[0].Addr()) - if err != nil { - return - } - mu.Lock() - durations = append(durations, float64(duration.Microseconds())) - mu.Unlock() - }() +// Collect collects additional stats from the agent +func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connection]netlogtype.Counts) *proto.Stats { + a.logger.Debug(context.Background(), "computing stats report") + stats := &proto.Stats{ + ConnectionCount: int64(len(networkStats)), + ConnectionsByProto: map[string]int64{}, + } + for conn, counts := range networkStats { + stats.ConnectionsByProto[conn.Proto.String()]++ + stats.RxBytes += int64(counts.RxBytes) + stats.RxPackets += int64(counts.RxPackets) + stats.TxBytes += int64(counts.TxBytes) + stats.TxPackets += int64(counts.TxPackets) + } + + // The count of active sessions. + sshStats := a.sshServer.ConnStats() + stats.SessionCountSsh = sshStats.Sessions + stats.SessionCountVscode = sshStats.VSCode + stats.SessionCountJetbrains = sshStats.JetBrains + + stats.SessionCountReconnectingPty = a.connCountReconnectingPTY.Load() + + // Compute the median connection latency! + a.logger.Debug(ctx, "starting peer latency measurement for stats") + var wg sync.WaitGroup + var mu sync.Mutex + status := a.network.Status() + durations := []float64{} + pingCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second) + defer cancelFunc() + for nodeID, peer := range status.Peer { + if !peer.Active { + continue } - wg.Wait() - sort.Float64s(durations) - durationsLength := len(durations) - if durationsLength == 0 { - stats.ConnectionMedianLatencyMS = -1 - } else if durationsLength%2 == 0 { - stats.ConnectionMedianLatencyMS = (durations[durationsLength/2-1] + durations[durationsLength/2]) / 2 - } else { - stats.ConnectionMedianLatencyMS = durations[durationsLength/2] + addresses, found := a.network.NodeAddresses(nodeID) + if !found { + continue } - // Convert from microseconds to milliseconds. - stats.ConnectionMedianLatencyMS /= 1000 - - // Collect agent metrics. - // Agent metrics are changing all the time, so there is no need to perform - // reflect.DeepEqual to see if stats should be transferred. - - metricsCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second) - defer cancelFunc() - a.logger.Debug(ctx, "collecting agent metrics for stats") - stats.Metrics = a.collectMetrics(metricsCtx) - - a.latestStat.Store(stats) - - a.logger.Debug(ctx, "about to send stats") - select { - case a.connStatsChan <- stats: - a.logger.Debug(ctx, "successfully sent stats") - case <-a.closed: - a.logger.Debug(ctx, "didn't send stats because we are closed") + if len(addresses) == 0 { + continue } + wg.Add(1) + go func() { + defer wg.Done() + duration, _, _, err := a.network.Ping(pingCtx, addresses[0].Addr()) + if err != nil { + return + } + mu.Lock() + defer mu.Unlock() + durations = append(durations, float64(duration.Microseconds())) + }() } - - // Report statistics from the created network. - cl, err := a.client.ReportStats(ctx, a.logger, a.connStatsChan, func(d time.Duration) { - a.network.SetConnStatsCallback(d, 2048, - func(_, _ time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) { - reportStats(virtual) - }, - ) - }) - if err != nil { - a.logger.Error(ctx, "agent failed to report stats", slog.Error(err)) + wg.Wait() + sort.Float64s(durations) + durationsLength := len(durations) + if durationsLength == 0 { + stats.ConnectionMedianLatencyMs = -1 + } else if durationsLength%2 == 0 { + stats.ConnectionMedianLatencyMs = (durations[durationsLength/2-1] + durations[durationsLength/2]) / 2 } else { - if err = a.trackConnGoroutine(func() { - // This is OK because the agent never re-creates the tailnet - // and the only shutdown indicator is agent.Close(). - <-a.closed - _ = cl.Close() - }); err != nil { - a.logger.Debug(ctx, "report stats goroutine", slog.Error(err)) - _ = cl.Close() - } + stats.ConnectionMedianLatencyMs = durations[durationsLength/2] } + // Convert from microseconds to milliseconds. + stats.ConnectionMedianLatencyMs /= 1000 + + // Collect agent metrics. + // Agent metrics are changing all the time, so there is no need to perform + // reflect.DeepEqual to see if stats should be transferred. + + metricsCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second) + defer cancelFunc() + a.logger.Debug(ctx, "collecting agent metrics for stats") + stats.Metrics = a.collectMetrics(metricsCtx) + + return stats } var prioritizedProcs = []string{"coder agent"} diff --git a/agent/agent_test.go b/agent/agent_test.go index f7cbe41e96ec0..f30dc430addc7 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -52,6 +52,7 @@ import ( "github.com/coder/coder/v2/agent/agentproc/agentproctest" "github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/agent/agenttest" + "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/pty/ptytest" @@ -85,11 +86,11 @@ func TestAgent_Stats_SSH(t *testing.T) { err = session.Shell() require.NoError(t, err) - var s *agentsdk.Stats + var s *proto.Stats require.Eventuallyf(t, func() bool { var ok bool s, ok = <-stats - return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountSSH == 1 + return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountSsh == 1 }, testutil.WaitLong, testutil.IntervalFast, "never saw stats: %+v", s, ) @@ -118,11 +119,11 @@ func TestAgent_Stats_ReconnectingPTY(t *testing.T) { _, err = ptyConn.Write(data) require.NoError(t, err) - var s *agentsdk.Stats + var s *proto.Stats require.Eventuallyf(t, func() bool { var ok bool s, ok = <-stats - return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountReconnectingPTY == 1 + return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountReconnectingPty == 1 }, testutil.WaitLong, testutil.IntervalFast, "never saw stats: %+v", s, ) @@ -177,14 +178,14 @@ func TestAgent_Stats_Magic(t *testing.T) { require.Eventuallyf(t, func() bool { s, ok := <-stats t.Logf("got stats: ok=%t, ConnectionCount=%d, RxBytes=%d, TxBytes=%d, SessionCountVSCode=%d, ConnectionMedianLatencyMS=%f", - ok, s.ConnectionCount, s.RxBytes, s.TxBytes, s.SessionCountVSCode, s.ConnectionMedianLatencyMS) + ok, s.ConnectionCount, s.RxBytes, s.TxBytes, s.SessionCountVscode, s.ConnectionMedianLatencyMs) return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && // Ensure that the connection didn't count as a "normal" SSH session. // This was a special one, so it should be labeled specially in the stats! - s.SessionCountVSCode == 1 && + s.SessionCountVscode == 1 && // Ensure that connection latency is being counted! // If it isn't, it's set to -1. - s.ConnectionMedianLatencyMS >= 0 + s.ConnectionMedianLatencyMs >= 0 }, testutil.WaitLong, testutil.IntervalFast, "never saw stats", ) @@ -243,9 +244,9 @@ func TestAgent_Stats_Magic(t *testing.T) { require.Eventuallyf(t, func() bool { s, ok := <-stats t.Logf("got stats with conn open: ok=%t, ConnectionCount=%d, SessionCountJetBrains=%d", - ok, s.ConnectionCount, s.SessionCountJetBrains) + ok, s.ConnectionCount, s.SessionCountJetbrains) return ok && s.ConnectionCount > 0 && - s.SessionCountJetBrains == 1 + s.SessionCountJetbrains == 1 }, testutil.WaitLong, testutil.IntervalFast, "never saw stats with conn open", ) @@ -258,9 +259,9 @@ func TestAgent_Stats_Magic(t *testing.T) { require.Eventuallyf(t, func() bool { s, ok := <-stats t.Logf("got stats after disconnect %t, %d", - ok, s.SessionCountJetBrains) + ok, s.SessionCountJetbrains) return ok && - s.SessionCountJetBrains == 0 + s.SessionCountJetbrains == 0 }, testutil.WaitLong, testutil.IntervalFast, "never saw stats after conn closes", ) @@ -1346,7 +1347,7 @@ func TestAgent_Lifecycle(t *testing.T) { RunOnStop: true, }}, }, - make(chan *agentsdk.Stats, 50), + make(chan *proto.Stats, 50), tailnet.NewCoordinator(logger), ) defer client.Close() @@ -1667,7 +1668,7 @@ func TestAgent_UpdatedDERP(t *testing.T) { _ = coordinator.Close() }) agentID := uuid.New() - statsCh := make(chan *agentsdk.Stats, 50) + statsCh := make(chan *proto.Stats, 50) fs := afero.NewMemMapFs() client := agenttest.NewClient(t, logger.Named("agent"), @@ -1816,7 +1817,7 @@ func TestAgent_Reconnect(t *testing.T) { defer coordinator.Close() agentID := uuid.New() - statsCh := make(chan *agentsdk.Stats, 50) + statsCh := make(chan *proto.Stats, 50) derpMap, _ := tailnettest.RunDERPAndSTUN(t) client := agenttest.NewClient(t, logger, @@ -1861,7 +1862,7 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) { GitAuthConfigs: 1, DERPMap: &tailcfg.DERPMap{}, }, - make(chan *agentsdk.Stats, 50), + make(chan *proto.Stats, 50), coordinator, ) defer client.Close() @@ -2018,7 +2019,7 @@ func setupSSHSession( func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Duration, opts ...func(*agenttest.Client, *agent.Options)) ( *codersdk.WorkspaceAgentConn, *agenttest.Client, - <-chan *agentsdk.Stats, + <-chan *proto.Stats, afero.Fs, agent.Agent, ) { @@ -2046,7 +2047,7 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati t.Cleanup(func() { _ = coordinator.Close() }) - statsCh := make(chan *agentsdk.Stats, 50) + statsCh := make(chan *proto.Stats, 50) fs := afero.NewMemMapFs() c := agenttest.NewClient(t, logger.Named("agent"), metadata.AgentID, metadata, statsCh, coordinator) t.Cleanup(c.Close) diff --git a/agent/agenttest/client.go b/agent/agenttest/client.go index 0b7832ac6739a..50b67379cc5c3 100644 --- a/agent/agenttest/client.go +++ b/agent/agenttest/client.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/exp/maps" "golang.org/x/xerrors" + "google.golang.org/protobuf/types/known/durationpb" "storj.io/drpc" "storj.io/drpc/drpcmux" "storj.io/drpc/drpcserver" @@ -27,11 +28,13 @@ import ( "github.com/coder/coder/v2/testutil" ) +const statsInterval = 500 * time.Millisecond + func NewClient(t testing.TB, logger slog.Logger, agentID uuid.UUID, manifest agentsdk.Manifest, - statsChan chan *agentsdk.Stats, + statsChan chan *agentproto.Stats, coordinator tailnet.Coordinator, ) *Client { if manifest.AgentID == uuid.Nil { @@ -51,7 +54,7 @@ func NewClient(t testing.TB, require.NoError(t, err) mp, err := agentsdk.ProtoFromManifest(manifest) require.NoError(t, err) - fakeAAPI := NewFakeAgentAPI(t, logger, mp) + fakeAAPI := NewFakeAgentAPI(t, logger, mp, statsChan) err = agentproto.DRPCRegisterAgent(mux, fakeAAPI) require.NoError(t, err) server := drpcserver.NewWithOptions(mux, drpcserver.Options{ @@ -66,7 +69,6 @@ func NewClient(t testing.TB, t: t, logger: logger.Named("client"), agentID: agentID, - statsChan: statsChan, coordinator: coordinator, server: server, fakeAgentAPI: fakeAAPI, @@ -79,7 +81,6 @@ type Client struct { logger slog.Logger agentID uuid.UUID metadata map[string]agentsdk.Metadata - statsChan chan *agentsdk.Stats coordinator tailnet.Coordinator server *drpcserver.Server fakeAgentAPI *FakeAgentAPI @@ -121,38 +122,6 @@ func (c *Client) ConnectRPC(ctx context.Context) (drpc.Conn, error) { return conn, nil } -func (c *Client) ReportStats(ctx context.Context, _ slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) { - doneCh := make(chan struct{}) - ctx, cancel := context.WithCancel(ctx) - - go func() { - defer close(doneCh) - - setInterval(500 * time.Millisecond) - for { - select { - case <-ctx.Done(): - return - case stat := <-statsChan: - select { - case c.statsChan <- stat: - case <-ctx.Done(): - return - default: - // We don't want to send old stats. - continue - } - } - } - }() - return closeFunc(func() error { - cancel() - <-doneCh - close(c.statsChan) - return nil - }), nil -} - func (c *Client) GetLifecycleStates() []codersdk.WorkspaceAgentLifecycle { c.mu.Lock() defer c.mu.Unlock() @@ -223,12 +192,6 @@ func (c *Client) PushDERPMapUpdate(update *tailcfg.DERPMap) error { return nil } -type closeFunc func() error - -func (c closeFunc) Close() error { - return c() -} - type FakeAgentAPI struct { sync.Mutex t testing.TB @@ -236,6 +199,7 @@ type FakeAgentAPI struct { manifest *agentproto.Manifest startupCh chan *agentproto.Startup + statsCh chan *agentproto.Stats getServiceBannerFunc func() (codersdk.ServiceBannerConfig, error) } @@ -264,9 +228,13 @@ func (f *FakeAgentAPI) GetServiceBanner(context.Context, *agentproto.GetServiceB return agentsdk.ProtoFromServiceBanner(sb), nil } -func (*FakeAgentAPI) UpdateStats(context.Context, *agentproto.UpdateStatsRequest) (*agentproto.UpdateStatsResponse, error) { - // TODO implement me - panic("implement me") +func (f *FakeAgentAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsRequest) (*agentproto.UpdateStatsResponse, error) { + f.logger.Debug(ctx, "update stats called", slog.F("req", req)) + // empty request is sent to get the interval; but our tests don't want empty stats requests + if req.Stats != nil { + f.statsCh <- req.Stats + } + return &agentproto.UpdateStatsResponse{ReportInterval: durationpb.New(statsInterval)}, nil } func (*FakeAgentAPI) UpdateLifecycle(context.Context, *agentproto.UpdateLifecycleRequest) (*agentproto.Lifecycle, error) { @@ -294,11 +262,12 @@ func (*FakeAgentAPI) BatchCreateLogs(context.Context, *agentproto.BatchCreateLog panic("implement me") } -func NewFakeAgentAPI(t testing.TB, logger slog.Logger, manifest *agentproto.Manifest) *FakeAgentAPI { +func NewFakeAgentAPI(t testing.TB, logger slog.Logger, manifest *agentproto.Manifest, statsCh chan *agentproto.Stats) *FakeAgentAPI { return &FakeAgentAPI{ t: t, logger: logger.Named("FakeAgentAPI"), manifest: manifest, + statsCh: statsCh, startupCh: make(chan *agentproto.Startup, 100), } } diff --git a/agent/metrics.go b/agent/metrics.go index d987bad9a50c0..5a60740c4c969 100644 --- a/agent/metrics.go +++ b/agent/metrics.go @@ -10,8 +10,7 @@ import ( "tailscale.com/util/clientmetric" "cdr.dev/slog" - - "github.com/coder/coder/v2/codersdk/agentsdk" + "github.com/coder/coder/v2/agent/proto" ) type agentMetrics struct { @@ -53,8 +52,8 @@ func newAgentMetrics(registerer prometheus.Registerer) *agentMetrics { } } -func (a *agent) collectMetrics(ctx context.Context) []agentsdk.AgentMetric { - var collected []agentsdk.AgentMetric +func (a *agent) collectMetrics(ctx context.Context) []*proto.Stats_Metric { + var collected []*proto.Stats_Metric // Tailscale internal metrics metrics := clientmetric.Metrics() @@ -63,7 +62,7 @@ func (a *agent) collectMetrics(ctx context.Context) []agentsdk.AgentMetric { continue } - collected = append(collected, agentsdk.AgentMetric{ + collected = append(collected, &proto.Stats_Metric{ Name: m.Name(), Type: asMetricType(m.Type()), Value: float64(m.Value()), @@ -81,16 +80,16 @@ func (a *agent) collectMetrics(ctx context.Context) []agentsdk.AgentMetric { labels := toAgentMetricLabels(metric.Label) if metric.Counter != nil { - collected = append(collected, agentsdk.AgentMetric{ + collected = append(collected, &proto.Stats_Metric{ Name: metricFamily.GetName(), - Type: agentsdk.AgentMetricTypeCounter, + Type: proto.Stats_Metric_COUNTER, Value: metric.Counter.GetValue(), Labels: labels, }) } else if metric.Gauge != nil { - collected = append(collected, agentsdk.AgentMetric{ + collected = append(collected, &proto.Stats_Metric{ Name: metricFamily.GetName(), - Type: agentsdk.AgentMetricTypeGauge, + Type: proto.Stats_Metric_GAUGE, Value: metric.Gauge.GetValue(), Labels: labels, }) @@ -102,14 +101,14 @@ func (a *agent) collectMetrics(ctx context.Context) []agentsdk.AgentMetric { return collected } -func toAgentMetricLabels(metricLabels []*prompb.LabelPair) []agentsdk.AgentMetricLabel { +func toAgentMetricLabels(metricLabels []*prompb.LabelPair) []*proto.Stats_Metric_Label { if len(metricLabels) == 0 { return nil } - labels := make([]agentsdk.AgentMetricLabel, 0, len(metricLabels)) + labels := make([]*proto.Stats_Metric_Label, 0, len(metricLabels)) for _, metricLabel := range metricLabels { - labels = append(labels, agentsdk.AgentMetricLabel{ + labels = append(labels, &proto.Stats_Metric_Label{ Name: metricLabel.GetName(), Value: metricLabel.GetValue(), }) @@ -130,12 +129,12 @@ func isIgnoredMetric(metricName string) bool { return false } -func asMetricType(typ clientmetric.Type) agentsdk.AgentMetricType { +func asMetricType(typ clientmetric.Type) proto.Stats_Metric_Type { switch typ { case clientmetric.TypeGauge: - return agentsdk.AgentMetricTypeGauge + return proto.Stats_Metric_GAUGE case clientmetric.TypeCounter: - return agentsdk.AgentMetricTypeCounter + return proto.Stats_Metric_COUNTER default: panic(fmt.Sprintf("unknown metric type: %d", typ)) } diff --git a/coderd/tailnet_test.go b/coderd/tailnet_test.go index 73ccba701b632..7da5b51d2d7ca 100644 --- a/coderd/tailnet_test.go +++ b/coderd/tailnet_test.go @@ -24,6 +24,7 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/agent" "github.com/coder/coder/v2/agent/agenttest" + "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" @@ -327,7 +328,7 @@ func setupServerTailnetAgent(t *testing.T, agentNum int) ([]agentWithID, *coderd DERPMap: derpMap, } - c := agenttest.NewClient(t, logger, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord) + c := agenttest.NewClient(t, logger, manifest.AgentID, manifest, make(chan *proto.Stats, 50), coord) t.Cleanup(c.Close) options := agent.Options{ diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 6457fc3771df7..fe6d229cd1fac 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -890,6 +890,7 @@ func TestWorkspaceAgentAppHealth(t *testing.T) { require.EqualValues(t, codersdk.WorkspaceAppHealthUnhealthy, manifest.Apps[1].Health) } +// TestWorkspaceAgentReportStats tests the legacy (agent API v1) report stats endpoint. func TestWorkspaceAgentReportStats(t *testing.T) { t.Parallel() diff --git a/codersdk/agentsdk/agentsdk.go b/codersdk/agentsdk/agentsdk.go index bdbf61a4ad1f2..82a8e697eeed3 100644 --- a/codersdk/agentsdk/agentsdk.go +++ b/codersdk/agentsdk/agentsdk.go @@ -24,7 +24,6 @@ import ( "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/codersdk" drpcsdk "github.com/coder/coder/v2/codersdk/drpc" - "github.com/coder/retry" ) // ExternalLogSourceID is the statically-defined ID of a log-source that @@ -390,61 +389,6 @@ func (c *Client) AuthAzureInstanceIdentity(ctx context.Context) (AuthenticateRes return resp, json.NewDecoder(res.Body).Decode(&resp) } -// ReportStats begins a stat streaming connection with the Coder server. -// It is resilient to network failures and intermittent coderd issues. -func (c *Client) ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *Stats, setInterval func(time.Duration)) (io.Closer, error) { - var interval time.Duration - ctx, cancel := context.WithCancel(ctx) - exited := make(chan struct{}) - - postStat := func(stat *Stats) { - var nextInterval time.Duration - for r := retry.New(100*time.Millisecond, time.Minute); r.Wait(ctx); { - resp, err := c.PostStats(ctx, stat) - if err != nil { - if !xerrors.Is(err, context.Canceled) { - log.Error(ctx, "report stats", slog.Error(err)) - } - continue - } - - nextInterval = resp.ReportInterval - break - } - - if nextInterval != 0 && interval != nextInterval { - setInterval(nextInterval) - } - interval = nextInterval - } - - // Send an empty stat to get the interval. - postStat(&Stats{}) - - go func() { - defer close(exited) - - for { - select { - case <-ctx.Done(): - return - case stat, ok := <-statsChan: - if !ok { - return - } - - postStat(stat) - } - } - }() - - return closeFunc(func() error { - cancel() - <-exited - return nil - }), nil -} - // Stats records the Agent's network connection statistics for use in // user-facing metrics and debugging. type Stats struct { @@ -509,6 +453,9 @@ type StatsResponse struct { ReportInterval time.Duration `json:"report_interval"` } +// PostStats sends agent stats to the coder server +// +// Deprecated: uses agent API v1 endpoint func (c *Client) PostStats(ctx context.Context, stats *Stats) (StatsResponse, error) { res, err := c.SDK.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/me/report-stats", stats) if err != nil { @@ -649,12 +596,6 @@ func (c *Client) ExternalAuth(ctx context.Context, req ExternalAuthRequest) (Ext return authResp, json.NewDecoder(res.Body).Decode(&authResp) } -type closeFunc func() error - -func (c closeFunc) Close() error { - return c() -} - // wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func // is called if a read or write error is encountered. type wsNetConn struct { diff --git a/codersdk/workspaceagents_test.go b/codersdk/workspaceagents_test.go index 4ae07f4dc66c2..31a516bfdd96f 100644 --- a/codersdk/workspaceagents_test.go +++ b/codersdk/workspaceagents_test.go @@ -1,22 +1,13 @@ package codersdk_test import ( - "context" - "net/http" - "net/http/httptest" "net/url" - "sync/atomic" "testing" - "time" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/tailcfg" - "cdr.dev/slog/sloggers/slogtest" - "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/codersdk/agentsdk" - "github.com/coder/coder/v2/testutil" ) func TestWorkspaceRewriteDERPMap(t *testing.T) { @@ -46,45 +37,3 @@ func TestWorkspaceRewriteDERPMap(t *testing.T) { require.Equal(t, "coconuts.org", node.HostName) require.Equal(t, 44558, node.DERPPort) } - -func TestAgentReportStats(t *testing.T) { - t.Parallel() - - var ( - numReports atomic.Int64 - numIntervalCalls atomic.Int64 - wantInterval = 5 * time.Millisecond - ) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - numReports.Add(1) - httpapi.Write(context.Background(), w, http.StatusOK, agentsdk.StatsResponse{ - ReportInterval: wantInterval, - }) - })) - parsed, err := url.Parse(srv.URL) - require.NoError(t, err) - client := agentsdk.New(parsed) - - assertStatInterval := func(interval time.Duration) { - numIntervalCalls.Add(1) - assert.Equal(t, wantInterval, interval) - } - - chanLen := 3 - statCh := make(chan *agentsdk.Stats, chanLen) - for i := 0; i < chanLen; i++ { - statCh <- &agentsdk.Stats{ConnectionsByProto: map[string]int64{}} - } - - ctx := context.Background() - closeStream, err := client.ReportStats(ctx, slogtest.Make(t, nil), statCh, assertStatInterval) - require.NoError(t, err) - defer closeStream.Close() - - require.Eventually(t, - func() bool { return numReports.Load() >= 3 }, - testutil.WaitMedium, testutil.IntervalFast, - ) - closeStream.Close() - require.Equal(t, int64(1), numIntervalCalls.Load()) -}