@@ -12,15 +12,32 @@ import (
12
12
"golang.org/x/xerrors"
13
13
14
14
"github.com/coder/coder/cli/clibase"
15
+ "github.com/coder/coder/cli/cliui"
15
16
"github.com/coder/coder/codersdk"
16
17
"github.com/coder/coder/cryptorand"
17
18
)
18
19
20
+ type trafficGenOutput struct {
21
+ DurationSeconds float64 `json:"duration_s"`
22
+ SentBytes int64 `json:"sent_bytes"`
23
+ RcvdBytes int64 `json:"rcvd_bytes"`
24
+ }
25
+
26
+ func (o trafficGenOutput ) String () string {
27
+ return fmt .Sprintf ("Duration: %.2fs\n " , o .DurationSeconds ) +
28
+ fmt .Sprintf ("Sent: %dB\n " , o .SentBytes ) +
29
+ fmt .Sprintf ("Rcvd: %dB" , o .RcvdBytes )
30
+ }
31
+
19
32
func (r * RootCmd ) trafficGen () * clibase.Cmd {
20
33
var (
21
- duration time.Duration
22
- bps int64
23
- client = new (codersdk.Client )
34
+ duration time.Duration
35
+ formatter = cliui .NewOutputFormatter (
36
+ cliui .TextFormat (),
37
+ cliui .JSONFormat (),
38
+ )
39
+ bps int64
40
+ client = new (codersdk.Client )
24
41
)
25
42
26
43
cmd := & clibase.Cmd {
@@ -32,7 +49,10 @@ func (r *RootCmd) trafficGen() *clibase.Cmd {
32
49
r .InitClient (client ),
33
50
),
34
51
Handler : func (inv * clibase.Invocation ) error {
35
- var agentName string
52
+ var (
53
+ agentName string
54
+ tickInterval = 100 * time .Millisecond
55
+ )
36
56
ws , err := namedWorkspace (inv .Context (), client , inv .Args [0 ])
37
57
if err != nil {
38
58
return err
@@ -53,6 +73,7 @@ func (r *RootCmd) trafficGen() *clibase.Cmd {
53
73
return xerrors .Errorf ("no agent found for workspace %s" , ws .Name )
54
74
}
55
75
76
+ // Setup our workspace agent connection.
56
77
reconnect := uuid .New ()
57
78
conn , err := client .WorkspaceAgentReconnectingPTY (inv .Context (), codersdk.WorkspaceAgentReconnectingPTYOpts {
58
79
AgentID : agentID ,
@@ -68,46 +89,60 @@ func (r *RootCmd) trafficGen() *clibase.Cmd {
68
89
defer func () {
69
90
_ = conn .Close ()
70
91
}()
92
+
93
+ // Wrap the conn in a countReadWriter so we can monitor bytes sent/rcvd.
94
+ crw := countReadWriter {ReadWriter : conn }
95
+
96
+ // Set a deadline for stopping the text.
71
97
start := time .Now ()
72
- ctx , cancel := context .WithDeadline (inv .Context (), start .Add (duration ))
98
+ deadlineCtx , cancel := context .WithDeadline (inv .Context (), start .Add (duration ))
73
99
defer cancel ()
74
- crw := countReadWriter {ReadWriter : conn }
75
- // First, write a comment to the pty so we don't execute anything.
76
- data , err := json .Marshal (codersdk.ReconnectingPTYRequest {
77
- Data : "#" ,
78
- })
79
- if err != nil {
80
- return xerrors .Errorf ("serialize request: %w" , err )
81
- }
82
- _ , err = crw .Write (data )
83
- if err != nil {
84
- return xerrors .Errorf ("write comment to pty: %w" , err )
85
- }
100
+
101
+ // Create a ticker for sending data to the PTY.
102
+ tick := time .NewTicker (tickInterval )
103
+ defer tick .Stop ()
104
+
86
105
// Now we begin writing random data to the pty.
87
106
writeSize := int (bps / 10 )
88
107
rch := make (chan error )
89
108
wch := make (chan error )
109
+
110
+ // Read forever in the background.
90
111
go func () {
91
- rch <- readForever (ctx , & crw )
112
+ rch <- readContext (deadlineCtx , & crw , writeSize * 2 )
113
+ conn .Close ()
92
114
close (rch )
93
115
}()
116
+
117
+ // Write random data to the PTY every tick.
94
118
go func () {
95
- wch <- writeRandomData (ctx , & crw , writeSize , 100 * time . Millisecond )
119
+ wch <- writeRandomData (deadlineCtx , & crw , writeSize , tick . C )
96
120
close (wch )
97
121
}()
98
122
123
+ // Wait for both our reads and writes to be finished.
99
124
if wErr := <- wch ; wErr != nil {
100
125
return xerrors .Errorf ("write to pty: %w" , wErr )
101
126
}
102
127
if rErr := <- rch ; rErr != nil {
103
128
return xerrors .Errorf ("read from pty: %w" , rErr )
104
129
}
105
130
106
- _ , _ = fmt .Fprintf (inv .Stdout , "Test results:\n " )
107
- _ , _ = fmt .Fprintf (inv .Stdout , "Took: %.2fs\n " , time .Since (start ).Seconds ())
108
- _ , _ = fmt .Fprintf (inv .Stdout , "Sent: %d bytes\n " , crw .BytesWritten ())
109
- _ , _ = fmt .Fprintf (inv .Stdout , "Rcvd: %d bytes\n " , crw .BytesRead ())
110
- return nil
131
+ duration := time .Since (start )
132
+
133
+ results := trafficGenOutput {
134
+ DurationSeconds : duration .Seconds (),
135
+ SentBytes : crw .BytesWritten (),
136
+ RcvdBytes : crw .BytesRead (),
137
+ }
138
+
139
+ out , err := formatter .Format (inv .Context (), results )
140
+ if err != nil {
141
+ return err
142
+ }
143
+
144
+ _ , err = fmt .Fprintln (inv .Stdout , out )
145
+ return err
111
146
},
112
147
}
113
148
@@ -128,66 +163,78 @@ func (r *RootCmd) trafficGen() *clibase.Cmd {
128
163
},
129
164
}
130
165
166
+ formatter .AttachOptions (& cmd .Options )
131
167
return cmd
132
168
}
133
169
134
- func readForever (ctx context.Context , src io.Reader ) error {
135
- buf := make ([]byte , 1024 )
170
+ func readContext (ctx context.Context , src io.Reader , bufSize int ) error {
171
+ buf := make ([]byte , bufSize )
136
172
for {
137
173
select {
138
174
case <- ctx .Done ():
139
175
return nil
140
176
default :
177
+ if ctx .Err () != nil {
178
+ return nil
179
+ }
141
180
_ , err := src .Read (buf )
142
- if err != nil && err != io .EOF {
181
+ if err != nil {
182
+ if xerrors .Is (err , io .EOF ) {
183
+ return nil
184
+ }
143
185
return err
144
186
}
145
187
}
146
188
}
147
189
}
148
190
149
- func writeRandomData (ctx context.Context , dst io.Writer , size int , period time.Duration ) error {
150
- tick := time .NewTicker (period )
151
- defer tick .Stop ()
191
+ func writeRandomData (ctx context.Context , dst io.Writer , size int , tick <- chan time.Time ) error {
152
192
for {
153
193
select {
154
194
case <- ctx .Done ():
155
195
return nil
156
- case <- tick .C :
157
- randStr , err := cryptorand .String (size )
158
- if err != nil {
159
- return err
160
- }
196
+ case <- tick :
197
+ payload := "#" + mustRandStr (size - 1 )
161
198
data , err := json .Marshal (codersdk.ReconnectingPTYRequest {
162
- Data : randStr ,
199
+ Data : payload ,
163
200
})
164
201
if err != nil {
165
202
return err
166
203
}
167
- err = copyContext (ctx , dst , data )
168
- if err != nil {
204
+ if _ , err := copyContext (ctx , dst , data ); err != nil {
169
205
return err
170
206
}
171
207
}
172
208
}
173
209
}
174
210
175
- func copyContext (ctx context.Context , dst io.Writer , src []byte ) error {
176
- for idx := range src {
211
+ // copyContext copies from src to dst until ctx is canceled.
212
+ func copyContext (ctx context.Context , dst io.Writer , src []byte ) (int , error ) {
213
+ var count int
214
+ for {
177
215
select {
178
216
case <- ctx .Done ():
179
- return nil
217
+ return count , nil
180
218
default :
181
- _ , err := dst .Write (src [idx : idx + 1 ])
219
+ if ctx .Err () != nil {
220
+ return count , nil
221
+ }
222
+ n , err := dst .Write (src )
182
223
if err != nil {
183
224
if xerrors .Is (err , io .EOF ) {
184
- return nil
225
+ // On an EOF, assume that all of src was consumed.
226
+ return len (src ), nil
185
227
}
186
- return err
228
+ return count , err
229
+ }
230
+ count += n
231
+ if n == len (src ) {
232
+ return count , nil
187
233
}
234
+ // Not all of src was consumed. Update src and retry.
235
+ src = src [n :]
188
236
}
189
237
}
190
- return nil
191
238
}
192
239
193
240
type countReadWriter struct {
@@ -219,3 +266,11 @@ func (w *countReadWriter) BytesRead() int64 {
219
266
func (w * countReadWriter ) BytesWritten () int64 {
220
267
return w .bytesWritten .Load ()
221
268
}
269
+
270
+ func mustRandStr (len int ) string {
271
+ randStr , err := cryptorand .String (len )
272
+ if err != nil {
273
+ panic (err )
274
+ }
275
+ return randStr
276
+ }
0 commit comments