Skip to content

Commit c8dde6e

Browse files
committed
Dev tunnel test uses local fake server; fixed port
Signed-off-by: Spike Curtis <spike@coder.com>
1 parent a66b852 commit c8dde6e

File tree

2 files changed

+168
-28
lines changed

2 files changed

+168
-28
lines changed

coderd/devtunnel/tunnel.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,19 @@ type Config struct {
3636
PublicKey device.NoisePublicKey `json:"public_key"`
3737

3838
Tunnel Node `json:"tunnel"`
39+
40+
// Used in testing. Normally this is nil, indicating to use DefaultClient.
41+
HTTPClient *http.Client `json:"-"`
3942
}
4043
type configExt struct {
4144
Version int `json:"-"`
4245
PrivateKey device.NoisePrivateKey `json:"-"`
4346
PublicKey device.NoisePublicKey `json:"public_key"`
4447

4548
Tunnel Node `json:"-"`
49+
50+
// Used in testing. Normally this is nil, indicating to use DefaultClient.
51+
HTTPClient *http.Client `json:"-"`
4652
}
4753

4854
// NewWithConfig calls New with the given config. For documentation, see New.
@@ -54,7 +60,7 @@ func NewWithConfig(ctx context.Context, logger slog.Logger, cfg Config) (*Tunnel
5460

5561
tun, tnet, err := netstack.CreateNetTUN(
5662
[]netip.Addr{server.ClientIP},
57-
[]netip.Addr{netip.AddrFrom4([4]byte{1, 1, 1, 1})},
63+
[]netip.Addr{},
5864
1280,
5965
)
6066
if err != nil {
@@ -69,7 +75,7 @@ func NewWithConfig(ctx context.Context, logger slog.Logger, cfg Config) (*Tunnel
6975
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelSilent, ""))
7076
err = dev.IpcSet(fmt.Sprintf(`private_key=%s
7177
public_key=%s
72-
endpoint=%s:%d
78+
endpoint=[%s]:%d
7379
persistent_keepalive_interval=21
7480
allowed_ip=%s/128`,
7581
hex.EncodeToString(cfg.PrivateKey[:]),
@@ -174,7 +180,11 @@ func sendConfigToServer(ctx context.Context, cfg Config) (ServerResponse, error)
174180
return ServerResponse{}, xerrors.Errorf("new request: %w", err)
175181
}
176182

177-
res, err := http.DefaultClient.Do(req)
183+
client := http.DefaultClient
184+
if cfg.HTTPClient != nil {
185+
client = cfg.HTTPClient
186+
}
187+
res, err := client.Do(req)
178188
if err != nil {
179189
return ServerResponse{}, xerrors.Errorf("do request: %w", err)
180190
}

coderd/devtunnel/tunnel_test.go

Lines changed: 155 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,26 @@ package devtunnel_test
22

33
import (
44
"context"
5+
"encoding/hex"
6+
"encoding/json"
7+
"fmt"
58
"io"
69
"net"
710
"net/http"
11+
"net/http/httptest"
12+
"net/netip"
13+
"strings"
814
"testing"
915
"time"
1016

11-
"github.com/stretchr/testify/assert"
12-
"github.com/stretchr/testify/require"
13-
1417
"cdr.dev/slog/sloggers/slogtest"
1518
"github.com/coder/coder/coderd/devtunnel"
19+
"github.com/stretchr/testify/assert"
20+
"github.com/stretchr/testify/require"
21+
"golang.zx2c4.com/wireguard/conn"
22+
"golang.zx2c4.com/wireguard/device"
23+
"golang.zx2c4.com/wireguard/tun/netstack"
24+
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
1625
)
1726

1827
// The tunnel leaks a few goroutines that aren't impactful to production scenarios.
@@ -23,53 +32,42 @@ import (
2332
func TestTunnel(t *testing.T) {
2433
t.Parallel()
2534

26-
// It's not super useful for us to test this constantly, it'll only cause
27-
// flakes is the tunnel becomes unavailable for some reason.
28-
t.Skip()
29-
// if testing.Short() {
30-
// t.Skip()
31-
// return
32-
// }
33-
3435
ctx, cancelTun := context.WithCancel(context.Background())
3536
defer cancelTun()
3637

3738
server := http.Server{
3839
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
39-
w.WriteHeader(http.StatusOK)
40+
t.Log("got request for", r.URL)
41+
w.WriteHeader(http.StatusAccepted)
4042
}),
4143
BaseContext: func(_ net.Listener) context.Context {
4244
return ctx
4345
},
4446
}
4547

46-
cfg, err := devtunnel.GenerateConfig()
47-
require.NoError(t, err)
48+
fTunServer := newFakeTunnelServer(t)
49+
cfg := fTunServer.config()
4850

4951
tun, errCh, err := devtunnel.NewWithConfig(ctx, slogtest.Make(t, nil), cfg)
5052
require.NoError(t, err)
5153
t.Log(tun.URL)
5254

53-
go server.Serve(tun.Listener)
54-
defer tun.Listener.Close()
55-
56-
httpClient := &http.Client{
57-
Timeout: 10 * time.Second,
58-
}
55+
go func() {
56+
err := server.Serve(tun.Listener)
57+
assert.Equal(t, http.ErrServerClosed, err)
58+
}()
59+
t.Cleanup(func() { _ = server.Close() })
60+
t.Cleanup(func() { tun.Listener.Close() })
5961

6062
require.Eventually(t, func() bool {
61-
req, err := http.NewRequestWithContext(ctx, "GET", tun.URL, nil)
62-
require.NoError(t, err)
63-
64-
res, err := httpClient.Do(req)
63+
res, err := fTunServer.requestHttp()
6564
require.NoError(t, err)
6665
defer res.Body.Close()
6766
_, _ = io.Copy(io.Discard, res.Body)
6867

6968
return res.StatusCode == http.StatusOK
7069
}, time.Minute, time.Second)
7170

72-
httpClient.CloseIdleConnections()
7371
assert.NoError(t, server.Close())
7472
cancelTun()
7573

@@ -79,3 +77,135 @@ func TestTunnel(t *testing.T) {
7977
t.Error("tunnel did not close after 10 seconds")
8078
}
8179
}
80+
81+
// fakeTunnelServer is a fake version of the real dev tunnel server. It fakes 2 client interactions
82+
// that we want to test:
83+
// 1. Responding to a POST /tun from the client
84+
// 2. Sending an HTTP request down the wireguard connection
85+
//
86+
// Note that for 2, we don't implement a full proxy that accepts arbitrary requests, we just send
87+
// a test request over the Wireguard tunnel to make sure that we can listen. The proxy behavior is
88+
// outside of the scope of the dev tunnel client, which is what we are testing here.
89+
type fakeTunnelServer struct {
90+
t *testing.T
91+
pub device.NoisePublicKey
92+
priv device.NoisePrivateKey
93+
tnet *netstack.Net
94+
device *device.Device
95+
clients int
96+
server *httptest.Server
97+
}
98+
99+
const (
100+
ipByte1 = 0xfc
101+
ipByte2 = 0xca
102+
wgPort = 48732
103+
)
104+
105+
var (
106+
serverIp = netip.AddrFrom16([16]byte{ipByte1, ipByte2, 15: 0x1})
107+
dnsIp = netip.AddrFrom4([4]byte{1, 1, 1, 1})
108+
clientIp = netip.AddrFrom16([16]byte{ipByte1, ipByte2, 15: 0x2})
109+
)
110+
111+
func newFakeTunnelServer(t *testing.T) *fakeTunnelServer {
112+
priv, err := wgtypes.GeneratePrivateKey()
113+
privBytes := [32]byte(priv)
114+
require.NoError(t, err)
115+
pub := priv.PublicKey()
116+
pubBytes := [32]byte(pub)
117+
tun, tnet, err := netstack.CreateNetTUN(
118+
[]netip.Addr{serverIp},
119+
[]netip.Addr{dnsIp},
120+
1280,
121+
)
122+
require.NoError(t, err)
123+
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
124+
err = dev.IpcSet(fmt.Sprintf(`private_key=%s
125+
listen_port=%d`,
126+
hex.EncodeToString(privBytes[:]),
127+
wgPort,
128+
))
129+
require.NoError(t, err)
130+
131+
err = dev.Up()
132+
require.NoError(t, err)
133+
134+
server := newFakeTunnelHttpsServer(t, pubBytes)
135+
136+
return &fakeTunnelServer{
137+
t: t,
138+
pub: device.NoisePublicKey(pub),
139+
priv: device.NoisePrivateKey(priv),
140+
tnet: tnet,
141+
device: dev,
142+
server: server,
143+
}
144+
}
145+
146+
func newFakeTunnelHttpsServer(t *testing.T, pubBytes [32]byte) *httptest.Server {
147+
handler := http.NewServeMux()
148+
handler.HandleFunc("/tun", func(writer http.ResponseWriter, request *http.Request) {
149+
assert.Equal(t, "POST", request.Method)
150+
151+
resp := devtunnel.ServerResponse{
152+
Hostname: fmt.Sprintf("[%s]", serverIp.String()),
153+
ServerIP: serverIp,
154+
ServerPublicKey: hex.EncodeToString(pubBytes[:]),
155+
ClientIP: clientIp,
156+
}
157+
b, err := json.Marshal(&resp)
158+
assert.NoError(t, err)
159+
writer.WriteHeader(200)
160+
_, err = writer.Write(b)
161+
assert.NoError(t, err)
162+
})
163+
164+
server := httptest.NewTLSServer(handler)
165+
t.Cleanup(func() {
166+
server.Close()
167+
})
168+
return server
169+
}
170+
171+
func (f *fakeTunnelServer) config() devtunnel.Config {
172+
priv, err := wgtypes.GeneratePrivateKey()
173+
require.NoError(f.t, err)
174+
pub := priv.PublicKey()
175+
f.clients++
176+
assert.Equal(f.t, 1, f.clients) // only allow one client as we hardcode the address
177+
178+
err = f.device.IpcSet(fmt.Sprintf(`public_key=%x
179+
allowed_ip=%s/128`,
180+
pub[:],
181+
clientIp.String(),
182+
))
183+
require.NoError(f.t, err)
184+
return devtunnel.Config{
185+
Version: 1,
186+
PrivateKey: device.NoisePrivateKey(priv),
187+
PublicKey: device.NoisePublicKey(pub),
188+
Tunnel: devtunnel.Node{
189+
HostnameHTTPS: strings.TrimPrefix(f.server.URL, "https://"),
190+
HostnameWireguard: "::1",
191+
WireguardPort: wgPort,
192+
},
193+
HTTPClient: f.server.Client(),
194+
}
195+
}
196+
197+
func (f *fakeTunnelServer) requestHttp() (*http.Response, error) {
198+
transport := &http.Transport{
199+
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
200+
f.t.Log("Dial", network, addr)
201+
nc, err := f.tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(clientIp, 8090))
202+
assert.NoError(f.t, err)
203+
return nc, err
204+
},
205+
}
206+
client := &http.Client{
207+
Transport: transport,
208+
Timeout: 10 * time.Second,
209+
}
210+
return client.Get(fmt.Sprintf("http://[%s]:8090", clientIp))
211+
}

0 commit comments

Comments
 (0)