@@ -51,6 +51,9 @@ func ssh() *cobra.Command {
51
51
Short : "SSH into a workspace" ,
52
52
Args : cobra .ArbitraryArgs ,
53
53
RunE : func (cmd * cobra.Command , args []string ) error {
54
+ ctx , cancel := context .WithCancel (cmd .Context ())
55
+ defer cancel ()
56
+
54
57
client , err := createClient (cmd )
55
58
if err != nil {
56
59
return err
@@ -68,14 +71,14 @@ func ssh() *cobra.Command {
68
71
}
69
72
}
70
73
71
- workspace , workspaceAgent , err := getWorkspaceAndAgent (cmd , client , codersdk .Me , args [0 ], shuffle )
74
+ workspace , workspaceAgent , err := getWorkspaceAndAgent (ctx , cmd , client , codersdk .Me , args [0 ], shuffle )
72
75
if err != nil {
73
76
return err
74
77
}
75
78
76
79
// OpenSSH passes stderr directly to the calling TTY.
77
80
// This is required in "stdio" mode so a connecting indicator can be displayed.
78
- err = cliui .Agent (cmd . Context () , cmd .ErrOrStderr (), cliui.AgentOptions {
81
+ err = cliui .Agent (ctx , cmd .ErrOrStderr (), cliui.AgentOptions {
79
82
WorkspaceName : workspace .Name ,
80
83
Fetch : func (ctx context.Context ) (codersdk.WorkspaceAgent , error ) {
81
84
return client .WorkspaceAgent (ctx , workspaceAgent .ID )
@@ -85,42 +88,33 @@ func ssh() *cobra.Command {
85
88
return xerrors .Errorf ("await agent: %w" , err )
86
89
}
87
90
88
- var (
89
- sshClient * gossh.Client
90
- sshSession * gossh.Session
91
- )
91
+ var newSSHClient func () (* gossh.Client , error )
92
92
93
93
if ! wireguard {
94
- conn , err := client .DialWorkspaceAgent (cmd . Context () , workspaceAgent .ID , nil )
94
+ conn , err := client .DialWorkspaceAgent (ctx , workspaceAgent .ID , nil )
95
95
if err != nil {
96
96
return err
97
97
}
98
98
defer conn .Close ()
99
99
100
- stopPolling := tryPollWorkspaceAutostop (cmd . Context () , client , workspace )
100
+ stopPolling := tryPollWorkspaceAutostop (ctx , client , workspace )
101
101
defer stopPolling ()
102
102
103
103
if stdio {
104
104
rawSSH , err := conn .SSH ()
105
105
if err != nil {
106
106
return err
107
107
}
108
+ defer rawSSH .Close ()
109
+
108
110
go func () {
109
111
_ , _ = io .Copy (cmd .OutOrStdout (), rawSSH )
110
112
}()
111
113
_ , _ = io .Copy (rawSSH , cmd .InOrStdin ())
112
114
return nil
113
115
}
114
116
115
- sshClient , err = conn .SSHClient ()
116
- if err != nil {
117
- return err
118
- }
119
-
120
- sshSession , err = sshClient .NewSession ()
121
- if err != nil {
122
- return err
123
- }
117
+ newSSHClient = conn .SSHClient
124
118
} else {
125
119
// TODO: more granual control of Tailscale logging.
126
120
peerwg .Logf = tslogger .Discard
@@ -133,8 +127,9 @@ func ssh() *cobra.Command {
133
127
if err != nil {
134
128
return xerrors .Errorf ("create wireguard network: %w" , err )
135
129
}
130
+ defer wgn .Close ()
136
131
137
- err = client .PostWireguardPeer (cmd . Context () , workspace .ID , peerwg.Handshake {
132
+ err = client .PostWireguardPeer (ctx , workspace .ID , peerwg.Handshake {
138
133
Recipient : workspaceAgent .ID ,
139
134
NodePublicKey : wgn .NodePrivateKey .Public (),
140
135
DiscoPublicKey : wgn .DiscoPublicKey ,
@@ -155,10 +150,11 @@ func ssh() *cobra.Command {
155
150
}
156
151
157
152
if stdio {
158
- rawSSH , err := wgn .SSH (cmd . Context () , workspaceAgent .IPv6 .IP ())
153
+ rawSSH , err := wgn .SSH (ctx , workspaceAgent .IPv6 .IP ())
159
154
if err != nil {
160
155
return err
161
156
}
157
+ defer rawSSH .Close ()
162
158
163
159
go func () {
164
160
_ , _ = io .Copy (cmd .OutOrStdout (), rawSSH )
@@ -167,16 +163,29 @@ func ssh() *cobra.Command {
167
163
return nil
168
164
}
169
165
170
- sshClient , err = wgn .SSHClient (cmd .Context (), workspaceAgent .IPv6 .IP ())
171
- if err != nil {
172
- return err
166
+ newSSHClient = func () (* gossh.Client , error ) {
167
+ return wgn .SSHClient (ctx , workspaceAgent .IPv6 .IP ())
173
168
}
169
+ }
174
170
175
- sshSession , err = sshClient .NewSession ()
176
- if err != nil {
177
- return err
178
- }
171
+ sshClient , err := newSSHClient ()
172
+ if err != nil {
173
+ return err
174
+ }
175
+ defer sshClient .Close ()
176
+
177
+ sshSession , err := sshClient .NewSession ()
178
+ if err != nil {
179
+ return err
179
180
}
181
+ defer sshSession .Close ()
182
+
183
+ // Ensure context cancellation is propagated to the
184
+ // SSH session, e.g. to cancel `Wait()` at the end.
185
+ go func () {
186
+ <- ctx .Done ()
187
+ _ = sshSession .Close ()
188
+ }()
180
189
181
190
if identityAgent == "" {
182
191
identityAgent = os .Getenv ("SSH_AUTH_SOCK" )
@@ -203,15 +212,18 @@ func ssh() *cobra.Command {
203
212
_ = term .Restore (int (stdinFile .Fd ()), state )
204
213
}()
205
214
206
- windowChange := listenWindowSize (cmd . Context () )
215
+ windowChange := listenWindowSize (ctx )
207
216
go func () {
208
217
for {
209
218
select {
210
- case <- cmd . Context () .Done ():
219
+ case <- ctx .Done ():
211
220
return
212
221
case <- windowChange :
213
222
}
214
- width , height , _ := term .GetSize (int (stdoutFile .Fd ()))
223
+ width , height , err := term .GetSize (int (stdoutFile .Fd ()))
224
+ if err != nil {
225
+ continue
226
+ }
215
227
_ = sshSession .WindowChange (height , width )
216
228
}
217
229
}()
@@ -231,6 +243,10 @@ func ssh() *cobra.Command {
231
243
return err
232
244
}
233
245
246
+ // Put cancel at the top of the defer stack to initiate
247
+ // shutdown of services.
248
+ defer cancel ()
249
+
234
250
err = sshSession .Wait ()
235
251
if err != nil {
236
252
// If the connection drops unexpectedly, we get an ExitMissingError but no other
@@ -259,16 +275,14 @@ func ssh() *cobra.Command {
259
275
// getWorkspaceAgent returns the workspace and agent selected using either the
260
276
// `<workspace>[.<agent>]` syntax via `in` or picks a random workspace and agent
261
277
// if `shuffle` is true.
262
- func getWorkspaceAndAgent (cmd * cobra.Command , client * codersdk.Client , userID string , in string , shuffle bool ) (codersdk.Workspace , codersdk.WorkspaceAgent , error ) { //nolint:revive
263
- ctx := cmd .Context ()
264
-
278
+ func getWorkspaceAndAgent (ctx context.Context , cmd * cobra.Command , client * codersdk.Client , userID string , in string , shuffle bool ) (codersdk.Workspace , codersdk.WorkspaceAgent , error ) { //nolint:revive
265
279
var (
266
280
workspace codersdk.Workspace
267
281
workspaceParts = strings .Split (in , "." )
268
282
err error
269
283
)
270
284
if shuffle {
271
- workspaces , err := client .Workspaces (cmd . Context () , codersdk.WorkspaceFilter {
285
+ workspaces , err := client .Workspaces (ctx , codersdk.WorkspaceFilter {
272
286
Owner : codersdk .Me ,
273
287
})
274
288
if err != nil {
0 commit comments