diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go index 7899868581afe..c306a03b8f46f 100644 --- a/derp/derphttp/derphttp_client.go +++ b/derp/derphttp/derphttp_client.go @@ -61,6 +61,10 @@ type Client struct { logf logger.Logf dialer func(ctx context.Context, network, addr string) (net.Conn, error) + // regionDialer allows the caller to override the dialer used to + // connect to DERP regions. If nil, the default dialer is used. + regionDialer func(ctx context.Context, r *tailcfg.DERPRegion) net.Conn + // Either url or getRegion is non-nil: url *url.URL getRegion func() *tailcfg.DERPRegion @@ -313,6 +317,31 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien } }() + if c.regionDialer != nil { + conn := c.regionDialer(ctx, reg) + if conn != nil { + brw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + derpClient, err := derp.NewClient(c.privateKey, conn, brw, c.logf, + derp.MeshKey(c.MeshKey), + derp.CanAckPings(c.canAckPings), + derp.IsProber(c.IsProber), + ) + if err != nil { + return nil, 0, err + } + if c.preferred { + // It's important that this happens in a goroutine because + // of synchronous I/O woes. + go derpClient.NotePreferred(true) + } + c.serverPubKey = derpClient.ServerPublicKey() + c.client = derpClient + c.netConn = conn + c.connGen++ + return c.client, c.connGen, nil + } + } + var node *tailcfg.DERPNode // nil when using c.url to dial switch { case c.useWebsockets(): @@ -513,6 +542,13 @@ func (c *Client) SetURLDialer(dialer func(ctx context.Context, network, addr str c.dialer = dialer } +// SetRegionDialer sets the dialer to use for dialing DERP regions. +func (c *Client) SetRegionDialer(dialer func(ctx context.Context, region *tailcfg.DERPRegion) net.Conn) { + c.mu.Lock() + c.regionDialer = dialer + c.mu.Unlock() +} + // SetForcedWebsocketCallback is a callback that is called when the client // decides to force WebSockets on the next connection attempt. func (c *Client) SetForcedWebsocketCallback(callback func(region int, reason string)) { diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 86255342dbf78..41ce49fb1e192 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -349,6 +349,9 @@ type Conn struct { // headers that are passed to the DERP HTTP client derpHeader atomic.Pointer[http.Header] + // derpRegionDialer is passed to the DERP client + derpRegionDialer atomic.Pointer[func(ctx context.Context, region *tailcfg.DERPRegion) net.Conn] + // stats maintains per-connection counters. stats atomic.Pointer[connstats.Statistics] @@ -1514,6 +1517,10 @@ func (c *Conn) derpWriteChanOfAddr(addr netip.AddrPort, peer key.NodePublic) cha if header != nil { dc.Header = header.Clone() } + dialer := c.derpRegionDialer.Load() + if dialer != nil { + dc.SetRegionDialer(*dialer) + } ctx, cancel := context.WithCancel(c.connCtx) ch := make(chan derpWriteRequest, bufferedDerpWritesBeforeDrop) @@ -2367,6 +2374,15 @@ func (c *Conn) SetDERPHeader(header http.Header) { c.derpHeader.Store(&header) } +func (c *Conn) SetDERPRegionDialer(dialer func(ctx context.Context, region *tailcfg.DERPRegion) net.Conn) { + c.derpRegionDialer.Store(&dialer) + c.mu.Lock() + defer c.mu.Unlock() + for _, d := range c.activeDerp { + d.c.SetRegionDialer(dialer) + } +} + func (c *Conn) SetNetworkUp(up bool) { c.mu.Lock() defer c.mu.Unlock()