@@ -3,10 +3,13 @@ package coderd_test
3
3
import (
4
4
"context"
5
5
"fmt"
6
+ "io"
7
+ "net"
6
8
"net/http"
7
9
"net/http/httptest"
8
- "net/netip"
9
10
"net/url"
11
+ "strconv"
12
+ "sync/atomic"
10
13
"testing"
11
14
12
15
"github.com/google/uuid"
@@ -35,9 +38,10 @@ func TestServerTailnet_AgentConn_OK(t *testing.T) {
35
38
defer cancel ()
36
39
37
40
// Connect through the ServerTailnet
38
- agentID , _ , serverTailnet := setupAgent (t , nil )
41
+ agents , serverTailnet := setupServerTailnetAgent (t , 1 )
42
+ a := agents [0 ]
39
43
40
- conn , release , err := serverTailnet .AgentConn (ctx , agentID )
44
+ conn , release , err := serverTailnet .AgentConn (ctx , a . id )
41
45
require .NoError (t , err )
42
46
defer release ()
43
47
@@ -53,12 +57,13 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
53
57
ctx , cancel := context .WithTimeout (context .Background (), testutil .WaitLong )
54
58
defer cancel ()
55
59
56
- agentID , _ , serverTailnet := setupAgent (t , nil )
60
+ agents , serverTailnet := setupServerTailnetAgent (t , 1 )
61
+ a := agents [0 ]
57
62
58
63
u , err := url .Parse (fmt .Sprintf ("http://127.0.0.1:%d" , codersdk .WorkspaceAgentHTTPAPIServerPort ))
59
64
require .NoError (t , err )
60
65
61
- rp := serverTailnet .ReverseProxy (u , u , agentID )
66
+ rp := serverTailnet .ReverseProxy (u , u , a . id )
62
67
63
68
rw := httptest .NewRecorder ()
64
69
req := httptest .NewRequest (
@@ -74,13 +79,147 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
74
79
assert .Equal (t , http .StatusOK , res .StatusCode )
75
80
})
76
81
82
+ t .Run ("HostRewrite" , func (t * testing.T ) {
83
+ t .Parallel ()
84
+
85
+ ctx , cancel := context .WithTimeout (context .Background (), testutil .WaitLong )
86
+ defer cancel ()
87
+
88
+ agents , serverTailnet := setupServerTailnetAgent (t , 1 )
89
+ a := agents [0 ]
90
+
91
+ u , err := url .Parse (fmt .Sprintf ("http://127.0.0.1:%d" , codersdk .WorkspaceAgentHTTPAPIServerPort ))
92
+ require .NoError (t , err )
93
+
94
+ rp := serverTailnet .ReverseProxy (u , u , a .id )
95
+
96
+ req , err := http .NewRequestWithContext (ctx , http .MethodGet , u .String (), nil )
97
+ require .NoError (t , err )
98
+
99
+ // Ensure the reverse proxy director rewrites the url host to the agent's IP.
100
+ rp .Director (req )
101
+ assert .Equal (t ,
102
+ fmt .Sprintf ("[%s]:%d" , tailnet .IPFromUUID (a .id ).String (), codersdk .WorkspaceAgentHTTPAPIServerPort ),
103
+ req .URL .Host ,
104
+ )
105
+ })
106
+
107
+ t .Run ("CachesConnection" , func (t * testing.T ) {
108
+ t .Parallel ()
109
+
110
+ ctx , cancel := context .WithTimeout (context .Background (), testutil .WaitLong )
111
+ defer cancel ()
112
+
113
+ agents , serverTailnet := setupServerTailnetAgent (t , 1 )
114
+ a := agents [0 ]
115
+ port := ":4444"
116
+ ln , err := a .TailnetConn ().Listen ("tcp" , port )
117
+ require .NoError (t , err )
118
+ wln := & wrappedListener {Listener : ln }
119
+
120
+ serverClosed := make (chan struct {})
121
+ go func () {
122
+ defer close (serverClosed )
123
+ //nolint:gosec
124
+ _ = http .Serve (wln , http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
125
+ w .WriteHeader (http .StatusOK )
126
+ w .Write ([]byte ("hello from agent" ))
127
+ }))
128
+ }()
129
+ defer func () {
130
+ // wait for server to close
131
+ <- serverClosed
132
+ }()
133
+
134
+ defer ln .Close ()
135
+
136
+ u , err := url .Parse ("http://127.0.0.1" + port )
137
+ require .NoError (t , err )
138
+
139
+ rp := serverTailnet .ReverseProxy (u , u , a .id )
140
+
141
+ for i := 0 ; i < 5 ; i ++ {
142
+ rw := httptest .NewRecorder ()
143
+ req := httptest .NewRequest (
144
+ http .MethodGet ,
145
+ u .String (),
146
+ nil ,
147
+ ).WithContext (ctx )
148
+
149
+ rp .ServeHTTP (rw , req )
150
+ res := rw .Result ()
151
+
152
+ _ , _ = io .Copy (io .Discard , res .Body )
153
+ res .Body .Close ()
154
+ assert .Equal (t , http .StatusOK , res .StatusCode )
155
+ }
156
+
157
+ assert .Equal (t , 1 , wln .getDials ())
158
+ })
159
+
160
+ t .Run ("NotReusedBetweenAgents" , func (t * testing.T ) {
161
+ t .Parallel ()
162
+
163
+ ctx , cancel := context .WithTimeout (context .Background (), testutil .WaitLong )
164
+ defer cancel ()
165
+
166
+ agents , serverTailnet := setupServerTailnetAgent (t , 2 )
167
+ port := ":4444"
168
+
169
+ for i , ag := range agents {
170
+ i := i
171
+ ln , err := ag .TailnetConn ().Listen ("tcp" , port )
172
+ require .NoError (t , err )
173
+ wln := & wrappedListener {Listener : ln }
174
+
175
+ serverClosed := make (chan struct {})
176
+ go func () {
177
+ defer close (serverClosed )
178
+ //nolint:gosec
179
+ _ = http .Serve (wln , http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
180
+ w .WriteHeader (http .StatusOK )
181
+ w .Write ([]byte (strconv .Itoa (i )))
182
+ }))
183
+ }()
184
+ defer func () { //nolint:revive
185
+ // wait for server to close
186
+ <- serverClosed
187
+ }()
188
+
189
+ defer ln .Close () //nolint:revive
190
+ }
191
+
192
+ u , err := url .Parse ("http://127.0.0.1" + port )
193
+ require .NoError (t , err )
194
+
195
+ for i , ag := range agents {
196
+ rp := serverTailnet .ReverseProxy (u , u , ag .id )
197
+
198
+ rw := httptest .NewRecorder ()
199
+ req := httptest .NewRequest (
200
+ http .MethodGet ,
201
+ u .String (),
202
+ nil ,
203
+ ).WithContext (ctx )
204
+
205
+ rp .ServeHTTP (rw , req )
206
+ res := rw .Result ()
207
+
208
+ body , _ := io .ReadAll (res .Body )
209
+ res .Body .Close ()
210
+ assert .Equal (t , http .StatusOK , res .StatusCode )
211
+ assert .Equal (t , strconv .Itoa (i ), string (body ))
212
+ }
213
+ })
214
+
77
215
t .Run ("HTTPSProxy" , func (t * testing.T ) {
78
216
t .Parallel ()
79
217
80
218
ctx , cancel := context .WithTimeout (context .Background (), testutil .WaitLong )
81
219
defer cancel ()
82
220
83
- agentID , _ , serverTailnet := setupAgent (t , nil )
221
+ agents , serverTailnet := setupServerTailnetAgent (t , 1 )
222
+ a := agents [0 ]
84
223
85
224
const expectedResponseCode = 209
86
225
// Test that we can proxy HTTPS traffic.
@@ -92,7 +231,7 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
92
231
uri , err := url .Parse (s .URL )
93
232
require .NoError (t , err )
94
233
95
- rp := serverTailnet .ReverseProxy (uri , uri , agentID )
234
+ rp := serverTailnet .ReverseProxy (uri , uri , a . id )
96
235
97
236
rw := httptest .NewRecorder ()
98
237
req := httptest .NewRequest (
@@ -109,44 +248,74 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
109
248
})
110
249
}
111
250
112
- func setupAgent (t * testing.T , agentAddresses []netip.Prefix ) (uuid.UUID , agent.Agent , * coderd.ServerTailnet ) {
251
+ type wrappedListener struct {
252
+ net.Listener
253
+ dials int32
254
+ }
255
+
256
+ func (w * wrappedListener ) Accept () (net.Conn , error ) {
257
+ conn , err := w .Listener .Accept ()
258
+ if err != nil {
259
+ return nil , err
260
+ }
261
+
262
+ atomic .AddInt32 (& w .dials , 1 )
263
+ return conn , nil
264
+ }
265
+
266
+ func (w * wrappedListener ) getDials () int {
267
+ return int (atomic .LoadInt32 (& w .dials ))
268
+ }
269
+
270
+ type agentWithID struct {
271
+ id uuid.UUID
272
+ agent.Agent
273
+ }
274
+
275
+ func setupServerTailnetAgent (t * testing.T , agentNum int ) ([]agentWithID , * coderd.ServerTailnet ) {
113
276
logger := slogtest .Make (t , nil ).Leveled (slog .LevelDebug )
114
277
derpMap , derpServer := tailnettest .RunDERPAndSTUN (t )
115
- manifest := agentsdk.Manifest {
116
- AgentID : uuid .New (),
117
- DERPMap : derpMap ,
118
- }
119
278
120
279
coord := tailnet .NewCoordinator (logger )
121
280
t .Cleanup (func () {
122
281
_ = coord .Close ()
123
282
})
124
283
125
- c := agenttest .NewClient (t , logger , manifest .AgentID , manifest , make (chan * agentsdk.Stats , 50 ), coord )
126
- t .Cleanup (c .Close )
284
+ agents := []agentWithID {}
127
285
128
- options := agent.Options {
129
- Client : c ,
130
- Filesystem : afero .NewMemMapFs (),
131
- Logger : logger .Named ("agent" ),
132
- Addresses : agentAddresses ,
133
- }
286
+ for i := 0 ; i < agentNum ; i ++ {
287
+ manifest := agentsdk.Manifest {
288
+ AgentID : uuid .New (),
289
+ DERPMap : derpMap ,
290
+ }
134
291
135
- ag := agent .New (options )
136
- t .Cleanup (func () {
137
- _ = ag .Close ()
138
- })
292
+ c := agenttest .NewClient (t , logger , manifest .AgentID , manifest , make (chan * agentsdk.Stats , 50 ), coord )
293
+ t .Cleanup (c .Close )
294
+
295
+ options := agent.Options {
296
+ Client : c ,
297
+ Filesystem : afero .NewMemMapFs (),
298
+ Logger : logger .Named ("agent" ),
299
+ }
139
300
140
- // Wait for the agent to connect.
141
- require .Eventually (t , func () bool {
142
- return coord .Node (manifest .AgentID ) != nil
143
- }, testutil .WaitShort , testutil .IntervalFast )
301
+ ag := agent .New (options )
302
+ t .Cleanup (func () {
303
+ _ = ag .Close ()
304
+ })
305
+
306
+ // Wait for the agent to connect.
307
+ require .Eventually (t , func () bool {
308
+ return coord .Node (manifest .AgentID ) != nil
309
+ }, testutil .WaitShort , testutil .IntervalFast )
310
+
311
+ agents = append (agents , agentWithID {id : manifest .AgentID , Agent : ag })
312
+ }
144
313
145
314
serverTailnet , err := coderd .NewServerTailnet (
146
315
context .Background (),
147
316
logger ,
148
317
derpServer ,
149
- func () * tailcfg.DERPMap { return manifest . DERPMap },
318
+ func () * tailcfg.DERPMap { return derpMap },
150
319
false ,
151
320
func (context.Context ) (tailnet.MultiAgentConn , error ) { return coord .ServeMultiAgent (uuid .New ()), nil },
152
321
trace .NewNoopTracerProvider (),
@@ -157,5 +326,5 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
157
326
_ = serverTailnet .Close ()
158
327
})
159
328
160
- return manifest . AgentID , ag , serverTailnet
329
+ return agents , serverTailnet
161
330
}
0 commit comments