Skip to content

Commit 3d408db

Browse files
authored
Merge pull request #27 from coder/spike/rsv-bits
Fix remoteProcess race between Close and Read
2 parents ce30d9a + 69ea3a8 commit 3d408db

File tree

3 files changed

+286
-135
lines changed

3 files changed

+286
-135
lines changed

client.go

Lines changed: 151 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ import (
66
"encoding/json"
77
"io"
88
"net"
9+
"strings"
910

1011
"cdr.dev/wsep/internal/proto"
11-
"golang.org/x/sync/errgroup"
1212
"golang.org/x/xerrors"
1313
"nhooyr.io/websocket"
1414
)
@@ -39,6 +39,9 @@ type Command struct {
3939
WorkingDir string
4040
}
4141

42+
// Start runs the command on the remote. Once a command is started, callers should
43+
// not read from, write to, or close the websocket. Closing the returned Process will
44+
// also close the websocket.
4245
func (r remoteExec) Start(ctx context.Context, c Command) (Process, error) {
4346
header := proto.ClientStartHeader{
4447
ID: c.ID,
@@ -73,30 +76,42 @@ func (r remoteExec) Start(ctx context.Context, c Command) (Process, error) {
7376
stdin = disabledStdinWriter{}
7477
}
7578

76-
rp := remoteProcess{
77-
ctx: ctx,
78-
conn: r.conn,
79-
cmd: c,
80-
pid: pidHeader.Pid,
81-
done: make(chan error, 1),
82-
stderr: newPipe(),
83-
stdout: newPipe(),
84-
stdin: stdin,
79+
listenCtx, cancelListen := context.WithCancel(ctx)
80+
rp := &remoteProcess{
81+
ctx: ctx,
82+
conn: r.conn,
83+
cmd: c,
84+
pid: pidHeader.Pid,
85+
done: make(chan struct{}),
86+
stderr: newPipe(),
87+
stderrData: make(chan []byte),
88+
stdout: newPipe(),
89+
stdoutData: make(chan []byte),
90+
stdin: stdin,
91+
cancelListen: cancelListen,
8592
}
8693

87-
go rp.listen(ctx)
94+
go rp.listen(listenCtx)
8895
return rp, nil
8996
}
9097

9198
type remoteProcess struct {
92-
ctx context.Context
93-
cmd Command
94-
conn *websocket.Conn
95-
pid int
96-
done chan error
97-
stdin io.WriteCloser
98-
stdout pipe
99-
stderr pipe
99+
ctx context.Context
100+
cancelListen func()
101+
cmd Command
102+
conn *websocket.Conn
103+
pid int
104+
done chan struct{}
105+
closeErr error
106+
exitCode *int
107+
readErr error
108+
stdin io.WriteCloser
109+
stdout pipe
110+
stdoutErr error
111+
stdoutData chan []byte
112+
stderr pipe
113+
stderrErr error
114+
stderrData chan []byte
100115
}
101116

102117
type remoteStdin struct {
@@ -143,99 +158,143 @@ func (r remoteStdin) Close() error {
143158
}
144159

145160
type pipe struct {
146-
r *io.PipeReader
147-
w *io.PipeWriter
161+
r *io.PipeReader
162+
w *io.PipeWriter
163+
d chan []byte
164+
e chan error
165+
buf []byte
148166
}
149167

150168
func newPipe() pipe {
151169
pr, pw := io.Pipe()
152170
return pipe{
153-
r: pr,
154-
w: pw,
171+
r: pr,
172+
w: pw,
173+
d: make(chan []byte),
174+
e: make(chan error),
175+
buf: make([]byte, maxMessageSize),
155176
}
156177
}
157178

158-
func (r remoteProcess) listen(ctx context.Context) {
159-
defer r.conn.Close(websocket.StatusNormalClosure, "normal closure")
160-
defer close(r.done)
179+
// writeCtx writes data to the pipe, or returns if the context is canceled.
180+
func (p *pipe) writeCtx(ctx context.Context, data []byte) error {
181+
// actually do the copy on another goroutine so that we can return if context
182+
// is canceled
183+
go func() {
184+
var err error
185+
select {
186+
case <-ctx.Done():
187+
return
188+
case body := <-p.d:
189+
_, err = io.CopyBuffer(p.w, bytes.NewReader(body), p.buf)
190+
}
191+
select {
192+
case <-ctx.Done():
193+
return
194+
case p.e <- err:
195+
return
196+
}
197+
}()
198+
199+
select {
200+
case <-ctx.Done():
201+
return ctx.Err()
202+
case p.d <- data:
203+
// data being written.
204+
}
205+
select {
206+
case <-ctx.Done():
207+
return ctx.Err()
208+
case err := <-p.e:
209+
return err
210+
}
211+
}
161212

162-
exitCode := make(chan int, 1)
163-
var eg errgroup.Group
213+
func (r *remoteProcess) listen(ctx context.Context) {
214+
defer func() {
215+
r.stdoutErr = r.stdout.w.Close()
216+
r.stderrErr = r.stderr.w.Close()
217+
218+
r.closeErr = r.conn.Close(websocket.StatusNormalClosure, "normal closure")
219+
// If we were in r.conn.Read() we cancel the ctx, the websocket library closes
220+
// the websocket before we have a chance to. This is a normal closure.
221+
if r.closeErr != nil && strings.Contains(r.closeErr.Error(), "already wrote close") &&
222+
r.readErr != nil && strings.Contains(r.readErr.Error(), "context canceled") {
223+
r.closeErr = nil
224+
}
225+
close(r.done)
226+
}()
164227

165-
eg.Go(func() error {
166-
defer r.stdout.w.Close()
167-
defer r.stderr.w.Close()
228+
for ctx.Err() == nil {
229+
_, payload, err := r.conn.Read(ctx)
230+
if err != nil {
231+
r.readErr = err
232+
return
233+
}
234+
headerByt, body := proto.SplitMessage(payload)
168235

169-
buf := make([]byte, maxMessageSize) // max size of one websocket message
170-
for ctx.Err() == nil {
171-
_, payload, err := r.conn.Read(ctx)
236+
var header proto.Header
237+
err = json.Unmarshal(headerByt, &header)
238+
if err != nil {
239+
r.readErr = err
240+
return
241+
}
242+
243+
switch header.Type {
244+
case proto.TypeStderr:
245+
err = r.stderr.writeCtx(ctx, body)
172246
if err != nil {
173-
return err
247+
r.readErr = err
248+
return
174249
}
175-
headerByt, body := proto.SplitMessage(payload)
176-
177-
var header proto.Header
178-
err = json.Unmarshal(headerByt, &header)
250+
case proto.TypeStdout:
251+
err = r.stdout.writeCtx(ctx, body)
179252
if err != nil {
180-
continue
253+
r.readErr = err
254+
return
181255
}
182-
183-
switch header.Type {
184-
case proto.TypeStderr:
185-
_, err = io.CopyBuffer(r.stderr.w, bytes.NewReader(body), buf)
186-
if err != nil {
187-
return err
188-
}
189-
case proto.TypeStdout:
190-
_, err = io.CopyBuffer(r.stdout.w, bytes.NewReader(body), buf)
191-
if err != nil {
192-
return err
193-
}
194-
case proto.TypeExitCode:
195-
var exitMsg proto.ServerExitCodeHeader
196-
err = json.Unmarshal(headerByt, &exitMsg)
197-
if err != nil {
198-
continue
199-
}
200-
201-
exitCode <- exitMsg.ExitCode
202-
return nil
256+
case proto.TypeExitCode:
257+
var exitMsg proto.ServerExitCodeHeader
258+
err = json.Unmarshal(headerByt, &exitMsg)
259+
if err != nil {
260+
r.readErr = err
261+
return
203262
}
204-
}
205-
return ctx.Err()
206-
})
207263

208-
err := eg.Wait()
209-
select {
210-
case exitCode := <-exitCode:
211-
if exitCode != 0 {
212-
r.done <- ExitError{Code: exitCode}
264+
r.exitCode = &exitMsg.ExitCode
265+
return
213266
}
214-
default:
215-
r.done <- err
216267
}
268+
// if we get here, the context is done, so use that as the read error
269+
r.readErr = ctx.Err()
217270
}
218271

219-
func (r remoteProcess) Pid() int {
272+
func (r *remoteProcess) Pid() int {
220273
return r.pid
221274
}
222275

223-
func (r remoteProcess) Stdin() io.WriteCloser {
276+
func (r *remoteProcess) Stdin() io.WriteCloser {
224277
if !r.cmd.Stdin {
225278
return disabledStdinWriter{}
226279
}
227280
return r.stdin
228281
}
229282

230-
func (r remoteProcess) Stdout() io.Reader {
283+
// Stdout returns a reader for standard out from the process. You MUST read from
284+
// this reader even if you don't care about the data to avoid blocking the
285+
// websocket.
286+
func (r *remoteProcess) Stdout() io.Reader {
231287
return r.stdout.r
232288
}
233289

234-
func (r remoteProcess) Stderr() io.Reader {
290+
// Stdout returns a reader for standard error from the process. You MUST read from
291+
// this reader even if you don't care about the data to avoid blocking the
292+
// websocket.
293+
func (r *remoteProcess) Stderr() io.Reader {
235294
return r.stderr.r
236295
}
237296

238-
func (r remoteProcess) Resize(ctx context.Context, rows, cols uint16) error {
297+
func (r *remoteProcess) Resize(ctx context.Context, rows, cols uint16) error {
239298
header := proto.ClientResizeHeader{
240299
Type: proto.TypeResize,
241300
Cols: cols,
@@ -248,20 +307,25 @@ func (r remoteProcess) Resize(ctx context.Context, rows, cols uint16) error {
248307
return r.conn.Write(ctx, websocket.MessageBinary, payload)
249308
}
250309

251-
func (r remoteProcess) Wait() error {
252-
select {
253-
case err := <-r.done:
254-
return err
255-
case <-r.ctx.Done():
256-
return r.ctx.Err()
310+
func (r *remoteProcess) Wait() error {
311+
<-r.done
312+
if r.readErr != nil {
313+
return r.readErr
314+
}
315+
// when listen() closes r.done, either there must be a read error
316+
// or exitCode is set non-nil, so it's safe to dereference the pointer
317+
// here
318+
if *r.exitCode != 0 {
319+
return ExitError{Code: *r.exitCode}
257320
}
321+
return nil
258322
}
259323

260-
func (r remoteProcess) Close() error {
261-
err := r.conn.Close(websocket.StatusNormalClosure, "")
262-
err1 := r.stderr.w.Close()
263-
err2 := r.stdout.w.Close()
264-
return joinErrs(err, err1, err2)
324+
func (r *remoteProcess) Close() error {
325+
r.cancelListen()
326+
<-r.done
327+
closeErr := r.closeErr
328+
return joinErrs(closeErr, r.stdoutErr, r.stderrErr)
265329
}
266330

267331
func joinErrs(errs ...error) error {

0 commit comments

Comments
 (0)