Skip to content

Commit 76ec6f7

Browse files
committed
chore: add support for peer updates to tailnet.configMaps
1 parent b03a760 commit 76ec6f7

File tree

2 files changed

+801
-21
lines changed

2 files changed

+801
-21
lines changed

tailnet/configmaps.go

+228-10
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@ package tailnet
33
import (
44
"context"
55
"errors"
6+
"fmt"
67
"net/netip"
78
"sync"
9+
"time"
810

911
"github.com/google/uuid"
1012
"go4.org/netipx"
13+
"tailscale.com/ipn/ipnstate"
1114
"tailscale.com/net/dns"
1215
"tailscale.com/tailcfg"
1316
"tailscale.com/types/ipproto"
@@ -23,10 +26,13 @@ import (
2326
"github.com/coder/coder/v2/tailnet/proto"
2427
)
2528

29+
const lostTimeout = 15 * time.Minute
30+
2631
// engineConfigurable is the subset of wgengine.Engine that we use for configuration.
2732
//
2833
// This allows us to test configuration code without faking the whole interface.
2934
type engineConfigurable interface {
35+
UpdateStatus(*ipnstate.StatusBuilder)
3036
SetNetworkMap(*netmap.NetworkMap)
3137
Reconfig(*wgcfg.Config, *router.Config, *dns.Config, *tailcfg.Debug) error
3238
SetDERPMap(*tailcfg.DERPMap)
@@ -41,6 +47,20 @@ const (
4147
closed
4248
)
4349

50+
type clock interface {
51+
Now() time.Time
52+
AfterFunc(d time.Duration, f func()) stopper
53+
}
54+
55+
type stopper interface {
56+
Stop() bool
57+
}
58+
59+
type stdClock struct{}
60+
61+
func (stdClock) Now() time.Time { return time.Now() }
62+
func (stdClock) AfterFunc(d time.Duration, f func()) stopper { return time.AfterFunc(d, f) }
63+
4464
type configMaps struct {
4565
sync.Cond
4666
netmapDirty bool
@@ -49,12 +69,16 @@ type configMaps struct {
4969
closing bool
5070
phase phase
5171

52-
engine engineConfigurable
53-
static netmap.NetworkMap
54-
peers map[uuid.UUID]*peerLifecycle
55-
addresses []netip.Prefix
56-
derpMap *proto.DERPMap
57-
logger slog.Logger
72+
engine engineConfigurable
73+
static netmap.NetworkMap
74+
peers map[uuid.UUID]*peerLifecycle
75+
addresses []netip.Prefix
76+
derpMap *proto.DERPMap
77+
logger slog.Logger
78+
blockEndpoints bool
79+
80+
// for testing
81+
clock clock
5882
}
5983

6084
func newConfigMaps(logger slog.Logger, engine engineConfigurable, nodeID tailcfg.NodeID, nodeKey key.NodePrivate, discoKey key.DiscoPublic, addresses []netip.Prefix) *configMaps {
@@ -101,6 +125,7 @@ func newConfigMaps(logger slog.Logger, engine engineConfigurable, nodeID tailcfg
101125
},
102126
peers: make(map[uuid.UUID]*peerLifecycle),
103127
addresses: addresses,
128+
clock: stdClock{},
104129
}
105130
go c.configLoop()
106131
return c
@@ -165,6 +190,9 @@ func (c *configMaps) configLoop() {
165190
func (c *configMaps) close() {
166191
c.L.Lock()
167192
defer c.L.Unlock()
193+
for _, lc := range c.peers {
194+
lc.resetTimer()
195+
}
168196
c.closing = true
169197
c.Broadcast()
170198
for c.phase != closed {
@@ -248,11 +276,201 @@ func (c *configMaps) filterLocked() *filter.Filter {
248276
)
249277
}
250278

279+
func (c *configMaps) updatePeers(updates []*proto.CoordinateResponse_PeerUpdate) {
280+
status := c.status()
281+
c.L.Lock()
282+
defer c.L.Unlock()
283+
284+
// Update all the lastHandshake values here. That way we don't have to
285+
// worry about them being up-to-date when handling updates below, and it covers
286+
// all peers, not just the ones we got updates about.
287+
for _, lc := range c.peers {
288+
if peerStatus, ok := status.Peer[lc.node.Key]; ok {
289+
lc.lastHandshake = peerStatus.LastHandshake
290+
}
291+
}
292+
293+
for _, update := range updates {
294+
if dirty := c.updatePeerLocked(update, status); dirty {
295+
c.netmapDirty = true
296+
}
297+
}
298+
if c.netmapDirty {
299+
c.Broadcast()
300+
}
301+
}
302+
303+
func (c *configMaps) status() *ipnstate.Status {
304+
sb := &ipnstate.StatusBuilder{WantPeers: true}
305+
c.engine.UpdateStatus(sb)
306+
return sb.Status()
307+
}
308+
309+
func (c *configMaps) updatePeerLocked(update *proto.CoordinateResponse_PeerUpdate, status *ipnstate.Status) (dirty bool) {
310+
id, err := uuid.FromBytes(update.Id)
311+
if err != nil {
312+
c.logger.Critical(context.Background(), "received update with bad id", slog.F("id", update.Id))
313+
return false
314+
}
315+
logger := c.logger.With(slog.F("peer_id", id))
316+
lc, ok := c.peers[id]
317+
var node *tailcfg.Node
318+
if update.Kind == proto.CoordinateResponse_PeerUpdate_NODE {
319+
// If no preferred DERP is provided, we can't reach the node.
320+
if update.Node.PreferredDerp == 0 {
321+
logger.Warn(context.Background(), "no preferred DERP, peer update", slog.F("node_proto", update.Node))
322+
return false
323+
}
324+
node, err = c.protoNodeToTailcfg(update.Node)
325+
if err != nil {
326+
logger.Critical(context.Background(), "failed to convert proto node to tailcfg", slog.F("node_proto", update.Node))
327+
return false
328+
}
329+
logger = logger.With(slog.F("key_id", node.Key.ShortString()), slog.F("node", node))
330+
peerStatus, ok := status.Peer[node.Key]
331+
// Starting KeepAlive messages at the initialization of a connection
332+
// causes a race condition. If we send the handshake before the peer has
333+
// our node, we'll have to wait for 5 seconds before trying again.
334+
// Ideally, the first handshake starts when the user first initiates a
335+
// connection to the peer. After a successful connection we enable
336+
// keep alives to persist the connection and keep it from becoming idle.
337+
// SSH connections don't send packets while idle, so we use keep alives
338+
// to avoid random hangs while we set up the connection again after
339+
// inactivity.
340+
node.KeepAlive = ok && peerStatus.Active
341+
if c.blockEndpoints {
342+
node.Endpoints = nil
343+
}
344+
}
345+
switch {
346+
case !ok && update.Kind == proto.CoordinateResponse_PeerUpdate_NODE:
347+
// new!
348+
var lastHandshake time.Time
349+
if ps, ok := status.Peer[node.Key]; ok {
350+
lastHandshake = ps.LastHandshake
351+
}
352+
c.peers[id] = &peerLifecycle{
353+
peerID: id,
354+
node: node,
355+
lastHandshake: lastHandshake,
356+
lost: false,
357+
}
358+
logger.Debug(context.Background(), "adding new peer")
359+
return true
360+
case ok && update.Kind == proto.CoordinateResponse_PeerUpdate_NODE:
361+
// update
362+
node.Created = lc.node.Created
363+
dirty = !lc.node.Equal(node)
364+
lc.node = node
365+
lc.lost = false
366+
lc.resetTimer()
367+
logger.Debug(context.Background(), "node update to existing peer", slog.F("dirty", dirty))
368+
return dirty
369+
case !ok:
370+
// disconnected or lost, but we don't have the node. No op
371+
logger.Debug(context.Background(), "skipping update for peer we don't recognize")
372+
return false
373+
case update.Kind == proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
374+
lc.resetTimer()
375+
delete(c.peers, id)
376+
logger.Debug(context.Background(), "disconnected peer")
377+
return true
378+
case update.Kind == proto.CoordinateResponse_PeerUpdate_LOST:
379+
lc.lost = true
380+
lc.setLostTimer(c)
381+
logger.Debug(context.Background(), "marked peer lost")
382+
// marking a node lost doesn't change anything right now, so dirty=false
383+
return false
384+
default:
385+
logger.Warn(context.Background(), "unknown peer update", slog.F("kind", update.Kind))
386+
return false
387+
}
388+
}
389+
390+
func (c *configMaps) peerLostTimeout(id uuid.UUID) {
391+
logger := c.logger.With(slog.F("peer_id", id))
392+
logger.Debug(context.Background(),
393+
"peer lost timeout")
394+
395+
// First do a status update to see if the peer did a handshake while we were
396+
// waiting
397+
status := c.status()
398+
c.L.Lock()
399+
defer c.L.Unlock()
400+
401+
lc, ok := c.peers[id]
402+
if !ok {
403+
logger.Debug(context.Background(),
404+
"timeout triggered for peer that is removed from the map")
405+
return
406+
}
407+
if peerStatus, ok := status.Peer[lc.node.Key]; ok {
408+
lc.lastHandshake = peerStatus.LastHandshake
409+
}
410+
logger = logger.With(slog.F("key_id", lc.node.Key.ShortString()))
411+
if !lc.lost {
412+
logger.Debug(context.Background(),
413+
"timeout triggered for peer that is no longer lost")
414+
return
415+
}
416+
since := c.clock.Now().Sub(lc.lastHandshake)
417+
if since >= lostTimeout {
418+
logger.Info(
419+
context.Background(), "removing lost peer")
420+
delete(c.peers, id)
421+
c.netmapDirty = true
422+
c.Broadcast()
423+
return
424+
}
425+
logger.Debug(context.Background(),
426+
"timeout triggered for peer but it had handshake in meantime")
427+
lc.setLostTimer(c)
428+
}
429+
430+
func (c *configMaps) protoNodeToTailcfg(p *proto.Node) (*tailcfg.Node, error) {
431+
node, err := ProtoToNode(p)
432+
if err != nil {
433+
return nil, err
434+
}
435+
return &tailcfg.Node{
436+
ID: tailcfg.NodeID(p.GetId()),
437+
Created: c.clock.Now(),
438+
Key: node.Key,
439+
DiscoKey: node.DiscoKey,
440+
Addresses: node.Addresses,
441+
AllowedIPs: node.AllowedIPs,
442+
Endpoints: node.Endpoints,
443+
DERP: fmt.Sprintf("%s:%d", tailcfg.DerpMagicIP, node.PreferredDERP),
444+
Hostinfo: (&tailcfg.Hostinfo{}).View(),
445+
}, nil
446+
}
447+
251448
type peerLifecycle struct {
252-
node *tailcfg.Node
253-
// TODO: implement timers to track lost peers
254-
// lastHandshake time.Time
255-
// timer time.Timer
449+
peerID uuid.UUID
450+
node *tailcfg.Node
451+
lost bool
452+
lastHandshake time.Time
453+
timer stopper
454+
}
455+
456+
func (l *peerLifecycle) resetTimer() {
457+
if l.timer != nil {
458+
l.timer.Stop()
459+
l.timer = nil
460+
}
461+
}
462+
463+
func (l *peerLifecycle) setLostTimer(c *configMaps) {
464+
if l.timer != nil {
465+
l.timer.Stop()
466+
}
467+
ttl := lostTimeout - c.clock.Now().Sub(l.lastHandshake)
468+
if ttl <= 0 {
469+
ttl = time.Nanosecond
470+
}
471+
l.timer = c.clock.AfterFunc(ttl, func() {
472+
c.peerLostTimeout(l.peerID)
473+
})
256474
}
257475

258476
// prefixesDifferent returns true if the two slices contain different prefixes

0 commit comments

Comments
 (0)