Skip to content

Commit fa4361d

Browse files
authored
restore devtunnel test (coder#3050)
* Dev tunnel test uses local fake server; fixed port Signed-off-by: Spike Curtis <spike@coder.com> * Remove parallel for test Signed-off-by: Spike Curtis <spike@coder.com> * Fix segfault
1 parent 882ee55 commit fa4361d

File tree

2 files changed

+189
-31
lines changed

2 files changed

+189
-31
lines changed

coderd/devtunnel/tunnel.go

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,19 @@ type Config struct {
3636
PublicKey device.NoisePublicKey `json:"public_key"`
3737

3838
Tunnel Node `json:"tunnel"`
39+
40+
// Used in testing. Normally this is nil, indicating to use DefaultClient.
41+
HTTPClient *http.Client `json:"-"`
3942
}
4043
type configExt struct {
4144
Version int `json:"-"`
4245
PrivateKey device.NoisePrivateKey `json:"-"`
4346
PublicKey device.NoisePublicKey `json:"public_key"`
4447

4548
Tunnel Node `json:"-"`
49+
50+
// Used in testing. Normally this is nil, indicating to use DefaultClient.
51+
HTTPClient *http.Client `json:"-"`
4652
}
4753

4854
// NewWithConfig calls New with the given config. For documentation, see New.
@@ -65,17 +71,23 @@ func NewWithConfig(ctx context.Context, logger slog.Logger, cfg Config) (*Tunnel
6571
if err != nil {
6672
return nil, nil, xerrors.Errorf("resolve endpoint: %w", err)
6773
}
74+
// In IPv6, we need to enclose the address to in [] before passing to wireguard's endpoint key, like
75+
// [2001:abcd::1]:8888. We'll use netip.AddrPort to correctly handle this.
76+
wgAddr, err := netip.ParseAddr(wgip.String())
77+
if err != nil {
78+
return nil, nil, xerrors.Errorf("parse address: %w", err)
79+
}
80+
wgEndpoint := netip.AddrPortFrom(wgAddr, cfg.Tunnel.WireguardPort)
6881

69-
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelSilent, ""))
82+
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelError, "devtunnel "))
7083
err = dev.IpcSet(fmt.Sprintf(`private_key=%s
7184
public_key=%s
72-
endpoint=%s:%d
85+
endpoint=%s
7386
persistent_keepalive_interval=21
7487
allowed_ip=%s/128`,
7588
hex.EncodeToString(cfg.PrivateKey[:]),
7689
server.ServerPublicKey,
77-
wgip.IP.String(),
78-
cfg.Tunnel.WireguardPort,
90+
wgEndpoint.String(),
7991
server.ServerIP.String(),
8092
))
8193
if err != nil {
@@ -97,6 +109,9 @@ allowed_ip=%s/128`,
97109
select {
98110
case <-ctx.Done():
99111
_ = wgListen.Close()
112+
// We need to remove peers before closing to avoid a race condition between dev.Close() and the peer
113+
// goroutines which results in segfault.
114+
dev.RemoveAllPeers()
100115
dev.Close()
101116
<-routineEnd
102117
close(ch)
@@ -174,7 +189,11 @@ func sendConfigToServer(ctx context.Context, cfg Config) (ServerResponse, error)
174189
return ServerResponse{}, xerrors.Errorf("new request: %w", err)
175190
}
176191

177-
res, err := http.DefaultClient.Do(req)
192+
client := http.DefaultClient
193+
if cfg.HTTPClient != nil {
194+
client = cfg.HTTPClient
195+
}
196+
res, err := client.Do(req)
178197
if err != nil {
179198
return ServerResponse{}, xerrors.Errorf("do request: %w", err)
180199
}

coderd/devtunnel/tunnel_test.go

Lines changed: 165 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,74 +2,89 @@ package devtunnel_test
22

33
import (
44
"context"
5+
"encoding/hex"
6+
"encoding/json"
7+
"fmt"
58
"io"
69
"net"
710
"net/http"
11+
"net/http/httptest"
12+
"net/netip"
13+
"strings"
814
"testing"
915
"time"
1016

17+
"cdr.dev/slog"
18+
1119
"github.com/stretchr/testify/assert"
1220
"github.com/stretchr/testify/require"
21+
"golang.zx2c4.com/wireguard/conn"
22+
"golang.zx2c4.com/wireguard/device"
23+
"golang.zx2c4.com/wireguard/tun/netstack"
24+
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
1325

1426
"cdr.dev/slog/sloggers/slogtest"
1527
"github.com/coder/coder/coderd/devtunnel"
1628
)
1729

30+
const (
31+
ipByte1 = 0xfc
32+
ipByte2 = 0xca
33+
wgPort = 48732
34+
)
35+
36+
var (
37+
serverIP = netip.AddrFrom16([16]byte{ipByte1, ipByte2, 15: 0x1})
38+
dnsIP = netip.AddrFrom4([4]byte{1, 1, 1, 1})
39+
clientIP = netip.AddrFrom16([16]byte{ipByte1, ipByte2, 15: 0x2})
40+
)
41+
1842
// The tunnel leaks a few goroutines that aren't impactful to production scenarios.
1943
// func TestMain(m *testing.M) {
2044
// goleak.VerifyTestMain(m)
2145
// }
2246

47+
// TestTunnel cannot run in parallel because we hardcode the UDP port used by the wireguard server.
48+
// nolint: paralleltest
2349
func TestTunnel(t *testing.T) {
24-
t.Parallel()
25-
26-
// It's not super useful for us to test this constantly, it'll only cause
27-
// flakes is the tunnel becomes unavailable for some reason.
28-
t.Skip()
29-
// if testing.Short() {
30-
// t.Skip()
31-
// return
32-
// }
33-
3450
ctx, cancelTun := context.WithCancel(context.Background())
3551
defer cancelTun()
3652

3753
server := http.Server{
3854
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
39-
w.WriteHeader(http.StatusOK)
55+
t.Log("got request for", r.URL)
56+
// Going to use something _slightly_ exotic so that we can't accidentally get some
57+
// default behavior creating a false positive on the test
58+
w.WriteHeader(http.StatusAccepted)
4059
}),
4160
BaseContext: func(_ net.Listener) context.Context {
4261
return ctx
4362
},
4463
}
4564

46-
cfg, err := devtunnel.GenerateConfig()
47-
require.NoError(t, err)
65+
fTunServer := newFakeTunnelServer(t)
66+
cfg := fTunServer.config()
4867

49-
tun, errCh, err := devtunnel.NewWithConfig(ctx, slogtest.Make(t, nil), cfg)
68+
tun, errCh, err := devtunnel.NewWithConfig(ctx, slogtest.Make(t, nil).Leveled(slog.LevelDebug), cfg)
5069
require.NoError(t, err)
5170
t.Log(tun.URL)
5271

53-
go server.Serve(tun.Listener)
54-
defer tun.Listener.Close()
55-
56-
httpClient := &http.Client{
57-
Timeout: 10 * time.Second,
58-
}
72+
go func() {
73+
err := server.Serve(tun.Listener)
74+
assert.Equal(t, http.ErrServerClosed, err)
75+
}()
76+
t.Cleanup(func() { _ = server.Close() })
77+
t.Cleanup(func() { tun.Listener.Close() })
5978

6079
require.Eventually(t, func() bool {
61-
req, err := http.NewRequestWithContext(ctx, "GET", tun.URL, nil)
62-
require.NoError(t, err)
63-
64-
res, err := httpClient.Do(req)
80+
res, err := fTunServer.requestHTTP()
6581
require.NoError(t, err)
6682
defer res.Body.Close()
6783
_, _ = io.Copy(io.Discard, res.Body)
6884

69-
return res.StatusCode == http.StatusOK
85+
return res.StatusCode == http.StatusAccepted
7086
}, time.Minute, time.Second)
7187

72-
httpClient.CloseIdleConnections()
7388
assert.NoError(t, server.Close())
7489
cancelTun()
7590

@@ -79,3 +94,127 @@ func TestTunnel(t *testing.T) {
7994
t.Error("tunnel did not close after 10 seconds")
8095
}
8196
}
97+
98+
// fakeTunnelServer is a fake version of the real dev tunnel server. It fakes 2 client interactions
99+
// that we want to test:
100+
// 1. Responding to a POST /tun from the client
101+
// 2. Sending an HTTP request down the wireguard connection
102+
//
103+
// Note that for 2, we don't implement a full proxy that accepts arbitrary requests, we just send
104+
// a test request over the Wireguard tunnel to make sure that we can listen. The proxy behavior is
105+
// outside of the scope of the dev tunnel client, which is what we are testing here.
106+
type fakeTunnelServer struct {
107+
t *testing.T
108+
pub device.NoisePublicKey
109+
priv device.NoisePrivateKey
110+
tnet *netstack.Net
111+
device *device.Device
112+
clients int
113+
server *httptest.Server
114+
}
115+
116+
func newFakeTunnelServer(t *testing.T) *fakeTunnelServer {
117+
priv, err := wgtypes.GeneratePrivateKey()
118+
privBytes := [32]byte(priv)
119+
require.NoError(t, err)
120+
pub := priv.PublicKey()
121+
pubBytes := [32]byte(pub)
122+
tun, tnet, err := netstack.CreateNetTUN(
123+
[]netip.Addr{serverIP},
124+
[]netip.Addr{dnsIP},
125+
1280,
126+
)
127+
require.NoError(t, err)
128+
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "server "))
129+
err = dev.IpcSet(fmt.Sprintf(`private_key=%s
130+
listen_port=%d`,
131+
hex.EncodeToString(privBytes[:]),
132+
wgPort,
133+
))
134+
require.NoError(t, err)
135+
t.Cleanup(func() {
136+
dev.RemoveAllPeers()
137+
dev.Close()
138+
})
139+
140+
err = dev.Up()
141+
require.NoError(t, err)
142+
143+
server := newFakeTunnelHTTPSServer(t, pubBytes)
144+
145+
return &fakeTunnelServer{
146+
t: t,
147+
pub: device.NoisePublicKey(pub),
148+
priv: device.NoisePrivateKey(priv),
149+
tnet: tnet,
150+
device: dev,
151+
server: server,
152+
}
153+
}
154+
155+
func newFakeTunnelHTTPSServer(t *testing.T, pubBytes [32]byte) *httptest.Server {
156+
handler := http.NewServeMux()
157+
handler.HandleFunc("/tun", func(writer http.ResponseWriter, request *http.Request) {
158+
assert.Equal(t, "POST", request.Method)
159+
160+
resp := devtunnel.ServerResponse{
161+
Hostname: fmt.Sprintf("[%s]", serverIP.String()),
162+
ServerIP: serverIP,
163+
ServerPublicKey: hex.EncodeToString(pubBytes[:]),
164+
ClientIP: clientIP,
165+
}
166+
b, err := json.Marshal(&resp)
167+
assert.NoError(t, err)
168+
writer.WriteHeader(200)
169+
_, err = writer.Write(b)
170+
assert.NoError(t, err)
171+
})
172+
173+
server := httptest.NewTLSServer(handler)
174+
t.Cleanup(func() {
175+
server.Close()
176+
})
177+
return server
178+
}
179+
180+
func (f *fakeTunnelServer) config() devtunnel.Config {
181+
priv, err := wgtypes.GeneratePrivateKey()
182+
require.NoError(f.t, err)
183+
pub := priv.PublicKey()
184+
f.clients++
185+
assert.Equal(f.t, 1, f.clients) // only allow one client as we hardcode the address
186+
187+
err = f.device.IpcSet(fmt.Sprintf(`public_key=%x
188+
allowed_ip=%s/128`,
189+
pub[:],
190+
clientIP.String(),
191+
))
192+
require.NoError(f.t, err)
193+
return devtunnel.Config{
194+
Version: 1,
195+
PrivateKey: device.NoisePrivateKey(priv),
196+
PublicKey: device.NoisePublicKey(pub),
197+
Tunnel: devtunnel.Node{
198+
HostnameHTTPS: strings.TrimPrefix(f.server.URL, "https://"),
199+
HostnameWireguard: "localhost",
200+
WireguardPort: wgPort,
201+
},
202+
HTTPClient: f.server.Client(),
203+
}
204+
}
205+
206+
func (f *fakeTunnelServer) requestHTTP() (*http.Response, error) {
207+
transport := &http.Transport{
208+
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
209+
f.t.Log("Dial", network, addr)
210+
nc, err := f.tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(clientIP, 8090))
211+
assert.NoError(f.t, err)
212+
return nc, err
213+
},
214+
}
215+
client := &http.Client{
216+
Transport: transport,
217+
Timeout: 10 * time.Second,
218+
}
219+
return client.Get(fmt.Sprintf("http://[%s]:8090", clientIP))
220+
}

0 commit comments

Comments
 (0)