6
6
"encoding/json"
7
7
"io"
8
8
"net"
9
+ "strings"
9
10
10
11
"cdr.dev/wsep/internal/proto"
11
- "golang.org/x/sync/errgroup"
12
12
"golang.org/x/xerrors"
13
13
"nhooyr.io/websocket"
14
14
)
@@ -39,6 +39,9 @@ type Command struct {
39
39
WorkingDir string
40
40
}
41
41
42
+ // Start runs the command on the remote. Once a command is started, callers should
43
+ // not read from, write to, or close the websocket. Closing the returned Process will
44
+ // also close the websocket.
42
45
func (r remoteExec ) Start (ctx context.Context , c Command ) (Process , error ) {
43
46
header := proto.ClientStartHeader {
44
47
ID : c .ID ,
@@ -73,30 +76,42 @@ func (r remoteExec) Start(ctx context.Context, c Command) (Process, error) {
73
76
stdin = disabledStdinWriter {}
74
77
}
75
78
76
- rp := remoteProcess {
77
- ctx : ctx ,
78
- conn : r .conn ,
79
- cmd : c ,
80
- pid : pidHeader .Pid ,
81
- done : make (chan error , 1 ),
82
- stderr : newPipe (),
83
- stdout : newPipe (),
84
- stdin : stdin ,
79
+ listenCtx , cancelListen := context .WithCancel (ctx )
80
+ rp := & remoteProcess {
81
+ ctx : ctx ,
82
+ conn : r .conn ,
83
+ cmd : c ,
84
+ pid : pidHeader .Pid ,
85
+ done : make (chan struct {}),
86
+ stderr : newPipe (),
87
+ stderrData : make (chan []byte ),
88
+ stdout : newPipe (),
89
+ stdoutData : make (chan []byte ),
90
+ stdin : stdin ,
91
+ cancelListen : cancelListen ,
85
92
}
86
93
87
- go rp .listen (ctx )
94
+ go rp .listen (listenCtx )
88
95
return rp , nil
89
96
}
90
97
91
98
type remoteProcess struct {
92
- ctx context.Context
93
- cmd Command
94
- conn * websocket.Conn
95
- pid int
96
- done chan error
97
- stdin io.WriteCloser
98
- stdout pipe
99
- stderr pipe
99
+ ctx context.Context
100
+ cancelListen func ()
101
+ cmd Command
102
+ conn * websocket.Conn
103
+ pid int
104
+ done chan struct {}
105
+ closeErr error
106
+ exitCode * int
107
+ readErr error
108
+ stdin io.WriteCloser
109
+ stdout pipe
110
+ stdoutErr error
111
+ stdoutData chan []byte
112
+ stderr pipe
113
+ stderrErr error
114
+ stderrData chan []byte
100
115
}
101
116
102
117
type remoteStdin struct {
@@ -143,99 +158,143 @@ func (r remoteStdin) Close() error {
143
158
}
144
159
145
160
type pipe struct {
146
- r * io.PipeReader
147
- w * io.PipeWriter
161
+ r * io.PipeReader
162
+ w * io.PipeWriter
163
+ d chan []byte
164
+ e chan error
165
+ buf []byte
148
166
}
149
167
150
168
func newPipe () pipe {
151
169
pr , pw := io .Pipe ()
152
170
return pipe {
153
- r : pr ,
154
- w : pw ,
171
+ r : pr ,
172
+ w : pw ,
173
+ d : make (chan []byte ),
174
+ e : make (chan error ),
175
+ buf : make ([]byte , maxMessageSize ),
155
176
}
156
177
}
157
178
158
- func (r remoteProcess ) listen (ctx context.Context ) {
159
- defer r .conn .Close (websocket .StatusNormalClosure , "normal closure" )
160
- defer close (r .done )
179
+ // writeCtx writes data to the pipe, or returns if the context is canceled.
180
+ func (p * pipe ) writeCtx (ctx context.Context , data []byte ) error {
181
+ // actually do the copy on another goroutine so that we can return if context
182
+ // is canceled
183
+ go func () {
184
+ var err error
185
+ select {
186
+ case <- ctx .Done ():
187
+ return
188
+ case body := <- p .d :
189
+ _ , err = io .CopyBuffer (p .w , bytes .NewReader (body ), p .buf )
190
+ }
191
+ select {
192
+ case <- ctx .Done ():
193
+ return
194
+ case p .e <- err :
195
+ return
196
+ }
197
+ }()
198
+
199
+ select {
200
+ case <- ctx .Done ():
201
+ return ctx .Err ()
202
+ case p .d <- data :
203
+ // data being written.
204
+ }
205
+ select {
206
+ case <- ctx .Done ():
207
+ return ctx .Err ()
208
+ case err := <- p .e :
209
+ return err
210
+ }
211
+ }
161
212
162
- exitCode := make (chan int , 1 )
163
- var eg errgroup.Group
213
+ func (r * remoteProcess ) listen (ctx context.Context ) {
214
+ defer func () {
215
+ r .stdoutErr = r .stdout .w .Close ()
216
+ r .stderrErr = r .stderr .w .Close ()
217
+
218
+ r .closeErr = r .conn .Close (websocket .StatusNormalClosure , "normal closure" )
219
+ // If we were in r.conn.Read() we cancel the ctx, the websocket library closes
220
+ // the websocket before we have a chance to. This is a normal closure.
221
+ if r .closeErr != nil && strings .Contains (r .closeErr .Error (), "already wrote close" ) &&
222
+ r .readErr != nil && strings .Contains (r .readErr .Error (), "context canceled" ) {
223
+ r .closeErr = nil
224
+ }
225
+ close (r .done )
226
+ }()
164
227
165
- eg .Go (func () error {
166
- defer r .stdout .w .Close ()
167
- defer r .stderr .w .Close ()
228
+ for ctx .Err () == nil {
229
+ _ , payload , err := r .conn .Read (ctx )
230
+ if err != nil {
231
+ r .readErr = err
232
+ return
233
+ }
234
+ headerByt , body := proto .SplitMessage (payload )
168
235
169
- buf := make ([]byte , maxMessageSize ) // max size of one websocket message
170
- for ctx .Err () == nil {
171
- _ , payload , err := r .conn .Read (ctx )
236
+ var header proto.Header
237
+ err = json .Unmarshal (headerByt , & header )
238
+ if err != nil {
239
+ r .readErr = err
240
+ return
241
+ }
242
+
243
+ switch header .Type {
244
+ case proto .TypeStderr :
245
+ err = r .stderr .writeCtx (ctx , body )
172
246
if err != nil {
173
- return err
247
+ r .readErr = err
248
+ return
174
249
}
175
- headerByt , body := proto .SplitMessage (payload )
176
-
177
- var header proto.Header
178
- err = json .Unmarshal (headerByt , & header )
250
+ case proto .TypeStdout :
251
+ err = r .stdout .writeCtx (ctx , body )
179
252
if err != nil {
180
- continue
253
+ r .readErr = err
254
+ return
181
255
}
182
-
183
- switch header .Type {
184
- case proto .TypeStderr :
185
- _ , err = io .CopyBuffer (r .stderr .w , bytes .NewReader (body ), buf )
186
- if err != nil {
187
- return err
188
- }
189
- case proto .TypeStdout :
190
- _ , err = io .CopyBuffer (r .stdout .w , bytes .NewReader (body ), buf )
191
- if err != nil {
192
- return err
193
- }
194
- case proto .TypeExitCode :
195
- var exitMsg proto.ServerExitCodeHeader
196
- err = json .Unmarshal (headerByt , & exitMsg )
197
- if err != nil {
198
- continue
199
- }
200
-
201
- exitCode <- exitMsg .ExitCode
202
- return nil
256
+ case proto .TypeExitCode :
257
+ var exitMsg proto.ServerExitCodeHeader
258
+ err = json .Unmarshal (headerByt , & exitMsg )
259
+ if err != nil {
260
+ r .readErr = err
261
+ return
203
262
}
204
- }
205
- return ctx .Err ()
206
- })
207
263
208
- err := eg .Wait ()
209
- select {
210
- case exitCode := <- exitCode :
211
- if exitCode != 0 {
212
- r .done <- ExitError {Code : exitCode }
264
+ r .exitCode = & exitMsg .ExitCode
265
+ return
213
266
}
214
- default :
215
- r .done <- err
216
267
}
268
+ // if we get here, the context is done, so use that as the read error
269
+ r .readErr = ctx .Err ()
217
270
}
218
271
219
- func (r remoteProcess ) Pid () int {
272
+ func (r * remoteProcess ) Pid () int {
220
273
return r .pid
221
274
}
222
275
223
- func (r remoteProcess ) Stdin () io.WriteCloser {
276
+ func (r * remoteProcess ) Stdin () io.WriteCloser {
224
277
if ! r .cmd .Stdin {
225
278
return disabledStdinWriter {}
226
279
}
227
280
return r .stdin
228
281
}
229
282
230
- func (r remoteProcess ) Stdout () io.Reader {
283
+ // Stdout returns a reader for standard out from the process. You MUST read from
284
+ // this reader even if you don't care about the data to avoid blocking the
285
+ // websocket.
286
+ func (r * remoteProcess ) Stdout () io.Reader {
231
287
return r .stdout .r
232
288
}
233
289
234
- func (r remoteProcess ) Stderr () io.Reader {
290
+ // Stdout returns a reader for standard error from the process. You MUST read from
291
+ // this reader even if you don't care about the data to avoid blocking the
292
+ // websocket.
293
+ func (r * remoteProcess ) Stderr () io.Reader {
235
294
return r .stderr .r
236
295
}
237
296
238
- func (r remoteProcess ) Resize (ctx context.Context , rows , cols uint16 ) error {
297
+ func (r * remoteProcess ) Resize (ctx context.Context , rows , cols uint16 ) error {
239
298
header := proto.ClientResizeHeader {
240
299
Type : proto .TypeResize ,
241
300
Cols : cols ,
@@ -248,20 +307,25 @@ func (r remoteProcess) Resize(ctx context.Context, rows, cols uint16) error {
248
307
return r .conn .Write (ctx , websocket .MessageBinary , payload )
249
308
}
250
309
251
- func (r remoteProcess ) Wait () error {
252
- select {
253
- case err := <- r .done :
254
- return err
255
- case <- r .ctx .Done ():
256
- return r .ctx .Err ()
310
+ func (r * remoteProcess ) Wait () error {
311
+ <- r .done
312
+ if r .readErr != nil {
313
+ return r .readErr
314
+ }
315
+ // when listen() closes r.done, either there must be a read error
316
+ // or exitCode is set non-nil, so it's safe to dereference the pointer
317
+ // here
318
+ if * r .exitCode != 0 {
319
+ return ExitError {Code : * r .exitCode }
257
320
}
321
+ return nil
258
322
}
259
323
260
- func (r remoteProcess ) Close () error {
261
- err := r . conn . Close ( websocket . StatusNormalClosure , "" )
262
- err1 := r . stderr . w . Close ()
263
- err2 := r .stdout . w . Close ()
264
- return joinErrs (err , err1 , err2 )
324
+ func (r * remoteProcess ) Close () error {
325
+ r . cancelListen ( )
326
+ <- r . done
327
+ closeErr := r .closeErr
328
+ return joinErrs (closeErr , r . stdoutErr , r . stderrErr )
265
329
}
266
330
267
331
func joinErrs (errs ... error ) error {
0 commit comments