@@ -14,29 +14,36 @@ import (
14
14
"github.com/google/uuid"
15
15
"github.com/mattn/go-isatty"
16
16
"github.com/spf13/cobra"
17
+ "golang.org/x/crypto/ssh"
17
18
gossh "golang.org/x/crypto/ssh"
18
19
gosshagent "golang.org/x/crypto/ssh/agent"
19
20
"golang.org/x/term"
20
21
"golang.org/x/xerrors"
22
+ "inet.af/netaddr"
23
+ tslogger "tailscale.com/types/logger"
21
24
25
+ "cdr.dev/slog"
26
+ "cdr.dev/slog/sloggers/sloghuman"
22
27
"github.com/coder/coder/cli/cliflag"
23
28
"github.com/coder/coder/cli/cliui"
24
29
"github.com/coder/coder/coderd/autobuild/notify"
25
30
"github.com/coder/coder/coderd/util/ptr"
26
31
"github.com/coder/coder/codersdk"
27
32
"github.com/coder/coder/cryptorand"
33
+ "github.com/coder/coder/peer/peerwg"
28
34
)
29
35
30
36
var workspacePollInterval = time .Minute
31
37
var autostopNotifyCountdown = []time.Duration {30 * time .Minute }
32
38
33
- func ssh () * cobra.Command {
39
+ func sshCmd () * cobra.Command {
34
40
var (
35
41
stdio bool
36
42
shuffle bool
37
43
forwardAgent bool
38
44
identityAgent string
39
45
wsPollInterval time.Duration
46
+ wireguard bool
40
47
)
41
48
cmd := & cobra.Command {
42
49
Annotations : workspaceCommand ,
@@ -61,7 +68,7 @@ func ssh() *cobra.Command {
61
68
}
62
69
}
63
70
64
- workspace , agent , err := getWorkspaceAndAgent (cmd , client , codersdk .Me , args [0 ], shuffle )
71
+ workspace , workspaceAgent , err := getWorkspaceAndAgent (cmd , client , codersdk .Me , args [0 ], shuffle )
65
72
if err != nil {
66
73
return err
67
74
}
@@ -71,41 +78,111 @@ func ssh() *cobra.Command {
71
78
err = cliui .Agent (cmd .Context (), cmd .ErrOrStderr (), cliui.AgentOptions {
72
79
WorkspaceName : workspace .Name ,
73
80
Fetch : func (ctx context.Context ) (codersdk.WorkspaceAgent , error ) {
74
- return client .WorkspaceAgent (ctx , agent .ID )
81
+ return client .WorkspaceAgent (ctx , workspaceAgent .ID )
75
82
},
76
83
})
77
84
if err != nil {
78
85
return xerrors .Errorf ("await agent: %w" , err )
79
86
}
80
87
81
- conn , err := client .DialWorkspaceAgent (cmd .Context (), agent .ID , nil )
82
- if err != nil {
83
- return err
84
- }
85
- defer conn .Close ()
88
+ var (
89
+ sshClient * gossh.Client
90
+ sshSession * gossh.Session
91
+ )
86
92
87
- stopPolling := tryPollWorkspaceAutostop (cmd .Context (), client , workspace )
88
- defer stopPolling ()
93
+ if ! wireguard {
94
+ conn , err := client .DialWorkspaceAgent (cmd .Context (), workspaceAgent .ID , nil )
95
+ if err != nil {
96
+ return err
97
+ }
98
+ defer conn .Close ()
89
99
90
- if stdio {
91
- rawSSH , err := conn .SSH ()
100
+ stopPolling := tryPollWorkspaceAutostop (cmd .Context (), client , workspace )
101
+ defer stopPolling ()
102
+
103
+ if stdio {
104
+ rawSSH , err := conn .SSH ()
105
+ if err != nil {
106
+ return err
107
+ }
108
+ go func () {
109
+ _ , _ = io .Copy (cmd .OutOrStdout (), rawSSH )
110
+ }()
111
+ _ , _ = io .Copy (rawSSH , cmd .InOrStdin ())
112
+ return nil
113
+ }
114
+
115
+ sshClient , err = conn .SSHClient ()
92
116
if err != nil {
93
117
return err
94
118
}
95
- go func () {
96
- _ , _ = io .Copy (cmd .OutOrStdout (), rawSSH )
97
- }()
98
- _ , _ = io .Copy (rawSSH , cmd .InOrStdin ())
99
- return nil
100
- }
101
- sshClient , err := conn .SSHClient ()
102
- if err != nil {
103
- return err
104
- }
105
119
106
- sshSession , err := sshClient .NewSession ()
107
- if err != nil {
108
- return err
120
+ sshSession , err = sshClient .NewSession ()
121
+ if err != nil {
122
+ return err
123
+ }
124
+ } else {
125
+ // TODO: more granual control of Tailscale logging.
126
+ peerwg .Logf = tslogger .Discard
127
+
128
+ ipv6 := peerwg .UUIDToNetaddr (uuid .New ())
129
+ wgn , err := peerwg .New (
130
+ slog .Make (sloghuman .Sink (os .Stderr )),
131
+ []netaddr.IPPrefix {netaddr .IPPrefixFrom (ipv6 , 128 )},
132
+ )
133
+ if err != nil {
134
+ return xerrors .Errorf ("create wireguard network: %w" , err )
135
+ }
136
+
137
+ err = client .PostWireguardPeer (cmd .Context (), workspace .ID , peerwg.Handshake {
138
+ Recipient : workspaceAgent .ID ,
139
+ NodePublicKey : wgn .NodePrivateKey .Public (),
140
+ DiscoPublicKey : wgn .DiscoPublicKey ,
141
+ IPv6 : ipv6 ,
142
+ })
143
+ if err != nil {
144
+ return xerrors .Errorf ("post wireguard peer: %w" , err )
145
+ }
146
+
147
+ err = wgn .AddPeer (peerwg.Handshake {
148
+ Recipient : workspaceAgent .ID ,
149
+ DiscoPublicKey : workspaceAgent .DiscoPublicKey ,
150
+ NodePublicKey : workspaceAgent .WireguardPublicKey ,
151
+ IPv6 : workspaceAgent .IPv6 .IP (),
152
+ })
153
+ if err != nil {
154
+ return xerrors .Errorf ("add workspace agent as peer: %w" , err )
155
+ }
156
+
157
+ netConn , err := wgn .Netstack .DialContextTCP (cmd .Context (), netaddr .IPPortFrom (workspaceAgent .IPv6 .IP (), 12212 ))
158
+ if err != nil {
159
+ return xerrors .Errorf ("add workspace agent ssh: %w" , err )
160
+ }
161
+
162
+ if stdio {
163
+ go func () {
164
+ _ , _ = io .Copy (cmd .OutOrStdout (), netConn )
165
+ }()
166
+ _ , _ = io .Copy (netConn , cmd .InOrStdin ())
167
+ return nil
168
+ }
169
+
170
+ sshConn , channels , requests , err := ssh .NewClientConn (netConn , "localhost:22" , & ssh.ClientConfig {
171
+ // SSH host validation isn't helpful, because obtaining a peer
172
+ // connection already signifies user-intent to dial a workspace.
173
+ // #nosec
174
+ HostKeyCallback : ssh .InsecureIgnoreHostKey (),
175
+ })
176
+ if err != nil {
177
+ return xerrors .Errorf ("ssh conn: %w" , err )
178
+ }
179
+
180
+ sshClient = ssh .NewClient (sshConn , channels , requests )
181
+
182
+ sshSession , err = sshClient .NewSession ()
183
+ if err != nil {
184
+ return err
185
+ }
109
186
}
110
187
111
188
if identityAgent == "" {
@@ -174,6 +251,7 @@ func ssh() *cobra.Command {
174
251
cliflag .BoolVarP (cmd .Flags (), & forwardAgent , "forward-agent" , "A" , "CODER_SSH_FORWARD_AGENT" , false , "Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK" )
175
252
cliflag .StringVarP (cmd .Flags (), & identityAgent , "identity-agent" , "" , "CODER_SSH_IDENTITY_AGENT" , "" , "Specifies which identity agent to use (overrides $SSH_AUTH_SOCK), forward agent must also be enabled" )
176
253
cliflag .DurationVarP (cmd .Flags (), & wsPollInterval , "workspace-poll-interval" , "" , "CODER_WORKSPACE_POLL_INTERVAL" , workspacePollInterval , "Specifies how often to poll for workspace automated shutdown." )
254
+ cliflag .BoolVarP (cmd .Flags (), & wireguard , "wireguard" , "" , "CODER_SSH_WIREGUARD" , true , "Whether to use Wireguard for SSH tunneling." )
177
255
178
256
return cmd
179
257
}
0 commit comments