Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions derp/derphttp/derphttp_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ type Client struct {
MeshKey string // optional; for trusted clients
IsProber bool // optional; for probers to optional declare themselves as such

// Allow forcing WebSocket fallback for situations where proxies do not
// play well with `Upgrade: derp`. Turning this on will cause the client to
// always try WebSocket on it's first attempt. The callback will not be
// called if this is true.
//
// Example proxies include:
// - Azure Application Proxy (which redirects to login)
ForceWebsockets bool
forcedWebsocket atomic.Bool // optional; set if the server has failed to upgrade the connection on the DERP server
forcedWebsocketCallback atomic.Pointer[func(int, string)]

Expand Down Expand Up @@ -301,6 +309,9 @@ func (c *Client) preferIPv6() bool {
var dialWebsocketFunc func(ctx context.Context, urlStr string, tlsConfig *tls.Config, httpHeader http.Header) (net.Conn, error)

func (c *Client) useWebsockets() bool {
if c.ForceWebsockets {
return true
}
if runtime.GOOS == "js" {
return true
}
Expand Down
60 changes: 60 additions & 0 deletions derp/derphttp/derphttp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,63 @@ func TestHTTP2OnlyServer(t *testing.T) {

c.Close()
}

func TestForceWebsockets(t *testing.T) {
serverPrivateKey := key.NewNode()
s := derp.NewServer(serverPrivateKey, t.Logf)
defer s.Close()

httpsrv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
up := r.Header.Get("Upgrade")
if up == "" {
Handler(s).ServeHTTP(w, r)
return
}
if up != "websocket" {
// Should only attempt to upgrade to websocket.
t.Errorf("unexpected Upgrade header: %q", up)
return
}

c, err := websocket.Accept(w, r, &websocket.AcceptOptions{})
if err != nil {
t.Errorf("websocket.Accept: %v", err)
return
}
defer c.Close(websocket.StatusInternalError, "closing")
wc := wsconn.NetConn(context.Background(), c, websocket.MessageBinary)
brw := bufio.NewReadWriter(bufio.NewReader(wc), bufio.NewWriter(wc))
s.Accept(context.Background(), wc, brw, r.RemoteAddr)
}))
defer httpsrv.Close()
httpsrv.TLS = &tls.Config{
GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
// Add this to ensure fast start works!
cert := httpsrv.TLS.Certificates[0]
cert.Certificate = append(cert.Certificate, s.MetaCert())
return &cert, nil
},
}
httpsrv.StartTLS()

serverURL := httpsrv.URL
t.Logf("server URL: %s", serverURL)

c, err := NewClient(key.NewNode(), serverURL, t.Logf)
if err != nil {
t.Fatalf("NewClient: %v", err)
}
c.ForceWebsockets = true
c.TLSConfig = &tls.Config{
ServerName: "example.com",
RootCAs: httpsrv.Client().Transport.(*http.Transport).TLSClientConfig.RootCAs,
}
defer c.Close()

err = c.Connect(context.Background())
if err != nil {
t.Fatalf("client errored initial connect: %v", err)
}

c.Close()
}
1 change: 1 addition & 0 deletions wgengine/magicsock/derp.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ func (c *Conn) derpWriteChanOfAddr(addr netip.AddrPort, peer key.NodePublic) cha
if header != nil {
dc.Header = header.Clone()
}
dc.ForceWebsockets = c.derpForceWebsockets.Load()
dialer := c.derpRegionDialer.Load()
if dialer != nil {
dc.SetRegionDialer(*dialer)
Expand Down
7 changes: 7 additions & 0 deletions wgengine/magicsock/magicsock.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ type Conn struct {
// headers that are passed to the DERP HTTP client
derpHeader atomic.Pointer[http.Header]

// whether websocket is always used by the DERP HTTP client
derpForceWebsockets atomic.Bool

// derpRegionDialer is passed to the DERP client
derpRegionDialer atomic.Pointer[func(ctx context.Context, region *tailcfg.DERPRegion) net.Conn]

Expand Down Expand Up @@ -1660,6 +1663,10 @@ func (c *Conn) SetDERPHeader(header http.Header) {
c.derpHeader.Store(&header)
}

func (c *Conn) SetDERPForceWebsockets(v bool) {
c.derpForceWebsockets.Store(v)
}

func (c *Conn) SetDERPRegionDialer(dialer func(ctx context.Context, region *tailcfg.DERPRegion) net.Conn) {
c.derpRegionDialer.Store(&dialer)
c.mu.Lock()
Expand Down