@@ -2,17 +2,26 @@ package devtunnel_test
2
2
3
3
import (
4
4
"context"
5
+ "encoding/hex"
6
+ "encoding/json"
7
+ "fmt"
5
8
"io"
6
9
"net"
7
10
"net/http"
11
+ "net/http/httptest"
12
+ "net/netip"
13
+ "strings"
8
14
"testing"
9
15
"time"
10
16
11
- "github.com/stretchr/testify/assert"
12
- "github.com/stretchr/testify/require"
13
-
14
17
"cdr.dev/slog/sloggers/slogtest"
15
18
"github.com/coder/coder/coderd/devtunnel"
19
+ "github.com/stretchr/testify/assert"
20
+ "github.com/stretchr/testify/require"
21
+ "golang.zx2c4.com/wireguard/conn"
22
+ "golang.zx2c4.com/wireguard/device"
23
+ "golang.zx2c4.com/wireguard/tun/netstack"
24
+ "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
16
25
)
17
26
18
27
// The tunnel leaks a few goroutines that aren't impactful to production scenarios.
@@ -23,53 +32,42 @@ import (
23
32
func TestTunnel (t * testing.T ) {
24
33
t .Parallel ()
25
34
26
- // It's not super useful for us to test this constantly, it'll only cause
27
- // flakes is the tunnel becomes unavailable for some reason.
28
- t .Skip ()
29
- // if testing.Short() {
30
- // t.Skip()
31
- // return
32
- // }
33
-
34
35
ctx , cancelTun := context .WithCancel (context .Background ())
35
36
defer cancelTun ()
36
37
37
38
server := http.Server {
38
39
Handler : http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
39
- w .WriteHeader (http .StatusOK )
40
+ t .Log ("got request for" , r .URL )
41
+ w .WriteHeader (http .StatusAccepted )
40
42
}),
41
43
BaseContext : func (_ net.Listener ) context.Context {
42
44
return ctx
43
45
},
44
46
}
45
47
46
- cfg , err := devtunnel . GenerateConfig ( )
47
- require . NoError ( t , err )
48
+ fTunServer := newFakeTunnelServer ( t )
49
+ cfg := fTunServer . config ( )
48
50
49
51
tun , errCh , err := devtunnel .NewWithConfig (ctx , slogtest .Make (t , nil ), cfg )
50
52
require .NoError (t , err )
51
53
t .Log (tun .URL )
52
54
53
- go server . Serve ( tun . Listener )
54
- defer tun .Listener . Close ( )
55
-
56
- httpClient := & http. Client {
57
- Timeout : 10 * time . Second ,
58
- }
55
+ go func () {
56
+ err := server . Serve ( tun .Listener )
57
+ assert . Equal ( t , http . ErrServerClosed , err )
58
+ }()
59
+ t . Cleanup ( func () { _ = server . Close () })
60
+ t . Cleanup ( func () { tun . Listener . Close () })
59
61
60
62
require .Eventually (t , func () bool {
61
- req , err := http .NewRequestWithContext (ctx , "GET" , tun .URL , nil )
62
- require .NoError (t , err )
63
-
64
- res , err := httpClient .Do (req )
63
+ res , err := fTunServer .requestHttp ()
65
64
require .NoError (t , err )
66
65
defer res .Body .Close ()
67
66
_ , _ = io .Copy (io .Discard , res .Body )
68
67
69
68
return res .StatusCode == http .StatusOK
70
69
}, time .Minute , time .Second )
71
70
72
- httpClient .CloseIdleConnections ()
73
71
assert .NoError (t , server .Close ())
74
72
cancelTun ()
75
73
@@ -79,3 +77,135 @@ func TestTunnel(t *testing.T) {
79
77
t .Error ("tunnel did not close after 10 seconds" )
80
78
}
81
79
}
80
+
81
+ // fakeTunnelServer is a fake version of the real dev tunnel server. It fakes 2 client interactions
82
+ // that we want to test:
83
+ // 1. Responding to a POST /tun from the client
84
+ // 2. Sending an HTTP request down the wireguard connection
85
+ //
86
+ // Note that for 2, we don't implement a full proxy that accepts arbitrary requests, we just send
87
+ // a test request over the Wireguard tunnel to make sure that we can listen. The proxy behavior is
88
+ // outside of the scope of the dev tunnel client, which is what we are testing here.
89
+ type fakeTunnelServer struct {
90
+ t * testing.T
91
+ pub device.NoisePublicKey
92
+ priv device.NoisePrivateKey
93
+ tnet * netstack.Net
94
+ device * device.Device
95
+ clients int
96
+ server * httptest.Server
97
+ }
98
+
99
+ const (
100
+ ipByte1 = 0xfc
101
+ ipByte2 = 0xca
102
+ wgPort = 48732
103
+ )
104
+
105
+ var (
106
+ serverIp = netip .AddrFrom16 ([16 ]byte {ipByte1 , ipByte2 , 15 : 0x1 })
107
+ dnsIp = netip .AddrFrom4 ([4 ]byte {1 , 1 , 1 , 1 })
108
+ clientIp = netip .AddrFrom16 ([16 ]byte {ipByte1 , ipByte2 , 15 : 0x2 })
109
+ )
110
+
111
+ func newFakeTunnelServer (t * testing.T ) * fakeTunnelServer {
112
+ priv , err := wgtypes .GeneratePrivateKey ()
113
+ privBytes := [32 ]byte (priv )
114
+ require .NoError (t , err )
115
+ pub := priv .PublicKey ()
116
+ pubBytes := [32 ]byte (pub )
117
+ tun , tnet , err := netstack .CreateNetTUN (
118
+ []netip.Addr {serverIp },
119
+ []netip.Addr {dnsIp },
120
+ 1280 ,
121
+ )
122
+ require .NoError (t , err )
123
+ dev := device .NewDevice (tun , conn .NewDefaultBind (), device .NewLogger (device .LogLevelVerbose , "" ))
124
+ err = dev .IpcSet (fmt .Sprintf (`private_key=%s
125
+ listen_port=%d` ,
126
+ hex .EncodeToString (privBytes [:]),
127
+ wgPort ,
128
+ ))
129
+ require .NoError (t , err )
130
+
131
+ err = dev .Up ()
132
+ require .NoError (t , err )
133
+
134
+ server := newFakeTunnelHttpsServer (t , pubBytes )
135
+
136
+ return & fakeTunnelServer {
137
+ t : t ,
138
+ pub : device .NoisePublicKey (pub ),
139
+ priv : device .NoisePrivateKey (priv ),
140
+ tnet : tnet ,
141
+ device : dev ,
142
+ server : server ,
143
+ }
144
+ }
145
+
146
+ func newFakeTunnelHttpsServer (t * testing.T , pubBytes [32 ]byte ) * httptest.Server {
147
+ handler := http .NewServeMux ()
148
+ handler .HandleFunc ("/tun" , func (writer http.ResponseWriter , request * http.Request ) {
149
+ assert .Equal (t , "POST" , request .Method )
150
+
151
+ resp := devtunnel.ServerResponse {
152
+ Hostname : fmt .Sprintf ("[%s]" , serverIp .String ()),
153
+ ServerIP : serverIp ,
154
+ ServerPublicKey : hex .EncodeToString (pubBytes [:]),
155
+ ClientIP : clientIp ,
156
+ }
157
+ b , err := json .Marshal (& resp )
158
+ assert .NoError (t , err )
159
+ writer .WriteHeader (200 )
160
+ _ , err = writer .Write (b )
161
+ assert .NoError (t , err )
162
+ })
163
+
164
+ server := httptest .NewTLSServer (handler )
165
+ t .Cleanup (func () {
166
+ server .Close ()
167
+ })
168
+ return server
169
+ }
170
+
171
+ func (f * fakeTunnelServer ) config () devtunnel.Config {
172
+ priv , err := wgtypes .GeneratePrivateKey ()
173
+ require .NoError (f .t , err )
174
+ pub := priv .PublicKey ()
175
+ f .clients ++
176
+ assert .Equal (f .t , 1 , f .clients ) // only allow one client as we hardcode the address
177
+
178
+ err = f .device .IpcSet (fmt .Sprintf (`public_key=%x
179
+ allowed_ip=%s/128` ,
180
+ pub [:],
181
+ clientIp .String (),
182
+ ))
183
+ require .NoError (f .t , err )
184
+ return devtunnel.Config {
185
+ Version : 1 ,
186
+ PrivateKey : device .NoisePrivateKey (priv ),
187
+ PublicKey : device .NoisePublicKey (pub ),
188
+ Tunnel : devtunnel.Node {
189
+ HostnameHTTPS : strings .TrimPrefix (f .server .URL , "https://" ),
190
+ HostnameWireguard : "::1" ,
191
+ WireguardPort : wgPort ,
192
+ },
193
+ HTTPClient : f .server .Client (),
194
+ }
195
+ }
196
+
197
+ func (f * fakeTunnelServer ) requestHttp () (* http.Response , error ) {
198
+ transport := & http.Transport {
199
+ DialContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
200
+ f .t .Log ("Dial" , network , addr )
201
+ nc , err := f .tnet .DialContextTCPAddrPort (ctx , netip .AddrPortFrom (clientIp , 8090 ))
202
+ assert .NoError (f .t , err )
203
+ return nc , err
204
+ },
205
+ }
206
+ client := & http.Client {
207
+ Transport : transport ,
208
+ Timeout : 10 * time .Second ,
209
+ }
210
+ return client .Get (fmt .Sprintf ("http://[%s]:8090" , clientIp ))
211
+ }
0 commit comments