Skip to content

Commit 6d7722e

Browse files
committed
add integration test
1 parent 689855c commit 6d7722e

File tree

2 files changed

+179
-36
lines changed

2 files changed

+179
-36
lines changed

coderd/files.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ import (
1212
"io"
1313
"net/http"
1414

15-
"cdr.dev/slog"
1615
"github.com/go-chi/chi/v5"
1716
"github.com/google/uuid"
1817

18+
"cdr.dev/slog"
1919
"github.com/coder/coder/v2/coderd/database"
2020
"github.com/coder/coder/v2/coderd/database/dbtime"
2121
"github.com/coder/coder/v2/coderd/httpapi"

coderd/tailnet_test.go

+178-35
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@ package coderd_test
33
import (
44
"context"
55
"fmt"
6+
"io"
7+
"net"
68
"net/http"
79
"net/http/httptest"
8-
"net/netip"
910
"net/url"
11+
"strconv"
12+
"sync/atomic"
1013
"testing"
1114

1215
"github.com/google/uuid"
@@ -35,9 +38,10 @@ func TestServerTailnet_AgentConn_OK(t *testing.T) {
3538
defer cancel()
3639

3740
// Connect through the ServerTailnet
38-
agentID, _, serverTailnet := setupAgent(t, nil)
41+
agents, serverTailnet := setupServerTailnetAgent(t, 1)
42+
a := agents[0]
3943

40-
conn, release, err := serverTailnet.AgentConn(ctx, agentID)
44+
conn, release, err := serverTailnet.AgentConn(ctx, a.id)
4145
require.NoError(t, err)
4246
defer release()
4347

@@ -53,12 +57,13 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
5357
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
5458
defer cancel()
5559

56-
agentID, _, serverTailnet := setupAgent(t, nil)
60+
agents, serverTailnet := setupServerTailnetAgent(t, 1)
61+
a := agents[0]
5762

5863
u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort))
5964
require.NoError(t, err)
6065

61-
rp := serverTailnet.ReverseProxy(u, u, agentID)
66+
rp := serverTailnet.ReverseProxy(u, u, a.id)
6267

6368
rw := httptest.NewRecorder()
6469
req := httptest.NewRequest(
@@ -80,33 +85,141 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
8085
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
8186
defer cancel()
8287

83-
agentID, _, serverTailnet := setupAgent(t, nil)
88+
agents, serverTailnet := setupServerTailnetAgent(t, 1)
89+
a := agents[0]
8490

8591
u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort))
8692
require.NoError(t, err)
8793

88-
rp, release, err := serverTailnet.ReverseProxy(u, u, agentID)
89-
require.NoError(t, err)
90-
defer release()
94+
rp := serverTailnet.ReverseProxy(u, u, a.id)
9195

9296
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
9397
require.NoError(t, err)
9498

9599
// Ensure the reverse proxy director rewrites the url host to the agent's IP.
96100
rp.Director(req)
97101
assert.Equal(t,
98-
fmt.Sprintf("[%s]:%d", tailnet.IPFromUUID(agentID).String(), codersdk.WorkspaceAgentHTTPAPIServerPort),
102+
fmt.Sprintf("[%s]:%d", tailnet.IPFromUUID(a.id).String(), codersdk.WorkspaceAgentHTTPAPIServerPort),
99103
req.URL.Host,
100104
)
101105
})
102106

107+
t.Run("CachesConnection", func(t *testing.T) {
108+
t.Parallel()
109+
110+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
111+
defer cancel()
112+
113+
agents, serverTailnet := setupServerTailnetAgent(t, 1)
114+
a := agents[0]
115+
port := ":4444"
116+
ln, err := a.TailnetConn().Listen("tcp", port)
117+
require.NoError(t, err)
118+
wln := &wrappedListener{Listener: ln}
119+
120+
serverClosed := make(chan struct{})
121+
go func() {
122+
defer close(serverClosed)
123+
//nolint:gosec
124+
_ = http.Serve(wln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
125+
w.WriteHeader(http.StatusOK)
126+
w.Write([]byte("hello from agent"))
127+
}))
128+
}()
129+
defer func() {
130+
// wait for server to close
131+
<-serverClosed
132+
}()
133+
134+
defer ln.Close()
135+
136+
u, err := url.Parse("http://127.0.0.1" + port)
137+
require.NoError(t, err)
138+
139+
rp := serverTailnet.ReverseProxy(u, u, a.id)
140+
141+
for i := 0; i < 5; i++ {
142+
rw := httptest.NewRecorder()
143+
req := httptest.NewRequest(
144+
http.MethodGet,
145+
u.String(),
146+
nil,
147+
).WithContext(ctx)
148+
149+
rp.ServeHTTP(rw, req)
150+
res := rw.Result()
151+
152+
_, _ = io.Copy(io.Discard, res.Body)
153+
res.Body.Close()
154+
assert.Equal(t, http.StatusOK, res.StatusCode)
155+
}
156+
157+
assert.Equal(t, 1, wln.getDials())
158+
})
159+
160+
t.Run("NotReusedBetweenAgents", func(t *testing.T) {
161+
t.Parallel()
162+
163+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
164+
defer cancel()
165+
166+
agents, serverTailnet := setupServerTailnetAgent(t, 2)
167+
port := ":4444"
168+
169+
for i, ag := range agents {
170+
i := i
171+
ln, err := ag.TailnetConn().Listen("tcp", port)
172+
require.NoError(t, err)
173+
wln := &wrappedListener{Listener: ln}
174+
175+
serverClosed := make(chan struct{})
176+
go func() {
177+
defer close(serverClosed)
178+
//nolint:gosec
179+
_ = http.Serve(wln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
180+
w.WriteHeader(http.StatusOK)
181+
w.Write([]byte(strconv.Itoa(i)))
182+
}))
183+
}()
184+
defer func() { //nolint:revive
185+
// wait for server to close
186+
<-serverClosed
187+
}()
188+
189+
defer ln.Close() //nolint:revive
190+
}
191+
192+
u, err := url.Parse("http://127.0.0.1" + port)
193+
require.NoError(t, err)
194+
195+
for i, ag := range agents {
196+
rp := serverTailnet.ReverseProxy(u, u, ag.id)
197+
198+
rw := httptest.NewRecorder()
199+
req := httptest.NewRequest(
200+
http.MethodGet,
201+
u.String(),
202+
nil,
203+
).WithContext(ctx)
204+
205+
rp.ServeHTTP(rw, req)
206+
res := rw.Result()
207+
208+
body, _ := io.ReadAll(res.Body)
209+
res.Body.Close()
210+
assert.Equal(t, http.StatusOK, res.StatusCode)
211+
assert.Equal(t, strconv.Itoa(i), string(body))
212+
}
213+
})
214+
103215
t.Run("HTTPSProxy", func(t *testing.T) {
104216
t.Parallel()
105217

106218
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
107219
defer cancel()
108220

109-
agentID, _, serverTailnet := setupAgent(t, nil)
221+
agents, serverTailnet := setupServerTailnetAgent(t, 1)
222+
a := agents[0]
110223

111224
const expectedResponseCode = 209
112225
// Test that we can proxy HTTPS traffic.
@@ -118,7 +231,7 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
118231
uri, err := url.Parse(s.URL)
119232
require.NoError(t, err)
120233

121-
rp := serverTailnet.ReverseProxy(uri, uri, agentID)
234+
rp := serverTailnet.ReverseProxy(uri, uri, a.id)
122235

123236
rw := httptest.NewRecorder()
124237
req := httptest.NewRequest(
@@ -135,44 +248,74 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
135248
})
136249
}
137250

138-
func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.Agent, *coderd.ServerTailnet) {
251+
type wrappedListener struct {
252+
net.Listener
253+
dials int32
254+
}
255+
256+
func (w *wrappedListener) Accept() (net.Conn, error) {
257+
conn, err := w.Listener.Accept()
258+
if err != nil {
259+
return nil, err
260+
}
261+
262+
atomic.AddInt32(&w.dials, 1)
263+
return conn, nil
264+
}
265+
266+
func (w *wrappedListener) getDials() int {
267+
return int(atomic.LoadInt32(&w.dials))
268+
}
269+
270+
type agentWithID struct {
271+
id uuid.UUID
272+
agent.Agent
273+
}
274+
275+
func setupServerTailnetAgent(t *testing.T, agentNum int) ([]agentWithID, *coderd.ServerTailnet) {
139276
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
140277
derpMap, derpServer := tailnettest.RunDERPAndSTUN(t)
141-
manifest := agentsdk.Manifest{
142-
AgentID: uuid.New(),
143-
DERPMap: derpMap,
144-
}
145278

146279
coord := tailnet.NewCoordinator(logger)
147280
t.Cleanup(func() {
148281
_ = coord.Close()
149282
})
150283

151-
c := agenttest.NewClient(t, logger, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord)
152-
t.Cleanup(c.Close)
284+
agents := []agentWithID{}
153285

154-
options := agent.Options{
155-
Client: c,
156-
Filesystem: afero.NewMemMapFs(),
157-
Logger: logger.Named("agent"),
158-
Addresses: agentAddresses,
159-
}
286+
for i := 0; i < agentNum; i++ {
287+
manifest := agentsdk.Manifest{
288+
AgentID: uuid.New(),
289+
DERPMap: derpMap,
290+
}
160291

161-
ag := agent.New(options)
162-
t.Cleanup(func() {
163-
_ = ag.Close()
164-
})
292+
c := agenttest.NewClient(t, logger, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord)
293+
t.Cleanup(c.Close)
294+
295+
options := agent.Options{
296+
Client: c,
297+
Filesystem: afero.NewMemMapFs(),
298+
Logger: logger.Named("agent"),
299+
}
300+
301+
ag := agent.New(options)
302+
t.Cleanup(func() {
303+
_ = ag.Close()
304+
})
165305

166-
// Wait for the agent to connect.
167-
require.Eventually(t, func() bool {
168-
return coord.Node(manifest.AgentID) != nil
169-
}, testutil.WaitShort, testutil.IntervalFast)
306+
// Wait for the agent to connect.
307+
require.Eventually(t, func() bool {
308+
return coord.Node(manifest.AgentID) != nil
309+
}, testutil.WaitShort, testutil.IntervalFast)
310+
311+
agents = append(agents, agentWithID{id: manifest.AgentID, Agent: ag})
312+
}
170313

171314
serverTailnet, err := coderd.NewServerTailnet(
172315
context.Background(),
173316
logger,
174317
derpServer,
175-
func() *tailcfg.DERPMap { return manifest.DERPMap },
318+
func() *tailcfg.DERPMap { return derpMap },
176319
false,
177320
func(context.Context) (tailnet.MultiAgentConn, error) { return coord.ServeMultiAgent(uuid.New()), nil },
178321
trace.NewNoopTracerProvider(),
@@ -183,5 +326,5 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
183326
_ = serverTailnet.Close()
184327
})
185328

186-
return manifest.AgentID, ag, serverTailnet
329+
return agents, serverTailnet
187330
}

0 commit comments

Comments
 (0)