Skip to content

Commit e080b14

Browse files
committed
feat: add ssh support over wireguard
1 parent 1157303 commit e080b14

File tree

4 files changed

+152
-34
lines changed

4 files changed

+152
-34
lines changed

agent/wireguard.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package agent
22

33
import (
44
"context"
5+
"net"
6+
"strconv"
57

68
"golang.org/x/xerrors"
79
"inet.af/netaddr"
@@ -58,6 +60,38 @@ func (a *agent) startWireguard(ctx context.Context, addrs []netaddr.IPPrefix) er
5860
}
5961
}()
6062

63+
a.startWireguardListeners(ctx, wg, []handlerPort{
64+
{port: 12212, handler: a.sshServer.HandleConn},
65+
})
66+
6167
a.network = wg
6268
return nil
6369
}
70+
71+
type handlerPort struct {
72+
handler func(conn net.Conn)
73+
port uint16
74+
}
75+
76+
func (a *agent) startWireguardListeners(ctx context.Context, network *peerwg.Network, handlers []handlerPort) {
77+
for _, h := range handlers {
78+
go func(h handlerPort) {
79+
a.logger.Debug(ctx, "starting wireguard listener", slog.F("port", h.port))
80+
81+
listener, err := network.Listen("tcp", net.JoinHostPort("", strconv.Itoa(int(h.port))))
82+
if err != nil {
83+
a.logger.Warn(ctx, "listen wireguard", slog.F("port", h.port), slog.Error(err))
84+
return
85+
}
86+
87+
for {
88+
conn, err := listener.Accept()
89+
if err != nil {
90+
return
91+
}
92+
93+
go h.handler(conn)
94+
}
95+
}(h)
96+
}
97+
}

cli/root.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func Root() *cobra.Command {
8080
schedules(),
8181
server(),
8282
show(),
83-
ssh(),
83+
sshCmd(),
8484
start(),
8585
state(),
8686
stop(),

cli/ssh.go

Lines changed: 103 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,29 +14,36 @@ import (
1414
"github.com/google/uuid"
1515
"github.com/mattn/go-isatty"
1616
"github.com/spf13/cobra"
17+
"golang.org/x/crypto/ssh"
1718
gossh "golang.org/x/crypto/ssh"
1819
gosshagent "golang.org/x/crypto/ssh/agent"
1920
"golang.org/x/term"
2021
"golang.org/x/xerrors"
22+
"inet.af/netaddr"
23+
tslogger "tailscale.com/types/logger"
2124

25+
"cdr.dev/slog"
26+
"cdr.dev/slog/sloggers/sloghuman"
2227
"github.com/coder/coder/cli/cliflag"
2328
"github.com/coder/coder/cli/cliui"
2429
"github.com/coder/coder/coderd/autobuild/notify"
2530
"github.com/coder/coder/coderd/util/ptr"
2631
"github.com/coder/coder/codersdk"
2732
"github.com/coder/coder/cryptorand"
33+
"github.com/coder/coder/peer/peerwg"
2834
)
2935

3036
var workspacePollInterval = time.Minute
3137
var autostopNotifyCountdown = []time.Duration{30 * time.Minute}
3238

33-
func ssh() *cobra.Command {
39+
func sshCmd() *cobra.Command {
3440
var (
3541
stdio bool
3642
shuffle bool
3743
forwardAgent bool
3844
identityAgent string
3945
wsPollInterval time.Duration
46+
wireguard bool
4047
)
4148
cmd := &cobra.Command{
4249
Annotations: workspaceCommand,
@@ -61,7 +68,7 @@ func ssh() *cobra.Command {
6168
}
6269
}
6370

64-
workspace, agent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], shuffle)
71+
workspace, workspaceAgent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], shuffle)
6572
if err != nil {
6673
return err
6774
}
@@ -71,41 +78,111 @@ func ssh() *cobra.Command {
7178
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
7279
WorkspaceName: workspace.Name,
7380
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
74-
return client.WorkspaceAgent(ctx, agent.ID)
81+
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
7582
},
7683
})
7784
if err != nil {
7885
return xerrors.Errorf("await agent: %w", err)
7986
}
8087

81-
conn, err := client.DialWorkspaceAgent(cmd.Context(), agent.ID, nil)
82-
if err != nil {
83-
return err
84-
}
85-
defer conn.Close()
88+
var (
89+
sshClient *gossh.Client
90+
sshSession *gossh.Session
91+
)
8692

87-
stopPolling := tryPollWorkspaceAutostop(cmd.Context(), client, workspace)
88-
defer stopPolling()
93+
if !wireguard {
94+
conn, err := client.DialWorkspaceAgent(cmd.Context(), workspaceAgent.ID, nil)
95+
if err != nil {
96+
return err
97+
}
98+
defer conn.Close()
8999

90-
if stdio {
91-
rawSSH, err := conn.SSH()
100+
stopPolling := tryPollWorkspaceAutostop(cmd.Context(), client, workspace)
101+
defer stopPolling()
102+
103+
if stdio {
104+
rawSSH, err := conn.SSH()
105+
if err != nil {
106+
return err
107+
}
108+
go func() {
109+
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
110+
}()
111+
_, _ = io.Copy(rawSSH, cmd.InOrStdin())
112+
return nil
113+
}
114+
115+
sshClient, err = conn.SSHClient()
92116
if err != nil {
93117
return err
94118
}
95-
go func() {
96-
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
97-
}()
98-
_, _ = io.Copy(rawSSH, cmd.InOrStdin())
99-
return nil
100-
}
101-
sshClient, err := conn.SSHClient()
102-
if err != nil {
103-
return err
104-
}
105119

106-
sshSession, err := sshClient.NewSession()
107-
if err != nil {
108-
return err
120+
sshSession, err = sshClient.NewSession()
121+
if err != nil {
122+
return err
123+
}
124+
} else {
125+
// TODO: more granual control of Tailscale logging.
126+
peerwg.Logf = tslogger.Discard
127+
128+
ipv6 := peerwg.UUIDToNetaddr(uuid.New())
129+
wgn, err := peerwg.New(
130+
slog.Make(sloghuman.Sink(os.Stderr)),
131+
[]netaddr.IPPrefix{netaddr.IPPrefixFrom(ipv6, 128)},
132+
)
133+
if err != nil {
134+
return xerrors.Errorf("create wireguard network: %w", err)
135+
}
136+
137+
err = client.PostWireguardPeer(cmd.Context(), workspace.ID, peerwg.Handshake{
138+
Recipient: workspaceAgent.ID,
139+
NodePublicKey: wgn.NodePrivateKey.Public(),
140+
DiscoPublicKey: wgn.DiscoPublicKey,
141+
IPv6: ipv6,
142+
})
143+
if err != nil {
144+
return xerrors.Errorf("post wireguard peer: %w", err)
145+
}
146+
147+
err = wgn.AddPeer(peerwg.Handshake{
148+
Recipient: workspaceAgent.ID,
149+
DiscoPublicKey: workspaceAgent.DiscoPublicKey,
150+
NodePublicKey: workspaceAgent.WireguardPublicKey,
151+
IPv6: workspaceAgent.IPv6.IP(),
152+
})
153+
if err != nil {
154+
return xerrors.Errorf("add workspace agent as peer: %w", err)
155+
}
156+
157+
netConn, err := wgn.Netstack.DialContextTCP(cmd.Context(), netaddr.IPPortFrom(workspaceAgent.IPv6.IP(), 12212))
158+
if err != nil {
159+
return xerrors.Errorf("add workspace agent ssh: %w", err)
160+
}
161+
162+
if stdio {
163+
go func() {
164+
_, _ = io.Copy(cmd.OutOrStdout(), netConn)
165+
}()
166+
_, _ = io.Copy(netConn, cmd.InOrStdin())
167+
return nil
168+
}
169+
170+
sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{
171+
// SSH host validation isn't helpful, because obtaining a peer
172+
// connection already signifies user-intent to dial a workspace.
173+
// #nosec
174+
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
175+
})
176+
if err != nil {
177+
return xerrors.Errorf("ssh conn: %w", err)
178+
}
179+
180+
sshClient = ssh.NewClient(sshConn, channels, requests)
181+
182+
sshSession, err = sshClient.NewSession()
183+
if err != nil {
184+
return err
185+
}
109186
}
110187

111188
if identityAgent == "" {
@@ -174,6 +251,7 @@ func ssh() *cobra.Command {
174251
cliflag.BoolVarP(cmd.Flags(), &forwardAgent, "forward-agent", "A", "CODER_SSH_FORWARD_AGENT", false, "Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK")
175252
cliflag.StringVarP(cmd.Flags(), &identityAgent, "identity-agent", "", "CODER_SSH_IDENTITY_AGENT", "", "Specifies which identity agent to use (overrides $SSH_AUTH_SOCK), forward agent must also be enabled")
176253
cliflag.DurationVarP(cmd.Flags(), &wsPollInterval, "workspace-poll-interval", "", "CODER_WORKSPACE_POLL_INTERVAL", workspacePollInterval, "Specifies how often to poll for workspace automated shutdown.")
254+
cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_SSH_WIREGUARD", true, "Whether to use Wireguard for SSH tunneling.")
177255

178256
return cmd
179257
}

peer/peerwg/wireguard.go

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import (
3535
"cdr.dev/slog"
3636
)
3737

38-
var logf tslogger.Logf = log.Printf
38+
var Logf tslogger.Logf = log.Printf
3939

4040
func init() {
4141
// Globally disable network namespacing.
@@ -139,15 +139,15 @@ func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) {
139139
DERP: DefaultDerpHome,
140140
}
141141

142-
wgMonitor, err := monitor.New(logf)
142+
wgMonitor, err := monitor.New(Logf)
143143
if err != nil {
144144
return nil, xerrors.Errorf("create link monitor: %w", err)
145145
}
146146

147147
dialer := new(tsdial.Dialer)
148-
dialer.Logf = logf
148+
dialer.Logf = Logf
149149
// Create a wireguard engine in userspace.
150-
engine, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{
150+
engine, err := wgengine.NewUserspaceEngine(Logf, wgengine.Config{
151151
LinkMonitor: wgMonitor,
152152
Dialer: dialer,
153153
})
@@ -172,7 +172,7 @@ func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) {
172172

173173
// Create the networking stack.
174174
// This is called to route connections.
175-
netStack, err := netstack.Create(logf, tunDev, engine, magicConn, dialer, dnsManager)
175+
netStack, err := netstack.Create(Logf, tunDev, engine, magicConn, dialer, dnsManager)
176176
if err != nil {
177177
return nil, xerrors.Errorf("create netstack: %w", err)
178178
}
@@ -192,7 +192,7 @@ func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) {
192192
engine = wgengine.NewWatchdog(engine)
193193

194194
// Update the wireguard configuration to allow traffic to flow.
195-
cfg, err := nmcfg.WGCfg(netMap, logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, netMap.SelfNode.StableID)
195+
cfg, err := nmcfg.WGCfg(netMap, Logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, netMap.SelfNode.StableID)
196196
if err != nil {
197197
return nil, xerrors.Errorf("create wgcfg: %w", err)
198198
}
@@ -216,7 +216,7 @@ func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) {
216216

217217
iplb := netaddr.IPSetBuilder{}
218218
ipl, _ := iplb.IPSet()
219-
engine.SetFilter(filter.New(netMap.PacketFilter, ips, ipl, nil, logf))
219+
engine.SetFilter(filter.New(netMap.PacketFilter, ips, ipl, nil, Logf))
220220

221221
wn := &Network{
222222
logger: logger,
@@ -319,7 +319,7 @@ func (n *Network) AddPeer(handshake Handshake) error {
319319

320320
n.netMap.Peers = peers
321321

322-
cfg, err := nmcfg.WGCfg(n.netMap, logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, tailcfg.StableNodeID("nBBoJZ5CNTRL"))
322+
cfg, err := nmcfg.WGCfg(n.netMap, Logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, tailcfg.StableNodeID("nBBoJZ5CNTRL"))
323323
if err != nil {
324324
return xerrors.Errorf("create wgcfg: %w", err)
325325
}
@@ -375,6 +375,12 @@ func (n *Network) Listen(network, addr string) (net.Listener, error) {
375375
}
376376

377377
func (n *Network) Close() error {
378+
// Close all listeners.
379+
for _, l := range n.listeners {
380+
_ = l.Close()
381+
}
382+
383+
// Close the Wireguard netstack and engine.
378384
_ = n.Netstack.Close()
379385
n.wgEngine.Close()
380386

0 commit comments

Comments
 (0)