Skip to content

Commit b7b9365

Browse files
authored
feat: add setAllPeersLost to the configMaps subcomponent (#11665)
adds setAllPeersLost to the configMaps subcomponent of tailnet.Conn --- we'll call this when we disconnect from a coordinator so we'll eventually clean up peers if they disconnect while we are retrying the coordinator connection (or we don't succeed in reconnecting to the coordinator).
1 parent f01cab9 commit b7b9365

File tree

3 files changed

+111
-8
lines changed

3 files changed

+111
-8
lines changed

coderd/coderd_test.go

+3-8
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"net/netip"
1010
"strconv"
1111
"strings"
12-
"sync"
1312
"sync/atomic"
1413
"testing"
1514

@@ -59,6 +58,7 @@ func TestBuildInfo(t *testing.T) {
5958

6059
func TestDERP(t *testing.T) {
6160
t.Parallel()
61+
ctx := testutil.Context(t, testutil.WaitMedium)
6262
client := coderdtest.New(t, nil)
6363

6464
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
@@ -97,8 +97,6 @@ func TestDERP(t *testing.T) {
9797
})
9898
require.NoError(t, err)
9999

100-
w2Ready := make(chan struct{})
101-
w2ReadyOnce := sync.Once{}
102100
w1ID := uuid.New()
103101
w1.SetNodeCallback(func(node *tailnet.Node) {
104102
pn, err := tailnet.NodeToProto(node)
@@ -110,9 +108,6 @@ func TestDERP(t *testing.T) {
110108
Node: pn,
111109
Kind: tailnetproto.CoordinateResponse_PeerUpdate_NODE,
112110
}})
113-
w2ReadyOnce.Do(func() {
114-
close(w2Ready)
115-
})
116111
})
117112
w2ID := uuid.New()
118113
w2.SetNodeCallback(func(node *tailnet.Node) {
@@ -140,8 +135,8 @@ func TestDERP(t *testing.T) {
140135
}()
141136

142137
<-conn
143-
<-w2Ready
144-
nc, err := w2.DialContextTCP(context.Background(), netip.AddrPortFrom(w1IP, 35565))
138+
w2.AwaitReachable(ctx, w1IP)
139+
nc, err := w2.DialContextTCP(ctx, netip.AddrPortFrom(w1IP, 35565))
145140
require.NoError(t, err)
146141
_ = nc.Close()
147142
<-conn

tailnet/configmaps.go

+23
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,29 @@ func (c *configMaps) updatePeerLocked(update *proto.CoordinateResponse_PeerUpdat
430430
}
431431
}
432432

433+
// setAllPeersLost marks all peers as lost. Typically, this is called when we lose connection to
434+
// the Coordinator. (When we reconnect, we will get NODE updates for all peers that are still connected
435+
// and mark them as not lost.)
436+
func (c *configMaps) setAllPeersLost() {
437+
c.L.Lock()
438+
defer c.L.Unlock()
439+
for _, lc := range c.peers {
440+
if lc.lost {
441+
// skip processing already lost nodes, as this just results in timer churn
442+
continue
443+
}
444+
lc.lost = true
445+
lc.setLostTimer(c)
446+
// it's important to drop a log here so that we see it get marked lost if grepping thru
447+
// the logs for a specific peer
448+
c.logger.Debug(context.Background(),
449+
"setAllPeersLost marked peer lost",
450+
slog.F("peer_id", lc.peerID),
451+
slog.F("key_id", lc.node.Key.ShortString()),
452+
)
453+
}
454+
}
455+
433456
// peerLostTimeout is the callback that peerLifecycle uses when a peer is lost the timeout to
434457
// receive a handshake fires.
435458
func (c *configMaps) peerLostTimeout(id uuid.UUID) {

tailnet/configmaps_internal_test.go

+85
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,91 @@ func TestConfigMaps_updatePeers_lost_and_found(t *testing.T) {
491491
_ = testutil.RequireRecvCtx(ctx, t, done)
492492
}
493493

494+
func TestConfigMaps_setAllPeersLost(t *testing.T) {
495+
t.Parallel()
496+
ctx := testutil.Context(t, testutil.WaitShort)
497+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
498+
fEng := newFakeEngineConfigurable()
499+
nodePrivateKey := key.NewNode()
500+
nodeID := tailcfg.NodeID(5)
501+
discoKey := key.NewDisco()
502+
uut := newConfigMaps(logger, fEng, nodeID, nodePrivateKey, discoKey.Public())
503+
defer uut.close()
504+
start := time.Date(2024, time.January, 1, 8, 0, 0, 0, time.UTC)
505+
mClock := clock.NewMock()
506+
mClock.Set(start)
507+
uut.clock = mClock
508+
509+
p1ID := uuid.UUID{1}
510+
p1Node := newTestNode(1)
511+
p1n, err := NodeToProto(p1Node)
512+
require.NoError(t, err)
513+
p2ID := uuid.UUID{2}
514+
p2Node := newTestNode(2)
515+
p2n, err := NodeToProto(p2Node)
516+
require.NoError(t, err)
517+
518+
s1 := expectStatusWithHandshake(ctx, t, fEng, p1Node.Key, start)
519+
520+
updates := []*proto.CoordinateResponse_PeerUpdate{
521+
{
522+
Id: p1ID[:],
523+
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
524+
Node: p1n,
525+
},
526+
{
527+
Id: p2ID[:],
528+
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
529+
Node: p2n,
530+
},
531+
}
532+
uut.updatePeers(updates)
533+
nm := testutil.RequireRecvCtx(ctx, t, fEng.setNetworkMap)
534+
r := testutil.RequireRecvCtx(ctx, t, fEng.reconfig)
535+
require.Len(t, nm.Peers, 2)
536+
require.Len(t, r.wg.Peers, 2)
537+
_ = testutil.RequireRecvCtx(ctx, t, s1)
538+
539+
mClock.Add(5 * time.Second)
540+
uut.setAllPeersLost()
541+
542+
// No reprogramming yet, since we keep the peer around.
543+
select {
544+
case <-fEng.setNetworkMap:
545+
t.Fatal("should not reprogram")
546+
default:
547+
// OK!
548+
}
549+
550+
// When we advance the clock, even by a few ms, the timeout for peer 2 pops
551+
// because our status only includes a handshake for peer 1
552+
s2 := expectStatusWithHandshake(ctx, t, fEng, p1Node.Key, start)
553+
mClock.Add(time.Millisecond * 10)
554+
_ = testutil.RequireRecvCtx(ctx, t, s2)
555+
556+
nm = testutil.RequireRecvCtx(ctx, t, fEng.setNetworkMap)
557+
r = testutil.RequireRecvCtx(ctx, t, fEng.reconfig)
558+
require.Len(t, nm.Peers, 1)
559+
require.Len(t, r.wg.Peers, 1)
560+
561+
// Finally, advance the clock until after the timeout
562+
s3 := expectStatusWithHandshake(ctx, t, fEng, p1Node.Key, start)
563+
mClock.Add(lostTimeout)
564+
_ = testutil.RequireRecvCtx(ctx, t, s3)
565+
566+
nm = testutil.RequireRecvCtx(ctx, t, fEng.setNetworkMap)
567+
r = testutil.RequireRecvCtx(ctx, t, fEng.reconfig)
568+
require.Len(t, nm.Peers, 0)
569+
require.Len(t, r.wg.Peers, 0)
570+
571+
done := make(chan struct{})
572+
go func() {
573+
defer close(done)
574+
uut.close()
575+
}()
576+
_ = testutil.RequireRecvCtx(ctx, t, done)
577+
}
578+
494579
func TestConfigMaps_setBlockEndpoints_different(t *testing.T) {
495580
t.Parallel()
496581
ctx := testutil.Context(t, testutil.WaitShort)

0 commit comments

Comments
 (0)