@@ -3,10 +3,13 @@ package coderd_test
3
3
import (
4
4
"context"
5
5
"fmt"
6
+ "io"
7
+ "net"
6
8
"net/http"
7
9
"net/http/httptest"
8
- "net/netip"
9
10
"net/url"
11
+ "strconv"
12
+ "sync/atomic"
10
13
"testing"
11
14
12
15
"github.com/google/uuid"
@@ -35,9 +38,10 @@ func TestServerTailnet_AgentConn_OK(t *testing.T) {
35
38
defer cancel ()
36
39
37
40
// Connect through the ServerTailnet
38
- agentID , _ , serverTailnet := setupAgent (t , nil )
41
+ agents , serverTailnet := setupServerTailnetAgent (t , 1 )
42
+ a := agents [0 ]
39
43
40
- conn , release , err := serverTailnet .AgentConn (ctx , agentID )
44
+ conn , release , err := serverTailnet .AgentConn (ctx , a . id )
41
45
require .NoError (t , err )
42
46
defer release ()
43
47
@@ -53,12 +57,13 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
53
57
ctx , cancel := context .WithTimeout (context .Background (), testutil .WaitLong )
54
58
defer cancel ()
55
59
56
- agentID , _ , serverTailnet := setupAgent (t , nil )
60
+ agents , serverTailnet := setupServerTailnetAgent (t , 1 )
61
+ a := agents [0 ]
57
62
58
63
u , err := url .Parse (fmt .Sprintf ("http://127.0.0.1:%d" , codersdk .WorkspaceAgentHTTPAPIServerPort ))
59
64
require .NoError (t , err )
60
65
61
- rp := serverTailnet .ReverseProxy (u , u , agentID )
66
+ rp := serverTailnet .ReverseProxy (u , u , a . id )
62
67
63
68
rw := httptest .NewRecorder ()
64
69
req := httptest .NewRequest (
@@ -80,33 +85,141 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
80
85
ctx , cancel := context .WithTimeout (context .Background (), testutil .WaitLong )
81
86
defer cancel ()
82
87
83
- agentID , _ , serverTailnet := setupAgent (t , nil )
88
+ agents , serverTailnet := setupServerTailnetAgent (t , 1 )
89
+ a := agents [0 ]
84
90
85
91
u , err := url .Parse (fmt .Sprintf ("http://127.0.0.1:%d" , codersdk .WorkspaceAgentHTTPAPIServerPort ))
86
92
require .NoError (t , err )
87
93
88
- rp , release , err := serverTailnet .ReverseProxy (u , u , agentID )
89
- require .NoError (t , err )
90
- defer release ()
94
+ rp := serverTailnet .ReverseProxy (u , u , a .id )
91
95
92
96
req , err := http .NewRequestWithContext (ctx , http .MethodGet , u .String (), nil )
93
97
require .NoError (t , err )
94
98
95
99
// Ensure the reverse proxy director rewrites the url host to the agent's IP.
96
100
rp .Director (req )
97
101
assert .Equal (t ,
98
- fmt .Sprintf ("[%s]:%d" , tailnet .IPFromUUID (agentID ).String (), codersdk .WorkspaceAgentHTTPAPIServerPort ),
102
+ fmt .Sprintf ("[%s]:%d" , tailnet .IPFromUUID (a . id ).String (), codersdk .WorkspaceAgentHTTPAPIServerPort ),
99
103
req .URL .Host ,
100
104
)
101
105
})
102
106
107
+ t .Run ("CachesConnection" , func (t * testing.T ) {
108
+ t .Parallel ()
109
+
110
+ ctx , cancel := context .WithTimeout (context .Background (), testutil .WaitLong )
111
+ defer cancel ()
112
+
113
+ agents , serverTailnet := setupServerTailnetAgent (t , 1 )
114
+ a := agents [0 ]
115
+ port := ":4444"
116
+ ln , err := a .TailnetConn ().Listen ("tcp" , port )
117
+ require .NoError (t , err )
118
+ wln := & wrappedListener {Listener : ln }
119
+
120
+ serverClosed := make (chan struct {})
121
+ go func () {
122
+ defer close (serverClosed )
123
+ //nolint:gosec
124
+ _ = http .Serve (wln , http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
125
+ w .WriteHeader (http .StatusOK )
126
+ w .Write ([]byte ("hello from agent" ))
127
+ }))
128
+ }()
129
+ defer func () {
130
+ // wait for server to close
131
+ <- serverClosed
132
+ }()
133
+
134
+ defer ln .Close ()
135
+
136
+ u , err := url .Parse ("http://127.0.0.1" + port )
137
+ require .NoError (t , err )
138
+
139
+ rp := serverTailnet .ReverseProxy (u , u , a .id )
140
+
141
+ for i := 0 ; i < 5 ; i ++ {
142
+ rw := httptest .NewRecorder ()
143
+ req := httptest .NewRequest (
144
+ http .MethodGet ,
145
+ u .String (),
146
+ nil ,
147
+ ).WithContext (ctx )
148
+
149
+ rp .ServeHTTP (rw , req )
150
+ res := rw .Result ()
151
+
152
+ _ , _ = io .Copy (io .Discard , res .Body )
153
+ res .Body .Close ()
154
+ assert .Equal (t , http .StatusOK , res .StatusCode )
155
+ }
156
+
157
+ assert .Equal (t , 1 , wln .getDials ())
158
+ })
159
+
160
+ t .Run ("NotReusedBetweenAgents" , func (t * testing.T ) {
161
+ t .Parallel ()
162
+
163
+ ctx , cancel := context .WithTimeout (context .Background (), testutil .WaitLong )
164
+ defer cancel ()
165
+
166
+ agents , serverTailnet := setupServerTailnetAgent (t , 2 )
167
+ port := ":4444"
168
+
169
+ for i , ag := range agents {
170
+ i := i
171
+ ln , err := ag .TailnetConn ().Listen ("tcp" , port )
172
+ require .NoError (t , err )
173
+ wln := & wrappedListener {Listener : ln }
174
+
175
+ serverClosed := make (chan struct {})
176
+ go func () {
177
+ defer close (serverClosed )
178
+ //nolint:gosec
179
+ _ = http .Serve (wln , http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
180
+ w .WriteHeader (http .StatusOK )
181
+ w .Write ([]byte (strconv .Itoa (i )))
182
+ }))
183
+ }()
184
+ defer func () { //nolint:revive
185
+ // wait for server to close
186
+ <- serverClosed
187
+ }()
188
+
189
+ defer ln .Close () //nolint:revive
190
+ }
191
+
192
+ u , err := url .Parse ("http://127.0.0.1" + port )
193
+ require .NoError (t , err )
194
+
195
+ for i , ag := range agents {
196
+ rp := serverTailnet .ReverseProxy (u , u , ag .id )
197
+
198
+ rw := httptest .NewRecorder ()
199
+ req := httptest .NewRequest (
200
+ http .MethodGet ,
201
+ u .String (),
202
+ nil ,
203
+ ).WithContext (ctx )
204
+
205
+ rp .ServeHTTP (rw , req )
206
+ res := rw .Result ()
207
+
208
+ body , _ := io .ReadAll (res .Body )
209
+ res .Body .Close ()
210
+ assert .Equal (t , http .StatusOK , res .StatusCode )
211
+ assert .Equal (t , strconv .Itoa (i ), string (body ))
212
+ }
213
+ })
214
+
103
215
t .Run ("HTTPSProxy" , func (t * testing.T ) {
104
216
t .Parallel ()
105
217
106
218
ctx , cancel := context .WithTimeout (context .Background (), testutil .WaitLong )
107
219
defer cancel ()
108
220
109
- agentID , _ , serverTailnet := setupAgent (t , nil )
221
+ agents , serverTailnet := setupServerTailnetAgent (t , 1 )
222
+ a := agents [0 ]
110
223
111
224
const expectedResponseCode = 209
112
225
// Test that we can proxy HTTPS traffic.
@@ -118,7 +231,7 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
118
231
uri , err := url .Parse (s .URL )
119
232
require .NoError (t , err )
120
233
121
- rp := serverTailnet .ReverseProxy (uri , uri , agentID )
234
+ rp := serverTailnet .ReverseProxy (uri , uri , a . id )
122
235
123
236
rw := httptest .NewRecorder ()
124
237
req := httptest .NewRequest (
@@ -135,44 +248,74 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
135
248
})
136
249
}
137
250
138
- func setupAgent (t * testing.T , agentAddresses []netip.Prefix ) (uuid.UUID , agent.Agent , * coderd.ServerTailnet ) {
251
+ type wrappedListener struct {
252
+ net.Listener
253
+ dials int32
254
+ }
255
+
256
+ func (w * wrappedListener ) Accept () (net.Conn , error ) {
257
+ conn , err := w .Listener .Accept ()
258
+ if err != nil {
259
+ return nil , err
260
+ }
261
+
262
+ atomic .AddInt32 (& w .dials , 1 )
263
+ return conn , nil
264
+ }
265
+
266
+ func (w * wrappedListener ) getDials () int {
267
+ return int (atomic .LoadInt32 (& w .dials ))
268
+ }
269
+
270
+ type agentWithID struct {
271
+ id uuid.UUID
272
+ agent.Agent
273
+ }
274
+
275
+ func setupServerTailnetAgent (t * testing.T , agentNum int ) ([]agentWithID , * coderd.ServerTailnet ) {
139
276
logger := slogtest .Make (t , nil ).Leveled (slog .LevelDebug )
140
277
derpMap , derpServer := tailnettest .RunDERPAndSTUN (t )
141
- manifest := agentsdk.Manifest {
142
- AgentID : uuid .New (),
143
- DERPMap : derpMap ,
144
- }
145
278
146
279
coord := tailnet .NewCoordinator (logger )
147
280
t .Cleanup (func () {
148
281
_ = coord .Close ()
149
282
})
150
283
151
- c := agenttest .NewClient (t , logger , manifest .AgentID , manifest , make (chan * agentsdk.Stats , 50 ), coord )
152
- t .Cleanup (c .Close )
284
+ agents := []agentWithID {}
153
285
154
- options := agent.Options {
155
- Client : c ,
156
- Filesystem : afero .NewMemMapFs (),
157
- Logger : logger .Named ("agent" ),
158
- Addresses : agentAddresses ,
159
- }
286
+ for i := 0 ; i < agentNum ; i ++ {
287
+ manifest := agentsdk.Manifest {
288
+ AgentID : uuid .New (),
289
+ DERPMap : derpMap ,
290
+ }
160
291
161
- ag := agent .New (options )
162
- t .Cleanup (func () {
163
- _ = ag .Close ()
164
- })
292
+ c := agenttest .NewClient (t , logger , manifest .AgentID , manifest , make (chan * agentsdk.Stats , 50 ), coord )
293
+ t .Cleanup (c .Close )
294
+
295
+ options := agent.Options {
296
+ Client : c ,
297
+ Filesystem : afero .NewMemMapFs (),
298
+ Logger : logger .Named ("agent" ),
299
+ }
300
+
301
+ ag := agent .New (options )
302
+ t .Cleanup (func () {
303
+ _ = ag .Close ()
304
+ })
165
305
166
- // Wait for the agent to connect.
167
- require .Eventually (t , func () bool {
168
- return coord .Node (manifest .AgentID ) != nil
169
- }, testutil .WaitShort , testutil .IntervalFast )
306
+ // Wait for the agent to connect.
307
+ require .Eventually (t , func () bool {
308
+ return coord .Node (manifest .AgentID ) != nil
309
+ }, testutil .WaitShort , testutil .IntervalFast )
310
+
311
+ agents = append (agents , agentWithID {id : manifest .AgentID , Agent : ag })
312
+ }
170
313
171
314
serverTailnet , err := coderd .NewServerTailnet (
172
315
context .Background (),
173
316
logger ,
174
317
derpServer ,
175
- func () * tailcfg.DERPMap { return manifest . DERPMap },
318
+ func () * tailcfg.DERPMap { return derpMap },
176
319
false ,
177
320
func (context.Context ) (tailnet.MultiAgentConn , error ) { return coord .ServeMultiAgent (uuid .New ()), nil },
178
321
trace .NewNoopTracerProvider (),
@@ -183,5 +326,5 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
183
326
_ = serverTailnet .Close ()
184
327
})
185
328
186
- return manifest . AgentID , ag , serverTailnet
329
+ return agents , serverTailnet
187
330
}
0 commit comments