-
Notifications
You must be signed in to change notification settings - Fork 998
refactor: PTY & SSH #7100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor: PTY & SSH #7100
Changes from 15 commits
0075d7d
a491d4f
2cf357a
28c0646
872e357
3f21e30
e83ff6e
b610579
90bfe94
d6e131c
2c9c6ef
e39e885
8ec3d1f
df424e6
50e3fec
c09083e
439107d
c6a3229
aa94546
0f07cb9
eaabb3a
50060d8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -191,7 +191,7 @@ func (s *Server) sessionHandler(session ssh.Session) { | |
_ = session.Exit(0) | ||
} | ||
|
||
func (s *Server) sessionStart(session ssh.Session) (retErr error) { | ||
func (s *Server) sessionStart(session ssh.Session) error { | ||
ctx := session.Context() | ||
env := session.Environ() | ||
var magicType string | ||
|
@@ -233,102 +233,12 @@ func (s *Server) sessionStart(session ssh.Session) (retErr error) { | |
|
||
sshPty, windowSize, isPty := session.Pty() | ||
if isPty { | ||
// Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL). | ||
// See https://github.com/coder/coder/issues/3371. | ||
session.DisablePTYEmulation() | ||
|
||
if !isQuietLogin(session.RawCommand()) { | ||
manifest := s.Manifest.Load() | ||
if manifest != nil { | ||
err = showMOTD(session, manifest.MOTDFile) | ||
if err != nil { | ||
s.logger.Error(ctx, "show MOTD", slog.Error(err)) | ||
} | ||
} else { | ||
s.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD") | ||
} | ||
} | ||
|
||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term)) | ||
|
||
// The pty package sets `SSH_TTY` on supported platforms. | ||
ptty, process, err := pty.Start(cmd, pty.WithPTYOption( | ||
pty.WithSSHRequest(sshPty), | ||
pty.WithLogger(slog.Stdlib(ctx, s.logger, slog.LevelInfo)), | ||
)) | ||
if err != nil { | ||
return xerrors.Errorf("start command: %w", err) | ||
} | ||
var wg sync.WaitGroup | ||
defer func() { | ||
defer wg.Wait() | ||
closeErr := ptty.Close() | ||
if closeErr != nil { | ||
s.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr)) | ||
if retErr == nil { | ||
retErr = closeErr | ||
} | ||
} | ||
}() | ||
go func() { | ||
for win := range windowSize { | ||
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width)) | ||
// If the pty is closed, then command has exited, no need to log. | ||
if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) { | ||
s.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr)) | ||
} | ||
} | ||
}() | ||
// We don't add input copy to wait group because | ||
// it won't return until the session is closed. | ||
go func() { | ||
_, _ = io.Copy(ptty.Input(), session) | ||
}() | ||
|
||
// In low parallelism scenarios, the command may exit and we may close | ||
// the pty before the output copy has started. This can result in the | ||
// output being lost. To avoid this, we wait for the output copy to | ||
// start before waiting for the command to exit. This ensures that the | ||
// output copy goroutine will be scheduled before calling close on the | ||
// pty. This shouldn't be needed because of `pty.Dup()` below, but it | ||
// may not be supported on all platforms. | ||
outputCopyStarted := make(chan struct{}) | ||
ptyOutput := func() io.ReadCloser { | ||
defer close(outputCopyStarted) | ||
// Try to dup so we can separate stdin and stdout closure. | ||
// Once the original pty is closed, the dup will return | ||
// input/output error once the buffered data has been read. | ||
stdout, err := ptty.Dup() | ||
if err == nil { | ||
return stdout | ||
} | ||
// If we can't dup, we shouldn't close | ||
// the fd since it's tied to stdin. | ||
return readNopCloser{ptty.Output()} | ||
} | ||
wg.Add(1) | ||
go func() { | ||
// Ensure data is flushed to session on command exit, if we | ||
// close the session too soon, we might lose data. | ||
defer wg.Done() | ||
|
||
stdout := ptyOutput() | ||
defer stdout.Close() | ||
|
||
_, _ = io.Copy(session, stdout) | ||
}() | ||
<-outputCopyStarted | ||
|
||
err = process.Wait() | ||
var exitErr *exec.ExitError | ||
// ExitErrors just mean the command we run returned a non-zero exit code, which is normal | ||
// and not something to be concerned about. But, if it's something else, we should log it. | ||
if err != nil && !xerrors.As(err, &exitErr) { | ||
s.logger.Warn(ctx, "wait error", slog.Error(err)) | ||
} | ||
return err | ||
return s.startPTYSession(session, cmd, sshPty, windowSize) | ||
} | ||
return startNonPTYSession(session, cmd) | ||
} | ||
|
||
func startNonPTYSession(session ssh.Session, cmd *exec.Cmd) error { | ||
cmd.Stdout = session | ||
cmd.Stderr = session.Stderr() | ||
// This blocks forever until stdin is received if we don't | ||
|
@@ -348,10 +258,94 @@ func (s *Server) sessionStart(session ssh.Session) (retErr error) { | |
return cmd.Wait() | ||
} | ||
|
||
type readNopCloser struct{ io.Reader } | ||
// ptySession is the interface to the ssh.Session that startPTYSession uses | ||
// we use an interface here so that we can fake it in tests. | ||
type ptySession interface { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason you chose a stripped down interface vs using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I couldn't see any way to directly instantiate any of the concrete implementations in like My first stab at the test I wanted to do created the whole server like the other ssh tests do, but I found the even after calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a bit annoying that I have to mock out the context methods, but unfortunately, Unfortunately the go type-checker isn't smart enough to understand that |
||
io.ReadWriter | ||
Context() ssh.Context | ||
DisablePTYEmulation() | ||
RawCommand() string | ||
} | ||
|
||
func (s *Server) startPTYSession(session ptySession, cmd *exec.Cmd, sshPty ssh.Pty, windowSize <-chan ssh.Window) (retErr error) { | ||
ctx := session.Context() | ||
// Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL). | ||
// See https://github.com/coder/coder/issues/3371. | ||
session.DisablePTYEmulation() | ||
|
||
if !isQuietLogin(session.RawCommand()) { | ||
manifest := s.Manifest.Load() | ||
if manifest != nil { | ||
err := showMOTD(session, manifest.MOTDFile) | ||
if err != nil { | ||
s.logger.Error(ctx, "show MOTD", slog.Error(err)) | ||
} | ||
} else { | ||
s.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD") | ||
} | ||
} | ||
|
||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term)) | ||
|
||
// The pty package sets `SSH_TTY` on supported platforms. | ||
ptty, process, err := pty.Start(cmd, pty.WithPTYOption( | ||
pty.WithSSHRequest(sshPty), | ||
pty.WithLogger(slog.Stdlib(ctx, s.logger, slog.LevelInfo)), | ||
)) | ||
if err != nil { | ||
return xerrors.Errorf("start command: %w", err) | ||
} | ||
defer func() { | ||
closeErr := ptty.Close() | ||
if closeErr != nil { | ||
s.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr)) | ||
if retErr == nil { | ||
retErr = closeErr | ||
} | ||
} | ||
}() | ||
go func() { | ||
for win := range windowSize { | ||
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width)) | ||
// If the pty is closed, then command has exited, no need to log. | ||
if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) { | ||
s.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr)) | ||
} | ||
} | ||
}() | ||
|
||
go func() { | ||
_, _ = io.Copy(ptty.InputWriter(), session) | ||
}() | ||
|
||
// Close implements io.Closer. | ||
func (readNopCloser) Close() error { return nil } | ||
// We need to wait for the command output to finish copying. It's safe to | ||
// just do this copy on the main handler goroutine because one of two things | ||
// will happen: | ||
// | ||
// 1. The command completes & closes the TTY, which then triggers an error | ||
// after we've Read() all the buffered data from the PTY. | ||
// 2. The client hangs up, which cancels the command's Context, and go will | ||
// kill the command's process. This then has the same effect as (1). | ||
n, err := io.Copy(session, ptty.OutputReader()) | ||
s.logger.Debug(ctx, "copy output done", slog.F("bytes", n), slog.Error(err)) | ||
if err != nil { | ||
return xerrors.Errorf("copy error: %w", err) | ||
} | ||
// We've gotten all the output, but we need to wait for the process to | ||
// complete so that we can get the exit code. This returns | ||
// immediately if the TTY was closed as part of the command exiting. | ||
err = process.Wait() | ||
var exitErr *exec.ExitError | ||
// ExitErrors just mean the command we run returned a non-zero exit code, which is normal | ||
// and not something to be concerned about. But, if it's something else, we should log it. | ||
if err != nil && !xerrors.As(err, &exitErr) { | ||
s.logger.Warn(ctx, "wait error", slog.Error(err)) | ||
} | ||
if err != nil { | ||
return xerrors.Errorf("process wait: %w", err) | ||
} | ||
return nil | ||
} | ||
|
||
func (s *Server) sftpHandler(session ssh.Session) { | ||
ctx := session.Context() | ||
|
Uh oh!
There was an error while loading. Please reload this page.