Skip to content

Commit 11f3d11

Browse files
committed
feat: add statsReporter for reporting stats on agent v2 API
1 parent 60653bb commit 11f3d11

File tree

2 files changed

+322
-0
lines changed

2 files changed

+322
-0
lines changed

agent/stats.go

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
package agent
2+
3+
import (
4+
"context"
5+
"sync"
6+
"time"
7+
8+
"golang.org/x/xerrors"
9+
"tailscale.com/types/netlogtype"
10+
11+
"cdr.dev/slog"
12+
"github.com/coder/coder/v2/agent/proto"
13+
)
14+
15+
const maxConns = 2048
16+
17+
type networkStatsSource interface {
18+
SetConnStatsCallback(maxPeriod time.Duration, maxConns int, dump func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts))
19+
}
20+
21+
type statsCollector interface {
22+
Collect(ctx context.Context, networkStats map[netlogtype.Connection]netlogtype.Counts) *proto.Stats
23+
}
24+
25+
type statsDest interface {
26+
UpdateStats(ctx context.Context, req *proto.UpdateStatsRequest) (*proto.UpdateStatsResponse, error)
27+
}
28+
29+
type statsReporter struct {
30+
sync.Cond
31+
networkStats *map[netlogtype.Connection]netlogtype.Counts
32+
unreported bool
33+
lastInterval time.Duration
34+
35+
source networkStatsSource
36+
collector statsCollector
37+
logger slog.Logger
38+
}
39+
40+
func newStatsReporter(logger slog.Logger, source networkStatsSource, collector statsCollector) *statsReporter {
41+
return &statsReporter{
42+
Cond: *(sync.NewCond(&sync.Mutex{})),
43+
logger: logger,
44+
source: source,
45+
collector: collector,
46+
}
47+
}
48+
49+
func (s *statsReporter) callback(_, _ time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) {
50+
s.L.Lock()
51+
defer s.L.Unlock()
52+
s.logger.Debug(context.Background(), "got stats callback")
53+
s.networkStats = &virtual
54+
s.unreported = true
55+
s.Broadcast()
56+
}
57+
58+
func (s *statsReporter) reportLoop(ctx context.Context, dest statsDest) error {
59+
// send an initial, blank report to get the interval
60+
resp, err := dest.UpdateStats(ctx, &proto.UpdateStatsRequest{})
61+
if err != nil {
62+
return xerrors.Errorf("initial update: %w", err)
63+
}
64+
s.lastInterval = resp.ReportInterval.AsDuration()
65+
s.source.SetConnStatsCallback(s.lastInterval, maxConns, s.callback)
66+
67+
// use a separate goroutine to monitor the context so that we notice immediately, rather than
68+
// waiting for the next callback (which might never come if we are closing!)
69+
ctxDone := false
70+
go func() {
71+
<-ctx.Done()
72+
s.L.Lock()
73+
defer s.L.Unlock()
74+
ctxDone = true
75+
s.Broadcast()
76+
}()
77+
defer s.logger.Debug(ctx, "reportLoop exiting")
78+
79+
s.L.Lock()
80+
defer s.L.Unlock()
81+
for {
82+
for !s.unreported && !ctxDone {
83+
s.Wait()
84+
}
85+
if ctxDone {
86+
return nil
87+
}
88+
networkStats := *s.networkStats
89+
s.unreported = false
90+
if err = s.reportLocked(ctx, dest, networkStats); err != nil {
91+
return xerrors.Errorf("report stats:%w", err)
92+
}
93+
}
94+
}
95+
96+
func (s *statsReporter) reportLocked(
97+
ctx context.Context, dest statsDest, networkStats map[netlogtype.Connection]netlogtype.Counts,
98+
) error {
99+
// here we want to do our collecting/reporting while it is unlocked, but then relock
100+
// when we return to reportLoop.
101+
s.L.Unlock()
102+
defer s.L.Lock()
103+
stats := s.collector.Collect(ctx, networkStats)
104+
resp, err := dest.UpdateStats(ctx, &proto.UpdateStatsRequest{Stats: stats})
105+
if err != nil {
106+
return err
107+
}
108+
interval := resp.GetReportInterval().AsDuration()
109+
if interval != s.lastInterval {
110+
s.logger.Info(ctx, "new stats report interval", slog.F("interval", interval))
111+
s.lastInterval = interval
112+
s.source.SetConnStatsCallback(s.lastInterval, maxConns, s.callback)
113+
}
114+
return nil
115+
}

agent/stats_internal_test.go

+207
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
package agent
2+
3+
import (
4+
"context"
5+
"net/netip"
6+
"sync"
7+
"testing"
8+
"time"
9+
10+
"github.com/stretchr/testify/require"
11+
"google.golang.org/protobuf/types/known/durationpb"
12+
"tailscale.com/types/ipproto"
13+
14+
"tailscale.com/types/netlogtype"
15+
16+
"cdr.dev/slog"
17+
"cdr.dev/slog/sloggers/slogtest"
18+
"github.com/coder/coder/v2/agent/proto"
19+
"github.com/coder/coder/v2/testutil"
20+
)
21+
22+
func TestStatsReporter(t *testing.T) {
23+
t.Parallel()
24+
ctx := testutil.Context(t, testutil.WaitShort)
25+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
26+
fSource := newFakeNetworkStatsSource(ctx, t)
27+
fCollector := newFakeCollector(t)
28+
fDest := newFakeStatsDest()
29+
uut := newStatsReporter(logger, fSource, fCollector)
30+
31+
loopErr := make(chan error, 1)
32+
loopCtx, loopCancel := context.WithCancel(ctx)
33+
go func() {
34+
err := uut.reportLoop(loopCtx, fDest)
35+
loopErr <- err
36+
}()
37+
38+
// initial request to get duration
39+
req := testutil.RequireRecvCtx(ctx, t, fDest.reqs)
40+
require.NotNil(t, req)
41+
require.Nil(t, req.Stats)
42+
interval := time.Second * 34
43+
testutil.RequireSendCtx(ctx, t, fDest.resps, &proto.UpdateStatsResponse{ReportInterval: durationpb.New(interval)})
44+
45+
// call to source to set the callback and interval
46+
gotInterval := testutil.RequireRecvCtx(ctx, t, fSource.period)
47+
require.Equal(t, interval, gotInterval)
48+
49+
// callback returning netstats
50+
netStats := map[netlogtype.Connection]netlogtype.Counts{
51+
{
52+
Proto: ipproto.TCP,
53+
Src: netip.MustParseAddrPort("192.168.1.33:4887"),
54+
Dst: netip.MustParseAddrPort("192.168.2.99:9999"),
55+
}: {
56+
TxPackets: 22,
57+
TxBytes: 23,
58+
RxPackets: 24,
59+
RxBytes: 25,
60+
},
61+
}
62+
fSource.callback(time.Now(), time.Now(), netStats, nil)
63+
64+
// collector called to complete the stats
65+
gotNetStats := testutil.RequireRecvCtx(ctx, t, fCollector.calls)
66+
require.Equal(t, netStats, gotNetStats)
67+
68+
// while we are collecting the stats, send in two new netStats to simulate
69+
// what happens if we don't keep up. Only the latest should be kept.
70+
netStats0 := map[netlogtype.Connection]netlogtype.Counts{
71+
{
72+
Proto: ipproto.TCP,
73+
Src: netip.MustParseAddrPort("192.168.1.33:4887"),
74+
Dst: netip.MustParseAddrPort("192.168.2.99:9999"),
75+
}: {
76+
TxPackets: 10,
77+
TxBytes: 10,
78+
RxPackets: 10,
79+
RxBytes: 10,
80+
},
81+
}
82+
fSource.callback(time.Now(), time.Now(), netStats0, nil)
83+
netStats1 := map[netlogtype.Connection]netlogtype.Counts{
84+
{
85+
Proto: ipproto.TCP,
86+
Src: netip.MustParseAddrPort("192.168.1.33:4887"),
87+
Dst: netip.MustParseAddrPort("192.168.2.99:9999"),
88+
}: {
89+
TxPackets: 11,
90+
TxBytes: 11,
91+
RxPackets: 11,
92+
RxBytes: 11,
93+
},
94+
}
95+
fSource.callback(time.Now(), time.Now(), netStats1, nil)
96+
97+
// complete first collection
98+
stats := &proto.Stats{SessionCountJetbrains: 55}
99+
testutil.RequireSendCtx(ctx, t, fCollector.stats, stats)
100+
101+
// destination called to report the first stats
102+
update := testutil.RequireRecvCtx(ctx, t, fDest.reqs)
103+
require.NotNil(t, update)
104+
require.Equal(t, stats, update.Stats)
105+
testutil.RequireSendCtx(ctx, t, fDest.resps, &proto.UpdateStatsResponse{ReportInterval: durationpb.New(interval)})
106+
107+
// second update -- only netStats1 is reported
108+
gotNetStats = testutil.RequireRecvCtx(ctx, t, fCollector.calls)
109+
require.Equal(t, netStats1, gotNetStats)
110+
stats = &proto.Stats{SessionCountJetbrains: 66}
111+
testutil.RequireSendCtx(ctx, t, fCollector.stats, stats)
112+
update = testutil.RequireRecvCtx(ctx, t, fDest.reqs)
113+
require.NotNil(t, update)
114+
require.Equal(t, stats, update.Stats)
115+
testutil.RequireSendCtx(ctx, t, fDest.resps, &proto.UpdateStatsResponse{ReportInterval: durationpb.New(interval)})
116+
117+
loopCancel()
118+
err := testutil.RequireRecvCtx(ctx, t, loopErr)
119+
require.NoError(t, err)
120+
}
121+
122+
type fakeNetworkStatsSource struct {
123+
sync.Mutex
124+
ctx context.Context
125+
t testing.TB
126+
callback func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts)
127+
period chan time.Duration
128+
}
129+
130+
func (f *fakeNetworkStatsSource) SetConnStatsCallback(maxPeriod time.Duration, _ int, dump func(start time.Time, end time.Time, virtual map[netlogtype.Connection]netlogtype.Counts, physical map[netlogtype.Connection]netlogtype.Counts)) {
131+
f.Lock()
132+
defer f.Unlock()
133+
f.callback = dump
134+
select {
135+
case <-f.ctx.Done():
136+
f.t.Error("timeout")
137+
case f.period <- maxPeriod:
138+
// OK
139+
}
140+
}
141+
142+
func newFakeNetworkStatsSource(ctx context.Context, t testing.TB) *fakeNetworkStatsSource {
143+
f := &fakeNetworkStatsSource{
144+
ctx: ctx,
145+
t: t,
146+
period: make(chan time.Duration),
147+
}
148+
return f
149+
}
150+
151+
type fakeCollector struct {
152+
t testing.TB
153+
calls chan map[netlogtype.Connection]netlogtype.Counts
154+
stats chan *proto.Stats
155+
}
156+
157+
func (f *fakeCollector) Collect(ctx context.Context, networkStats map[netlogtype.Connection]netlogtype.Counts) *proto.Stats {
158+
select {
159+
case <-ctx.Done():
160+
f.t.Error("timeout on collect")
161+
return nil
162+
case f.calls <- networkStats:
163+
// ok
164+
}
165+
select {
166+
case <-ctx.Done():
167+
f.t.Error("timeout on collect")
168+
return nil
169+
case s := <-f.stats:
170+
return s
171+
}
172+
}
173+
174+
func newFakeCollector(t testing.TB) *fakeCollector {
175+
return &fakeCollector{
176+
t: t,
177+
calls: make(chan map[netlogtype.Connection]netlogtype.Counts),
178+
stats: make(chan *proto.Stats),
179+
}
180+
}
181+
182+
type fakeStatsDest struct {
183+
reqs chan *proto.UpdateStatsRequest
184+
resps chan *proto.UpdateStatsResponse
185+
}
186+
187+
func (f *fakeStatsDest) UpdateStats(ctx context.Context, req *proto.UpdateStatsRequest) (*proto.UpdateStatsResponse, error) {
188+
select {
189+
case <-ctx.Done():
190+
return nil, ctx.Err()
191+
case f.reqs <- req:
192+
// OK
193+
}
194+
select {
195+
case <-ctx.Done():
196+
return nil, ctx.Err()
197+
case resp := <-f.resps:
198+
return resp, nil
199+
}
200+
}
201+
202+
func newFakeStatsDest() *fakeStatsDest {
203+
return &fakeStatsDest{
204+
reqs: make(chan *proto.UpdateStatsRequest),
205+
resps: make(chan *proto.UpdateStatsResponse),
206+
}
207+
}

0 commit comments

Comments
 (0)