Skip to content

feat: add statsReporter for reporting stats on agent v2 API #11920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions agent/stats.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package agent

import (
"context"
"sync"
"time"

"golang.org/x/xerrors"
"tailscale.com/types/netlogtype"

"cdr.dev/slog"
"github.com/coder/coder/v2/agent/proto"
)

const maxConns = 2048

type networkStatsSource interface {
SetConnStatsCallback(maxPeriod time.Duration, maxConns int, dump func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts))
}

type statsCollector interface {
Collect(ctx context.Context, networkStats map[netlogtype.Connection]netlogtype.Counts) *proto.Stats
}

type statsDest interface {
UpdateStats(ctx context.Context, req *proto.UpdateStatsRequest) (*proto.UpdateStatsResponse, error)
}

// statsReporter is a subcomponent of the agent that handles registering the stats callback on the
// networkStatsSource (tailnet.Conn in prod), handling the callback, calling back to the
// statsCollector (agent in prod) to collect additional stats, then sending the update to the
// statsDest (agent API in prod)
type statsReporter struct {
*sync.Cond
networkStats *map[netlogtype.Connection]netlogtype.Counts
unreported bool
lastInterval time.Duration

source networkStatsSource
collector statsCollector
logger slog.Logger
}

func newStatsReporter(logger slog.Logger, source networkStatsSource, collector statsCollector) *statsReporter {
return &statsReporter{
Cond: sync.NewCond(&sync.Mutex{}),
logger: logger,
source: source,
collector: collector,
}
}

func (s *statsReporter) callback(_, _ time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) {
s.L.Lock()
defer s.L.Unlock()
s.logger.Debug(context.Background(), "got stats callback")
s.networkStats = &virtual
s.unreported = true
s.Broadcast()
}

// reportLoop programs the source (tailnet.Conn) to send it stats via the
// callback, then reports them to the dest.
//
// It's intended to be called within the larger retry loop that establishes a
// connection to the agent API, then passes that connection to go routines like
// this that use it. There is no retry and we fail on the first error since
// this will be inside a larger retry loop.
func (s *statsReporter) reportLoop(ctx context.Context, dest statsDest) error {
// send an initial, blank report to get the interval
resp, err := dest.UpdateStats(ctx, &proto.UpdateStatsRequest{})
if err != nil {
return xerrors.Errorf("initial update: %w", err)
}
s.lastInterval = resp.ReportInterval.AsDuration()
s.source.SetConnStatsCallback(s.lastInterval, maxConns, s.callback)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to also unset the callback on exit? Mostly a safety precaution in case callback ever has code that may block without a consumer.


// use a separate goroutine to monitor the context so that we notice immediately, rather than
// waiting for the next callback (which might never come if we are closing!)
ctxDone := false
go func() {
<-ctx.Done()
s.L.Lock()
defer s.L.Unlock()
ctxDone = true
s.Broadcast()
}()
defer s.logger.Debug(ctx, "reportLoop exiting")

s.L.Lock()
defer s.L.Unlock()
for {
for !s.unreported && !ctxDone {
s.Wait()
}
if ctxDone {
return nil
}
networkStats := *s.networkStats
s.unreported = false
if err = s.reportLocked(ctx, dest, networkStats); err != nil {
return xerrors.Errorf("report stats: %w", err)
}
}
}

func (s *statsReporter) reportLocked(
ctx context.Context, dest statsDest, networkStats map[netlogtype.Connection]netlogtype.Counts,
) error {
// here we want to do our collecting/reporting while it is unlocked, but then relock
// when we return to reportLoop.
s.L.Unlock()
defer s.L.Lock()
stats := s.collector.Collect(ctx, networkStats)
resp, err := dest.UpdateStats(ctx, &proto.UpdateStatsRequest{Stats: stats})
if err != nil {
return err
}
interval := resp.GetReportInterval().AsDuration()
if interval != s.lastInterval {
s.logger.Info(ctx, "new stats report interval", slog.F("interval", interval))
s.lastInterval = interval
s.source.SetConnStatsCallback(s.lastInterval, maxConns, s.callback)
}
return nil
}
212 changes: 212 additions & 0 deletions agent/stats_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
package agent

import (
"context"
"net/netip"
"sync"
"testing"
"time"

"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/durationpb"
"tailscale.com/types/ipproto"

"tailscale.com/types/netlogtype"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/testutil"
)

func TestStatsReporter(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
fSource := newFakeNetworkStatsSource(ctx, t)
fCollector := newFakeCollector(t)
fDest := newFakeStatsDest()
uut := newStatsReporter(logger, fSource, fCollector)

loopErr := make(chan error, 1)
loopCtx, loopCancel := context.WithCancel(ctx)
go func() {
err := uut.reportLoop(loopCtx, fDest)
loopErr <- err
}()

// initial request to get duration
req := testutil.RequireRecvCtx(ctx, t, fDest.reqs)
require.NotNil(t, req)
require.Nil(t, req.Stats)
interval := time.Second * 34
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a special meaning to this magic value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wanted something that's not likely to be some internal constant and proves that we actually use the interval in the response.

testutil.RequireSendCtx(ctx, t, fDest.resps, &proto.UpdateStatsResponse{ReportInterval: durationpb.New(interval)})

// call to source to set the callback and interval
gotInterval := testutil.RequireRecvCtx(ctx, t, fSource.period)
require.Equal(t, interval, gotInterval)

// callback returning netstats
netStats := map[netlogtype.Connection]netlogtype.Counts{
{
Proto: ipproto.TCP,
Src: netip.MustParseAddrPort("192.168.1.33:4887"),
Dst: netip.MustParseAddrPort("192.168.2.99:9999"),
}: {
TxPackets: 22,
TxBytes: 23,
RxPackets: 24,
RxBytes: 25,
},
}
fSource.callback(time.Now(), time.Now(), netStats, nil)

// collector called to complete the stats
gotNetStats := testutil.RequireRecvCtx(ctx, t, fCollector.calls)
require.Equal(t, netStats, gotNetStats)

// while we are collecting the stats, send in two new netStats to simulate
// what happens if we don't keep up. Only the latest should be kept.
netStats0 := map[netlogtype.Connection]netlogtype.Counts{
{
Proto: ipproto.TCP,
Src: netip.MustParseAddrPort("192.168.1.33:4887"),
Dst: netip.MustParseAddrPort("192.168.2.99:9999"),
}: {
TxPackets: 10,
TxBytes: 10,
RxPackets: 10,
RxBytes: 10,
},
}
fSource.callback(time.Now(), time.Now(), netStats0, nil)
netStats1 := map[netlogtype.Connection]netlogtype.Counts{
{
Proto: ipproto.TCP,
Src: netip.MustParseAddrPort("192.168.1.33:4887"),
Dst: netip.MustParseAddrPort("192.168.2.99:9999"),
}: {
TxPackets: 11,
TxBytes: 11,
RxPackets: 11,
RxBytes: 11,
},
}
fSource.callback(time.Now(), time.Now(), netStats1, nil)

// complete first collection
stats := &proto.Stats{SessionCountJetbrains: 55}
testutil.RequireSendCtx(ctx, t, fCollector.stats, stats)

// destination called to report the first stats
update := testutil.RequireRecvCtx(ctx, t, fDest.reqs)
require.NotNil(t, update)
require.Equal(t, stats, update.Stats)
testutil.RequireSendCtx(ctx, t, fDest.resps, &proto.UpdateStatsResponse{ReportInterval: durationpb.New(interval)})

// second update -- only netStats1 is reported
gotNetStats = testutil.RequireRecvCtx(ctx, t, fCollector.calls)
require.Equal(t, netStats1, gotNetStats)
stats = &proto.Stats{SessionCountJetbrains: 66}
testutil.RequireSendCtx(ctx, t, fCollector.stats, stats)
update = testutil.RequireRecvCtx(ctx, t, fDest.reqs)
require.NotNil(t, update)
require.Equal(t, stats, update.Stats)
interval2 := 27 * time.Second
testutil.RequireSendCtx(ctx, t, fDest.resps, &proto.UpdateStatsResponse{ReportInterval: durationpb.New(interval2)})

// set the new interval
gotInterval = testutil.RequireRecvCtx(ctx, t, fSource.period)
require.Equal(t, interval2, gotInterval)

loopCancel()
err := testutil.RequireRecvCtx(ctx, t, loopErr)
require.NoError(t, err)
}

type fakeNetworkStatsSource struct {
sync.Mutex
ctx context.Context
t testing.TB
callback func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts)
period chan time.Duration
}

func (f *fakeNetworkStatsSource) SetConnStatsCallback(maxPeriod time.Duration, _ int, dump func(start time.Time, end time.Time, virtual map[netlogtype.Connection]netlogtype.Counts, physical map[netlogtype.Connection]netlogtype.Counts)) {
f.Lock()
defer f.Unlock()
f.callback = dump
select {
case <-f.ctx.Done():
f.t.Error("timeout")
case f.period <- maxPeriod:
// OK
}
}

func newFakeNetworkStatsSource(ctx context.Context, t testing.TB) *fakeNetworkStatsSource {
f := &fakeNetworkStatsSource{
ctx: ctx,
t: t,
period: make(chan time.Duration),
}
return f
}

type fakeCollector struct {
t testing.TB
calls chan map[netlogtype.Connection]netlogtype.Counts
stats chan *proto.Stats
}

func (f *fakeCollector) Collect(ctx context.Context, networkStats map[netlogtype.Connection]netlogtype.Counts) *proto.Stats {
select {
case <-ctx.Done():
f.t.Error("timeout on collect")
return nil
case f.calls <- networkStats:
// ok
}
select {
case <-ctx.Done():
f.t.Error("timeout on collect")
return nil
case s := <-f.stats:
return s
}
}

func newFakeCollector(t testing.TB) *fakeCollector {
return &fakeCollector{
t: t,
calls: make(chan map[netlogtype.Connection]netlogtype.Counts),
stats: make(chan *proto.Stats),
}
}

type fakeStatsDest struct {
reqs chan *proto.UpdateStatsRequest
resps chan *proto.UpdateStatsResponse
}

func (f *fakeStatsDest) UpdateStats(ctx context.Context, req *proto.UpdateStatsRequest) (*proto.UpdateStatsResponse, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case f.reqs <- req:
// OK
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case resp := <-f.resps:
return resp, nil
}
}

func newFakeStatsDest() *fakeStatsDest {
return &fakeStatsDest{
reqs: make(chan *proto.UpdateStatsRequest),
resps: make(chan *proto.UpdateStatsResponse),
}
}