Skip to content

Commit 88c35d3

Browse files
authored
fix(pty): close output writer before reader on Windows to unblock close (#8299)
1 parent 59246e0 commit 88c35d3

File tree

2 files changed

+74
-27
lines changed

2 files changed

+74
-27
lines changed

pty/pty_windows.go

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,13 @@ func newPty(opt ...Option) (*ptyWindows, error) {
6565
0,
6666
uintptr(unsafe.Pointer(&pty.console)),
6767
)
68-
if int32(ret) < 0 {
68+
// CreatePseudoConsole returns S_OK on success, as per:
69+
// https://learn.microsoft.com/en-us/windows/console/createpseudoconsole
70+
if windows.Handle(ret) != windows.S_OK {
6971
_ = pty.Close()
7072
return nil, xerrors.Errorf("create pseudo console (%d): %w", int32(ret), err)
7173
}
74+
7275
return pty, nil
7376
}
7477

@@ -134,39 +137,65 @@ func (p *ptyWindows) Resize(height uint16, width uint16) error {
134137
Y: int16(height),
135138
X: int16(width),
136139
})))))
137-
if ret != 0 {
140+
if windows.Handle(ret) != windows.S_OK {
138141
return err
139142
}
140143
return nil
141144
}
142145

143-
func (p *ptyWindows) Close() error {
144-
p.closeMutex.Lock()
145-
defer p.closeMutex.Unlock()
146-
if p.closed {
147-
return nil
148-
}
149-
p.closed = true
150-
146+
// closeConsoleNoLock closes the console handle, and sets it to
147+
// windows.InvalidHandle. It must be called with p.closeMutex held.
148+
func (p *ptyWindows) closeConsoleNoLock() error {
151149
// if we are running a command in the PTY, the corresponding *windowsProcess
152150
// may have already closed the PseudoConsole when the command exited, so that
153151
// output reads can get to EOF. In that case, we don't need to close it
154152
// again here.
155153
if p.console != windows.InvalidHandle {
154+
// ClosePseudoConsole has no return value and typically the syscall
155+
// returns S_FALSE (a success value). We could ignore the return value
156+
// and error here but we handle anyway, it just in case.
157+
//
158+
// Note that ClosePseudoConsole is a blocking system call and may write
159+
// a final frame to the output buffer (p.outputWrite), so there must be
160+
// a consumer (p.outputRead) to ensure we don't block here indefinitely.
161+
//
162+
// https://docs.microsoft.com/en-us/windows/console/closepseudoconsole
156163
ret, _, err := procClosePseudoConsole.Call(uintptr(p.console))
157-
if ret < 0 {
158-
return xerrors.Errorf("close pseudo console: %w", err)
164+
if winerrorFailed(ret) {
165+
return xerrors.Errorf("close pseudo console (%d): %w", ret, err)
159166
}
160167
p.console = windows.InvalidHandle
161168
}
162169

163-
// We always have these files
164-
_ = p.outputRead.Close()
165-
_ = p.inputWrite.Close()
166-
// These get closed & unset if we Start() a new process.
170+
return nil
171+
}
172+
173+
func (p *ptyWindows) Close() error {
174+
p.closeMutex.Lock()
175+
defer p.closeMutex.Unlock()
176+
if p.closed {
177+
return nil
178+
}
179+
180+
// Close the pseudo console, this will also terminate the process attached
181+
// to this pty. If it was created via Start(), this also unblocks close of
182+
// the readers below.
183+
err := p.closeConsoleNoLock()
184+
if err != nil {
185+
return err
186+
}
187+
188+
// Only set closed after the console has been successfully closed.
189+
p.closed = true
190+
191+
// Close the pipes ensuring that the writer is closed before the respective
192+
// reader, otherwise closing the reader may block indefinitely. Note that
193+
// outputWrite and inputRead are unset when we Start() a new process.
167194
if p.outputWrite != nil {
168195
_ = p.outputWrite.Close()
169196
}
197+
_ = p.outputRead.Close()
198+
_ = p.inputWrite.Close()
170199
if p.inputRead != nil {
171200
_ = p.inputRead.Close()
172201
}
@@ -184,15 +213,13 @@ func (p *windowsProcess) waitInternal() {
184213
// c.f. https://devblogs.microsoft.com/commandline/windows-command-line-introducing-the-windows-pseudo-console-conpty/
185214
p.pw.closeMutex.Lock()
186215
defer p.pw.closeMutex.Unlock()
187-
if p.pw.console != windows.InvalidHandle {
188-
ret, _, err := procClosePseudoConsole.Call(uintptr(p.pw.console))
189-
if ret < 0 && p.cmdErr == nil {
190-
// if we already have an error from the command, prefer that error
191-
// but if the command succeeded and closing the PseudoConsole fails
192-
// then record that error so that we have a chance to see it
193-
p.cmdErr = err
194-
}
195-
p.pw.console = windows.InvalidHandle
216+
217+
err := p.pw.closeConsoleNoLock()
218+
// if we already have an error from the command, prefer that error
219+
// but if the command succeeded and closing the PseudoConsole fails
220+
// then record that error so that we have a chance to see it
221+
if err != nil && p.cmdErr == nil {
222+
p.cmdErr = err
196223
}
197224
}()
198225

@@ -225,3 +252,11 @@ func (p *windowsProcess) killOnContext(ctx context.Context) {
225252
p.Kill()
226253
}
227254
}
255+
256+
// winerrorFailed returns true if the syscall failed, this function
257+
// assumes the return value is a 32-bit integer, like HRESULT.
258+
//
259+
// https://learn.microsoft.com/en-us/windows/win32/api/winerror/nf-winerror-failed
260+
func winerrorFailed(r1 uintptr) bool {
261+
return int32(r1) < 0
262+
}

pty/ptytest/ptytest.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,13 @@ type PTY struct {
354354
func (p *PTY) Close() error {
355355
p.t.Helper()
356356
pErr := p.PTY.Close()
357-
eErr := p.outExpecter.close("close")
357+
if pErr != nil {
358+
p.logf("PTY: Close failed: %v", pErr)
359+
}
360+
eErr := p.outExpecter.close("PTY close")
361+
if eErr != nil {
362+
p.logf("PTY: close expecter failed: %v", eErr)
363+
}
358364
if pErr != nil {
359365
return pErr
360366
}
@@ -398,7 +404,13 @@ type PTYCmd struct {
398404
func (p *PTYCmd) Close() error {
399405
p.t.Helper()
400406
pErr := p.PTYCmd.Close()
401-
eErr := p.outExpecter.close("close")
407+
if pErr != nil {
408+
p.logf("PTYCmd: Close failed: %v", pErr)
409+
}
410+
eErr := p.outExpecter.close("PTYCmd close")
411+
if eErr != nil {
412+
p.logf("PTYCmd: close expecter failed: %v", eErr)
413+
}
402414
if pErr != nil {
403415
return pErr
404416
}

0 commit comments

Comments
 (0)