Skip to content

Commit 9392155

Browse files
committed
fix: make handleManifest always signal dependents
1 parent 2a73bb4 commit 9392155

File tree

3 files changed

+103
-34
lines changed

3 files changed

+103
-34
lines changed

agent/agent.go

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -501,8 +501,8 @@ func (a *agent) reportMetadata(ctx context.Context, conn drpc.Conn) error {
501501
// mutex logic and overloading the API.
502502
for _, md := range manifest.Metadata {
503503
md := md
504-
// We send the result to the channel in the goroutine to avoid
505-
// sending the same result multiple times. So, we don't care about
504+
// We send the complete to the channel in the goroutine to avoid
505+
// sending the same complete multiple times. So, we don't care about
506506
// the return values.
507507
go flight.Do(md.Key, func() {
508508
ctx := slog.With(ctx, slog.F("key", md.Key))
@@ -807,56 +807,48 @@ func (a *agent) run() (retErr error) {
807807
// coordination <--------------------------+
808808
// derp map subscriber <----------------+
809809
// stats report loop <---------------+
810-
networkOK := make(chan struct{})
811-
manifestOK := make(chan struct{})
810+
networkOK := newCheckpoint()
811+
manifestOK := newCheckpoint()
812812

813813
connMan.start("handle manifest", gracefulShutdownBehaviorStop, a.handleManifest(manifestOK))
814814

815815
connMan.start("app health reporter", gracefulShutdownBehaviorStop,
816816
func(ctx context.Context, conn drpc.Conn) error {
817-
select {
818-
case <-ctx.Done():
819-
return nil
820-
case <-manifestOK:
821-
manifest := a.manifest.Load()
822-
NewWorkspaceAppHealthReporter(
823-
a.logger, manifest.Apps, agentsdk.AppHealthPoster(proto.NewDRPCAgentClient(conn)),
824-
)(ctx)
825-
return nil
817+
if err := manifestOK.waitCtx(ctx); err != nil {
818+
return xerrors.Errorf("no manifest: %w", err)
826819
}
820+
manifest := a.manifest.Load()
821+
NewWorkspaceAppHealthReporter(
822+
a.logger, manifest.Apps, agentsdk.AppHealthPoster(proto.NewDRPCAgentClient(conn)),
823+
)(ctx)
824+
return nil
827825
})
828826

829827
connMan.start("create or update network", gracefulShutdownBehaviorStop,
830828
a.createOrUpdateNetwork(manifestOK, networkOK))
831829

832830
connMan.start("coordination", gracefulShutdownBehaviorStop,
833831
func(ctx context.Context, conn drpc.Conn) error {
834-
select {
835-
case <-ctx.Done():
836-
return nil
837-
case <-networkOK:
832+
if err := networkOK.waitCtx(ctx); err != nil {
833+
return xerrors.Errorf("no network: %w", err)
838834
}
839835
return a.runCoordinator(ctx, conn, a.network)
840836
},
841837
)
842838

843839
connMan.start("derp map subscriber", gracefulShutdownBehaviorStop,
844840
func(ctx context.Context, conn drpc.Conn) error {
845-
select {
846-
case <-ctx.Done():
847-
return nil
848-
case <-networkOK:
841+
if err := networkOK.waitCtx(ctx); err != nil {
842+
return xerrors.Errorf("no network: %w", err)
849843
}
850844
return a.runDERPMapSubscriber(ctx, conn, a.network)
851845
})
852846

853847
connMan.start("fetch service banner loop", gracefulShutdownBehaviorStop, a.fetchServiceBannerLoop)
854848

855849
connMan.start("stats report loop", gracefulShutdownBehaviorStop, func(ctx context.Context, conn drpc.Conn) error {
856-
select {
857-
case <-ctx.Done():
858-
return nil
859-
case <-networkOK:
850+
if networkOK.waitCtx(ctx); err != nil {
851+
return xerrors.Errorf("no network: %w", err)
860852
}
861853
return a.statsReporter.reportLoop(ctx, proto.NewDRPCAgentClient(conn))
862854
})
@@ -865,8 +857,17 @@ func (a *agent) run() (retErr error) {
865857
}
866858

867859
// handleManifest returns a function that fetches and processes the manifest
868-
func (a *agent) handleManifest(manifestOK chan<- struct{}) func(ctx context.Context, conn drpc.Conn) error {
860+
func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context, conn drpc.Conn) error {
869861
return func(ctx context.Context, conn drpc.Conn) error {
862+
var (
863+
sentResult = false
864+
err error
865+
)
866+
defer func() {
867+
if !sentResult {
868+
manifestOK.complete(err)
869+
}
870+
}()
870871
aAPI := proto.NewDRPCAgentClient(conn)
871872
mp, err := aAPI.GetManifest(ctx, &proto.GetManifestRequest{})
872873
if err != nil {
@@ -907,7 +908,8 @@ func (a *agent) handleManifest(manifestOK chan<- struct{}) func(ctx context.Cont
907908
}
908909

909910
oldManifest := a.manifest.Swap(&manifest)
910-
close(manifestOK)
911+
manifestOK.complete(nil)
912+
sentResult = true
911913

912914
// The startup script should only execute on the first run!
913915
if oldManifest == nil {
@@ -968,14 +970,15 @@ func (a *agent) handleManifest(manifestOK chan<- struct{}) func(ctx context.Cont
968970

969971
// createOrUpdateNetwork waits for the manifest to be set using manifestOK, then creates or updates
970972
// the tailnet using the information in the manifest
971-
func (a *agent) createOrUpdateNetwork(manifestOK <-chan struct{}, networkOK chan<- struct{}) func(context.Context, drpc.Conn) error {
972-
return func(ctx context.Context, _ drpc.Conn) error {
973-
select {
974-
case <-ctx.Done():
975-
return nil
976-
case <-manifestOK:
973+
func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(context.Context, drpc.Conn) error {
974+
return func(ctx context.Context, _ drpc.Conn) (retErr error) {
975+
if err := manifestOK.waitCtx(ctx); err != nil {
976+
return xerrors.Errorf("no manifest: %w", err)
977977
}
978978
var err error
979+
defer func() {
980+
networkOK.complete(retErr)
981+
}()
979982
manifest := a.manifest.Load()
980983
a.closeMutex.Lock()
981984
network := a.network
@@ -1011,7 +1014,6 @@ func (a *agent) createOrUpdateNetwork(manifestOK <-chan struct{}, networkOK chan
10111014
network.SetDERPForceWebSockets(manifest.DERPForceWebSockets)
10121015
network.SetBlockEndpoints(manifest.DisableDirectConnections)
10131016
}
1014-
close(networkOK)
10151017
return nil
10161018
}
10171019
}

agent/checkpoint.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package agent
2+
3+
import (
4+
"context"
5+
)
6+
7+
// checkpoint allows a goroutine to communicate when it is OK to proceed beyond some async condition
8+
// to other dependent goroutines.
9+
type checkpoint struct {
10+
done chan struct{}
11+
err error
12+
}
13+
14+
// complete the checkpoint. Pass nil to indicate the checkpoint was ok.
15+
func (c *checkpoint) complete(err error) {
16+
c.err = err
17+
close(c.done)
18+
}
19+
20+
func (c *checkpoint) waitCtx(ctx context.Context) error {
21+
select {
22+
case <-ctx.Done():
23+
return ctx.Err()
24+
case <-c.done:
25+
return c.err
26+
}
27+
}
28+
29+
func newCheckpoint() *checkpoint {
30+
return &checkpoint{
31+
done: make(chan struct{}),
32+
}
33+
}

agent/checkpoint_internal_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package agent
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
"golang.org/x/xerrors"
8+
9+
"github.com/coder/coder/v2/testutil"
10+
)
11+
12+
func TestCheckpoint_CompleteWait(t *testing.T) {
13+
t.Parallel()
14+
ctx := testutil.Context(t, testutil.WaitShort)
15+
uut := newCheckpoint()
16+
err := xerrors.New("test")
17+
uut.complete(err)
18+
got := uut.waitCtx(ctx)
19+
require.Equal(t, err, got)
20+
}
21+
22+
func TestCheckpoint_WaitComplete(t *testing.T) {
23+
t.Parallel()
24+
ctx := testutil.Context(t, testutil.WaitShort)
25+
uut := newCheckpoint()
26+
err := xerrors.New("test")
27+
errCh := make(chan error, 1)
28+
go func() {
29+
errCh <- uut.waitCtx(ctx)
30+
}()
31+
uut.complete(err)
32+
got := testutil.RequireRecvCtx(ctx, t, errCh)
33+
require.Equal(t, err, got)
34+
}

0 commit comments

Comments
 (0)