Skip to content

Commit 731b4db

Browse files
committed
remove unnecessary context-based I/O
1 parent b105e67 commit 731b4db

File tree

1 file changed

+24
-83
lines changed

1 file changed

+24
-83
lines changed

scaletest/trafficgen/run.go

Lines changed: 24 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99

1010
"github.com/google/uuid"
1111
"golang.org/x/xerrors"
12+
"nhooyr.io/websocket"
1213

1314
"cdr.dev/slog"
1415
"cdr.dev/slog/sloggers/sloghuman"
@@ -101,22 +102,22 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
101102

102103
go func() {
103104
<-deadlineCtx.Done()
104-
logger.Debug(ctx, "context deadline reached", slog.F("duration", time.Since(start)))
105+
logger.Debug(ctx, "closing agent connection")
106+
conn.Close()
105107
}()
106108

107109
// Read forever in the background.
108110
go func() {
109111
logger.Debug(ctx, "reading from agent", slog.F("agent_id", agentID))
110-
rch <- drainContext(deadlineCtx, &crw)
112+
rch <- drain(&crw)
111113
logger.Debug(ctx, "done reading from agent", slog.F("agent_id", agentID))
112-
conn.Close()
113114
close(rch)
114115
}()
115116

116117
// Write random data to the PTY every tick.
117118
go func() {
118119
logger.Debug(ctx, "writing to agent", slog.F("agent_id", agentID))
119-
wch <- writeRandomData(deadlineCtx, &crw, bytesPerTick, tick.C)
120+
wch <- writeRandomData(&crw, bytesPerTick, tick.C)
120121
logger.Debug(ctx, "done writing to agent", slog.F("agent_id", agentID))
121122
close(wch)
122123
}()
@@ -145,93 +146,33 @@ func (*Runner) Cleanup(context.Context, string) error {
145146
return nil
146147
}
147148

148-
// drainContext drains from src until it returns io.EOF or ctx times out.
149-
func drainContext(ctx context.Context, src io.Reader) error {
150-
errCh := make(chan error, 1)
151-
done := make(chan struct{})
152-
go func() {
153-
for {
154-
select {
155-
case <-done:
156-
return
157-
default:
158-
_, err := io.CopyN(io.Discard, src, 1)
159-
if ctx.Err() != nil {
160-
return // context canceled while we were copying.
161-
}
162-
if err != nil {
163-
errCh <- err
164-
close(errCh)
165-
return
166-
}
167-
}
168-
}
169-
}()
170-
for {
171-
select {
172-
case <-ctx.Done():
173-
close(done)
174-
return nil
175-
case err := <-errCh:
176-
if err != nil {
177-
if xerrors.Is(err, io.EOF) {
178-
return nil
179-
}
180-
// It's OK if the context is canceled.
181-
if xerrors.Is(err, context.DeadlineExceeded) {
182-
return nil
183-
}
184-
return err
185-
}
186-
}
187-
}
188-
}
189-
190-
func writeRandomData(ctx context.Context, dst io.Writer, size int64, tick <-chan time.Time) error {
191-
for {
192-
select {
193-
case <-ctx.Done():
149+
// drain drains from src until it returns io.EOF or ctx times out.
150+
func drain(src io.Reader) error {
151+
if _, err := io.Copy(io.Discard, src); err != nil {
152+
if xerrors.Is(err, context.DeadlineExceeded) || xerrors.Is(err, websocket.CloseError{}) {
194153
return nil
195-
case <-tick:
196-
payload := "#" + mustRandStr(size-1)
197-
data, err := json.Marshal(codersdk.ReconnectingPTYRequest{
198-
Data: payload,
199-
})
200-
if err != nil {
201-
return err
202-
}
203-
if _, err := copyContext(ctx, dst, data); err != nil {
204-
return err
205-
}
206154
}
155+
return err
207156
}
157+
return nil
208158
}
209159

210-
// copyContext copies from src to dst until ctx is canceled.
211-
func copyContext(ctx context.Context, dst io.Writer, src []byte) (int, error) {
212-
var count int
213-
for {
214-
select {
215-
case <-ctx.Done():
216-
return count, nil
217-
default:
218-
for idx := range src {
219-
n, err := dst.Write(src[idx : idx+1])
220-
if err != nil {
221-
if xerrors.Is(err, io.EOF) {
222-
return count, nil
223-
}
224-
if xerrors.Is(err, context.DeadlineExceeded) {
225-
// It's OK if we reach the deadline before writing the full payload.
226-
return count, nil
227-
}
228-
return count, err
229-
}
230-
count += n
160+
func writeRandomData(dst io.Writer, size int64, tick <-chan time.Time) error {
161+
var (
162+
enc = json.NewEncoder(dst)
163+
ptyReq = codersdk.ReconnectingPTYRequest{}
164+
)
165+
for range tick {
166+
payload := "#" + mustRandStr(size-1)
167+
ptyReq.Data = payload
168+
if err := enc.Encode(ptyReq); err != nil {
169+
if xerrors.Is(err, context.DeadlineExceeded) || xerrors.Is(err, websocket.CloseError{}) {
170+
return nil
231171
}
232-
return count, nil
172+
return err
233173
}
234174
}
175+
return nil
235176
}
236177

237178
// countReadWriter wraps an io.ReadWriter and counts the number of bytes read and written.

0 commit comments

Comments
 (0)