Skip to content

Commit e801e87

Browse files
authored
feat: add agent acks to in-memory coordinator (coder#12786)
When an agent receives a node, it responds with an ACK which is relayed to the client. After the client receives the ACK, it's allowed to begin pinging.
1 parent 9cf2358 commit e801e87

13 files changed

+878
-122
lines changed

codersdk/workspacesdk/connector.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,11 @@ func runTailnetAPIConnector(
8686
func (tac *tailnetAPIConnector) manageGracefulTimeout() {
8787
defer tac.cancelGracefulCtx()
8888
<-tac.ctx.Done()
89+
timer := time.NewTimer(time.Second)
90+
defer timer.Stop()
8991
select {
9092
case <-tac.closed:
91-
case <-time.After(time.Second):
93+
case <-timer.C:
9294
}
9395
}
9496

codersdk/workspacesdk/workspacesdk_internal_test.go renamed to codersdk/workspacesdk/connector_internal_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ func (*fakeTailnetConn) SetNodeCallback(func(*tailnet.Node)) {}
102102

103103
func (*fakeTailnetConn) SetDERPMap(*tailcfg.DERPMap) {}
104104

105+
func (*fakeTailnetConn) SetTunnelDestination(uuid.UUID) {}
106+
105107
func newFakeTailnetConn() *fakeTailnetConn {
106108
return &fakeTailnetConn{}
107109
}

enterprise/coderd/workspaceproxy.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ func (api *API) workspaceProxyRegister(rw http.ResponseWriter, r *http.Request)
658658
if err != nil {
659659
return xerrors.Errorf("insert replica: %w", err)
660660
}
661-
} else if err != nil {
661+
} else {
662662
return xerrors.Errorf("get replica: %w", err)
663663
}
664664

tailnet/configmaps.go

Lines changed: 160 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ func (c *configMaps) close() {
186186
c.L.Lock()
187187
defer c.L.Unlock()
188188
for _, lc := range c.peers {
189-
lc.resetTimer()
189+
lc.resetLostTimer()
190190
}
191191
c.closing = true
192192
c.Broadcast()
@@ -216,6 +216,12 @@ func (c *configMaps) netMapLocked() *netmap.NetworkMap {
216216
func (c *configMaps) peerConfigLocked() []*tailcfg.Node {
217217
out := make([]*tailcfg.Node, 0, len(c.peers))
218218
for _, p := range c.peers {
219+
// Don't add nodes that we havent received a READY_FOR_HANDSHAKE for
220+
// yet, if they're a destination. If we received a READY_FOR_HANDSHAKE
221+
// for a peer before we receive their node, the node will be nil.
222+
if (!p.readyForHandshake && p.isDestination) || p.node == nil {
223+
continue
224+
}
219225
n := p.node.Clone()
220226
if c.blockEndpoints {
221227
n.Endpoints = nil
@@ -225,6 +231,19 @@ func (c *configMaps) peerConfigLocked() []*tailcfg.Node {
225231
return out
226232
}
227233

234+
func (c *configMaps) setTunnelDestination(id uuid.UUID) {
235+
c.L.Lock()
236+
defer c.L.Unlock()
237+
lc, ok := c.peers[id]
238+
if !ok {
239+
lc = &peerLifecycle{
240+
peerID: id,
241+
}
242+
c.peers[id] = lc
243+
}
244+
lc.isDestination = true
245+
}
246+
228247
// setAddresses sets the addresses belonging to this node to the given slice. It
229248
// triggers configuration of the engine if the addresses have changed.
230249
// c.L MUST NOT be held.
@@ -331,8 +350,10 @@ func (c *configMaps) updatePeers(updates []*proto.CoordinateResponse_PeerUpdate)
331350
// worry about them being up-to-date when handling updates below, and it covers
332351
// all peers, not just the ones we got updates about.
333352
for _, lc := range c.peers {
334-
if peerStatus, ok := status.Peer[lc.node.Key]; ok {
335-
lc.lastHandshake = peerStatus.LastHandshake
353+
if lc.node != nil {
354+
if peerStatus, ok := status.Peer[lc.node.Key]; ok {
355+
lc.lastHandshake = peerStatus.LastHandshake
356+
}
336357
}
337358
}
338359

@@ -363,7 +384,7 @@ func (c *configMaps) updatePeerLocked(update *proto.CoordinateResponse_PeerUpdat
363384
return false
364385
}
365386
logger := c.logger.With(slog.F("peer_id", id))
366-
lc, ok := c.peers[id]
387+
lc, peerOk := c.peers[id]
367388
var node *tailcfg.Node
368389
if update.Kind == proto.CoordinateResponse_PeerUpdate_NODE {
369390
// If no preferred DERP is provided, we can't reach the node.
@@ -377,48 +398,76 @@ func (c *configMaps) updatePeerLocked(update *proto.CoordinateResponse_PeerUpdat
377398
return false
378399
}
379400
logger = logger.With(slog.F("key_id", node.Key.ShortString()), slog.F("node", node))
380-
peerStatus, ok := status.Peer[node.Key]
381-
// Starting KeepAlive messages at the initialization of a connection
382-
// causes a race condition. If we send the handshake before the peer has
383-
// our node, we'll have to wait for 5 seconds before trying again.
384-
// Ideally, the first handshake starts when the user first initiates a
385-
// connection to the peer. After a successful connection we enable
386-
// keep alives to persist the connection and keep it from becoming idle.
387-
// SSH connections don't send packets while idle, so we use keep alives
388-
// to avoid random hangs while we set up the connection again after
389-
// inactivity.
390-
node.KeepAlive = ok && peerStatus.Active
401+
node.KeepAlive = c.nodeKeepalive(lc, status, node)
391402
}
392403
switch {
393-
case !ok && update.Kind == proto.CoordinateResponse_PeerUpdate_NODE:
404+
case !peerOk && update.Kind == proto.CoordinateResponse_PeerUpdate_NODE:
394405
// new!
395406
var lastHandshake time.Time
396407
if ps, ok := status.Peer[node.Key]; ok {
397408
lastHandshake = ps.LastHandshake
398409
}
399-
c.peers[id] = &peerLifecycle{
410+
lc = &peerLifecycle{
400411
peerID: id,
401412
node: node,
402413
lastHandshake: lastHandshake,
403414
lost: false,
404415
}
416+
c.peers[id] = lc
405417
logger.Debug(context.Background(), "adding new peer")
406-
return true
407-
case ok && update.Kind == proto.CoordinateResponse_PeerUpdate_NODE:
418+
return lc.validForWireguard()
419+
case peerOk && update.Kind == proto.CoordinateResponse_PeerUpdate_NODE:
408420
// update
409-
node.Created = lc.node.Created
421+
if lc.node != nil {
422+
node.Created = lc.node.Created
423+
}
410424
dirty = !lc.node.Equal(node)
411425
lc.node = node
426+
// validForWireguard checks that the node is non-nil, so should be
427+
// called after we update the node.
428+
dirty = dirty && lc.validForWireguard()
412429
lc.lost = false
413-
lc.resetTimer()
430+
lc.resetLostTimer()
431+
if lc.isDestination && !lc.readyForHandshake {
432+
// We received the node of a destination peer before we've received
433+
// their READY_FOR_HANDSHAKE. Set a timer
434+
lc.setReadyForHandshakeTimer(c)
435+
logger.Debug(context.Background(), "setting ready for handshake timeout")
436+
}
414437
logger.Debug(context.Background(), "node update to existing peer", slog.F("dirty", dirty))
415438
return dirty
416-
case !ok:
439+
case peerOk && update.Kind == proto.CoordinateResponse_PeerUpdate_READY_FOR_HANDSHAKE:
440+
dirty := !lc.readyForHandshake
441+
lc.readyForHandshake = true
442+
if lc.readyForHandshakeTimer != nil {
443+
lc.readyForHandshakeTimer.Stop()
444+
}
445+
if lc.node != nil {
446+
old := lc.node.KeepAlive
447+
lc.node.KeepAlive = c.nodeKeepalive(lc, status, lc.node)
448+
dirty = dirty || (old != lc.node.KeepAlive)
449+
}
450+
logger.Debug(context.Background(), "peer ready for handshake")
451+
// only force a reconfig if the node populated
452+
return dirty && lc.node != nil
453+
case !peerOk && update.Kind == proto.CoordinateResponse_PeerUpdate_READY_FOR_HANDSHAKE:
454+
// When we receive a READY_FOR_HANDSHAKE for a peer we don't know about,
455+
// we create a peerLifecycle with the peerID and set readyForHandshake
456+
// to true. Eventually we should receive a NODE update for this peer,
457+
// and it'll be programmed into wireguard.
458+
logger.Debug(context.Background(), "got peer ready for handshake for unknown peer")
459+
lc = &peerLifecycle{
460+
peerID: id,
461+
readyForHandshake: true,
462+
}
463+
c.peers[id] = lc
464+
return false
465+
case !peerOk:
417466
// disconnected or lost, but we don't have the node. No op
418467
logger.Debug(context.Background(), "skipping update for peer we don't recognize")
419468
return false
420469
case update.Kind == proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
421-
lc.resetTimer()
470+
lc.resetLostTimer()
422471
delete(c.peers, id)
423472
logger.Debug(context.Background(), "disconnected peer")
424473
return true
@@ -476,10 +525,12 @@ func (c *configMaps) peerLostTimeout(id uuid.UUID) {
476525
"timeout triggered for peer that is removed from the map")
477526
return
478527
}
479-
if peerStatus, ok := status.Peer[lc.node.Key]; ok {
480-
lc.lastHandshake = peerStatus.LastHandshake
528+
if lc.node != nil {
529+
if peerStatus, ok := status.Peer[lc.node.Key]; ok {
530+
lc.lastHandshake = peerStatus.LastHandshake
531+
}
532+
logger = logger.With(slog.F("key_id", lc.node.Key.ShortString()))
481533
}
482-
logger = logger.With(slog.F("key_id", lc.node.Key.ShortString()))
483534
if !lc.lost {
484535
logger.Debug(context.Background(),
485536
"timeout triggered for peer that is no longer lost")
@@ -522,7 +573,7 @@ func (c *configMaps) nodeAddresses(publicKey key.NodePublic) ([]netip.Prefix, bo
522573
c.L.Lock()
523574
defer c.L.Unlock()
524575
for _, lc := range c.peers {
525-
if lc.node.Key == publicKey {
576+
if lc.node != nil && lc.node.Key == publicKey {
526577
return lc.node.Addresses, true
527578
}
528579
}
@@ -539,9 +590,10 @@ func (c *configMaps) fillPeerDiagnostics(d *PeerDiagnostics, peerID uuid.UUID) {
539590
}
540591
}
541592
lc, ok := c.peers[peerID]
542-
if !ok {
593+
if !ok || lc.node == nil {
543594
return
544595
}
596+
545597
d.ReceivedNode = lc.node
546598
ps, ok := status.Peer[lc.node.Key]
547599
if !ok {
@@ -550,34 +602,102 @@ func (c *configMaps) fillPeerDiagnostics(d *PeerDiagnostics, peerID uuid.UUID) {
550602
d.LastWireguardHandshake = ps.LastHandshake
551603
}
552604

605+
func (c *configMaps) peerReadyForHandshakeTimeout(peerID uuid.UUID) {
606+
logger := c.logger.With(slog.F("peer_id", peerID))
607+
logger.Debug(context.Background(), "peer ready for handshake timeout")
608+
c.L.Lock()
609+
defer c.L.Unlock()
610+
lc, ok := c.peers[peerID]
611+
if !ok {
612+
logger.Debug(context.Background(),
613+
"ready for handshake timeout triggered for peer that is removed from the map")
614+
return
615+
}
616+
617+
wasReady := lc.readyForHandshake
618+
lc.readyForHandshake = true
619+
if !wasReady {
620+
logger.Info(context.Background(), "setting peer ready for handshake after timeout")
621+
c.netmapDirty = true
622+
c.Broadcast()
623+
}
624+
}
625+
626+
func (*configMaps) nodeKeepalive(lc *peerLifecycle, status *ipnstate.Status, node *tailcfg.Node) bool {
627+
// If the peer is already active, keepalives should be enabled.
628+
if peerStatus, statusOk := status.Peer[node.Key]; statusOk && peerStatus.Active {
629+
return true
630+
}
631+
// If the peer is a destination, we should only enable keepalives if we've
632+
// received the READY_FOR_HANDSHAKE.
633+
if lc != nil && lc.isDestination && lc.readyForHandshake {
634+
return true
635+
}
636+
637+
// If none of the above are true, keepalives should not be enabled.
638+
return false
639+
}
640+
553641
type peerLifecycle struct {
554-
peerID uuid.UUID
555-
node *tailcfg.Node
556-
lost bool
557-
lastHandshake time.Time
558-
timer *clock.Timer
642+
peerID uuid.UUID
643+
// isDestination specifies if the peer is a destination, meaning we
644+
// initiated a tunnel to the peer. When the peer is a destination, we do not
645+
// respond to node updates with `READY_FOR_HANDSHAKE`s, and we wait to
646+
// program the peer into wireguard until we receive a READY_FOR_HANDSHAKE
647+
// from the peer or the timeout is reached.
648+
isDestination bool
649+
// node is the tailcfg.Node for the peer. It may be nil until we receive a
650+
// NODE update for it.
651+
node *tailcfg.Node
652+
lost bool
653+
lastHandshake time.Time
654+
lostTimer *clock.Timer
655+
readyForHandshake bool
656+
readyForHandshakeTimer *clock.Timer
559657
}
560658

561-
func (l *peerLifecycle) resetTimer() {
562-
if l.timer != nil {
563-
l.timer.Stop()
564-
l.timer = nil
659+
func (l *peerLifecycle) resetLostTimer() {
660+
if l.lostTimer != nil {
661+
l.lostTimer.Stop()
662+
l.lostTimer = nil
565663
}
566664
}
567665

568666
func (l *peerLifecycle) setLostTimer(c *configMaps) {
569-
if l.timer != nil {
570-
l.timer.Stop()
667+
if l.lostTimer != nil {
668+
l.lostTimer.Stop()
571669
}
572670
ttl := lostTimeout - c.clock.Since(l.lastHandshake)
573671
if ttl <= 0 {
574672
ttl = time.Nanosecond
575673
}
576-
l.timer = c.clock.AfterFunc(ttl, func() {
674+
l.lostTimer = c.clock.AfterFunc(ttl, func() {
577675
c.peerLostTimeout(l.peerID)
578676
})
579677
}
580678

679+
const readyForHandshakeTimeout = 5 * time.Second
680+
681+
func (l *peerLifecycle) setReadyForHandshakeTimer(c *configMaps) {
682+
if l.readyForHandshakeTimer != nil {
683+
l.readyForHandshakeTimer.Stop()
684+
}
685+
l.readyForHandshakeTimer = c.clock.AfterFunc(readyForHandshakeTimeout, func() {
686+
c.logger.Debug(context.Background(), "ready for handshake timeout", slog.F("peer_id", l.peerID))
687+
c.peerReadyForHandshakeTimeout(l.peerID)
688+
})
689+
}
690+
691+
// validForWireguard returns true if the peer is ready to be programmed into
692+
// wireguard.
693+
func (l *peerLifecycle) validForWireguard() bool {
694+
valid := l.node != nil
695+
if l.isDestination {
696+
return valid && l.readyForHandshake
697+
}
698+
return valid
699+
}
700+
581701
// prefixesDifferent returns true if the two slices contain different prefixes
582702
// where order doesn't matter.
583703
func prefixesDifferent(a, b []netip.Prefix) bool {

0 commit comments

Comments
 (0)