Skip to content

Commit 65c6d88

Browse files
committed
move back under scaletest cmd
1 parent fafca95 commit 65c6d88

File tree

5 files changed

+343
-373
lines changed

5 files changed

+343
-373
lines changed

cli/root.go

-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ func (r *RootCmd) Core() []*clibase.Cmd {
105105
// Hidden
106106
r.gitssh(),
107107
r.scaletest(),
108-
r.trafficGen(),
109108
r.vscodeSSH(),
110109
r.workspaceAgent(),
111110
}

cli/scaletest.go

+260
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"strconv"
1111
"strings"
1212
"sync"
13+
"sync/atomic"
1314
"syscall"
1415
"time"
1516

@@ -42,6 +43,7 @@ func (r *RootCmd) scaletest() *clibase.Cmd {
4243
Children: []*clibase.Cmd{
4344
r.scaletestCleanup(),
4445
r.scaletestCreateWorkspaces(),
46+
r.scaletestTrafficGen(),
4547
},
4648
}
4749

@@ -947,6 +949,156 @@ func (r *RootCmd) scaletestCreateWorkspaces() *clibase.Cmd {
947949
return cmd
948950
}
949951

952+
type trafficGenOutput struct {
953+
DurationSeconds float64 `json:"duration_s"`
954+
SentBytes int64 `json:"sent_bytes"`
955+
RcvdBytes int64 `json:"rcvd_bytes"`
956+
}
957+
958+
func (o trafficGenOutput) String() string {
959+
return fmt.Sprintf("Duration: %.2fs\n", o.DurationSeconds) +
960+
fmt.Sprintf("Sent: %dB\n", o.SentBytes) +
961+
fmt.Sprintf("Rcvd: %dB", o.RcvdBytes)
962+
}
963+
964+
func (r *RootCmd) scaletestTrafficGen() *clibase.Cmd {
965+
var (
966+
duration time.Duration
967+
formatter = cliui.NewOutputFormatter(
968+
cliui.TextFormat(),
969+
cliui.JSONFormat(),
970+
)
971+
bps int64
972+
client = new(codersdk.Client)
973+
)
974+
975+
cmd := &clibase.Cmd{
976+
Use: "trafficgen",
977+
Hidden: true,
978+
Short: "Generate traffic to a Coder workspace",
979+
Middleware: clibase.Chain(
980+
clibase.RequireRangeArgs(1, 2),
981+
r.InitClient(client),
982+
),
983+
Handler: func(inv *clibase.Invocation) error {
984+
var (
985+
agentName string
986+
tickInterval = 100 * time.Millisecond
987+
)
988+
ws, err := namedWorkspace(inv.Context(), client, inv.Args[0])
989+
if err != nil {
990+
return err
991+
}
992+
993+
var agentID uuid.UUID
994+
for _, res := range ws.LatestBuild.Resources {
995+
if len(res.Agents) == 0 {
996+
continue
997+
}
998+
if agentName != "" && agentName != res.Agents[0].Name {
999+
continue
1000+
}
1001+
agentID = res.Agents[0].ID
1002+
}
1003+
1004+
if agentID == uuid.Nil {
1005+
return xerrors.Errorf("no agent found for workspace %s", ws.Name)
1006+
}
1007+
1008+
// Setup our workspace agent connection.
1009+
reconnect := uuid.New()
1010+
conn, err := client.WorkspaceAgentReconnectingPTY(inv.Context(), codersdk.WorkspaceAgentReconnectingPTYOpts{
1011+
AgentID: agentID,
1012+
Reconnect: reconnect,
1013+
Height: 65535,
1014+
Width: 65535,
1015+
Command: "/bin/sh",
1016+
})
1017+
if err != nil {
1018+
return xerrors.Errorf("connect to workspace: %w", err)
1019+
}
1020+
1021+
defer func() {
1022+
_ = conn.Close()
1023+
}()
1024+
1025+
// Wrap the conn in a countReadWriter so we can monitor bytes sent/rcvd.
1026+
crw := countReadWriter{ReadWriter: conn}
1027+
1028+
// Set a deadline for stopping the text.
1029+
start := time.Now()
1030+
deadlineCtx, cancel := context.WithDeadline(inv.Context(), start.Add(duration))
1031+
defer cancel()
1032+
1033+
// Create a ticker for sending data to the PTY.
1034+
tick := time.NewTicker(tickInterval)
1035+
defer tick.Stop()
1036+
1037+
// Now we begin writing random data to the pty.
1038+
writeSize := int(bps / 10)
1039+
rch := make(chan error)
1040+
wch := make(chan error)
1041+
1042+
// Read forever in the background.
1043+
go func() {
1044+
rch <- readContext(deadlineCtx, &crw, writeSize*2)
1045+
conn.Close()
1046+
close(rch)
1047+
}()
1048+
1049+
// Write random data to the PTY every tick.
1050+
go func() {
1051+
wch <- writeRandomData(deadlineCtx, &crw, writeSize, tick.C)
1052+
close(wch)
1053+
}()
1054+
1055+
// Wait for both our reads and writes to be finished.
1056+
if wErr := <-wch; wErr != nil {
1057+
return xerrors.Errorf("write to pty: %w", wErr)
1058+
}
1059+
if rErr := <-rch; rErr != nil {
1060+
return xerrors.Errorf("read from pty: %w", rErr)
1061+
}
1062+
1063+
duration := time.Since(start)
1064+
1065+
results := trafficGenOutput{
1066+
DurationSeconds: duration.Seconds(),
1067+
SentBytes: crw.BytesWritten(),
1068+
RcvdBytes: crw.BytesRead(),
1069+
}
1070+
1071+
out, err := formatter.Format(inv.Context(), results)
1072+
if err != nil {
1073+
return err
1074+
}
1075+
1076+
_, err = fmt.Fprintln(inv.Stdout, out)
1077+
return err
1078+
},
1079+
}
1080+
1081+
cmd.Options = []clibase.Option{
1082+
{
1083+
Flag: "duration",
1084+
Env: "CODER_SCALETEST_TRAFFICGEN_DURATION",
1085+
Default: "10s",
1086+
Description: "How long to generate traffic for.",
1087+
Value: clibase.DurationOf(&duration),
1088+
},
1089+
{
1090+
Flag: "bps",
1091+
Env: "CODER_SCALETEST_TRAFFICGEN_BPS",
1092+
Default: "1024",
1093+
Description: "How much traffic to generate in bytes per second.",
1094+
Value: clibase.Int64Of(&bps),
1095+
},
1096+
}
1097+
1098+
formatter.AttachOptions(&cmd.Options)
1099+
return cmd
1100+
}
1101+
9501102
type runnableTraceWrapper struct {
9511103
tracer trace.Tracer
9521104
spanName string
@@ -1023,3 +1175,111 @@ func isScaleTestWorkspace(workspace codersdk.Workspace) bool {
10231175
return strings.HasPrefix(workspace.OwnerName, "scaletest-") ||
10241176
strings.HasPrefix(workspace.Name, "scaletest-")
10251177
}
1178+
1179+
func readContext(ctx context.Context, src io.Reader, bufSize int) error {
1180+
buf := make([]byte, bufSize)
1181+
for {
1182+
select {
1183+
case <-ctx.Done():
1184+
return nil
1185+
default:
1186+
if ctx.Err() != nil {
1187+
return nil
1188+
}
1189+
_, err := src.Read(buf)
1190+
if err != nil {
1191+
if xerrors.Is(err, io.EOF) {
1192+
return nil
1193+
}
1194+
return err
1195+
}
1196+
}
1197+
}
1198+
}
1199+
1200+
func writeRandomData(ctx context.Context, dst io.Writer, size int, tick <-chan time.Time) error {
1201+
for {
1202+
select {
1203+
case <-ctx.Done():
1204+
return nil
1205+
case <-tick:
1206+
payload := "#" + mustRandStr(size-1)
1207+
data, err := json.Marshal(codersdk.ReconnectingPTYRequest{
1208+
Data: payload,
1209+
})
1210+
if err != nil {
1211+
return err
1212+
}
1213+
if _, err := copyContext(ctx, dst, data); err != nil {
1214+
return err
1215+
}
1216+
}
1217+
}
1218+
}
1219+
1220+
// copyContext copies from src to dst until ctx is canceled.
1221+
func copyContext(ctx context.Context, dst io.Writer, src []byte) (int, error) {
1222+
var count int
1223+
for {
1224+
select {
1225+
case <-ctx.Done():
1226+
return count, nil
1227+
default:
1228+
if ctx.Err() != nil {
1229+
return count, nil
1230+
}
1231+
n, err := dst.Write(src)
1232+
if err != nil {
1233+
if xerrors.Is(err, io.EOF) {
1234+
// On an EOF, assume that all of src was consumed.
1235+
return len(src), nil
1236+
}
1237+
return count, err
1238+
}
1239+
count += n
1240+
if n == len(src) {
1241+
return count, nil
1242+
}
1243+
// Not all of src was consumed. Update src and retry.
1244+
src = src[n:]
1245+
}
1246+
}
1247+
}
1248+
1249+
type countReadWriter struct {
1250+
io.ReadWriter
1251+
bytesRead atomic.Int64
1252+
bytesWritten atomic.Int64
1253+
}
1254+
1255+
func (w *countReadWriter) Read(p []byte) (int, error) {
1256+
n, err := w.ReadWriter.Read(p)
1257+
if err == nil {
1258+
w.bytesRead.Add(int64(n))
1259+
}
1260+
return n, err
1261+
}
1262+
1263+
func (w *countReadWriter) Write(p []byte) (int, error) {
1264+
n, err := w.ReadWriter.Write(p)
1265+
if err == nil {
1266+
w.bytesWritten.Add(int64(n))
1267+
}
1268+
return n, err
1269+
}
1270+
1271+
func (w *countReadWriter) BytesRead() int64 {
1272+
return w.bytesRead.Load()
1273+
}
1274+
1275+
func (w *countReadWriter) BytesWritten() int64 {
1276+
return w.bytesWritten.Load()
1277+
}
1278+
1279+
func mustRandStr(len int) string {
1280+
randStr, err := cryptorand.String(len)
1281+
if err != nil {
1282+
panic(err)
1283+
}
1284+
return randStr
1285+
}

0 commit comments

Comments
 (0)