Skip to content

Commit 8c43c94

Browse files
committed
feat: add statsReporter for reporting stats on agent v2 API
1 parent 619bdd1 commit 8c43c94

File tree

2 files changed

+338
-0
lines changed

2 files changed

+338
-0
lines changed

agent/stats.go

+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
package agent
2+
3+
import (
4+
"context"
5+
"sync"
6+
"time"
7+
8+
"golang.org/x/xerrors"
9+
"tailscale.com/types/netlogtype"
10+
11+
"cdr.dev/slog"
12+
"github.com/coder/coder/v2/agent/proto"
13+
)
14+
15+
const maxConns = 2048
16+
17+
type networkStatsSource interface {
18+
SetConnStatsCallback(maxPeriod time.Duration, maxConns int, dump func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts))
19+
}
20+
21+
type statsCollector interface {
22+
Collect(ctx context.Context, networkStats map[netlogtype.Connection]netlogtype.Counts) *proto.Stats
23+
}
24+
25+
type statsDest interface {
26+
UpdateStats(ctx context.Context, req *proto.UpdateStatsRequest) (*proto.UpdateStatsResponse, error)
27+
}
28+
29+
// statsReporter is a subcomponent of the agent that handles registering the stats callback on the
30+
// networkStatsSource (tailnet.Conn in prod), handling the callback, calling back to the
31+
// statsCollector (agent in prod) to collect additional stats, then sending the update to the
32+
// statsDest (agent API in prod)
33+
type statsReporter struct {
34+
*sync.Cond
35+
networkStats *map[netlogtype.Connection]netlogtype.Counts
36+
unreported bool
37+
lastInterval time.Duration
38+
39+
source networkStatsSource
40+
collector statsCollector
41+
logger slog.Logger
42+
}
43+
44+
func newStatsReporter(logger slog.Logger, source networkStatsSource, collector statsCollector) *statsReporter {
45+
return &statsReporter{
46+
Cond: sync.NewCond(&sync.Mutex{}),
47+
logger: logger,
48+
source: source,
49+
collector: collector,
50+
}
51+
}
52+
53+
func (s *statsReporter) callback(_, _ time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) {
54+
s.L.Lock()
55+
defer s.L.Unlock()
56+
s.logger.Debug(context.Background(), "got stats callback")
57+
s.networkStats = &virtual
58+
s.unreported = true
59+
s.Broadcast()
60+
}
61+
62+
// reportLoop programs the source (tailnet.Conn) to send it stats via the
63+
// callback, then reports them to the dest.
64+
//
65+
// It's intended to be called within the larger retry loop that establishes a
66+
// connection to the agent API, then passes that connection to go routines like
67+
// this that use it. There is no retry and we fail on the first error since
68+
// this will be inside a larger retry loop.
69+
func (s *statsReporter) reportLoop(ctx context.Context, dest statsDest) error {
70+
// send an initial, blank report to get the interval
71+
resp, err := dest.UpdateStats(ctx, &proto.UpdateStatsRequest{})
72+
if err != nil {
73+
return xerrors.Errorf("initial update: %w", err)
74+
}
75+
s.lastInterval = resp.ReportInterval.AsDuration()
76+
s.source.SetConnStatsCallback(s.lastInterval, maxConns, s.callback)
77+
78+
// use a separate goroutine to monitor the context so that we notice immediately, rather than
79+
// waiting for the next callback (which might never come if we are closing!)
80+
ctxDone := false
81+
go func() {
82+
<-ctx.Done()
83+
s.L.Lock()
84+
defer s.L.Unlock()
85+
ctxDone = true
86+
s.Broadcast()
87+
}()
88+
defer s.logger.Debug(ctx, "reportLoop exiting")
89+
90+
s.L.Lock()
91+
defer s.L.Unlock()
92+
for {
93+
for !s.unreported && !ctxDone {
94+
s.Wait()
95+
}
96+
if ctxDone {
97+
return nil
98+
}
99+
networkStats := *s.networkStats
100+
s.unreported = false
101+
if err = s.reportLocked(ctx, dest, networkStats); err != nil {
102+
return xerrors.Errorf("report stats: %w", err)
103+
}
104+
}
105+
}
106+
107+
func (s *statsReporter) reportLocked(
108+
ctx context.Context, dest statsDest, networkStats map[netlogtype.Connection]netlogtype.Counts,
109+
) error {
110+
// here we want to do our collecting/reporting while it is unlocked, but then relock
111+
// when we return to reportLoop.
112+
s.L.Unlock()
113+
defer s.L.Lock()
114+
stats := s.collector.Collect(ctx, networkStats)
115+
resp, err := dest.UpdateStats(ctx, &proto.UpdateStatsRequest{Stats: stats})
116+
if err != nil {
117+
return err
118+
}
119+
interval := resp.GetReportInterval().AsDuration()
120+
if interval != s.lastInterval {
121+
s.logger.Info(ctx, "new stats report interval", slog.F("interval", interval))
122+
s.lastInterval = interval
123+
s.source.SetConnStatsCallback(s.lastInterval, maxConns, s.callback)
124+
}
125+
return nil
126+
}

agent/stats_internal_test.go

+212
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
package agent
2+
3+
import (
4+
"context"
5+
"net/netip"
6+
"sync"
7+
"testing"
8+
"time"
9+
10+
"github.com/stretchr/testify/require"
11+
"google.golang.org/protobuf/types/known/durationpb"
12+
"tailscale.com/types/ipproto"
13+
14+
"tailscale.com/types/netlogtype"
15+
16+
"cdr.dev/slog"
17+
"cdr.dev/slog/sloggers/slogtest"
18+
"github.com/coder/coder/v2/agent/proto"
19+
"github.com/coder/coder/v2/testutil"
20+
)
21+
22+
func TestStatsReporter(t *testing.T) {
23+
t.Parallel()
24+
ctx := testutil.Context(t, testutil.WaitShort)
25+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
26+
fSource := newFakeNetworkStatsSource(ctx, t)
27+
fCollector := newFakeCollector(t)
28+
fDest := newFakeStatsDest()
29+
uut := newStatsReporter(logger, fSource, fCollector)
30+
31+
loopErr := make(chan error, 1)
32+
loopCtx, loopCancel := context.WithCancel(ctx)
33+
go func() {
34+
err := uut.reportLoop(loopCtx, fDest)
35+
loopErr <- err
36+
}()
37+
38+
// initial request to get duration
39+
req := testutil.RequireRecvCtx(ctx, t, fDest.reqs)
40+
require.NotNil(t, req)
41+
require.Nil(t, req.Stats)
42+
interval := time.Second * 34
43+
testutil.RequireSendCtx(ctx, t, fDest.resps, &proto.UpdateStatsResponse{ReportInterval: durationpb.New(interval)})
44+
45+
// call to source to set the callback and interval
46+
gotInterval := testutil.RequireRecvCtx(ctx, t, fSource.period)
47+
require.Equal(t, interval, gotInterval)
48+
49+
// callback returning netstats
50+
netStats := map[netlogtype.Connection]netlogtype.Counts{
51+
{
52+
Proto: ipproto.TCP,
53+
Src: netip.MustParseAddrPort("192.168.1.33:4887"),
54+
Dst: netip.MustParseAddrPort("192.168.2.99:9999"),
55+
}: {
56+
TxPackets: 22,
57+
TxBytes: 23,
58+
RxPackets: 24,
59+
RxBytes: 25,
60+
},
61+
}
62+
fSource.callback(time.Now(), time.Now(), netStats, nil)
63+
64+
// collector called to complete the stats
65+
gotNetStats := testutil.RequireRecvCtx(ctx, t, fCollector.calls)
66+
require.Equal(t, netStats, gotNetStats)
67+
68+
// while we are collecting the stats, send in two new netStats to simulate
69+
// what happens if we don't keep up. Only the latest should be kept.
70+
netStats0 := map[netlogtype.Connection]netlogtype.Counts{
71+
{
72+
Proto: ipproto.TCP,
73+
Src: netip.MustParseAddrPort("192.168.1.33:4887"),
74+
Dst: netip.MustParseAddrPort("192.168.2.99:9999"),
75+
}: {
76+
TxPackets: 10,
77+
TxBytes: 10,
78+
RxPackets: 10,
79+
RxBytes: 10,
80+
},
81+
}
82+
fSource.callback(time.Now(), time.Now(), netStats0, nil)
83+
netStats1 := map[netlogtype.Connection]netlogtype.Counts{
84+
{
85+
Proto: ipproto.TCP,
86+
Src: netip.MustParseAddrPort("192.168.1.33:4887"),
87+
Dst: netip.MustParseAddrPort("192.168.2.99:9999"),
88+
}: {
89+
TxPackets: 11,
90+
TxBytes: 11,
91+
RxPackets: 11,
92+
RxBytes: 11,
93+
},
94+
}
95+
fSource.callback(time.Now(), time.Now(), netStats1, nil)
96+
97+
// complete first collection
98+
stats := &proto.Stats{SessionCountJetbrains: 55}
99+
testutil.RequireSendCtx(ctx, t, fCollector.stats, stats)
100+
101+
// destination called to report the first stats
102+
update := testutil.RequireRecvCtx(ctx, t, fDest.reqs)
103+
require.NotNil(t, update)
104+
require.Equal(t, stats, update.Stats)
105+
testutil.RequireSendCtx(ctx, t, fDest.resps, &proto.UpdateStatsResponse{ReportInterval: durationpb.New(interval)})
106+
107+
// second update -- only netStats1 is reported
108+
gotNetStats = testutil.RequireRecvCtx(ctx, t, fCollector.calls)
109+
require.Equal(t, netStats1, gotNetStats)
110+
stats = &proto.Stats{SessionCountJetbrains: 66}
111+
testutil.RequireSendCtx(ctx, t, fCollector.stats, stats)
112+
update = testutil.RequireRecvCtx(ctx, t, fDest.reqs)
113+
require.NotNil(t, update)
114+
require.Equal(t, stats, update.Stats)
115+
interval2 := 27 * time.Second
116+
testutil.RequireSendCtx(ctx, t, fDest.resps, &proto.UpdateStatsResponse{ReportInterval: durationpb.New(interval2)})
117+
118+
// set the new interval
119+
gotInterval = testutil.RequireRecvCtx(ctx, t, fSource.period)
120+
require.Equal(t, interval2, gotInterval)
121+
122+
loopCancel()
123+
err := testutil.RequireRecvCtx(ctx, t, loopErr)
124+
require.NoError(t, err)
125+
}
126+
127+
type fakeNetworkStatsSource struct {
128+
sync.Mutex
129+
ctx context.Context
130+
t testing.TB
131+
callback func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts)
132+
period chan time.Duration
133+
}
134+
135+
func (f *fakeNetworkStatsSource) SetConnStatsCallback(maxPeriod time.Duration, _ int, dump func(start time.Time, end time.Time, virtual map[netlogtype.Connection]netlogtype.Counts, physical map[netlogtype.Connection]netlogtype.Counts)) {
136+
f.Lock()
137+
defer f.Unlock()
138+
f.callback = dump
139+
select {
140+
case <-f.ctx.Done():
141+
f.t.Error("timeout")
142+
case f.period <- maxPeriod:
143+
// OK
144+
}
145+
}
146+
147+
func newFakeNetworkStatsSource(ctx context.Context, t testing.TB) *fakeNetworkStatsSource {
148+
f := &fakeNetworkStatsSource{
149+
ctx: ctx,
150+
t: t,
151+
period: make(chan time.Duration),
152+
}
153+
return f
154+
}
155+
156+
type fakeCollector struct {
157+
t testing.TB
158+
calls chan map[netlogtype.Connection]netlogtype.Counts
159+
stats chan *proto.Stats
160+
}
161+
162+
func (f *fakeCollector) Collect(ctx context.Context, networkStats map[netlogtype.Connection]netlogtype.Counts) *proto.Stats {
163+
select {
164+
case <-ctx.Done():
165+
f.t.Error("timeout on collect")
166+
return nil
167+
case f.calls <- networkStats:
168+
// ok
169+
}
170+
select {
171+
case <-ctx.Done():
172+
f.t.Error("timeout on collect")
173+
return nil
174+
case s := <-f.stats:
175+
return s
176+
}
177+
}
178+
179+
func newFakeCollector(t testing.TB) *fakeCollector {
180+
return &fakeCollector{
181+
t: t,
182+
calls: make(chan map[netlogtype.Connection]netlogtype.Counts),
183+
stats: make(chan *proto.Stats),
184+
}
185+
}
186+
187+
type fakeStatsDest struct {
188+
reqs chan *proto.UpdateStatsRequest
189+
resps chan *proto.UpdateStatsResponse
190+
}
191+
192+
func (f *fakeStatsDest) UpdateStats(ctx context.Context, req *proto.UpdateStatsRequest) (*proto.UpdateStatsResponse, error) {
193+
select {
194+
case <-ctx.Done():
195+
return nil, ctx.Err()
196+
case f.reqs <- req:
197+
// OK
198+
}
199+
select {
200+
case <-ctx.Done():
201+
return nil, ctx.Err()
202+
case resp := <-f.resps:
203+
return resp, nil
204+
}
205+
}
206+
207+
func newFakeStatsDest() *fakeStatsDest {
208+
return &fakeStatsDest{
209+
reqs: make(chan *proto.UpdateStatsRequest),
210+
resps: make(chan *proto.UpdateStatsResponse),
211+
}
212+
}

0 commit comments

Comments
 (0)