@@ -18,13 +18,18 @@ import (
18
18
gosshagent "golang.org/x/crypto/ssh/agent"
19
19
"golang.org/x/term"
20
20
"golang.org/x/xerrors"
21
+ "inet.af/netaddr"
22
+ tslogger "tailscale.com/types/logger"
21
23
24
+ "cdr.dev/slog"
25
+ "cdr.dev/slog/sloggers/sloghuman"
22
26
"github.com/coder/coder/cli/cliflag"
23
27
"github.com/coder/coder/cli/cliui"
24
28
"github.com/coder/coder/coderd/autobuild/notify"
25
29
"github.com/coder/coder/coderd/util/ptr"
26
30
"github.com/coder/coder/codersdk"
27
31
"github.com/coder/coder/cryptorand"
32
+ "github.com/coder/coder/peer/peerwg"
28
33
)
29
34
30
35
var workspacePollInterval = time .Minute
@@ -37,6 +42,7 @@ func ssh() *cobra.Command {
37
42
forwardAgent bool
38
43
identityAgent string
39
44
wsPollInterval time.Duration
45
+ wireguard bool
40
46
)
41
47
cmd := & cobra.Command {
42
48
Annotations : workspaceCommand ,
@@ -61,7 +67,7 @@ func ssh() *cobra.Command {
61
67
}
62
68
}
63
69
64
- workspace , agent , err := getWorkspaceAndAgent (cmd , client , codersdk .Me , args [0 ], shuffle )
70
+ workspace , workspaceAgent , err := getWorkspaceAndAgent (cmd , client , codersdk .Me , args [0 ], shuffle )
65
71
if err != nil {
66
72
return err
67
73
}
@@ -71,41 +77,104 @@ func ssh() *cobra.Command {
71
77
err = cliui .Agent (cmd .Context (), cmd .ErrOrStderr (), cliui.AgentOptions {
72
78
WorkspaceName : workspace .Name ,
73
79
Fetch : func (ctx context.Context ) (codersdk.WorkspaceAgent , error ) {
74
- return client .WorkspaceAgent (ctx , agent .ID )
80
+ return client .WorkspaceAgent (ctx , workspaceAgent .ID )
75
81
},
76
82
})
77
83
if err != nil {
78
84
return xerrors .Errorf ("await agent: %w" , err )
79
85
}
80
86
81
- conn , err := client .DialWorkspaceAgent (cmd .Context (), agent .ID , nil )
82
- if err != nil {
83
- return err
84
- }
85
- defer conn .Close ()
87
+ var (
88
+ sshClient * gossh.Client
89
+ sshSession * gossh.Session
90
+ )
86
91
87
- stopPolling := tryPollWorkspaceAutostop (cmd .Context (), client , workspace )
88
- defer stopPolling ()
92
+ if ! wireguard {
93
+ conn , err := client .DialWorkspaceAgent (cmd .Context (), workspaceAgent .ID , nil )
94
+ if err != nil {
95
+ return err
96
+ }
97
+ defer conn .Close ()
89
98
90
- if stdio {
91
- rawSSH , err := conn .SSH ()
99
+ stopPolling := tryPollWorkspaceAutostop (cmd .Context (), client , workspace )
100
+ defer stopPolling ()
101
+
102
+ if stdio {
103
+ rawSSH , err := conn .SSH ()
104
+ if err != nil {
105
+ return err
106
+ }
107
+ go func () {
108
+ _ , _ = io .Copy (cmd .OutOrStdout (), rawSSH )
109
+ }()
110
+ _ , _ = io .Copy (rawSSH , cmd .InOrStdin ())
111
+ return nil
112
+ }
113
+
114
+ sshClient , err = conn .SSHClient ()
92
115
if err != nil {
93
116
return err
94
117
}
95
- go func () {
96
- _ , _ = io .Copy (cmd .OutOrStdout (), rawSSH )
97
- }()
98
- _ , _ = io .Copy (rawSSH , cmd .InOrStdin ())
99
- return nil
100
- }
101
- sshClient , err := conn .SSHClient ()
102
- if err != nil {
103
- return err
104
- }
105
118
106
- sshSession , err := sshClient .NewSession ()
107
- if err != nil {
108
- return err
119
+ sshSession , err = sshClient .NewSession ()
120
+ if err != nil {
121
+ return err
122
+ }
123
+ } else {
124
+ // TODO: more granual control of Tailscale logging.
125
+ peerwg .Logf = tslogger .Discard
126
+
127
+ ipv6 := peerwg .UUIDToNetaddr (uuid .New ())
128
+ wgn , err := peerwg .New (
129
+ slog .Make (sloghuman .Sink (os .Stderr )),
130
+ []netaddr.IPPrefix {netaddr .IPPrefixFrom (ipv6 , 128 )},
131
+ )
132
+ if err != nil {
133
+ return xerrors .Errorf ("create wireguard network: %w" , err )
134
+ }
135
+
136
+ err = client .PostWireguardPeer (cmd .Context (), workspace .ID , peerwg.Handshake {
137
+ Recipient : workspaceAgent .ID ,
138
+ NodePublicKey : wgn .NodePrivateKey .Public (),
139
+ DiscoPublicKey : wgn .DiscoPublicKey ,
140
+ IPv6 : ipv6 ,
141
+ })
142
+ if err != nil {
143
+ return xerrors .Errorf ("post wireguard peer: %w" , err )
144
+ }
145
+
146
+ err = wgn .AddPeer (peerwg.Handshake {
147
+ Recipient : workspaceAgent .ID ,
148
+ DiscoPublicKey : workspaceAgent .DiscoPublicKey ,
149
+ NodePublicKey : workspaceAgent .WireguardPublicKey ,
150
+ IPv6 : workspaceAgent .IPv6 .IP (),
151
+ })
152
+ if err != nil {
153
+ return xerrors .Errorf ("add workspace agent as peer: %w" , err )
154
+ }
155
+
156
+ if stdio {
157
+ rawSSH , err := wgn .SSH (cmd .Context (), workspaceAgent .IPv6 .IP ())
158
+ if err != nil {
159
+ return err
160
+ }
161
+
162
+ go func () {
163
+ _ , _ = io .Copy (cmd .OutOrStdout (), rawSSH )
164
+ }()
165
+ _ , _ = io .Copy (rawSSH , cmd .InOrStdin ())
166
+ return nil
167
+ }
168
+
169
+ sshClient , err = wgn .SSHClient (cmd .Context (), workspaceAgent .IPv6 .IP ())
170
+ if err != nil {
171
+ return err
172
+ }
173
+
174
+ sshSession , err = sshClient .NewSession ()
175
+ if err != nil {
176
+ return err
177
+ }
109
178
}
110
179
111
180
if identityAgent == "" {
@@ -174,6 +243,7 @@ func ssh() *cobra.Command {
174
243
cliflag .BoolVarP (cmd .Flags (), & forwardAgent , "forward-agent" , "A" , "CODER_SSH_FORWARD_AGENT" , false , "Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK" )
175
244
cliflag .StringVarP (cmd .Flags (), & identityAgent , "identity-agent" , "" , "CODER_SSH_IDENTITY_AGENT" , "" , "Specifies which identity agent to use (overrides $SSH_AUTH_SOCK), forward agent must also be enabled" )
176
245
cliflag .DurationVarP (cmd .Flags (), & wsPollInterval , "workspace-poll-interval" , "" , "CODER_WORKSPACE_POLL_INTERVAL" , workspacePollInterval , "Specifies how often to poll for workspace automated shutdown." )
246
+ cliflag .BoolVarP (cmd .Flags (), & wireguard , "wireguard" , "" , "CODER_SSH_WIREGUARD" , true , "Whether to use Wireguard for SSH tunneling." )
177
247
178
248
return cmd
179
249
}
0 commit comments