@@ -25,6 +25,7 @@ import (
25
25
"runtime"
26
26
"strings"
27
27
"sync"
28
+ "sync/atomic"
28
29
"time"
29
30
30
31
"go4.org/mem"
@@ -52,6 +53,9 @@ type Client struct {
52
53
MeshKey string // optional; for trusted clients
53
54
IsProber bool // optional; for probers to optional declare themselves as such
54
55
56
+ forcedWebsocket atomic.Bool // optional; set if the server has failed to upgrade the connection on the DERP server
57
+ forcedWebsocketCallback atomic.Pointer [func (int , string )]
58
+
55
59
privateKey key.NodePrivate
56
60
logf logger.Logf
57
61
dialer func (ctx context.Context , network , addr string ) (net.Conn , error )
@@ -191,13 +195,27 @@ func (c *Client) tlsServerName(node *tailcfg.DERPNode) string {
191
195
if c .url != nil {
192
196
return c .url .Host
193
197
}
198
+ if node == nil {
199
+ return ""
200
+ }
194
201
return node .HostName
195
202
}
196
203
197
204
func (c * Client ) urlString (node * tailcfg.DERPNode ) string {
198
205
if c .url != nil {
199
206
return c .url .String ()
200
207
}
208
+ if node .HostName == "" {
209
+ url := & url.URL {
210
+ Scheme : "https" ,
211
+ Host : fmt .Sprintf ("%s:%d" , node .IPv4 , node .DERPPort ),
212
+ Path : "/derp" ,
213
+ }
214
+ if node .ForceHTTP {
215
+ url .Scheme = "http"
216
+ }
217
+ return url .String ()
218
+ }
201
219
return fmt .Sprintf ("https://%s/derp" , node .HostName )
202
220
}
203
221
@@ -228,12 +246,15 @@ func (c *Client) preferIPv6() bool {
228
246
}
229
247
230
248
// dialWebsocketFunc is non-nil (set by websocket.go's init) when compiled in.
231
- var dialWebsocketFunc func (ctx context.Context , urlStr string ) (net.Conn , error )
249
+ var dialWebsocketFunc func (ctx context.Context , urlStr string , tlsConfig * tls. Config ) (net.Conn , error )
232
250
233
- func useWebsockets () bool {
251
+ func ( c * Client ) useWebsockets () bool {
234
252
if runtime .GOOS == "js" {
235
253
return true
236
254
}
255
+ if c .forcedWebsocket .Load () {
256
+ return true
257
+ }
237
258
if dialWebsocketFunc != nil {
238
259
return envknob .Bool ("TS_DEBUG_DERP_WS_CLIENT" )
239
260
}
@@ -293,15 +314,18 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
293
314
294
315
var node * tailcfg.DERPNode // nil when using c.url to dial
295
316
switch {
296
- case useWebsockets ():
317
+ case c . useWebsockets ():
297
318
var urlStr string
319
+ var tlsConfig * tls.Config
298
320
if c .url != nil {
299
321
urlStr = c .url .String ()
322
+ tlsConfig = c .tlsConfig (nil )
300
323
} else {
301
324
urlStr = c .urlString (reg .Nodes [0 ])
325
+ tlsConfig = c .tlsConfig (reg .Nodes [0 ])
302
326
}
303
327
c .logf ("%s: connecting websocket to %v" , caller , urlStr )
304
- conn , err := dialWebsocketFunc (ctx , urlStr )
328
+ conn , err := dialWebsocketFunc (ctx , urlStr , tlsConfig )
305
329
if err != nil {
306
330
c .logf ("%s: websocket to %v error: %v" , caller , urlStr , err )
307
331
return nil , 0 , err
@@ -435,6 +459,19 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
435
459
if resp .StatusCode != http .StatusSwitchingProtocols {
436
460
b , _ := io .ReadAll (resp .Body )
437
461
resp .Body .Close ()
462
+
463
+ // Only on StatusCode < 500 in case a gateway timeout
464
+ // or network intermittency occurs!
465
+ if resp .StatusCode < 500 {
466
+ reason := fmt .Sprintf ("GET failed with status code %d: %s" , resp .StatusCode , b )
467
+ c .logf ("We'll use WebSockets on the next connection attempt. A proxy could be disallowing the use of 'Upgrade: derp': %s" , reason )
468
+ c .forcedWebsocket .Store (true )
469
+ forcedWebsocketCallback := c .forcedWebsocketCallback .Load ()
470
+ if forcedWebsocketCallback != nil {
471
+ go (* forcedWebsocketCallback )(reg .RegionID , reason )
472
+ }
473
+ }
474
+
438
475
return nil , 0 , fmt .Errorf ("GET failed: %v: %s" , err , b )
439
476
}
440
477
}
@@ -472,6 +509,12 @@ func (c *Client) SetURLDialer(dialer func(ctx context.Context, network, addr str
472
509
c .dialer = dialer
473
510
}
474
511
512
+ // SetForcedWebsocketCallback is a callback that is called when the client
513
+ // decides to force WebSockets on the next connection attempt.
514
+ func (c * Client ) SetForcedWebsocketCallback (callback func (region int , reason string )) {
515
+ c .forcedWebsocketCallback .Store (& callback )
516
+ }
517
+
475
518
func (c * Client ) dialURL (ctx context.Context ) (net.Conn , error ) {
476
519
host := c .url .Hostname ()
477
520
if c .dialer != nil {
@@ -525,7 +568,7 @@ func (c *Client) dialRegion(ctx context.Context, reg *tailcfg.DERPRegion) (net.C
525
568
return nil , nil , firstErr
526
569
}
527
570
528
- func (c * Client ) tlsClient ( nc net. Conn , node * tailcfg.DERPNode ) * tls.Conn {
571
+ func (c * Client ) tlsConfig ( node * tailcfg.DERPNode ) * tls.Config {
529
572
tlsConf := tlsdial .Config (c .tlsServerName (node ), c .TLSConfig )
530
573
if node != nil {
531
574
if node .InsecureForTests {
@@ -536,7 +579,11 @@ func (c *Client) tlsClient(nc net.Conn, node *tailcfg.DERPNode) *tls.Conn {
536
579
tlsdial .SetConfigExpectedCert (tlsConf , node .CertName )
537
580
}
538
581
}
539
- return tls .Client (nc , tlsConf )
582
+ return tlsConf
583
+ }
584
+
585
+ func (c * Client ) tlsClient (nc net.Conn , node * tailcfg.DERPNode ) * tls.Conn {
586
+ return tls .Client (nc , c .tlsConfig (node ))
540
587
}
541
588
542
589
// DialRegionTLS returns a TLS connection to a DERP node in the given region.
0 commit comments