Skip to content

Commit c0746cf

Browse files
Maisem Alimaisem
Maisem Ali
authored andcommitted
net/tsdial: add SystemDial as a wrapper on netns.Dial
The connections returned from SystemDial are automatically closed when there is a major link change. Also plumb through the dialer to the noise client so that connections are auto-reset when moving from cellular to WiFi etc. Updates tailscale#3363 Signed-off-by: Maisem Ali <maisem@tailscale.com> (cherry picked from commit 5a1ef1b)
1 parent 5ff23cb commit c0746cf

File tree

9 files changed

+138
-20
lines changed

9 files changed

+138
-20
lines changed

cmd/tailscaled/tailscaled.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ func run() error {
332332
socksListener, httpProxyListener := mustStartProxyListeners(args.socksAddr, args.httpProxyAddr)
333333

334334
dialer := new(tsdial.Dialer) // mutated below (before used)
335+
dialer.Logf = logf
335336
e, useNetstack, err := createEngine(logf, linkMon, dialer)
336337
if err != nil {
337338
return fmt.Errorf("createEngine: %w", err)
@@ -394,6 +395,7 @@ func run() error {
394395
// want to keep running.
395396
signal.Ignore(syscall.SIGPIPE)
396397
go func() {
398+
defer dialer.Close()
397399
select {
398400
case s := <-interrupt:
399401
logf("tailscaled got signal %v; shutting down", s)

control/controlclient/direct.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ import (
3838
"tailscale.com/net/dnscache"
3939
"tailscale.com/net/dnsfallback"
4040
"tailscale.com/net/interfaces"
41-
"tailscale.com/net/netns"
4241
"tailscale.com/net/netutil"
4342
"tailscale.com/net/tlsdial"
43+
"tailscale.com/net/tsdial"
4444
"tailscale.com/net/tshttpproxy"
4545
"tailscale.com/tailcfg"
4646
"tailscale.com/types/key"
@@ -57,7 +57,8 @@ import (
5757
// Direct is the client that connects to a tailcontrol server for a node.
5858
type Direct struct {
5959
httpc *http.Client // HTTP client used to talk to tailcontrol
60-
serverURL string // URL of the tailcontrol server
60+
dialer *tsdial.Dialer
61+
serverURL string // URL of the tailcontrol server
6162
timeNow func() time.Time
6263
lastPrintMap time.Time
6364
newDecompressor func() (Decompressor, error)
@@ -106,6 +107,7 @@ type Options struct {
106107
DebugFlags []string // debug settings to send to control
107108
LinkMonitor *monitor.Mon // optional link monitor
108109
PopBrowserURL func(url string) // optional func to open browser
110+
Dialer *tsdial.Dialer // non-nil
109111

110112
// KeepSharerAndUserSplit controls whether the client
111113
// understands Node.Sharer. If false, the Sharer is mapped to the User.
@@ -170,13 +172,12 @@ func NewDirect(opts Options) (*Direct, error) {
170172
UseLastGood: true,
171173
LookupIPFallback: dnsfallback.Lookup,
172174
}
173-
dialer := netns.NewDialer(opts.Logf)
174175
tr := http.DefaultTransport.(*http.Transport).Clone()
175176
tr.Proxy = tshttpproxy.ProxyFromEnvironment
176177
tshttpproxy.SetTransportGetProxyConnectHeader(tr)
177178
tr.TLSClientConfig = tlsdial.Config(serverURL.Hostname(), tr.TLSClientConfig)
178-
tr.DialContext = dnscache.Dialer(dialer.DialContext, dnsCache)
179-
tr.DialTLSContext = dnscache.TLSDialer(dialer.DialContext, dnsCache, tr.TLSClientConfig)
179+
tr.DialContext = dnscache.Dialer(opts.Dialer.SystemDial, dnsCache)
180+
tr.DialTLSContext = dnscache.TLSDialer(opts.Dialer.SystemDial, dnsCache, tr.TLSClientConfig)
180181
tr.ForceAttemptHTTP2 = true
181182
// Disable implicit gzip compression; the various
182183
// handlers (register, map, set-dns, etc) do their own
@@ -202,6 +203,7 @@ func NewDirect(opts Options) (*Direct, error) {
202203
skipIPForwardingCheck: opts.SkipIPForwardingCheck,
203204
pinger: opts.Pinger,
204205
popBrowser: opts.PopBrowserURL,
206+
dialer: opts.Dialer,
205207
}
206208
if opts.Hostinfo == nil {
207209
c.SetHostinfo(hostinfo.New())
@@ -1278,7 +1280,7 @@ func (c *Direct) getNoiseClient() (*noiseClient, error) {
12781280
return nil, err
12791281
}
12801282

1281-
nc, err = newNoiseClient(k, serverNoiseKey, c.serverURL)
1283+
nc, err = newNoiseClient(k, serverNoiseKey, c.serverURL, c.dialer)
12821284
if err != nil {
12831285
return nil, err
12841286
}

control/controlclient/direct_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"inet.af/netaddr"
1515
"tailscale.com/hostinfo"
1616
"tailscale.com/ipn/ipnstate"
17+
"tailscale.com/net/tsdial"
1718
"tailscale.com/tailcfg"
1819
"tailscale.com/types/key"
1920
)
@@ -30,6 +31,7 @@ func TestNewDirect(t *testing.T) {
3031
GetMachinePrivateKey: func() (key.MachinePrivate, error) {
3132
return k, nil
3233
},
34+
Dialer: new(tsdial.Dialer),
3335
}
3436
c, err := NewDirect(opts)
3537
if err != nil {
@@ -106,6 +108,7 @@ func TestTsmpPing(t *testing.T) {
106108
GetMachinePrivateKey: func() (key.MachinePrivate, error) {
107109
return k, nil
108110
},
111+
Dialer: new(tsdial.Dialer),
109112
}
110113

111114
c, err := NewDirect(opts)

control/controlclient/noise.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"golang.org/x/net/http2"
1919
"tailscale.com/control/controlbase"
2020
"tailscale.com/control/controlhttp"
21+
"tailscale.com/net/tsdial"
2122
"tailscale.com/tailcfg"
2223
"tailscale.com/types/key"
2324
"tailscale.com/util/mak"
@@ -46,6 +47,7 @@ func (c *noiseConn) Close() error {
4647
// the ts2021 protocol.
4748
type noiseClient struct {
4849
*http.Client // HTTP client used to talk to tailcontrol
50+
dialer *tsdial.Dialer
4951
privKey key.MachinePrivate
5052
serverPubKey key.MachinePublic
5153
serverHost string // the host:port part of serverURL
@@ -58,7 +60,7 @@ type noiseClient struct {
5860

5961
// newNoiseClient returns a new noiseClient for the provided server and machine key.
6062
// serverURL is of the form https://<host>:<port> (no trailing slash).
61-
func newNoiseClient(priKey key.MachinePrivate, serverPubKey key.MachinePublic, serverURL string) (*noiseClient, error) {
63+
func newNoiseClient(priKey key.MachinePrivate, serverPubKey key.MachinePublic, serverURL string, dialer *tsdial.Dialer) (*noiseClient, error) {
6264
u, err := url.Parse(serverURL)
6365
if err != nil {
6466
return nil, err
@@ -75,6 +77,7 @@ func newNoiseClient(priKey key.MachinePrivate, serverPubKey key.MachinePublic, s
7577
serverPubKey: serverPubKey,
7678
privKey: priKey,
7779
serverHost: host,
80+
dialer: dialer,
7881
}
7982

8083
// Create the HTTP/2 Transport using a net/http.Transport
@@ -151,7 +154,7 @@ func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) {
151154
// thousand version numbers before getting to this point.
152155
panic("capability version is too high to fit in the wire protocol")
153156
}
154-
conn, err := controlhttp.Dial(ctx, nc.serverHost, nc.privKey, nc.serverPubKey, uint16(tailcfg.CurrentCapabilityVersion))
157+
conn, err := controlhttp.Dial(ctx, nc.serverHost, nc.privKey, nc.serverPubKey, uint16(tailcfg.CurrentCapabilityVersion), nc.dialer.SystemDial)
155158
if err != nil {
156159
return nil, err
157160
}

control/controlhttp/client.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import (
2525
"errors"
2626
"fmt"
2727
"io"
28-
"log"
2928
"net"
3029
"net/http"
3130
"net/http/httptrace"
@@ -35,7 +34,6 @@ import (
3534
"tailscale.com/control/controlbase"
3635
"tailscale.com/net/dnscache"
3736
"tailscale.com/net/dnsfallback"
38-
"tailscale.com/net/netns"
3937
"tailscale.com/net/netutil"
4038
"tailscale.com/net/tlsdial"
4139
"tailscale.com/net/tshttpproxy"
@@ -66,7 +64,7 @@ const (
6664
//
6765
// The provided ctx is only used for the initial connection, until
6866
// Dial returns. It does not affect the connection once established.
69-
func Dial(ctx context.Context, addr string, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*controlbase.Conn, error) {
67+
func Dial(ctx context.Context, addr string, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16, dialer dnscache.DialContextFunc) (*controlbase.Conn, error) {
7068
host, port, err := net.SplitHostPort(addr)
7169
if err != nil {
7270
return nil, err
@@ -80,6 +78,7 @@ func Dial(ctx context.Context, addr string, machineKey key.MachinePrivate, contr
8078
controlKey: controlKey,
8179
version: protocolVersion,
8280
proxyFunc: tshttpproxy.ProxyFromEnvironment,
81+
dialer: dialer,
8382
}
8483
return a.dial()
8584
}
@@ -93,6 +92,7 @@ type dialParams struct {
9392
controlKey key.MachinePublic
9493
version uint16
9594
proxyFunc func(*http.Request) (*url.URL, error) // or nil
95+
dialer dnscache.DialContextFunc
9696

9797
// For tests only
9898
insecureTLS bool
@@ -196,12 +196,11 @@ func (a *dialParams) tryURL(ctx context.Context, u *url.URL, init []byte) (net.C
196196
LookupIPFallback: dnsfallback.Lookup,
197197
UseLastGood: true,
198198
}
199-
dialer := netns.NewDialer(log.Printf)
200199
tr := http.DefaultTransport.(*http.Transport).Clone()
201200
defer tr.CloseIdleConnections()
202201
tr.Proxy = a.proxyFunc
203202
tshttpproxy.SetTransportGetProxyConnectHeader(tr)
204-
tr.DialContext = dnscache.Dialer(dialer.DialContext, dns)
203+
tr.DialContext = dnscache.Dialer(a.dialer, dns)
205204
// Disable HTTP2, since h2 can't do protocol switching.
206205
tr.TLSClientConfig.NextProtos = []string{}
207206
tr.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{}
@@ -210,7 +209,7 @@ func (a *dialParams) tryURL(ctx context.Context, u *url.URL, init []byte) (net.C
210209
tr.TLSClientConfig.InsecureSkipVerify = true
211210
tr.TLSClientConfig.VerifyConnection = nil
212211
}
213-
tr.DialTLSContext = dnscache.TLSDialer(dialer.DialContext, dns, tr.TLSClientConfig)
212+
tr.DialTLSContext = dnscache.TLSDialer(a.dialer, dns, tr.TLSClientConfig)
214213
tr.DisableCompression = true
215214

216215
// (mis)use httptrace to extract the underlying net.Conn from the

control/controlhttp/http_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020

2121
"tailscale.com/control/controlbase"
2222
"tailscale.com/net/socks5"
23+
"tailscale.com/net/tsdial"
2324
"tailscale.com/types/key"
2425
)
2526

@@ -155,6 +156,7 @@ func testControlHTTP(t *testing.T, proxy proxy) {
155156
controlKey: server.Public(),
156157
version: testProtocolVersion,
157158
insecureTLS: true,
159+
dialer: new(tsdial.Dialer).SystemDial,
158160
}
159161

160162
if proxy != nil {

ipn/ipnlocal/local.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
10341034
LinkMonitor: b.e.GetLinkMonitor(),
10351035
Pinger: b.e,
10361036
PopBrowserURL: b.tellClientToBrowseToURL,
1037+
Dialer: b.Dialer(),
10371038

10381039
// Don't warn about broken Linux IP forwarding when
10391040
// netstack is being used.

net/tsdial/tsdial.go

Lines changed: 111 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,12 @@ import (
2020

2121
"inet.af/netaddr"
2222
"tailscale.com/net/dnscache"
23+
"tailscale.com/net/interfaces"
2324
"tailscale.com/net/netknob"
25+
"tailscale.com/net/netns"
26+
"tailscale.com/types/logger"
2427
"tailscale.com/types/netmap"
28+
"tailscale.com/util/mak"
2529
"tailscale.com/wgengine/monitor"
2630
)
2731

@@ -30,6 +34,7 @@ import (
3034
// (TUN, netstack), the OS network sandboxing style (macOS/iOS
3135
// Extension, none), user-selected route acceptance prefs, etc.
3236
type Dialer struct {
37+
Logf logger.Logf
3338
// UseNetstackForIP if non-nil is whether NetstackDialTCP (if
3439
// it's non-nil) should be used to dial the provided IP.
3540
UseNetstackForIP func(netaddr.IP) bool
@@ -46,12 +51,33 @@ type Dialer struct {
4651
peerDialerOnce sync.Once
4752
peerDialer *net.Dialer
4853

49-
mu sync.Mutex
50-
dns dnsMap
51-
tunName string // tun device name
52-
linkMon *monitor.Mon
53-
exitDNSDoHBase string // non-empty if DoH-proxying exit node in use; base URL+path (without '?')
54-
dnsCache *dnscache.MessageCache // nil until first first non-empty SetExitDNSDoH
54+
netnsDialerOnce sync.Once
55+
netnsDialer netns.Dialer
56+
57+
mu sync.Mutex
58+
closed bool
59+
dns dnsMap
60+
tunName string // tun device name
61+
linkMon *monitor.Mon
62+
linkMonUnregister func()
63+
exitDNSDoHBase string // non-empty if DoH-proxying exit node in use; base URL+path (without '?')
64+
dnsCache *dnscache.MessageCache // nil until first first non-empty SetExitDNSDoH
65+
nextSysConnID int
66+
activeSysConns map[int]net.Conn // active connections not yet closed
67+
}
68+
69+
// sysConn wraps a net.Conn that was created using d.SystemDial.
70+
// It exists to track which connections are still open, and should be
71+
// closed on major link changes.
72+
type sysConn struct {
73+
net.Conn
74+
id int
75+
d *Dialer
76+
}
77+
78+
func (c sysConn) Close() error {
79+
c.d.closeSysConn(c.id)
80+
return nil
5581
}
5682

5783
// SetTUNName sets the name of the tun device in use ("tailscale0", "utun6",
@@ -91,10 +117,53 @@ func (d *Dialer) SetExitDNSDoH(doh string) {
91117
}
92118
}
93119

120+
func (d *Dialer) Close() error {
121+
d.mu.Lock()
122+
defer d.mu.Unlock()
123+
d.closed = true
124+
if d.linkMonUnregister != nil {
125+
d.linkMonUnregister()
126+
d.linkMonUnregister = nil
127+
}
128+
for _, c := range d.activeSysConns {
129+
c.Close()
130+
}
131+
d.activeSysConns = nil
132+
return nil
133+
}
134+
94135
func (d *Dialer) SetLinkMonitor(mon *monitor.Mon) {
95136
d.mu.Lock()
96137
defer d.mu.Unlock()
138+
if d.linkMonUnregister != nil {
139+
go d.linkMonUnregister()
140+
d.linkMonUnregister = nil
141+
}
97142
d.linkMon = mon
143+
d.linkMonUnregister = d.linkMon.RegisterChangeCallback(d.linkChanged)
144+
}
145+
146+
func (d *Dialer) linkChanged(major bool, state *interfaces.State) {
147+
if !major {
148+
return
149+
}
150+
d.mu.Lock()
151+
defer d.mu.Unlock()
152+
for id, c := range d.activeSysConns {
153+
go c.Close()
154+
delete(d.activeSysConns, id)
155+
}
156+
}
157+
158+
func (d *Dialer) closeSysConn(id int) {
159+
d.mu.Lock()
160+
defer d.mu.Unlock()
161+
c, ok := d.activeSysConns[id]
162+
if !ok {
163+
return
164+
}
165+
delete(d.activeSysConns, id)
166+
go c.Close() // ignore the error
98167
}
99168

100169
func (d *Dialer) interfaceIndexLocked(ifName string) (index int, ok bool) {
@@ -197,6 +266,42 @@ func ipNetOfNetwork(n string) string {
197266
return "ip"
198267
}
199268

269+
// SystemDial connects to the provided network address without going over
270+
// Tailscale. It prefers going over the default interface and closes existing
271+
// connections if the default interface changes. It is used to connect to
272+
// Control and (in the future, as of 2022-04-27) DERPs..
273+
func (d *Dialer) SystemDial(ctx context.Context, network, addr string) (net.Conn, error) {
274+
d.mu.Lock()
275+
closed := d.closed
276+
d.mu.Unlock()
277+
if closed {
278+
return nil, net.ErrClosed
279+
}
280+
281+
d.netnsDialerOnce.Do(func() {
282+
logf := d.Logf
283+
if logf == nil {
284+
logf = logger.Discard
285+
}
286+
d.netnsDialer = netns.NewDialer(logf)
287+
})
288+
c, err := d.netnsDialer.DialContext(ctx, network, addr)
289+
if err != nil {
290+
return nil, err
291+
}
292+
d.mu.Lock()
293+
defer d.mu.Unlock()
294+
id := d.nextSysConnID
295+
d.nextSysConnID++
296+
mak.Set(&d.activeSysConns, id, c)
297+
298+
return sysConn{
299+
id: id,
300+
d: d,
301+
Conn: c,
302+
}, nil
303+
}
304+
200305
// UserDial connects to the provided network address as if a user were initiating the dial.
201306
// (e.g. from a SOCKS or HTTP outbound proxy)
202307
func (d *Dialer) UserDial(ctx context.Context, network, addr string) (net.Conn, error) {

tsnet/tsnet.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ func (s *Server) Close() error {
105105
s.shutdownCancel()
106106
s.lb.Shutdown()
107107
s.linkMon.Close()
108+
s.dialer.Close()
108109
s.localAPIListener.Close()
109110

110111
s.mu.Lock()

0 commit comments

Comments
 (0)