Skip to content

test(agent): fix TestAgent_Metadata/Once flake #8613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 20, 2023
98 changes: 65 additions & 33 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,15 +242,15 @@ func (a *agent) runLoop(ctx context.Context) {
}
}

func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentMetadataDescription) *codersdk.WorkspaceAgentMetadataResult {
func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentMetadataDescription, now time.Time) *codersdk.WorkspaceAgentMetadataResult {
var out bytes.Buffer
result := &codersdk.WorkspaceAgentMetadataResult{
// CollectedAt is set here for testing purposes and overrode by
// coderd to the time of server receipt to solve clock skew.
//
// In the future, the server may accept the timestamp from the agent
// if it can guarantee the clocks are synchronized.
CollectedAt: time.Now(),
CollectedAt: now,
}
cmdPty, err := a.sshServer.CreateCommand(ctx, md.Script, nil)
if err != nil {
Expand Down Expand Up @@ -298,54 +298,64 @@ type metadataResultAndKey struct {
}

type trySingleflight struct {
m sync.Map
mu sync.Mutex
m map[string]struct{}
}

func (t *trySingleflight) Do(key string, fn func()) {
_, loaded := t.m.LoadOrStore(key, struct{}{})
if !loaded {
// There is already a goroutine running for this key.
t.mu.Lock()
_, ok := t.m[key]
if ok {
t.mu.Unlock()
return
}

defer t.m.Delete(key)
t.m[key] = struct{}{}
t.mu.Unlock()
defer func() {
t.mu.Lock()
delete(t.m, key)
t.mu.Unlock()
}()

fn()
}

func (a *agent) reportMetadataLoop(ctx context.Context) {
const metadataLimit = 128

var (
baseTicker = time.NewTicker(a.reportMetadataInterval)
lastCollectedAts = make(map[string]time.Time)
metadataResults = make(chan metadataResultAndKey, metadataLimit)
baseTicker = time.NewTicker(a.reportMetadataInterval)
lastCollectedAtMu sync.RWMutex
lastCollectedAts = make(map[string]time.Time)
metadataResults = make(chan metadataResultAndKey, metadataLimit)
logger = a.logger.Named("metadata")
)
defer baseTicker.Stop()

// We use a custom singleflight that immediately returns if there is already
// a goroutine running for a given key. This is to prevent a build-up of
// goroutines waiting on Do when the script takes many multiples of
// baseInterval to run.
var flight trySingleflight
flight := trySingleflight{m: map[string]struct{}{}}

for {
select {
case <-ctx.Done():
return
case mr := <-metadataResults:
lastCollectedAts[mr.key] = mr.result.CollectedAt
err := a.client.PostMetadata(ctx, mr.key, *mr.result)
if err != nil {
a.logger.Error(ctx, "agent failed to report metadata", slog.Error(err))
}
continue
case <-baseTicker.C:
}

if len(metadataResults) > 0 {
// The inner collection loop expects the channel is empty before spinning up
// all the collection goroutines.
a.logger.Debug(
ctx, "metadata collection backpressured",
logger.Debug(ctx, "metadata collection backpressured",
slog.F("queue_len", len(metadataResults)),
)
continue
Expand All @@ -357,7 +367,7 @@ func (a *agent) reportMetadataLoop(ctx context.Context) {
}

if len(manifest.Metadata) > metadataLimit {
a.logger.Error(
logger.Error(
ctx, "metadata limit exceeded",
slog.F("limit", metadataLimit), slog.F("got", len(manifest.Metadata)),
)
Expand All @@ -367,51 +377,73 @@ func (a *agent) reportMetadataLoop(ctx context.Context) {
// If the manifest changes (e.g. on agent reconnect) we need to
// purge old cache values to prevent lastCollectedAt from growing
// boundlessly.
lastCollectedAtMu.Lock()
for key := range lastCollectedAts {
if slices.IndexFunc(manifest.Metadata, func(md codersdk.WorkspaceAgentMetadataDescription) bool {
return md.Key == key
}) < 0 {
logger.Debug(ctx, "deleting lastCollected key, missing from manifest",
slog.F("key", key),
)
delete(lastCollectedAts, key)
}
}
lastCollectedAtMu.Unlock()

// Spawn a goroutine for each metadata collection, and use a
// channel to synchronize the results and avoid both messy
// mutex logic and overloading the API.
for _, md := range manifest.Metadata {
collectedAt, ok := lastCollectedAts[md.Key]
if ok {
// If the interval is zero, we assume the user just wants
// a single collection at startup, not a spinning loop.
if md.Interval == 0 {
continue
}
// The last collected value isn't quite stale yet, so we skip it.
if collectedAt.Add(a.reportMetadataInterval).After(time.Now()) {
continue
}
}

md := md
// We send the result to the channel in the goroutine to avoid
// sending the same result multiple times. So, we don't care about
// the return values.
go flight.Do(md.Key, func() {
ctx := slog.With(ctx, slog.F("key", md.Key))
lastCollectedAtMu.RLock()
collectedAt, ok := lastCollectedAts[md.Key]
lastCollectedAtMu.RUnlock()
if ok {
// If the interval is zero, we assume the user just wants
// a single collection at startup, not a spinning loop.
if md.Interval == 0 {
return
}
// The last collected value isn't quite stale yet, so we skip it.
if collectedAt.Add(a.reportMetadataInterval).After(time.Now()) {
return
}
}

timeout := md.Timeout
if timeout == 0 {
timeout = md.Interval
if md.Interval != 0 {
timeout = md.Interval
} else if interval := int64(a.reportMetadataInterval.Seconds()); interval != 0 {
// Fallback to the report interval
timeout = interval * 3
} else {
// If the interval is still 0 (possible if the interval
// is less than a second), default to 5. This was
// randomly picked.
timeout = 5
}
}
ctx, cancel := context.WithTimeout(ctx,
time.Duration(timeout)*time.Second,
)
ctxTimeout := time.Duration(timeout) * time.Second
ctx, cancel := context.WithTimeout(ctx, ctxTimeout)
defer cancel()

now := time.Now()
select {
case <-ctx.Done():
logger.Warn(ctx, "metadata collection timed out", slog.F("timeout", ctxTimeout))
case metadataResults <- metadataResultAndKey{
key: md.Key,
result: a.collectMetadata(ctx, md),
result: a.collectMetadata(ctx, md, now),
}:
lastCollectedAtMu.Lock()
lastCollectedAts[md.Key] = now
lastCollectedAtMu.Unlock()
}
})
}
Expand Down
7 changes: 6 additions & 1 deletion agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,7 @@ func TestAgent_StartupScript(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
client := agenttest.NewClient(t,
logger,
uuid.New(),
agentsdk.Manifest{
StartupScript: command,
Expand Down Expand Up @@ -1097,6 +1098,7 @@ func TestAgent_StartupScript(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
client := agenttest.NewClient(t,
logger,
uuid.New(),
agentsdk.Manifest{
StartupScript: command,
Expand Down Expand Up @@ -1470,6 +1472,7 @@ func TestAgent_Lifecycle(t *testing.T) {
derpMap, _ := tailnettest.RunDERPAndSTUN(t)

client := agenttest.NewClient(t,
logger,
uuid.New(),
agentsdk.Manifest{
DERPMap: derpMap,
Expand Down Expand Up @@ -1742,6 +1745,7 @@ func TestAgent_Reconnect(t *testing.T) {
statsCh := make(chan *agentsdk.Stats, 50)
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
client := agenttest.NewClient(t,
logger,
agentID,
agentsdk.Manifest{
DERPMap: derpMap,
Expand Down Expand Up @@ -1776,6 +1780,7 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) {
defer coordinator.Close()

client := agenttest.NewClient(t,
logger,
uuid.New(),
agentsdk.Manifest{
GitAuthConfigs: 1,
Expand Down Expand Up @@ -1900,7 +1905,7 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
})
statsCh := make(chan *agentsdk.Stats, 50)
fs := afero.NewMemMapFs()
c := agenttest.NewClient(t, metadata.AgentID, metadata, statsCh, coordinator)
c := agenttest.NewClient(t, logger.Named("agent"), metadata.AgentID, metadata, statsCh, coordinator)

options := agent.Options{
Client: c,
Expand Down
21 changes: 15 additions & 6 deletions agent/agenttest/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
)

func NewClient(t testing.TB,
logger slog.Logger,
agentID uuid.UUID,
manifest agentsdk.Manifest,
statsChan chan *agentsdk.Stats,
Expand All @@ -28,6 +29,7 @@ func NewClient(t testing.TB,
}
return &Client{
t: t,
logger: logger.Named("client"),
agentID: agentID,
manifest: manifest,
statsChan: statsChan,
Expand All @@ -37,6 +39,7 @@ func NewClient(t testing.TB,

type Client struct {
t testing.TB
logger slog.Logger
agentID uuid.UUID
manifest agentsdk.Manifest
metadata map[string]agentsdk.PostMetadataRequest
Expand Down Expand Up @@ -110,14 +113,16 @@ func (c *Client) GetLifecycleStates() []codersdk.WorkspaceAgentLifecycle {
return c.lifecycleStates
}

func (c *Client) PostLifecycle(_ context.Context, req agentsdk.PostLifecycleRequest) error {
func (c *Client) PostLifecycle(ctx context.Context, req agentsdk.PostLifecycleRequest) error {
c.mu.Lock()
defer c.mu.Unlock()
c.lifecycleStates = append(c.lifecycleStates, req.State)
c.logger.Debug(ctx, "post lifecycle", slog.F("req", req))
return nil
}

func (*Client) PostAppHealth(_ context.Context, _ agentsdk.PostAppHealthsRequest) error {
func (c *Client) PostAppHealth(ctx context.Context, req agentsdk.PostAppHealthsRequest) error {
c.logger.Debug(ctx, "post app health", slog.F("req", req))
return nil
}

Expand All @@ -133,20 +138,22 @@ func (c *Client) GetMetadata() map[string]agentsdk.PostMetadataRequest {
return maps.Clone(c.metadata)
}

func (c *Client) PostMetadata(_ context.Context, key string, req agentsdk.PostMetadataRequest) error {
func (c *Client) PostMetadata(ctx context.Context, key string, req agentsdk.PostMetadataRequest) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.metadata == nil {
c.metadata = make(map[string]agentsdk.PostMetadataRequest)
}
c.metadata[key] = req
c.logger.Debug(ctx, "post metadata", slog.F("key", key), slog.F("req", req))
return nil
}

func (c *Client) PostStartup(_ context.Context, startup agentsdk.PostStartupRequest) error {
func (c *Client) PostStartup(ctx context.Context, startup agentsdk.PostStartupRequest) error {
c.mu.Lock()
defer c.mu.Unlock()
c.startup = startup
c.logger.Debug(ctx, "post startup", slog.F("req", startup))
return nil
}

Expand All @@ -156,13 +163,14 @@ func (c *Client) GetStartupLogs() []agentsdk.StartupLog {
return c.logs
}

func (c *Client) PatchStartupLogs(_ context.Context, logs agentsdk.PatchStartupLogs) error {
func (c *Client) PatchStartupLogs(ctx context.Context, logs agentsdk.PatchStartupLogs) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.PatchWorkspaceLogs != nil {
return c.PatchWorkspaceLogs()
}
c.logs = append(c.logs, logs.Logs...)
c.logger.Debug(ctx, "patch startup logs", slog.F("req", logs))
return nil
}

Expand All @@ -173,9 +181,10 @@ func (c *Client) SetServiceBannerFunc(f func() (codersdk.ServiceBannerConfig, er
c.GetServiceBannerFunc = f
}

func (c *Client) GetServiceBanner(_ context.Context) (codersdk.ServiceBannerConfig, error) {
func (c *Client) GetServiceBanner(ctx context.Context) (codersdk.ServiceBannerConfig, error) {
c.mu.Lock()
defer c.mu.Unlock()
c.logger.Debug(ctx, "get service banner")
if c.GetServiceBannerFunc != nil {
return c.GetServiceBannerFunc()
}
Expand Down
2 changes: 1 addition & 1 deletion coderd/tailnet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
_ = coord.Close()
})

c := agenttest.NewClient(t, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord)
c := agenttest.NewClient(t, logger, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord)

options := agent.Options{
Client: c,
Expand Down