Skip to content

Commit da935a2

Browse files
committed
drain connection async
1 parent 0bfa9f6 commit da935a2

File tree

2 files changed

+58
-22
lines changed

2 files changed

+58
-22
lines changed

cli/scaletest.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -975,7 +975,8 @@ func (r *RootCmd) scaletestTrafficGen() *clibase.Cmd {
975975
}
976976

977977
if agentID == uuid.Nil {
978-
return xerrors.Errorf("no agent found for workspace %s", ws.Name)
978+
_, _ = fmt.Fprintf(inv.Stderr, "WARN: skipping workspace %s: no agent\n", ws.Name)
979+
continue
979980
}
980981

981982
// Setup our workspace agent connection.

scaletest/trafficgen/run.go

+56-21
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package trafficgen
22

33
import (
4+
"bytes"
45
"context"
56
"encoding/json"
67
"io"
@@ -12,6 +13,7 @@ import (
1213

1314
"cdr.dev/slog"
1415
"cdr.dev/slog/sloggers/sloghuman"
16+
1517
"github.com/coder/coder/coderd/tracing"
1618
"github.com/coder/coder/codersdk"
1719
"github.com/coder/coder/cryptorand"
@@ -72,14 +74,14 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
7274
_ = conn.Close()
7375
}()
7476

75-
// Wrap the conn in a countReadWriter so we can monitor bytes sent/rcvd.
76-
crw := countReadWriter{ReadWriter: conn}
77-
7877
// Set a deadline for stopping the text.
7978
start := time.Now()
8079
deadlineCtx, cancel := context.WithDeadline(ctx, start.Add(r.cfg.Duration))
8180
defer cancel()
8281

82+
// Wrap the conn in a countReadWriter so we can monitor bytes sent/rcvd.
83+
crw := countReadWriter{ReadWriter: conn, ctx: deadlineCtx}
84+
8385
// Create a ticker for sending data to the PTY.
8486
tick := time.NewTicker(time.Duration(tickInterval))
8587
defer tick.Stop()
@@ -88,10 +90,15 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
8890
rch := make(chan error)
8991
wch := make(chan error)
9092

93+
go func() {
94+
<-deadlineCtx.Done()
95+
logger.Debug(ctx, "context deadline reached", slog.F("duration", time.Since(start)))
96+
}()
97+
9198
// Read forever in the background.
9299
go func() {
93100
logger.Debug(ctx, "reading from agent", slog.F("agent_id", agentID))
94-
rch <- readContext(deadlineCtx, &crw, bytesPerTick*2)
101+
rch <- drainContext(deadlineCtx, &crw, bytesPerTick*2)
95102
logger.Debug(ctx, "done reading from agent", slog.F("agent_id", agentID))
96103
conn.Close()
97104
close(rch)
@@ -115,7 +122,7 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
115122

116123
duration := time.Since(start)
117124

118-
logger.Info(ctx, "trafficgen result",
125+
logger.Info(ctx, "results",
119126
slog.F("duration", duration),
120127
slog.F("sent", crw.BytesWritten()),
121128
slog.F("rcvd", crw.BytesRead()),
@@ -129,14 +136,33 @@ func (*Runner) Cleanup(context.Context, string) error {
129136
return nil
130137
}
131138

132-
func readContext(ctx context.Context, src io.Reader, bufSize int64) error {
133-
buf := make([]byte, bufSize)
139+
// drainContext drains from src until it returns io.EOF or ctx times out.
140+
func drainContext(ctx context.Context, src io.Reader, bufSize int64) error {
141+
errCh := make(chan error)
142+
done := make(chan struct{})
143+
go func() {
144+
tmp := make([]byte, bufSize)
145+
buf := bytes.NewBuffer(tmp)
146+
for {
147+
select {
148+
case <-done:
149+
return
150+
default:
151+
_, err := io.CopyN(buf, src, 1)
152+
if err != nil {
153+
errCh <- err
154+
close(errCh)
155+
return
156+
}
157+
}
158+
}
159+
}()
134160
for {
135161
select {
136162
case <-ctx.Done():
163+
close(done)
137164
return nil
138-
default:
139-
_, err := src.Read(buf)
165+
case err := <-errCh:
140166
if err != nil {
141167
if xerrors.Is(err, io.EOF) {
142168
return nil
@@ -175,31 +201,37 @@ func copyContext(ctx context.Context, dst io.Writer, src []byte) (int, error) {
175201
case <-ctx.Done():
176202
return count, nil
177203
default:
178-
n, err := dst.Write(src)
179-
if err != nil {
180-
if xerrors.Is(err, io.EOF) {
181-
// On an EOF, assume that all of src was consumed.
182-
return len(src), nil
204+
for idx := range src {
205+
n, err := dst.Write(src[idx : idx+1])
206+
if err != nil {
207+
if xerrors.Is(err, io.EOF) {
208+
return count, nil
209+
}
210+
if xerrors.Is(err, context.DeadlineExceeded) {
211+
// It's OK if we reach the deadline before writing the full payload.
212+
return count, nil
213+
}
214+
return count, err
183215
}
184-
return count, err
185-
}
186-
count += n
187-
if n == len(src) {
188-
return count, nil
216+
count += n
189217
}
190-
// Not all of src was consumed. Update src and retry.
191-
src = src[n:]
218+
return count, nil
192219
}
193220
}
194221
}
195222

223+
// countReadWriter wraps an io.ReadWriter and counts the number of bytes read and written.
196224
type countReadWriter struct {
225+
ctx context.Context
197226
io.ReadWriter
198227
bytesRead atomic.Int64
199228
bytesWritten atomic.Int64
200229
}
201230

202231
func (w *countReadWriter) Read(p []byte) (int, error) {
232+
if err := w.ctx.Err(); err != nil {
233+
return 0, err
234+
}
203235
n, err := w.ReadWriter.Read(p)
204236
if err == nil {
205237
w.bytesRead.Add(int64(n))
@@ -208,6 +240,9 @@ func (w *countReadWriter) Read(p []byte) (int, error) {
208240
}
209241

210242
func (w *countReadWriter) Write(p []byte) (int, error) {
243+
if err := w.ctx.Err(); err != nil {
244+
return 0, err
245+
}
211246
n, err := w.ReadWriter.Write(p)
212247
if err == nil {
213248
w.bytesWritten.Add(int64(n))

0 commit comments

Comments
 (0)