@@ -2,122 +2,261 @@ package workspacetraffic
2
2
3
3
import (
4
4
"context"
5
+ "encoding/json"
6
+ "errors"
5
7
"io"
6
8
"sync"
9
+ "time"
7
10
8
11
"github.com/coder/coder/v2/codersdk"
9
12
10
13
"github.com/google/uuid"
11
- "github.com/hashicorp/go-multierror"
12
14
gossh "golang.org/x/crypto/ssh"
13
15
"golang.org/x/xerrors"
14
16
)
15
17
16
- func connectPTY (ctx context.Context , client * codersdk.Client , agentID , reconnect uuid.UUID ) (* countReadWriteCloser , error ) {
18
+ const (
19
+ // Set a timeout for graceful close of the connection.
20
+ connCloseTimeout = 30 * time .Second
21
+ // Set a timeout for waiting for the connection to close.
22
+ waitCloseTimeout = connCloseTimeout + 5 * time .Second
23
+
24
+ // In theory, we can send larger payloads to push bandwidth, but we need to
25
+ // be careful not to send too much data at once or the server will close the
26
+ // connection. We see this more readily as our JSON payloads approach 28KB.
27
+ //
28
+ // failed to write frame: WebSocket closed: received close frame: status = StatusMessageTooBig and reason = "read limited at 32769 bytes"
29
+ //
30
+ // Since we can't control fragmentation/buffer sizes, we keep it simple and
31
+ // match the conservative payload size used by agent/reconnectingpty (1024).
32
+ rptyJSONMaxDataSize = 1024
33
+ )
34
+
35
+ func connectRPTY (ctx context.Context , client * codersdk.Client , agentID , reconnect uuid.UUID , cmd string ) (* countReadWriteCloser , error ) {
36
+ width , height := 80 , 25
17
37
conn , err := client .WorkspaceAgentReconnectingPTY (ctx , codersdk.WorkspaceAgentReconnectingPTYOpts {
18
38
AgentID : agentID ,
19
39
Reconnect : reconnect ,
20
- Height : 25 ,
21
- Width : 80 ,
22
- Command : "sh" ,
40
+ Width : uint16 ( width ) ,
41
+ Height : uint16 ( height ) ,
42
+ Command : cmd ,
23
43
})
24
44
if err != nil {
25
45
return nil , xerrors .Errorf ("connect pty: %w" , err )
26
46
}
27
47
28
48
// Wrap the conn in a countReadWriteCloser so we can monitor bytes sent/rcvd.
29
- crw := countReadWriteCloser {ctx : ctx , rwc : conn }
49
+ crw := countReadWriteCloser {rwc : newPTYConn ( conn ) }
30
50
return & crw , nil
31
51
}
32
52
33
- func connectSSH (ctx context.Context , client * codersdk.Client , agentID uuid.UUID ) (* countReadWriteCloser , error ) {
53
+ type rptyConn struct {
54
+ conn io.ReadWriteCloser
55
+ wenc * json.Encoder
56
+
57
+ readOnce sync.Once
58
+ readErr chan error
59
+
60
+ mu sync.Mutex // Protects following.
61
+ closed bool
62
+ }
63
+
64
+ func newPTYConn (conn io.ReadWriteCloser ) * rptyConn {
65
+ rc := & rptyConn {
66
+ conn : conn ,
67
+ wenc : json .NewEncoder (conn ),
68
+ readErr : make (chan error , 1 ),
69
+ }
70
+ return rc
71
+ }
72
+
73
+ func (c * rptyConn ) Read (p []byte ) (int , error ) {
74
+ n , err := c .conn .Read (p )
75
+ if err != nil {
76
+ c .readOnce .Do (func () {
77
+ c .readErr <- err
78
+ close (c .readErr )
79
+ })
80
+ return n , err
81
+ }
82
+ return n , nil
83
+ }
84
+
85
+ func (c * rptyConn ) Write (p []byte ) (int , error ) {
86
+ c .mu .Lock ()
87
+ defer c .mu .Unlock ()
88
+
89
+ // Early exit in case we're closing, this is to let call write Ctrl+C
90
+ // without a flood of other writes.
91
+ if c .closed {
92
+ return 0 , io .EOF
93
+ }
94
+
95
+ return c .writeNoLock (p )
96
+ }
97
+
98
+ func (c * rptyConn ) writeNoLock (p []byte ) (n int , err error ) {
99
+ // If we try to send more than the max payload size, the server will close the connection.
100
+ for len (p ) > 0 {
101
+ pp := p
102
+ if len (pp ) > rptyJSONMaxDataSize {
103
+ pp = p [:rptyJSONMaxDataSize ]
104
+ }
105
+ p = p [len (pp ):]
106
+ req := codersdk.ReconnectingPTYRequest {Data : string (pp )}
107
+ if err := c .wenc .Encode (req ); err != nil {
108
+ return n , xerrors .Errorf ("encode pty request: %w" , err )
109
+ }
110
+ n += len (pp )
111
+ }
112
+ return n , nil
113
+ }
114
+
115
+ func (c * rptyConn ) Close () (err error ) {
116
+ c .mu .Lock ()
117
+ if c .closed {
118
+ c .mu .Unlock ()
119
+ return nil
120
+ }
121
+ c .closed = true
122
+ c .mu .Unlock ()
123
+
124
+ defer c .conn .Close ()
125
+
126
+ // Send Ctrl+C to interrupt the command.
127
+ _ , err = c .writeNoLock ([]byte ("\u0003 " ))
128
+ if err != nil {
129
+ return xerrors .Errorf ("write ctrl+c: %w" , err )
130
+ }
131
+ select {
132
+ case <- time .After (connCloseTimeout ):
133
+ return xerrors .Errorf ("timeout waiting for read to finish" )
134
+ case err = <- c .readErr :
135
+ if errors .Is (err , io .EOF ) {
136
+ return nil
137
+ }
138
+ return err
139
+ }
140
+ }
141
+
142
+ //nolint:revive // Ignore requestPTY control flag.
143
+ func connectSSH (ctx context.Context , client * codersdk.Client , agentID uuid.UUID , cmd string , requestPTY bool ) (rwc * countReadWriteCloser , err error ) {
144
+ var closers []func () error
145
+ defer func () {
146
+ if err != nil {
147
+ for _ , c := range closers {
148
+ if err2 := c (); err2 != nil {
149
+ err = errors .Join (err , err2 )
150
+ }
151
+ }
152
+ }
153
+ }()
154
+
34
155
agentConn , err := client .DialWorkspaceAgent (ctx , agentID , & codersdk.DialWorkspaceAgentOptions {})
35
156
if err != nil {
36
157
return nil , xerrors .Errorf ("dial workspace agent: %w" , err )
37
158
}
38
- agentConn .AwaitReachable (ctx )
159
+ closers = append (closers , agentConn .Close )
160
+
39
161
sshClient , err := agentConn .SSHClient (ctx )
40
162
if err != nil {
41
163
return nil , xerrors .Errorf ("get ssh client: %w" , err )
42
164
}
165
+ closers = append (closers , sshClient .Close )
166
+
43
167
sshSession , err := sshClient .NewSession ()
44
168
if err != nil {
45
- _ = agentConn .Close ()
46
169
return nil , xerrors .Errorf ("new ssh session: %w" , err )
47
170
}
48
- wrappedConn := & wrappedSSHConn {ctx : ctx }
171
+ closers = append (closers , sshSession .Close )
172
+
173
+ wrappedConn := & wrappedSSHConn {}
174
+
49
175
// Do some plumbing to hook up the wrappedConn
50
176
pr1 , pw1 := io .Pipe ()
177
+ closers = append (closers , pr1 .Close , pw1 .Close )
51
178
wrappedConn .stdout = pr1
52
179
sshSession .Stdout = pw1
180
+
53
181
pr2 , pw2 := io .Pipe ()
182
+ closers = append (closers , pr2 .Close , pw2 .Close )
54
183
sshSession .Stdin = pr2
55
184
wrappedConn .stdin = pw2
56
- err = sshSession .RequestPty ("xterm" , 25 , 80 , gossh.TerminalModes {})
57
- if err != nil {
58
- _ = pr1 .Close ()
59
- _ = pr2 .Close ()
60
- _ = pw1 .Close ()
61
- _ = pw2 .Close ()
62
- _ = sshSession .Close ()
63
- _ = agentConn .Close ()
64
- return nil , xerrors .Errorf ("request pty: %w" , err )
185
+
186
+ if requestPTY {
187
+ err = sshSession .RequestPty ("xterm" , 25 , 80 , gossh.TerminalModes {})
188
+ if err != nil {
189
+ return nil , xerrors .Errorf ("request pty: %w" , err )
190
+ }
65
191
}
66
- err = sshSession .Shell ( )
192
+ err = sshSession .Start ( cmd )
67
193
if err != nil {
68
- _ = sshSession .Close ()
69
- _ = agentConn .Close ()
70
194
return nil , xerrors .Errorf ("shell: %w" , err )
71
195
}
196
+ waitErr := make (chan error , 1 )
197
+ go func () {
198
+ waitErr <- sshSession .Wait ()
199
+ }()
72
200
73
201
closeFn := func () error {
74
- var merr error
75
- if err := sshSession .Close (); err != nil {
76
- merr = multierror .Append (merr , err )
202
+ // Start by closing stdin so we stop writing to the ssh session.
203
+ merr := pw2 .Close ()
204
+ if err := sshSession .Signal (gossh .SIGHUP ); err != nil {
205
+ merr = errors .Join (merr , err )
77
206
}
78
- if err := agentConn .Close (); err != nil {
79
- merr = multierror .Append (merr , err )
207
+ select {
208
+ case <- time .After (connCloseTimeout ):
209
+ merr = errors .Join (merr , xerrors .Errorf ("timeout waiting for ssh session to close" ))
210
+ case err := <- waitErr :
211
+ if err != nil {
212
+ var exitErr * gossh.ExitError
213
+ if xerrors .As (err , & exitErr ) {
214
+ // The exit status is 255 when the command is
215
+ // interrupted by a signal. This is expected.
216
+ if exitErr .ExitStatus () != 255 {
217
+ merr = errors .Join (merr , xerrors .Errorf ("ssh session exited with unexpected status: %d" , int32 (exitErr .ExitStatus ())))
218
+ }
219
+ } else {
220
+ merr = errors .Join (merr , err )
221
+ }
222
+ }
223
+ }
224
+ for _ , c := range closers {
225
+ if err := c (); err != nil {
226
+ if ! errors .Is (err , io .EOF ) {
227
+ merr = errors .Join (merr , err )
228
+ }
229
+ }
80
230
}
81
231
return merr
82
232
}
83
233
wrappedConn .close = closeFn
84
234
85
- crw := & countReadWriteCloser {ctx : ctx , rwc : wrappedConn }
235
+ crw := & countReadWriteCloser {rwc : wrappedConn }
236
+
86
237
return crw , nil
87
238
}
88
239
89
240
// wrappedSSHConn wraps an ssh.Session to implement io.ReadWriteCloser.
90
241
type wrappedSSHConn struct {
91
- ctx context.Context
92
242
stdout io.Reader
93
- stdin io.Writer
243
+ stdin io.WriteCloser
94
244
closeOnce sync.Once
95
245
closeErr error
96
246
close func () error
97
247
}
98
248
99
249
func (w * wrappedSSHConn ) Close () error {
100
250
w .closeOnce .Do (func () {
101
- _ , _ = w .stdin .Write ([]byte ("exit\n " ))
102
251
w .closeErr = w .close ()
103
252
})
104
253
return w .closeErr
105
254
}
106
255
107
256
func (w * wrappedSSHConn ) Read (p []byte ) (n int , err error ) {
108
- select {
109
- case <- w .ctx .Done ():
110
- return 0 , xerrors .Errorf ("read: %w" , w .ctx .Err ())
111
- default :
112
- return w .stdout .Read (p )
113
- }
257
+ return w .stdout .Read (p )
114
258
}
115
259
116
260
func (w * wrappedSSHConn ) Write (p []byte ) (n int , err error ) {
117
- select {
118
- case <- w .ctx .Done ():
119
- return 0 , xerrors .Errorf ("write: %w" , w .ctx .Err ())
120
- default :
121
- return w .stdin .Write (p )
122
- }
261
+ return w .stdin .Write (p )
123
262
}
0 commit comments