Skip to content

Commit b8ec5c7

Browse files
authored
fix: Ensure tailnet coordinations are sent orderly (coder#4198)
1 parent c37ecdb commit b8ec5c7

File tree

3 files changed

+35
-26
lines changed

3 files changed

+35
-26
lines changed

codersdk/agentconn.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ type AgentConn struct {
4646
func (c *AgentConn) Ping() (time.Duration, error) {
4747
errCh := make(chan error, 1)
4848
durCh := make(chan time.Duration, 1)
49-
c.Conn.Ping(TailnetIP, tailcfg.PingICMP, func(pr *ipnstate.PingResult) {
49+
c.Conn.Ping(TailnetIP, tailcfg.PingDisco, func(pr *ipnstate.PingResult) {
5050
if pr.Err != "" {
5151
errCh <- xerrors.New(pr.Err)
5252
return

tailnet/conn.go

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ func NewConn(options *Options) (*Conn, error) {
182182
magicConn: magicConn,
183183
dialer: dialer,
184184
listeners: map[listenKey]*listener{},
185+
peerMap: map[tailcfg.NodeID]*tailcfg.Node{},
185186
tunDevice: tunDevice,
186187
netMap: netMap,
187188
netStack: netStack,
@@ -192,10 +193,17 @@ func NewConn(options *Options) (*Conn, error) {
192193
wireguardEngine: wireguardEngine,
193194
}
194195
wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) {
196+
server.logger.Info(context.Background(), "wireguard status", slog.F("status", s), slog.F("err", err))
195197
if err != nil {
196198
return
197199
}
198200
server.lastMutex.Lock()
201+
if s.AsOf.Before(server.lastStatus) {
202+
// Don't process outdated status!
203+
server.lastMutex.Unlock()
204+
return
205+
}
206+
server.lastStatus = s.AsOf
199207
server.lastEndpoints = make([]string, 0, len(s.LocalAddrs))
200208
for _, addr := range s.LocalAddrs {
201209
server.lastEndpoints = append(server.lastEndpoints, addr.Addr.String())
@@ -240,6 +248,7 @@ type Conn struct {
240248

241249
dialer *tsdial.Dialer
242250
tunDevice *tstun.Wrapper
251+
peerMap map[tailcfg.NodeID]*tailcfg.Node
243252
netMap *netmap.NetworkMap
244253
netStack *netstack.Impl
245254
magicConn *magicsock.Conn
@@ -254,6 +263,7 @@ type Conn struct {
254263
nodeChanged bool
255264
// It's only possible to store these values via status functions,
256265
// so the values must be stored for retrieval later on.
266+
lastStatus time.Time
257267
lastEndpoints []string
258268
lastPreferredDERP int
259269
lastDERPLatency map[string]float64
@@ -282,8 +292,9 @@ func (c *Conn) SetNodeCallback(callback func(node *Node)) {
282292
func (c *Conn) SetDERPMap(derpMap *tailcfg.DERPMap) {
283293
c.mutex.Lock()
284294
defer c.mutex.Unlock()
285-
c.netMap.DERPMap = derpMap
286295
c.logger.Debug(context.Background(), "updating derp map", slog.F("derp_map", derpMap))
296+
c.netMap.DERPMap = derpMap
297+
c.wireguardEngine.SetNetworkMap(c.netMap)
287298
c.wireguardEngine.SetDERPMap(derpMap)
288299
}
289300

@@ -292,18 +303,24 @@ func (c *Conn) SetDERPMap(derpMap *tailcfg.DERPMap) {
292303
func (c *Conn) UpdateNodes(nodes []*Node) error {
293304
c.mutex.Lock()
294305
defer c.mutex.Unlock()
295-
peerMap := map[tailcfg.NodeID]*tailcfg.Node{}
296306
status := c.Status()
297307
for _, peer := range c.netMap.Peers {
298-
if peerStatus, ok := status.Peer[peer.Key]; ok {
299-
// Clear out inactive connections!
300-
// If a connection hasn't been active for a minute post creation, we assume it's dead.
301-
if !peerStatus.Active && peer.Created.Before(time.Now().Add(-time.Minute)) {
302-
c.logger.Debug(context.Background(), "clearing peer", slog.F("peerStatus", peerStatus))
303-
continue
304-
}
308+
peerStatus, ok := status.Peer[peer.Key]
309+
if !ok {
310+
continue
305311
}
306-
peerMap[peer.ID] = peer
312+
// If this peer was added in the last 5 minutes, assume it
313+
// could still be active.
314+
if time.Since(peer.Created) < 5*time.Minute {
315+
continue
316+
}
317+
// We double-check that it's safe to remove by ensuring no
318+
// handshake has been sent in the past 5 minutes as well. Connections that
319+
// are actively exchanging IP traffic will handshake every 2 minutes.
320+
if time.Since(peerStatus.LastHandshake) < 5*time.Minute {
321+
continue
322+
}
323+
delete(c.peerMap, peer.ID)
307324
}
308325
for _, node := range nodes {
309326
peerStatus, ok := status.Peer[node.Key]
@@ -322,18 +339,11 @@ func (c *Conn) UpdateNodes(nodes []*Node) error {
322339
// reason. TODO: @kylecarbs debug this!
323340
KeepAlive: ok && peerStatus.Active,
324341
}
325-
existingNode, ok := peerMap[node.ID]
326-
if ok {
327-
peerNode.Created = existingNode.Created
328-
c.logger.Debug(context.Background(), "updating peer", slog.F("peer", peerNode))
329-
} else {
330-
c.logger.Debug(context.Background(), "adding peer", slog.F("peer", peerNode))
331-
}
332-
peerMap[node.ID] = peerNode
342+
c.peerMap[node.ID] = peerNode
333343
}
334-
c.netMap.Peers = make([]*tailcfg.Node, 0, len(peerMap))
335-
for _, peer := range peerMap {
336-
c.netMap.Peers = append(c.netMap.Peers, peer)
344+
c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap))
345+
for _, peer := range c.peerMap {
346+
c.netMap.Peers = append(c.netMap.Peers, peer.Clone())
337347
}
338348
netMapCopy := *c.netMap
339349
c.wireguardEngine.SetNetworkMap(&netMapCopy)
@@ -425,6 +435,7 @@ func (c *Conn) sendNode() {
425435
}
426436
c.nodeSending = true
427437
go func() {
438+
c.logger.Info(context.Background(), "sending node", slog.F("node", node))
428439
nodeCallback(node)
429440
c.lastMutex.Lock()
430441
c.nodeSending = false

tailnet/coordinator.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ func (c *Coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID)
164164
c.mutex.Unlock()
165165
continue
166166
}
167+
c.mutex.Unlock()
167168
// Write the new node from this client to the actively
168169
// connected agent.
169170
data, err := json.Marshal([]*Node{&node})
@@ -173,14 +174,11 @@ func (c *Coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID)
173174
}
174175
_, err = agentSocket.Write(data)
175176
if errors.Is(err, io.EOF) {
176-
c.mutex.Unlock()
177177
return nil
178178
}
179179
if err != nil {
180-
c.mutex.Unlock()
181180
return xerrors.Errorf("write json: %w", err)
182181
}
183-
c.mutex.Unlock()
184182
}
185183
}
186184

@@ -259,7 +257,7 @@ func (c *Coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error {
259257
wg.Done()
260258
}()
261259
}
262-
wg.Wait()
263260
c.mutex.Unlock()
261+
wg.Wait()
264262
}
265263
}

0 commit comments

Comments
 (0)