Skip to content

Commit 939a07d

Browse files
committed
Use more scalable connection tracking
1 parent 58e55a2 commit 939a07d

File tree

6 files changed

+143
-52
lines changed

6 files changed

+143
-52
lines changed

agent/agent.go

+8-2
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ func New(dialer Dialer, options *Options) io.Closer {
9595
postKeys: options.UploadWireguardKeys,
9696
listenWireguardPeers: options.ListenWireguardPeers,
9797
stats: &Stats{
98-
ActiveConns: make(map[int64]*ConnStats),
98+
ProtocolStats: make(map[string]*ProtocolStats),
9999
},
100100
statsReporter: options.StatsReporter,
101101
}
@@ -350,13 +350,19 @@ func (a *agent) init(ctx context.Context) {
350350

351351
go a.run(ctx)
352352
if a.statsReporter != nil {
353-
err := a.statsReporter(ctx, a.logger, func() *Stats {
353+
cl, err := a.statsReporter(ctx, a.logger, func() *Stats {
354354
return a.stats.Copy()
355355
})
356356
if err != nil {
357357
a.logger.Error(ctx, "report stats", slog.Error(err))
358358
return
359359
}
360+
a.connCloseWait.Add(1)
361+
go func() {
362+
defer a.connCloseWait.Done()
363+
<-a.closed
364+
cl.Close()
365+
}()
360366
}
361367
}
362368

agent/stats.go

+42-33
Original file line numberDiff line numberDiff line change
@@ -2,87 +2,96 @@ package agent
22

33
import (
44
"context"
5+
"io"
56
"net"
67
"sync"
78
"sync/atomic"
8-
"time"
9-
10-
"golang.org/x/exp/maps"
119

1210
"cdr.dev/slog"
1311
)
1412

1513
// ConnStats wraps a net.Conn with statistics.
1614
type ConnStats struct {
17-
CreatedAt time.Time `json:"created_at,omitempty"`
18-
Protocol string `json:"protocol,omitempty"`
19-
20-
// RxBytes must be read with atomic.
21-
RxBytes uint64 `json:"rx_bytes,omitempty"`
22-
23-
// TxBytes must be read with atomic.
24-
TxBytes uint64 `json:"tx_bytes,omitempty"`
15+
*ProtocolStats
2516
net.Conn `json:"-"`
2617
}
2718

2819
var _ net.Conn = new(ConnStats)
2920

3021
func (c *ConnStats) Read(b []byte) (n int, err error) {
3122
n, err = c.Conn.Read(b)
32-
atomic.AddUint64(&c.RxBytes, uint64(n))
23+
atomic.AddInt64(&c.RxBytes, int64(n))
3324
return n, err
3425
}
3526

3627
func (c *ConnStats) Write(b []byte) (n int, err error) {
3728
n, err = c.Conn.Write(b)
38-
atomic.AddUint64(&c.TxBytes, uint64(n))
29+
atomic.AddInt64(&c.TxBytes, int64(n))
3930
return n, err
4031
}
4132

33+
type ProtocolStats struct {
34+
NumConns int64 `json:"num_comms,omitempty"`
35+
36+
// RxBytes must be read with atomic.
37+
RxBytes int64 `json:"rx_bytes,omitempty"`
38+
39+
// TxBytes must be read with atomic.
40+
TxBytes int64 `json:"tx_bytes,omitempty"`
41+
}
42+
4243
var _ net.Conn = new(ConnStats)
4344

4445
// Stats records the Agent's network connection statistics for use in
4546
// user-facing metrics and debugging.
4647
type Stats struct {
47-
sync.RWMutex `json:"-"`
48-
// ActiveConns are identified by their start time in nanoseconds.
49-
ActiveConns map[int64]*ConnStats `json:"active_conns,omitempty"`
48+
sync.RWMutex `json:"-"`
49+
ProtocolStats map[string]*ProtocolStats `json:"conn_stats,omitempty"`
5050
}
5151

5252
func (s *Stats) Copy() *Stats {
5353
s.RLock()
54-
ss := &Stats{
55-
ActiveConns: maps.Clone(s.ActiveConns),
54+
ss := Stats{ProtocolStats: make(map[string]*ProtocolStats, len(s.ProtocolStats))}
55+
for k, cs := range s.ProtocolStats {
56+
ss.ProtocolStats[k] = &ProtocolStats{
57+
NumConns: atomic.LoadInt64(&cs.NumConns),
58+
RxBytes: atomic.LoadInt64(&cs.RxBytes),
59+
TxBytes: atomic.LoadInt64(&cs.TxBytes),
60+
}
5661
}
5762
s.RUnlock()
58-
return ss
63+
return &ss
5964
}
6065

6166
// goConn launches a new connection-processing goroutine, account for
6267
// s.Conns in a thread-safe manner.
6368
func (s *Stats) goConn(conn net.Conn, protocol string, fn func(conn net.Conn)) {
64-
sc := &ConnStats{
65-
CreatedAt: time.Now(),
66-
Protocol: protocol,
67-
Conn: conn,
68-
}
69-
70-
key := sc.CreatedAt.UnixNano()
71-
7269
s.Lock()
73-
s.ActiveConns[key] = sc
70+
ps, ok := s.ProtocolStats[protocol]
71+
if !ok {
72+
ps = &ProtocolStats{}
73+
s.ProtocolStats[protocol] = ps
74+
}
7475
s.Unlock()
7576

77+
cs := &ConnStats{
78+
ProtocolStats: ps,
79+
Conn: conn,
80+
}
81+
7682
go func() {
83+
atomic.AddInt64(&ps.NumConns, 1)
7784
defer func() {
78-
s.Lock()
79-
delete(s.ActiveConns, key)
80-
s.Unlock()
85+
atomic.AddInt64(&ps.NumConns, -1)
8186
}()
8287

83-
fn(sc)
88+
fn(cs)
8489
}()
8590
}
8691

8792
// StatsReporter periodically accept and records agent stats.
88-
type StatsReporter func(ctx context.Context, log slog.Logger, stats func() *Stats) error
93+
type StatsReporter func(
94+
ctx context.Context,
95+
log slog.Logger,
96+
stats func() *Stats,
97+
) (io.Closer, error)

coderd/database/databasefake/databasefake.go

+18
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ func New() database.Store {
2323
mutex: &sync.RWMutex{},
2424
data: &data{
2525
apiKeys: make([]database.APIKey, 0),
26+
agentStats: make([]database.AgentStat, 0),
2627
organizationMembers: make([]database.OrganizationMember, 0),
2728
organizations: make([]database.Organization, 0),
2829
users: make([]database.User, 0),
@@ -78,6 +79,7 @@ type data struct {
7879
userLinks []database.UserLink
7980

8081
// New tables
82+
agentStats []database.AgentStat
8183
auditLogs []database.AuditLog
8284
files []database.File
8385
gitSSHKey []database.GitSSHKey
@@ -135,6 +137,22 @@ func (q *fakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.Acqu
135137
return database.ProvisionerJob{}, sql.ErrNoRows
136138
}
137139

140+
func (q *fakeQuerier) InsertAgentStat(_ context.Context, p database.InsertAgentStatParams) (database.AgentStat, error) {
141+
q.mutex.Lock()
142+
defer q.mutex.Unlock()
143+
144+
stat := database.AgentStat{
145+
ID: p.ID,
146+
CreatedAt: p.CreatedAt,
147+
WorkspaceID: p.WorkspaceID,
148+
AgentID: p.AgentID,
149+
UserID: p.UserID,
150+
Payload: p.Payload,
151+
}
152+
q.agentStats = append(q.agentStats, stat)
153+
return stat, nil
154+
}
155+
138156
func (q *fakeQuerier) ParameterValue(_ context.Context, id uuid.UUID) (database.ParameterValue, error) {
139157
q.mutex.Lock()
140158
defer q.mutex.Unlock()

coderd/workspaceagents.go

+47-4
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,28 @@ func (api *API) workspaceAgentReportStats(rw http.ResponseWriter, r *http.Reques
164164
defer api.websocketWaitGroup.Done()
165165

166166
workspaceAgent := httpmw.WorkspaceAgent(r)
167-
workspace, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID)
167+
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID)
168168
if err != nil {
169169
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
170-
Message: "Failed to accept websocket.",
170+
Message: "Failed to get workspace resource.",
171+
Detail: err.Error(),
172+
})
173+
return
174+
}
175+
176+
build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID)
177+
if err != nil {
178+
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
179+
Message: "Failed to get build.",
180+
Detail: err.Error(),
181+
})
182+
return
183+
}
184+
185+
workspace, err := api.Database.GetWorkspaceByID(r.Context(), build.WorkspaceID)
186+
if err != nil {
187+
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
188+
Message: "Failed to get workspace.",
171189
Detail: err.Error(),
172190
})
173191
return
@@ -191,7 +209,7 @@ func (api *API) workspaceAgentReportStats(rw http.ResponseWriter, r *http.Reques
191209
}
192210

193211
ctx := r.Context()
194-
timer := time.NewTimer(interval)
212+
timer := time.NewTicker(interval)
195213
for {
196214
err := wsjson.Write(ctx, conn, codersdk.AgentStatsReportRequest{})
197215
if err != nil {
@@ -212,11 +230,36 @@ func (api *API) workspaceAgentReportStats(rw http.ResponseWriter, r *http.Reques
212230
return
213231
}
214232

233+
repJSON, err := json.Marshal(rep)
234+
if err != nil {
235+
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
236+
Message: "Failed to marshal stat json.",
237+
Detail: err.Error(),
238+
})
239+
return
240+
}
241+
215242
api.Logger.Debug(ctx, "read stats report",
216243
slog.F("agent", workspaceAgent.ID),
244+
slog.F("resource", resource.ID),
217245
slog.F("workspace", workspace.ID),
218-
slog.F("report", rep),
246+
slog.F("conns", rep.ProtocolStats),
219247
)
248+
_, err = api.Database.InsertAgentStat(ctx, database.InsertAgentStatParams{
249+
ID: uuid.NewString(),
250+
CreatedAt: time.Now(),
251+
AgentID: workspaceAgent.ID,
252+
WorkspaceID: build.WorkspaceID,
253+
UserID: workspace.OwnerID,
254+
Payload: json.RawMessage(repJSON),
255+
})
256+
if err != nil {
257+
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
258+
Message: "Failed to insert agent stat.",
259+
Detail: err.Error(),
260+
})
261+
return
262+
}
220263

221264
select {
222265
case <-timer.C:

coderd/workspaceagents_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ func TestWorkspaceReportStats(t *testing.T) {
137137
_, err = session.Output("echo hello")
138138
require.NoError(t, err)
139139

140-
time.Sleep(time.Second * 10)
140+
time.Sleep(time.Second * 1)
141141
require.NoError(t, err)
142142
}
143143

codersdk/workspaceagents.go

+27-12
Original file line numberDiff line numberDiff line change
@@ -484,28 +484,44 @@ func (c *Client) turnProxyDialer(ctx context.Context, httpClient *http.Client, p
484484
})
485485
}
486486

487+
type CloseFunc func() error
488+
489+
func (c CloseFunc) Close() error {
490+
return c()
491+
}
492+
487493
// AgentReportStats begins a stat streaming connection with the Coder server.
488494
// It is resilient to network failures and intermittent coderd issues.
489-
func (c *Client) AgentReportStats(ctx context.Context, log slog.Logger, stats func() *agent.Stats) error {
495+
func (c *Client) AgentReportStats(
496+
ctx context.Context,
497+
log slog.Logger,
498+
stats func() *agent.Stats,
499+
) (io.Closer, error) {
490500
serverURL, err := c.URL.Parse("/api/v2/workspaceagents/me/report-stats")
491501
if err != nil {
492-
return xerrors.Errorf("parse url: %w", err)
502+
return nil, xerrors.Errorf("parse url: %w", err)
493503
}
494504

495505
jar, err := cookiejar.New(nil)
496506
if err != nil {
497-
return xerrors.Errorf("create cookie jar: %w", err)
507+
return nil, xerrors.Errorf("create cookie jar: %w", err)
498508
}
499509

500510
jar.SetCookies(serverURL, []*http.Cookie{{
501511
Name: SessionTokenKey,
502512
Value: c.SessionToken,
503513
}})
514+
504515
httpClient := &http.Client{
505516
Jar: jar,
506517
}
507518

519+
doneCh := make(chan struct{})
520+
ctx, cancel := context.WithCancel(ctx)
521+
508522
go func() {
523+
defer close(doneCh)
524+
509525
for r := retry.New(time.Second, time.Hour); r.Wait(ctx); {
510526
err = func() error {
511527
conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{
@@ -527,13 +543,8 @@ func (c *Client) AgentReportStats(ctx context.Context, log slog.Logger, stats fu
527543
return err
528544
}
529545

530-
s := stats()
531546
resp := AgentStatsReportResponse{
532-
Conns: make([]agent.ConnStats, 0, len(s.ActiveConns)),
533-
}
534-
535-
for _, cs := range s.ActiveConns {
536-
resp.Conns = append(resp.Conns, *cs)
547+
ProtocolStats: stats().ProtocolStats,
537548
}
538549

539550
err = wsjson.Write(ctx, conn, resp)
@@ -542,13 +553,17 @@ func (c *Client) AgentReportStats(ctx context.Context, log slog.Logger, stats fu
542553
}
543554
}
544555
}()
545-
if err != nil {
556+
if err != nil && ctx.Err() == nil {
546557
log.Error(ctx, "report stats", slog.Error(err))
547558
}
548559
}
549560
}()
550561

551-
return nil
562+
return CloseFunc(func() error {
563+
cancel()
564+
<-doneCh
565+
return nil
566+
}), nil
552567
}
553568

554569
// AgentStatsReportRequest is a WebSocket request by coderd
@@ -559,5 +574,5 @@ type AgentStatsReportRequest struct {
559574
// AgentStatsReportResponse is returned for each report
560575
// request by the agent.
561576
type AgentStatsReportResponse struct {
562-
Conns []agent.ConnStats `json:"conns,omitempty"`
577+
ProtocolStats map[string]*agent.ProtocolStats `json:"conn_stats,omitempty"`
563578
}

0 commit comments

Comments
 (0)