Skip to content

Commit c8d65de

Browse files
authored
test(agent): fix TestAgent_Metadata/Once flake (#8613)
1 parent deb9261 commit c8d65de

File tree

4 files changed

+87
-41
lines changed

4 files changed

+87
-41
lines changed

agent/agent.go

+65-33
Original file line numberDiff line numberDiff line change
@@ -242,15 +242,15 @@ func (a *agent) runLoop(ctx context.Context) {
242242
}
243243
}
244244

245-
func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentMetadataDescription) *codersdk.WorkspaceAgentMetadataResult {
245+
func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentMetadataDescription, now time.Time) *codersdk.WorkspaceAgentMetadataResult {
246246
var out bytes.Buffer
247247
result := &codersdk.WorkspaceAgentMetadataResult{
248248
// CollectedAt is set here for testing purposes and overrode by
249249
// coderd to the time of server receipt to solve clock skew.
250250
//
251251
// In the future, the server may accept the timestamp from the agent
252252
// if it can guarantee the clocks are synchronized.
253-
CollectedAt: time.Now(),
253+
CollectedAt: now,
254254
}
255255
cmdPty, err := a.sshServer.CreateCommand(ctx, md.Script, nil)
256256
if err != nil {
@@ -298,54 +298,64 @@ type metadataResultAndKey struct {
298298
}
299299

300300
type trySingleflight struct {
301-
m sync.Map
301+
mu sync.Mutex
302+
m map[string]struct{}
302303
}
303304

304305
func (t *trySingleflight) Do(key string, fn func()) {
305-
_, loaded := t.m.LoadOrStore(key, struct{}{})
306-
if !loaded {
307-
// There is already a goroutine running for this key.
306+
t.mu.Lock()
307+
_, ok := t.m[key]
308+
if ok {
309+
t.mu.Unlock()
308310
return
309311
}
310312

311-
defer t.m.Delete(key)
313+
t.m[key] = struct{}{}
314+
t.mu.Unlock()
315+
defer func() {
316+
t.mu.Lock()
317+
delete(t.m, key)
318+
t.mu.Unlock()
319+
}()
320+
312321
fn()
313322
}
314323

315324
func (a *agent) reportMetadataLoop(ctx context.Context) {
316325
const metadataLimit = 128
317326

318327
var (
319-
baseTicker = time.NewTicker(a.reportMetadataInterval)
320-
lastCollectedAts = make(map[string]time.Time)
321-
metadataResults = make(chan metadataResultAndKey, metadataLimit)
328+
baseTicker = time.NewTicker(a.reportMetadataInterval)
329+
lastCollectedAtMu sync.RWMutex
330+
lastCollectedAts = make(map[string]time.Time)
331+
metadataResults = make(chan metadataResultAndKey, metadataLimit)
332+
logger = a.logger.Named("metadata")
322333
)
323334
defer baseTicker.Stop()
324335

325336
// We use a custom singleflight that immediately returns if there is already
326337
// a goroutine running for a given key. This is to prevent a build-up of
327338
// goroutines waiting on Do when the script takes many multiples of
328339
// baseInterval to run.
329-
var flight trySingleflight
340+
flight := trySingleflight{m: map[string]struct{}{}}
330341

331342
for {
332343
select {
333344
case <-ctx.Done():
334345
return
335346
case mr := <-metadataResults:
336-
lastCollectedAts[mr.key] = mr.result.CollectedAt
337347
err := a.client.PostMetadata(ctx, mr.key, *mr.result)
338348
if err != nil {
339349
a.logger.Error(ctx, "agent failed to report metadata", slog.Error(err))
340350
}
351+
continue
341352
case <-baseTicker.C:
342353
}
343354

344355
if len(metadataResults) > 0 {
345356
// The inner collection loop expects the channel is empty before spinning up
346357
// all the collection goroutines.
347-
a.logger.Debug(
348-
ctx, "metadata collection backpressured",
358+
logger.Debug(ctx, "metadata collection backpressured",
349359
slog.F("queue_len", len(metadataResults)),
350360
)
351361
continue
@@ -357,7 +367,7 @@ func (a *agent) reportMetadataLoop(ctx context.Context) {
357367
}
358368

359369
if len(manifest.Metadata) > metadataLimit {
360-
a.logger.Error(
370+
logger.Error(
361371
ctx, "metadata limit exceeded",
362372
slog.F("limit", metadataLimit), slog.F("got", len(manifest.Metadata)),
363373
)
@@ -367,51 +377,73 @@ func (a *agent) reportMetadataLoop(ctx context.Context) {
367377
// If the manifest changes (e.g. on agent reconnect) we need to
368378
// purge old cache values to prevent lastCollectedAt from growing
369379
// boundlessly.
380+
lastCollectedAtMu.Lock()
370381
for key := range lastCollectedAts {
371382
if slices.IndexFunc(manifest.Metadata, func(md codersdk.WorkspaceAgentMetadataDescription) bool {
372383
return md.Key == key
373384
}) < 0 {
385+
logger.Debug(ctx, "deleting lastCollected key, missing from manifest",
386+
slog.F("key", key),
387+
)
374388
delete(lastCollectedAts, key)
375389
}
376390
}
391+
lastCollectedAtMu.Unlock()
377392

378393
// Spawn a goroutine for each metadata collection, and use a
379394
// channel to synchronize the results and avoid both messy
380395
// mutex logic and overloading the API.
381396
for _, md := range manifest.Metadata {
382-
collectedAt, ok := lastCollectedAts[md.Key]
383-
if ok {
384-
// If the interval is zero, we assume the user just wants
385-
// a single collection at startup, not a spinning loop.
386-
if md.Interval == 0 {
387-
continue
388-
}
389-
// The last collected value isn't quite stale yet, so we skip it.
390-
if collectedAt.Add(a.reportMetadataInterval).After(time.Now()) {
391-
continue
392-
}
393-
}
394-
395397
md := md
396398
// We send the result to the channel in the goroutine to avoid
397399
// sending the same result multiple times. So, we don't care about
398400
// the return values.
399401
go flight.Do(md.Key, func() {
402+
ctx := slog.With(ctx, slog.F("key", md.Key))
403+
lastCollectedAtMu.RLock()
404+
collectedAt, ok := lastCollectedAts[md.Key]
405+
lastCollectedAtMu.RUnlock()
406+
if ok {
407+
// If the interval is zero, we assume the user just wants
408+
// a single collection at startup, not a spinning loop.
409+
if md.Interval == 0 {
410+
return
411+
}
412+
// The last collected value isn't quite stale yet, so we skip it.
413+
if collectedAt.Add(a.reportMetadataInterval).After(time.Now()) {
414+
return
415+
}
416+
}
417+
400418
timeout := md.Timeout
401419
if timeout == 0 {
402-
timeout = md.Interval
420+
if md.Interval != 0 {
421+
timeout = md.Interval
422+
} else if interval := int64(a.reportMetadataInterval.Seconds()); interval != 0 {
423+
// Fallback to the report interval
424+
timeout = interval * 3
425+
} else {
426+
// If the interval is still 0 (possible if the interval
427+
// is less than a second), default to 5. This was
428+
// randomly picked.
429+
timeout = 5
430+
}
403431
}
404-
ctx, cancel := context.WithTimeout(ctx,
405-
time.Duration(timeout)*time.Second,
406-
)
432+
ctxTimeout := time.Duration(timeout) * time.Second
433+
ctx, cancel := context.WithTimeout(ctx, ctxTimeout)
407434
defer cancel()
408435

436+
now := time.Now()
409437
select {
410438
case <-ctx.Done():
439+
logger.Warn(ctx, "metadata collection timed out", slog.F("timeout", ctxTimeout))
411440
case metadataResults <- metadataResultAndKey{
412441
key: md.Key,
413-
result: a.collectMetadata(ctx, md),
442+
result: a.collectMetadata(ctx, md, now),
414443
}:
444+
lastCollectedAtMu.Lock()
445+
lastCollectedAts[md.Key] = now
446+
lastCollectedAtMu.Unlock()
415447
}
416448
})
417449
}

agent/agent_test.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,7 @@ func TestAgent_StartupScript(t *testing.T) {
10661066
t.Parallel()
10671067
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
10681068
client := agenttest.NewClient(t,
1069+
logger,
10691070
uuid.New(),
10701071
agentsdk.Manifest{
10711072
StartupScript: command,
@@ -1097,6 +1098,7 @@ func TestAgent_StartupScript(t *testing.T) {
10971098
t.Parallel()
10981099
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
10991100
client := agenttest.NewClient(t,
1101+
logger,
11001102
uuid.New(),
11011103
agentsdk.Manifest{
11021104
StartupScript: command,
@@ -1470,6 +1472,7 @@ func TestAgent_Lifecycle(t *testing.T) {
14701472
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
14711473

14721474
client := agenttest.NewClient(t,
1475+
logger,
14731476
uuid.New(),
14741477
agentsdk.Manifest{
14751478
DERPMap: derpMap,
@@ -1742,6 +1745,7 @@ func TestAgent_Reconnect(t *testing.T) {
17421745
statsCh := make(chan *agentsdk.Stats, 50)
17431746
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
17441747
client := agenttest.NewClient(t,
1748+
logger,
17451749
agentID,
17461750
agentsdk.Manifest{
17471751
DERPMap: derpMap,
@@ -1776,6 +1780,7 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) {
17761780
defer coordinator.Close()
17771781

17781782
client := agenttest.NewClient(t,
1783+
logger,
17791784
uuid.New(),
17801785
agentsdk.Manifest{
17811786
GitAuthConfigs: 1,
@@ -1900,7 +1905,7 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
19001905
})
19011906
statsCh := make(chan *agentsdk.Stats, 50)
19021907
fs := afero.NewMemMapFs()
1903-
c := agenttest.NewClient(t, metadata.AgentID, metadata, statsCh, coordinator)
1908+
c := agenttest.NewClient(t, logger.Named("agent"), metadata.AgentID, metadata, statsCh, coordinator)
19041909

19051910
options := agent.Options{
19061911
Client: c,

agent/agenttest/client.go

+15-6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
)
1919

2020
func NewClient(t testing.TB,
21+
logger slog.Logger,
2122
agentID uuid.UUID,
2223
manifest agentsdk.Manifest,
2324
statsChan chan *agentsdk.Stats,
@@ -28,6 +29,7 @@ func NewClient(t testing.TB,
2829
}
2930
return &Client{
3031
t: t,
32+
logger: logger.Named("client"),
3133
agentID: agentID,
3234
manifest: manifest,
3335
statsChan: statsChan,
@@ -37,6 +39,7 @@ func NewClient(t testing.TB,
3739

3840
type Client struct {
3941
t testing.TB
42+
logger slog.Logger
4043
agentID uuid.UUID
4144
manifest agentsdk.Manifest
4245
metadata map[string]agentsdk.PostMetadataRequest
@@ -110,14 +113,16 @@ func (c *Client) GetLifecycleStates() []codersdk.WorkspaceAgentLifecycle {
110113
return c.lifecycleStates
111114
}
112115

113-
func (c *Client) PostLifecycle(_ context.Context, req agentsdk.PostLifecycleRequest) error {
116+
func (c *Client) PostLifecycle(ctx context.Context, req agentsdk.PostLifecycleRequest) error {
114117
c.mu.Lock()
115118
defer c.mu.Unlock()
116119
c.lifecycleStates = append(c.lifecycleStates, req.State)
120+
c.logger.Debug(ctx, "post lifecycle", slog.F("req", req))
117121
return nil
118122
}
119123

120-
func (*Client) PostAppHealth(_ context.Context, _ agentsdk.PostAppHealthsRequest) error {
124+
func (c *Client) PostAppHealth(ctx context.Context, req agentsdk.PostAppHealthsRequest) error {
125+
c.logger.Debug(ctx, "post app health", slog.F("req", req))
121126
return nil
122127
}
123128

@@ -133,20 +138,22 @@ func (c *Client) GetMetadata() map[string]agentsdk.PostMetadataRequest {
133138
return maps.Clone(c.metadata)
134139
}
135140

136-
func (c *Client) PostMetadata(_ context.Context, key string, req agentsdk.PostMetadataRequest) error {
141+
func (c *Client) PostMetadata(ctx context.Context, key string, req agentsdk.PostMetadataRequest) error {
137142
c.mu.Lock()
138143
defer c.mu.Unlock()
139144
if c.metadata == nil {
140145
c.metadata = make(map[string]agentsdk.PostMetadataRequest)
141146
}
142147
c.metadata[key] = req
148+
c.logger.Debug(ctx, "post metadata", slog.F("key", key), slog.F("req", req))
143149
return nil
144150
}
145151

146-
func (c *Client) PostStartup(_ context.Context, startup agentsdk.PostStartupRequest) error {
152+
func (c *Client) PostStartup(ctx context.Context, startup agentsdk.PostStartupRequest) error {
147153
c.mu.Lock()
148154
defer c.mu.Unlock()
149155
c.startup = startup
156+
c.logger.Debug(ctx, "post startup", slog.F("req", startup))
150157
return nil
151158
}
152159

@@ -156,13 +163,14 @@ func (c *Client) GetStartupLogs() []agentsdk.StartupLog {
156163
return c.logs
157164
}
158165

159-
func (c *Client) PatchStartupLogs(_ context.Context, logs agentsdk.PatchStartupLogs) error {
166+
func (c *Client) PatchStartupLogs(ctx context.Context, logs agentsdk.PatchStartupLogs) error {
160167
c.mu.Lock()
161168
defer c.mu.Unlock()
162169
if c.PatchWorkspaceLogs != nil {
163170
return c.PatchWorkspaceLogs()
164171
}
165172
c.logs = append(c.logs, logs.Logs...)
173+
c.logger.Debug(ctx, "patch startup logs", slog.F("req", logs))
166174
return nil
167175
}
168176

@@ -173,9 +181,10 @@ func (c *Client) SetServiceBannerFunc(f func() (codersdk.ServiceBannerConfig, er
173181
c.GetServiceBannerFunc = f
174182
}
175183

176-
func (c *Client) GetServiceBanner(_ context.Context) (codersdk.ServiceBannerConfig, error) {
184+
func (c *Client) GetServiceBanner(ctx context.Context) (codersdk.ServiceBannerConfig, error) {
177185
c.mu.Lock()
178186
defer c.mu.Unlock()
187+
c.logger.Debug(ctx, "get service banner")
179188
if c.GetServiceBannerFunc != nil {
180189
return c.GetServiceBannerFunc()
181190
}

coderd/tailnet_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
176176
_ = coord.Close()
177177
})
178178

179-
c := agenttest.NewClient(t, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord)
179+
c := agenttest.NewClient(t, logger, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord)
180180

181181
options := agent.Options{
182182
Client: c,

0 commit comments

Comments
 (0)