diff --git a/tailnet/conn.go b/tailnet/conn.go index be098d16085c4..7ca123c76a5aa 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -110,6 +110,8 @@ type Options struct { // DNSConfigurator is optional, and is passed to the underlying wireguard // engine. DNSConfigurator dns.OSConfigurator + // Router is optional, and is passed to the underlying wireguard engine. + Router router.Router } // TelemetrySink allows tailnet.Conn to send network telemetry to the Coder @@ -183,6 +185,7 @@ func NewConn(options *Options) (conn *Conn, err error) { ListenPort: options.ListenPort, SetSubsystem: sys.Set, DNS: options.DNSConfigurator, + Router: options.Router, }) if err != nil { return nil, xerrors.Errorf("create wgengine: %w", err) diff --git a/vpn/router.go b/vpn/router.go new file mode 100644 index 0000000000000..45998ac9540a1 --- /dev/null +++ b/vpn/router.go @@ -0,0 +1,105 @@ +package vpn + +import ( + "net" + "net/netip" + + "tailscale.com/wgengine/router" +) + +func NewRouter(t *Tunnel) router.Router { + return &vpnRouter{tunnel: t} +} + +type vpnRouter struct { + tunnel *Tunnel +} + +func (*vpnRouter) Up() error { + // On macOS, the Desktop app will handle turning the VPN on and off. + // On Windows, this is a no-op. + return nil +} + +func (v *vpnRouter) Set(cfg *router.Config) error { + req := convertRouterConfig(cfg) + return v.tunnel.ApplyNetworkSettings(v.tunnel.ctx, req) +} + +func (*vpnRouter) Close() error { + // There's no cleanup that we need to initiate from within the dylib. + return nil +} + +func convertRouterConfig(cfg *router.Config) *NetworkSettingsRequest { + v4LocalAddrs := make([]string, 0) + v6LocalAddrs := make([]string, 0) + for _, addrs := range cfg.LocalAddrs { + if addrs.Addr().Is4() { + v4LocalAddrs = append(v4LocalAddrs, addrs.String()) + } else if addrs.Addr().Is6() { + v6LocalAddrs = append(v6LocalAddrs, addrs.String()) + } else { + continue + } + } + v4Routes := make([]*NetworkSettingsRequest_IPv4Settings_IPv4Route, 0) + v6Routes := make([]*NetworkSettingsRequest_IPv6Settings_IPv6Route, 0) + for _, route := range cfg.Routes { + if route.Addr().Is4() { + v4Routes = append(v4Routes, convertToIPV4Route(route)) + } else if route.Addr().Is6() { + v6Routes = append(v6Routes, convertToIPV6Route(route)) + } else { + continue + } + } + v4ExcludedRoutes := make([]*NetworkSettingsRequest_IPv4Settings_IPv4Route, 0) + v6ExcludedRoutes := make([]*NetworkSettingsRequest_IPv6Settings_IPv6Route, 0) + for _, route := range cfg.LocalRoutes { + if route.Addr().Is4() { + v4ExcludedRoutes = append(v4ExcludedRoutes, convertToIPV4Route(route)) + } else if route.Addr().Is6() { + v6ExcludedRoutes = append(v6ExcludedRoutes, convertToIPV6Route(route)) + } else { + continue + } + } + + return &NetworkSettingsRequest{ + Mtu: uint32(cfg.NewMTU), + Ipv4Settings: &NetworkSettingsRequest_IPv4Settings{ + Addrs: v4LocalAddrs, + IncludedRoutes: v4Routes, + ExcludedRoutes: v4ExcludedRoutes, + }, + Ipv6Settings: &NetworkSettingsRequest_IPv6Settings{ + Addrs: v6LocalAddrs, + IncludedRoutes: v6Routes, + ExcludedRoutes: v6ExcludedRoutes, + }, + TunnelOverheadBytes: 0, // N/A + TunnelRemoteAddress: "", // N/A + } +} + +func convertToIPV4Route(route netip.Prefix) *NetworkSettingsRequest_IPv4Settings_IPv4Route { + return &NetworkSettingsRequest_IPv4Settings_IPv4Route{ + Destination: route.Addr().String(), + Mask: prefixToSubnetMask(route), + Router: "", // N/A + } +} + +func convertToIPV6Route(route netip.Prefix) *NetworkSettingsRequest_IPv6Settings_IPv6Route { + return &NetworkSettingsRequest_IPv6Settings_IPv6Route{ + Destination: route.Addr().String(), + PrefixLength: uint32(route.Bits()), + Router: "", // N/A + } +} + +func prefixToSubnetMask(prefix netip.Prefix) string { + maskBytes := net.CIDRMask(prefix.Masked().Bits(), net.IPv4len*8) + return net.IP(maskBytes).String() +} diff --git a/vpn/router_internal_test.go b/vpn/router_internal_test.go new file mode 100644 index 0000000000000..777b53940e533 --- /dev/null +++ b/vpn/router_internal_test.go @@ -0,0 +1,74 @@ +package vpn + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/require" + "tailscale.com/wgengine/router" +) + +func TestConvertRouterConfig(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *router.Config + expected *NetworkSettingsRequest + }{ + { + name: "IPv4 and IPv6 configuration", + cfg: &router.Config{ + LocalAddrs: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32"), netip.MustParsePrefix("fd7a:115c:a1e0::1/128")}, + Routes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24"), netip.MustParsePrefix("fd00::/64")}, + LocalRoutes: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8"), netip.MustParsePrefix("2001:db8::/32")}, + NewMTU: 1500, + }, + expected: &NetworkSettingsRequest{ + Mtu: 1500, + Ipv4Settings: &NetworkSettingsRequest_IPv4Settings{ + Addrs: []string{"100.64.0.1/32"}, + IncludedRoutes: []*NetworkSettingsRequest_IPv4Settings_IPv4Route{ + {Destination: "192.168.0.0", Mask: "255.255.255.0", Router: ""}, + }, + ExcludedRoutes: []*NetworkSettingsRequest_IPv4Settings_IPv4Route{ + {Destination: "10.0.0.0", Mask: "255.0.0.0", Router: ""}, + }, + }, + Ipv6Settings: &NetworkSettingsRequest_IPv6Settings{ + Addrs: []string{"fd7a:115c:a1e0::1/128"}, + IncludedRoutes: []*NetworkSettingsRequest_IPv6Settings_IPv6Route{ + {Destination: "fd00::", PrefixLength: 64, Router: ""}, + }, + ExcludedRoutes: []*NetworkSettingsRequest_IPv6Settings_IPv6Route{ + {Destination: "2001:db8::", PrefixLength: 32, Router: ""}, + }, + }, + }, + }, + { + name: "Empty", + cfg: &router.Config{}, + expected: &NetworkSettingsRequest{ + Ipv4Settings: &NetworkSettingsRequest_IPv4Settings{ + Addrs: []string{}, + IncludedRoutes: []*NetworkSettingsRequest_IPv4Settings_IPv4Route{}, + ExcludedRoutes: []*NetworkSettingsRequest_IPv4Settings_IPv4Route{}, + }, + Ipv6Settings: &NetworkSettingsRequest_IPv6Settings{ + Addrs: []string{}, + IncludedRoutes: []*NetworkSettingsRequest_IPv6Settings_IPv6Route{}, + ExcludedRoutes: []*NetworkSettingsRequest_IPv6Settings_IPv6Route{}, + }, + }, + }, + } + //nolint:paralleltest // outdated rule + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := convertRouterConfig(tt.cfg) + require.Equal(t, tt.expected, result) + }) + } +}