From ead3e059b2b79ef7dd14a165bc650b1dae799c95 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 23 Mar 2023 14:13:29 +0000 Subject: [PATCH] fix: use a background context when piping derp connections This was causing boatloads of connects to reestablish every time... See https://github.com/coder/coder/issues/6746 --- coderd/workspaceagents.go | 14 +++++++------- coderd/workspaceapps.go | 2 +- coderd/wsconncache/wsconncache.go | 6 +++--- coderd/wsconncache/wsconncache_test.go | 20 ++++++++++---------- 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 51ddf387e2ff3..a83bba5d9d31e 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -294,7 +294,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { go httpapi.Heartbeat(ctx, conn) - agentConn, release, err := api.workspaceAgentCache.Acquire(r, workspaceAgent.ID) + agentConn, release, err := api.workspaceAgentCache.Acquire(workspaceAgent.ID) if err != nil { _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err)) return @@ -339,7 +339,7 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req return } - agentConn, release, err := api.workspaceAgentCache.Acquire(r, workspaceAgent.ID) + agentConn, release, err := api.workspaceAgentCache.Acquire(workspaceAgent.ID) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error dialing workspace agent.", @@ -414,10 +414,8 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req httpapi.Write(ctx, rw, http.StatusOK, portsResponse) } -func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { - ctx := r.Context() +func (api *API) dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { clientConn, serverConn := net.Pipe() - conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, DERPMap: api.DERPMap, @@ -428,6 +426,7 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (* _ = serverConn.Close() return nil, xerrors.Errorf("create tailnet conn: %w", err) } + ctx, cancel := context.WithCancel(api.ctx) conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn { if !region.EmbeddedRelay { return nil @@ -437,7 +436,7 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (* defer left.Close() defer right.Close() brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right)) - api.DERPServer.Accept(ctx, right, brw, r.RemoteAddr) + api.DERPServer.Accept(ctx, right, brw, "internal") }() return left }) @@ -453,6 +452,7 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (* agentConn := &codersdk.WorkspaceAgentConn{ Conn: conn, CloseFunc: func() { + cancel() _ = clientConn.Close() _ = serverConn.Close() }, @@ -460,7 +460,7 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (* go func() { err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID) if err != nil { - api.Logger.Warn(r.Context(), "tailnet coordinator client error", slog.Error(err)) + api.Logger.Warn(ctx, "tailnet coordinator client error", slog.Error(err)) _ = agentConn.Close() } }() diff --git a/coderd/workspaceapps.go b/coderd/workspaceapps.go index ce43788b791a4..cf3dcfb0c1b86 100644 --- a/coderd/workspaceapps.go +++ b/coderd/workspaceapps.go @@ -639,7 +639,7 @@ func (api *API) proxyWorkspaceApplication(rw http.ResponseWriter, r *http.Reques }) } - conn, release, err := api.workspaceAgentCache.Acquire(r, ticket.AgentID) + conn, release, err := api.workspaceAgentCache.Acquire(ticket.AgentID) if err != nil { site.RenderStaticErrorPage(rw, r, site.ErrorPageData{ Status: http.StatusBadGateway, diff --git a/coderd/wsconncache/wsconncache.go b/coderd/wsconncache/wsconncache.go index 436581858cf4e..19c7f65f9fb74 100644 --- a/coderd/wsconncache/wsconncache.go +++ b/coderd/wsconncache/wsconncache.go @@ -32,7 +32,7 @@ func New(dialer Dialer, inactiveTimeout time.Duration) *Cache { } // Dialer creates a new agent connection by ID. -type Dialer func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) +type Dialer func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) // Conn wraps an agent connection with a reusable HTTP transport. type Conn struct { @@ -78,7 +78,7 @@ type Cache struct { // The returned function is used to release a lock on the connection. Once zero // locks exist on a connection, the inactive timeout will begin to tick down. // After the time expires, the connection will be cleared from the cache. -func (c *Cache) Acquire(r *http.Request, id uuid.UUID) (*Conn, func(), error) { +func (c *Cache) Acquire(id uuid.UUID) (*Conn, func(), error) { rawConn, found := c.connMap.Load(id.String()) // If the connection isn't found, establish a new one! if !found { @@ -95,7 +95,7 @@ func (c *Cache) Acquire(r *http.Request, id uuid.UUID) (*Conn, func(), error) { } c.closeGroup.Add(1) c.closeMutex.Unlock() - agentConn, err := c.dialer(r, id) + agentConn, err := c.dialer(id) if err != nil { c.closeGroup.Done() return nil, xerrors.Errorf("dial: %w", err) diff --git a/coderd/wsconncache/wsconncache_test.go b/coderd/wsconncache/wsconncache_test.go index 6abc5609ec6dc..e217172c6d776 100644 --- a/coderd/wsconncache/wsconncache_test.go +++ b/coderd/wsconncache/wsconncache_test.go @@ -40,33 +40,33 @@ func TestCache(t *testing.T) { t.Parallel() t.Run("Same", func(t *testing.T) { t.Parallel() - cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { + cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { return setupAgent(t, agentsdk.Metadata{}, 0), nil }, 0) defer func() { _ = cache.Close() }() - conn1, _, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil) + conn1, _, err := cache.Acquire(uuid.Nil) require.NoError(t, err) - conn2, _, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil) + conn2, _, err := cache.Acquire(uuid.Nil) require.NoError(t, err) require.True(t, conn1 == conn2) }) t.Run("Expire", func(t *testing.T) { t.Parallel() called := atomic.NewInt32(0) - cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { + cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { called.Add(1) return setupAgent(t, agentsdk.Metadata{}, 0), nil }, time.Microsecond) defer func() { _ = cache.Close() }() - conn, release, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil) + conn, release, err := cache.Acquire(uuid.Nil) require.NoError(t, err) release() <-conn.Closed() - conn, release, err = cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil) + conn, release, err = cache.Acquire(uuid.Nil) require.NoError(t, err) release() <-conn.Closed() @@ -74,13 +74,13 @@ func TestCache(t *testing.T) { }) t.Run("NoExpireWhenLocked", func(t *testing.T) { t.Parallel() - cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { + cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { return setupAgent(t, agentsdk.Metadata{}, 0), nil }, time.Microsecond) defer func() { _ = cache.Close() }() - conn, release, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil) + conn, release, err := cache.Acquire(uuid.Nil) require.NoError(t, err) time.Sleep(time.Millisecond) release() @@ -107,7 +107,7 @@ func TestCache(t *testing.T) { }() go server.Serve(random) - cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { + cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { return setupAgent(t, agentsdk.Metadata{}, 0), nil }, time.Microsecond) defer func() { @@ -130,7 +130,7 @@ func TestCache(t *testing.T) { defer cancel() req := httptest.NewRequest(http.MethodGet, "/", nil) req = req.WithContext(ctx) - conn, release, err := cache.Acquire(req, uuid.Nil) + conn, release, err := cache.Acquire(uuid.Nil) if !assert.NoError(t, err) { return }