From 0010f590924a1c715a98ea4b44e38089d786070b Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 24 Jun 2022 15:24:03 -0500 Subject: [PATCH 1/2] feat: add ssh support over wireguard --- agent/wireguard.go | 34 +++++++++++ cli/ssh.go | 118 +++++++++++++++++++++++++++++++-------- peer/peerwg/ssh.go | 38 +++++++++++++ peer/peerwg/wireguard.go | 22 +++++--- 4 files changed, 180 insertions(+), 32 deletions(-) create mode 100644 peer/peerwg/ssh.go diff --git a/agent/wireguard.go b/agent/wireguard.go index 3b213bf34c004..603b5616e4740 100644 --- a/agent/wireguard.go +++ b/agent/wireguard.go @@ -2,6 +2,8 @@ package agent import ( "context" + "net" + "strconv" "golang.org/x/xerrors" "inet.af/netaddr" @@ -58,6 +60,38 @@ func (a *agent) startWireguard(ctx context.Context, addrs []netaddr.IPPrefix) er } }() + a.startWireguardListeners(ctx, wg, []handlerPort{ + {port: 12212, handler: a.sshServer.HandleConn}, + }) + a.network = wg return nil } + +type handlerPort struct { + handler func(conn net.Conn) + port uint16 +} + +func (a *agent) startWireguardListeners(ctx context.Context, network *peerwg.Network, handlers []handlerPort) { + for _, h := range handlers { + go func(h handlerPort) { + a.logger.Debug(ctx, "starting wireguard listener", slog.F("port", h.port)) + + listener, err := network.Listen("tcp", net.JoinHostPort("", strconv.Itoa(int(h.port)))) + if err != nil { + a.logger.Warn(ctx, "listen wireguard", slog.F("port", h.port), slog.Error(err)) + return + } + + for { + conn, err := listener.Accept() + if err != nil { + return + } + + go h.handler(conn) + } + }(h) + } +} diff --git a/cli/ssh.go b/cli/ssh.go index 10736901433d9..9d9a2af5652d2 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -18,13 +18,18 @@ import ( gosshagent "golang.org/x/crypto/ssh/agent" "golang.org/x/term" "golang.org/x/xerrors" + "inet.af/netaddr" + tslogger "tailscale.com/types/logger" + "cdr.dev/slog" + "cdr.dev/slog/sloggers/sloghuman" "github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd/autobuild/notify" "github.com/coder/coder/coderd/util/ptr" "github.com/coder/coder/codersdk" "github.com/coder/coder/cryptorand" + "github.com/coder/coder/peer/peerwg" ) var workspacePollInterval = time.Minute @@ -37,6 +42,7 @@ func ssh() *cobra.Command { forwardAgent bool identityAgent string wsPollInterval time.Duration + wireguard bool ) cmd := &cobra.Command{ Annotations: workspaceCommand, @@ -61,7 +67,7 @@ func ssh() *cobra.Command { } } - workspace, agent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], shuffle) + workspace, workspaceAgent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], shuffle) if err != nil { return err } @@ -71,41 +77,104 @@ func ssh() *cobra.Command { err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{ WorkspaceName: workspace.Name, Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) { - return client.WorkspaceAgent(ctx, agent.ID) + return client.WorkspaceAgent(ctx, workspaceAgent.ID) }, }) if err != nil { return xerrors.Errorf("await agent: %w", err) } - conn, err := client.DialWorkspaceAgent(cmd.Context(), agent.ID, nil) - if err != nil { - return err - } - defer conn.Close() + var ( + sshClient *gossh.Client + sshSession *gossh.Session + ) - stopPolling := tryPollWorkspaceAutostop(cmd.Context(), client, workspace) - defer stopPolling() + if !wireguard { + conn, err := client.DialWorkspaceAgent(cmd.Context(), workspaceAgent.ID, nil) + if err != nil { + return err + } + defer conn.Close() - if stdio { - rawSSH, err := conn.SSH() + stopPolling := tryPollWorkspaceAutostop(cmd.Context(), client, workspace) + defer stopPolling() + + if stdio { + rawSSH, err := conn.SSH() + if err != nil { + return err + } + go func() { + _, _ = io.Copy(cmd.OutOrStdout(), rawSSH) + }() + _, _ = io.Copy(rawSSH, cmd.InOrStdin()) + return nil + } + + sshClient, err = conn.SSHClient() if err != nil { return err } - go func() { - _, _ = io.Copy(cmd.OutOrStdout(), rawSSH) - }() - _, _ = io.Copy(rawSSH, cmd.InOrStdin()) - return nil - } - sshClient, err := conn.SSHClient() - if err != nil { - return err - } - sshSession, err := sshClient.NewSession() - if err != nil { - return err + sshSession, err = sshClient.NewSession() + if err != nil { + return err + } + } else { + // TODO: more granual control of Tailscale logging. + peerwg.Logf = tslogger.Discard + + ipv6 := peerwg.UUIDToNetaddr(uuid.New()) + wgn, err := peerwg.New( + slog.Make(sloghuman.Sink(os.Stderr)), + []netaddr.IPPrefix{netaddr.IPPrefixFrom(ipv6, 128)}, + ) + if err != nil { + return xerrors.Errorf("create wireguard network: %w", err) + } + + err = client.PostWireguardPeer(cmd.Context(), workspace.ID, peerwg.Handshake{ + Recipient: workspaceAgent.ID, + NodePublicKey: wgn.NodePrivateKey.Public(), + DiscoPublicKey: wgn.DiscoPublicKey, + IPv6: ipv6, + }) + if err != nil { + return xerrors.Errorf("post wireguard peer: %w", err) + } + + err = wgn.AddPeer(peerwg.Handshake{ + Recipient: workspaceAgent.ID, + DiscoPublicKey: workspaceAgent.DiscoPublicKey, + NodePublicKey: workspaceAgent.WireguardPublicKey, + IPv6: workspaceAgent.IPv6.IP(), + }) + if err != nil { + return xerrors.Errorf("add workspace agent as peer: %w", err) + } + + if stdio { + rawSSH, err := wgn.SSH(cmd.Context(), workspaceAgent.IPv6.IP()) + if err != nil { + return err + } + + go func() { + _, _ = io.Copy(cmd.OutOrStdout(), rawSSH) + }() + _, _ = io.Copy(rawSSH, cmd.InOrStdin()) + return nil + } + + sshClient, err = wgn.SSHClient(cmd.Context(), workspaceAgent.IPv6.IP()) + if err != nil { + return err + } + + sshSession, err = sshClient.NewSession() + if err != nil { + return err + } } if identityAgent == "" { @@ -174,6 +243,7 @@ func ssh() *cobra.Command { cliflag.BoolVarP(cmd.Flags(), &forwardAgent, "forward-agent", "A", "CODER_SSH_FORWARD_AGENT", false, "Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK") cliflag.StringVarP(cmd.Flags(), &identityAgent, "identity-agent", "", "CODER_SSH_IDENTITY_AGENT", "", "Specifies which identity agent to use (overrides $SSH_AUTH_SOCK), forward agent must also be enabled") cliflag.DurationVarP(cmd.Flags(), &wsPollInterval, "workspace-poll-interval", "", "CODER_WORKSPACE_POLL_INTERVAL", workspacePollInterval, "Specifies how often to poll for workspace automated shutdown.") + cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_SSH_WIREGUARD", true, "Whether to use Wireguard for SSH tunneling.") return cmd } diff --git a/peer/peerwg/ssh.go b/peer/peerwg/ssh.go new file mode 100644 index 0000000000000..9ffe8cc92c816 --- /dev/null +++ b/peer/peerwg/ssh.go @@ -0,0 +1,38 @@ +package peerwg + +import ( + "context" + "net" + + "golang.org/x/crypto/ssh" + "golang.org/x/xerrors" + "inet.af/netaddr" +) + +func (n *Network) SSH(ctx context.Context, ip netaddr.IP) (net.Conn, error) { + netConn, err := n.Netstack.DialContextTCP(ctx, netaddr.IPPortFrom(ip, 12212)) + if err != nil { + return nil, xerrors.Errorf("dial agent ssh: %w", err) + } + + return netConn, nil +} + +func (n *Network) SSHClient(ctx context.Context, ip netaddr.IP) (*ssh.Client, error) { + netConn, err := n.SSH(ctx, ip) + if err != nil { + return nil, xerrors.Errorf("ssh: %w", err) + } + + sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{ + // SSH host validation isn't helpful, because obtaining a peer + // connection already signifies user-intent to dial a workspace. + // #nosec + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + if err != nil { + return nil, xerrors.Errorf("new ssh client conn: %w", err) + } + + return ssh.NewClient(sshConn, channels, requests), nil +} diff --git a/peer/peerwg/wireguard.go b/peer/peerwg/wireguard.go index d63a7b039f634..b210b2b70dadc 100644 --- a/peer/peerwg/wireguard.go +++ b/peer/peerwg/wireguard.go @@ -35,7 +35,7 @@ import ( "cdr.dev/slog" ) -var logf tslogger.Logf = log.Printf +var Logf tslogger.Logf = log.Printf func init() { // Globally disable network namespacing. @@ -139,15 +139,15 @@ func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) { DERP: DefaultDerpHome, } - wgMonitor, err := monitor.New(logf) + wgMonitor, err := monitor.New(Logf) if err != nil { return nil, xerrors.Errorf("create link monitor: %w", err) } dialer := new(tsdial.Dialer) - dialer.Logf = logf + dialer.Logf = Logf // Create a wireguard engine in userspace. - engine, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{ + engine, err := wgengine.NewUserspaceEngine(Logf, wgengine.Config{ LinkMonitor: wgMonitor, Dialer: dialer, }) @@ -172,7 +172,7 @@ func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) { // Create the networking stack. // This is called to route connections. - netStack, err := netstack.Create(logf, tunDev, engine, magicConn, dialer, dnsManager) + netStack, err := netstack.Create(Logf, tunDev, engine, magicConn, dialer, dnsManager) if err != nil { return nil, xerrors.Errorf("create netstack: %w", err) } @@ -192,7 +192,7 @@ func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) { engine = wgengine.NewWatchdog(engine) // Update the wireguard configuration to allow traffic to flow. - cfg, err := nmcfg.WGCfg(netMap, logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, netMap.SelfNode.StableID) + cfg, err := nmcfg.WGCfg(netMap, Logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, netMap.SelfNode.StableID) if err != nil { return nil, xerrors.Errorf("create wgcfg: %w", err) } @@ -216,7 +216,7 @@ func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) { iplb := netaddr.IPSetBuilder{} ipl, _ := iplb.IPSet() - engine.SetFilter(filter.New(netMap.PacketFilter, ips, ipl, nil, logf)) + engine.SetFilter(filter.New(netMap.PacketFilter, ips, ipl, nil, Logf)) wn := &Network{ logger: logger, @@ -319,7 +319,7 @@ func (n *Network) AddPeer(handshake Handshake) error { n.netMap.Peers = peers - cfg, err := nmcfg.WGCfg(n.netMap, logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, tailcfg.StableNodeID("nBBoJZ5CNTRL")) + cfg, err := nmcfg.WGCfg(n.netMap, Logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, tailcfg.StableNodeID("nBBoJZ5CNTRL")) if err != nil { return xerrors.Errorf("create wgcfg: %w", err) } @@ -375,6 +375,12 @@ func (n *Network) Listen(network, addr string) (net.Listener, error) { } func (n *Network) Close() error { + // Close all listeners. + for _, l := range n.listeners { + _ = l.Close() + } + + // Close the Wireguard netstack and engine. _ = n.Netstack.Close() n.wgEngine.Close() From 2a361dbf99d78ea6c97bfd5f5b4fc0357ad5deb3 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 24 Jun 2022 15:43:21 -0500 Subject: [PATCH 2/2] config-ssh wireguard --- cli/configssh.go | 9 ++++++++- cli/ssh.go | 3 ++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/cli/configssh.go b/cli/configssh.go index 238fcf592f9e8..725e634e3e989 100644 --- a/cli/configssh.go +++ b/cli/configssh.go @@ -135,6 +135,7 @@ func configSSH() *cobra.Command { coderConfigFile string dryRun bool skipProxyCommand bool + wireguard bool ) cmd := &cobra.Command{ Annotations: workspaceCommand, @@ -287,7 +288,11 @@ func configSSH() *cobra.Command { "\tLogLevel ERROR", ) if !skipProxyCommand { - configOptions = append(configOptions, fmt.Sprintf("\tProxyCommand %q --global-config %q ssh --stdio %s", binaryFile, root, hostname)) + if !wireguard { + configOptions = append(configOptions, fmt.Sprintf("\tProxyCommand %q --global-config %q ssh --stdio %s", binaryFile, root, hostname)) + } else { + configOptions = append(configOptions, fmt.Sprintf("\tProxyCommand %q --global-config %q ssh --wireguard --stdio %s", binaryFile, root, hostname)) + } } _, _ = buf.WriteString(strings.Join(configOptions, "\n")) @@ -374,6 +379,8 @@ func configSSH() *cobra.Command { cmd.Flags().BoolVarP(&skipProxyCommand, "skip-proxy-command", "", false, "Specifies whether the ProxyCommand option should be skipped. Useful for testing.") _ = cmd.Flags().MarkHidden("skip-proxy-command") cliflag.BoolVarP(cmd.Flags(), &usePreviousOpts, "use-previous-options", "", "CODER_SSH_USE_PREVIOUS_OPTIONS", false, "Specifies whether or not to keep options from previous run of config-ssh.") + cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_CONFIG_SSH_WIREGUARD", false, "Whether to use Wireguard for SSH tunneling.") + _ = cmd.Flags().MarkHidden("wireguard") // Deprecated: Remove after migration period. cmd.Flags().StringVar(&coderConfigFile, "test.ssh-coder-config-file", sshDefaultCoderConfigFileName, "Specifies the path to an Coder SSH config file. Useful for testing.") diff --git a/cli/ssh.go b/cli/ssh.go index 9d9a2af5652d2..81eacc33f0899 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -243,7 +243,8 @@ func ssh() *cobra.Command { cliflag.BoolVarP(cmd.Flags(), &forwardAgent, "forward-agent", "A", "CODER_SSH_FORWARD_AGENT", false, "Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK") cliflag.StringVarP(cmd.Flags(), &identityAgent, "identity-agent", "", "CODER_SSH_IDENTITY_AGENT", "", "Specifies which identity agent to use (overrides $SSH_AUTH_SOCK), forward agent must also be enabled") cliflag.DurationVarP(cmd.Flags(), &wsPollInterval, "workspace-poll-interval", "", "CODER_WORKSPACE_POLL_INTERVAL", workspacePollInterval, "Specifies how often to poll for workspace automated shutdown.") - cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_SSH_WIREGUARD", true, "Whether to use Wireguard for SSH tunneling.") + cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_SSH_WIREGUARD", false, "Whether to use Wireguard for SSH tunneling.") + _ = cmd.Flags().MarkHidden("wireguard") return cmd }