@@ -14,30 +14,33 @@ import (
14
14
"testing"
15
15
"time"
16
16
17
- "cdr.dev/slog/sloggers/slogtest"
18
- "github.com/coder/coder/coderd/devtunnel"
19
17
"github.com/stretchr/testify/assert"
20
18
"github.com/stretchr/testify/require"
21
19
"golang.zx2c4.com/wireguard/conn"
22
20
"golang.zx2c4.com/wireguard/device"
23
21
"golang.zx2c4.com/wireguard/tun/netstack"
24
22
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
23
+
24
+ "cdr.dev/slog/sloggers/slogtest"
25
+ "github.com/coder/coder/coderd/devtunnel"
25
26
)
26
27
27
28
// The tunnel leaks a few goroutines that aren't impactful to production scenarios.
28
29
// func TestMain(m *testing.M) {
29
30
// goleak.VerifyTestMain(m)
30
31
// }
31
32
33
+ // TestTunnel cannot run in parallel because we hardcode the UDP port used by the wireguard server.
34
+ // nolint: tparallel
32
35
func TestTunnel (t * testing.T ) {
33
- t .Parallel ()
34
-
35
36
ctx , cancelTun := context .WithCancel (context .Background ())
36
37
defer cancelTun ()
37
38
38
39
server := http.Server {
39
40
Handler : http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
40
41
t .Log ("got request for" , r .URL )
42
+ // Going to use something _slightly_ exotic so that we can't accidentally get some
43
+ // default behavior creating a false positive on the test
41
44
w .WriteHeader (http .StatusAccepted )
42
45
}),
43
46
BaseContext : func (_ net.Listener ) context.Context {
@@ -60,12 +63,12 @@ func TestTunnel(t *testing.T) {
60
63
t .Cleanup (func () { tun .Listener .Close () })
61
64
62
65
require .Eventually (t , func () bool {
63
- res , err := fTunServer .requestHttp ()
66
+ res , err := fTunServer .requestHTTP ()
64
67
require .NoError (t , err )
65
68
defer res .Body .Close ()
66
69
_ , _ = io .Copy (io .Discard , res .Body )
67
70
68
- return res .StatusCode == http .StatusOK
71
+ return res .StatusCode == http .StatusAccepted
69
72
}, time .Minute , time .Second )
70
73
71
74
assert .NoError (t , server .Close ())
@@ -103,9 +106,9 @@ const (
103
106
)
104
107
105
108
var (
106
- serverIp = netip .AddrFrom16 ([16 ]byte {ipByte1 , ipByte2 , 15 : 0x1 })
107
- dnsIp = netip .AddrFrom4 ([4 ]byte {1 , 1 , 1 , 1 })
108
- clientIp = netip .AddrFrom16 ([16 ]byte {ipByte1 , ipByte2 , 15 : 0x2 })
109
+ serverIP = netip .AddrFrom16 ([16 ]byte {ipByte1 , ipByte2 , 15 : 0x1 })
110
+ dnsIP = netip .AddrFrom4 ([4 ]byte {1 , 1 , 1 , 1 })
111
+ clientIP = netip .AddrFrom16 ([16 ]byte {ipByte1 , ipByte2 , 15 : 0x2 })
109
112
)
110
113
111
114
func newFakeTunnelServer (t * testing.T ) * fakeTunnelServer {
@@ -115,8 +118,8 @@ func newFakeTunnelServer(t *testing.T) *fakeTunnelServer {
115
118
pub := priv .PublicKey ()
116
119
pubBytes := [32 ]byte (pub )
117
120
tun , tnet , err := netstack .CreateNetTUN (
118
- []netip.Addr {serverIp },
119
- []netip.Addr {dnsIp },
121
+ []netip.Addr {serverIP },
122
+ []netip.Addr {dnsIP },
120
123
1280 ,
121
124
)
122
125
require .NoError (t , err )
@@ -127,11 +130,14 @@ listen_port=%d`,
127
130
wgPort ,
128
131
))
129
132
require .NoError (t , err )
133
+ t .Cleanup (func () {
134
+ dev .Close ()
135
+ })
130
136
131
137
err = dev .Up ()
132
138
require .NoError (t , err )
133
139
134
- server := newFakeTunnelHttpsServer (t , pubBytes )
140
+ server := newFakeTunnelHTTPSServer (t , pubBytes )
135
141
136
142
return & fakeTunnelServer {
137
143
t : t ,
@@ -143,16 +149,16 @@ listen_port=%d`,
143
149
}
144
150
}
145
151
146
- func newFakeTunnelHttpsServer (t * testing.T , pubBytes [32 ]byte ) * httptest.Server {
152
+ func newFakeTunnelHTTPSServer (t * testing.T , pubBytes [32 ]byte ) * httptest.Server {
147
153
handler := http .NewServeMux ()
148
154
handler .HandleFunc ("/tun" , func (writer http.ResponseWriter , request * http.Request ) {
149
155
assert .Equal (t , "POST" , request .Method )
150
156
151
157
resp := devtunnel.ServerResponse {
152
- Hostname : fmt .Sprintf ("[%s]" , serverIp .String ()),
153
- ServerIP : serverIp ,
158
+ Hostname : fmt .Sprintf ("[%s]" , serverIP .String ()),
159
+ ServerIP : serverIP ,
154
160
ServerPublicKey : hex .EncodeToString (pubBytes [:]),
155
- ClientIP : clientIp ,
161
+ ClientIP : clientIP ,
156
162
}
157
163
b , err := json .Marshal (& resp )
158
164
assert .NoError (t , err )
@@ -178,7 +184,7 @@ func (f *fakeTunnelServer) config() devtunnel.Config {
178
184
err = f .device .IpcSet (fmt .Sprintf (`public_key=%x
179
185
allowed_ip=%s/128` ,
180
186
pub [:],
181
- clientIp .String (),
187
+ clientIP .String (),
182
188
))
183
189
require .NoError (f .t , err )
184
190
return devtunnel.Config {
@@ -194,11 +200,11 @@ allowed_ip=%s/128`,
194
200
}
195
201
}
196
202
197
- func (f * fakeTunnelServer ) requestHttp () (* http.Response , error ) {
203
+ func (f * fakeTunnelServer ) requestHTTP () (* http.Response , error ) {
198
204
transport := & http.Transport {
199
205
DialContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
200
206
f .t .Log ("Dial" , network , addr )
201
- nc , err := f .tnet .DialContextTCPAddrPort (ctx , netip .AddrPortFrom (clientIp , 8090 ))
207
+ nc , err := f .tnet .DialContextTCPAddrPort (ctx , netip .AddrPortFrom (clientIP , 8090 ))
202
208
assert .NoError (f .t , err )
203
209
return nc , err
204
210
},
@@ -207,5 +213,5 @@ func (f *fakeTunnelServer) requestHttp() (*http.Response, error) {
207
213
Transport : transport ,
208
214
Timeout : 10 * time .Second ,
209
215
}
210
- return client .Get (fmt .Sprintf ("http://[%s]:8090" , clientIp ))
216
+ return client .Get (fmt .Sprintf ("http://[%s]:8090" , clientIP ))
211
217
}
0 commit comments