Skip to content

chore: get TUN/DNS working on Windows for CoderVPN #16310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions cli/vpndaemon_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ func (r *RootCmd) vpnDaemonRun() *serpent.Command {
},
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
logger := inv.Logger.AppendSinks(sloghuman.Sink(inv.Stderr)).Leveled(slog.LevelDebug)
sinks := []slog.Sink{
sloghuman.Sink(inv.Stderr),
}
logger := inv.Logger.AppendSinks(sinks...).Leveled(slog.LevelDebug)

if rpcReadHandleInt < 0 || rpcWriteHandleInt < 0 {
return xerrors.Errorf("rpc-read-handle (%v) and rpc-write-handle (%v) must be positive", rpcReadHandleInt, rpcWriteHandleInt)
Expand All @@ -60,7 +63,11 @@ func (r *RootCmd) vpnDaemonRun() *serpent.Command {
defer pipe.Close()

logger.Info(ctx, "starting tunnel")
tunnel, err := vpn.NewTunnel(ctx, logger, pipe, vpn.NewClient())
tunnel, err := vpn.NewTunnel(ctx, logger, pipe, vpn.NewClient(),
vpn.UseOSNetworkingStack(),
vpn.UseAsLogger(),
vpn.UseCustomLogSinks(sinks...),
)
if err != nil {
return xerrors.Errorf("create new tunnel for client: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ require (
go.opentelemetry.io/proto/otlp v1.4.0 // indirect
go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect
golang.org/x/time v0.9.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 // indirect
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
google.golang.org/appengine v1.6.8 // indirect
Expand Down
17 changes: 11 additions & 6 deletions tailnet/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ type Options struct {
Router router.Router
// TUNDev is optional, and is passed to the underlying wireguard engine.
TUNDev tun.Device
// WireguardMonitor is optional, and is passed to the underlying wireguard
// engine.
WireguardMonitor *netmon.Monitor
}

// TelemetrySink allows tailnet.Conn to send network telemetry to the Coder
Expand Down Expand Up @@ -171,13 +174,15 @@ func NewConn(options *Options) (conn *Conn, err error) {
nodeID = tailcfg.NodeID(uid)
}

wireguardMonitor, err := netmon.New(Logger(options.Logger.Named("net.wgmonitor")))
if err != nil {
return nil, xerrors.Errorf("create wireguard link monitor: %w", err)
if options.WireguardMonitor == nil {
options.WireguardMonitor, err = netmon.New(Logger(options.Logger.Named("net.wgmonitor")))
if err != nil {
return nil, xerrors.Errorf("create wireguard link monitor: %w", err)
}
}
defer func() {
if err != nil {
wireguardMonitor.Close()
options.WireguardMonitor.Close()
}
}()

Expand All @@ -186,7 +191,7 @@ func NewConn(options *Options) (conn *Conn, err error) {
}
sys := new(tsd.System)
wireguardEngine, err := wgengine.NewUserspaceEngine(Logger(options.Logger.Named("net.wgengine")), wgengine.Config{
NetMon: wireguardMonitor,
NetMon: options.WireguardMonitor,
Dialer: dialer,
ListenPort: options.ListenPort,
SetSubsystem: sys.Set,
Expand Down Expand Up @@ -293,7 +298,7 @@ func NewConn(options *Options) (conn *Conn, err error) {
listeners: map[listenKey]*listener{},
tunDevice: sys.Tun.Get(),
netStack: netStack,
wireguardMonitor: wireguardMonitor,
wireguardMonitor: options.WireguardMonitor,
wireguardRouter: &router.Config{
LocalAddrs: options.Addresses,
},
Expand Down
26 changes: 10 additions & 16 deletions vpn/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"golang.org/x/xerrors"
"tailscale.com/net/dns"
"tailscale.com/net/netmon"
"tailscale.com/wgengine/router"

"github.com/google/uuid"
Expand Down Expand Up @@ -57,12 +58,13 @@ func NewClient() Client {
}

type Options struct {
Headers http.Header
Logger slog.Logger
DNSConfigurator dns.OSConfigurator
Router router.Router
TUNFileDescriptor *int
UpdateHandler tailnet.UpdatesHandler
Headers http.Header
Logger slog.Logger
DNSConfigurator dns.OSConfigurator
Router router.Router
TUNDevice tun.Device
WireguardMonitor *netmon.Monitor
UpdateHandler tailnet.UpdatesHandler
}

func (*client) NewConn(initCtx context.Context, serverURL *url.URL, token string, options *Options) (vpnC Conn, err error) {
Expand All @@ -74,15 +76,6 @@ func (*client) NewConn(initCtx context.Context, serverURL *url.URL, token string
options.Headers = http.Header{}
}

var dev tun.Device
if options.TUNFileDescriptor != nil {
// No-op on non-Darwin platforms.
dev, err = makeTUN(*options.TUNFileDescriptor)
if err != nil {
return nil, xerrors.Errorf("make TUN: %w", err)
}
}

headers := options.Headers
sdk := codersdk.New(serverURL)
sdk.SetSessionToken(token)
Expand Down Expand Up @@ -134,7 +127,8 @@ func (*client) NewConn(initCtx context.Context, serverURL *url.URL, token string
BlockEndpoints: connInfo.DisableDirectConnections,
DNSConfigurator: options.DNSConfigurator,
Router: options.Router,
TUNDev: dev,
TUNDev: options.TUNDevice,
WireguardMonitor: options.WireguardMonitor,
})
if err != nil {
return nil, xerrors.Errorf("create tailnet: %w", err)
Expand Down
3 changes: 1 addition & 2 deletions vpn/dylib/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ func OpenTunnel(cReadFD, cWriteFD int32) int32 {
}

_, err = vpn.NewTunnel(ctx, slog.Make(), conn, vpn.NewClient(),
vpn.UseAsDNSConfig(),
vpn.UseAsRouter(),
vpn.UseOSNetworkingStack(),
vpn.UseAsLogger(),
)
if err != nil {
Expand Down
10 changes: 5 additions & 5 deletions vpn/tun.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
//go:build !darwin
//go:build !darwin && !windows

package vpn

import "github.com/tailscale/wireguard-go/tun"
import "cdr.dev/slog"

// This is a no-op on non-Darwin platforms.
func makeTUN(int) (tun.Device, error) {
return nil, nil
// This is a no-op on every platform except Darwin and Windows.
func GetNetworkingStack(_ *Tunnel, _ *StartRequest, _ slog.Logger) (NetworkStack, error) {
return NetworkStack{}, nil
}
20 changes: 14 additions & 6 deletions vpn/tun_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,34 @@ package vpn
import (
"os"

"cdr.dev/slog"
"github.com/tailscale/wireguard-go/tun"
"golang.org/x/sys/unix"
"golang.org/x/xerrors"
)

func makeTUN(tunFD int) (tun.Device, error) {
dupTunFd, err := unix.Dup(tunFD)
func GetNetworkingStack(t *Tunnel, req *StartRequest, _ slog.Logger) (NetworkStack, error) {
tunFd := int(req.GetTunnelFileDescriptor())
dupTunFd, err := unix.Dup(tunFd)
if err != nil {
return nil, xerrors.Errorf("dup tun fd: %w", err)
return NetworkStack{}, xerrors.Errorf("dup tun fd: %w", err)
}

err = unix.SetNonblock(dupTunFd, true)
if err != nil {
unix.Close(dupTunFd)
return nil, xerrors.Errorf("set nonblock: %w", err)
return NetworkStack{}, xerrors.Errorf("set nonblock: %w", err)
}
fileTun, err := tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0)
if err != nil {
unix.Close(dupTunFd)
return nil, xerrors.Errorf("create TUN from File: %w", err)
return NetworkStack{}, xerrors.Errorf("create TUN from File: %w", err)
}
return fileTun, nil

return NetworkStack{
WireguardMonitor: nil, // default is fine
TUNDevice: fileTun,
Router: NewRouter(t),
DNSConfigurator: NewDNSConfigurator(t),
}, nil
}
115 changes: 115 additions & 0 deletions vpn/tun_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
//go:build windows

package vpn

import (
"context"
"errors"
"time"

"github.com/coder/retry"
"github.com/tailscale/wireguard-go/tun"
"golang.org/x/sys/windows"
"golang.org/x/xerrors"
"golang.zx2c4.com/wintun"
"tailscale.com/net/dns"
"tailscale.com/net/netmon"
"tailscale.com/net/tstun"
"tailscale.com/types/logger"
"tailscale.com/util/winutil"
"tailscale.com/wgengine/router"

"cdr.dev/slog"
"github.com/coder/coder/v2/tailnet"
)

const tunName = "Coder"

func GetNetworkingStack(t *Tunnel, _ *StartRequest, logger slog.Logger) (NetworkStack, error) {
tun.WintunTunnelType = tunName
guid, err := windows.GUIDFromString("{0ed1515d-04a4-4c46-abae-11ad07cf0e6d}")
if err != nil {
panic(err)
}
tun.WintunStaticRequestedGUID = &guid

tunDev, tunName, err := tstunNewWithWindowsRetries(tailnet.Logger(logger.Named("net.tun.device")), tunName)
if err != nil {
return NetworkStack{}, xerrors.Errorf("create tun device: %w", err)
}
logger.Info(context.Background(), "tun created", slog.F("name", tunName))

wireguardMonitor, err := netmon.New(tailnet.Logger(logger.Named("net.wgmonitor")))

coderRouter, err := router.New(tailnet.Logger(logger.Named("net.router")), tunDev, wireguardMonitor)
if err != nil {
return NetworkStack{}, xerrors.Errorf("create router: %w", err)
}

dnsConfigurator, err := dns.NewOSConfigurator(tailnet.Logger(logger.Named("net.dns")), tunName)
if err != nil {
return NetworkStack{}, xerrors.Errorf("create dns configurator: %w", err)
}

return NetworkStack{
WireguardMonitor: nil, // default is fine
TUNDevice: tunDev,
Router: coderRouter,
DNSConfigurator: dnsConfigurator,
}, nil
}

// tstunNewOrRetry is a wrapper around tstun.New that retries on Windows for certain
// errors.
//
// This is taken from Tailscale:
// https://github.com/tailscale/tailscale/blob/3abfbf50aebbe3ba57dc749165edb56be6715c0a/cmd/tailscaled/tailscaled_windows.go#L107
func tstunNewWithWindowsRetries(logf logger.Logf, tunName string) (_ tun.Device, devName string, _ error) {
r := retry.New(250*time.Millisecond, 10*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
for r.Wait(ctx) {
dev, devName, err := tstun.New(logf, tunName)
if err == nil {
return dev, devName, err
}
if errors.Is(err, windows.ERROR_DEVICE_NOT_AVAILABLE) || windowsUptime() < 10*time.Minute {
// Wintun is not installing correctly. Dump the state of NetSetupSvc
// (which is a user-mode service that must be active for network devices
// to install) and its dependencies to the log.
winutil.LogSvcState(logf, "NetSetupSvc")
}
}

return nil, "", ctx.Err()
}

var (
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
getTickCount64Proc = kernel32.NewProc("GetTickCount64")
)

func windowsUptime() time.Duration {
r, _, _ := getTickCount64Proc.Call()
return time.Duration(int64(r)) * time.Millisecond
}

// TODO(@dean): implement a way to install/uninstall the wintun driver, most
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// likely as a CLI command
//
// This is taken from Tailscale:
// https://github.com/tailscale/tailscale/blob/3abfbf50aebbe3ba57dc749165edb56be6715c0a/cmd/tailscaled/tailscaled_windows.go#L543
func uninstallWinTun(logf logger.Logf) {
dll := windows.NewLazyDLL("wintun.dll")
if err := dll.Load(); err != nil {
logf("Cannot load wintun.dll for uninstall: %v", err)
return
}

logf("Removing wintun driver...")
err := wintun.Uninstall()
logf("Uninstall: %v", err)
}

// TODO(@dean): remove
var _ = uninstallWinTun
Loading
Loading