Skip to content

Commit 278527c

Browse files
authored
feat(scaletest): add option to send traffic over SSH (coder#8521)
- Refactors the metrics logic to avoid needing to pass in a whole prometheus registry - Adds an --ssh option to the workspace-traffic command to send SSH traffic Fixes coder#8242
1 parent ab54008 commit 278527c

File tree

8 files changed

+613
-251
lines changed

8 files changed

+613
-251
lines changed

cli/exp_scaletest.go

+16-9
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,7 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *clibase.Cmd {
848848
var (
849849
tickInterval time.Duration
850850
bytesPerTick int64
851+
ssh bool
851852
scaletestPrometheusAddress string
852853
scaletestPrometheusWait time.Duration
853854

@@ -938,20 +939,19 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *clibase.Cmd {
938939

939940
// Setup our workspace agent connection.
940941
config := workspacetraffic.Config{
941-
AgentID: agentID,
942-
AgentName: agentName,
943-
BytesPerTick: bytesPerTick,
944-
Duration: strategy.timeout,
945-
TickInterval: tickInterval,
946-
WorkspaceName: ws.Name,
947-
WorkspaceOwner: ws.OwnerName,
948-
Registry: reg,
942+
AgentID: agentID,
943+
BytesPerTick: bytesPerTick,
944+
Duration: strategy.timeout,
945+
TickInterval: tickInterval,
946+
ReadMetrics: metrics.ReadMetrics(ws.OwnerName, ws.Name, agentName),
947+
WriteMetrics: metrics.WriteMetrics(ws.OwnerName, ws.Name, agentName),
948+
SSH: ssh,
949949
}
950950

951951
if err := config.Validate(); err != nil {
952952
return xerrors.Errorf("validate config: %w", err)
953953
}
954-
var runner harness.Runnable = workspacetraffic.NewRunner(client, config, metrics)
954+
var runner harness.Runnable = workspacetraffic.NewRunner(client, config)
955955
if tracingEnabled {
956956
runner = &runnableTraceWrapper{
957957
tracer: tracer,
@@ -1002,6 +1002,13 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *clibase.Cmd {
10021002
Description: "How often to send traffic.",
10031003
Value: clibase.DurationOf(&tickInterval),
10041004
},
1005+
{
1006+
Flag: "ssh",
1007+
Env: "CODER_SCALETEST_WORKSPACE_TRAFFIC_SSH",
1008+
Default: "",
1009+
Description: "Send traffic over SSH.",
1010+
Value: clibase.BoolOf(&ssh),
1011+
},
10051012
{
10061013
Flag: "scaletest-prometheus-address",
10071014
Env: "CODER_SCALETEST_PROMETHEUS_ADDRESS",

cli/exp_scaletest_test.go

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ func TestScaleTestWorkspaceTraffic(t *testing.T) {
6969
"--tick-interval", "100ms",
7070
"--scaletest-prometheus-address", "127.0.0.1:0",
7171
"--scaletest-prometheus-wait", "0s",
72+
"--ssh",
7273
)
7374
clitest.SetupConfig(t, client, root)
7475
var stdout, stderr bytes.Buffer

scaletest/workspacetraffic/config.go

+4-12
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,13 @@ import (
44
"time"
55

66
"github.com/google/uuid"
7-
"github.com/prometheus/client_golang/prometheus"
87
"golang.org/x/xerrors"
98
)
109

1110
type Config struct {
1211
// AgentID is the workspace agent ID to which to connect.
1312
AgentID uuid.UUID `json:"agent_id"`
1413

15-
// AgentName is the name of the agent. Used for metrics.
16-
AgentName string `json:"agent_name"`
17-
18-
// WorkspaceName is the name of the workspace. Used for metrics.
19-
WorkspaceName string `json:"workspace_name"`
20-
21-
// WorkspaceOwner is the owner of the workspace. Used for metrics.
22-
WorkspaceOwner string `json:"workspace_owner"`
23-
2414
// BytesPerTick is the number of bytes to send to the agent per tick.
2515
BytesPerTick int64 `json:"bytes_per_tick"`
2616

@@ -31,8 +21,10 @@ type Config struct {
3121
// send data to workspace agents).
3222
TickInterval time.Duration `json:"tick_interval"`
3323

34-
// Registry is a prometheus.Registerer for logging metrics
35-
Registry prometheus.Registerer
24+
ReadMetrics ConnMetrics `json:"-"`
25+
WriteMetrics ConnMetrics `json:"-"`
26+
27+
SSH bool `json:"ssh"`
3628
}
3729

3830
func (c Config) Validate() error {

scaletest/workspacetraffic/conn.go

+123
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
package workspacetraffic
2+
3+
import (
4+
"context"
5+
"io"
6+
"sync"
7+
8+
"github.com/coder/coder/codersdk"
9+
10+
"github.com/google/uuid"
11+
"github.com/hashicorp/go-multierror"
12+
gossh "golang.org/x/crypto/ssh"
13+
"golang.org/x/xerrors"
14+
)
15+
16+
func connectPTY(ctx context.Context, client *codersdk.Client, agentID, reconnect uuid.UUID) (*countReadWriteCloser, error) {
17+
conn, err := client.WorkspaceAgentReconnectingPTY(ctx, codersdk.WorkspaceAgentReconnectingPTYOpts{
18+
AgentID: agentID,
19+
Reconnect: reconnect,
20+
Height: 25,
21+
Width: 80,
22+
Command: "/bin/sh",
23+
})
24+
if err != nil {
25+
return nil, xerrors.Errorf("connect pty: %w", err)
26+
}
27+
28+
// Wrap the conn in a countReadWriteCloser so we can monitor bytes sent/rcvd.
29+
crw := countReadWriteCloser{ctx: ctx, rwc: conn}
30+
return &crw, nil
31+
}
32+
33+
func connectSSH(ctx context.Context, client *codersdk.Client, agentID uuid.UUID) (*countReadWriteCloser, error) {
34+
agentConn, err := client.DialWorkspaceAgent(ctx, agentID, &codersdk.DialWorkspaceAgentOptions{})
35+
if err != nil {
36+
return nil, xerrors.Errorf("dial workspace agent: %w", err)
37+
}
38+
agentConn.AwaitReachable(ctx)
39+
sshClient, err := agentConn.SSHClient(ctx)
40+
if err != nil {
41+
return nil, xerrors.Errorf("get ssh client: %w", err)
42+
}
43+
sshSession, err := sshClient.NewSession()
44+
if err != nil {
45+
_ = agentConn.Close()
46+
return nil, xerrors.Errorf("new ssh session: %w", err)
47+
}
48+
wrappedConn := &wrappedSSHConn{ctx: ctx}
49+
// Do some plumbing to hook up the wrappedConn
50+
pr1, pw1 := io.Pipe()
51+
wrappedConn.stdout = pr1
52+
sshSession.Stdout = pw1
53+
pr2, pw2 := io.Pipe()
54+
sshSession.Stdin = pr2
55+
wrappedConn.stdin = pw2
56+
err = sshSession.RequestPty("xterm", 25, 80, gossh.TerminalModes{})
57+
if err != nil {
58+
_ = pr1.Close()
59+
_ = pr2.Close()
60+
_ = pw1.Close()
61+
_ = pw2.Close()
62+
_ = sshSession.Close()
63+
_ = agentConn.Close()
64+
return nil, xerrors.Errorf("request pty: %w", err)
65+
}
66+
err = sshSession.Shell()
67+
if err != nil {
68+
_ = sshSession.Close()
69+
_ = agentConn.Close()
70+
return nil, xerrors.Errorf("shell: %w", err)
71+
}
72+
73+
closeFn := func() error {
74+
var merr error
75+
if err := sshSession.Close(); err != nil {
76+
merr = multierror.Append(merr, err)
77+
}
78+
if err := agentConn.Close(); err != nil {
79+
merr = multierror.Append(merr, err)
80+
}
81+
return merr
82+
}
83+
wrappedConn.close = closeFn
84+
85+
crw := &countReadWriteCloser{ctx: ctx, rwc: wrappedConn}
86+
return crw, nil
87+
}
88+
89+
// wrappedSSHConn wraps an ssh.Session to implement io.ReadWriteCloser.
90+
type wrappedSSHConn struct {
91+
ctx context.Context
92+
stdout io.Reader
93+
stdin io.Writer
94+
closeOnce sync.Once
95+
closeErr error
96+
close func() error
97+
}
98+
99+
func (w *wrappedSSHConn) Close() error {
100+
w.closeOnce.Do(func() {
101+
_, _ = w.stdin.Write([]byte("exit\n"))
102+
w.closeErr = w.close()
103+
})
104+
return w.closeErr
105+
}
106+
107+
func (w *wrappedSSHConn) Read(p []byte) (n int, err error) {
108+
select {
109+
case <-w.ctx.Done():
110+
return 0, xerrors.Errorf("read: %w", w.ctx.Err())
111+
default:
112+
return w.stdout.Read(p)
113+
}
114+
}
115+
116+
func (w *wrappedSSHConn) Write(p []byte) (n int, err error) {
117+
select {
118+
case <-w.ctx.Done():
119+
return 0, xerrors.Errorf("write: %w", w.ctx.Err())
120+
default:
121+
return w.stdin.Write(p)
122+
}
123+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package workspacetraffic
2+
3+
import (
4+
"context"
5+
"errors"
6+
"io"
7+
"time"
8+
9+
"golang.org/x/xerrors"
10+
11+
"nhooyr.io/websocket"
12+
)
13+
14+
// countReadWriteCloser wraps an io.ReadWriteCloser and counts the number of bytes read and written.
15+
type countReadWriteCloser struct {
16+
ctx context.Context
17+
rwc io.ReadWriteCloser
18+
readMetrics ConnMetrics
19+
writeMetrics ConnMetrics
20+
}
21+
22+
func (w *countReadWriteCloser) Close() error {
23+
return w.rwc.Close()
24+
}
25+
26+
func (w *countReadWriteCloser) Read(p []byte) (int, error) {
27+
start := time.Now()
28+
n, err := w.rwc.Read(p)
29+
took := time.Since(start).Seconds()
30+
if reportableErr(err) {
31+
w.readMetrics.AddError(1)
32+
}
33+
w.readMetrics.ObserveLatency(took)
34+
if n > 0 {
35+
w.readMetrics.AddTotal(float64(n))
36+
}
37+
return n, err
38+
}
39+
40+
func (w *countReadWriteCloser) Write(p []byte) (int, error) {
41+
start := time.Now()
42+
n, err := w.rwc.Write(p)
43+
took := time.Since(start).Seconds()
44+
if reportableErr(err) {
45+
w.writeMetrics.AddError(1)
46+
}
47+
w.writeMetrics.ObserveLatency(took)
48+
if n > 0 {
49+
w.writeMetrics.AddTotal(float64(n))
50+
}
51+
return n, err
52+
}
53+
54+
// some errors we want to report in metrics; others we want to ignore
55+
// such as websocket.StatusNormalClosure or context.Canceled
56+
func reportableErr(err error) bool {
57+
if err == nil {
58+
return false
59+
}
60+
if xerrors.Is(err, io.EOF) {
61+
return false
62+
}
63+
if xerrors.Is(err, context.Canceled) {
64+
return false
65+
}
66+
var wsErr websocket.CloseError
67+
if errors.As(err, &wsErr) {
68+
return wsErr.Code != websocket.StatusNormalClosure
69+
}
70+
return false
71+
}

scaletest/workspacetraffic/metrics.go

+40
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,43 @@ func NewMetrics(reg prometheus.Registerer, labelNames ...string) *Metrics {
5454
reg.MustRegister(m.WriteLatencySeconds)
5555
return m
5656
}
57+
58+
func (m *Metrics) ReadMetrics(lvs ...string) ConnMetrics {
59+
return &connMetrics{
60+
addError: m.ReadErrorsTotal.WithLabelValues(lvs...).Add,
61+
observeLatency: m.ReadLatencySeconds.WithLabelValues(lvs...).Observe,
62+
addTotal: m.BytesReadTotal.WithLabelValues(lvs...).Add,
63+
}
64+
}
65+
66+
func (m *Metrics) WriteMetrics(lvs ...string) ConnMetrics {
67+
return &connMetrics{
68+
addError: m.WriteErrorsTotal.WithLabelValues(lvs...).Add,
69+
observeLatency: m.WriteLatencySeconds.WithLabelValues(lvs...).Observe,
70+
addTotal: m.BytesWrittenTotal.WithLabelValues(lvs...).Add,
71+
}
72+
}
73+
74+
type ConnMetrics interface {
75+
AddError(float64)
76+
ObserveLatency(float64)
77+
AddTotal(float64)
78+
}
79+
80+
type connMetrics struct {
81+
addError func(float64)
82+
observeLatency func(float64)
83+
addTotal func(float64)
84+
}
85+
86+
func (c *connMetrics) AddError(f float64) {
87+
c.addError(f)
88+
}
89+
90+
func (c *connMetrics) ObserveLatency(f float64) {
91+
c.observeLatency(f)
92+
}
93+
94+
func (c *connMetrics) AddTotal(f float64) {
95+
c.addTotal(f)
96+
}

0 commit comments

Comments
 (0)