|
9 | 9 |
|
10 | 10 | "github.com/google/uuid"
|
11 | 11 | "golang.org/x/xerrors"
|
| 12 | + "nhooyr.io/websocket" |
12 | 13 |
|
13 | 14 | "cdr.dev/slog"
|
14 | 15 | "cdr.dev/slog/sloggers/sloghuman"
|
@@ -101,22 +102,22 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
|
101 | 102 |
|
102 | 103 | go func() {
|
103 | 104 | <-deadlineCtx.Done()
|
104 |
| - logger.Debug(ctx, "context deadline reached", slog.F("duration", time.Since(start))) |
| 105 | + logger.Debug(ctx, "closing agent connection") |
| 106 | + conn.Close() |
105 | 107 | }()
|
106 | 108 |
|
107 | 109 | // Read forever in the background.
|
108 | 110 | go func() {
|
109 | 111 | logger.Debug(ctx, "reading from agent", slog.F("agent_id", agentID))
|
110 |
| - rch <- drainContext(deadlineCtx, &crw) |
| 112 | + rch <- drain(&crw) |
111 | 113 | logger.Debug(ctx, "done reading from agent", slog.F("agent_id", agentID))
|
112 |
| - conn.Close() |
113 | 114 | close(rch)
|
114 | 115 | }()
|
115 | 116 |
|
116 | 117 | // Write random data to the PTY every tick.
|
117 | 118 | go func() {
|
118 | 119 | logger.Debug(ctx, "writing to agent", slog.F("agent_id", agentID))
|
119 |
| - wch <- writeRandomData(deadlineCtx, &crw, bytesPerTick, tick.C) |
| 120 | + wch <- writeRandomData(&crw, bytesPerTick, tick.C) |
120 | 121 | logger.Debug(ctx, "done writing to agent", slog.F("agent_id", agentID))
|
121 | 122 | close(wch)
|
122 | 123 | }()
|
@@ -145,93 +146,33 @@ func (*Runner) Cleanup(context.Context, string) error {
|
145 | 146 | return nil
|
146 | 147 | }
|
147 | 148 |
|
148 |
| -// drainContext drains from src until it returns io.EOF or ctx times out. |
149 |
| -func drainContext(ctx context.Context, src io.Reader) error { |
150 |
| - errCh := make(chan error, 1) |
151 |
| - done := make(chan struct{}) |
152 |
| - go func() { |
153 |
| - for { |
154 |
| - select { |
155 |
| - case <-done: |
156 |
| - return |
157 |
| - default: |
158 |
| - _, err := io.CopyN(io.Discard, src, 1) |
159 |
| - if ctx.Err() != nil { |
160 |
| - return // context canceled while we were copying. |
161 |
| - } |
162 |
| - if err != nil { |
163 |
| - errCh <- err |
164 |
| - close(errCh) |
165 |
| - return |
166 |
| - } |
167 |
| - } |
168 |
| - } |
169 |
| - }() |
170 |
| - for { |
171 |
| - select { |
172 |
| - case <-ctx.Done(): |
173 |
| - close(done) |
174 |
| - return nil |
175 |
| - case err := <-errCh: |
176 |
| - if err != nil { |
177 |
| - if xerrors.Is(err, io.EOF) { |
178 |
| - return nil |
179 |
| - } |
180 |
| - // It's OK if the context is canceled. |
181 |
| - if xerrors.Is(err, context.DeadlineExceeded) { |
182 |
| - return nil |
183 |
| - } |
184 |
| - return err |
185 |
| - } |
186 |
| - } |
187 |
| - } |
188 |
| -} |
189 |
| - |
190 |
| -func writeRandomData(ctx context.Context, dst io.Writer, size int64, tick <-chan time.Time) error { |
191 |
| - for { |
192 |
| - select { |
193 |
| - case <-ctx.Done(): |
| 149 | +// drain drains from src until it returns io.EOF or ctx times out. |
| 150 | +func drain(src io.Reader) error { |
| 151 | + if _, err := io.Copy(io.Discard, src); err != nil { |
| 152 | + if xerrors.Is(err, context.DeadlineExceeded) || xerrors.Is(err, websocket.CloseError{}) { |
194 | 153 | return nil
|
195 |
| - case <-tick: |
196 |
| - payload := "#" + mustRandStr(size-1) |
197 |
| - data, err := json.Marshal(codersdk.ReconnectingPTYRequest{ |
198 |
| - Data: payload, |
199 |
| - }) |
200 |
| - if err != nil { |
201 |
| - return err |
202 |
| - } |
203 |
| - if _, err := copyContext(ctx, dst, data); err != nil { |
204 |
| - return err |
205 |
| - } |
206 | 154 | }
|
| 155 | + return err |
207 | 156 | }
|
| 157 | + return nil |
208 | 158 | }
|
209 | 159 |
|
210 |
| -// copyContext copies from src to dst until ctx is canceled. |
211 |
| -func copyContext(ctx context.Context, dst io.Writer, src []byte) (int, error) { |
212 |
| - var count int |
213 |
| - for { |
214 |
| - select { |
215 |
| - case <-ctx.Done(): |
216 |
| - return count, nil |
217 |
| - default: |
218 |
| - for idx := range src { |
219 |
| - n, err := dst.Write(src[idx : idx+1]) |
220 |
| - if err != nil { |
221 |
| - if xerrors.Is(err, io.EOF) { |
222 |
| - return count, nil |
223 |
| - } |
224 |
| - if xerrors.Is(err, context.DeadlineExceeded) { |
225 |
| - // It's OK if we reach the deadline before writing the full payload. |
226 |
| - return count, nil |
227 |
| - } |
228 |
| - return count, err |
229 |
| - } |
230 |
| - count += n |
| 160 | +func writeRandomData(dst io.Writer, size int64, tick <-chan time.Time) error { |
| 161 | + var ( |
| 162 | + enc = json.NewEncoder(dst) |
| 163 | + ptyReq = codersdk.ReconnectingPTYRequest{} |
| 164 | + ) |
| 165 | + for range tick { |
| 166 | + payload := "#" + mustRandStr(size-1) |
| 167 | + ptyReq.Data = payload |
| 168 | + if err := enc.Encode(ptyReq); err != nil { |
| 169 | + if xerrors.Is(err, context.DeadlineExceeded) || xerrors.Is(err, websocket.CloseError{}) { |
| 170 | + return nil |
231 | 171 | }
|
232 |
| - return count, nil |
| 172 | + return err |
233 | 173 | }
|
234 | 174 | }
|
| 175 | + return nil |
235 | 176 | }
|
236 | 177 |
|
237 | 178 | // countReadWriter wraps an io.ReadWriter and counts the number of bytes read and written.
|
|
0 commit comments