Skip to content

Commit b0a641d

Browse files
committed
close connection, add output formatting
1 parent 04bc551 commit b0a641d

File tree

2 files changed

+112
-56
lines changed

2 files changed

+112
-56
lines changed

cli/trafficgen.go

+100-45
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,32 @@ import (
1212
"golang.org/x/xerrors"
1313

1414
"github.com/coder/coder/cli/clibase"
15+
"github.com/coder/coder/cli/cliui"
1516
"github.com/coder/coder/codersdk"
1617
"github.com/coder/coder/cryptorand"
1718
)
1819

20+
type trafficGenOutput struct {
21+
DurationSeconds float64 `json:"duration_s"`
22+
SentBytes int64 `json:"sent_bytes"`
23+
RcvdBytes int64 `json:"rcvd_bytes"`
24+
}
25+
26+
func (o trafficGenOutput) String() string {
27+
return fmt.Sprintf("Duration: %.2fs\n", o.DurationSeconds) +
28+
fmt.Sprintf("Sent: %dB\n", o.SentBytes) +
29+
fmt.Sprintf("Rcvd: %dB", o.RcvdBytes)
30+
}
31+
1932
func (r *RootCmd) trafficGen() *clibase.Cmd {
2033
var (
21-
duration time.Duration
22-
bps int64
23-
client = new(codersdk.Client)
34+
duration time.Duration
35+
formatter = cliui.NewOutputFormatter(
36+
cliui.TextFormat(),
37+
cliui.JSONFormat(),
38+
)
39+
bps int64
40+
client = new(codersdk.Client)
2441
)
2542

2643
cmd := &clibase.Cmd{
@@ -32,7 +49,10 @@ func (r *RootCmd) trafficGen() *clibase.Cmd {
3249
r.InitClient(client),
3350
),
3451
Handler: func(inv *clibase.Invocation) error {
35-
var agentName string
52+
var (
53+
agentName string
54+
tickInterval = 100 * time.Millisecond
55+
)
3656
ws, err := namedWorkspace(inv.Context(), client, inv.Args[0])
3757
if err != nil {
3858
return err
@@ -53,6 +73,7 @@ func (r *RootCmd) trafficGen() *clibase.Cmd {
5373
return xerrors.Errorf("no agent found for workspace %s", ws.Name)
5474
}
5575

76+
// Setup our workspace agent connection.
5677
reconnect := uuid.New()
5778
conn, err := client.WorkspaceAgentReconnectingPTY(inv.Context(), codersdk.WorkspaceAgentReconnectingPTYOpts{
5879
AgentID: agentID,
@@ -68,46 +89,60 @@ func (r *RootCmd) trafficGen() *clibase.Cmd {
6889
defer func() {
6990
_ = conn.Close()
7091
}()
92+
93+
// Wrap the conn in a countReadWriter so we can monitor bytes sent/rcvd.
94+
crw := countReadWriter{ReadWriter: conn}
95+
96+
// Set a deadline for stopping the text.
7197
start := time.Now()
72-
ctx, cancel := context.WithDeadline(inv.Context(), start.Add(duration))
98+
deadlineCtx, cancel := context.WithDeadline(inv.Context(), start.Add(duration))
7399
defer cancel()
74-
crw := countReadWriter{ReadWriter: conn}
75-
// First, write a comment to the pty so we don't execute anything.
76-
data, err := json.Marshal(codersdk.ReconnectingPTYRequest{
77-
Data: "#",
78-
})
79-
if err != nil {
80-
return xerrors.Errorf("serialize request: %w", err)
81-
}
82-
_, err = crw.Write(data)
83-
if err != nil {
84-
return xerrors.Errorf("write comment to pty: %w", err)
85-
}
100+
101+
// Create a ticker for sending data to the PTY.
102+
tick := time.NewTicker(tickInterval)
103+
defer tick.Stop()
104+
86105
// Now we begin writing random data to the pty.
87106
writeSize := int(bps / 10)
88107
rch := make(chan error)
89108
wch := make(chan error)
109+
110+
// Read forever in the background.
90111
go func() {
91-
rch <- readForever(ctx, &crw)
112+
rch <- readContext(deadlineCtx, &crw, writeSize*2)
113+
conn.Close()
92114
close(rch)
93115
}()
116+
117+
// Write random data to the PTY every tick.
94118
go func() {
95-
wch <- writeRandomData(ctx, &crw, writeSize, 100*time.Millisecond)
119+
wch <- writeRandomData(deadlineCtx, &crw, writeSize, tick.C)
96120
close(wch)
97121
}()
98122

123+
// Wait for both our reads and writes to be finished.
99124
if wErr := <-wch; wErr != nil {
100125
return xerrors.Errorf("write to pty: %w", wErr)
101126
}
102127
if rErr := <-rch; rErr != nil {
103128
return xerrors.Errorf("read from pty: %w", rErr)
104129
}
105130

106-
_, _ = fmt.Fprintf(inv.Stdout, "Test results:\n")
107-
_, _ = fmt.Fprintf(inv.Stdout, "Took: %.2fs\n", time.Since(start).Seconds())
108-
_, _ = fmt.Fprintf(inv.Stdout, "Sent: %d bytes\n", crw.BytesWritten())
109-
_, _ = fmt.Fprintf(inv.Stdout, "Rcvd: %d bytes\n", crw.BytesRead())
110-
return nil
131+
duration := time.Since(start)
132+
133+
results := trafficGenOutput{
134+
DurationSeconds: duration.Seconds(),
135+
SentBytes: crw.BytesWritten(),
136+
RcvdBytes: crw.BytesRead(),
137+
}
138+
139+
out, err := formatter.Format(inv.Context(), results)
140+
if err != nil {
141+
return err
142+
}
143+
144+
_, err = fmt.Fprintln(inv.Stdout, out)
145+
return err
111146
},
112147
}
113148

@@ -128,66 +163,78 @@ func (r *RootCmd) trafficGen() *clibase.Cmd {
128163
},
129164
}
130165

166+
formatter.AttachOptions(&cmd.Options)
131167
return cmd
132168
}
133169

134-
func readForever(ctx context.Context, src io.Reader) error {
135-
buf := make([]byte, 1024)
170+
func readContext(ctx context.Context, src io.Reader, bufSize int) error {
171+
buf := make([]byte, bufSize)
136172
for {
137173
select {
138174
case <-ctx.Done():
139175
return nil
140176
default:
177+
if ctx.Err() != nil {
178+
return nil
179+
}
141180
_, err := src.Read(buf)
142-
if err != nil && err != io.EOF {
181+
if err != nil {
182+
if xerrors.Is(err, io.EOF) {
183+
return nil
184+
}
143185
return err
144186
}
145187
}
146188
}
147189
}
148190

149-
func writeRandomData(ctx context.Context, dst io.Writer, size int, period time.Duration) error {
150-
tick := time.NewTicker(period)
151-
defer tick.Stop()
191+
func writeRandomData(ctx context.Context, dst io.Writer, size int, tick <-chan time.Time) error {
152192
for {
153193
select {
154194
case <-ctx.Done():
155195
return nil
156-
case <-tick.C:
157-
randStr, err := cryptorand.String(size)
158-
if err != nil {
159-
return err
160-
}
196+
case <-tick:
197+
payload := "#" + mustRandStr(size-1)
161198
data, err := json.Marshal(codersdk.ReconnectingPTYRequest{
162-
Data: randStr,
199+
Data: payload,
163200
})
164201
if err != nil {
165202
return err
166203
}
167-
err = copyContext(ctx, dst, data)
168-
if err != nil {
204+
if _, err := copyContext(ctx, dst, data); err != nil {
169205
return err
170206
}
171207
}
172208
}
173209
}
174210

175-
func copyContext(ctx context.Context, dst io.Writer, src []byte) error {
176-
for idx := range src {
211+
// copyContext copies from src to dst until ctx is canceled.
212+
func copyContext(ctx context.Context, dst io.Writer, src []byte) (int, error) {
213+
var count int
214+
for {
177215
select {
178216
case <-ctx.Done():
179-
return nil
217+
return count, nil
180218
default:
181-
_, err := dst.Write(src[idx : idx+1])
219+
if ctx.Err() != nil {
220+
return count, nil
221+
}
222+
n, err := dst.Write(src)
182223
if err != nil {
183224
if xerrors.Is(err, io.EOF) {
184-
return nil
225+
// On an EOF, assume that all of src was consumed.
226+
return len(src), nil
185227
}
186-
return err
228+
return count, err
229+
}
230+
count += n
231+
if n == len(src) {
232+
return count, nil
187233
}
234+
// Not all of src was consumed. Update src and retry.
235+
src = src[n:]
188236
}
189237
}
190-
return nil
191238
}
192239

193240
type countReadWriter struct {
@@ -219,3 +266,11 @@ func (w *countReadWriter) BytesRead() int64 {
219266
func (w *countReadWriter) BytesWritten() int64 {
220267
return w.bytesWritten.Load()
221268
}
269+
270+
func mustRandStr(len int) string {
271+
randStr, err := cryptorand.String(len)
272+
if err != nil {
273+
panic(err)
274+
}
275+
return randStr
276+
}

cli/trafficgen_test.go

+12-11
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package cli_test
33
import (
44
"bytes"
55
"context"
6-
"strings"
6+
"encoding/json"
77
"testing"
88

99
"github.com/google/uuid"
@@ -23,7 +23,6 @@ import (
2323
// We do not perform any cleanup.
2424
func TestTrafficGen(t *testing.T) {
2525
t.Parallel()
26-
t.Skip("TODO: this hangs in a unit test but works in the real world.")
2726

2827
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitMedium)
2928
defer cancelFunc()
@@ -74,20 +73,22 @@ func TestTrafficGen(t *testing.T) {
7473
inv, root := clitest.New(t, "trafficgen", ws.Name,
7574
"--duration", "1s",
7675
"--bps", "100",
76+
"-o", "json",
7777
)
7878
clitest.SetupConfig(t, client, root)
7979
var stdout, stderr bytes.Buffer
8080
inv.Stdout = &stdout
8181
inv.Stderr = &stderr
8282
err := inv.WithContext(ctx).Run()
8383
require.NoError(t, err)
84-
stdoutStr := stdout.String()
85-
stderrStr := stderr.String()
86-
require.Empty(t, stderrStr)
87-
lines := strings.Split(strings.TrimSpace(stdoutStr), "\n")
88-
require.Len(t, lines, 4)
89-
require.Equal(t, "Test results:", lines[0])
90-
require.Regexp(t, `Took:\s+\d+\.\d+s`, lines[1])
91-
require.Regexp(t, `Sent:\s+\d+ bytes`, lines[2])
92-
require.Regexp(t, `Rcvd:\s+\d+ bytes`, lines[3])
84+
// TODO: this struct is currently unexported. Put it somewhere better.
85+
var output struct {
86+
DurationSeconds float64 `json:"duration_s"`
87+
SentBytes int64 `json:"sent_bytes"`
88+
RcvdBytes int64 `json:"rcvd_bytes"`
89+
}
90+
require.NoError(t, json.Unmarshal(stdout.Bytes(), &output))
91+
require.NotZero(t, output.DurationSeconds)
92+
require.NotZero(t, output.SentBytes)
93+
require.NotZero(t, output.RcvdBytes)
9394
}

0 commit comments

Comments
 (0)