Skip to content

chore: switch to new wgtunnel via tunnelsdk #6489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
chore: switch to new wgtunnel via tunnelsdk
  • Loading branch information
deansheather committed Mar 7, 2023
commit 6658eee7273c38bd831fc00d77ae4b52586f36cc
31 changes: 10 additions & 21 deletions cli/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ import (
"github.com/coder/coder/provisionersdk"
sdkproto "github.com/coder/coder/provisionersdk/proto"
"github.com/coder/coder/tailnet"
"github.com/coder/wgtunnel/tunnelsdk"
)

// nolint:gocyclo
Expand Down Expand Up @@ -347,31 +348,21 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co
return xerrors.Errorf("configure http client: %w", err)
}

var (
ctxTunnel, closeTunnel = context.WithCancel(ctx)
tunnel *devtunnel.Tunnel
tunnelErr <-chan error
)
defer closeTunnel()

// If the access URL is empty, we attempt to run a reverse-proxy
// tunnel to make the initial setup really simple.
var tunnel *tunnelsdk.Tunnel
if cfg.AccessURL.Value == "" {
cmd.Printf("Opening tunnel so workspaces can connect to your deployment. For production scenarios, specify an external access URL\n")
tunnel, tunnelErr, err = devtunnel.New(ctxTunnel, logger.Named("devtunnel"))
tunnel, err = devtunnel.New(ctx, logger.Named("devtunnel"))
if err != nil {
return xerrors.Errorf("create tunnel: %w", err)
}
cfg.AccessURL.Value = tunnel.URL
defer tunnel.Close()
cfg.AccessURL.Value = tunnel.URL.String()

if cfg.WildcardAccessURL.Value == "" {
u, err := parseURL(tunnel.URL)
if err != nil {
return xerrors.Errorf("parse tunnel url: %w", err)
}

// Suffixed wildcard access URL.
cfg.WildcardAccessURL.Value = fmt.Sprintf("*--%s", u.Hostname())
cfg.WildcardAccessURL.Value = fmt.Sprintf("*--%s", tunnel.URL.Hostname())
}
}

Expand Down Expand Up @@ -824,10 +815,8 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co
_, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Bold.Render(
"Interrupt caught, gracefully exiting. Use ctrl+\\ to force quit",
))
case exitErr = <-tunnelErr:
if exitErr == nil {
exitErr = xerrors.New("dev tunnel closed unexpectedly")
}
case <-tunnel.Wait():
exitErr = xerrors.New("dev tunnel closed unexpectedly")
case exitErr = <-errCh:
}
if exitErr != nil && !xerrors.Is(exitErr, context.Canceled) {
Expand Down Expand Up @@ -896,8 +885,8 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co
// Close tunnel after we no longer have in-flight connections.
if tunnel != nil {
cmd.Println("Waiting for tunnel to close...")
closeTunnel()
<-tunnelErr
_ = tunnel.Close()
<-tunnel.Wait()
cmd.Println("Done waiting for tunnel")
}

Expand Down
22 changes: 20 additions & 2 deletions coderd/devtunnel/servers.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/go-ping/ping"
"golang.org/x/exp/slices"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"

"github.com/coder/coder/cryptorand"
)
Expand Down Expand Up @@ -44,18 +45,35 @@ var Regions = []Region{
},
}

func FindClosestNode() (Node, error) {
// Nodes returns a list of nodes to use for the tunnel. It will pick a random
// node from each region.
func Nodes() ([]Node, error) {
nodes := []Node{}

for _, region := range Regions {
// Pick a random node from each region.
i, err := cryptorand.Intn(len(region.Nodes))
if err != nil {
return Node{}, err
return []Node{}, err
}
nodes = append(nodes, region.Nodes[i])
}

return nodes, nil
}

// FindClosestNode pings each node and returns the one with the lowest latency.
func FindClosestNode(nodes []Node) (Node, error) {
if len(nodes) == 0 {
return Node{}, xerrors.New("no wgtunnel nodes")
}
if len(nodes) == 1 {
return nodes[0], nil
}

// Copy the nodes so we don't mutate the original.
nodes = append([]Node{}, nodes...)

var (
nodesMu sync.Mutex
eg = errgroup.Group{}
Expand Down
207 changes: 31 additions & 176 deletions coderd/devtunnel/tunnel.go
Original file line number Diff line number Diff line change
@@ -1,217 +1,68 @@
package devtunnel

import (
"bytes"
"context"
"encoding/hex"
"encoding/json"
"fmt"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"path/filepath"
"time"

"github.com/briandowns/spinner"
"golang.org/x/xerrors"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"

"cdr.dev/slog"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/cryptorand"
"github.com/coder/wgtunnel/tunnelsdk"
)

type Tunnel struct {
URL string
Listener net.Listener
}

type Config struct {
Version int `json:"version"`
PrivateKey device.NoisePrivateKey `json:"private_key"`
PublicKey device.NoisePublicKey `json:"public_key"`
Version tunnelsdk.TunnelVersion `json:"version"`
PrivateKey device.NoisePrivateKey `json:"private_key"`
PublicKey device.NoisePublicKey `json:"public_key"`

Tunnel Node `json:"tunnel"`

// Used in testing. Normally this is nil, indicating to use DefaultClient.
HTTPClient *http.Client `json:"-"`
}
type configExt struct {
Version int `json:"-"`
PrivateKey device.NoisePrivateKey `json:"-"`
PublicKey device.NoisePublicKey `json:"public_key"`

Tunnel Node `json:"-"`

// Used in testing. Normally this is nil, indicating to use DefaultClient.
HTTPClient *http.Client `json:"-"`
}

// NewWithConfig calls New with the given config. For documentation, see New.
func NewWithConfig(ctx context.Context, logger slog.Logger, cfg Config) (*Tunnel, <-chan error, error) {
server, routineEnd, err := startUpdateRoutine(ctx, logger, cfg)
if err != nil {
return nil, nil, xerrors.Errorf("start update routine: %w", err)
}

tun, tnet, err := netstack.CreateNetTUN(
[]netip.Addr{server.ClientIP},
[]netip.Addr{netip.AddrFrom4([4]byte{1, 1, 1, 1})},
1280,
)
if err != nil {
return nil, nil, xerrors.Errorf("create net TUN: %w", err)
}

wgip, err := net.ResolveIPAddr("ip", cfg.Tunnel.HostnameWireguard)
if err != nil {
return nil, nil, xerrors.Errorf("resolve endpoint: %w", err)
}
// In IPv6, we need to enclose the address to in [] before passing to wireguard's endpoint key, like
// [2001:abcd::1]:8888. We'll use netip.AddrPort to correctly handle this.
wgAddr, err := netip.ParseAddr(wgip.String())
if err != nil {
return nil, nil, xerrors.Errorf("parse address: %w", err)
}
wgEndpoint := netip.AddrPortFrom(wgAddr, cfg.Tunnel.WireguardPort)

dlog := &device.Logger{
Verbosef: slog.Stdlib(ctx, logger, slog.LevelDebug).Printf,
Errorf: slog.Stdlib(ctx, logger, slog.LevelError).Printf,
}
dev := device.NewDevice(tun, conn.NewDefaultBind(), dlog)
err = dev.IpcSet(fmt.Sprintf(`private_key=%s
public_key=%s
endpoint=%s
persistent_keepalive_interval=21
allowed_ip=%s/128`,
hex.EncodeToString(cfg.PrivateKey[:]),
server.ServerPublicKey,
wgEndpoint.String(),
server.ServerIP.String(),
))
if err != nil {
return nil, nil, xerrors.Errorf("configure wireguard ipc: %w", err)
}

err = dev.Up()
if err != nil {
return nil, nil, xerrors.Errorf("wireguard device up: %w", err)
}

wgListen, err := tnet.ListenTCP(&net.TCPAddr{Port: 8090})
if err != nil {
return nil, nil, xerrors.Errorf("wireguard device listen: %w", err)
}

ch := make(chan error, 1)
go func() {
select {
case <-ctx.Done():
_ = wgListen.Close()
// We need to remove peers before closing to avoid a race condition between dev.Close() and the peer
// goroutines which results in segfault.
dev.RemoveAllPeers()
dev.Close()
<-routineEnd
close(ch)

case <-dev.Wait():
close(ch)
}
}()

return &Tunnel{
URL: fmt.Sprintf("https://%s", server.Hostname),
Listener: wgListen,
}, ch, nil
func NewWithConfig(ctx context.Context, logger slog.Logger, cfg Config) (*tunnelsdk.Tunnel, error) {
u := &url.URL{
Scheme: "https",
Host: cfg.Tunnel.HostnameHTTPS,
}

c := tunnelsdk.New(u)
return c.LaunchTunnel(ctx, tunnelsdk.TunnelConfig{
Log: logger,
Version: cfg.Version,
PrivateKey: tunnelsdk.FromNoisePrivateKey(cfg.PrivateKey),
})
}

// New creates a tunnel with a public URL and returns a listener for incoming
// connections on that URL. Connections are made over the wireguard protocol.
// Tunnel configuration is cached in the user's config directory. Successive
// calls to New will always use the same URL. If multiple public URLs in
// parallel are required, use NewWithConfig.
func New(ctx context.Context, logger slog.Logger) (*Tunnel, <-chan error, error) {
//
// This uses https://github.com/coder/wgtunnel as the server and client
// implementation.
func New(ctx context.Context, logger slog.Logger) (*tunnelsdk.Tunnel, error) {
cfg, err := readOrGenerateConfig()
if err != nil {
return nil, nil, xerrors.Errorf("read or generate config: %w", err)
return nil, xerrors.Errorf("read or generate config: %w", err)
}

return NewWithConfig(ctx, logger, cfg)
}

func startUpdateRoutine(ctx context.Context, logger slog.Logger, cfg Config) (ServerResponse, <-chan struct{}, error) {
// Ensure we send the first config before spawning in the background.
res, err := sendConfigToServer(ctx, cfg)
if err != nil {
return ServerResponse{}, nil, xerrors.Errorf("send config to server: %w", err)
}

endCh := make(chan struct{})
go func() {
defer close(endCh)
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()

for {
select {
case <-ctx.Done():
return

case <-ticker.C:
}

_, err := sendConfigToServer(ctx, cfg)
if err != nil {
logger.Debug(ctx, "send tunnel config to server", slog.Error(err))
}
}
}()
return res, endCh, nil
}

type ServerResponse struct {
Hostname string `json:"hostname"`
ServerIP netip.Addr `json:"server_ip"`
ServerPublicKey string `json:"server_public_key"` // hex
ClientIP netip.Addr `json:"client_ip"`
}

func sendConfigToServer(ctx context.Context, cfg Config) (ServerResponse, error) {
raw, err := json.Marshal(configExt(cfg))
if err != nil {
return ServerResponse{}, xerrors.Errorf("marshal config: %w", err)
}

req, err := http.NewRequestWithContext(ctx, "POST", "https://"+cfg.Tunnel.HostnameHTTPS+"/tun", bytes.NewReader(raw))
if err != nil {
return ServerResponse{}, xerrors.Errorf("new request: %w", err)
}

client := http.DefaultClient
if cfg.HTTPClient != nil {
client = cfg.HTTPClient
}
res, err := client.Do(req)
if err != nil {
return ServerResponse{}, xerrors.Errorf("do request: %w", err)
}
defer res.Body.Close()

var resp ServerResponse
err = json.NewDecoder(res.Body).Decode(&resp)
if err != nil {
return ServerResponse{}, xerrors.Errorf("decode response: %w", err)
}

return resp, nil
}

func cfgPath() (string, error) {
cfgDir, err := os.UserConfigDir()
if err != nil {
Expand Down Expand Up @@ -281,11 +132,15 @@ func readOrGenerateConfig() (Config, error) {
}

func GenerateConfig() (Config, error) {
priv, err := wgtypes.GeneratePrivateKey()
priv, err := tunnelsdk.GeneratePrivateKey()
if err != nil {
return Config{}, xerrors.Errorf("generate private key: %w", err)
}
pub := priv.PublicKey()
privNoisePublicKey, err := priv.NoisePrivateKey()
if err != nil {
return Config{}, xerrors.Errorf("generate noise private key: %w", err)
}
pubNoisePublicKey := priv.NoisePublicKey()

spin := spinner.New(spinner.CharSets[39], 350*time.Millisecond)
spin.Suffix = " Finding the closest tunnel region..."
Expand All @@ -309,9 +164,9 @@ func GenerateConfig() (Config, error) {
)

return Config{
Version: 1,
PrivateKey: device.NoisePrivateKey(priv),
PublicKey: device.NoisePublicKey(pub),
Version: tunnelsdk.TunnelVersion2,
PrivateKey: privNoisePublicKey,
PublicKey: pubNoisePublicKey,
Tunnel: node,
}, nil
}
Expand Down
2 changes: 1 addition & 1 deletion coderd/devtunnel/tunnel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func TestTunnel(t *testing.T) {
fTunServer := newFakeTunnelServer(t)
cfg := fTunServer.config()

tun, errCh, err := devtunnel.NewWithConfig(ctx, slogtest.Make(t, nil).Leveled(slog.LevelDebug), cfg)
tun, errCh, err := devtunnel.New(ctx, slogtest.Make(t, nil).Leveled(slog.LevelDebug), cfg)
require.NoError(t, err)
t.Log(tun.URL)

Expand Down
Loading