diff --git a/tailnet/conn.go b/tailnet/conn.go index 1217bdeb6f0f7..be098d16085c4 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -22,6 +22,7 @@ import ( "tailscale.com/envknob" "tailscale.com/ipn/ipnstate" "tailscale.com/net/connstats" + "tailscale.com/net/dns" "tailscale.com/net/netmon" "tailscale.com/net/netns" "tailscale.com/net/tsdial" @@ -106,6 +107,9 @@ type Options struct { ClientType proto.TelemetryEvent_ClientType // TelemetrySink is optional. TelemetrySink TelemetrySink + // DNSConfigurator is optional, and is passed to the underlying wireguard + // engine. + DNSConfigurator dns.OSConfigurator } // TelemetrySink allows tailnet.Conn to send network telemetry to the Coder @@ -178,6 +182,7 @@ func NewConn(options *Options) (conn *Conn, err error) { Dialer: dialer, ListenPort: options.ListenPort, SetSubsystem: sys.Set, + DNS: options.DNSConfigurator, }) if err != nil { return nil, xerrors.Errorf("create wgengine: %w", err) diff --git a/vpn/dns.go b/vpn/dns.go new file mode 100644 index 0000000000000..7e4ea5bbd29a0 --- /dev/null +++ b/vpn/dns.go @@ -0,0 +1,58 @@ +package vpn + +import "tailscale.com/net/dns" + +func NewDNSConfigurator(t *Tunnel) dns.OSConfigurator { + return &dnsManager{tunnel: t} +} + +type dnsManager struct { + tunnel *Tunnel +} + +func (v *dnsManager) SetDNS(cfg dns.OSConfig) error { + settings := convertDNSConfig(cfg) + return v.tunnel.ApplyNetworkSettings(v.tunnel.ctx, &NetworkSettingsRequest{ + DnsSettings: settings, + }) +} + +func (*dnsManager) GetBaseConfig() (dns.OSConfig, error) { + // Tailscale calls this function to blend the OS's DNS configuration with + // it's own, so this is only called if `SupportsSplitDNS` returns false. + return dns.OSConfig{}, dns.ErrGetBaseConfigNotSupported +} + +func (*dnsManager) SupportsSplitDNS() bool { + // macOS & Windows 10+ support split DNS, so we'll assume all CoderVPN + // clients do too. + return true +} + +// Close implements dns.OSConfigurator. +func (*dnsManager) Close() error { + // There's no cleanup that we need to initiate from within the dylib. + return nil +} + +func convertDNSConfig(cfg dns.OSConfig) *NetworkSettingsRequest_DNSSettings { + servers := make([]string, 0, len(cfg.Nameservers)) + for _, ns := range cfg.Nameservers { + servers = append(servers, ns.String()) + } + searchDomains := make([]string, 0, len(cfg.SearchDomains)) + for _, domain := range cfg.SearchDomains { + searchDomains = append(searchDomains, domain.WithoutTrailingDot()) + } + matchDomains := make([]string, 0, len(cfg.MatchDomains)) + for _, domain := range cfg.MatchDomains { + matchDomains = append(matchDomains, domain.WithoutTrailingDot()) + } + return &NetworkSettingsRequest_DNSSettings{ + Servers: servers, + SearchDomains: searchDomains, + DomainName: "coder", + MatchDomains: matchDomains, + MatchDomainsNoSearch: false, + } +} diff --git a/vpn/dns_internal_test.go b/vpn/dns_internal_test.go new file mode 100644 index 0000000000000..a4fa61aec1d66 --- /dev/null +++ b/vpn/dns_internal_test.go @@ -0,0 +1,73 @@ +package vpn + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/require" + "tailscale.com/net/dns" + "tailscale.com/util/dnsname" +) + +func TestConvertDNSConfig(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input dns.OSConfig + expected *NetworkSettingsRequest_DNSSettings + }{ + { + name: "Basic", + input: dns.OSConfig{ + Nameservers: []netip.Addr{ + netip.MustParseAddr("1.1.1.1"), + netip.MustParseAddr("8.8.8.8"), + }, + SearchDomains: []dnsname.FQDN{ + "example.com.", + "test.local.", + }, + MatchDomains: []dnsname.FQDN{ + "internal.domain.", + }, + }, + expected: &NetworkSettingsRequest_DNSSettings{ + Servers: []string{"1.1.1.1", "8.8.8.8"}, + SearchDomains: []string{"example.com", "test.local"}, + DomainName: "coder", + MatchDomains: []string{"internal.domain"}, + MatchDomainsNoSearch: false, + }, + }, + { + name: "Empty", + input: dns.OSConfig{ + Nameservers: []netip.Addr{}, + SearchDomains: []dnsname.FQDN{}, + MatchDomains: []dnsname.FQDN{}, + }, + expected: &NetworkSettingsRequest_DNSSettings{ + Servers: []string{}, + SearchDomains: []string{}, + DomainName: "coder", + MatchDomains: []string{}, + MatchDomainsNoSearch: false, + }, + }, + } + + //nolint:paralleltest // outdated rule + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result := convertDNSConfig(tt.input) + require.Equal(t, tt.expected.Servers, result.Servers) + require.Equal(t, tt.expected.SearchDomains, result.SearchDomains) + require.Equal(t, tt.expected.DomainName, result.DomainName) + require.Equal(t, tt.expected.MatchDomains, result.MatchDomains) + require.Equal(t, tt.expected.MatchDomainsNoSearch, result.MatchDomainsNoSearch) + }) + } +}