From d476a8746dc7aea644b8c35d477b8736531f97b0 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 17 Jan 2024 15:48:48 +0400 Subject: [PATCH] feat: add setAllPeersLost to configMaps --- coderd/coderd_test.go | 11 +--- tailnet/configmaps.go | 23 ++++++++ tailnet/configmaps_internal_test.go | 85 +++++++++++++++++++++++++++++ 3 files changed, 111 insertions(+), 8 deletions(-) diff --git a/coderd/coderd_test.go b/coderd/coderd_test.go index 4c98feffb7546..a5f91fe6fd362 100644 --- a/coderd/coderd_test.go +++ b/coderd/coderd_test.go @@ -9,7 +9,6 @@ import ( "net/netip" "strconv" "strings" - "sync" "sync/atomic" "testing" @@ -59,6 +58,7 @@ func TestBuildInfo(t *testing.T) { func TestDERP(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitMedium) client := coderdtest.New(t, nil) logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) @@ -97,8 +97,6 @@ func TestDERP(t *testing.T) { }) require.NoError(t, err) - w2Ready := make(chan struct{}) - w2ReadyOnce := sync.Once{} w1ID := uuid.New() w1.SetNodeCallback(func(node *tailnet.Node) { pn, err := tailnet.NodeToProto(node) @@ -110,9 +108,6 @@ func TestDERP(t *testing.T) { Node: pn, Kind: tailnetproto.CoordinateResponse_PeerUpdate_NODE, }}) - w2ReadyOnce.Do(func() { - close(w2Ready) - }) }) w2ID := uuid.New() w2.SetNodeCallback(func(node *tailnet.Node) { @@ -140,8 +135,8 @@ func TestDERP(t *testing.T) { }() <-conn - <-w2Ready - nc, err := w2.DialContextTCP(context.Background(), netip.AddrPortFrom(w1IP, 35565)) + w2.AwaitReachable(ctx, w1IP) + nc, err := w2.DialContextTCP(ctx, netip.AddrPortFrom(w1IP, 35565)) require.NoError(t, err) _ = nc.Close() <-conn diff --git a/tailnet/configmaps.go b/tailnet/configmaps.go index 9c9fe7ee8d733..7579140c9f604 100644 --- a/tailnet/configmaps.go +++ b/tailnet/configmaps.go @@ -430,6 +430,29 @@ func (c *configMaps) updatePeerLocked(update *proto.CoordinateResponse_PeerUpdat } } +// setAllPeersLost marks all peers as lost. Typically, this is called when we lose connection to +// the Coordinator. (When we reconnect, we will get NODE updates for all peers that are still connected +// and mark them as not lost.) +func (c *configMaps) setAllPeersLost() { + c.L.Lock() + defer c.L.Unlock() + for _, lc := range c.peers { + if lc.lost { + // skip processing already lost nodes, as this just results in timer churn + continue + } + lc.lost = true + lc.setLostTimer(c) + // it's important to drop a log here so that we see it get marked lost if grepping thru + // the logs for a specific peer + c.logger.Debug(context.Background(), + "setAllPeersLost marked peer lost", + slog.F("peer_id", lc.peerID), + slog.F("key_id", lc.node.Key.ShortString()), + ) + } +} + // peerLostTimeout is the callback that peerLifecycle uses when a peer is lost the timeout to // receive a handshake fires. func (c *configMaps) peerLostTimeout(id uuid.UUID) { diff --git a/tailnet/configmaps_internal_test.go b/tailnet/configmaps_internal_test.go index bf04cd8378b76..a6921f939713e 100644 --- a/tailnet/configmaps_internal_test.go +++ b/tailnet/configmaps_internal_test.go @@ -491,6 +491,91 @@ func TestConfigMaps_updatePeers_lost_and_found(t *testing.T) { _ = testutil.RequireRecvCtx(ctx, t, done) } +func TestConfigMaps_setAllPeersLost(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + fEng := newFakeEngineConfigurable() + nodePrivateKey := key.NewNode() + nodeID := tailcfg.NodeID(5) + discoKey := key.NewDisco() + uut := newConfigMaps(logger, fEng, nodeID, nodePrivateKey, discoKey.Public()) + defer uut.close() + start := time.Date(2024, time.January, 1, 8, 0, 0, 0, time.UTC) + mClock := clock.NewMock() + mClock.Set(start) + uut.clock = mClock + + p1ID := uuid.UUID{1} + p1Node := newTestNode(1) + p1n, err := NodeToProto(p1Node) + require.NoError(t, err) + p2ID := uuid.UUID{2} + p2Node := newTestNode(2) + p2n, err := NodeToProto(p2Node) + require.NoError(t, err) + + s1 := expectStatusWithHandshake(ctx, t, fEng, p1Node.Key, start) + + updates := []*proto.CoordinateResponse_PeerUpdate{ + { + Id: p1ID[:], + Kind: proto.CoordinateResponse_PeerUpdate_NODE, + Node: p1n, + }, + { + Id: p2ID[:], + Kind: proto.CoordinateResponse_PeerUpdate_NODE, + Node: p2n, + }, + } + uut.updatePeers(updates) + nm := testutil.RequireRecvCtx(ctx, t, fEng.setNetworkMap) + r := testutil.RequireRecvCtx(ctx, t, fEng.reconfig) + require.Len(t, nm.Peers, 2) + require.Len(t, r.wg.Peers, 2) + _ = testutil.RequireRecvCtx(ctx, t, s1) + + mClock.Add(5 * time.Second) + uut.setAllPeersLost() + + // No reprogramming yet, since we keep the peer around. + select { + case <-fEng.setNetworkMap: + t.Fatal("should not reprogram") + default: + // OK! + } + + // When we advance the clock, even by a few ms, the timeout for peer 2 pops + // because our status only includes a handshake for peer 1 + s2 := expectStatusWithHandshake(ctx, t, fEng, p1Node.Key, start) + mClock.Add(time.Millisecond * 10) + _ = testutil.RequireRecvCtx(ctx, t, s2) + + nm = testutil.RequireRecvCtx(ctx, t, fEng.setNetworkMap) + r = testutil.RequireRecvCtx(ctx, t, fEng.reconfig) + require.Len(t, nm.Peers, 1) + require.Len(t, r.wg.Peers, 1) + + // Finally, advance the clock until after the timeout + s3 := expectStatusWithHandshake(ctx, t, fEng, p1Node.Key, start) + mClock.Add(lostTimeout) + _ = testutil.RequireRecvCtx(ctx, t, s3) + + nm = testutil.RequireRecvCtx(ctx, t, fEng.setNetworkMap) + r = testutil.RequireRecvCtx(ctx, t, fEng.reconfig) + require.Len(t, nm.Peers, 0) + require.Len(t, r.wg.Peers, 0) + + done := make(chan struct{}) + go func() { + defer close(done) + uut.close() + }() + _ = testutil.RequireRecvCtx(ctx, t, done) +} + func TestConfigMaps_setBlockEndpoints_different(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort)