Skip to content

Commit b610579

Browse files
committed
WIP Windows pty handling
Signed-off-by: Spike Curtis <spike@coder.com>
1 parent e83ff6e commit b610579

File tree

3 files changed

+121
-4
lines changed

3 files changed

+121
-4
lines changed

pty/pty_windows.go

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ type windowsProcess struct {
8989
cmdDone chan any
9090
cmdErr error
9191
proc *os.Process
92+
pw *ptyWindows
9293
}
9394

9495
// Name returns the TTY name on Windows.
@@ -140,9 +141,12 @@ func (p *ptyWindows) Close() error {
140141
}
141142
p.closed = true
142143

143-
ret, _, err := procClosePseudoConsole.Call(uintptr(p.console))
144-
if ret < 0 {
145-
return xerrors.Errorf("close pseudo console: %w", err)
144+
if p.console != windows.InvalidHandle {
145+
ret, _, err := procClosePseudoConsole.Call(uintptr(p.console))
146+
if ret < 0 {
147+
return xerrors.Errorf("close pseudo console: %w", err)
148+
}
149+
p.console = windows.InvalidHandle
146150
}
147151

148152
// We always have these files
@@ -159,6 +163,19 @@ func (p *ptyWindows) Close() error {
159163
}
160164

161165
func (p *windowsProcess) waitInternal() {
166+
defer func() {
167+
// close the pseudoconsole handle when the process exits, if it hasn't already been closed.
168+
p.pw.closeMutex.Lock()
169+
defer p.pw.closeMutex.Unlock()
170+
if p.pw.console != windows.InvalidHandle {
171+
ret, _, err := procClosePseudoConsole.Call(uintptr(p.pw.console))
172+
if ret < 0 {
173+
// not much we can do here...
174+
panic(err)
175+
}
176+
p.pw.console = windows.InvalidHandle
177+
}
178+
}()
162179
defer close(p.cmdDone)
163180
state, err := p.proc.Wait()
164181
if err != nil {

pty/start_windows.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (PTY, Process, error) {
111111
wp := &windowsProcess{
112112
cmdDone: make(chan any),
113113
proc: process,
114+
pw: winPty,
114115
}
115116
go wp.waitInternal()
116117
return pty, wp, nil

pty/start_windows_test.go

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ package pty_test
66
import (
77
"bytes"
88
"context"
9+
"fmt"
910
"io"
1011
"os/exec"
12+
"strings"
1113
"testing"
1214
"time"
1315

@@ -77,7 +79,7 @@ func Test_Start_copy(t *testing.T) {
7779
case <-ctx.Done():
7880
t.Error("read timed out")
7981
}
80-
assert.Equal(t, "test", b.String())
82+
assert.Contains(t, b.String(), "test")
8183

8284
cmdDone := make(chan error)
8385
go func() {
@@ -91,3 +93,100 @@ func Test_Start_copy(t *testing.T) {
9193
t.Error("cmd.Wait() timed out")
9294
}
9395
}
96+
97+
const countEnd = 1000
98+
99+
func Test_Start_trucation(t *testing.T) {
100+
t.Parallel()
101+
102+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*1000)
103+
defer cancel()
104+
105+
pc, cmd, err := pty.Start(exec.CommandContext(ctx,
106+
"cmd.exe", "/c",
107+
fmt.Sprintf("for /L %%n in (1,1,%d) do @echo %%n", countEnd)))
108+
require.NoError(t, err)
109+
readDone := make(chan struct{})
110+
go func() {
111+
defer close(readDone)
112+
// avoid buffered IO so that we can precisely control how many bytes to read.
113+
n := 1
114+
for n < countEnd-25 {
115+
want := fmt.Sprintf("%d\r\n", n)
116+
// the output also contains virtual terminal sequences
117+
// so just read until we see the number we want.
118+
err := readUntil(ctx, want, pc.OutputReader())
119+
require.NoError(t, err, "want: %s", want)
120+
n++
121+
}
122+
}()
123+
124+
select {
125+
case <-readDone:
126+
// OK!
127+
case <-ctx.Done():
128+
t.Error("read timed out")
129+
}
130+
131+
cmdDone := make(chan error)
132+
go func() {
133+
cmdDone <- cmd.Wait()
134+
}()
135+
136+
select {
137+
case err := <-cmdDone:
138+
require.NoError(t, err)
139+
case <-ctx.Done():
140+
t.Error("cmd.Wait() timed out")
141+
}
142+
143+
// do our final 25 reads, to make sure the output wasn't lost
144+
readDone = make(chan struct{})
145+
go func() {
146+
defer close(readDone)
147+
// avoid buffered IO so that we can precisely control how many bytes to read.
148+
n := countEnd - 25
149+
for n <= countEnd {
150+
want := fmt.Sprintf("%d\r\n", n)
151+
err := readUntil(ctx, want, pc.OutputReader())
152+
if n < countEnd {
153+
require.NoError(t, err, "want: %s", want)
154+
} else {
155+
require.ErrorIs(t, err, io.EOF)
156+
}
157+
n++
158+
}
159+
}()
160+
161+
select {
162+
case <-readDone:
163+
// OK!
164+
case <-ctx.Done():
165+
t.Error("read timed out")
166+
}
167+
}
168+
169+
// readUntil reads one byte at a time until we either see the string we want, or the context expires
170+
func readUntil(ctx context.Context, want string, r io.Reader) error {
171+
got := ""
172+
readErrs := make(chan error, 1)
173+
for {
174+
b := make([]byte, 1)
175+
go func() {
176+
_, err := r.Read(b)
177+
readErrs <- err
178+
}()
179+
select {
180+
case err := <-readErrs:
181+
if err != nil {
182+
return err
183+
}
184+
got = got + string(b)
185+
case <-ctx.Done():
186+
return ctx.Err()
187+
}
188+
if strings.Contains(got, want) {
189+
return nil
190+
}
191+
}
192+
}

0 commit comments

Comments
 (0)