diff --git a/tailnet/conn.go b/tailnet/conn.go index 8dcd835fe525a..4a10d0df8dba0 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -188,6 +188,25 @@ func NewConn(options *Options) (*Conn, error) { }, wireguardEngine: wireguardEngine, } + wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) { + if err != nil { + return + } + server.lastMutex.Lock() + server.lastEndpoints = make([]string, 0, len(s.LocalAddrs)) + for _, addr := range s.LocalAddrs { + server.lastEndpoints = append(server.lastEndpoints, addr.Addr.String()) + } + server.lastMutex.Unlock() + server.sendNode() + }) + wireguardEngine.SetNetInfoCallback(func(ni *tailcfg.NetInfo) { + server.lastMutex.Lock() + server.lastPreferredDERP = ni.PreferredDERP + server.lastDERPLatency = ni.DERPLatency + server.lastMutex.Unlock() + server.sendNode() + }) netStack.ForwardTCPIn = server.forwardTCP return server, nil } @@ -225,12 +244,15 @@ type Conn struct { listeners map[listenKey]*listener forwardTCPCallback func(conn net.Conn, listenerExists bool) net.Conn - lastMutex sync.Mutex + lastMutex sync.Mutex + nodeSending bool + nodeChanged bool // It's only possible to store these values via status functions, // so the values must be stored for retrieval later on. lastEndpoints []string lastPreferredDERP int lastDERPLatency map[string]float64 + nodeCallback func(node *Node) } // SetForwardTCPCallback is called every time a TCP connection is initiated inbound. @@ -244,56 +266,11 @@ func (c *Conn) SetForwardTCPCallback(callback func(conn net.Conn, listenerExists c.forwardTCPCallback = callback } -// SetNodeCallback is triggered when a network change occurs and peer -// renegotiation may be required. Clients should constantly be emitting -// node changes. func (c *Conn) SetNodeCallback(callback func(node *Node)) { - makeNode := func() *Node { - return &Node{ - ID: c.netMap.SelfNode.ID, - Key: c.netMap.SelfNode.Key, - Addresses: c.netMap.SelfNode.Addresses, - AllowedIPs: c.netMap.SelfNode.AllowedIPs, - DiscoKey: c.magicConn.DiscoPublicKey(), - Endpoints: c.lastEndpoints, - PreferredDERP: c.lastPreferredDERP, - DERPLatency: c.lastDERPLatency, - } - } - // A send queue must be used so nodes are sent in order. - queue := make(chan *Node, 16) - go func() { - for { - select { - case <-c.closed: - return - case node := <-queue: - c.logger.Debug(context.Background(), "send node callback", slog.F("node", node)) - callback(node) - } - } - }() - c.wireguardEngine.SetNetInfoCallback(func(ni *tailcfg.NetInfo) { - c.lastMutex.Lock() - c.lastPreferredDERP = ni.PreferredDERP - c.lastDERPLatency = ni.DERPLatency - node := makeNode() - queue <- node - c.lastMutex.Unlock() - }) - c.wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) { - if err != nil { - return - } - c.lastMutex.Lock() - c.lastEndpoints = make([]string, 0, len(s.LocalAddrs)) - for _, addr := range s.LocalAddrs { - c.lastEndpoints = append(c.lastEndpoints, addr.Addr.String()) - } - node := makeNode() - queue <- node - c.lastMutex.Unlock() - }) + c.lastMutex.Lock() + c.nodeCallback = callback + c.lastMutex.Unlock() + c.sendNode() } // SetDERPMap updates the DERPMap of a connection. @@ -361,6 +338,9 @@ func (c *Conn) UpdateNodes(nodes []*Node) error { } err = c.wireguardEngine.Reconfig(cfg, c.wireguardRouter, &dns.Config{}, &tailcfg.Debug{}) if err != nil { + if c.isClosed() { + return nil + } return xerrors.Errorf("reconfig: %w", err) } return nil @@ -416,6 +396,42 @@ func (c *Conn) isClosed() bool { } } +func (c *Conn) sendNode() { + c.lastMutex.Lock() + defer c.lastMutex.Unlock() + if c.nodeSending { + c.nodeChanged = true + return + } + node := &Node{ + ID: c.netMap.SelfNode.ID, + Key: c.netMap.SelfNode.Key, + Addresses: c.netMap.SelfNode.Addresses, + AllowedIPs: c.netMap.SelfNode.AllowedIPs, + DiscoKey: c.magicConn.DiscoPublicKey(), + Endpoints: c.lastEndpoints, + PreferredDERP: c.lastPreferredDERP, + DERPLatency: c.lastDERPLatency, + } + nodeCallback := c.nodeCallback + if nodeCallback == nil { + return + } + c.nodeSending = true + go func() { + nodeCallback(node) + c.lastMutex.Lock() + c.nodeSending = false + if c.nodeChanged { + c.nodeChanged = false + c.lastMutex.Unlock() + c.sendNode() + return + } + c.lastMutex.Unlock() + }() +} + // This and below is taken _mostly_ verbatim from Tailscale: // https://github.com/tailscale/tailscale/blob/c88bd53b1b7b2fcf7ba302f2e53dd1ce8c32dad4/tsnet/tsnet.go#L459-L494