@@ -3,11 +3,14 @@ package tailnet
3
3
import (
4
4
"context"
5
5
"errors"
6
+ "fmt"
6
7
"net/netip"
7
8
"sync"
9
+ "time"
8
10
9
11
"github.com/google/uuid"
10
12
"go4.org/netipx"
13
+ "tailscale.com/ipn/ipnstate"
11
14
"tailscale.com/net/dns"
12
15
"tailscale.com/tailcfg"
13
16
"tailscale.com/types/ipproto"
@@ -23,10 +26,13 @@ import (
23
26
"github.com/coder/coder/v2/tailnet/proto"
24
27
)
25
28
29
+ const lostTimeout = 15 * time .Minute
30
+
26
31
// engineConfigurable is the subset of wgengine.Engine that we use for configuration.
27
32
//
28
33
// This allows us to test configuration code without faking the whole interface.
29
34
type engineConfigurable interface {
35
+ UpdateStatus (* ipnstate.StatusBuilder )
30
36
SetNetworkMap (* netmap.NetworkMap )
31
37
Reconfig (* wgcfg.Config , * router.Config , * dns.Config , * tailcfg.Debug ) error
32
38
SetDERPMap (* tailcfg.DERPMap )
@@ -41,6 +47,20 @@ const (
41
47
closed
42
48
)
43
49
50
+ type clock interface {
51
+ Now () time.Time
52
+ AfterFunc (d time.Duration , f func ()) stopper
53
+ }
54
+
55
+ type stopper interface {
56
+ Stop () bool
57
+ }
58
+
59
+ type stdClock struct {}
60
+
61
+ func (stdClock ) Now () time.Time { return time .Now () }
62
+ func (stdClock ) AfterFunc (d time.Duration , f func ()) stopper { return time .AfterFunc (d , f ) }
63
+
44
64
type configMaps struct {
45
65
sync.Cond
46
66
netmapDirty bool
@@ -49,12 +69,16 @@ type configMaps struct {
49
69
closing bool
50
70
phase phase
51
71
52
- engine engineConfigurable
53
- static netmap.NetworkMap
54
- peers map [uuid.UUID ]* peerLifecycle
55
- addresses []netip.Prefix
56
- derpMap * proto.DERPMap
57
- logger slog.Logger
72
+ engine engineConfigurable
73
+ static netmap.NetworkMap
74
+ peers map [uuid.UUID ]* peerLifecycle
75
+ addresses []netip.Prefix
76
+ derpMap * proto.DERPMap
77
+ logger slog.Logger
78
+ blockEndpoints bool
79
+
80
+ // for testing
81
+ clock clock
58
82
}
59
83
60
84
func newConfigMaps (logger slog.Logger , engine engineConfigurable , nodeID tailcfg.NodeID , nodeKey key.NodePrivate , discoKey key.DiscoPublic , addresses []netip.Prefix ) * configMaps {
@@ -101,6 +125,7 @@ func newConfigMaps(logger slog.Logger, engine engineConfigurable, nodeID tailcfg
101
125
},
102
126
peers : make (map [uuid.UUID ]* peerLifecycle ),
103
127
addresses : addresses ,
128
+ clock : stdClock {},
104
129
}
105
130
go c .configLoop ()
106
131
return c
@@ -165,6 +190,9 @@ func (c *configMaps) configLoop() {
165
190
func (c * configMaps ) close () {
166
191
c .L .Lock ()
167
192
defer c .L .Unlock ()
193
+ for _ , lc := range c .peers {
194
+ lc .resetTimer ()
195
+ }
168
196
c .closing = true
169
197
c .Broadcast ()
170
198
for c .phase != closed {
@@ -248,11 +276,201 @@ func (c *configMaps) filterLocked() *filter.Filter {
248
276
)
249
277
}
250
278
279
+ func (c * configMaps ) updatePeers (updates []* proto.CoordinateResponse_PeerUpdate ) {
280
+ status := c .status ()
281
+ c .L .Lock ()
282
+ defer c .L .Unlock ()
283
+
284
+ // Update all the lastHandshake values here. That way we don't have to
285
+ // worry about them being up-to-date when handling updates below, and it covers
286
+ // all peers, not just the ones we got updates about.
287
+ for _ , lc := range c .peers {
288
+ if peerStatus , ok := status .Peer [lc .node .Key ]; ok {
289
+ lc .lastHandshake = peerStatus .LastHandshake
290
+ }
291
+ }
292
+
293
+ for _ , update := range updates {
294
+ if dirty := c .updatePeerLocked (update , status ); dirty {
295
+ c .netmapDirty = true
296
+ }
297
+ }
298
+ if c .netmapDirty {
299
+ c .Broadcast ()
300
+ }
301
+ }
302
+
303
+ func (c * configMaps ) status () * ipnstate.Status {
304
+ sb := & ipnstate.StatusBuilder {WantPeers : true }
305
+ c .engine .UpdateStatus (sb )
306
+ return sb .Status ()
307
+ }
308
+
309
+ func (c * configMaps ) updatePeerLocked (update * proto.CoordinateResponse_PeerUpdate , status * ipnstate.Status ) (dirty bool ) {
310
+ id , err := uuid .FromBytes (update .Id )
311
+ if err != nil {
312
+ c .logger .Critical (context .Background (), "received update with bad id" , slog .F ("id" , update .Id ))
313
+ return false
314
+ }
315
+ logger := c .logger .With (slog .F ("peer_id" , id ))
316
+ lc , ok := c .peers [id ]
317
+ var node * tailcfg.Node
318
+ if update .Kind == proto .CoordinateResponse_PeerUpdate_NODE {
319
+ // If no preferred DERP is provided, we can't reach the node.
320
+ if update .Node .PreferredDerp == 0 {
321
+ logger .Warn (context .Background (), "no preferred DERP, peer update" , slog .F ("node_proto" , update .Node ))
322
+ return false
323
+ }
324
+ node , err = c .protoNodeToTailcfg (update .Node )
325
+ if err != nil {
326
+ logger .Critical (context .Background (), "failed to convert proto node to tailcfg" , slog .F ("node_proto" , update .Node ))
327
+ return false
328
+ }
329
+ logger = logger .With (slog .F ("key_id" , node .Key .ShortString ()), slog .F ("node" , node ))
330
+ peerStatus , ok := status .Peer [node .Key ]
331
+ // Starting KeepAlive messages at the initialization of a connection
332
+ // causes a race condition. If we send the handshake before the peer has
333
+ // our node, we'll have to wait for 5 seconds before trying again.
334
+ // Ideally, the first handshake starts when the user first initiates a
335
+ // connection to the peer. After a successful connection we enable
336
+ // keep alives to persist the connection and keep it from becoming idle.
337
+ // SSH connections don't send packets while idle, so we use keep alives
338
+ // to avoid random hangs while we set up the connection again after
339
+ // inactivity.
340
+ node .KeepAlive = ok && peerStatus .Active
341
+ if c .blockEndpoints {
342
+ node .Endpoints = nil
343
+ }
344
+ }
345
+ switch {
346
+ case ! ok && update .Kind == proto .CoordinateResponse_PeerUpdate_NODE :
347
+ // new!
348
+ var lastHandshake time.Time
349
+ if ps , ok := status .Peer [node .Key ]; ok {
350
+ lastHandshake = ps .LastHandshake
351
+ }
352
+ c .peers [id ] = & peerLifecycle {
353
+ peerID : id ,
354
+ node : node ,
355
+ lastHandshake : lastHandshake ,
356
+ lost : false ,
357
+ }
358
+ logger .Debug (context .Background (), "adding new peer" )
359
+ return true
360
+ case ok && update .Kind == proto .CoordinateResponse_PeerUpdate_NODE :
361
+ // update
362
+ node .Created = lc .node .Created
363
+ dirty = ! lc .node .Equal (node )
364
+ lc .node = node
365
+ lc .lost = false
366
+ lc .resetTimer ()
367
+ logger .Debug (context .Background (), "node update to existing peer" , slog .F ("dirty" , dirty ))
368
+ return dirty
369
+ case ! ok :
370
+ // disconnected or lost, but we don't have the node. No op
371
+ logger .Debug (context .Background (), "skipping update for peer we don't recognize" )
372
+ return false
373
+ case update .Kind == proto .CoordinateResponse_PeerUpdate_DISCONNECTED :
374
+ lc .resetTimer ()
375
+ delete (c .peers , id )
376
+ logger .Debug (context .Background (), "disconnected peer" )
377
+ return true
378
+ case update .Kind == proto .CoordinateResponse_PeerUpdate_LOST :
379
+ lc .lost = true
380
+ lc .setLostTimer (c )
381
+ logger .Debug (context .Background (), "marked peer lost" )
382
+ // marking a node lost doesn't change anything right now, so dirty=false
383
+ return false
384
+ default :
385
+ logger .Warn (context .Background (), "unknown peer update" , slog .F ("kind" , update .Kind ))
386
+ return false
387
+ }
388
+ }
389
+
390
+ func (c * configMaps ) peerLostTimeout (id uuid.UUID ) {
391
+ logger := c .logger .With (slog .F ("peer_id" , id ))
392
+ logger .Debug (context .Background (),
393
+ "peer lost timeout" )
394
+
395
+ // First do a status update to see if the peer did a handshake while we were
396
+ // waiting
397
+ status := c .status ()
398
+ c .L .Lock ()
399
+ defer c .L .Unlock ()
400
+
401
+ lc , ok := c .peers [id ]
402
+ if ! ok {
403
+ logger .Debug (context .Background (),
404
+ "timeout triggered for peer that is removed from the map" )
405
+ return
406
+ }
407
+ if peerStatus , ok := status .Peer [lc .node .Key ]; ok {
408
+ lc .lastHandshake = peerStatus .LastHandshake
409
+ }
410
+ logger = logger .With (slog .F ("key_id" , lc .node .Key .ShortString ()))
411
+ if ! lc .lost {
412
+ logger .Debug (context .Background (),
413
+ "timeout triggered for peer that is no longer lost" )
414
+ return
415
+ }
416
+ since := c .clock .Now ().Sub (lc .lastHandshake )
417
+ if since >= lostTimeout {
418
+ logger .Info (
419
+ context .Background (), "removing lost peer" )
420
+ delete (c .peers , id )
421
+ c .netmapDirty = true
422
+ c .Broadcast ()
423
+ return
424
+ }
425
+ logger .Debug (context .Background (),
426
+ "timeout triggered for peer but it had handshake in meantime" )
427
+ lc .setLostTimer (c )
428
+ }
429
+
430
+ func (c * configMaps ) protoNodeToTailcfg (p * proto.Node ) (* tailcfg.Node , error ) {
431
+ node , err := ProtoToNode (p )
432
+ if err != nil {
433
+ return nil , err
434
+ }
435
+ return & tailcfg.Node {
436
+ ID : tailcfg .NodeID (p .GetId ()),
437
+ Created : c .clock .Now (),
438
+ Key : node .Key ,
439
+ DiscoKey : node .DiscoKey ,
440
+ Addresses : node .Addresses ,
441
+ AllowedIPs : node .AllowedIPs ,
442
+ Endpoints : node .Endpoints ,
443
+ DERP : fmt .Sprintf ("%s:%d" , tailcfg .DerpMagicIP , node .PreferredDERP ),
444
+ Hostinfo : (& tailcfg.Hostinfo {}).View (),
445
+ }, nil
446
+ }
447
+
251
448
type peerLifecycle struct {
252
- node * tailcfg.Node
253
- // TODO: implement timers to track lost peers
254
- // lastHandshake time.Time
255
- // timer time.Timer
449
+ peerID uuid.UUID
450
+ node * tailcfg.Node
451
+ lost bool
452
+ lastHandshake time.Time
453
+ timer stopper
454
+ }
455
+
456
+ func (l * peerLifecycle ) resetTimer () {
457
+ if l .timer != nil {
458
+ l .timer .Stop ()
459
+ l .timer = nil
460
+ }
461
+ }
462
+
463
+ func (l * peerLifecycle ) setLostTimer (c * configMaps ) {
464
+ if l .timer != nil {
465
+ l .timer .Stop ()
466
+ }
467
+ ttl := lostTimeout - c .clock .Now ().Sub (l .lastHandshake )
468
+ if ttl <= 0 {
469
+ ttl = time .Nanosecond
470
+ }
471
+ l .timer = c .clock .AfterFunc (ttl , func () {
472
+ c .peerLostTimeout (l .peerID )
473
+ })
256
474
}
257
475
258
476
// prefixesDifferent returns true if the two slices contain different prefixes
0 commit comments