diff --git a/coderd/files.go b/coderd/files.go index d5379c4d8b036..3ce2152b46dfc 100644 --- a/coderd/files.go +++ b/coderd/files.go @@ -12,10 +12,10 @@ import ( "io" "net/http" - "cdr.dev/slog" "github.com/go-chi/chi/v5" "github.com/google/uuid" + "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpapi" diff --git a/coderd/tailnet.go b/coderd/tailnet.go index 7a2ff9afc42e4..fed86ab5aecb0 100644 --- a/coderd/tailnet.go +++ b/coderd/tailnet.go @@ -99,15 +99,14 @@ func NewServerTailnet( transport: tailnetTransport.Clone(), } tn.transport.DialContext = tn.dialContext - - // Bugfix: for some reason all calls to tn.dialContext come from - // "localhost", causing connections to be cached and requests to go to the - // wrong workspaces. This disables keepalives for now until the root cause - // can be found. - tn.transport.MaxIdleConnsPerHost = -1 - tn.transport.DisableKeepAlives = true - + // These options are mostly just picked at random, and they can likely be + // fine tuned further. Generally, users are running applications in dev mode + // which can generate hundreds of requests per page load, so we increased + // MaxIdleConnsPerHost from 2 to 6 and removed the limit of total idle + // conns. + tn.transport.MaxIdleConnsPerHost = 6 tn.transport.MaxIdleConns = 0 + tn.transport.IdleConnTimeout = 10 * time.Minute // We intentionally don't verify the certificate chain here. // The connection to the workspace is already established and most // apps are already going to be accessed over plain HTTP, this config @@ -308,7 +307,15 @@ type ServerTailnet struct { } func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) *httputil.ReverseProxy { - proxy := httputil.NewSingleHostReverseProxy(targetURL) + // Rewrite the targetURL's Host to point to the agent's IP. This is + // necessary because due to TCP connection caching, each agent needs to be + // addressed invidivually. Otherwise, all connections get dialed as + // "localhost:port", causing connections to be shared across agents. + tgt := *targetURL + _, port, _ := net.SplitHostPort(tgt.Host) + tgt.Host = net.JoinHostPort(tailnet.IPFromUUID(agentID).String(), port) + + proxy := httputil.NewSingleHostReverseProxy(&tgt) proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { site.RenderStaticErrorPage(w, r, site.ErrorPageData{ Status: http.StatusBadGateway, diff --git a/coderd/tailnet_test.go b/coderd/tailnet_test.go index d6a6c143fe051..cffe818424827 100644 --- a/coderd/tailnet_test.go +++ b/coderd/tailnet_test.go @@ -3,10 +3,13 @@ package coderd_test import ( "context" "fmt" + "io" + "net" "net/http" "net/http/httptest" - "net/netip" "net/url" + "strconv" + "sync/atomic" "testing" "github.com/google/uuid" @@ -35,9 +38,10 @@ func TestServerTailnet_AgentConn_OK(t *testing.T) { defer cancel() // Connect through the ServerTailnet - agentID, _, serverTailnet := setupAgent(t, nil) + agents, serverTailnet := setupServerTailnetAgent(t, 1) + a := agents[0] - conn, release, err := serverTailnet.AgentConn(ctx, agentID) + conn, release, err := serverTailnet.AgentConn(ctx, a.id) require.NoError(t, err) defer release() @@ -53,12 +57,13 @@ func TestServerTailnet_ReverseProxy(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - agentID, _, serverTailnet := setupAgent(t, nil) + agents, serverTailnet := setupServerTailnetAgent(t, 1) + a := agents[0] u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort)) require.NoError(t, err) - rp := serverTailnet.ReverseProxy(u, u, agentID) + rp := serverTailnet.ReverseProxy(u, u, a.id) rw := httptest.NewRecorder() req := httptest.NewRequest( @@ -74,13 +79,147 @@ func TestServerTailnet_ReverseProxy(t *testing.T) { assert.Equal(t, http.StatusOK, res.StatusCode) }) + t.Run("HostRewrite", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + agents, serverTailnet := setupServerTailnetAgent(t, 1) + a := agents[0] + + u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort)) + require.NoError(t, err) + + rp := serverTailnet.ReverseProxy(u, u, a.id) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) + require.NoError(t, err) + + // Ensure the reverse proxy director rewrites the url host to the agent's IP. + rp.Director(req) + assert.Equal(t, + fmt.Sprintf("[%s]:%d", tailnet.IPFromUUID(a.id).String(), codersdk.WorkspaceAgentHTTPAPIServerPort), + req.URL.Host, + ) + }) + + t.Run("CachesConnection", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + agents, serverTailnet := setupServerTailnetAgent(t, 1) + a := agents[0] + port := ":4444" + ln, err := a.TailnetConn().Listen("tcp", port) + require.NoError(t, err) + wln := &wrappedListener{Listener: ln} + + serverClosed := make(chan struct{}) + go func() { + defer close(serverClosed) + //nolint:gosec + _ = http.Serve(wln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("hello from agent")) + })) + }() + defer func() { + // wait for server to close + <-serverClosed + }() + + defer ln.Close() + + u, err := url.Parse("http://127.0.0.1" + port) + require.NoError(t, err) + + rp := serverTailnet.ReverseProxy(u, u, a.id) + + for i := 0; i < 5; i++ { + rw := httptest.NewRecorder() + req := httptest.NewRequest( + http.MethodGet, + u.String(), + nil, + ).WithContext(ctx) + + rp.ServeHTTP(rw, req) + res := rw.Result() + + _, _ = io.Copy(io.Discard, res.Body) + res.Body.Close() + assert.Equal(t, http.StatusOK, res.StatusCode) + } + + assert.Equal(t, 1, wln.getDials()) + }) + + t.Run("NotReusedBetweenAgents", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + agents, serverTailnet := setupServerTailnetAgent(t, 2) + port := ":4444" + + for i, ag := range agents { + i := i + ln, err := ag.TailnetConn().Listen("tcp", port) + require.NoError(t, err) + wln := &wrappedListener{Listener: ln} + + serverClosed := make(chan struct{}) + go func() { + defer close(serverClosed) + //nolint:gosec + _ = http.Serve(wln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(strconv.Itoa(i))) + })) + }() + defer func() { //nolint:revive + // wait for server to close + <-serverClosed + }() + + defer ln.Close() //nolint:revive + } + + u, err := url.Parse("http://127.0.0.1" + port) + require.NoError(t, err) + + for i, ag := range agents { + rp := serverTailnet.ReverseProxy(u, u, ag.id) + + rw := httptest.NewRecorder() + req := httptest.NewRequest( + http.MethodGet, + u.String(), + nil, + ).WithContext(ctx) + + rp.ServeHTTP(rw, req) + res := rw.Result() + + body, _ := io.ReadAll(res.Body) + res.Body.Close() + assert.Equal(t, http.StatusOK, res.StatusCode) + assert.Equal(t, strconv.Itoa(i), string(body)) + } + }) + t.Run("HTTPSProxy", func(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - agentID, _, serverTailnet := setupAgent(t, nil) + agents, serverTailnet := setupServerTailnetAgent(t, 1) + a := agents[0] const expectedResponseCode = 209 // Test that we can proxy HTTPS traffic. @@ -92,7 +231,7 @@ func TestServerTailnet_ReverseProxy(t *testing.T) { uri, err := url.Parse(s.URL) require.NoError(t, err) - rp := serverTailnet.ReverseProxy(uri, uri, agentID) + rp := serverTailnet.ReverseProxy(uri, uri, a.id) rw := httptest.NewRecorder() req := httptest.NewRequest( @@ -109,44 +248,74 @@ func TestServerTailnet_ReverseProxy(t *testing.T) { }) } -func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.Agent, *coderd.ServerTailnet) { +type wrappedListener struct { + net.Listener + dials int32 +} + +func (w *wrappedListener) Accept() (net.Conn, error) { + conn, err := w.Listener.Accept() + if err != nil { + return nil, err + } + + atomic.AddInt32(&w.dials, 1) + return conn, nil +} + +func (w *wrappedListener) getDials() int { + return int(atomic.LoadInt32(&w.dials)) +} + +type agentWithID struct { + id uuid.UUID + agent.Agent +} + +func setupServerTailnetAgent(t *testing.T, agentNum int) ([]agentWithID, *coderd.ServerTailnet) { logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) derpMap, derpServer := tailnettest.RunDERPAndSTUN(t) - manifest := agentsdk.Manifest{ - AgentID: uuid.New(), - DERPMap: derpMap, - } coord := tailnet.NewCoordinator(logger) t.Cleanup(func() { _ = coord.Close() }) - c := agenttest.NewClient(t, logger, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord) - t.Cleanup(c.Close) + agents := []agentWithID{} - options := agent.Options{ - Client: c, - Filesystem: afero.NewMemMapFs(), - Logger: logger.Named("agent"), - Addresses: agentAddresses, - } + for i := 0; i < agentNum; i++ { + manifest := agentsdk.Manifest{ + AgentID: uuid.New(), + DERPMap: derpMap, + } - ag := agent.New(options) - t.Cleanup(func() { - _ = ag.Close() - }) + c := agenttest.NewClient(t, logger, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord) + t.Cleanup(c.Close) + + options := agent.Options{ + Client: c, + Filesystem: afero.NewMemMapFs(), + Logger: logger.Named("agent"), + } - // Wait for the agent to connect. - require.Eventually(t, func() bool { - return coord.Node(manifest.AgentID) != nil - }, testutil.WaitShort, testutil.IntervalFast) + ag := agent.New(options) + t.Cleanup(func() { + _ = ag.Close() + }) + + // Wait for the agent to connect. + require.Eventually(t, func() bool { + return coord.Node(manifest.AgentID) != nil + }, testutil.WaitShort, testutil.IntervalFast) + + agents = append(agents, agentWithID{id: manifest.AgentID, Agent: ag}) + } serverTailnet, err := coderd.NewServerTailnet( context.Background(), logger, derpServer, - func() *tailcfg.DERPMap { return manifest.DERPMap }, + func() *tailcfg.DERPMap { return derpMap }, false, func(context.Context) (tailnet.MultiAgentConn, error) { return coord.ServeMultiAgent(uuid.New()), nil }, trace.NewNoopTracerProvider(), @@ -157,5 +326,5 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A _ = serverTailnet.Close() }) - return manifest.AgentID, ag, serverTailnet + return agents, serverTailnet }