@@ -253,102 +253,12 @@ func (s *Server) sessionStart(session ssh.Session, extraEnv []string) (retErr er
253
253
254
254
sshPty , windowSize , isPty := session .Pty ()
255
255
if isPty {
256
- // Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL).
257
- // See https://github.com/coder/coder/issues/3371.
258
- session .DisablePTYEmulation ()
259
-
260
- if ! isQuietLogin (session .RawCommand ()) {
261
- manifest := s .Manifest .Load ()
262
- if manifest != nil {
263
- err = showMOTD (session , manifest .MOTDFile )
264
- if err != nil {
265
- s .logger .Error (ctx , "show MOTD" , slog .Error (err ))
266
- }
267
- } else {
268
- s .logger .Warn (ctx , "metadata lookup failed, unable to show MOTD" )
269
- }
270
- }
271
-
272
- cmd .Env = append (cmd .Env , fmt .Sprintf ("TERM=%s" , sshPty .Term ))
273
-
274
- // The pty package sets `SSH_TTY` on supported platforms.
275
- ptty , process , err := pty .Start (cmd , pty .WithPTYOption (
276
- pty .WithSSHRequest (sshPty ),
277
- pty .WithLogger (slog .Stdlib (ctx , s .logger , slog .LevelInfo )),
278
- ))
279
- if err != nil {
280
- return xerrors .Errorf ("start command: %w" , err )
281
- }
282
- var wg sync.WaitGroup
283
- defer func () {
284
- defer wg .Wait ()
285
- closeErr := ptty .Close ()
286
- if closeErr != nil {
287
- s .logger .Warn (ctx , "failed to close tty" , slog .Error (closeErr ))
288
- if retErr == nil {
289
- retErr = closeErr
290
- }
291
- }
292
- }()
293
- go func () {
294
- for win := range windowSize {
295
- resizeErr := ptty .Resize (uint16 (win .Height ), uint16 (win .Width ))
296
- // If the pty is closed, then command has exited, no need to log.
297
- if resizeErr != nil && ! errors .Is (resizeErr , pty .ErrClosed ) {
298
- s .logger .Warn (ctx , "failed to resize tty" , slog .Error (resizeErr ))
299
- }
300
- }
301
- }()
302
- // We don't add input copy to wait group because
303
- // it won't return until the session is closed.
304
- go func () {
305
- _ , _ = io .Copy (ptty .Input (), session )
306
- }()
307
-
308
- // In low parallelism scenarios, the command may exit and we may close
309
- // the pty before the output copy has started. This can result in the
310
- // output being lost. To avoid this, we wait for the output copy to
311
- // start before waiting for the command to exit. This ensures that the
312
- // output copy goroutine will be scheduled before calling close on the
313
- // pty. This shouldn't be needed because of `pty.Dup()` below, but it
314
- // may not be supported on all platforms.
315
- outputCopyStarted := make (chan struct {})
316
- ptyOutput := func () io.ReadCloser {
317
- defer close (outputCopyStarted )
318
- // Try to dup so we can separate stdin and stdout closure.
319
- // Once the original pty is closed, the dup will return
320
- // input/output error once the buffered data has been read.
321
- stdout , err := ptty .Dup ()
322
- if err == nil {
323
- return stdout
324
- }
325
- // If we can't dup, we shouldn't close
326
- // the fd since it's tied to stdin.
327
- return readNopCloser {ptty .Output ()}
328
- }
329
- wg .Add (1 )
330
- go func () {
331
- // Ensure data is flushed to session on command exit, if we
332
- // close the session too soon, we might lose data.
333
- defer wg .Done ()
334
-
335
- stdout := ptyOutput ()
336
- defer stdout .Close ()
337
-
338
- _ , _ = io .Copy (session , stdout )
339
- }()
340
- <- outputCopyStarted
341
-
342
- err = process .Wait ()
343
- var exitErr * exec.ExitError
344
- // ExitErrors just mean the command we run returned a non-zero exit code, which is normal
345
- // and not something to be concerned about. But, if it's something else, we should log it.
346
- if err != nil && ! xerrors .As (err , & exitErr ) {
347
- s .logger .Warn (ctx , "wait error" , slog .Error (err ))
348
- }
349
- return err
256
+ return s .startPTYSession (session , cmd , sshPty , windowSize )
350
257
}
258
+ return startNonPTYSession (session , cmd )
259
+ }
351
260
261
+ func startNonPTYSession (session ssh.Session , cmd * exec.Cmd ) error {
352
262
cmd .Stdout = session
353
263
cmd .Stderr = session .Stderr ()
354
264
// This blocks forever until stdin is received if we don't
@@ -368,10 +278,94 @@ func (s *Server) sessionStart(session ssh.Session, extraEnv []string) (retErr er
368
278
return cmd .Wait ()
369
279
}
370
280
371
- type readNopCloser struct { io.Reader }
281
+ // ptySession is the interface to the ssh.Session that startPTYSession uses
282
+ // we use an interface here so that we can fake it in tests.
283
+ type ptySession interface {
284
+ io.ReadWriter
285
+ Context () ssh.Context
286
+ DisablePTYEmulation ()
287
+ RawCommand () string
288
+ }
289
+
290
+ func (s * Server ) startPTYSession (session ptySession , cmd * exec.Cmd , sshPty ssh.Pty , windowSize <- chan ssh.Window ) (retErr error ) {
291
+ ctx := session .Context ()
292
+ // Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL).
293
+ // See https://github.com/coder/coder/issues/3371.
294
+ session .DisablePTYEmulation ()
295
+
296
+ if ! isQuietLogin (session .RawCommand ()) {
297
+ manifest := s .Manifest .Load ()
298
+ if manifest != nil {
299
+ err := showMOTD (session , manifest .MOTDFile )
300
+ if err != nil {
301
+ s .logger .Error (ctx , "show MOTD" , slog .Error (err ))
302
+ }
303
+ } else {
304
+ s .logger .Warn (ctx , "metadata lookup failed, unable to show MOTD" )
305
+ }
306
+ }
307
+
308
+ cmd .Env = append (cmd .Env , fmt .Sprintf ("TERM=%s" , sshPty .Term ))
309
+
310
+ // The pty package sets `SSH_TTY` on supported platforms.
311
+ ptty , process , err := pty .Start (cmd , pty .WithPTYOption (
312
+ pty .WithSSHRequest (sshPty ),
313
+ pty .WithLogger (slog .Stdlib (ctx , s .logger , slog .LevelInfo )),
314
+ ))
315
+ if err != nil {
316
+ return xerrors .Errorf ("start command: %w" , err )
317
+ }
318
+ defer func () {
319
+ closeErr := ptty .Close ()
320
+ if closeErr != nil {
321
+ s .logger .Warn (ctx , "failed to close tty" , slog .Error (closeErr ))
322
+ if retErr == nil {
323
+ retErr = closeErr
324
+ }
325
+ }
326
+ }()
327
+ go func () {
328
+ for win := range windowSize {
329
+ resizeErr := ptty .Resize (uint16 (win .Height ), uint16 (win .Width ))
330
+ // If the pty is closed, then command has exited, no need to log.
331
+ if resizeErr != nil && ! errors .Is (resizeErr , pty .ErrClosed ) {
332
+ s .logger .Warn (ctx , "failed to resize tty" , slog .Error (resizeErr ))
333
+ }
334
+ }
335
+ }()
336
+
337
+ go func () {
338
+ _ , _ = io .Copy (ptty .InputWriter (), session )
339
+ }()
372
340
373
- // Close implements io.Closer.
374
- func (readNopCloser ) Close () error { return nil }
341
+ // We need to wait for the command output to finish copying. It's safe to
342
+ // just do this copy on the main handler goroutine because one of two things
343
+ // will happen:
344
+ //
345
+ // 1. The command completes & closes the TTY, which then triggers an error
346
+ // after we've Read() all the buffered data from the PTY.
347
+ // 2. The client hangs up, which cancels the command's Context, and go will
348
+ // kill the command's process. This then has the same effect as (1).
349
+ n , err := io .Copy (session , ptty .OutputReader ())
350
+ s .logger .Debug (ctx , "copy output done" , slog .F ("bytes" , n ), slog .Error (err ))
351
+ if err != nil {
352
+ return xerrors .Errorf ("copy error: %w" , err )
353
+ }
354
+ // We've gotten all the output, but we need to wait for the process to
355
+ // complete so that we can get the exit code. This returns
356
+ // immediately if the TTY was closed as part of the command exiting.
357
+ err = process .Wait ()
358
+ var exitErr * exec.ExitError
359
+ // ExitErrors just mean the command we run returned a non-zero exit code, which is normal
360
+ // and not something to be concerned about. But, if it's something else, we should log it.
361
+ if err != nil && ! xerrors .As (err , & exitErr ) {
362
+ s .logger .Warn (ctx , "wait error" , slog .Error (err ))
363
+ }
364
+ if err != nil {
365
+ return xerrors .Errorf ("process wait: %w" , err )
366
+ }
367
+ return nil
368
+ }
375
369
376
370
func (s * Server ) sftpHandler (session ssh.Session ) {
377
371
ctx := session .Context ()
0 commit comments