Skip to content

feat: swap over to websockets if initial derp exchange fails #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 1, 2023
59 changes: 53 additions & 6 deletions derp/derphttp/derphttp_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"runtime"
"strings"
"sync"
"sync/atomic"
"time"

"go4.org/mem"
Expand Down Expand Up @@ -52,6 +53,9 @@ type Client struct {
MeshKey string // optional; for trusted clients
IsProber bool // optional; for probers to optional declare themselves as such

forcedWebsocket atomic.Bool // optional; set if the server has failed to upgrade the connection on the DERP server
forcedWebsocketCallback atomic.Pointer[func(int, string)]

privateKey key.NodePrivate
logf logger.Logf
dialer func(ctx context.Context, network, addr string) (net.Conn, error)
Expand Down Expand Up @@ -191,13 +195,27 @@ func (c *Client) tlsServerName(node *tailcfg.DERPNode) string {
if c.url != nil {
return c.url.Host
}
if node == nil {
return ""
}
return node.HostName
}

func (c *Client) urlString(node *tailcfg.DERPNode) string {
if c.url != nil {
return c.url.String()
}
if node.HostName == "" {
url := &url.URL{
Scheme: "https",
Host: fmt.Sprintf("%s:%d", node.IPv4, node.DERPPort),
Path: "/derp",
}
if node.ForceHTTP {
url.Scheme = "http"
}
return url.String()
}
return fmt.Sprintf("https://%s/derp", node.HostName)
}

Expand Down Expand Up @@ -228,12 +246,15 @@ func (c *Client) preferIPv6() bool {
}

// dialWebsocketFunc is non-nil (set by websocket.go's init) when compiled in.
var dialWebsocketFunc func(ctx context.Context, urlStr string) (net.Conn, error)
var dialWebsocketFunc func(ctx context.Context, urlStr string, tlsConfig *tls.Config) (net.Conn, error)

func useWebsockets() bool {
func (c *Client) useWebsockets() bool {
if runtime.GOOS == "js" {
return true
}
if c.forcedWebsocket.Load() {
return true
}
if dialWebsocketFunc != nil {
return envknob.Bool("TS_DEBUG_DERP_WS_CLIENT")
}
Expand Down Expand Up @@ -293,15 +314,18 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien

var node *tailcfg.DERPNode // nil when using c.url to dial
switch {
case useWebsockets():
case c.useWebsockets():
var urlStr string
var tlsConfig *tls.Config
if c.url != nil {
urlStr = c.url.String()
tlsConfig = c.tlsConfig(nil)
} else {
urlStr = c.urlString(reg.Nodes[0])
tlsConfig = c.tlsConfig(reg.Nodes[0])
}
c.logf("%s: connecting websocket to %v", caller, urlStr)
conn, err := dialWebsocketFunc(ctx, urlStr)
conn, err := dialWebsocketFunc(ctx, urlStr, tlsConfig)
if err != nil {
c.logf("%s: websocket to %v error: %v", caller, urlStr, err)
return nil, 0, err
Expand Down Expand Up @@ -435,6 +459,19 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
if resp.StatusCode != http.StatusSwitchingProtocols {
b, _ := io.ReadAll(resp.Body)
resp.Body.Close()

// Only on StatusCode < 500 in case a gateway timeout
// or network intermittency occurs!
if resp.StatusCode < 500 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the only status code it could be? I could see a proxy just dropping the connection with no status if it didn't like the Upgrade, or returning 400 perhaps?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

400 seems to be the standard status code if an upgrade fails, but I'm not sure we can guarantee that with all load balancers that clouds may have. That's why I chose < 500 because it'd generally mean that the server intentionally chose to not accept the WebSocket for some reason, which coderd will never do.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we limit this to just 400 error codes or do we run into others as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The honest answer is that I'm not really sure, maybe it's safe to assume only 400, but I didn't want to be presumptuous for what all proxies might do... we get the error message surfaced, so it's not impossible for us to narrow this down in the future.

reason := fmt.Sprintf("GET failed with status code %d: %s", resp.StatusCode, b)
c.logf("We'll use WebSockets on the next connection attempt. A proxy could be disallowing the use of 'Upgrade: derp': %s", reason)
c.forcedWebsocket.Store(true)
forcedWebsocketCallback := c.forcedWebsocketCallback.Load()
if forcedWebsocketCallback != nil {
go (*forcedWebsocketCallback)(reg.RegionID, reason)
}
}

return nil, 0, fmt.Errorf("GET failed: %v: %s", err, b)
}
}
Expand Down Expand Up @@ -472,6 +509,12 @@ func (c *Client) SetURLDialer(dialer func(ctx context.Context, network, addr str
c.dialer = dialer
}

// SetForcedWebsocketCallback is a callback that is called when the client
// decides to force WebSockets on the next connection attempt.
func (c *Client) SetForcedWebsocketCallback(callback func(region int, reason string)) {
c.forcedWebsocketCallback.Store(&callback)
}

func (c *Client) dialURL(ctx context.Context) (net.Conn, error) {
host := c.url.Hostname()
if c.dialer != nil {
Expand Down Expand Up @@ -525,7 +568,7 @@ func (c *Client) dialRegion(ctx context.Context, reg *tailcfg.DERPRegion) (net.C
return nil, nil, firstErr
}

func (c *Client) tlsClient(nc net.Conn, node *tailcfg.DERPNode) *tls.Conn {
func (c *Client) tlsConfig(node *tailcfg.DERPNode) *tls.Config {
tlsConf := tlsdial.Config(c.tlsServerName(node), c.TLSConfig)
if node != nil {
if node.InsecureForTests {
Expand All @@ -536,7 +579,11 @@ func (c *Client) tlsClient(nc net.Conn, node *tailcfg.DERPNode) *tls.Conn {
tlsdial.SetConfigExpectedCert(tlsConf, node.CertName)
}
}
return tls.Client(nc, tlsConf)
return tlsConf
}

func (c *Client) tlsClient(nc net.Conn, node *tailcfg.DERPNode) *tls.Conn {
return tls.Client(nc, c.tlsConfig(node))
}

// DialRegionTLS returns a TLS connection to a DERP node in the given region.
Expand Down
14 changes: 9 additions & 5 deletions derp/derphttp/websocket.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

//go:build linux || js
//go:build !js

package derphttp

import (
"context"
"crypto/tls"
"log"
"net"
"net/http"

"nhooyr.io/websocket"
"tailscale.com/net/wsconn"
Expand All @@ -18,9 +17,14 @@ func init() {
dialWebsocketFunc = dialWebsocket
}

func dialWebsocket(ctx context.Context, urlStr string) (net.Conn, error) {
func dialWebsocket(ctx context.Context, urlStr string, tlsConfig *tls.Config) (net.Conn, error) {
c, res, err := websocket.Dial(ctx, urlStr, &websocket.DialOptions{
Subprotocols: []string{"derp"},
HTTPClient: &http.Client{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't available in js, so we might want to move this assignment out to a separate file behind build tag. Maybe a function like wsDialOptionsWithHTTPClient(opts, tlsConfig).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Will fix.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check now!

Transport: &http.Transport{
TLSClientConfig: tlsConfig,
},
},
})
if err != nil {
log.Printf("websocket Dial: %v, %+v", err, res)
Expand Down
33 changes: 33 additions & 0 deletions derp/derphttp/websocket_js.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

//go:build js

package derphttp

import (
"context"
"crypto/tls"
"log"
"net"

"nhooyr.io/websocket"
"tailscale.com/net/wsconn"
)

func init() {
dialWebsocketFunc = dialWebsocket
}

func dialWebsocket(ctx context.Context, urlStr string, _ *tls.Config) (net.Conn, error) {
c, res, err := websocket.Dial(ctx, urlStr, &websocket.DialOptions{
Subprotocols: []string{"derp"},
})
if err != nil {
log.Printf("websocket Dial: %v, %+v", err, res)
return nil, err
}
log.Printf("websocket: connected to %v", urlStr)
netConn := wsconn.NetConn(context.Background(), c, websocket.MessageBinary)
return netConn, nil
}
20 changes: 20 additions & 0 deletions wgengine/magicsock/magicsock.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,10 @@ type Conn struct {
// peerLastDerp tracks which DERP node we last used to speak with a
// peer. It's only used to quiet logging, so we only log on change.
peerLastDerp map[key.NodePublic]int

// derpForcedWebsocketFunc is a callback that is called when a DERP
// connection is forced to use WebSockets.
derpForcedWebsocketFunc func(region int, reason string)
}

// SetDebugLoggingEnabled controls whether spammy debug logging is enabled.
Expand Down Expand Up @@ -842,6 +846,7 @@ func (c *Conn) updateNetInfo(ctx context.Context) (*netcheck.Report, error) {
for rid, d := range report.RegionV6Latency {
ni.DERPLatency[fmt.Sprintf("%d-v6", rid)] = d.Seconds()
}

ni.WorkingIPv6.Set(report.IPv6)
ni.OSHasIPv6.Set(report.OSHasIPv6)
ni.WorkingUDP.Set(report.UDP)
Expand Down Expand Up @@ -952,6 +957,20 @@ func (c *Conn) SetNetInfoCallback(fn func(*tailcfg.NetInfo)) {
}
}

// SetDERPForcedWebsocketCallback is called when a DERP connection
// switches to using WebSockets.
func (c *Conn) SetDERPForcedWebsocketCallback(fn func(region int, reason string)) {
if fn == nil {
panic("nil DERPClientForcedWebsocketCallback")
}
c.mu.Lock()
c.derpForcedWebsocketFunc = fn
for _, a := range c.activeDerp {
a.c.SetForcedWebsocketCallback(fn)
}
c.mu.Unlock()
}

// LastRecvActivityOfNodeKey describes the time we last got traffic from
// this endpoint (updated every ~10 seconds).
func (c *Conn) LastRecvActivityOfNodeKey(nk key.NodePublic) string {
Expand Down Expand Up @@ -1486,6 +1505,7 @@ func (c *Conn) derpWriteChanOfAddr(addr netip.AddrPort, peer key.NodePublic) cha
dc.SetCanAckPings(true)
dc.NotePreferred(c.myDerp == regionID)
dc.SetAddressFamilySelector(derpAddrFamSelector{c})
dc.SetForcedWebsocketCallback(c.derpForcedWebsocketFunc)
dc.DNSCache = dnscache.Get()

ctx, cancel := context.WithCancel(c.connCtx)
Expand Down