diff --git a/coderd/devtunnel/tunnel.go b/coderd/devtunnel/tunnel.go index bdcd66f266625..b177e6fb88141 100644 --- a/coderd/devtunnel/tunnel.go +++ b/coderd/devtunnel/tunnel.go @@ -36,6 +36,9 @@ type Config struct { PublicKey device.NoisePublicKey `json:"public_key"` Tunnel Node `json:"tunnel"` + + // Used in testing. Normally this is nil, indicating to use DefaultClient. + HTTPClient *http.Client `json:"-"` } type configExt struct { Version int `json:"-"` @@ -43,6 +46,9 @@ type configExt struct { PublicKey device.NoisePublicKey `json:"public_key"` Tunnel Node `json:"-"` + + // Used in testing. Normally this is nil, indicating to use DefaultClient. + HTTPClient *http.Client `json:"-"` } // NewWithConfig calls New with the given config. For documentation, see New. @@ -65,17 +71,23 @@ func NewWithConfig(ctx context.Context, logger slog.Logger, cfg Config) (*Tunnel if err != nil { return nil, nil, xerrors.Errorf("resolve endpoint: %w", err) } + // In IPv6, we need to enclose the address to in [] before passing to wireguard's endpoint key, like + // [2001:abcd::1]:8888. We'll use netip.AddrPort to correctly handle this. + wgAddr, err := netip.ParseAddr(wgip.String()) + if err != nil { + return nil, nil, xerrors.Errorf("parse address: %w", err) + } + wgEndpoint := netip.AddrPortFrom(wgAddr, cfg.Tunnel.WireguardPort) - dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelSilent, "")) + dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelError, "devtunnel ")) err = dev.IpcSet(fmt.Sprintf(`private_key=%s public_key=%s -endpoint=%s:%d +endpoint=%s persistent_keepalive_interval=21 allowed_ip=%s/128`, hex.EncodeToString(cfg.PrivateKey[:]), server.ServerPublicKey, - wgip.IP.String(), - cfg.Tunnel.WireguardPort, + wgEndpoint.String(), server.ServerIP.String(), )) if err != nil { @@ -97,6 +109,9 @@ allowed_ip=%s/128`, select { case <-ctx.Done(): _ = wgListen.Close() + // We need to remove peers before closing to avoid a race condition between dev.Close() and the peer + // goroutines which results in segfault. + dev.RemoveAllPeers() dev.Close() <-routineEnd close(ch) @@ -174,7 +189,11 @@ func sendConfigToServer(ctx context.Context, cfg Config) (ServerResponse, error) return ServerResponse{}, xerrors.Errorf("new request: %w", err) } - res, err := http.DefaultClient.Do(req) + client := http.DefaultClient + if cfg.HTTPClient != nil { + client = cfg.HTTPClient + } + res, err := client.Do(req) if err != nil { return ServerResponse{}, xerrors.Errorf("do request: %w", err) } diff --git a/coderd/devtunnel/tunnel_test.go b/coderd/devtunnel/tunnel_test.go index 8ecde68f4fe60..a012922f801fd 100644 --- a/coderd/devtunnel/tunnel_test.go +++ b/coderd/devtunnel/tunnel_test.go @@ -2,74 +2,89 @@ package devtunnel_test import ( "context" + "encoding/hex" + "encoding/json" + "fmt" "io" "net" "net/http" + "net/http/httptest" + "net/netip" + "strings" "testing" "time" + "cdr.dev/slog" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/coderd/devtunnel" ) +const ( + ipByte1 = 0xfc + ipByte2 = 0xca + wgPort = 48732 +) + +var ( + serverIP = netip.AddrFrom16([16]byte{ipByte1, ipByte2, 15: 0x1}) + dnsIP = netip.AddrFrom4([4]byte{1, 1, 1, 1}) + clientIP = netip.AddrFrom16([16]byte{ipByte1, ipByte2, 15: 0x2}) +) + // The tunnel leaks a few goroutines that aren't impactful to production scenarios. // func TestMain(m *testing.M) { // goleak.VerifyTestMain(m) // } +// TestTunnel cannot run in parallel because we hardcode the UDP port used by the wireguard server. +// nolint: paralleltest func TestTunnel(t *testing.T) { - t.Parallel() - - // It's not super useful for us to test this constantly, it'll only cause - // flakes is the tunnel becomes unavailable for some reason. - t.Skip() - // if testing.Short() { - // t.Skip() - // return - // } - ctx, cancelTun := context.WithCancel(context.Background()) defer cancelTun() server := http.Server{ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) + t.Log("got request for", r.URL) + // Going to use something _slightly_ exotic so that we can't accidentally get some + // default behavior creating a false positive on the test + w.WriteHeader(http.StatusAccepted) }), BaseContext: func(_ net.Listener) context.Context { return ctx }, } - cfg, err := devtunnel.GenerateConfig() - require.NoError(t, err) + fTunServer := newFakeTunnelServer(t) + cfg := fTunServer.config() - tun, errCh, err := devtunnel.NewWithConfig(ctx, slogtest.Make(t, nil), cfg) + tun, errCh, err := devtunnel.NewWithConfig(ctx, slogtest.Make(t, nil).Leveled(slog.LevelDebug), cfg) require.NoError(t, err) t.Log(tun.URL) - go server.Serve(tun.Listener) - defer tun.Listener.Close() - - httpClient := &http.Client{ - Timeout: 10 * time.Second, - } + go func() { + err := server.Serve(tun.Listener) + assert.Equal(t, http.ErrServerClosed, err) + }() + t.Cleanup(func() { _ = server.Close() }) + t.Cleanup(func() { tun.Listener.Close() }) require.Eventually(t, func() bool { - req, err := http.NewRequestWithContext(ctx, "GET", tun.URL, nil) - require.NoError(t, err) - - res, err := httpClient.Do(req) + res, err := fTunServer.requestHTTP() require.NoError(t, err) defer res.Body.Close() _, _ = io.Copy(io.Discard, res.Body) - return res.StatusCode == http.StatusOK + return res.StatusCode == http.StatusAccepted }, time.Minute, time.Second) - httpClient.CloseIdleConnections() assert.NoError(t, server.Close()) cancelTun() @@ -79,3 +94,127 @@ func TestTunnel(t *testing.T) { t.Error("tunnel did not close after 10 seconds") } } + +// fakeTunnelServer is a fake version of the real dev tunnel server. It fakes 2 client interactions +// that we want to test: +// 1. Responding to a POST /tun from the client +// 2. Sending an HTTP request down the wireguard connection +// +// Note that for 2, we don't implement a full proxy that accepts arbitrary requests, we just send +// a test request over the Wireguard tunnel to make sure that we can listen. The proxy behavior is +// outside of the scope of the dev tunnel client, which is what we are testing here. +type fakeTunnelServer struct { + t *testing.T + pub device.NoisePublicKey + priv device.NoisePrivateKey + tnet *netstack.Net + device *device.Device + clients int + server *httptest.Server +} + +func newFakeTunnelServer(t *testing.T) *fakeTunnelServer { + priv, err := wgtypes.GeneratePrivateKey() + privBytes := [32]byte(priv) + require.NoError(t, err) + pub := priv.PublicKey() + pubBytes := [32]byte(pub) + tun, tnet, err := netstack.CreateNetTUN( + []netip.Addr{serverIP}, + []netip.Addr{dnsIP}, + 1280, + ) + require.NoError(t, err) + dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "server ")) + err = dev.IpcSet(fmt.Sprintf(`private_key=%s +listen_port=%d`, + hex.EncodeToString(privBytes[:]), + wgPort, + )) + require.NoError(t, err) + t.Cleanup(func() { + dev.RemoveAllPeers() + dev.Close() + }) + + err = dev.Up() + require.NoError(t, err) + + server := newFakeTunnelHTTPSServer(t, pubBytes) + + return &fakeTunnelServer{ + t: t, + pub: device.NoisePublicKey(pub), + priv: device.NoisePrivateKey(priv), + tnet: tnet, + device: dev, + server: server, + } +} + +func newFakeTunnelHTTPSServer(t *testing.T, pubBytes [32]byte) *httptest.Server { + handler := http.NewServeMux() + handler.HandleFunc("/tun", func(writer http.ResponseWriter, request *http.Request) { + assert.Equal(t, "POST", request.Method) + + resp := devtunnel.ServerResponse{ + Hostname: fmt.Sprintf("[%s]", serverIP.String()), + ServerIP: serverIP, + ServerPublicKey: hex.EncodeToString(pubBytes[:]), + ClientIP: clientIP, + } + b, err := json.Marshal(&resp) + assert.NoError(t, err) + writer.WriteHeader(200) + _, err = writer.Write(b) + assert.NoError(t, err) + }) + + server := httptest.NewTLSServer(handler) + t.Cleanup(func() { + server.Close() + }) + return server +} + +func (f *fakeTunnelServer) config() devtunnel.Config { + priv, err := wgtypes.GeneratePrivateKey() + require.NoError(f.t, err) + pub := priv.PublicKey() + f.clients++ + assert.Equal(f.t, 1, f.clients) // only allow one client as we hardcode the address + + err = f.device.IpcSet(fmt.Sprintf(`public_key=%x +allowed_ip=%s/128`, + pub[:], + clientIP.String(), + )) + require.NoError(f.t, err) + return devtunnel.Config{ + Version: 1, + PrivateKey: device.NoisePrivateKey(priv), + PublicKey: device.NoisePublicKey(pub), + Tunnel: devtunnel.Node{ + HostnameHTTPS: strings.TrimPrefix(f.server.URL, "https://"), + HostnameWireguard: "localhost", + WireguardPort: wgPort, + }, + HTTPClient: f.server.Client(), + } +} + +func (f *fakeTunnelServer) requestHTTP() (*http.Response, error) { + transport := &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + f.t.Log("Dial", network, addr) + nc, err := f.tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(clientIP, 8090)) + assert.NoError(f.t, err) + return nc, err + }, + } + client := &http.Client{ + Transport: transport, + Timeout: 10 * time.Second, + } + return client.Get(fmt.Sprintf("http://[%s]:8090", clientIP)) +}