Skip to content

Commit 45b969a

Browse files
committed
Push websocket tracker into tailnet as well
1 parent bf64a43 commit 45b969a

File tree

5 files changed

+58
-67
lines changed

5 files changed

+58
-67
lines changed

coderd/sockets.go renamed to coderd/activewebsockets/sockets.go

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package coderd
1+
package activewebsockets
22

33
import (
44
"context"
@@ -12,26 +12,26 @@ import (
1212
"github.com/coder/coder/codersdk"
1313
)
1414

15-
// ActiveWebsockets is a helper struct that can be used to track active
15+
// Active is a helper struct that can be used to track active
1616
// websocket connections. All connections will be closed when the parent
1717
// context is canceled.
18-
type ActiveWebsockets struct {
18+
type Active struct {
1919
ctx context.Context
2020
cancel func()
2121

2222
wg sync.WaitGroup
2323
}
2424

25-
func NewActiveWebsockets(ctx context.Context) *ActiveWebsockets {
25+
func New(ctx context.Context) *Active {
2626
ctx, cancel := context.WithCancel(ctx)
27-
return &ActiveWebsockets{
27+
return &Active{
2828
ctx: ctx,
2929
cancel: cancel,
3030
}
3131
}
3232

3333
// Accept accepts a websocket connection and calls f with the connection.
34-
// The function will be tracked by the ActiveWebsockets struct and will be
34+
// The function will be tracked by the Active struct and will be
3535
// closed when the parent context is canceled.
3636
// Steps:
3737
// 1. Ensure we are still accepting websocket connections, and not shutting down.
@@ -41,7 +41,7 @@ func NewActiveWebsockets(ctx context.Context) *ActiveWebsockets {
4141
// 4a. If there is an error, write the error to the response writer and return.
4242
// 5. Launch go routine to kill websocket if the parent context is canceled.
4343
// 6. Call 'f' with the websocket connection.
44-
func (a *ActiveWebsockets) Accept(rw http.ResponseWriter, r *http.Request, options *websocket.AcceptOptions, f func(conn *websocket.Conn)) {
44+
func (a *Active) Accept(rw http.ResponseWriter, r *http.Request, options *websocket.AcceptOptions, f func(conn *websocket.Conn)) {
4545
// Ensure we are still accepting websocket connections, and not shutting down.
4646
if err := a.ctx.Err(); err != nil {
4747
httpapi.Write(context.Background(), rw, http.StatusBadRequest, codersdk.Response{
@@ -79,7 +79,7 @@ func (a *ActiveWebsockets) Accept(rw http.ResponseWriter, r *http.Request, optio
7979
// and close a websocket connection if that context is canceled.
8080
func closeConnOnContext(ctx context.Context, conn *websocket.Conn) {
8181
// Labeling the go routine for goroutine dumps/debugging.
82-
go pprof.Do(ctx, pprof.Labels("service", "api-server", "function", "ActiveWebsockets.track"), func(ctx context.Context) {
82+
go pprof.Do(ctx, pprof.Labels("service", "ActiveWebsockets"), func(ctx context.Context) {
8383
select {
8484
case <-ctx.Done():
8585
_ = conn.Close(websocket.StatusNormalClosure, "")
@@ -89,7 +89,7 @@ func closeConnOnContext(ctx context.Context, conn *websocket.Conn) {
8989

9090
// Close will close all active websocket connections and wait for them to
9191
// finish.
92-
func (a *ActiveWebsockets) Close() {
92+
func (a *Active) Close() {
9393
a.cancel()
9494
a.wg.Wait()
9595
}

coderd/coderd.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import (
3838
"github.com/coder/coder/buildinfo"
3939

4040
// Used to serve the Swagger endpoint
41+
"github.com/coder/coder/coderd/activewebsockets"
4142
_ "github.com/coder/coder/coderd/apidoc"
4243
"github.com/coder/coder/coderd/audit"
4344
"github.com/coder/coder/coderd/awsidentity"
@@ -315,7 +316,7 @@ func New(options *Options) *API {
315316
TemplateScheduleStore: options.TemplateScheduleStore,
316317
Experiments: experiments,
317318
healthCheckGroup: &singleflight.Group[string, *healthcheck.Report]{},
318-
WebsocketWatch: NewActiveWebsockets(ctx),
319+
WebsocketWatch: activewebsockets.New(ctx),
319320
}
320321
if options.UpdateCheckOptions != nil {
321322
api.updateChecker = updatecheck.New(
@@ -355,7 +356,7 @@ func New(options *Options) *API {
355356
apiRateLimiter := httpmw.RateLimit(options.APIRateLimit, time.Minute)
356357

357358
derpHandler := derphttp.Handler(api.DERPServer)
358-
derpHandler, api.derpCloseFunc = tailnet.WithWebsocketSupport(api.DERPServer, derpHandler)
359+
derpHandler = tailnet.WithWebsocketSupport(api.WebsocketWatch, api.DERPServer, derpHandler)
359360

360361
r.Use(
361362
httpmw.Recover(api.Logger),
@@ -784,7 +785,7 @@ type API struct {
784785

785786
siteHandler http.Handler
786787

787-
WebsocketWatch *ActiveWebsockets
788+
WebsocketWatch *activewebsockets.Active
788789
derpCloseFunc func()
789790

790791
metricsCache *metricscache.Cache

coderd/healthcheck/derp_test.go

+8-3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"tailscale.com/tailcfg"
1818
"tailscale.com/types/key"
1919

20+
"github.com/coder/coder/coderd/activewebsockets"
2021
"github.com/coder/coder/coderd/healthcheck"
2122
"github.com/coder/coder/tailnet"
2223
)
@@ -124,10 +125,15 @@ func TestDERP(t *testing.T) {
124125
t.Run("ForceWebsockets", func(t *testing.T) {
125126
t.Parallel()
126127

128+
ctx, cancel := context.WithCancel(context.Background())
129+
defer cancel()
130+
127131
derpSrv := derp.NewServer(key.NewNode(), func(format string, args ...any) { t.Logf(format, args...) })
128132
defer derpSrv.Close()
129-
handler, closeHandler := tailnet.WithWebsocketSupport(derpSrv, derphttp.Handler(derpSrv))
130-
defer closeHandler()
133+
134+
sockets := activewebsockets.New(ctx)
135+
handler := tailnet.WithWebsocketSupport(sockets, derpSrv, derphttp.Handler(derpSrv))
136+
defer sockets.Close()
131137

132138
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
133139
if r.Header.Get("Upgrade") == "DERP" {
@@ -140,7 +146,6 @@ func TestDERP(t *testing.T) {
140146
}))
141147

142148
var (
143-
ctx = context.Background()
144149
report = healthcheck.DERPReport{}
145150
derpURL, _ = url.Parse(srv.URL)
146151
opts = &healthcheck.DERPReportOptions{

tailnet/derp.go

+28-49
Original file line numberDiff line numberDiff line change
@@ -2,71 +2,50 @@ package tailnet
22

33
import (
44
"bufio"
5-
"context"
6-
"log"
75
"net/http"
86
"strings"
9-
"sync"
107

118
"nhooyr.io/websocket"
129
"tailscale.com/derp"
1310
"tailscale.com/net/wsconn"
11+
12+
"github.com/coder/coder/coderd/activewebsockets"
1413
)
1514

1615
// WithWebsocketSupport returns an http.Handler that upgrades
1716
// connections to the "derp" subprotocol to WebSockets and
1817
// passes them to the DERP server.
1918
// Taken from: https://github.com/tailscale/tailscale/blob/e3211ff88ba85435f70984cf67d9b353f3d650d8/cmd/derper/websocket.go#L21
20-
func WithWebsocketSupport(s *derp.Server, base http.Handler) (http.Handler, func()) {
21-
var mu sync.Mutex
22-
var waitGroup sync.WaitGroup
23-
ctx, cancelFunc := context.WithCancel(context.Background())
24-
19+
func WithWebsocketSupport(sockets *activewebsockets.Active, s *derp.Server, base http.Handler) http.Handler {
2520
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
26-
up := strings.ToLower(r.Header.Get("Upgrade"))
21+
up := strings.ToLower(r.Header.Get("Upgrade"))
2722

28-
// Very early versions of Tailscale set "Upgrade: WebSocket" but didn't actually
29-
// speak WebSockets (they still assumed DERP's binary framing). So to distinguish
30-
// clients that actually want WebSockets, look for an explicit "derp" subprotocol.
31-
if up != "websocket" || !strings.Contains(r.Header.Get("Sec-Websocket-Protocol"), "derp") {
32-
base.ServeHTTP(w, r)
33-
return
34-
}
23+
// Very early versions of Tailscale set "Upgrade: WebSocket" but didn't actually
24+
// speak WebSockets (they still assumed DERP's binary framing). So to distinguish
25+
// clients that actually want WebSockets, look for an explicit "derp" subprotocol.
26+
if up != "websocket" || !strings.Contains(r.Header.Get("Sec-Websocket-Protocol"), "derp") {
27+
base.ServeHTTP(w, r)
28+
return
29+
}
3530

36-
mu.Lock()
37-
if ctx.Err() != nil {
38-
mu.Unlock()
31+
sockets.Accept(w, r, &websocket.AcceptOptions{
32+
Subprotocols: []string{"derp"},
33+
OriginPatterns: []string{"*"},
34+
// Disable compression because we transmit WireGuard messages that
35+
// are not compressible.
36+
// Additionally, Safari has a broken implementation of compression
37+
// (see https://github.com/nhooyr/websocket/issues/218) that makes
38+
// enabling it actively harmful.
39+
CompressionMode: websocket.CompressionDisabled,
40+
}, func(conn *websocket.Conn) {
41+
defer conn.Close(websocket.StatusInternalError, "closing")
42+
if conn.Subprotocol() != "derp" {
43+
conn.Close(websocket.StatusPolicyViolation, "client must speak the derp subprotocol")
3944
return
4045
}
41-
waitGroup.Add(1)
42-
mu.Unlock()
43-
defer waitGroup.Done()
44-
c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
45-
Subprotocols: []string{"derp"},
46-
OriginPatterns: []string{"*"},
47-
// Disable compression because we transmit WireGuard messages that
48-
// are not compressible.
49-
// Additionally, Safari has a broken implementation of compression
50-
// (see https://github.com/nhooyr/websocket/issues/218) that makes
51-
// enabling it actively harmful.
52-
CompressionMode: websocket.CompressionDisabled,
53-
})
54-
if err != nil {
55-
log.Printf("websocket.Accept: %v", err)
56-
return
57-
}
58-
defer c.Close(websocket.StatusInternalError, "closing")
59-
if c.Subprotocol() != "derp" {
60-
c.Close(websocket.StatusPolicyViolation, "client must speak the derp subprotocol")
61-
return
62-
}
63-
wc := wsconn.NetConn(ctx, c, websocket.MessageBinary)
46+
wc := wsconn.NetConn(r.Context(), conn, websocket.MessageBinary)
6447
brw := bufio.NewReadWriter(bufio.NewReader(wc), bufio.NewWriter(wc))
65-
s.Accept(ctx, wc, brw, r.RemoteAddr)
66-
}), func() {
67-
cancelFunc()
68-
mu.Lock()
69-
waitGroup.Wait()
70-
mu.Unlock()
71-
}
48+
s.Accept(r.Context(), wc, brw, r.RemoteAddr)
49+
})
50+
})
7251
}

tailnet/tailnettest/tailnettest.go

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package tailnettest
22

33
import (
4+
"context"
45
"crypto/tls"
56
"fmt"
67
"html"
@@ -18,6 +19,7 @@ import (
1819
"tailscale.com/types/nettype"
1920

2021
"cdr.dev/slog/sloggers/slogtest"
22+
"github.com/coder/coder/coderd/activewebsockets"
2123
"github.com/coder/coder/tailnet"
2224
)
2325

@@ -71,8 +73,12 @@ func RunDERPOnlyWebSockets(t *testing.T) *tailcfg.DERPMap {
7173
logf := tailnet.Logger(slogtest.Make(t, nil))
7274
d := derp.NewServer(key.NewNode(), logf)
7375
handler := derphttp.Handler(d)
74-
var closeFunc func()
75-
handler, closeFunc = tailnet.WithWebsocketSupport(d, handler)
76+
77+
ctx, cancel := context.WithCancel(context.Background())
78+
defer cancel()
79+
sockets := activewebsockets.New(ctx)
80+
81+
handler = tailnet.WithWebsocketSupport(sockets, d, handler)
7682
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
7783
if r.URL.Path != "/derp" {
7884
handler.ServeHTTP(w, r)
@@ -91,7 +97,7 @@ func RunDERPOnlyWebSockets(t *testing.T) *tailcfg.DERPMap {
9197
t.Cleanup(func() {
9298
server.CloseClientConnections()
9399
server.Close()
94-
closeFunc()
100+
sockets.Close()
95101
d.Close()
96102
})
97103

0 commit comments

Comments
 (0)