Skip to content

Commit 82b8659

Browse files
committed
feat(agent): send metadata in batches
1 parent eeb4adb commit 82b8659

File tree

2 files changed

+203
-131
lines changed

2 files changed

+203
-131
lines changed

agent/agent.go

Lines changed: 185 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -362,147 +362,210 @@ func (t *trySingleflight) Do(key string, fn func()) {
362362
}
363363

364364
func (a *agent) reportMetadataLoop(ctx context.Context) {
365-
const metadataLimit = 128
365+
tickerDone := make(chan struct{})
366+
collectDone := make(chan struct{})
367+
ctx, cancel := context.WithCancel(ctx)
368+
defer func() {
369+
cancel()
370+
<-collectDone
371+
<-tickerDone
372+
}()
366373

367374
var (
368-
baseTicker = time.NewTicker(a.reportMetadataInterval)
369-
lastCollectedAtMu sync.RWMutex
370-
lastCollectedAts = make(map[string]time.Time)
371-
metadataResults = make(chan metadataResultAndKey, metadataLimit)
372-
logger = a.logger.Named("metadata")
375+
logger = a.logger.Named("metadata")
376+
report = make(chan struct{}, 1)
377+
collect = make(chan struct{}, 1)
378+
metadataResults = make(chan metadataResultAndKey, 1)
373379
)
374-
defer baseTicker.Stop()
375-
376-
// We use a custom singleflight that immediately returns if there is already
377-
// a goroutine running for a given key. This is to prevent a build-up of
378-
// goroutines waiting on Do when the script takes many multiples of
379-
// baseInterval to run.
380-
flight := trySingleflight{m: map[string]struct{}{}}
381-
382-
postMetadata := func(mr metadataResultAndKey) {
383-
err := a.client.PostMetadata(ctx, agentsdk.PostMetadataRequest{
384-
Metadata: []agentsdk.Metadata{
385-
{
386-
Key: mr.key,
387-
WorkspaceAgentMetadataResult: *mr.result,
388-
},
389-
},
390-
})
391-
if err != nil {
392-
a.logger.Error(ctx, "agent failed to report metadata", slog.Error(err))
393-
}
394-
}
395380

396-
for {
397-
select {
398-
case <-ctx.Done():
399-
return
400-
case mr := <-metadataResults:
401-
postMetadata(mr)
402-
continue
403-
case <-baseTicker.C:
381+
// Set up collect and report as a single ticker with two channels,
382+
// this is to allow collection and reporting to be triggered
383+
// independently of each other.
384+
go func() {
385+
t := time.NewTicker(a.reportMetadataInterval)
386+
defer func() {
387+
t.Stop()
388+
close(report)
389+
close(collect)
390+
close(tickerDone)
391+
}()
392+
wake := func(c chan<- struct{}) {
393+
select {
394+
case c <- struct{}{}:
395+
default:
396+
}
404397
}
398+
wake(collect) // Start immediately.
405399

406-
if len(metadataResults) > 0 {
407-
// The inner collection loop expects the channel is empty before spinning up
408-
// all the collection goroutines.
409-
logger.Debug(ctx, "metadata collection backpressured",
410-
slog.F("queue_len", len(metadataResults)),
411-
)
412-
continue
400+
for {
401+
select {
402+
case <-ctx.Done():
403+
return
404+
case <-t.C:
405+
wake(report)
406+
wake(collect)
407+
}
413408
}
409+
}()
414410

415-
manifest := a.manifest.Load()
416-
if manifest == nil {
417-
continue
418-
}
411+
go func() {
412+
defer close(collectDone)
413+
414+
var (
415+
// We use a custom singleflight that immediately returns if there is already
416+
// a goroutine running for a given key. This is to prevent a build-up of
417+
// goroutines waiting on Do when the script takes many multiples of
418+
// baseInterval to run.
419+
flight = trySingleflight{m: map[string]struct{}{}}
420+
lastCollectedAtMu sync.RWMutex
421+
lastCollectedAts = make(map[string]time.Time)
422+
)
423+
for {
424+
select {
425+
case <-ctx.Done():
426+
return
427+
case <-collect:
428+
}
419429

420-
if len(manifest.Metadata) > metadataLimit {
421-
logger.Error(
422-
ctx, "metadata limit exceeded",
423-
slog.F("limit", metadataLimit), slog.F("got", len(manifest.Metadata)),
424-
)
425-
continue
426-
}
430+
manifest := a.manifest.Load()
431+
if manifest == nil {
432+
continue
433+
}
427434

428-
// If the manifest changes (e.g. on agent reconnect) we need to
429-
// purge old cache values to prevent lastCollectedAt from growing
430-
// boundlessly.
431-
lastCollectedAtMu.Lock()
432-
for key := range lastCollectedAts {
433-
if slices.IndexFunc(manifest.Metadata, func(md codersdk.WorkspaceAgentMetadataDescription) bool {
434-
return md.Key == key
435-
}) < 0 {
436-
logger.Debug(ctx, "deleting lastCollected key, missing from manifest",
437-
slog.F("key", key),
438-
)
439-
delete(lastCollectedAts, key)
435+
// If the manifest changes (e.g. on agent reconnect) we need to
436+
// purge old cache values to prevent lastCollectedAt from growing
437+
// boundlessly.
438+
lastCollectedAtMu.Lock()
439+
for key := range lastCollectedAts {
440+
if slices.IndexFunc(manifest.Metadata, func(md codersdk.WorkspaceAgentMetadataDescription) bool {
441+
return md.Key == key
442+
}) < 0 {
443+
logger.Debug(ctx, "deleting lastCollected key, missing from manifest",
444+
slog.F("key", key),
445+
)
446+
delete(lastCollectedAts, key)
447+
}
440448
}
441-
}
442-
lastCollectedAtMu.Unlock()
443-
444-
// Spawn a goroutine for each metadata collection, and use a
445-
// channel to synchronize the results and avoid both messy
446-
// mutex logic and overloading the API.
447-
for _, md := range manifest.Metadata {
448-
md := md
449-
// We send the result to the channel in the goroutine to avoid
450-
// sending the same result multiple times. So, we don't care about
451-
// the return values.
452-
go flight.Do(md.Key, func() {
453-
ctx := slog.With(ctx, slog.F("key", md.Key))
454-
lastCollectedAtMu.RLock()
455-
collectedAt, ok := lastCollectedAts[md.Key]
456-
lastCollectedAtMu.RUnlock()
457-
if ok {
458-
// If the interval is zero, we assume the user just wants
459-
// a single collection at startup, not a spinning loop.
460-
if md.Interval == 0 {
461-
return
449+
lastCollectedAtMu.Unlock()
450+
451+
// Spawn a goroutine for each metadata collection, and use a
452+
// channel to synchronize the results and avoid both messy
453+
// mutex logic and overloading the API.
454+
for _, md := range manifest.Metadata {
455+
md := md
456+
// We send the result to the channel in the goroutine to avoid
457+
// sending the same result multiple times. So, we don't care about
458+
// the return values.
459+
go flight.Do(md.Key, func() {
460+
ctx := slog.With(ctx, slog.F("key", md.Key))
461+
lastCollectedAtMu.RLock()
462+
collectedAt, ok := lastCollectedAts[md.Key]
463+
lastCollectedAtMu.RUnlock()
464+
if ok {
465+
// If the interval is zero, we assume the user just wants
466+
// a single collection at startup, not a spinning loop.
467+
if md.Interval == 0 {
468+
return
469+
}
470+
intervalUnit := time.Second
471+
// reportMetadataInterval is only less than a second in tests,
472+
// so adjust the interval unit for them.
473+
if a.reportMetadataInterval < time.Second {
474+
intervalUnit = 100 * time.Millisecond
475+
}
476+
// The last collected value isn't quite stale yet, so we skip it.
477+
if collectedAt.Add(time.Duration(md.Interval) * intervalUnit).After(time.Now()) {
478+
return
479+
}
462480
}
463-
intervalUnit := time.Second
464-
// reportMetadataInterval is only less than a second in tests,
465-
// so adjust the interval unit for them.
466-
if a.reportMetadataInterval < time.Second {
467-
intervalUnit = 100 * time.Millisecond
481+
482+
timeout := md.Timeout
483+
if timeout == 0 {
484+
if md.Interval != 0 {
485+
timeout = md.Interval
486+
} else if interval := int64(a.reportMetadataInterval.Seconds()); interval != 0 {
487+
// Fallback to the report interval
488+
timeout = interval * 3
489+
} else {
490+
// If the interval is still 0 (possible if the interval
491+
// is less than a second), default to 5. This was
492+
// randomly picked.
493+
timeout = 5
494+
}
468495
}
469-
// The last collected value isn't quite stale yet, so we skip it.
470-
if collectedAt.Add(time.Duration(md.Interval) * intervalUnit).After(time.Now()) {
471-
return
496+
ctxTimeout := time.Duration(timeout) * time.Second
497+
ctx, cancel := context.WithTimeout(ctx, ctxTimeout)
498+
defer cancel()
499+
500+
now := time.Now()
501+
select {
502+
case <-ctx.Done():
503+
logger.Warn(ctx, "metadata collection timed out", slog.F("timeout", ctxTimeout))
504+
case metadataResults <- metadataResultAndKey{
505+
key: md.Key,
506+
result: a.collectMetadata(ctx, md, now),
507+
}:
508+
lastCollectedAtMu.Lock()
509+
lastCollectedAts[md.Key] = now
510+
lastCollectedAtMu.Unlock()
472511
}
473-
}
512+
})
513+
}
514+
}
515+
}()
474516

475-
timeout := md.Timeout
476-
if timeout == 0 {
477-
if md.Interval != 0 {
478-
timeout = md.Interval
479-
} else if interval := int64(a.reportMetadataInterval.Seconds()); interval != 0 {
480-
// Fallback to the report interval
481-
timeout = interval * 3
482-
} else {
483-
// If the interval is still 0 (possible if the interval
484-
// is less than a second), default to 5. This was
485-
// randomly picked.
486-
timeout = 5
487-
}
517+
// Gather metadata updates and report them once every interval. If a
518+
// previous report is in flight, wait for it to complete before
519+
// sending a new one. If the network conditions are bad, we won't
520+
// benefit from canceling the previous send and starting a new one.
521+
var (
522+
updatedMetadata = make(map[string]*codersdk.WorkspaceAgentMetadataResult)
523+
reportTimeout = 30 * time.Second
524+
reportSemaphore = make(chan struct{}, 1)
525+
)
526+
reportSemaphore <- struct{}{}
527+
528+
for {
529+
select {
530+
case <-ctx.Done():
531+
return
532+
case mr := <-metadataResults:
533+
// This can overwrite unsent values, but that's fine because
534+
// we're only interested about up-to-date values.
535+
updatedMetadata[mr.key] = mr.result
536+
continue
537+
case <-report:
538+
if len(updatedMetadata) > 0 {
539+
metadata := make([]agentsdk.Metadata, 0, len(updatedMetadata))
540+
for key, result := range updatedMetadata {
541+
metadata = append(metadata, agentsdk.Metadata{
542+
Key: key,
543+
WorkspaceAgentMetadataResult: *result,
544+
})
545+
delete(updatedMetadata, key)
488546
}
489-
ctxTimeout := time.Duration(timeout) * time.Second
490-
ctx, cancel := context.WithTimeout(ctx, ctxTimeout)
491-
defer cancel()
492547

493-
now := time.Now()
494548
select {
495-
case <-ctx.Done():
496-
logger.Warn(ctx, "metadata collection timed out", slog.F("timeout", ctxTimeout))
497-
case metadataResults <- metadataResultAndKey{
498-
key: md.Key,
499-
result: a.collectMetadata(ctx, md, now),
500-
}:
501-
lastCollectedAtMu.Lock()
502-
lastCollectedAts[md.Key] = now
503-
lastCollectedAtMu.Unlock()
549+
case <-reportSemaphore:
550+
default:
551+
// If there's already a report in flight, don't send
552+
// another one, wait for next tick instead.
553+
continue
504554
}
505-
})
555+
556+
go func() {
557+
ctx, cancel := context.WithTimeout(ctx, reportTimeout)
558+
defer func() {
559+
cancel()
560+
reportSemaphore <- struct{}{}
561+
}()
562+
563+
err := a.client.PostMetadata(ctx, agentsdk.PostMetadataRequest{Metadata: metadata})
564+
if err != nil {
565+
a.logger.Error(ctx, "agent failed to report metadata", slog.Error(err))
566+
}
567+
}()
568+
}
506569
}
507570
}
508571
}

agent/agent_test.go

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,34 +1066,43 @@ func TestAgent_Metadata(t *testing.T) {
10661066

10671067
t.Run("Once", func(t *testing.T) {
10681068
t.Parallel()
1069+
10691070
//nolint:dogsled
10701071
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
10711072
Metadata: []codersdk.WorkspaceAgentMetadataDescription{
10721073
{
1073-
Key: "greeting",
1074+
Key: "greeting1",
10741075
Interval: 0,
10751076
Script: echoHello,
10761077
},
1078+
{
1079+
Key: "greeting2",
1080+
Interval: 1,
1081+
Script: echoHello,
1082+
},
10771083
},
10781084
}, 0, func(_ *agenttest.Client, opts *agent.Options) {
1079-
opts.ReportMetadataInterval = 100 * time.Millisecond
1085+
opts.ReportMetadataInterval = testutil.IntervalFast
10801086
})
10811087

10821088
var gotMd map[string]agentsdk.Metadata
10831089
require.Eventually(t, func() bool {
10841090
gotMd = client.GetMetadata()
1085-
return len(gotMd) == 1
1086-
}, testutil.WaitShort, testutil.IntervalMedium)
1091+
return len(gotMd) == 2
1092+
}, testutil.WaitShort, testutil.IntervalFast/2)
10871093

1088-
collectedAt := gotMd["greeting"].CollectedAt
1094+
collectedAt1 := gotMd["greeting1"].CollectedAt
1095+
collectedAt2 := gotMd["greeting2"].CollectedAt
10891096

1090-
require.Never(t, func() bool {
1097+
require.Eventually(t, func() bool {
10911098
gotMd = client.GetMetadata()
1092-
if len(gotMd) != 1 {
1099+
if len(gotMd) != 2 {
10931100
panic("unexpected number of metadata")
10941101
}
1095-
return !gotMd["greeting"].CollectedAt.Equal(collectedAt)
1096-
}, testutil.WaitShort, testutil.IntervalMedium)
1102+
return !gotMd["greeting2"].CollectedAt.Equal(collectedAt2)
1103+
}, testutil.WaitShort, testutil.IntervalFast/2)
1104+
1105+
require.Equal(t, gotMd["greeting1"].CollectedAt, collectedAt1, "metadata should not be collected again")
10971106
})
10981107

10991108
t.Run("Many", func(t *testing.T) {

0 commit comments

Comments
 (0)