Skip to content

Commit 9c030a8

Browse files
authored
fix: pty.Start respects context on Windows too (#7373)
* fix: pty.Start respects context on Windows too Signed-off-by: Spike Curtis <spike@coder.com> * Fix windows imports; rename ToExec -> AsExec Signed-off-by: Spike Curtis <spike@coder.com> * Fix import in windows test Signed-off-by: Spike Curtis <spike@coder.com> --------- Signed-off-by: Spike Curtis <spike@coder.com>
1 parent e6931d6 commit 9c030a8

13 files changed

+132
-48
lines changed

agent/agent.go

+4-12
Original file line numberDiff line numberDiff line change
@@ -216,11 +216,12 @@ func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentM
216216
// if it can guarantee the clocks are synchronized.
217217
CollectedAt: time.Now(),
218218
}
219-
cmd, err := a.sshServer.CreateCommand(ctx, md.Script, nil)
219+
cmdPty, err := a.sshServer.CreateCommand(ctx, md.Script, nil)
220220
if err != nil {
221221
result.Error = fmt.Sprintf("create cmd: %+v", err)
222222
return result
223223
}
224+
cmd := cmdPty.AsExec()
224225

225226
cmd.Stdout = &out
226227
cmd.Stderr = &out
@@ -842,10 +843,11 @@ func (a *agent) runScript(ctx context.Context, lifecycle, script string) error {
842843
}()
843844
}
844845

845-
cmd, err := a.sshServer.CreateCommand(ctx, script, nil)
846+
cmdPty, err := a.sshServer.CreateCommand(ctx, script, nil)
846847
if err != nil {
847848
return xerrors.Errorf("create command: %w", err)
848849
}
850+
cmd := cmdPty.AsExec()
849851
cmd.Stdout = writer
850852
cmd.Stderr = writer
851853
err = cmd.Run()
@@ -1044,16 +1046,6 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10441046
circularBuffer: circularBuffer,
10451047
}
10461048
a.reconnectingPTYs.Store(msg.ID, rpty)
1047-
go func() {
1048-
// CommandContext isn't respected for Windows PTYs right now,
1049-
// so we need to manually track the lifecycle.
1050-
// When the context has been completed either:
1051-
// 1. The timeout completed.
1052-
// 2. The parent context was canceled.
1053-
<-ctx.Done()
1054-
logger.Debug(ctx, "context done", slog.Error(ctx.Err()))
1055-
_ = process.Kill()
1056-
}()
10571049
// We don't need to separately monitor for the process exiting.
10581050
// When it exits, our ptty.OutputReader() will return EOF after
10591051
// reading all process output.

agent/agent_test.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212
"net/http/httptest"
1313
"net/netip"
1414
"os"
15-
"os/exec"
1615
"os/user"
1716
"path"
1817
"path/filepath"
@@ -1697,7 +1696,7 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) (*pt
16971696
"host",
16981697
)
16991698
args = append(args, afterArgs...)
1700-
cmd := exec.Command("ssh", args...)
1699+
cmd := pty.Command("ssh", args...)
17011700
return ptytest.Start(t, cmd)
17021701
}
17031702

agent/agentssh/agentssh.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ func (s *Server) sessionStart(session ssh.Session, extraEnv []string) (retErr er
255255
if isPty {
256256
return s.startPTYSession(session, cmd, sshPty, windowSize)
257257
}
258-
return startNonPTYSession(session, cmd)
258+
return startNonPTYSession(session, cmd.AsExec())
259259
}
260260

261261
func startNonPTYSession(session ssh.Session, cmd *exec.Cmd) error {
@@ -287,7 +287,7 @@ type ptySession interface {
287287
RawCommand() string
288288
}
289289

290-
func (s *Server) startPTYSession(session ptySession, cmd *exec.Cmd, sshPty ssh.Pty, windowSize <-chan ssh.Window) (retErr error) {
290+
func (s *Server) startPTYSession(session ptySession, cmd *pty.Cmd, sshPty ssh.Pty, windowSize <-chan ssh.Window) (retErr error) {
291291
ctx := session.Context()
292292
// Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL).
293293
// See https://github.com/coder/coder/issues/3371.
@@ -413,7 +413,7 @@ func (s *Server) sftpHandler(session ssh.Session) {
413413
// CreateCommand processes raw command input with OpenSSH-like behavior.
414414
// If the script provided is empty, it will default to the users shell.
415415
// This injects environment variables specified by the user at launch too.
416-
func (s *Server) CreateCommand(ctx context.Context, script string, env []string) (*exec.Cmd, error) {
416+
func (s *Server) CreateCommand(ctx context.Context, script string, env []string) (*pty.Cmd, error) {
417417
currentUser, err := user.Current()
418418
if err != nil {
419419
return nil, xerrors.Errorf("get current user: %w", err)
@@ -449,7 +449,7 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
449449
}
450450
}
451451

452-
cmd := exec.CommandContext(ctx, shell, args...)
452+
cmd := pty.CommandContext(ctx, shell, args...)
453453
cmd.Dir = manifest.Directory
454454

455455
// If the metadata directory doesn't exist, we run the command

agent/agentssh/agentssh_internal_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ import (
77
"context"
88
"io"
99
"net"
10-
"os/exec"
1110
"testing"
1211

1312
gliderssh "github.com/gliderlabs/ssh"
1413
"github.com/spf13/afero"
1514
"github.com/stretchr/testify/assert"
1615
"github.com/stretchr/testify/require"
1716

17+
"github.com/coder/coder/pty"
1818
"github.com/coder/coder/testutil"
1919

2020
"cdr.dev/slog/sloggers/slogtest"
@@ -52,7 +52,7 @@ func Test_sessionStart_orphan(t *testing.T) {
5252
close(windowSize)
5353
// the command gets the session context so that Go will terminate it when
5454
// the session expires.
55-
cmd := exec.CommandContext(sessionCtx, "sh", "-c", longScript)
55+
cmd := pty.CommandContext(sessionCtx, "sh", "-c", longScript)
5656

5757
done := make(chan struct{})
5858
go func() {

cli/ssh_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ Expire-Date: 0
540540
require.NoError(t, err, "import ownertrust failed: %s", out)
541541

542542
// Start the GPG agent.
543-
agentCmd := exec.CommandContext(ctx, gpgAgentPath, "--no-detach", "--extra-socket", extraSocketPath)
543+
agentCmd := pty.CommandContext(ctx, gpgAgentPath, "--no-detach", "--extra-socket", extraSocketPath)
544544
agentCmd.Env = append(agentCmd.Env, "GNUPGHOME="+gnupgHomeClient)
545545
agentPTY, agentProc, err := pty.Start(agentCmd, pty.WithPTYOption(pty.WithGPGTTY()))
546546
require.NoError(t, err, "launch agent failed")

pty/pty_windows.go

+11
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
package pty
44

55
import (
6+
"context"
67
"io"
78
"os"
89
"os/exec"
@@ -214,3 +215,13 @@ func (p *windowsProcess) Wait() error {
214215
func (p *windowsProcess) Kill() error {
215216
return p.proc.Kill()
216217
}
218+
219+
// killOnContext waits for the context to be done and kills the process, unless it exits on its own first.
220+
func (p *windowsProcess) killOnContext(ctx context.Context) {
221+
select {
222+
case <-p.cmdDone:
223+
return
224+
case <-ctx.Done():
225+
p.Kill()
226+
}
227+
}

pty/ptytest/ptytest.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"context"
77
"fmt"
88
"io"
9-
"os/exec"
109
"runtime"
1110
"strings"
1211
"sync"
@@ -44,7 +43,7 @@ func New(t *testing.T, opts ...pty.Option) *PTY {
4443

4544
// Start starts a new process asynchronously and returns a PTYCmd and Process.
4645
// It kills the process and PTYCmd upon cleanup
47-
func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTYCmd, pty.Process) {
46+
func Start(t *testing.T, cmd *pty.Cmd, opts ...pty.StartOption) (*PTYCmd, pty.Process) {
4847
t.Helper()
4948

5049
ptty, ps, err := pty.Start(cmd, opts...)

pty/start.go

+36-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package pty
22

33
import (
4+
"context"
45
"os/exec"
56
)
67

@@ -18,8 +19,42 @@ func WithPTYOption(opts ...Option) StartOption {
1819
}
1920
}
2021

22+
// Cmd is a drop-in replacement for exec.Cmd with most of the same API, but
23+
// it exposes the context.Context to our PTY code so that we can still kill the
24+
// process when the Context expires. This is required because on Windows, we don't
25+
// start the command using the `exec` library, so we have to manage the context
26+
// ourselves.
27+
type Cmd struct {
28+
Context context.Context
29+
Path string
30+
Args []string
31+
Env []string
32+
Dir string
33+
}
34+
35+
func CommandContext(ctx context.Context, name string, arg ...string) *Cmd {
36+
return &Cmd{
37+
Context: ctx,
38+
Path: name,
39+
Args: append([]string{name}, arg...),
40+
Env: make([]string, 0),
41+
}
42+
}
43+
44+
func Command(name string, arg ...string) *Cmd {
45+
return CommandContext(context.Background(), name, arg...)
46+
}
47+
48+
func (c *Cmd) AsExec() *exec.Cmd {
49+
//nolint: gosec
50+
execCmd := exec.CommandContext(c.Context, c.Path, c.Args[1:]...)
51+
execCmd.Dir = c.Dir
52+
execCmd.Env = c.Env
53+
return execCmd
54+
}
55+
2156
// Start the command in a TTY. The calling code must not use cmd after passing it to the PTY, and
2257
// instead rely on the returned Process to manage the command/process.
23-
func Start(cmd *exec.Cmd, opt ...StartOption) (PTYCmd, Process, error) {
58+
func Start(cmd *Cmd, opt ...StartOption) (PTYCmd, Process, error) {
2459
return startPty(cmd, opt...)
2560
}

pty/start_other.go

+18-14
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33
package pty
44

55
import (
6+
"context"
67
"fmt"
7-
"os/exec"
88
"runtime"
99
"strings"
1010
"syscall"
1111

1212
"golang.org/x/xerrors"
1313
)
1414

15-
func startPty(cmd *exec.Cmd, opt ...StartOption) (retPTY *otherPty, proc Process, err error) {
15+
func startPty(cmdPty *Cmd, opt ...StartOption) (retPTY *otherPty, proc Process, err error) {
1616
var opts startOptions
1717
for _, o := range opt {
1818
o(&opts)
@@ -23,30 +23,34 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (retPTY *otherPty, proc Process
2323
return nil, nil, xerrors.Errorf("newPty failed: %w", err)
2424
}
2525

26-
origEnv := cmd.Env
26+
origEnv := cmdPty.Env
2727
if opty.opts.sshReq != nil {
28-
cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_TTY=%s", opty.Name()))
28+
cmdPty.Env = append(cmdPty.Env, fmt.Sprintf("SSH_TTY=%s", opty.Name()))
2929
}
3030
if opty.opts.setGPGTTY {
31-
cmd.Env = append(cmd.Env, fmt.Sprintf("GPG_TTY=%s", opty.Name()))
31+
cmdPty.Env = append(cmdPty.Env, fmt.Sprintf("GPG_TTY=%s", opty.Name()))
3232
}
33+
if cmdPty.Context == nil {
34+
cmdPty.Context = context.Background()
35+
}
36+
cmdExec := cmdPty.AsExec()
3337

34-
cmd.SysProcAttr = &syscall.SysProcAttr{
38+
cmdExec.SysProcAttr = &syscall.SysProcAttr{
3539
Setsid: true,
3640
Setctty: true,
3741
}
38-
cmd.Stdout = opty.tty
39-
cmd.Stderr = opty.tty
40-
cmd.Stdin = opty.tty
41-
err = cmd.Start()
42+
cmdExec.Stdout = opty.tty
43+
cmdExec.Stderr = opty.tty
44+
cmdExec.Stdin = opty.tty
45+
err = cmdExec.Start()
4246
if err != nil {
4347
_ = opty.Close()
4448
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "bad file descriptor") {
4549
// macOS has an obscure issue where the PTY occasionally closes
4650
// before it's used. It's unknown why this is, but creating a new
4751
// TTY resolves it.
48-
cmd.Env = origEnv
49-
return startPty(cmd, opt...)
52+
cmdPty.Env = origEnv
53+
return startPty(cmdPty, opt...)
5054
}
5155
return nil, nil, xerrors.Errorf("start: %w", err)
5256
}
@@ -64,14 +68,14 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (retPTY *otherPty, proc Process
6468
// confirming this, but I did find a thread of someone else's
6569
// observations: https://developer.apple.com/forums/thread/663632
6670
if err := opty.tty.Close(); err != nil {
67-
_ = cmd.Process.Kill()
71+
_ = cmdExec.Process.Kill()
6872
return nil, nil, xerrors.Errorf("close tty: %w", err)
6973
}
7074
opty.tty = nil // remove so we don't attempt to close it again.
7175
}
7276
oProcess := &otherProcess{
7377
pty: opty.pty,
74-
cmd: cmd,
78+
cmd: cmdExec,
7579
cmdDone: make(chan any),
7680
}
7781
go oProcess.waitInternal()

pty/start_other_test.go

+9-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func TestStart(t *testing.T) {
2424
t.Parallel()
2525
t.Run("Echo", func(t *testing.T) {
2626
t.Parallel()
27-
pty, ps := ptytest.Start(t, exec.Command("echo", "test"))
27+
pty, ps := ptytest.Start(t, pty.Command("echo", "test"))
2828

2929
pty.ExpectMatch("test")
3030
err := ps.Wait()
@@ -35,7 +35,7 @@ func TestStart(t *testing.T) {
3535

3636
t.Run("Kill", func(t *testing.T) {
3737
t.Parallel()
38-
pty, ps := ptytest.Start(t, exec.Command("sleep", "30"))
38+
pty, ps := ptytest.Start(t, pty.Command("sleep", "30"))
3939
err := ps.Kill()
4040
assert.NoError(t, err)
4141
err = ps.Wait()
@@ -54,7 +54,7 @@ func TestStart(t *testing.T) {
5454
Height: 24,
5555
},
5656
}))
57-
pty, ps := ptytest.Start(t, exec.Command("env"), opts)
57+
pty, ps := ptytest.Start(t, pty.Command("env"), opts)
5858
pty.ExpectMatch("SSH_TTY=/dev/")
5959
err := ps.Wait()
6060
require.NoError(t, err)
@@ -84,3 +84,9 @@ do
8484
echo "$i"
8585
done
8686
`}
87+
88+
// these constants/vars are used by Test_Start_cancel_context
89+
90+
const cmdSleep = "sleep"
91+
92+
var argSleep = []string{"30"}

0 commit comments

Comments
 (0)