Skip to content

feat: add ssh support over wireguard #2642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions agent/wireguard.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package agent

import (
"context"
"net"
"strconv"

"golang.org/x/xerrors"
"inet.af/netaddr"
Expand Down Expand Up @@ -58,6 +60,38 @@ func (a *agent) startWireguard(ctx context.Context, addrs []netaddr.IPPrefix) er
}
}()

a.startWireguardListeners(ctx, wg, []handlerPort{
{port: 12212, handler: a.sshServer.HandleConn},
})

a.network = wg
return nil
}

type handlerPort struct {
handler func(conn net.Conn)
port uint16
}

func (a *agent) startWireguardListeners(ctx context.Context, network *peerwg.Network, handlers []handlerPort) {
for _, h := range handlers {
go func(h handlerPort) {
a.logger.Debug(ctx, "starting wireguard listener", slog.F("port", h.port))

listener, err := network.Listen("tcp", net.JoinHostPort("", strconv.Itoa(int(h.port))))
if err != nil {
a.logger.Warn(ctx, "listen wireguard", slog.F("port", h.port), slog.Error(err))
return
}

for {
conn, err := listener.Accept()
if err != nil {
return
}

go h.handler(conn)
}
}(h)
}
}
9 changes: 8 additions & 1 deletion cli/configssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ func configSSH() *cobra.Command {
coderConfigFile string
dryRun bool
skipProxyCommand bool
wireguard bool
)
cmd := &cobra.Command{
Annotations: workspaceCommand,
Expand Down Expand Up @@ -287,7 +288,11 @@ func configSSH() *cobra.Command {
"\tLogLevel ERROR",
)
if !skipProxyCommand {
configOptions = append(configOptions, fmt.Sprintf("\tProxyCommand %q --global-config %q ssh --stdio %s", binaryFile, root, hostname))
if !wireguard {
configOptions = append(configOptions, fmt.Sprintf("\tProxyCommand %q --global-config %q ssh --stdio %s", binaryFile, root, hostname))
} else {
configOptions = append(configOptions, fmt.Sprintf("\tProxyCommand %q --global-config %q ssh --wireguard --stdio %s", binaryFile, root, hostname))
}
}

_, _ = buf.WriteString(strings.Join(configOptions, "\n"))
Expand Down Expand Up @@ -374,6 +379,8 @@ func configSSH() *cobra.Command {
cmd.Flags().BoolVarP(&skipProxyCommand, "skip-proxy-command", "", false, "Specifies whether the ProxyCommand option should be skipped. Useful for testing.")
_ = cmd.Flags().MarkHidden("skip-proxy-command")
cliflag.BoolVarP(cmd.Flags(), &usePreviousOpts, "use-previous-options", "", "CODER_SSH_USE_PREVIOUS_OPTIONS", false, "Specifies whether or not to keep options from previous run of config-ssh.")
cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_CONFIG_SSH_WIREGUARD", false, "Whether to use Wireguard for SSH tunneling.")
_ = cmd.Flags().MarkHidden("wireguard")

// Deprecated: Remove after migration period.
cmd.Flags().StringVar(&coderConfigFile, "test.ssh-coder-config-file", sshDefaultCoderConfigFileName, "Specifies the path to an Coder SSH config file. Useful for testing.")
Expand Down
119 changes: 95 additions & 24 deletions cli/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@ import (
gosshagent "golang.org/x/crypto/ssh/agent"
"golang.org/x/term"
"golang.org/x/xerrors"
"inet.af/netaddr"
tslogger "tailscale.com/types/logger"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/autobuild/notify"
"github.com/coder/coder/coderd/util/ptr"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/cryptorand"
"github.com/coder/coder/peer/peerwg"
)

var workspacePollInterval = time.Minute
Expand All @@ -37,6 +42,7 @@ func ssh() *cobra.Command {
forwardAgent bool
identityAgent string
wsPollInterval time.Duration
wireguard bool
)
cmd := &cobra.Command{
Annotations: workspaceCommand,
Expand All @@ -61,7 +67,7 @@ func ssh() *cobra.Command {
}
}

workspace, agent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], shuffle)
workspace, workspaceAgent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], shuffle)
if err != nil {
return err
}
Expand All @@ -71,41 +77,104 @@ func ssh() *cobra.Command {
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
WorkspaceName: workspace.Name,
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
return client.WorkspaceAgent(ctx, agent.ID)
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
},
})
if err != nil {
return xerrors.Errorf("await agent: %w", err)
}

conn, err := client.DialWorkspaceAgent(cmd.Context(), agent.ID, nil)
if err != nil {
return err
}
defer conn.Close()
var (
sshClient *gossh.Client
sshSession *gossh.Session
)

stopPolling := tryPollWorkspaceAutostop(cmd.Context(), client, workspace)
defer stopPolling()
if !wireguard {
conn, err := client.DialWorkspaceAgent(cmd.Context(), workspaceAgent.ID, nil)
if err != nil {
return err
}
defer conn.Close()

if stdio {
rawSSH, err := conn.SSH()
stopPolling := tryPollWorkspaceAutostop(cmd.Context(), client, workspace)
defer stopPolling()

if stdio {
rawSSH, err := conn.SSH()
if err != nil {
return err
}
go func() {
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
}()
_, _ = io.Copy(rawSSH, cmd.InOrStdin())
return nil
}

sshClient, err = conn.SSHClient()
if err != nil {
return err
}
go func() {
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
}()
_, _ = io.Copy(rawSSH, cmd.InOrStdin())
return nil
}
sshClient, err := conn.SSHClient()
if err != nil {
return err
}

sshSession, err := sshClient.NewSession()
if err != nil {
return err
sshSession, err = sshClient.NewSession()
if err != nil {
return err
}
} else {
// TODO: more granual control of Tailscale logging.
peerwg.Logf = tslogger.Discard

ipv6 := peerwg.UUIDToNetaddr(uuid.New())
wgn, err := peerwg.New(
slog.Make(sloghuman.Sink(os.Stderr)),
[]netaddr.IPPrefix{netaddr.IPPrefixFrom(ipv6, 128)},
)
if err != nil {
return xerrors.Errorf("create wireguard network: %w", err)
}

err = client.PostWireguardPeer(cmd.Context(), workspace.ID, peerwg.Handshake{
Recipient: workspaceAgent.ID,
NodePublicKey: wgn.NodePrivateKey.Public(),
DiscoPublicKey: wgn.DiscoPublicKey,
IPv6: ipv6,
})
if err != nil {
return xerrors.Errorf("post wireguard peer: %w", err)
}

err = wgn.AddPeer(peerwg.Handshake{
Recipient: workspaceAgent.ID,
DiscoPublicKey: workspaceAgent.DiscoPublicKey,
NodePublicKey: workspaceAgent.WireguardPublicKey,
IPv6: workspaceAgent.IPv6.IP(),
})
if err != nil {
return xerrors.Errorf("add workspace agent as peer: %w", err)
}

if stdio {
rawSSH, err := wgn.SSH(cmd.Context(), workspaceAgent.IPv6.IP())
if err != nil {
return err
}

go func() {
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
}()
_, _ = io.Copy(rawSSH, cmd.InOrStdin())
return nil
}

sshClient, err = wgn.SSHClient(cmd.Context(), workspaceAgent.IPv6.IP())
if err != nil {
return err
}

sshSession, err = sshClient.NewSession()
if err != nil {
return err
}
}

if identityAgent == "" {
Expand Down Expand Up @@ -174,6 +243,8 @@ func ssh() *cobra.Command {
cliflag.BoolVarP(cmd.Flags(), &forwardAgent, "forward-agent", "A", "CODER_SSH_FORWARD_AGENT", false, "Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK")
cliflag.StringVarP(cmd.Flags(), &identityAgent, "identity-agent", "", "CODER_SSH_IDENTITY_AGENT", "", "Specifies which identity agent to use (overrides $SSH_AUTH_SOCK), forward agent must also be enabled")
cliflag.DurationVarP(cmd.Flags(), &wsPollInterval, "workspace-poll-interval", "", "CODER_WORKSPACE_POLL_INTERVAL", workspacePollInterval, "Specifies how often to poll for workspace automated shutdown.")
cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_SSH_WIREGUARD", false, "Whether to use Wireguard for SSH tunneling.")
_ = cmd.Flags().MarkHidden("wireguard")

return cmd
}
Expand Down
38 changes: 38 additions & 0 deletions peer/peerwg/ssh.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package peerwg

import (
"context"
"net"

"golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
"inet.af/netaddr"
)

func (n *Network) SSH(ctx context.Context, ip netaddr.IP) (net.Conn, error) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine for now, but we should probably keep this in the agent package eventually to separate concerns of networking and handshakes over our specific protocol implementations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True

netConn, err := n.Netstack.DialContextTCP(ctx, netaddr.IPPortFrom(ip, 12212))
if err != nil {
return nil, xerrors.Errorf("dial agent ssh: %w", err)
}

return netConn, nil
}

func (n *Network) SSHClient(ctx context.Context, ip netaddr.IP) (*ssh.Client, error) {
netConn, err := n.SSH(ctx, ip)
if err != nil {
return nil, xerrors.Errorf("ssh: %w", err)
}

sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{
// SSH host validation isn't helpful, because obtaining a peer
// connection already signifies user-intent to dial a workspace.
// #nosec
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
})
if err != nil {
return nil, xerrors.Errorf("new ssh client conn: %w", err)
}

return ssh.NewClient(sshConn, channels, requests), nil
}
22 changes: 14 additions & 8 deletions peer/peerwg/wireguard.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (
"cdr.dev/slog"
)

var logf tslogger.Logf = log.Printf
var Logf tslogger.Logf = log.Printf

func init() {
// Globally disable network namespacing.
Expand Down Expand Up @@ -139,15 +139,15 @@ func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) {
DERP: DefaultDerpHome,
}

wgMonitor, err := monitor.New(logf)
wgMonitor, err := monitor.New(Logf)
if err != nil {
return nil, xerrors.Errorf("create link monitor: %w", err)
}

dialer := new(tsdial.Dialer)
dialer.Logf = logf
dialer.Logf = Logf
// Create a wireguard engine in userspace.
engine, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{
engine, err := wgengine.NewUserspaceEngine(Logf, wgengine.Config{
LinkMonitor: wgMonitor,
Dialer: dialer,
})
Expand All @@ -172,7 +172,7 @@ func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) {

// Create the networking stack.
// This is called to route connections.
netStack, err := netstack.Create(logf, tunDev, engine, magicConn, dialer, dnsManager)
netStack, err := netstack.Create(Logf, tunDev, engine, magicConn, dialer, dnsManager)
if err != nil {
return nil, xerrors.Errorf("create netstack: %w", err)
}
Expand All @@ -192,7 +192,7 @@ func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) {
engine = wgengine.NewWatchdog(engine)

// Update the wireguard configuration to allow traffic to flow.
cfg, err := nmcfg.WGCfg(netMap, logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, netMap.SelfNode.StableID)
cfg, err := nmcfg.WGCfg(netMap, Logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, netMap.SelfNode.StableID)
if err != nil {
return nil, xerrors.Errorf("create wgcfg: %w", err)
}
Expand All @@ -216,7 +216,7 @@ func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) {

iplb := netaddr.IPSetBuilder{}
ipl, _ := iplb.IPSet()
engine.SetFilter(filter.New(netMap.PacketFilter, ips, ipl, nil, logf))
engine.SetFilter(filter.New(netMap.PacketFilter, ips, ipl, nil, Logf))

wn := &Network{
logger: logger,
Expand Down Expand Up @@ -319,7 +319,7 @@ func (n *Network) AddPeer(handshake Handshake) error {

n.netMap.Peers = peers

cfg, err := nmcfg.WGCfg(n.netMap, logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, tailcfg.StableNodeID("nBBoJZ5CNTRL"))
cfg, err := nmcfg.WGCfg(n.netMap, Logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, tailcfg.StableNodeID("nBBoJZ5CNTRL"))
if err != nil {
return xerrors.Errorf("create wgcfg: %w", err)
}
Expand Down Expand Up @@ -375,6 +375,12 @@ func (n *Network) Listen(network, addr string) (net.Listener, error) {
}

func (n *Network) Close() error {
// Close all listeners.
for _, l := range n.listeners {
_ = l.Close()
}

// Close the Wireguard netstack and engine.
_ = n.Netstack.Close()
n.wgEngine.Close()

Expand Down