Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Push websocket tracker into tailnet as well
  • Loading branch information
Emyrk committed Apr 4, 2023
commit 45b969ab385f8152cec17a77ed1462d78072e7de
18 changes: 9 additions & 9 deletions coderd/sockets.go → coderd/activewebsockets/sockets.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package coderd
package activewebsockets

import (
"context"
Expand All @@ -12,26 +12,26 @@ import (
"github.com/coder/coder/codersdk"
)

// ActiveWebsockets is a helper struct that can be used to track active
// Active is a helper struct that can be used to track active
// websocket connections. All connections will be closed when the parent
// context is canceled.
type ActiveWebsockets struct {
type Active struct {
ctx context.Context
cancel func()

wg sync.WaitGroup
}

func NewActiveWebsockets(ctx context.Context) *ActiveWebsockets {
func New(ctx context.Context) *Active {
ctx, cancel := context.WithCancel(ctx)
return &ActiveWebsockets{
return &Active{
ctx: ctx,
cancel: cancel,
}
}

// Accept accepts a websocket connection and calls f with the connection.
// The function will be tracked by the ActiveWebsockets struct and will be
// The function will be tracked by the Active struct and will be
// closed when the parent context is canceled.
// Steps:
// 1. Ensure we are still accepting websocket connections, and not shutting down.
Expand All @@ -41,7 +41,7 @@ func NewActiveWebsockets(ctx context.Context) *ActiveWebsockets {
// 4a. If there is an error, write the error to the response writer and return.
// 5. Launch go routine to kill websocket if the parent context is canceled.
// 6. Call 'f' with the websocket connection.
func (a *ActiveWebsockets) Accept(rw http.ResponseWriter, r *http.Request, options *websocket.AcceptOptions, f func(conn *websocket.Conn)) {
func (a *Active) Accept(rw http.ResponseWriter, r *http.Request, options *websocket.AcceptOptions, f func(conn *websocket.Conn)) {
// Ensure we are still accepting websocket connections, and not shutting down.
if err := a.ctx.Err(); err != nil {
httpapi.Write(context.Background(), rw, http.StatusBadRequest, codersdk.Response{
Expand Down Expand Up @@ -79,7 +79,7 @@ func (a *ActiveWebsockets) Accept(rw http.ResponseWriter, r *http.Request, optio
// and close a websocket connection if that context is canceled.
func closeConnOnContext(ctx context.Context, conn *websocket.Conn) {
// Labeling the go routine for goroutine dumps/debugging.
go pprof.Do(ctx, pprof.Labels("service", "api-server", "function", "ActiveWebsockets.track"), func(ctx context.Context) {
go pprof.Do(ctx, pprof.Labels("service", "ActiveWebsockets"), func(ctx context.Context) {
select {
case <-ctx.Done():
_ = conn.Close(websocket.StatusNormalClosure, "")
Expand All @@ -89,7 +89,7 @@ func closeConnOnContext(ctx context.Context, conn *websocket.Conn) {

// Close will close all active websocket connections and wait for them to
// finish.
func (a *ActiveWebsockets) Close() {
func (a *Active) Close() {
a.cancel()
a.wg.Wait()
}
7 changes: 4 additions & 3 deletions coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"github.com/coder/coder/buildinfo"

// Used to serve the Swagger endpoint
"github.com/coder/coder/coderd/activewebsockets"
_ "github.com/coder/coder/coderd/apidoc"
"github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/awsidentity"
Expand Down Expand Up @@ -315,7 +316,7 @@ func New(options *Options) *API {
TemplateScheduleStore: options.TemplateScheduleStore,
Experiments: experiments,
healthCheckGroup: &singleflight.Group[string, *healthcheck.Report]{},
WebsocketWatch: NewActiveWebsockets(ctx),
WebsocketWatch: activewebsockets.New(ctx),
}
if options.UpdateCheckOptions != nil {
api.updateChecker = updatecheck.New(
Expand Down Expand Up @@ -355,7 +356,7 @@ func New(options *Options) *API {
apiRateLimiter := httpmw.RateLimit(options.APIRateLimit, time.Minute)

derpHandler := derphttp.Handler(api.DERPServer)
derpHandler, api.derpCloseFunc = tailnet.WithWebsocketSupport(api.DERPServer, derpHandler)
derpHandler = tailnet.WithWebsocketSupport(api.WebsocketWatch, api.DERPServer, derpHandler)

r.Use(
httpmw.Recover(api.Logger),
Expand Down Expand Up @@ -784,7 +785,7 @@ type API struct {

siteHandler http.Handler

WebsocketWatch *ActiveWebsockets
WebsocketWatch *activewebsockets.Active
derpCloseFunc func()

metricsCache *metricscache.Cache
Expand Down
11 changes: 8 additions & 3 deletions coderd/healthcheck/derp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"tailscale.com/tailcfg"
"tailscale.com/types/key"

"github.com/coder/coder/coderd/activewebsockets"
"github.com/coder/coder/coderd/healthcheck"
"github.com/coder/coder/tailnet"
)
Expand Down Expand Up @@ -124,10 +125,15 @@ func TestDERP(t *testing.T) {
t.Run("ForceWebsockets", func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

derpSrv := derp.NewServer(key.NewNode(), func(format string, args ...any) { t.Logf(format, args...) })
defer derpSrv.Close()
handler, closeHandler := tailnet.WithWebsocketSupport(derpSrv, derphttp.Handler(derpSrv))
defer closeHandler()

sockets := activewebsockets.New(ctx)
handler := tailnet.WithWebsocketSupport(sockets, derpSrv, derphttp.Handler(derpSrv))
defer sockets.Close()

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Upgrade") == "DERP" {
Expand All @@ -140,7 +146,6 @@ func TestDERP(t *testing.T) {
}))

var (
ctx = context.Background()
report = healthcheck.DERPReport{}
derpURL, _ = url.Parse(srv.URL)
opts = &healthcheck.DERPReportOptions{
Expand Down
77 changes: 28 additions & 49 deletions tailnet/derp.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,71 +2,50 @@ package tailnet

import (
"bufio"
"context"
"log"
"net/http"
"strings"
"sync"

"nhooyr.io/websocket"
"tailscale.com/derp"
"tailscale.com/net/wsconn"

"github.com/coder/coder/coderd/activewebsockets"
)

// WithWebsocketSupport returns an http.Handler that upgrades
// connections to the "derp" subprotocol to WebSockets and
// passes them to the DERP server.
// Taken from: https://github.com/tailscale/tailscale/blob/e3211ff88ba85435f70984cf67d9b353f3d650d8/cmd/derper/websocket.go#L21
func WithWebsocketSupport(s *derp.Server, base http.Handler) (http.Handler, func()) {
var mu sync.Mutex
var waitGroup sync.WaitGroup
ctx, cancelFunc := context.WithCancel(context.Background())

func WithWebsocketSupport(sockets *activewebsockets.Active, s *derp.Server, base http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
up := strings.ToLower(r.Header.Get("Upgrade"))
up := strings.ToLower(r.Header.Get("Upgrade"))

// Very early versions of Tailscale set "Upgrade: WebSocket" but didn't actually
// speak WebSockets (they still assumed DERP's binary framing). So to distinguish
// clients that actually want WebSockets, look for an explicit "derp" subprotocol.
if up != "websocket" || !strings.Contains(r.Header.Get("Sec-Websocket-Protocol"), "derp") {
base.ServeHTTP(w, r)
return
}
// Very early versions of Tailscale set "Upgrade: WebSocket" but didn't actually
// speak WebSockets (they still assumed DERP's binary framing). So to distinguish
// clients that actually want WebSockets, look for an explicit "derp" subprotocol.
if up != "websocket" || !strings.Contains(r.Header.Get("Sec-Websocket-Protocol"), "derp") {
base.ServeHTTP(w, r)
return
}

mu.Lock()
if ctx.Err() != nil {
mu.Unlock()
sockets.Accept(w, r, &websocket.AcceptOptions{
Subprotocols: []string{"derp"},
OriginPatterns: []string{"*"},
// Disable compression because we transmit WireGuard messages that
// are not compressible.
// Additionally, Safari has a broken implementation of compression
// (see https://github.com/nhooyr/websocket/issues/218) that makes
// enabling it actively harmful.
CompressionMode: websocket.CompressionDisabled,
}, func(conn *websocket.Conn) {
defer conn.Close(websocket.StatusInternalError, "closing")
if conn.Subprotocol() != "derp" {
conn.Close(websocket.StatusPolicyViolation, "client must speak the derp subprotocol")
return
}
waitGroup.Add(1)
mu.Unlock()
defer waitGroup.Done()
c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
Subprotocols: []string{"derp"},
OriginPatterns: []string{"*"},
// Disable compression because we transmit WireGuard messages that
// are not compressible.
// Additionally, Safari has a broken implementation of compression
// (see https://github.com/nhooyr/websocket/issues/218) that makes
// enabling it actively harmful.
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
log.Printf("websocket.Accept: %v", err)
return
}
defer c.Close(websocket.StatusInternalError, "closing")
if c.Subprotocol() != "derp" {
c.Close(websocket.StatusPolicyViolation, "client must speak the derp subprotocol")
return
}
wc := wsconn.NetConn(ctx, c, websocket.MessageBinary)
wc := wsconn.NetConn(r.Context(), conn, websocket.MessageBinary)
brw := bufio.NewReadWriter(bufio.NewReader(wc), bufio.NewWriter(wc))
s.Accept(ctx, wc, brw, r.RemoteAddr)
}), func() {
cancelFunc()
mu.Lock()
waitGroup.Wait()
mu.Unlock()
}
s.Accept(r.Context(), wc, brw, r.RemoteAddr)
})
})
}
12 changes: 9 additions & 3 deletions tailnet/tailnettest/tailnettest.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tailnettest

import (
"context"
"crypto/tls"
"fmt"
"html"
Expand All @@ -18,6 +19,7 @@ import (
"tailscale.com/types/nettype"

"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd/activewebsockets"
"github.com/coder/coder/tailnet"
)

Expand Down Expand Up @@ -71,8 +73,12 @@ func RunDERPOnlyWebSockets(t *testing.T) *tailcfg.DERPMap {
logf := tailnet.Logger(slogtest.Make(t, nil))
d := derp.NewServer(key.NewNode(), logf)
handler := derphttp.Handler(d)
var closeFunc func()
handler, closeFunc = tailnet.WithWebsocketSupport(d, handler)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sockets := activewebsockets.New(ctx)

handler = tailnet.WithWebsocketSupport(sockets, d, handler)
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/derp" {
handler.ServeHTTP(w, r)
Expand All @@ -91,7 +97,7 @@ func RunDERPOnlyWebSockets(t *testing.T) *tailcfg.DERPMap {
t.Cleanup(func() {
server.CloseClientConnections()
server.Close()
closeFunc()
sockets.Close()
d.Close()
})

Expand Down