Skip to content

Commit 6aed58f

Browse files
authored
feat: add ssh support over wireguard (coder#2642)
1 parent 26e85b0 commit 6aed58f

File tree

5 files changed

+189
-33
lines changed

5 files changed

+189
-33
lines changed

agent/wireguard.go

+34
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package agent
22

33
import (
44
"context"
5+
"net"
6+
"strconv"
57

68
"golang.org/x/xerrors"
79
"inet.af/netaddr"
@@ -58,6 +60,38 @@ func (a *agent) startWireguard(ctx context.Context, addrs []netaddr.IPPrefix) er
5860
}
5961
}()
6062

63+
a.startWireguardListeners(ctx, wg, []handlerPort{
64+
{port: 12212, handler: a.sshServer.HandleConn},
65+
})
66+
6167
a.network = wg
6268
return nil
6369
}
70+
71+
type handlerPort struct {
72+
handler func(conn net.Conn)
73+
port uint16
74+
}
75+
76+
func (a *agent) startWireguardListeners(ctx context.Context, network *peerwg.Network, handlers []handlerPort) {
77+
for _, h := range handlers {
78+
go func(h handlerPort) {
79+
a.logger.Debug(ctx, "starting wireguard listener", slog.F("port", h.port))
80+
81+
listener, err := network.Listen("tcp", net.JoinHostPort("", strconv.Itoa(int(h.port))))
82+
if err != nil {
83+
a.logger.Warn(ctx, "listen wireguard", slog.F("port", h.port), slog.Error(err))
84+
return
85+
}
86+
87+
for {
88+
conn, err := listener.Accept()
89+
if err != nil {
90+
return
91+
}
92+
93+
go h.handler(conn)
94+
}
95+
}(h)
96+
}
97+
}

cli/configssh.go

+8-1
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ func configSSH() *cobra.Command {
135135
coderConfigFile string
136136
dryRun bool
137137
skipProxyCommand bool
138+
wireguard bool
138139
)
139140
cmd := &cobra.Command{
140141
Annotations: workspaceCommand,
@@ -287,7 +288,11 @@ func configSSH() *cobra.Command {
287288
"\tLogLevel ERROR",
288289
)
289290
if !skipProxyCommand {
290-
configOptions = append(configOptions, fmt.Sprintf("\tProxyCommand %q --global-config %q ssh --stdio %s", binaryFile, root, hostname))
291+
if !wireguard {
292+
configOptions = append(configOptions, fmt.Sprintf("\tProxyCommand %q --global-config %q ssh --stdio %s", binaryFile, root, hostname))
293+
} else {
294+
configOptions = append(configOptions, fmt.Sprintf("\tProxyCommand %q --global-config %q ssh --wireguard --stdio %s", binaryFile, root, hostname))
295+
}
291296
}
292297

293298
_, _ = buf.WriteString(strings.Join(configOptions, "\n"))
@@ -374,6 +379,8 @@ func configSSH() *cobra.Command {
374379
cmd.Flags().BoolVarP(&skipProxyCommand, "skip-proxy-command", "", false, "Specifies whether the ProxyCommand option should be skipped. Useful for testing.")
375380
_ = cmd.Flags().MarkHidden("skip-proxy-command")
376381
cliflag.BoolVarP(cmd.Flags(), &usePreviousOpts, "use-previous-options", "", "CODER_SSH_USE_PREVIOUS_OPTIONS", false, "Specifies whether or not to keep options from previous run of config-ssh.")
382+
cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_CONFIG_SSH_WIREGUARD", false, "Whether to use Wireguard for SSH tunneling.")
383+
_ = cmd.Flags().MarkHidden("wireguard")
377384

378385
// Deprecated: Remove after migration period.
379386
cmd.Flags().StringVar(&coderConfigFile, "test.ssh-coder-config-file", sshDefaultCoderConfigFileName, "Specifies the path to an Coder SSH config file. Useful for testing.")

cli/ssh.go

+95-24
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,18 @@ import (
1818
gosshagent "golang.org/x/crypto/ssh/agent"
1919
"golang.org/x/term"
2020
"golang.org/x/xerrors"
21+
"inet.af/netaddr"
22+
tslogger "tailscale.com/types/logger"
2123

24+
"cdr.dev/slog"
25+
"cdr.dev/slog/sloggers/sloghuman"
2226
"github.com/coder/coder/cli/cliflag"
2327
"github.com/coder/coder/cli/cliui"
2428
"github.com/coder/coder/coderd/autobuild/notify"
2529
"github.com/coder/coder/coderd/util/ptr"
2630
"github.com/coder/coder/codersdk"
2731
"github.com/coder/coder/cryptorand"
32+
"github.com/coder/coder/peer/peerwg"
2833
)
2934

3035
var workspacePollInterval = time.Minute
@@ -37,6 +42,7 @@ func ssh() *cobra.Command {
3742
forwardAgent bool
3843
identityAgent string
3944
wsPollInterval time.Duration
45+
wireguard bool
4046
)
4147
cmd := &cobra.Command{
4248
Annotations: workspaceCommand,
@@ -61,7 +67,7 @@ func ssh() *cobra.Command {
6167
}
6268
}
6369

64-
workspace, agent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], shuffle)
70+
workspace, workspaceAgent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], shuffle)
6571
if err != nil {
6672
return err
6773
}
@@ -71,41 +77,104 @@ func ssh() *cobra.Command {
7177
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
7278
WorkspaceName: workspace.Name,
7379
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
74-
return client.WorkspaceAgent(ctx, agent.ID)
80+
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
7581
},
7682
})
7783
if err != nil {
7884
return xerrors.Errorf("await agent: %w", err)
7985
}
8086

81-
conn, err := client.DialWorkspaceAgent(cmd.Context(), agent.ID, nil)
82-
if err != nil {
83-
return err
84-
}
85-
defer conn.Close()
87+
var (
88+
sshClient *gossh.Client
89+
sshSession *gossh.Session
90+
)
8691

87-
stopPolling := tryPollWorkspaceAutostop(cmd.Context(), client, workspace)
88-
defer stopPolling()
92+
if !wireguard {
93+
conn, err := client.DialWorkspaceAgent(cmd.Context(), workspaceAgent.ID, nil)
94+
if err != nil {
95+
return err
96+
}
97+
defer conn.Close()
8998

90-
if stdio {
91-
rawSSH, err := conn.SSH()
99+
stopPolling := tryPollWorkspaceAutostop(cmd.Context(), client, workspace)
100+
defer stopPolling()
101+
102+
if stdio {
103+
rawSSH, err := conn.SSH()
104+
if err != nil {
105+
return err
106+
}
107+
go func() {
108+
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
109+
}()
110+
_, _ = io.Copy(rawSSH, cmd.InOrStdin())
111+
return nil
112+
}
113+
114+
sshClient, err = conn.SSHClient()
92115
if err != nil {
93116
return err
94117
}
95-
go func() {
96-
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
97-
}()
98-
_, _ = io.Copy(rawSSH, cmd.InOrStdin())
99-
return nil
100-
}
101-
sshClient, err := conn.SSHClient()
102-
if err != nil {
103-
return err
104-
}
105118

106-
sshSession, err := sshClient.NewSession()
107-
if err != nil {
108-
return err
119+
sshSession, err = sshClient.NewSession()
120+
if err != nil {
121+
return err
122+
}
123+
} else {
124+
// TODO: more granual control of Tailscale logging.
125+
peerwg.Logf = tslogger.Discard
126+
127+
ipv6 := peerwg.UUIDToNetaddr(uuid.New())
128+
wgn, err := peerwg.New(
129+
slog.Make(sloghuman.Sink(os.Stderr)),
130+
[]netaddr.IPPrefix{netaddr.IPPrefixFrom(ipv6, 128)},
131+
)
132+
if err != nil {
133+
return xerrors.Errorf("create wireguard network: %w", err)
134+
}
135+
136+
err = client.PostWireguardPeer(cmd.Context(), workspace.ID, peerwg.Handshake{
137+
Recipient: workspaceAgent.ID,
138+
NodePublicKey: wgn.NodePrivateKey.Public(),
139+
DiscoPublicKey: wgn.DiscoPublicKey,
140+
IPv6: ipv6,
141+
})
142+
if err != nil {
143+
return xerrors.Errorf("post wireguard peer: %w", err)
144+
}
145+
146+
err = wgn.AddPeer(peerwg.Handshake{
147+
Recipient: workspaceAgent.ID,
148+
DiscoPublicKey: workspaceAgent.DiscoPublicKey,
149+
NodePublicKey: workspaceAgent.WireguardPublicKey,
150+
IPv6: workspaceAgent.IPv6.IP(),
151+
})
152+
if err != nil {
153+
return xerrors.Errorf("add workspace agent as peer: %w", err)
154+
}
155+
156+
if stdio {
157+
rawSSH, err := wgn.SSH(cmd.Context(), workspaceAgent.IPv6.IP())
158+
if err != nil {
159+
return err
160+
}
161+
162+
go func() {
163+
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
164+
}()
165+
_, _ = io.Copy(rawSSH, cmd.InOrStdin())
166+
return nil
167+
}
168+
169+
sshClient, err = wgn.SSHClient(cmd.Context(), workspaceAgent.IPv6.IP())
170+
if err != nil {
171+
return err
172+
}
173+
174+
sshSession, err = sshClient.NewSession()
175+
if err != nil {
176+
return err
177+
}
109178
}
110179

111180
if identityAgent == "" {
@@ -174,6 +243,8 @@ func ssh() *cobra.Command {
174243
cliflag.BoolVarP(cmd.Flags(), &forwardAgent, "forward-agent", "A", "CODER_SSH_FORWARD_AGENT", false, "Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK")
175244
cliflag.StringVarP(cmd.Flags(), &identityAgent, "identity-agent", "", "CODER_SSH_IDENTITY_AGENT", "", "Specifies which identity agent to use (overrides $SSH_AUTH_SOCK), forward agent must also be enabled")
176245
cliflag.DurationVarP(cmd.Flags(), &wsPollInterval, "workspace-poll-interval", "", "CODER_WORKSPACE_POLL_INTERVAL", workspacePollInterval, "Specifies how often to poll for workspace automated shutdown.")
246+
cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_SSH_WIREGUARD", false, "Whether to use Wireguard for SSH tunneling.")
247+
_ = cmd.Flags().MarkHidden("wireguard")
177248

178249
return cmd
179250
}

peer/peerwg/ssh.go

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package peerwg
2+
3+
import (
4+
"context"
5+
"net"
6+
7+
"golang.org/x/crypto/ssh"
8+
"golang.org/x/xerrors"
9+
"inet.af/netaddr"
10+
)
11+
12+
func (n *Network) SSH(ctx context.Context, ip netaddr.IP) (net.Conn, error) {
13+
netConn, err := n.Netstack.DialContextTCP(ctx, netaddr.IPPortFrom(ip, 12212))
14+
if err != nil {
15+
return nil, xerrors.Errorf("dial agent ssh: %w", err)
16+
}
17+
18+
return netConn, nil
19+
}
20+
21+
func (n *Network) SSHClient(ctx context.Context, ip netaddr.IP) (*ssh.Client, error) {
22+
netConn, err := n.SSH(ctx, ip)
23+
if err != nil {
24+
return nil, xerrors.Errorf("ssh: %w", err)
25+
}
26+
27+
sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{
28+
// SSH host validation isn't helpful, because obtaining a peer
29+
// connection already signifies user-intent to dial a workspace.
30+
// #nosec
31+
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
32+
})
33+
if err != nil {
34+
return nil, xerrors.Errorf("new ssh client conn: %w", err)
35+
}
36+
37+
return ssh.NewClient(sshConn, channels, requests), nil
38+
}

peer/peerwg/wireguard.go

+14-8
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import (
3535
"cdr.dev/slog"
3636
)
3737

38-
var logf tslogger.Logf = log.Printf
38+
var Logf tslogger.Logf = log.Printf
3939

4040
func init() {
4141
// Globally disable network namespacing.
@@ -139,15 +139,15 @@ func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) {
139139
DERP: DefaultDerpHome,
140140
}
141141

142-
wgMonitor, err := monitor.New(logf)
142+
wgMonitor, err := monitor.New(Logf)
143143
if err != nil {
144144
return nil, xerrors.Errorf("create link monitor: %w", err)
145145
}
146146

147147
dialer := new(tsdial.Dialer)
148-
dialer.Logf = logf
148+
dialer.Logf = Logf
149149
// Create a wireguard engine in userspace.
150-
engine, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{
150+
engine, err := wgengine.NewUserspaceEngine(Logf, wgengine.Config{
151151
LinkMonitor: wgMonitor,
152152
Dialer: dialer,
153153
})
@@ -172,7 +172,7 @@ func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) {
172172

173173
// Create the networking stack.
174174
// This is called to route connections.
175-
netStack, err := netstack.Create(logf, tunDev, engine, magicConn, dialer, dnsManager)
175+
netStack, err := netstack.Create(Logf, tunDev, engine, magicConn, dialer, dnsManager)
176176
if err != nil {
177177
return nil, xerrors.Errorf("create netstack: %w", err)
178178
}
@@ -192,7 +192,7 @@ func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) {
192192
engine = wgengine.NewWatchdog(engine)
193193

194194
// Update the wireguard configuration to allow traffic to flow.
195-
cfg, err := nmcfg.WGCfg(netMap, logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, netMap.SelfNode.StableID)
195+
cfg, err := nmcfg.WGCfg(netMap, Logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, netMap.SelfNode.StableID)
196196
if err != nil {
197197
return nil, xerrors.Errorf("create wgcfg: %w", err)
198198
}
@@ -216,7 +216,7 @@ func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) {
216216

217217
iplb := netaddr.IPSetBuilder{}
218218
ipl, _ := iplb.IPSet()
219-
engine.SetFilter(filter.New(netMap.PacketFilter, ips, ipl, nil, logf))
219+
engine.SetFilter(filter.New(netMap.PacketFilter, ips, ipl, nil, Logf))
220220

221221
wn := &Network{
222222
logger: logger,
@@ -319,7 +319,7 @@ func (n *Network) AddPeer(handshake Handshake) error {
319319

320320
n.netMap.Peers = peers
321321

322-
cfg, err := nmcfg.WGCfg(n.netMap, logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, tailcfg.StableNodeID("nBBoJZ5CNTRL"))
322+
cfg, err := nmcfg.WGCfg(n.netMap, Logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, tailcfg.StableNodeID("nBBoJZ5CNTRL"))
323323
if err != nil {
324324
return xerrors.Errorf("create wgcfg: %w", err)
325325
}
@@ -375,6 +375,12 @@ func (n *Network) Listen(network, addr string) (net.Listener, error) {
375375
}
376376

377377
func (n *Network) Close() error {
378+
// Close all listeners.
379+
for _, l := range n.listeners {
380+
_ = l.Close()
381+
}
382+
383+
// Close the Wireguard netstack and engine.
378384
_ = n.Netstack.Close()
379385
n.wgEngine.Close()
380386

0 commit comments

Comments
 (0)