Skip to content

fix: Buffer tailnet nodes from connection initialization #4159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 22, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 66 additions & 50 deletions tailnet/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,25 @@ func NewConn(options *Options) (*Conn, error) {
},
wireguardEngine: wireguardEngine,
}
wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) {
if err != nil {
return
}
server.lastMutex.Lock()
server.lastEndpoints = make([]string, 0, len(s.LocalAddrs))
for _, addr := range s.LocalAddrs {
server.lastEndpoints = append(server.lastEndpoints, addr.Addr.String())
}
server.lastMutex.Unlock()
server.sendNode()
})
wireguardEngine.SetNetInfoCallback(func(ni *tailcfg.NetInfo) {
server.lastMutex.Lock()
server.lastPreferredDERP = ni.PreferredDERP
server.lastDERPLatency = ni.DERPLatency
server.lastMutex.Unlock()
server.sendNode()
})
netStack.ForwardTCPIn = server.forwardTCP
return server, nil
}
Expand Down Expand Up @@ -225,12 +244,15 @@ type Conn struct {
listeners map[listenKey]*listener
forwardTCPCallback func(conn net.Conn, listenerExists bool) net.Conn

lastMutex sync.Mutex
lastMutex sync.Mutex
nodeSending bool
nodeChanged bool
// It's only possible to store these values via status functions,
// so the values must be stored for retrieval later on.
lastEndpoints []string
lastPreferredDERP int
lastDERPLatency map[string]float64
nodeCallback func(node *Node)
}

// SetForwardTCPCallback is called every time a TCP connection is initiated inbound.
Expand All @@ -244,56 +266,11 @@ func (c *Conn) SetForwardTCPCallback(callback func(conn net.Conn, listenerExists
c.forwardTCPCallback = callback
}

// SetNodeCallback is triggered when a network change occurs and peer
// renegotiation may be required. Clients should constantly be emitting
// node changes.
func (c *Conn) SetNodeCallback(callback func(node *Node)) {
makeNode := func() *Node {
return &Node{
ID: c.netMap.SelfNode.ID,
Key: c.netMap.SelfNode.Key,
Addresses: c.netMap.SelfNode.Addresses,
AllowedIPs: c.netMap.SelfNode.AllowedIPs,
DiscoKey: c.magicConn.DiscoPublicKey(),
Endpoints: c.lastEndpoints,
PreferredDERP: c.lastPreferredDERP,
DERPLatency: c.lastDERPLatency,
}
}
// A send queue must be used so nodes are sent in order.
queue := make(chan *Node, 16)
go func() {
for {
select {
case <-c.closed:
return
case node := <-queue:
c.logger.Debug(context.Background(), "send node callback", slog.F("node", node))
callback(node)
}
}
}()
c.wireguardEngine.SetNetInfoCallback(func(ni *tailcfg.NetInfo) {
c.lastMutex.Lock()
c.lastPreferredDERP = ni.PreferredDERP
c.lastDERPLatency = ni.DERPLatency
node := makeNode()
queue <- node
c.lastMutex.Unlock()
})
c.wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) {
if err != nil {
return
}
c.lastMutex.Lock()
c.lastEndpoints = make([]string, 0, len(s.LocalAddrs))
for _, addr := range s.LocalAddrs {
c.lastEndpoints = append(c.lastEndpoints, addr.Addr.String())
}
node := makeNode()
queue <- node
c.lastMutex.Unlock()
})
c.lastMutex.Lock()
c.nodeCallback = callback
c.lastMutex.Unlock()
c.sendNode()
}

// SetDERPMap updates the DERPMap of a connection.
Expand Down Expand Up @@ -361,6 +338,9 @@ func (c *Conn) UpdateNodes(nodes []*Node) error {
}
err = c.wireguardEngine.Reconfig(cfg, c.wireguardRouter, &dns.Config{}, &tailcfg.Debug{})
if err != nil {
if c.isClosed() {
return nil
}
return xerrors.Errorf("reconfig: %w", err)
}
return nil
Expand Down Expand Up @@ -416,6 +396,42 @@ func (c *Conn) isClosed() bool {
}
}

func (c *Conn) sendNode() {
c.lastMutex.Lock()
defer c.lastMutex.Unlock()
if c.nodeSending {
c.nodeChanged = true
return
}
node := &Node{
ID: c.netMap.SelfNode.ID,
Key: c.netMap.SelfNode.Key,
Addresses: c.netMap.SelfNode.Addresses,
AllowedIPs: c.netMap.SelfNode.AllowedIPs,
DiscoKey: c.magicConn.DiscoPublicKey(),
Endpoints: c.lastEndpoints,
PreferredDERP: c.lastPreferredDERP,
DERPLatency: c.lastDERPLatency,
}
nodeCallback := c.nodeCallback
if nodeCallback == nil {
return
}
c.nodeSending = true
go func() {
nodeCallback(node)
c.lastMutex.Lock()
c.nodeSending = false
if c.nodeChanged {
c.nodeChanged = false
c.lastMutex.Unlock()
c.sendNode()
return
}
c.lastMutex.Unlock()
}()
}

// This and below is taken _mostly_ verbatim from Tailscale:
// https://github.com/tailscale/tailscale/blob/c88bd53b1b7b2fcf7ba302f2e53dd1ce8c32dad4/tsnet/tsnet.go#L459-L494

Expand Down