From 4ad81b0bb1875d8b6a15ff79a91ab490a7d2b5a5 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Fri, 3 May 2024 11:33:42 +0400 Subject: [PATCH] fix: make handleManifest always signal dependents --- agent/agent.go | 66 ++++++++++++++++--------------- agent/checkpoint.go | 51 ++++++++++++++++++++++++ agent/checkpoint_internal_test.go | 49 +++++++++++++++++++++++ 3 files changed, 134 insertions(+), 32 deletions(-) create mode 100644 agent/checkpoint.go create mode 100644 agent/checkpoint_internal_test.go diff --git a/agent/agent.go b/agent/agent.go index abaaed4c313c0..8125bbc5f70d6 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -807,23 +807,21 @@ func (a *agent) run() (retErr error) { // coordination <--------------------------+ // derp map subscriber <----------------+ // stats report loop <---------------+ - networkOK := make(chan struct{}) - manifestOK := make(chan struct{}) + networkOK := newCheckpoint(a.logger) + manifestOK := newCheckpoint(a.logger) connMan.start("handle manifest", gracefulShutdownBehaviorStop, a.handleManifest(manifestOK)) connMan.start("app health reporter", gracefulShutdownBehaviorStop, func(ctx context.Context, conn drpc.Conn) error { - select { - case <-ctx.Done(): - return nil - case <-manifestOK: - manifest := a.manifest.Load() - NewWorkspaceAppHealthReporter( - a.logger, manifest.Apps, agentsdk.AppHealthPoster(proto.NewDRPCAgentClient(conn)), - )(ctx) - return nil + if err := manifestOK.wait(ctx); err != nil { + return xerrors.Errorf("no manifest: %w", err) } + manifest := a.manifest.Load() + NewWorkspaceAppHealthReporter( + a.logger, manifest.Apps, agentsdk.AppHealthPoster(proto.NewDRPCAgentClient(conn)), + )(ctx) + return nil }) connMan.start("create or update network", gracefulShutdownBehaviorStop, @@ -831,10 +829,8 @@ func (a *agent) run() (retErr error) { connMan.start("coordination", gracefulShutdownBehaviorStop, func(ctx context.Context, conn drpc.Conn) error { - select { - case <-ctx.Done(): - return nil - case <-networkOK: + if err := networkOK.wait(ctx); err != nil { + return xerrors.Errorf("no network: %w", err) } return a.runCoordinator(ctx, conn, a.network) }, @@ -842,10 +838,8 @@ func (a *agent) run() (retErr error) { connMan.start("derp map subscriber", gracefulShutdownBehaviorStop, func(ctx context.Context, conn drpc.Conn) error { - select { - case <-ctx.Done(): - return nil - case <-networkOK: + if err := networkOK.wait(ctx); err != nil { + return xerrors.Errorf("no network: %w", err) } return a.runDERPMapSubscriber(ctx, conn, a.network) }) @@ -853,10 +847,8 @@ func (a *agent) run() (retErr error) { connMan.start("fetch service banner loop", gracefulShutdownBehaviorStop, a.fetchServiceBannerLoop) connMan.start("stats report loop", gracefulShutdownBehaviorStop, func(ctx context.Context, conn drpc.Conn) error { - select { - case <-ctx.Done(): - return nil - case <-networkOK: + if err := networkOK.wait(ctx); err != nil { + return xerrors.Errorf("no network: %w", err) } return a.statsReporter.reportLoop(ctx, proto.NewDRPCAgentClient(conn)) }) @@ -865,8 +857,17 @@ func (a *agent) run() (retErr error) { } // handleManifest returns a function that fetches and processes the manifest -func (a *agent) handleManifest(manifestOK chan<- struct{}) func(ctx context.Context, conn drpc.Conn) error { +func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context, conn drpc.Conn) error { return func(ctx context.Context, conn drpc.Conn) error { + var ( + sentResult = false + err error + ) + defer func() { + if !sentResult { + manifestOK.complete(err) + } + }() aAPI := proto.NewDRPCAgentClient(conn) mp, err := aAPI.GetManifest(ctx, &proto.GetManifestRequest{}) if err != nil { @@ -907,7 +908,8 @@ func (a *agent) handleManifest(manifestOK chan<- struct{}) func(ctx context.Cont } oldManifest := a.manifest.Swap(&manifest) - close(manifestOK) + manifestOK.complete(nil) + sentResult = true // The startup script should only execute on the first run! if oldManifest == nil { @@ -968,14 +970,15 @@ func (a *agent) handleManifest(manifestOK chan<- struct{}) func(ctx context.Cont // createOrUpdateNetwork waits for the manifest to be set using manifestOK, then creates or updates // the tailnet using the information in the manifest -func (a *agent) createOrUpdateNetwork(manifestOK <-chan struct{}, networkOK chan<- struct{}) func(context.Context, drpc.Conn) error { - return func(ctx context.Context, _ drpc.Conn) error { - select { - case <-ctx.Done(): - return nil - case <-manifestOK: +func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(context.Context, drpc.Conn) error { + return func(ctx context.Context, _ drpc.Conn) (retErr error) { + if err := manifestOK.wait(ctx); err != nil { + return xerrors.Errorf("no manifest: %w", err) } var err error + defer func() { + networkOK.complete(retErr) + }() manifest := a.manifest.Load() a.closeMutex.Lock() network := a.network @@ -1011,7 +1014,6 @@ func (a *agent) createOrUpdateNetwork(manifestOK <-chan struct{}, networkOK chan network.SetDERPForceWebSockets(manifest.DERPForceWebSockets) network.SetBlockEndpoints(manifest.DisableDirectConnections) } - close(networkOK) return nil } } diff --git a/agent/checkpoint.go b/agent/checkpoint.go new file mode 100644 index 0000000000000..3f6c7b2c6d299 --- /dev/null +++ b/agent/checkpoint.go @@ -0,0 +1,51 @@ +package agent + +import ( + "context" + "runtime" + "sync" + + "cdr.dev/slog" +) + +// checkpoint allows a goroutine to communicate when it is OK to proceed beyond some async condition +// to other dependent goroutines. +type checkpoint struct { + logger slog.Logger + mu sync.Mutex + called bool + done chan struct{} + err error +} + +// complete the checkpoint. Pass nil to indicate the checkpoint was ok. It is an error to call this +// more than once. +func (c *checkpoint) complete(err error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.called { + b := make([]byte, 2048) + n := runtime.Stack(b, false) + c.logger.Critical(context.Background(), "checkpoint complete called more than once", slog.F("stacktrace", b[:n])) + return + } + c.called = true + c.err = err + close(c.done) +} + +func (c *checkpoint) wait(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.done: + return c.err + } +} + +func newCheckpoint(logger slog.Logger) *checkpoint { + return &checkpoint{ + logger: logger, + done: make(chan struct{}), + } +} diff --git a/agent/checkpoint_internal_test.go b/agent/checkpoint_internal_test.go new file mode 100644 index 0000000000000..17567a0e3c587 --- /dev/null +++ b/agent/checkpoint_internal_test.go @@ -0,0 +1,49 @@ +package agent + +import ( + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/testutil" +) + +func TestCheckpoint_CompleteWait(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, nil) + ctx := testutil.Context(t, testutil.WaitShort) + uut := newCheckpoint(logger) + err := xerrors.New("test") + uut.complete(err) + got := uut.wait(ctx) + require.Equal(t, err, got) +} + +func TestCheckpoint_CompleteTwice(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + ctx := testutil.Context(t, testutil.WaitShort) + uut := newCheckpoint(logger) + err := xerrors.New("test") + uut.complete(err) + uut.complete(nil) // drops CRITICAL log + got := uut.wait(ctx) + require.Equal(t, err, got) +} + +func TestCheckpoint_WaitComplete(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, nil) + ctx := testutil.Context(t, testutil.WaitShort) + uut := newCheckpoint(logger) + err := xerrors.New("test") + errCh := make(chan error, 1) + go func() { + errCh <- uut.wait(ctx) + }() + uut.complete(err) + got := testutil.RequireRecvCtx(ctx, t, errCh) + require.Equal(t, err, got) +}