Skip to content

Commit fb16ae7

Browse files
authored
feat: swap over to websockets if initial derp exchange fails (#8)
* feat: swap over to websockets if initial derp exchange fails * Allow insecure WebSocket connections if the DERP map provides * Allow a nil node for the TLS config * Add WebSockets support to Mac and Windows * Use the first node * Swap to use a callback * Use a pointer to swap the callback * Only use TLS Config if a node exists * Move HTTPClient outside of JS compilation
1 parent 446fc10 commit fb16ae7

File tree

4 files changed

+115
-11
lines changed

4 files changed

+115
-11
lines changed

derp/derphttp/derphttp_client.go

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"runtime"
2626
"strings"
2727
"sync"
28+
"sync/atomic"
2829
"time"
2930

3031
"go4.org/mem"
@@ -52,6 +53,9 @@ type Client struct {
5253
MeshKey string // optional; for trusted clients
5354
IsProber bool // optional; for probers to optional declare themselves as such
5455

56+
forcedWebsocket atomic.Bool // optional; set if the server has failed to upgrade the connection on the DERP server
57+
forcedWebsocketCallback atomic.Pointer[func(int, string)]
58+
5559
privateKey key.NodePrivate
5660
logf logger.Logf
5761
dialer func(ctx context.Context, network, addr string) (net.Conn, error)
@@ -191,13 +195,27 @@ func (c *Client) tlsServerName(node *tailcfg.DERPNode) string {
191195
if c.url != nil {
192196
return c.url.Host
193197
}
198+
if node == nil {
199+
return ""
200+
}
194201
return node.HostName
195202
}
196203

197204
func (c *Client) urlString(node *tailcfg.DERPNode) string {
198205
if c.url != nil {
199206
return c.url.String()
200207
}
208+
if node.HostName == "" {
209+
url := &url.URL{
210+
Scheme: "https",
211+
Host: fmt.Sprintf("%s:%d", node.IPv4, node.DERPPort),
212+
Path: "/derp",
213+
}
214+
if node.ForceHTTP {
215+
url.Scheme = "http"
216+
}
217+
return url.String()
218+
}
201219
return fmt.Sprintf("https://%s/derp", node.HostName)
202220
}
203221

@@ -228,12 +246,15 @@ func (c *Client) preferIPv6() bool {
228246
}
229247

230248
// dialWebsocketFunc is non-nil (set by websocket.go's init) when compiled in.
231-
var dialWebsocketFunc func(ctx context.Context, urlStr string) (net.Conn, error)
249+
var dialWebsocketFunc func(ctx context.Context, urlStr string, tlsConfig *tls.Config) (net.Conn, error)
232250

233-
func useWebsockets() bool {
251+
func (c *Client) useWebsockets() bool {
234252
if runtime.GOOS == "js" {
235253
return true
236254
}
255+
if c.forcedWebsocket.Load() {
256+
return true
257+
}
237258
if dialWebsocketFunc != nil {
238259
return envknob.Bool("TS_DEBUG_DERP_WS_CLIENT")
239260
}
@@ -293,15 +314,18 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
293314

294315
var node *tailcfg.DERPNode // nil when using c.url to dial
295316
switch {
296-
case useWebsockets():
317+
case c.useWebsockets():
297318
var urlStr string
319+
var tlsConfig *tls.Config
298320
if c.url != nil {
299321
urlStr = c.url.String()
322+
tlsConfig = c.tlsConfig(nil)
300323
} else {
301324
urlStr = c.urlString(reg.Nodes[0])
325+
tlsConfig = c.tlsConfig(reg.Nodes[0])
302326
}
303327
c.logf("%s: connecting websocket to %v", caller, urlStr)
304-
conn, err := dialWebsocketFunc(ctx, urlStr)
328+
conn, err := dialWebsocketFunc(ctx, urlStr, tlsConfig)
305329
if err != nil {
306330
c.logf("%s: websocket to %v error: %v", caller, urlStr, err)
307331
return nil, 0, err
@@ -435,6 +459,19 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
435459
if resp.StatusCode != http.StatusSwitchingProtocols {
436460
b, _ := io.ReadAll(resp.Body)
437461
resp.Body.Close()
462+
463+
// Only on StatusCode < 500 in case a gateway timeout
464+
// or network intermittency occurs!
465+
if resp.StatusCode < 500 {
466+
reason := fmt.Sprintf("GET failed with status code %d: %s", resp.StatusCode, b)
467+
c.logf("We'll use WebSockets on the next connection attempt. A proxy could be disallowing the use of 'Upgrade: derp': %s", reason)
468+
c.forcedWebsocket.Store(true)
469+
forcedWebsocketCallback := c.forcedWebsocketCallback.Load()
470+
if forcedWebsocketCallback != nil {
471+
go (*forcedWebsocketCallback)(reg.RegionID, reason)
472+
}
473+
}
474+
438475
return nil, 0, fmt.Errorf("GET failed: %v: %s", err, b)
439476
}
440477
}
@@ -472,6 +509,12 @@ func (c *Client) SetURLDialer(dialer func(ctx context.Context, network, addr str
472509
c.dialer = dialer
473510
}
474511

512+
// SetForcedWebsocketCallback is a callback that is called when the client
513+
// decides to force WebSockets on the next connection attempt.
514+
func (c *Client) SetForcedWebsocketCallback(callback func(region int, reason string)) {
515+
c.forcedWebsocketCallback.Store(&callback)
516+
}
517+
475518
func (c *Client) dialURL(ctx context.Context) (net.Conn, error) {
476519
host := c.url.Hostname()
477520
if c.dialer != nil {
@@ -525,7 +568,7 @@ func (c *Client) dialRegion(ctx context.Context, reg *tailcfg.DERPRegion) (net.C
525568
return nil, nil, firstErr
526569
}
527570

528-
func (c *Client) tlsClient(nc net.Conn, node *tailcfg.DERPNode) *tls.Conn {
571+
func (c *Client) tlsConfig(node *tailcfg.DERPNode) *tls.Config {
529572
tlsConf := tlsdial.Config(c.tlsServerName(node), c.TLSConfig)
530573
if node != nil {
531574
if node.InsecureForTests {
@@ -536,7 +579,11 @@ func (c *Client) tlsClient(nc net.Conn, node *tailcfg.DERPNode) *tls.Conn {
536579
tlsdial.SetConfigExpectedCert(tlsConf, node.CertName)
537580
}
538581
}
539-
return tls.Client(nc, tlsConf)
582+
return tlsConf
583+
}
584+
585+
func (c *Client) tlsClient(nc net.Conn, node *tailcfg.DERPNode) *tls.Conn {
586+
return tls.Client(nc, c.tlsConfig(node))
540587
}
541588

542589
// DialRegionTLS returns a TLS connection to a DERP node in the given region.

derp/derphttp/websocket.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
// Copyright (c) Tailscale Inc & AUTHORS
2-
// SPDX-License-Identifier: BSD-3-Clause
3-
4-
//go:build linux || js
1+
//go:build !js
52

63
package derphttp
74

85
import (
96
"context"
7+
"crypto/tls"
108
"log"
119
"net"
10+
"net/http"
1211

1312
"nhooyr.io/websocket"
1413
"tailscale.com/net/wsconn"
@@ -18,9 +17,14 @@ func init() {
1817
dialWebsocketFunc = dialWebsocket
1918
}
2019

21-
func dialWebsocket(ctx context.Context, urlStr string) (net.Conn, error) {
20+
func dialWebsocket(ctx context.Context, urlStr string, tlsConfig *tls.Config) (net.Conn, error) {
2221
c, res, err := websocket.Dial(ctx, urlStr, &websocket.DialOptions{
2322
Subprotocols: []string{"derp"},
23+
HTTPClient: &http.Client{
24+
Transport: &http.Transport{
25+
TLSClientConfig: tlsConfig,
26+
},
27+
},
2428
})
2529
if err != nil {
2630
log.Printf("websocket Dial: %v, %+v", err, res)

derp/derphttp/websocket_js.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) Tailscale Inc & AUTHORS
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
//go:build js
5+
6+
package derphttp
7+
8+
import (
9+
"context"
10+
"crypto/tls"
11+
"log"
12+
"net"
13+
14+
"nhooyr.io/websocket"
15+
"tailscale.com/net/wsconn"
16+
)
17+
18+
func init() {
19+
dialWebsocketFunc = dialWebsocket
20+
}
21+
22+
func dialWebsocket(ctx context.Context, urlStr string, _ *tls.Config) (net.Conn, error) {
23+
c, res, err := websocket.Dial(ctx, urlStr, &websocket.DialOptions{
24+
Subprotocols: []string{"derp"},
25+
})
26+
if err != nil {
27+
log.Printf("websocket Dial: %v, %+v", err, res)
28+
return nil, err
29+
}
30+
log.Printf("websocket: connected to %v", urlStr)
31+
netConn := wsconn.NetConn(context.Background(), c, websocket.MessageBinary)
32+
return netConn, nil
33+
}

wgengine/magicsock/magicsock.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,10 @@ type Conn struct {
448448
// peerLastDerp tracks which DERP node we last used to speak with a
449449
// peer. It's only used to quiet logging, so we only log on change.
450450
peerLastDerp map[key.NodePublic]int
451+
452+
// derpForcedWebsocketFunc is a callback that is called when a DERP
453+
// connection is forced to use WebSockets.
454+
derpForcedWebsocketFunc func(region int, reason string)
451455
}
452456

453457
// SetDebugLoggingEnabled controls whether spammy debug logging is enabled.
@@ -842,6 +846,7 @@ func (c *Conn) updateNetInfo(ctx context.Context) (*netcheck.Report, error) {
842846
for rid, d := range report.RegionV6Latency {
843847
ni.DERPLatency[fmt.Sprintf("%d-v6", rid)] = d.Seconds()
844848
}
849+
845850
ni.WorkingIPv6.Set(report.IPv6)
846851
ni.OSHasIPv6.Set(report.OSHasIPv6)
847852
ni.WorkingUDP.Set(report.UDP)
@@ -952,6 +957,20 @@ func (c *Conn) SetNetInfoCallback(fn func(*tailcfg.NetInfo)) {
952957
}
953958
}
954959

960+
// SetDERPForcedWebsocketCallback is called when a DERP connection
961+
// switches to using WebSockets.
962+
func (c *Conn) SetDERPForcedWebsocketCallback(fn func(region int, reason string)) {
963+
if fn == nil {
964+
panic("nil DERPClientForcedWebsocketCallback")
965+
}
966+
c.mu.Lock()
967+
c.derpForcedWebsocketFunc = fn
968+
for _, a := range c.activeDerp {
969+
a.c.SetForcedWebsocketCallback(fn)
970+
}
971+
c.mu.Unlock()
972+
}
973+
955974
// LastRecvActivityOfNodeKey describes the time we last got traffic from
956975
// this endpoint (updated every ~10 seconds).
957976
func (c *Conn) LastRecvActivityOfNodeKey(nk key.NodePublic) string {
@@ -1486,6 +1505,7 @@ func (c *Conn) derpWriteChanOfAddr(addr netip.AddrPort, peer key.NodePublic) cha
14861505
dc.SetCanAckPings(true)
14871506
dc.NotePreferred(c.myDerp == regionID)
14881507
dc.SetAddressFamilySelector(derpAddrFamSelector{c})
1508+
dc.SetForcedWebsocketCallback(c.derpForcedWebsocketFunc)
14891509
dc.DNSCache = dnscache.Get()
14901510

14911511
ctx, cancel := context.WithCancel(c.connCtx)

0 commit comments

Comments
 (0)