Skip to content

fix: pty.Start respects context on Windows too #7373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,12 @@ func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentM
// if it can guarantee the clocks are synchronized.
CollectedAt: time.Now(),
}
cmd, err := a.sshServer.CreateCommand(ctx, md.Script, nil)
cmdPty, err := a.sshServer.CreateCommand(ctx, md.Script, nil)
if err != nil {
result.Error = fmt.Sprintf("create cmd: %+v", err)
return result
}
cmd := cmdPty.AsExec()

cmd.Stdout = &out
cmd.Stderr = &out
Expand Down Expand Up @@ -842,10 +843,11 @@ func (a *agent) runScript(ctx context.Context, lifecycle, script string) error {
}()
}

cmd, err := a.sshServer.CreateCommand(ctx, script, nil)
cmdPty, err := a.sshServer.CreateCommand(ctx, script, nil)
if err != nil {
return xerrors.Errorf("create command: %w", err)
}
cmd := cmdPty.AsExec()
cmd.Stdout = writer
cmd.Stderr = writer
err = cmd.Run()
Expand Down Expand Up @@ -1044,16 +1046,6 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
circularBuffer: circularBuffer,
}
a.reconnectingPTYs.Store(msg.ID, rpty)
go func() {
// CommandContext isn't respected for Windows PTYs right now,
// so we need to manually track the lifecycle.
// When the context has been completed either:
// 1. The timeout completed.
// 2. The parent context was canceled.
<-ctx.Done()
logger.Debug(ctx, "context done", slog.Error(ctx.Err()))
_ = process.Kill()
}()
// We don't need to separately monitor for the process exiting.
// When it exits, our ptty.OutputReader() will return EOF after
// reading all process output.
Expand Down
3 changes: 1 addition & 2 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"net/http/httptest"
"net/netip"
"os"
"os/exec"
"os/user"
"path"
"path/filepath"
Expand Down Expand Up @@ -1697,7 +1696,7 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) (*pt
"host",
)
args = append(args, afterArgs...)
cmd := exec.Command("ssh", args...)
cmd := pty.Command("ssh", args...)
return ptytest.Start(t, cmd)
}

Expand Down
8 changes: 4 additions & 4 deletions agent/agentssh/agentssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ func (s *Server) sessionStart(session ssh.Session, extraEnv []string) (retErr er
if isPty {
return s.startPTYSession(session, cmd, sshPty, windowSize)
}
return startNonPTYSession(session, cmd)
return startNonPTYSession(session, cmd.AsExec())
}

func startNonPTYSession(session ssh.Session, cmd *exec.Cmd) error {
Expand Down Expand Up @@ -287,7 +287,7 @@ type ptySession interface {
RawCommand() string
}

func (s *Server) startPTYSession(session ptySession, cmd *exec.Cmd, sshPty ssh.Pty, windowSize <-chan ssh.Window) (retErr error) {
func (s *Server) startPTYSession(session ptySession, cmd *pty.Cmd, sshPty ssh.Pty, windowSize <-chan ssh.Window) (retErr error) {
ctx := session.Context()
// Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL).
// See https://github.com/coder/coder/issues/3371.
Expand Down Expand Up @@ -413,7 +413,7 @@ func (s *Server) sftpHandler(session ssh.Session) {
// CreateCommand processes raw command input with OpenSSH-like behavior.
// If the script provided is empty, it will default to the users shell.
// This injects environment variables specified by the user at launch too.
func (s *Server) CreateCommand(ctx context.Context, script string, env []string) (*exec.Cmd, error) {
func (s *Server) CreateCommand(ctx context.Context, script string, env []string) (*pty.Cmd, error) {
currentUser, err := user.Current()
if err != nil {
return nil, xerrors.Errorf("get current user: %w", err)
Expand Down Expand Up @@ -449,7 +449,7 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
}
}

cmd := exec.CommandContext(ctx, shell, args...)
cmd := pty.CommandContext(ctx, shell, args...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Perhaps we should discourage the use of non-context commands and rename this to pty.Command? Main reason stdlib has two is that it's a later addition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but then it wouldn't be a drop-in replacement.

I like keeping the same API as the standard library so people will find it familiar, even though it's tempting to "improve" the API.

cmd.Dir = manifest.Directory

// If the metadata directory doesn't exist, we run the command
Expand Down
4 changes: 2 additions & 2 deletions agent/agentssh/agentssh_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ import (
"context"
"io"
"net"
"os/exec"
"testing"

gliderssh "github.com/gliderlabs/ssh"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/coder/coder/pty"
"github.com/coder/coder/testutil"

"cdr.dev/slog/sloggers/slogtest"
Expand Down Expand Up @@ -52,7 +52,7 @@ func Test_sessionStart_orphan(t *testing.T) {
close(windowSize)
// the command gets the session context so that Go will terminate it when
// the session expires.
cmd := exec.CommandContext(sessionCtx, "sh", "-c", longScript)
cmd := pty.CommandContext(sessionCtx, "sh", "-c", longScript)

done := make(chan struct{})
go func() {
Expand Down
2 changes: 1 addition & 1 deletion cli/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ Expire-Date: 0
require.NoError(t, err, "import ownertrust failed: %s", out)

// Start the GPG agent.
agentCmd := exec.CommandContext(ctx, gpgAgentPath, "--no-detach", "--extra-socket", extraSocketPath)
agentCmd := pty.CommandContext(ctx, gpgAgentPath, "--no-detach", "--extra-socket", extraSocketPath)
agentCmd.Env = append(agentCmd.Env, "GNUPGHOME="+gnupgHomeClient)
agentPTY, agentProc, err := pty.Start(agentCmd, pty.WithPTYOption(pty.WithGPGTTY()))
require.NoError(t, err, "launch agent failed")
Expand Down
11 changes: 11 additions & 0 deletions pty/pty_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package pty

import (
"context"
"io"
"os"
"os/exec"
Expand Down Expand Up @@ -214,3 +215,13 @@ func (p *windowsProcess) Wait() error {
func (p *windowsProcess) Kill() error {
return p.proc.Kill()
}

// killOnContext waits for the context to be done and kills the process, unless it exits on its own first.
func (p *windowsProcess) killOnContext(ctx context.Context) {
select {
case <-p.cmdDone:
return
case <-ctx.Done():
p.Kill()
}
}
3 changes: 1 addition & 2 deletions pty/ptytest/ptytest.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"context"
"fmt"
"io"
"os/exec"
"runtime"
"strings"
"sync"
Expand Down Expand Up @@ -44,7 +43,7 @@ func New(t *testing.T, opts ...pty.Option) *PTY {

// Start starts a new process asynchronously and returns a PTYCmd and Process.
// It kills the process and PTYCmd upon cleanup
func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTYCmd, pty.Process) {
func Start(t *testing.T, cmd *pty.Cmd, opts ...pty.StartOption) (*PTYCmd, pty.Process) {
t.Helper()

ptty, ps, err := pty.Start(cmd, opts...)
Expand Down
37 changes: 36 additions & 1 deletion pty/start.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pty

import (
"context"
"os/exec"
)

Expand All @@ -18,8 +19,42 @@ func WithPTYOption(opts ...Option) StartOption {
}
}

// Cmd is a drop-in replacement for exec.Cmd with most of the same API, but
// it exposes the context.Context to our PTY code so that we can still kill the
// process when the Context expires. This is required because on Windows, we don't
// start the command using the `exec` library, so we have to manage the context
// ourselves.
type Cmd struct {
Context context.Context
Path string
Args []string
Env []string
Dir string
}

func CommandContext(ctx context.Context, name string, arg ...string) *Cmd {
return &Cmd{
Context: ctx,
Path: name,
Args: append([]string{name}, arg...),
Env: make([]string, 0),
}
}

func Command(name string, arg ...string) *Cmd {
return CommandContext(context.Background(), name, arg...)
}

func (c *Cmd) AsExec() *exec.Cmd {
//nolint: gosec
execCmd := exec.CommandContext(c.Context, c.Path, c.Args[1:]...)
execCmd.Dir = c.Dir
execCmd.Env = c.Env
return execCmd
}

// Start the command in a TTY. The calling code must not use cmd after passing it to the PTY, and
// instead rely on the returned Process to manage the command/process.
func Start(cmd *exec.Cmd, opt ...StartOption) (PTYCmd, Process, error) {
func Start(cmd *Cmd, opt ...StartOption) (PTYCmd, Process, error) {
return startPty(cmd, opt...)
}
32 changes: 18 additions & 14 deletions pty/start_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
package pty

import (
"context"
"fmt"
"os/exec"
"runtime"
"strings"
"syscall"

"golang.org/x/xerrors"
)

func startPty(cmd *exec.Cmd, opt ...StartOption) (retPTY *otherPty, proc Process, err error) {
func startPty(cmdPty *Cmd, opt ...StartOption) (retPTY *otherPty, proc Process, err error) {
var opts startOptions
for _, o := range opt {
o(&opts)
Expand All @@ -23,30 +23,34 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (retPTY *otherPty, proc Process
return nil, nil, xerrors.Errorf("newPty failed: %w", err)
}

origEnv := cmd.Env
origEnv := cmdPty.Env
if opty.opts.sshReq != nil {
cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_TTY=%s", opty.Name()))
cmdPty.Env = append(cmdPty.Env, fmt.Sprintf("SSH_TTY=%s", opty.Name()))
}
if opty.opts.setGPGTTY {
cmd.Env = append(cmd.Env, fmt.Sprintf("GPG_TTY=%s", opty.Name()))
cmdPty.Env = append(cmdPty.Env, fmt.Sprintf("GPG_TTY=%s", opty.Name()))
}
if cmdPty.Context == nil {
cmdPty.Context = context.Background()
}
cmdExec := cmdPty.AsExec()

cmd.SysProcAttr = &syscall.SysProcAttr{
cmdExec.SysProcAttr = &syscall.SysProcAttr{
Setsid: true,
Setctty: true,
}
cmd.Stdout = opty.tty
cmd.Stderr = opty.tty
cmd.Stdin = opty.tty
err = cmd.Start()
cmdExec.Stdout = opty.tty
cmdExec.Stderr = opty.tty
cmdExec.Stdin = opty.tty
err = cmdExec.Start()
if err != nil {
_ = opty.Close()
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "bad file descriptor") {
// macOS has an obscure issue where the PTY occasionally closes
// before it's used. It's unknown why this is, but creating a new
// TTY resolves it.
cmd.Env = origEnv
return startPty(cmd, opt...)
cmdPty.Env = origEnv
return startPty(cmdPty, opt...)
}
return nil, nil, xerrors.Errorf("start: %w", err)
}
Expand All @@ -64,14 +68,14 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (retPTY *otherPty, proc Process
// confirming this, but I did find a thread of someone else's
// observations: https://developer.apple.com/forums/thread/663632
if err := opty.tty.Close(); err != nil {
_ = cmd.Process.Kill()
_ = cmdExec.Process.Kill()
return nil, nil, xerrors.Errorf("close tty: %w", err)
}
opty.tty = nil // remove so we don't attempt to close it again.
}
oProcess := &otherProcess{
pty: opty.pty,
cmd: cmd,
cmd: cmdExec,
cmdDone: make(chan any),
}
go oProcess.waitInternal()
Expand Down
12 changes: 9 additions & 3 deletions pty/start_other_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestStart(t *testing.T) {
t.Parallel()
t.Run("Echo", func(t *testing.T) {
t.Parallel()
pty, ps := ptytest.Start(t, exec.Command("echo", "test"))
pty, ps := ptytest.Start(t, pty.Command("echo", "test"))

pty.ExpectMatch("test")
err := ps.Wait()
Expand All @@ -35,7 +35,7 @@ func TestStart(t *testing.T) {

t.Run("Kill", func(t *testing.T) {
t.Parallel()
pty, ps := ptytest.Start(t, exec.Command("sleep", "30"))
pty, ps := ptytest.Start(t, pty.Command("sleep", "30"))
err := ps.Kill()
assert.NoError(t, err)
err = ps.Wait()
Expand All @@ -54,7 +54,7 @@ func TestStart(t *testing.T) {
Height: 24,
},
}))
pty, ps := ptytest.Start(t, exec.Command("env"), opts)
pty, ps := ptytest.Start(t, pty.Command("env"), opts)
pty.ExpectMatch("SSH_TTY=/dev/")
err := ps.Wait()
require.NoError(t, err)
Expand Down Expand Up @@ -84,3 +84,9 @@ do
echo "$i"
done
`}

// these constants/vars are used by Test_Start_cancel_context

const cmdSleep = "sleep"

var argSleep = []string{"30"}
Loading