diff --git a/net/netns/netns_windows.go b/net/netns/netns_windows.go index dd733542fb56f..f7554e222ed6a 100644 --- a/net/netns/netns_windows.go +++ b/net/netns/netns_windows.go @@ -8,7 +8,6 @@ import ( "math/bits" "net" "net/netip" - "strconv" "strings" "syscall" @@ -17,6 +16,7 @@ import ( "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "tailscale.com/net/interfaces" "tailscale.com/net/netmon" + "tailscale.com/net/tsaddr" "tailscale.com/types/logger" ) @@ -31,17 +31,6 @@ func interfaceIndex(iface *winipcfg.IPAdapterAddresses) uint32 { return iface.IfIndex } -// getBestInterface can be swapped out in tests. -var getBestInterface func(addr windows.Sockaddr, idx *uint32) error = windows.GetBestInterfaceEx - -// isInterfaceCoderInterface can be swapped out in tests. -var isInterfaceCoderInterface func(int) bool = isInterfaceCoderInterfaceDefault - -func isInterfaceCoderInterfaceDefault(idx int) bool { - _, tsif, err := interfaces.Coder() - return err == nil && tsif != nil && tsif.Index == idx -} - func control(logf logger.Logf, netMon *netmon.Monitor) func(network, address string, c syscall.RawConn) error { return func(network, address string, c syscall.RawConn) error { return controlLogf(logf, netMon, network, address, c) @@ -98,30 +87,17 @@ func shouldBindToDefaultInterface(logf logger.Logf, address string) bool { } if coderSoftIsolation.Load() { - sockAddr, err := getSockAddr(address) + addr, err := getAddr(address) if err != nil { - logf("[unexpected] netns: Coder soft isolation: error getting sockaddr for %q, binding to default: %v", address, err) + logf("[unexpected] netns: Coder soft isolation: error getting addr for %q, binding to default: %v", address, err) return true } - if sockAddr == nil { - // Unspecified addresses should not be bound to any interface. + if !addr.IsValid() || addr.IsUnspecified() { + // Invalid or unspecified addresses should not be bound to any + // interface. return false } - - // Ask Windows to find the best interface for this address by consulting - // the routing table. - // - // On macOS this value gets cached, but on Windows we don't need to - // because this API is very fast and doesn't require opening an AF_ROUTE - // socket. - var idx uint32 - err = getBestInterface(sockAddr, &idx) - if err != nil { - logf("[unexpected] netns: Coder soft isolation: error getting best interface, binding to default: %v", err) - return true - } - - if isInterfaceCoderInterface(int(idx)) { + if tsaddr.IsCoderIP(addr) { logf("[unexpected] netns: Coder soft isolation: detected socket destined for Coder interface, binding to default") return true } @@ -187,47 +163,31 @@ func nativeToBigEndian(i uint32) uint32 { return bits.ReverseBytes32(i) } -// getSockAddr returns the Windows sockaddr for the given address, or nil if -// the address is not specified. -func getSockAddr(address string) (windows.Sockaddr, error) { - host, port, err := net.SplitHostPort(address) +// getAddr returns the netip.Addr for the given address, or an invalid address +// if the address is not specified. Use addr.IsValid() to check for this. +func getAddr(address string) (netip.Addr, error) { + host, _, err := net.SplitHostPort(address) if err != nil { - return nil, fmt.Errorf("invalid address %q: %w", address, err) + return netip.Addr{}, fmt.Errorf("invalid address %q: %w", address, err) } if host == "" { // netip.ParseAddr("") will fail - return nil, nil + return netip.Addr{}, nil } addr, err := netip.ParseAddr(host) if err != nil { - return nil, fmt.Errorf("invalid address %q: %w", address, err) + return netip.Addr{}, fmt.Errorf("invalid address %q: %w", address, err) } if addr.Zone() != "" { // Addresses with zones *can* be represented as a Sockaddr with extra // effort, but we don't use or support them currently. - return nil, fmt.Errorf("invalid address %q, has zone: %w", address, err) + return netip.Addr{}, fmt.Errorf("invalid address %q, has zone: %w", address, err) } if addr.IsUnspecified() { // This covers the cases of 0.0.0.0 and [::]. - return nil, nil - } - - portInt, err := strconv.ParseUint(port, 10, 16) - if err != nil { - return nil, fmt.Errorf("invalid port %q: %w", port, err) + return netip.Addr{}, nil } - if addr.Is4() { - return &windows.SockaddrInet4{ - Port: int(portInt), // nolint:gosec // portInt is always in range - Addr: addr.As4(), - }, nil - } else if addr.Is6() { - return &windows.SockaddrInet6{ - Port: int(portInt), // nolint:gosec // portInt is always in range - Addr: addr.As16(), - }, nil - } - return nil, fmt.Errorf("invalid address %q, is not IPv4 or IPv6: %w", address, err) + return addr, nil } diff --git a/net/netns/netns_windows_test.go b/net/netns/netns_windows_test.go index f11fb0a2861a4..ab4eb110e9270 100644 --- a/net/netns/netns_windows_test.go +++ b/net/netns/netns_windows_test.go @@ -4,10 +4,10 @@ package netns import ( - "strconv" + "fmt" "testing" - "golang.org/x/sys/windows" + "tailscale.com/net/tsaddr" ) func TestShouldBindToDefaultInterface(t *testing.T) { @@ -34,59 +34,44 @@ func TestShouldBindToDefaultInterface(t *testing.T) { t.Run("CoderSoftIsolation", func(t *testing.T) { SetCoderSoftIsolation(true) - getBestInterface = func(addr windows.Sockaddr, idx *uint32) error { - *idx = 1 - return nil - } t.Cleanup(func() { SetCoderSoftIsolation(false) - getBestInterface = windows.GetBestInterfaceEx }) tests := []struct { - address string - isCoderInterface bool - want bool + address string + want bool }{ - // isCoderInterface shouldn't even matter for localhost since it has - // a special exemption. - {"127.0.0.1:0", false, false}, - {"127.0.0.1:0", true, false}, - {"127.0.0.1:1234", false, false}, - {"127.0.0.1:1234", true, false}, - - {"1.2.3.4:0", false, false}, - {"1.2.3.4:0", true, true}, - {"1.2.3.4:1234", false, false}, - {"1.2.3.4:1234", true, true}, + // localhost should still not bind to any interface. + {"127.0.0.1:0", false}, + {"127.0.0.1:0", false}, + {"127.0.0.1:1234", false}, + {"127.0.0.1:1234", false}, // Unspecified addresses should not be bound to any interface. - {":1234", false, false}, - {":1234", true, false}, - {"0.0.0.0:1234", false, false}, - {"0.0.0.0:1234", true, false}, - {"[::]:1234", false, false}, - {"[::]:1234", true, false}, + {":1234", false}, + {":1234", false}, + {"0.0.0.0:1234", false}, + {"0.0.0.0:1234", false}, + {"[::]:1234", false}, + {"[::]:1234", false}, // Special cases should always bind to default: - {"[::%eth0]:1234", false, true}, // zones are not supported - {"1.2.3.4:", false, true}, // port is empty - {"1.2.3.4:a", false, true}, // port is not a number - {"1.2.3.4:-1", false, true}, // port is negative - {"1.2.3.4:65536", false, true}, // port is too large + {"[::%eth0]:1234", true}, // zones are not supported + {"a:1234", true}, // not an IP + + // Coder IPs should bind to default. + {fmt.Sprintf("[%s]:8080", tsaddr.CoderServiceIPv6()), true}, + {fmt.Sprintf("[%s]:8080", tsaddr.CoderV6Range().Addr().Next()), true}, + // Non-Coder IPs should not bind to default. + {fmt.Sprintf("[%s]:8080", tsaddr.TailscaleServiceIPv6()), false}, + {fmt.Sprintf("%s:8080", tsaddr.TailscaleServiceIP()), false}, + {"1.2.3.4:8080", false}, } for _, test := range tests { - name := test.address + " (isCoderInterface=" + strconv.FormatBool(test.isCoderInterface) + ")" - t.Run(name, func(t *testing.T) { - isInterfaceCoderInterface = func(_ int) bool { - return test.isCoderInterface - } - defer func() { - isInterfaceCoderInterface = isInterfaceCoderInterfaceDefault - }() - + t.Run(test.address, func(t *testing.T) { got := shouldBindToDefaultInterface(t.Logf, test.address) if got != test.want { t.Errorf("want %v, got %v", test.want, got)