1
1
package trafficgen
2
2
3
3
import (
4
+ "bytes"
4
5
"context"
5
6
"encoding/json"
6
7
"io"
@@ -12,6 +13,7 @@ import (
12
13
13
14
"cdr.dev/slog"
14
15
"cdr.dev/slog/sloggers/sloghuman"
16
+
15
17
"github.com/coder/coder/coderd/tracing"
16
18
"github.com/coder/coder/codersdk"
17
19
"github.com/coder/coder/cryptorand"
@@ -72,14 +74,14 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
72
74
_ = conn .Close ()
73
75
}()
74
76
75
- // Wrap the conn in a countReadWriter so we can monitor bytes sent/rcvd.
76
- crw := countReadWriter {ReadWriter : conn }
77
-
78
77
// Set a deadline for stopping the text.
79
78
start := time .Now ()
80
79
deadlineCtx , cancel := context .WithDeadline (ctx , start .Add (r .cfg .Duration ))
81
80
defer cancel ()
82
81
82
+ // Wrap the conn in a countReadWriter so we can monitor bytes sent/rcvd.
83
+ crw := countReadWriter {ReadWriter : conn , ctx : deadlineCtx }
84
+
83
85
// Create a ticker for sending data to the PTY.
84
86
tick := time .NewTicker (time .Duration (tickInterval ))
85
87
defer tick .Stop ()
@@ -88,10 +90,15 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
88
90
rch := make (chan error )
89
91
wch := make (chan error )
90
92
93
+ go func () {
94
+ <- deadlineCtx .Done ()
95
+ logger .Debug (ctx , "context deadline reached" , slog .F ("duration" , time .Since (start )))
96
+ }()
97
+
91
98
// Read forever in the background.
92
99
go func () {
93
100
logger .Debug (ctx , "reading from agent" , slog .F ("agent_id" , agentID ))
94
- rch <- readContext (deadlineCtx , & crw , bytesPerTick * 2 )
101
+ rch <- drainContext (deadlineCtx , & crw , bytesPerTick * 2 )
95
102
logger .Debug (ctx , "done reading from agent" , slog .F ("agent_id" , agentID ))
96
103
conn .Close ()
97
104
close (rch )
@@ -115,7 +122,7 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
115
122
116
123
duration := time .Since (start )
117
124
118
- logger .Info (ctx , "trafficgen result " ,
125
+ logger .Info (ctx , "results " ,
119
126
slog .F ("duration" , duration ),
120
127
slog .F ("sent" , crw .BytesWritten ()),
121
128
slog .F ("rcvd" , crw .BytesRead ()),
@@ -129,14 +136,33 @@ func (*Runner) Cleanup(context.Context, string) error {
129
136
return nil
130
137
}
131
138
132
- func readContext (ctx context.Context , src io.Reader , bufSize int64 ) error {
133
- buf := make ([]byte , bufSize )
139
+ // drainContext drains from src until it returns io.EOF or ctx times out.
140
+ func drainContext (ctx context.Context , src io.Reader , bufSize int64 ) error {
141
+ errCh := make (chan error )
142
+ done := make (chan struct {})
143
+ go func () {
144
+ tmp := make ([]byte , bufSize )
145
+ buf := bytes .NewBuffer (tmp )
146
+ for {
147
+ select {
148
+ case <- done :
149
+ return
150
+ default :
151
+ _ , err := io .CopyN (buf , src , 1 )
152
+ if err != nil {
153
+ errCh <- err
154
+ close (errCh )
155
+ return
156
+ }
157
+ }
158
+ }
159
+ }()
134
160
for {
135
161
select {
136
162
case <- ctx .Done ():
163
+ close (done )
137
164
return nil
138
- default :
139
- _ , err := src .Read (buf )
165
+ case err := <- errCh :
140
166
if err != nil {
141
167
if xerrors .Is (err , io .EOF ) {
142
168
return nil
@@ -175,31 +201,37 @@ func copyContext(ctx context.Context, dst io.Writer, src []byte) (int, error) {
175
201
case <- ctx .Done ():
176
202
return count , nil
177
203
default :
178
- n , err := dst .Write (src )
179
- if err != nil {
180
- if xerrors .Is (err , io .EOF ) {
181
- // On an EOF, assume that all of src was consumed.
182
- return len (src ), nil
204
+ for idx := range src {
205
+ n , err := dst .Write (src [idx : idx + 1 ])
206
+ if err != nil {
207
+ if xerrors .Is (err , io .EOF ) {
208
+ return count , nil
209
+ }
210
+ if xerrors .Is (err , context .DeadlineExceeded ) {
211
+ // It's OK if we reach the deadline before writing the full payload.
212
+ return count , nil
213
+ }
214
+ return count , err
183
215
}
184
- return count , err
185
- }
186
- count += n
187
- if n == len (src ) {
188
- return count , nil
216
+ count += n
189
217
}
190
- // Not all of src was consumed. Update src and retry.
191
- src = src [n :]
218
+ return count , nil
192
219
}
193
220
}
194
221
}
195
222
223
+ // countReadWriter wraps an io.ReadWriter and counts the number of bytes read and written.
196
224
type countReadWriter struct {
225
+ ctx context.Context
197
226
io.ReadWriter
198
227
bytesRead atomic.Int64
199
228
bytesWritten atomic.Int64
200
229
}
201
230
202
231
func (w * countReadWriter ) Read (p []byte ) (int , error ) {
232
+ if err := w .ctx .Err (); err != nil {
233
+ return 0 , err
234
+ }
203
235
n , err := w .ReadWriter .Read (p )
204
236
if err == nil {
205
237
w .bytesRead .Add (int64 (n ))
@@ -208,6 +240,9 @@ func (w *countReadWriter) Read(p []byte) (int, error) {
208
240
}
209
241
210
242
func (w * countReadWriter ) Write (p []byte ) (int , error ) {
243
+ if err := w .ctx .Err (); err != nil {
244
+ return 0 , err
245
+ }
211
246
n , err := w .ReadWriter .Write (p )
212
247
if err == nil {
213
248
w .bytesWritten .Add (int64 (n ))
0 commit comments