diff --git a/agent/stats.go b/agent/stats.go new file mode 100644 index 0000000000000..2615ab339637b --- /dev/null +++ b/agent/stats.go @@ -0,0 +1,126 @@ +package agent + +import ( + "context" + "sync" + "time" + + "golang.org/x/xerrors" + "tailscale.com/types/netlogtype" + + "cdr.dev/slog" + "github.com/coder/coder/v2/agent/proto" +) + +const maxConns = 2048 + +type networkStatsSource interface { + SetConnStatsCallback(maxPeriod time.Duration, maxConns int, dump func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts)) +} + +type statsCollector interface { + Collect(ctx context.Context, networkStats map[netlogtype.Connection]netlogtype.Counts) *proto.Stats +} + +type statsDest interface { + UpdateStats(ctx context.Context, req *proto.UpdateStatsRequest) (*proto.UpdateStatsResponse, error) +} + +// statsReporter is a subcomponent of the agent that handles registering the stats callback on the +// networkStatsSource (tailnet.Conn in prod), handling the callback, calling back to the +// statsCollector (agent in prod) to collect additional stats, then sending the update to the +// statsDest (agent API in prod) +type statsReporter struct { + *sync.Cond + networkStats *map[netlogtype.Connection]netlogtype.Counts + unreported bool + lastInterval time.Duration + + source networkStatsSource + collector statsCollector + logger slog.Logger +} + +func newStatsReporter(logger slog.Logger, source networkStatsSource, collector statsCollector) *statsReporter { + return &statsReporter{ + Cond: sync.NewCond(&sync.Mutex{}), + logger: logger, + source: source, + collector: collector, + } +} + +func (s *statsReporter) callback(_, _ time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) { + s.L.Lock() + defer s.L.Unlock() + s.logger.Debug(context.Background(), "got stats callback") + s.networkStats = &virtual + s.unreported = true + s.Broadcast() +} + +// reportLoop programs the source (tailnet.Conn) to send it stats via the +// callback, then reports them to the dest. +// +// It's intended to be called within the larger retry loop that establishes a +// connection to the agent API, then passes that connection to go routines like +// this that use it. There is no retry and we fail on the first error since +// this will be inside a larger retry loop. +func (s *statsReporter) reportLoop(ctx context.Context, dest statsDest) error { + // send an initial, blank report to get the interval + resp, err := dest.UpdateStats(ctx, &proto.UpdateStatsRequest{}) + if err != nil { + return xerrors.Errorf("initial update: %w", err) + } + s.lastInterval = resp.ReportInterval.AsDuration() + s.source.SetConnStatsCallback(s.lastInterval, maxConns, s.callback) + + // use a separate goroutine to monitor the context so that we notice immediately, rather than + // waiting for the next callback (which might never come if we are closing!) + ctxDone := false + go func() { + <-ctx.Done() + s.L.Lock() + defer s.L.Unlock() + ctxDone = true + s.Broadcast() + }() + defer s.logger.Debug(ctx, "reportLoop exiting") + + s.L.Lock() + defer s.L.Unlock() + for { + for !s.unreported && !ctxDone { + s.Wait() + } + if ctxDone { + return nil + } + networkStats := *s.networkStats + s.unreported = false + if err = s.reportLocked(ctx, dest, networkStats); err != nil { + return xerrors.Errorf("report stats: %w", err) + } + } +} + +func (s *statsReporter) reportLocked( + ctx context.Context, dest statsDest, networkStats map[netlogtype.Connection]netlogtype.Counts, +) error { + // here we want to do our collecting/reporting while it is unlocked, but then relock + // when we return to reportLoop. + s.L.Unlock() + defer s.L.Lock() + stats := s.collector.Collect(ctx, networkStats) + resp, err := dest.UpdateStats(ctx, &proto.UpdateStatsRequest{Stats: stats}) + if err != nil { + return err + } + interval := resp.GetReportInterval().AsDuration() + if interval != s.lastInterval { + s.logger.Info(ctx, "new stats report interval", slog.F("interval", interval)) + s.lastInterval = interval + s.source.SetConnStatsCallback(s.lastInterval, maxConns, s.callback) + } + return nil +} diff --git a/agent/stats_internal_test.go b/agent/stats_internal_test.go new file mode 100644 index 0000000000000..bfd6a3436d499 --- /dev/null +++ b/agent/stats_internal_test.go @@ -0,0 +1,212 @@ +package agent + +import ( + "context" + "net/netip" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/durationpb" + "tailscale.com/types/ipproto" + + "tailscale.com/types/netlogtype" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/testutil" +) + +func TestStatsReporter(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + fSource := newFakeNetworkStatsSource(ctx, t) + fCollector := newFakeCollector(t) + fDest := newFakeStatsDest() + uut := newStatsReporter(logger, fSource, fCollector) + + loopErr := make(chan error, 1) + loopCtx, loopCancel := context.WithCancel(ctx) + go func() { + err := uut.reportLoop(loopCtx, fDest) + loopErr <- err + }() + + // initial request to get duration + req := testutil.RequireRecvCtx(ctx, t, fDest.reqs) + require.NotNil(t, req) + require.Nil(t, req.Stats) + interval := time.Second * 34 + testutil.RequireSendCtx(ctx, t, fDest.resps, &proto.UpdateStatsResponse{ReportInterval: durationpb.New(interval)}) + + // call to source to set the callback and interval + gotInterval := testutil.RequireRecvCtx(ctx, t, fSource.period) + require.Equal(t, interval, gotInterval) + + // callback returning netstats + netStats := map[netlogtype.Connection]netlogtype.Counts{ + { + Proto: ipproto.TCP, + Src: netip.MustParseAddrPort("192.168.1.33:4887"), + Dst: netip.MustParseAddrPort("192.168.2.99:9999"), + }: { + TxPackets: 22, + TxBytes: 23, + RxPackets: 24, + RxBytes: 25, + }, + } + fSource.callback(time.Now(), time.Now(), netStats, nil) + + // collector called to complete the stats + gotNetStats := testutil.RequireRecvCtx(ctx, t, fCollector.calls) + require.Equal(t, netStats, gotNetStats) + + // while we are collecting the stats, send in two new netStats to simulate + // what happens if we don't keep up. Only the latest should be kept. + netStats0 := map[netlogtype.Connection]netlogtype.Counts{ + { + Proto: ipproto.TCP, + Src: netip.MustParseAddrPort("192.168.1.33:4887"), + Dst: netip.MustParseAddrPort("192.168.2.99:9999"), + }: { + TxPackets: 10, + TxBytes: 10, + RxPackets: 10, + RxBytes: 10, + }, + } + fSource.callback(time.Now(), time.Now(), netStats0, nil) + netStats1 := map[netlogtype.Connection]netlogtype.Counts{ + { + Proto: ipproto.TCP, + Src: netip.MustParseAddrPort("192.168.1.33:4887"), + Dst: netip.MustParseAddrPort("192.168.2.99:9999"), + }: { + TxPackets: 11, + TxBytes: 11, + RxPackets: 11, + RxBytes: 11, + }, + } + fSource.callback(time.Now(), time.Now(), netStats1, nil) + + // complete first collection + stats := &proto.Stats{SessionCountJetbrains: 55} + testutil.RequireSendCtx(ctx, t, fCollector.stats, stats) + + // destination called to report the first stats + update := testutil.RequireRecvCtx(ctx, t, fDest.reqs) + require.NotNil(t, update) + require.Equal(t, stats, update.Stats) + testutil.RequireSendCtx(ctx, t, fDest.resps, &proto.UpdateStatsResponse{ReportInterval: durationpb.New(interval)}) + + // second update -- only netStats1 is reported + gotNetStats = testutil.RequireRecvCtx(ctx, t, fCollector.calls) + require.Equal(t, netStats1, gotNetStats) + stats = &proto.Stats{SessionCountJetbrains: 66} + testutil.RequireSendCtx(ctx, t, fCollector.stats, stats) + update = testutil.RequireRecvCtx(ctx, t, fDest.reqs) + require.NotNil(t, update) + require.Equal(t, stats, update.Stats) + interval2 := 27 * time.Second + testutil.RequireSendCtx(ctx, t, fDest.resps, &proto.UpdateStatsResponse{ReportInterval: durationpb.New(interval2)}) + + // set the new interval + gotInterval = testutil.RequireRecvCtx(ctx, t, fSource.period) + require.Equal(t, interval2, gotInterval) + + loopCancel() + err := testutil.RequireRecvCtx(ctx, t, loopErr) + require.NoError(t, err) +} + +type fakeNetworkStatsSource struct { + sync.Mutex + ctx context.Context + t testing.TB + callback func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts) + period chan time.Duration +} + +func (f *fakeNetworkStatsSource) SetConnStatsCallback(maxPeriod time.Duration, _ int, dump func(start time.Time, end time.Time, virtual map[netlogtype.Connection]netlogtype.Counts, physical map[netlogtype.Connection]netlogtype.Counts)) { + f.Lock() + defer f.Unlock() + f.callback = dump + select { + case <-f.ctx.Done(): + f.t.Error("timeout") + case f.period <- maxPeriod: + // OK + } +} + +func newFakeNetworkStatsSource(ctx context.Context, t testing.TB) *fakeNetworkStatsSource { + f := &fakeNetworkStatsSource{ + ctx: ctx, + t: t, + period: make(chan time.Duration), + } + return f +} + +type fakeCollector struct { + t testing.TB + calls chan map[netlogtype.Connection]netlogtype.Counts + stats chan *proto.Stats +} + +func (f *fakeCollector) Collect(ctx context.Context, networkStats map[netlogtype.Connection]netlogtype.Counts) *proto.Stats { + select { + case <-ctx.Done(): + f.t.Error("timeout on collect") + return nil + case f.calls <- networkStats: + // ok + } + select { + case <-ctx.Done(): + f.t.Error("timeout on collect") + return nil + case s := <-f.stats: + return s + } +} + +func newFakeCollector(t testing.TB) *fakeCollector { + return &fakeCollector{ + t: t, + calls: make(chan map[netlogtype.Connection]netlogtype.Counts), + stats: make(chan *proto.Stats), + } +} + +type fakeStatsDest struct { + reqs chan *proto.UpdateStatsRequest + resps chan *proto.UpdateStatsResponse +} + +func (f *fakeStatsDest) UpdateStats(ctx context.Context, req *proto.UpdateStatsRequest) (*proto.UpdateStatsResponse, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case f.reqs <- req: + // OK + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + case resp := <-f.resps: + return resp, nil + } +} + +func newFakeStatsDest() *fakeStatsDest { + return &fakeStatsDest{ + reqs: make(chan *proto.UpdateStatsRequest), + resps: make(chan *proto.UpdateStatsResponse), + } +}