Skip to content

Commit 680677e

Browse files
committed
fix: use a background context when piping derp connections
This was causing boatloads of connects to reestablish every time... See #6746
1 parent eaacc26 commit 680677e

File tree

4 files changed

+27
-21
lines changed

4 files changed

+27
-21
lines changed

coderd/workspaceagents.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
294294

295295
go httpapi.Heartbeat(ctx, conn)
296296

297-
agentConn, release, err := api.workspaceAgentCache.Acquire(r, workspaceAgent.ID)
297+
agentConn, release, err := api.workspaceAgentCache.Acquire(workspaceAgent.ID)
298298
if err != nil {
299299
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
300300
return
@@ -339,7 +339,7 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
339339
return
340340
}
341341

342-
agentConn, release, err := api.workspaceAgentCache.Acquire(r, workspaceAgent.ID)
342+
agentConn, release, err := api.workspaceAgentCache.Acquire(workspaceAgent.ID)
343343
if err != nil {
344344
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
345345
Message: "Internal error dialing workspace agent.",
@@ -414,10 +414,8 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
414414
httpapi.Write(ctx, rw, http.StatusOK, portsResponse)
415415
}
416416

417-
func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
418-
ctx := r.Context()
417+
func (api *API) dialWorkspaceAgentTailnet(ctx context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
419418
clientConn, serverConn := net.Pipe()
420-
421419
conn, err := tailnet.NewConn(&tailnet.Options{
422420
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
423421
DERPMap: api.DERPMap,
@@ -428,6 +426,7 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*
428426
_ = serverConn.Close()
429427
return nil, xerrors.Errorf("create tailnet conn: %w", err)
430428
}
429+
ctx, cancel := context.WithCancel(ctx)
431430
conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn {
432431
if !region.EmbeddedRelay {
433432
return nil
@@ -437,7 +436,7 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*
437436
defer left.Close()
438437
defer right.Close()
439438
brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right))
440-
api.DERPServer.Accept(ctx, right, brw, r.RemoteAddr)
439+
api.DERPServer.Accept(ctx, right, brw, "internal")
441440
}()
442441
return left
443442
})
@@ -453,14 +452,15 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*
453452
agentConn := &codersdk.WorkspaceAgentConn{
454453
Conn: conn,
455454
CloseFunc: func() {
455+
cancel()
456456
_ = clientConn.Close()
457457
_ = serverConn.Close()
458458
},
459459
}
460460
go func() {
461461
err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID)
462462
if err != nil {
463-
api.Logger.Warn(r.Context(), "tailnet coordinator client error", slog.Error(err))
463+
api.Logger.Warn(ctx, "tailnet coordinator client error", slog.Error(err))
464464
_ = agentConn.Close()
465465
}
466466
}()

coderd/workspaceapps.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,7 @@ func (api *API) proxyWorkspaceApplication(rw http.ResponseWriter, r *http.Reques
639639
})
640640
}
641641

642-
conn, release, err := api.workspaceAgentCache.Acquire(r, ticket.AgentID)
642+
conn, release, err := api.workspaceAgentCache.Acquire(ticket.AgentID)
643643
if err != nil {
644644
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
645645
Status: http.StatusBadGateway,

coderd/wsconncache/wsconncache.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,18 @@ func New(dialer Dialer, inactiveTimeout time.Duration) *Cache {
2424
if inactiveTimeout == 0 {
2525
inactiveTimeout = 5 * time.Minute
2626
}
27+
ctx, cancelFunc := context.WithCancel(context.Background())
2728
return &Cache{
29+
closeContext: ctx,
30+
closeCancel: cancelFunc,
2831
closed: make(chan struct{}),
2932
dialer: dialer,
3033
inactiveTimeout: inactiveTimeout,
3134
}
3235
}
3336

3437
// Dialer creates a new agent connection by ID.
35-
type Dialer func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error)
38+
type Dialer func(ctx context.Context, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error)
3639

3740
// Conn wraps an agent connection with a reusable HTTP transport.
3841
type Conn struct {
@@ -66,6 +69,8 @@ type Cache struct {
6669
closed chan struct{}
6770
closeMutex sync.Mutex
6871
closeGroup sync.WaitGroup
72+
closeContext context.Context
73+
closeCancel context.CancelFunc
6974
connGroup singleflight.Group
7075
connMap sync.Map
7176
dialer Dialer
@@ -78,7 +83,7 @@ type Cache struct {
7883
// The returned function is used to release a lock on the connection. Once zero
7984
// locks exist on a connection, the inactive timeout will begin to tick down.
8085
// After the time expires, the connection will be cleared from the cache.
81-
func (c *Cache) Acquire(r *http.Request, id uuid.UUID) (*Conn, func(), error) {
86+
func (c *Cache) Acquire(id uuid.UUID) (*Conn, func(), error) {
8287
rawConn, found := c.connMap.Load(id.String())
8388
// If the connection isn't found, establish a new one!
8489
if !found {
@@ -95,7 +100,7 @@ func (c *Cache) Acquire(r *http.Request, id uuid.UUID) (*Conn, func(), error) {
95100
}
96101
c.closeGroup.Add(1)
97102
c.closeMutex.Unlock()
98-
agentConn, err := c.dialer(r, id)
103+
agentConn, err := c.dialer(c.closeContext, id)
99104
if err != nil {
100105
c.closeGroup.Done()
101106
return nil, xerrors.Errorf("dial: %w", err)
@@ -161,6 +166,7 @@ func (c *Cache) Close() error {
161166
default:
162167
}
163168
close(c.closed)
169+
c.closeCancel()
164170
c.closeGroup.Wait()
165171
return nil
166172
}

coderd/wsconncache/wsconncache_test.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,47 +40,47 @@ func TestCache(t *testing.T) {
4040
t.Parallel()
4141
t.Run("Same", func(t *testing.T) {
4242
t.Parallel()
43-
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
43+
cache := wsconncache.New(func(_ context.Context, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
4444
return setupAgent(t, agentsdk.Metadata{}, 0), nil
4545
}, 0)
4646
defer func() {
4747
_ = cache.Close()
4848
}()
49-
conn1, _, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
49+
conn1, _, err := cache.Acquire(uuid.Nil)
5050
require.NoError(t, err)
51-
conn2, _, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
51+
conn2, _, err := cache.Acquire(uuid.Nil)
5252
require.NoError(t, err)
5353
require.True(t, conn1 == conn2)
5454
})
5555
t.Run("Expire", func(t *testing.T) {
5656
t.Parallel()
5757
called := atomic.NewInt32(0)
58-
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
58+
cache := wsconncache.New(func(_ context.Context, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
5959
called.Add(1)
6060
return setupAgent(t, agentsdk.Metadata{}, 0), nil
6161
}, time.Microsecond)
6262
defer func() {
6363
_ = cache.Close()
6464
}()
65-
conn, release, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
65+
conn, release, err := cache.Acquire(uuid.Nil)
6666
require.NoError(t, err)
6767
release()
6868
<-conn.Closed()
69-
conn, release, err = cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
69+
conn, release, err = cache.Acquire(uuid.Nil)
7070
require.NoError(t, err)
7171
release()
7272
<-conn.Closed()
7373
require.Equal(t, int32(2), called.Load())
7474
})
7575
t.Run("NoExpireWhenLocked", func(t *testing.T) {
7676
t.Parallel()
77-
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
77+
cache := wsconncache.New(func(_ context.Context, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
7878
return setupAgent(t, agentsdk.Metadata{}, 0), nil
7979
}, time.Microsecond)
8080
defer func() {
8181
_ = cache.Close()
8282
}()
83-
conn, release, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
83+
conn, release, err := cache.Acquire(uuid.Nil)
8484
require.NoError(t, err)
8585
time.Sleep(time.Millisecond)
8686
release()
@@ -107,7 +107,7 @@ func TestCache(t *testing.T) {
107107
}()
108108
go server.Serve(random)
109109

110-
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
110+
cache := wsconncache.New(func(_ context.Context, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
111111
return setupAgent(t, agentsdk.Metadata{}, 0), nil
112112
}, time.Microsecond)
113113
defer func() {
@@ -130,7 +130,7 @@ func TestCache(t *testing.T) {
130130
defer cancel()
131131
req := httptest.NewRequest(http.MethodGet, "/", nil)
132132
req = req.WithContext(ctx)
133-
conn, release, err := cache.Acquire(req, uuid.Nil)
133+
conn, release, err := cache.Acquire(uuid.Nil)
134134
if !assert.NoError(t, err) {
135135
return
136136
}

0 commit comments

Comments
 (0)