Skip to content

Commit 7c198a6

Browse files
committed
fix cancellation
1 parent 5d4dc96 commit 7c198a6

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

scaletest/trafficgen/run.go

+10-7
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,12 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
5656
bytesPerTick = r.cfg.BytesPerSecond / r.cfg.TicksPerSecond
5757
)
5858

59+
// Set a deadline for stopping the text.
60+
start := time.Now()
61+
deadlineCtx, cancel := context.WithDeadline(ctx, start.Add(r.cfg.Duration))
62+
defer cancel()
5963
logger.Debug(ctx, "connect to workspace agent", slog.F("agent_id", agentID))
64+
6065
conn, err := r.client.WorkspaceAgentReconnectingPTY(ctx, codersdk.WorkspaceAgentReconnectingPTYOpts{
6166
AgentID: agentID,
6267
Reconnect: reconnect,
@@ -74,11 +79,6 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
7479
_ = conn.Close()
7580
}()
7681

77-
// Set a deadline for stopping the text.
78-
start := time.Now()
79-
deadlineCtx, cancel := context.WithDeadline(ctx, start.Add(r.cfg.Duration))
80-
defer cancel()
81-
8282
// Wrap the conn in a countReadWriter so we can monitor bytes sent/rcvd.
8383
crw := countReadWriter{ReadWriter: conn, ctx: deadlineCtx}
8484

@@ -112,7 +112,7 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
112112
close(wch)
113113
}()
114114

115-
// Wait for both our reads and writes to be finished.
115+
// Write until the context is canceled.
116116
if wErr := <-wch; wErr != nil {
117117
return xerrors.Errorf("write to pty: %w", wErr)
118118
}
@@ -138,7 +138,7 @@ func (*Runner) Cleanup(context.Context, string) error {
138138

139139
// drainContext drains from src until it returns io.EOF or ctx times out.
140140
func drainContext(ctx context.Context, src io.Reader, bufSize int64) error {
141-
errCh := make(chan error)
141+
errCh := make(chan error, 1)
142142
done := make(chan struct{})
143143
go func() {
144144
tmp := make([]byte, bufSize)
@@ -149,6 +149,9 @@ func drainContext(ctx context.Context, src io.Reader, bufSize int64) error {
149149
return
150150
default:
151151
_, err := io.CopyN(buf, src, 1)
152+
if ctx.Err() != nil {
153+
return // context canceled while we were copying.
154+
}
152155
if err != nil {
153156
errCh <- err
154157
close(errCh)

0 commit comments

Comments
 (0)