diff --git a/cli/root.go b/cli/root.go index f0bae8ff75adb..9f9028c072423 100644 --- a/cli/root.go +++ b/cli/root.go @@ -125,6 +125,7 @@ func (r *RootCmd) CoreSubcommands() []*serpent.Command { r.expCmd(), r.gitssh(), r.support(), + r.vpnDaemon(), r.vscodeSSH(), r.workspaceAgent(), } diff --git a/cli/vpndaemon.go b/cli/vpndaemon.go new file mode 100644 index 0000000000000..eb6a1e2223c5d --- /dev/null +++ b/cli/vpndaemon.go @@ -0,0 +1,21 @@ +package cli + +import ( + "github.com/coder/serpent" +) + +func (r *RootCmd) vpnDaemon() *serpent.Command { + cmd := &serpent.Command{ + Use: "vpn-daemon [subcommand]", + Short: "VPN daemon commands used by Coder Desktop.", + Hidden: true, + Handler: func(inv *serpent.Invocation) error { + return inv.Command.HelpHandler(inv) + }, + Children: []*serpent.Command{ + r.vpnDaemonRun(), + }, + } + + return cmd +} diff --git a/cli/vpndaemon_other.go b/cli/vpndaemon_other.go new file mode 100644 index 0000000000000..2e3e39b1b99ba --- /dev/null +++ b/cli/vpndaemon_other.go @@ -0,0 +1,24 @@ +//go:build !windows + +package cli + +import ( + "golang.org/x/xerrors" + + "github.com/coder/serpent" +) + +func (*RootCmd) vpnDaemonRun() *serpent.Command { + cmd := &serpent.Command{ + Use: "run", + Short: "Run the VPN daemon on Windows.", + Middleware: serpent.Chain( + serpent.RequireNArgs(0), + ), + Handler: func(_ *serpent.Invocation) error { + return xerrors.New("vpn-daemon subcommand is not supported on this platform") + }, + } + + return cmd +} diff --git a/cli/vpndaemon_windows.go b/cli/vpndaemon_windows.go new file mode 100644 index 0000000000000..004fb6493b0c1 --- /dev/null +++ b/cli/vpndaemon_windows.go @@ -0,0 +1,75 @@ +//go:build windows + +package cli + +import ( + "golang.org/x/xerrors" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/sloghuman" + "github.com/coder/coder/v2/vpn" + "github.com/coder/serpent" +) + +func (r *RootCmd) vpnDaemonRun() *serpent.Command { + var ( + rpcReadHandleInt int64 + rpcWriteHandleInt int64 + ) + + cmd := &serpent.Command{ + Use: "run", + Short: "Run the VPN daemon on Windows.", + Middleware: serpent.Chain( + serpent.RequireNArgs(0), + ), + Options: serpent.OptionSet{ + { + Flag: "rpc-read-handle", + Env: "CODER_VPN_DAEMON_RPC_READ_HANDLE", + Description: "The handle for the pipe to read from the RPC connection.", + Value: serpent.Int64Of(&rpcReadHandleInt), + Required: true, + }, + { + Flag: "rpc-write-handle", + Env: "CODER_VPN_DAEMON_RPC_WRITE_HANDLE", + Description: "The handle for the pipe to write to the RPC connection.", + Value: serpent.Int64Of(&rpcWriteHandleInt), + Required: true, + }, + }, + Handler: func(inv *serpent.Invocation) error { + ctx := inv.Context() + logger := inv.Logger.AppendSinks(sloghuman.Sink(inv.Stderr)).Leveled(slog.LevelDebug) + + if rpcReadHandleInt < 0 || rpcWriteHandleInt < 0 { + return xerrors.Errorf("rpc-read-handle (%v) and rpc-write-handle (%v) must be positive", rpcReadHandleInt, rpcWriteHandleInt) + } + if rpcReadHandleInt == rpcWriteHandleInt { + return xerrors.Errorf("rpc-read-handle (%v) and rpc-write-handle (%v) must be different", rpcReadHandleInt, rpcWriteHandleInt) + } + + // We don't need to worry about duplicating the handles on Windows, + // which is different from Unix. + logger.Info(ctx, "opening bidirectional RPC pipe", slog.F("rpc_read_handle", rpcReadHandleInt), slog.F("rpc_write_handle", rpcWriteHandleInt)) + pipe, err := vpn.NewBidirectionalPipe(uintptr(rpcReadHandleInt), uintptr(rpcWriteHandleInt)) + if err != nil { + return xerrors.Errorf("create bidirectional RPC pipe: %w", err) + } + defer pipe.Close() + + logger.Info(ctx, "starting tunnel") + tunnel, err := vpn.NewTunnel(ctx, logger, pipe) + if err != nil { + return xerrors.Errorf("create new tunnel for client: %w", err) + } + defer tunnel.Close() + + <-ctx.Done() + return nil + }, + } + + return cmd +} diff --git a/cli/vpndaemon_windows_test.go b/cli/vpndaemon_windows_test.go new file mode 100644 index 0000000000000..98c63277d4fac --- /dev/null +++ b/cli/vpndaemon_windows_test.go @@ -0,0 +1,93 @@ +//go:build windows + +package cli_test + +import ( + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/cli/clitest" + "github.com/coder/coder/v2/testutil" +) + +func TestVPNDaemonRun(t *testing.T) { + t.Parallel() + + t.Run("InvalidFlags", func(t *testing.T) { + t.Parallel() + + cases := []struct { + Name string + Args []string + ErrorContains string + }{ + { + Name: "NoReadHandle", + Args: []string{"--rpc-write-handle", "10"}, + ErrorContains: "rpc-read-handle", + }, + { + Name: "NoWriteHandle", + Args: []string{"--rpc-read-handle", "10"}, + ErrorContains: "rpc-write-handle", + }, + { + Name: "NegativeReadHandle", + Args: []string{"--rpc-read-handle", "-1", "--rpc-write-handle", "10"}, + ErrorContains: "rpc-read-handle", + }, + { + Name: "NegativeWriteHandle", + Args: []string{"--rpc-read-handle", "10", "--rpc-write-handle", "-1"}, + ErrorContains: "rpc-write-handle", + }, + { + Name: "SameHandles", + Args: []string{"--rpc-read-handle", "10", "--rpc-write-handle", "10"}, + ErrorContains: "rpc-read-handle", + }, + } + + for _, c := range cases { + c := c + t.Run(c.Name, func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + inv, _ := clitest.New(t, append([]string{"vpn-daemon", "run"}, c.Args...)...) + err := inv.WithContext(ctx).Run() + require.ErrorContains(t, err, c.ErrorContains) + }) + } + }) + + t.Run("StartsTunnel", func(t *testing.T) { + t.Parallel() + + r1, w1, err := os.Pipe() + require.NoError(t, err) + defer r1.Close() + defer w1.Close() + r2, w2, err := os.Pipe() + require.NoError(t, err) + defer r2.Close() + defer w2.Close() + + ctx := testutil.Context(t, testutil.WaitLong) + inv, _ := clitest.New(t, "vpn-daemon", "run", "--rpc-read-handle", fmt.Sprint(r1.Fd()), "--rpc-write-handle", fmt.Sprint(w2.Fd())) + waiter := clitest.StartWithWaiter(t, inv.WithContext(ctx)) + + // Send garbage which should cause the handshake to fail and the daemon + // to exit. + _, err = w1.Write([]byte("garbage")) + require.NoError(t, err) + waiter.Cancel() + err = waiter.Wait() + require.ErrorContains(t, err, "handshake failed") + }) + + // TODO: once the VPN tunnel functionality is implemented, add tests that + // actually try to instantiate a tunnel to a workspace +} diff --git a/vpn/pipe.go b/vpn/pipe.go new file mode 100644 index 0000000000000..bdf06828cb6a2 --- /dev/null +++ b/vpn/pipe.go @@ -0,0 +1,69 @@ +package vpn + +import ( + "io" + "os" + + "github.com/hashicorp/go-multierror" + "golang.org/x/xerrors" +) + +// BidirectionalPipe combines a pair of files that can be used for bidirectional +// communication. +type BidirectionalPipe struct { + read *os.File + write *os.File +} + +var _ io.ReadWriteCloser = BidirectionalPipe{} + +// NewBidirectionalPipe creates a new BidirectionalPipe from the given file +// descriptors. +func NewBidirectionalPipe(readFd, writeFd uintptr) (BidirectionalPipe, error) { + read := os.NewFile(readFd, "pipe_read") + _, err := read.Stat() + if err != nil { + return BidirectionalPipe{}, xerrors.Errorf("stat pipe_read (fd=%v): %w", readFd, err) + } + write := os.NewFile(writeFd, "pipe_write") + _, err = write.Stat() + if err != nil { + return BidirectionalPipe{}, xerrors.Errorf("stat pipe_write (fd=%v): %w", writeFd, err) + } + return BidirectionalPipe{ + read: read, + write: write, + }, nil +} + +// Read implements io.Reader. Data is read from the read pipe. +func (b BidirectionalPipe) Read(p []byte) (int, error) { + n, err := b.read.Read(p) + if err != nil { + return n, xerrors.Errorf("read from pipe_read (fd=%v): %w", b.read.Fd(), err) + } + return n, nil +} + +// Write implements io.Writer. Data is written to the write pipe. +func (b BidirectionalPipe) Write(p []byte) (n int, err error) { + n, err = b.write.Write(p) + if err != nil { + return n, xerrors.Errorf("write to pipe_write (fd=%v): %w", b.write.Fd(), err) + } + return n, nil +} + +// Close implements io.Closer. Both the read and write pipes are closed. +func (b BidirectionalPipe) Close() error { + var err error + rErr := b.read.Close() + if rErr != nil { + err = multierror.Append(err, xerrors.Errorf("close pipe_read (fd=%v): %w", b.read.Fd(), rErr)) + } + wErr := b.write.Close() + if err != nil { + err = multierror.Append(err, xerrors.Errorf("close pipe_write (fd=%v): %w", b.write.Fd(), wErr)) + } + return err +}