From 7fb14b168fd544661564ce3cfbfee7ae9a01b6f7 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Mon, 18 Sep 2023 13:31:37 +0000 Subject: [PATCH] fix: use magicsock DERPHeaders in netcheck --- net/netcheck/netcheck.go | 13 +++++ net/netcheck/netcheck_test.go | 91 +++++++++++++++++++++++++++++++++ wgengine/magicsock/magicsock.go | 7 +++ 3 files changed, 111 insertions(+) diff --git a/net/netcheck/netcheck.go b/net/netcheck/netcheck.go index d571657ceef4a..4a573b7b4443b 100644 --- a/net/netcheck/netcheck.go +++ b/net/netcheck/netcheck.go @@ -203,6 +203,10 @@ type Client struct { // If false, the default net.Resolver will be used, with no caching. UseDNSCache bool + // GetDERPHeaders optionally specifies headers to send with all HTTP(S) DERP + // probes. + GetDERPHeaders func() http.Header + // For tests testEnoughRegions int testCaptivePortalDelay time.Duration @@ -1277,7 +1281,13 @@ func (c *Client) measureHTTPLatency(ctx context.Context, reg *tailcfg.DERPRegion var ip netip.Addr + derpHeaders := http.Header{} + if c.GetDERPHeaders != nil { + derpHeaders = c.GetDERPHeaders() + } + dc := derphttp.NewNetcheckClient(c.logf) + dc.Header = derpHeaders defer dc.Close() var hasForceHTTPNode = false @@ -1356,6 +1366,9 @@ func (c *Client) measureHTTPLatency(ctx context.Context, reg *tailcfg.DERPRegion if err != nil { return 0, ip, err } + for k := range derpHeaders { + req.Header.Set(k, derpHeaders.Get(k)) + } resp, err := hc.Do(req) if err != nil { diff --git a/net/netcheck/netcheck_test.go b/net/netcheck/netcheck_test.go index eeb407d99f852..217af5e536dc7 100644 --- a/net/netcheck/netcheck_test.go +++ b/net/netcheck/netcheck_test.go @@ -6,9 +6,11 @@ package netcheck import ( "bytes" "context" + "crypto/tls" "fmt" "net" "net/http" + "net/http/httptest" "net/netip" "reflect" "sort" @@ -18,11 +20,15 @@ import ( "testing" "time" + "tailscale.com/derp" + "tailscale.com/derp/derphttp" "tailscale.com/net/interfaces" "tailscale.com/net/stun" "tailscale.com/net/stun/stuntest" "tailscale.com/tailcfg" "tailscale.com/tstest" + "tailscale.com/types/key" + "tailscale.com/types/logger" ) func TestHairpinSTUN(t *testing.T) { @@ -966,3 +972,88 @@ func TestNodeAddrResolve(t *testing.T) { }) } } + +func TestProbeHeaders(t *testing.T) { + logf, closeLogf := logger.LogfCloser(t.Logf) + defer closeLogf() + + // Create a DERP server manually, without a STUN server and with a custom + // handler. + derpServer := derp.NewServer(key.NewNode(), logf) + derpHandler := derphttp.Handler(derpServer) + + expectedHeaders := http.Header{} + expectedHeaders.Set("X-Cool-Test", "yes") + expectedHeaders.Set("X-Proxy-Auth-Key", "blah blah blah") + + var called atomic.Bool + httpsrv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called.Store(true) + for k, v := range expectedHeaders { + if got := r.Header[k]; !reflect.DeepEqual(got, v) { + t.Errorf("unexpected header %q: got %q; want %q", k, got, v) + } + } + + if r.URL.Path == "/derp/latency-check" { + w.WriteHeader(http.StatusOK) + return + } + if r.URL.Path == "/derp" { + derpHandler.ServeHTTP(w, r) + return + } + + t.Errorf("unexpected request: %v", r.URL) + w.WriteHeader(http.StatusNotFound) + })) + httpsrv.Config.ErrorLog = logger.StdLogger(logf) + httpsrv.Config.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) + httpsrv.StartTLS() + t.Cleanup(func() { + httpsrv.CloseClientConnections() + httpsrv.Close() + derpServer.Close() + }) + + derpMap := &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: { + RegionID: 1, + RegionCode: "derpy", + Nodes: []*tailcfg.DERPNode{ + { + Name: "d1", + RegionID: 1, + HostName: "localhost", + // Don't specify an IP address to avoid ICMP pinging, + // which will bypass the artificial latency. + IPv4: "", + IPv6: "", + STUNPort: -1, + DERPPort: httpsrv.Listener.Addr().(*net.TCPAddr).Port, + InsecureForTests: true, + }, + }, + }, + }, + } + + c := &Client{ + Logf: t.Logf, + UDPBindAddr: "127.0.0.1:0", + GetDERPHeaders: func() http.Header { return expectedHeaders.Clone() }, + } + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + _, err := c.GetReport(ctx, derpMap) + if err != nil { + t.Fatal(err) + } + + if !called.Load() { + t.Error("didn't call test handler") + } +} diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 10e2d4a67336e..0d46a0424bb65 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -443,6 +443,13 @@ func NewConn(opts Options) (*Conn, error) { SkipExternalNetwork: inTest(), PortMapper: c.portMapper, UseDNSCache: true, + GetDERPHeaders: func() http.Header { + h := c.derpHeader.Load() + if h == nil { + return nil + } + return h.Clone() + }, } c.ignoreSTUNPackets()