Skip to content

Commit 3ace798

Browse files
authored
fix: rewrite url to agent ip in single tailnet (#11810)
This restores previous behavior of being able to cache connections across agents in single tailnet.
1 parent 073d1f7 commit 3ace798

File tree

2 files changed

+215
-39
lines changed

2 files changed

+215
-39
lines changed

coderd/tailnet.go

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,14 @@ func NewServerTailnet(
9999
transport: tailnetTransport.Clone(),
100100
}
101101
tn.transport.DialContext = tn.dialContext
102-
103-
// Bugfix: for some reason all calls to tn.dialContext come from
104-
// "localhost", causing connections to be cached and requests to go to the
105-
// wrong workspaces. This disables keepalives for now until the root cause
106-
// can be found.
107-
tn.transport.MaxIdleConnsPerHost = -1
108-
tn.transport.DisableKeepAlives = true
109-
102+
// These options are mostly just picked at random, and they can likely be
103+
// fine tuned further. Generally, users are running applications in dev mode
104+
// which can generate hundreds of requests per page load, so we increased
105+
// MaxIdleConnsPerHost from 2 to 6 and removed the limit of total idle
106+
// conns.
107+
tn.transport.MaxIdleConnsPerHost = 6
110108
tn.transport.MaxIdleConns = 0
109+
tn.transport.IdleConnTimeout = 10 * time.Minute
111110
// We intentionally don't verify the certificate chain here.
112111
// The connection to the workspace is already established and most
113112
// apps are already going to be accessed over plain HTTP, this config
@@ -308,7 +307,15 @@ type ServerTailnet struct {
308307
}
309308

310309
func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) *httputil.ReverseProxy {
311-
proxy := httputil.NewSingleHostReverseProxy(targetURL)
310+
// Rewrite the targetURL's Host to point to the agent's IP. This is
311+
// necessary because due to TCP connection caching, each agent needs to be
312+
// addressed invidivually. Otherwise, all connections get dialed as
313+
// "localhost:port", causing connections to be shared across agents.
314+
tgt := *targetURL
315+
_, port, _ := net.SplitHostPort(tgt.Host)
316+
tgt.Host = net.JoinHostPort(tailnet.IPFromUUID(agentID).String(), port)
317+
318+
proxy := httputil.NewSingleHostReverseProxy(&tgt)
312319
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
313320
site.RenderStaticErrorPage(w, r, site.ErrorPageData{
314321
Status: http.StatusBadGateway,

coderd/tailnet_test.go

Lines changed: 199 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@ package coderd_test
33
import (
44
"context"
55
"fmt"
6+
"io"
7+
"net"
68
"net/http"
79
"net/http/httptest"
8-
"net/netip"
910
"net/url"
11+
"strconv"
12+
"sync/atomic"
1013
"testing"
1114

1215
"github.com/google/uuid"
@@ -35,9 +38,10 @@ func TestServerTailnet_AgentConn_OK(t *testing.T) {
3538
defer cancel()
3639

3740
// Connect through the ServerTailnet
38-
agentID, _, serverTailnet := setupAgent(t, nil)
41+
agents, serverTailnet := setupServerTailnetAgent(t, 1)
42+
a := agents[0]
3943

40-
conn, release, err := serverTailnet.AgentConn(ctx, agentID)
44+
conn, release, err := serverTailnet.AgentConn(ctx, a.id)
4145
require.NoError(t, err)
4246
defer release()
4347

@@ -53,12 +57,13 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
5357
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
5458
defer cancel()
5559

56-
agentID, _, serverTailnet := setupAgent(t, nil)
60+
agents, serverTailnet := setupServerTailnetAgent(t, 1)
61+
a := agents[0]
5762

5863
u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort))
5964
require.NoError(t, err)
6065

61-
rp := serverTailnet.ReverseProxy(u, u, agentID)
66+
rp := serverTailnet.ReverseProxy(u, u, a.id)
6267

6368
rw := httptest.NewRecorder()
6469
req := httptest.NewRequest(
@@ -74,13 +79,147 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
7479
assert.Equal(t, http.StatusOK, res.StatusCode)
7580
})
7681

82+
t.Run("HostRewrite", func(t *testing.T) {
83+
t.Parallel()
84+
85+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
86+
defer cancel()
87+
88+
agents, serverTailnet := setupServerTailnetAgent(t, 1)
89+
a := agents[0]
90+
91+
u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort))
92+
require.NoError(t, err)
93+
94+
rp := serverTailnet.ReverseProxy(u, u, a.id)
95+
96+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
97+
require.NoError(t, err)
98+
99+
// Ensure the reverse proxy director rewrites the url host to the agent's IP.
100+
rp.Director(req)
101+
assert.Equal(t,
102+
fmt.Sprintf("[%s]:%d", tailnet.IPFromUUID(a.id).String(), codersdk.WorkspaceAgentHTTPAPIServerPort),
103+
req.URL.Host,
104+
)
105+
})
106+
107+
t.Run("CachesConnection", func(t *testing.T) {
108+
t.Parallel()
109+
110+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
111+
defer cancel()
112+
113+
agents, serverTailnet := setupServerTailnetAgent(t, 1)
114+
a := agents[0]
115+
port := ":4444"
116+
ln, err := a.TailnetConn().Listen("tcp", port)
117+
require.NoError(t, err)
118+
wln := &wrappedListener{Listener: ln}
119+
120+
serverClosed := make(chan struct{})
121+
go func() {
122+
defer close(serverClosed)
123+
//nolint:gosec
124+
_ = http.Serve(wln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
125+
w.WriteHeader(http.StatusOK)
126+
w.Write([]byte("hello from agent"))
127+
}))
128+
}()
129+
defer func() {
130+
// wait for server to close
131+
<-serverClosed
132+
}()
133+
134+
defer ln.Close()
135+
136+
u, err := url.Parse("http://127.0.0.1" + port)
137+
require.NoError(t, err)
138+
139+
rp := serverTailnet.ReverseProxy(u, u, a.id)
140+
141+
for i := 0; i < 5; i++ {
142+
rw := httptest.NewRecorder()
143+
req := httptest.NewRequest(
144+
http.MethodGet,
145+
u.String(),
146+
nil,
147+
).WithContext(ctx)
148+
149+
rp.ServeHTTP(rw, req)
150+
res := rw.Result()
151+
152+
_, _ = io.Copy(io.Discard, res.Body)
153+
res.Body.Close()
154+
assert.Equal(t, http.StatusOK, res.StatusCode)
155+
}
156+
157+
assert.Equal(t, 1, wln.getDials())
158+
})
159+
160+
t.Run("NotReusedBetweenAgents", func(t *testing.T) {
161+
t.Parallel()
162+
163+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
164+
defer cancel()
165+
166+
agents, serverTailnet := setupServerTailnetAgent(t, 2)
167+
port := ":4444"
168+
169+
for i, ag := range agents {
170+
i := i
171+
ln, err := ag.TailnetConn().Listen("tcp", port)
172+
require.NoError(t, err)
173+
wln := &wrappedListener{Listener: ln}
174+
175+
serverClosed := make(chan struct{})
176+
go func() {
177+
defer close(serverClosed)
178+
//nolint:gosec
179+
_ = http.Serve(wln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
180+
w.WriteHeader(http.StatusOK)
181+
w.Write([]byte(strconv.Itoa(i)))
182+
}))
183+
}()
184+
defer func() { //nolint:revive
185+
// wait for server to close
186+
<-serverClosed
187+
}()
188+
189+
defer ln.Close() //nolint:revive
190+
}
191+
192+
u, err := url.Parse("http://127.0.0.1" + port)
193+
require.NoError(t, err)
194+
195+
for i, ag := range agents {
196+
rp := serverTailnet.ReverseProxy(u, u, ag.id)
197+
198+
rw := httptest.NewRecorder()
199+
req := httptest.NewRequest(
200+
http.MethodGet,
201+
u.String(),
202+
nil,
203+
).WithContext(ctx)
204+
205+
rp.ServeHTTP(rw, req)
206+
res := rw.Result()
207+
208+
body, _ := io.ReadAll(res.Body)
209+
res.Body.Close()
210+
assert.Equal(t, http.StatusOK, res.StatusCode)
211+
assert.Equal(t, strconv.Itoa(i), string(body))
212+
}
213+
})
214+
77215
t.Run("HTTPSProxy", func(t *testing.T) {
78216
t.Parallel()
79217

80218
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
81219
defer cancel()
82220

83-
agentID, _, serverTailnet := setupAgent(t, nil)
221+
agents, serverTailnet := setupServerTailnetAgent(t, 1)
222+
a := agents[0]
84223

85224
const expectedResponseCode = 209
86225
// Test that we can proxy HTTPS traffic.
@@ -92,7 +231,7 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
92231
uri, err := url.Parse(s.URL)
93232
require.NoError(t, err)
94233

95-
rp := serverTailnet.ReverseProxy(uri, uri, agentID)
234+
rp := serverTailnet.ReverseProxy(uri, uri, a.id)
96235

97236
rw := httptest.NewRecorder()
98237
req := httptest.NewRequest(
@@ -109,44 +248,74 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
109248
})
110249
}
111250

112-
func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.Agent, *coderd.ServerTailnet) {
251+
type wrappedListener struct {
252+
net.Listener
253+
dials int32
254+
}
255+
256+
func (w *wrappedListener) Accept() (net.Conn, error) {
257+
conn, err := w.Listener.Accept()
258+
if err != nil {
259+
return nil, err
260+
}
261+
262+
atomic.AddInt32(&w.dials, 1)
263+
return conn, nil
264+
}
265+
266+
func (w *wrappedListener) getDials() int {
267+
return int(atomic.LoadInt32(&w.dials))
268+
}
269+
270+
type agentWithID struct {
271+
id uuid.UUID
272+
agent.Agent
273+
}
274+
275+
func setupServerTailnetAgent(t *testing.T, agentNum int) ([]agentWithID, *coderd.ServerTailnet) {
113276
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
114277
derpMap, derpServer := tailnettest.RunDERPAndSTUN(t)
115-
manifest := agentsdk.Manifest{
116-
AgentID: uuid.New(),
117-
DERPMap: derpMap,
118-
}
119278

120279
coord := tailnet.NewCoordinator(logger)
121280
t.Cleanup(func() {
122281
_ = coord.Close()
123282
})
124283

125-
c := agenttest.NewClient(t, logger, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord)
126-
t.Cleanup(c.Close)
284+
agents := []agentWithID{}
127285

128-
options := agent.Options{
129-
Client: c,
130-
Filesystem: afero.NewMemMapFs(),
131-
Logger: logger.Named("agent"),
132-
Addresses: agentAddresses,
133-
}
286+
for i := 0; i < agentNum; i++ {
287+
manifest := agentsdk.Manifest{
288+
AgentID: uuid.New(),
289+
DERPMap: derpMap,
290+
}
134291

135-
ag := agent.New(options)
136-
t.Cleanup(func() {
137-
_ = ag.Close()
138-
})
292+
c := agenttest.NewClient(t, logger, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord)
293+
t.Cleanup(c.Close)
294+
295+
options := agent.Options{
296+
Client: c,
297+
Filesystem: afero.NewMemMapFs(),
298+
Logger: logger.Named("agent"),
299+
}
139300

140-
// Wait for the agent to connect.
141-
require.Eventually(t, func() bool {
142-
return coord.Node(manifest.AgentID) != nil
143-
}, testutil.WaitShort, testutil.IntervalFast)
301+
ag := agent.New(options)
302+
t.Cleanup(func() {
303+
_ = ag.Close()
304+
})
305+
306+
// Wait for the agent to connect.
307+
require.Eventually(t, func() bool {
308+
return coord.Node(manifest.AgentID) != nil
309+
}, testutil.WaitShort, testutil.IntervalFast)
310+
311+
agents = append(agents, agentWithID{id: manifest.AgentID, Agent: ag})
312+
}
144313

145314
serverTailnet, err := coderd.NewServerTailnet(
146315
context.Background(),
147316
logger,
148317
derpServer,
149-
func() *tailcfg.DERPMap { return manifest.DERPMap },
318+
func() *tailcfg.DERPMap { return derpMap },
150319
false,
151320
func(context.Context) (tailnet.MultiAgentConn, error) { return coord.ServeMultiAgent(uuid.New()), nil },
152321
trace.NewNoopTracerProvider(),
@@ -157,5 +326,5 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
157326
_ = serverTailnet.Close()
158327
})
159328

160-
return manifest.AgentID, ag, serverTailnet
329+
return agents, serverTailnet
161330
}

0 commit comments

Comments
 (0)