Skip to content

restore devtunnel test #3050

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions coderd/devtunnel/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,19 @@ type Config struct {
PublicKey device.NoisePublicKey `json:"public_key"`

Tunnel Node `json:"tunnel"`

// Used in testing. Normally this is nil, indicating to use DefaultClient.
HTTPClient *http.Client `json:"-"`
}
type configExt struct {
Version int `json:"-"`
PrivateKey device.NoisePrivateKey `json:"-"`
PublicKey device.NoisePublicKey `json:"public_key"`

Tunnel Node `json:"-"`

// Used in testing. Normally this is nil, indicating to use DefaultClient.
HTTPClient *http.Client `json:"-"`
}

// NewWithConfig calls New with the given config. For documentation, see New.
Expand All @@ -65,17 +71,23 @@ func NewWithConfig(ctx context.Context, logger slog.Logger, cfg Config) (*Tunnel
if err != nil {
return nil, nil, xerrors.Errorf("resolve endpoint: %w", err)
}
// In IPv6, we need to enclose the address to in [] before passing to wireguard's endpoint key, like
// [2001:abcd::1]:8888. We'll use netip.AddrPort to correctly handle this.
wgAddr, err := netip.ParseAddr(wgip.String())
if err != nil {
return nil, nil, xerrors.Errorf("parse address: %w", err)
}
wgEndpoint := netip.AddrPortFrom(wgAddr, cfg.Tunnel.WireguardPort)

dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelSilent, ""))
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelError, "devtunnel "))
err = dev.IpcSet(fmt.Sprintf(`private_key=%s
public_key=%s
endpoint=%s:%d
endpoint=%s
persistent_keepalive_interval=21
allowed_ip=%s/128`,
hex.EncodeToString(cfg.PrivateKey[:]),
server.ServerPublicKey,
wgip.IP.String(),
cfg.Tunnel.WireguardPort,
wgEndpoint.String(),
server.ServerIP.String(),
))
if err != nil {
Expand All @@ -97,6 +109,9 @@ allowed_ip=%s/128`,
select {
case <-ctx.Done():
_ = wgListen.Close()
// We need to remove peers before closing to avoid a race condition between dev.Close() and the peer
// goroutines which results in segfault.
dev.RemoveAllPeers()
dev.Close()
<-routineEnd
close(ch)
Expand Down Expand Up @@ -174,7 +189,11 @@ func sendConfigToServer(ctx context.Context, cfg Config) (ServerResponse, error)
return ServerResponse{}, xerrors.Errorf("new request: %w", err)
}

res, err := http.DefaultClient.Do(req)
client := http.DefaultClient
if cfg.HTTPClient != nil {
client = cfg.HTTPClient
}
res, err := client.Do(req)
if err != nil {
return ServerResponse{}, xerrors.Errorf("do request: %w", err)
}
Expand Down
191 changes: 165 additions & 26 deletions coderd/devtunnel/tunnel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,74 +2,89 @@ package devtunnel_test

import (
"context"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/netip"
"strings"
"testing"
"time"

"cdr.dev/slog"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"

"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd/devtunnel"
)

const (
ipByte1 = 0xfc
ipByte2 = 0xca
wgPort = 48732
)

var (
serverIP = netip.AddrFrom16([16]byte{ipByte1, ipByte2, 15: 0x1})
dnsIP = netip.AddrFrom4([4]byte{1, 1, 1, 1})
clientIP = netip.AddrFrom16([16]byte{ipByte1, ipByte2, 15: 0x2})
)

// The tunnel leaks a few goroutines that aren't impactful to production scenarios.
// func TestMain(m *testing.M) {
// goleak.VerifyTestMain(m)
// }

// TestTunnel cannot run in parallel because we hardcode the UDP port used by the wireguard server.
// nolint: paralleltest
func TestTunnel(t *testing.T) {
t.Parallel()

// It's not super useful for us to test this constantly, it'll only cause
// flakes is the tunnel becomes unavailable for some reason.
t.Skip()
// if testing.Short() {
// t.Skip()
// return
// }

ctx, cancelTun := context.WithCancel(context.Background())
defer cancelTun()

server := http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
t.Log("got request for", r.URL)
// Going to use something _slightly_ exotic so that we can't accidentally get some
// default behavior creating a false positive on the test
w.WriteHeader(http.StatusAccepted)
}),
BaseContext: func(_ net.Listener) context.Context {
return ctx
},
}

cfg, err := devtunnel.GenerateConfig()
require.NoError(t, err)
fTunServer := newFakeTunnelServer(t)
cfg := fTunServer.config()

tun, errCh, err := devtunnel.NewWithConfig(ctx, slogtest.Make(t, nil), cfg)
tun, errCh, err := devtunnel.NewWithConfig(ctx, slogtest.Make(t, nil).Leveled(slog.LevelDebug), cfg)
require.NoError(t, err)
t.Log(tun.URL)

go server.Serve(tun.Listener)
defer tun.Listener.Close()

httpClient := &http.Client{
Timeout: 10 * time.Second,
}
go func() {
err := server.Serve(tun.Listener)
assert.Equal(t, http.ErrServerClosed, err)
}()
t.Cleanup(func() { _ = server.Close() })
t.Cleanup(func() { tun.Listener.Close() })

require.Eventually(t, func() bool {
req, err := http.NewRequestWithContext(ctx, "GET", tun.URL, nil)
require.NoError(t, err)

res, err := httpClient.Do(req)
res, err := fTunServer.requestHTTP()
require.NoError(t, err)
defer res.Body.Close()
_, _ = io.Copy(io.Discard, res.Body)

return res.StatusCode == http.StatusOK
return res.StatusCode == http.StatusAccepted
}, time.Minute, time.Second)

httpClient.CloseIdleConnections()
assert.NoError(t, server.Close())
cancelTun()

Expand All @@ -79,3 +94,127 @@ func TestTunnel(t *testing.T) {
t.Error("tunnel did not close after 10 seconds")
}
}

// fakeTunnelServer is a fake version of the real dev tunnel server. It fakes 2 client interactions
// that we want to test:
// 1. Responding to a POST /tun from the client
// 2. Sending an HTTP request down the wireguard connection
//
// Note that for 2, we don't implement a full proxy that accepts arbitrary requests, we just send
// a test request over the Wireguard tunnel to make sure that we can listen. The proxy behavior is
// outside of the scope of the dev tunnel client, which is what we are testing here.
type fakeTunnelServer struct {
t *testing.T
pub device.NoisePublicKey
priv device.NoisePrivateKey
tnet *netstack.Net
device *device.Device
clients int
server *httptest.Server
}

func newFakeTunnelServer(t *testing.T) *fakeTunnelServer {
priv, err := wgtypes.GeneratePrivateKey()
privBytes := [32]byte(priv)
require.NoError(t, err)
pub := priv.PublicKey()
pubBytes := [32]byte(pub)
tun, tnet, err := netstack.CreateNetTUN(
[]netip.Addr{serverIP},
[]netip.Addr{dnsIP},
1280,
)
require.NoError(t, err)
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "server "))
err = dev.IpcSet(fmt.Sprintf(`private_key=%s
listen_port=%d`,
hex.EncodeToString(privBytes[:]),
wgPort,
))
require.NoError(t, err)
t.Cleanup(func() {
dev.RemoveAllPeers()
dev.Close()
})

err = dev.Up()
require.NoError(t, err)

server := newFakeTunnelHTTPSServer(t, pubBytes)

return &fakeTunnelServer{
t: t,
pub: device.NoisePublicKey(pub),
priv: device.NoisePrivateKey(priv),
tnet: tnet,
device: dev,
server: server,
}
}

func newFakeTunnelHTTPSServer(t *testing.T, pubBytes [32]byte) *httptest.Server {
handler := http.NewServeMux()
handler.HandleFunc("/tun", func(writer http.ResponseWriter, request *http.Request) {
assert.Equal(t, "POST", request.Method)

resp := devtunnel.ServerResponse{
Hostname: fmt.Sprintf("[%s]", serverIP.String()),
ServerIP: serverIP,
ServerPublicKey: hex.EncodeToString(pubBytes[:]),
ClientIP: clientIP,
}
b, err := json.Marshal(&resp)
assert.NoError(t, err)
writer.WriteHeader(200)
_, err = writer.Write(b)
assert.NoError(t, err)
})

server := httptest.NewTLSServer(handler)
t.Cleanup(func() {
server.Close()
})
return server
}

func (f *fakeTunnelServer) config() devtunnel.Config {
priv, err := wgtypes.GeneratePrivateKey()
require.NoError(f.t, err)
pub := priv.PublicKey()
f.clients++
assert.Equal(f.t, 1, f.clients) // only allow one client as we hardcode the address

err = f.device.IpcSet(fmt.Sprintf(`public_key=%x
allowed_ip=%s/128`,
pub[:],
clientIP.String(),
))
require.NoError(f.t, err)
return devtunnel.Config{
Version: 1,
PrivateKey: device.NoisePrivateKey(priv),
PublicKey: device.NoisePublicKey(pub),
Tunnel: devtunnel.Node{
HostnameHTTPS: strings.TrimPrefix(f.server.URL, "https://"),
HostnameWireguard: "localhost",
WireguardPort: wgPort,
},
HTTPClient: f.server.Client(),
}
}

func (f *fakeTunnelServer) requestHTTP() (*http.Response, error) {
transport := &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
f.t.Log("Dial", network, addr)
nc, err := f.tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(clientIP, 8090))
assert.NoError(f.t, err)
return nc, err
},
}
client := &http.Client{
Transport: transport,
Timeout: 10 * time.Second,
}
return client.Get(fmt.Sprintf("http://[%s]:8090", clientIP))
}