diff --git a/cli/exp_scaletest.go b/cli/exp_scaletest.go index 9c88272e951a0..b1bafbdbb6c77 100644 --- a/cli/exp_scaletest.go +++ b/cli/exp_scaletest.go @@ -21,6 +21,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "go.opentelemetry.io/otel/trace" + "golang.org/x/exp/slices" "golang.org/x/xerrors" "cdr.dev/slog" @@ -859,6 +860,7 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *clibase.Cmd { tickInterval time.Duration bytesPerTick int64 ssh bool + app string template string client = &codersdk.Client{} @@ -911,6 +913,11 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *clibase.Cmd { } } + appHost, err := client.AppHost(ctx) + if err != nil { + return xerrors.Errorf("get app host: %w", err) + } + workspaces, err := getScaletestWorkspaces(inv.Context(), client, template) if err != nil { return err @@ -945,35 +952,39 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *clibase.Cmd { th := harness.NewTestHarness(strategy.toStrategy(), cleanupStrategy.toStrategy()) for idx, ws := range workspaces { var ( - agentID uuid.UUID - agentName string - name = "workspace-traffic" - id = strconv.Itoa(idx) + agent codersdk.WorkspaceAgent + name = "workspace-traffic" + id = strconv.Itoa(idx) ) for _, res := range ws.LatestBuild.Resources { if len(res.Agents) == 0 { continue } - agentID = res.Agents[0].ID - agentName = res.Agents[0].Name + agent = res.Agents[0] } - if agentID == uuid.Nil { + if agent.ID == uuid.Nil { _, _ = fmt.Fprintf(inv.Stderr, "WARN: skipping workspace %s: no agent\n", ws.Name) continue } + appConfig, err := createWorkspaceAppConfig(client, appHost.Host, app, ws, agent) + if err != nil { + return xerrors.Errorf("configure workspace app: %w", err) + } + // Setup our workspace agent connection. config := workspacetraffic.Config{ - AgentID: agentID, + AgentID: agent.ID, BytesPerTick: bytesPerTick, Duration: strategy.timeout, TickInterval: tickInterval, - ReadMetrics: metrics.ReadMetrics(ws.OwnerName, ws.Name, agentName), - WriteMetrics: metrics.WriteMetrics(ws.OwnerName, ws.Name, agentName), + ReadMetrics: metrics.ReadMetrics(ws.OwnerName, ws.Name, agent.Name), + WriteMetrics: metrics.WriteMetrics(ws.OwnerName, ws.Name, agent.Name), SSH: ssh, Echo: ssh, + App: appConfig, } if err := config.Validate(); err != nil { @@ -1046,9 +1057,16 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *clibase.Cmd { Flag: "ssh", Env: "CODER_SCALETEST_WORKSPACE_TRAFFIC_SSH", Default: "", - Description: "Send traffic over SSH.", + Description: "Send traffic over SSH, cannot be used with --app.", Value: clibase.BoolOf(&ssh), }, + { + Flag: "app", + Env: "CODER_SCALETEST_WORKSPACE_TRAFFIC_APP", + Default: "", + Description: "Send WebSocket traffic to a workspace app (proxied via coderd), cannot be used with --ssh.", + Value: clibase.StringOf(&app), + }, } tracingFlags.attach(&cmd.Options) @@ -1411,3 +1429,29 @@ func parseTemplate(ctx context.Context, client *codersdk.Client, organizationIDs return tpl, nil } + +func createWorkspaceAppConfig(client *codersdk.Client, appHost, app string, workspace codersdk.Workspace, agent codersdk.WorkspaceAgent) (workspacetraffic.AppConfig, error) { + if app == "" { + return workspacetraffic.AppConfig{}, nil + } + + i := slices.IndexFunc(agent.Apps, func(a codersdk.WorkspaceApp) bool { return a.Slug == app }) + if i == -1 { + return workspacetraffic.AppConfig{}, xerrors.Errorf("app %q not found in workspace %q", app, workspace.Name) + } + + c := workspacetraffic.AppConfig{ + Name: agent.Apps[i].Slug, + } + if agent.Apps[i].Subdomain { + if appHost == "" { + return workspacetraffic.AppConfig{}, xerrors.Errorf("app %q is a subdomain app but no app host is configured", app) + } + + c.URL = fmt.Sprintf("%s://%s", client.URL.Scheme, strings.Replace(appHost, "*", agent.Apps[i].SubdomainName, 1)) + } else { + c.URL = fmt.Sprintf("%s/@%s/%s.%s/apps/%s", client.URL.String(), workspace.OwnerName, workspace.Name, agent.Name, agent.Apps[i].Slug) + } + + return c, nil +} diff --git a/scaletest/workspacetraffic/config.go b/scaletest/workspacetraffic/config.go index 46c7a94b4ed29..71134a454a411 100644 --- a/scaletest/workspacetraffic/config.go +++ b/scaletest/workspacetraffic/config.go @@ -31,6 +31,8 @@ type Config struct { // to true will double the amount of data read from the agent for // PTYs (e.g. reconnecting pty or SSH connections that request PTY). Echo bool `json:"echo"` + + App AppConfig `json:"app"` } func (c Config) Validate() error { @@ -50,5 +52,14 @@ func (c Config) Validate() error { return xerrors.Errorf("validate tick_interval: must be greater than zero") } + if c.SSH && c.App.Name != "" { + return xerrors.Errorf("validate ssh: must be false when app is used") + } + return nil } + +type AppConfig struct { + Name string `json:"name"` + URL string `json:"url"` +} diff --git a/scaletest/workspacetraffic/conn.go b/scaletest/workspacetraffic/conn.go index c7b3daf6c7c73..31dfaf99c76bd 100644 --- a/scaletest/workspacetraffic/conn.go +++ b/scaletest/workspacetraffic/conn.go @@ -5,9 +5,13 @@ import ( "encoding/json" "errors" "io" + "net" + "net/http" "sync" "time" + "nhooyr.io/websocket" + "github.com/coder/coder/v2/codersdk" "github.com/google/uuid" @@ -260,3 +264,118 @@ func (w *wrappedSSHConn) Read(p []byte) (n int, err error) { func (w *wrappedSSHConn) Write(p []byte) (n int, err error) { return w.stdin.Write(p) } + +func appClientConn(ctx context.Context, client *codersdk.Client, url string) (*countReadWriteCloser, error) { + headers := http.Header{} + tokenHeader := codersdk.SessionTokenHeader + if client.SessionTokenHeader != "" { + tokenHeader = client.SessionTokenHeader + } + headers.Set(tokenHeader, client.SessionToken()) + + //nolint:bodyclose // The websocket conn manages the body. + conn, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{ + HTTPClient: client.HTTPClient, + HTTPHeader: headers, + }) + if err != nil { + return nil, xerrors.Errorf("websocket dial: %w", err) + } + + netConn := websocketNetConn(conn, websocket.MessageBinary) + + // Wrap the conn in a countReadWriteCloser so we can monitor bytes sent/rcvd. + crw := &countReadWriteCloser{rwc: netConn} + return crw, nil +} + +// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func +// is called if a read or write error is encountered. +type wsNetConn struct { + net.Conn + + writeMu sync.Mutex + readMu sync.Mutex + + cancel context.CancelFunc + closeMu sync.Mutex + closed bool +} + +func (c *wsNetConn) Read(b []byte) (n int, err error) { + c.readMu.Lock() + defer c.readMu.Unlock() + if c.isClosed() { + return 0, io.EOF + } + n, err = c.Conn.Read(b) + if err != nil { + if c.isClosed() { + return n, io.EOF + } + return n, err + } + return n, nil +} + +func (c *wsNetConn) Write(b []byte) (n int, err error) { + c.writeMu.Lock() + defer c.writeMu.Unlock() + if c.isClosed() { + return 0, io.EOF + } + + for len(b) > 0 { + bb := b + if len(bb) > rptyJSONMaxDataSize { + bb = b[:rptyJSONMaxDataSize] + } + b = b[len(bb):] + nn, err := c.Conn.Write(bb) + n += nn + if err != nil { + if c.isClosed() { + return n, io.EOF + } + return n, err + } + } + return n, nil +} + +func (c *wsNetConn) isClosed() bool { + c.closeMu.Lock() + defer c.closeMu.Unlock() + return c.closed +} + +func (c *wsNetConn) Close() error { + c.closeMu.Lock() + closed := c.closed + c.closed = true + c.closeMu.Unlock() + + if closed { + return nil + } + + // Cancel before acquiring locks to speed up teardown. + c.cancel() + + c.readMu.Lock() + defer c.readMu.Unlock() + c.writeMu.Lock() + defer c.writeMu.Unlock() + + _ = c.Conn.Close() + return nil +} + +func websocketNetConn(conn *websocket.Conn, msgType websocket.MessageType) net.Conn { + // Since `websocket.NetConn` binds to a context for the lifetime of the + // connection, we need to create a new context that can be canceled when + // the connection is closed. + ctx, cancel := context.WithCancel(context.Background()) + nc := websocket.NetConn(ctx, conn, msgType) + return &wsNetConn{cancel: cancel, Conn: nc} +} diff --git a/scaletest/workspacetraffic/run.go b/scaletest/workspacetraffic/run.go index 27a81f2da7d75..c683536461bbc 100644 --- a/scaletest/workspacetraffic/run.go +++ b/scaletest/workspacetraffic/run.go @@ -91,7 +91,16 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) (err error) command := fmt.Sprintf("dd if=/dev/stdin of=%s bs=%d status=none", output, bytesPerTick) var conn *countReadWriteCloser - if r.cfg.SSH { + switch { + case r.cfg.App.Name != "": + logger.Info(ctx, "sending traffic to workspace app", slog.F("app", r.cfg.App.Name)) + conn, err = appClientConn(ctx, r.client, r.cfg.App.URL) + if err != nil { + logger.Error(ctx, "connect to workspace app", slog.Error(err)) + return xerrors.Errorf("connect to workspace app: %w", err) + } + + case r.cfg.SSH: logger.Info(ctx, "connecting to workspace agent", slog.F("method", "ssh")) // If echo is enabled, disable PTY to avoid double echo and // reduce CPU usage. @@ -101,7 +110,8 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) (err error) logger.Error(ctx, "connect to workspace agent via ssh", slog.Error(err)) return xerrors.Errorf("connect to workspace via ssh: %w", err) } - } else { + + default: logger.Info(ctx, "connecting to workspace agent", slog.F("method", "reconnectingpty")) conn, err = connectRPTY(ctx, r.client, agentID, reconnect, command) if err != nil { @@ -114,8 +124,8 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) (err error) closeConn := func() error { closeOnce.Do(func() { closeErr = conn.Close() - if err != nil { - logger.Error(ctx, "close agent connection", slog.Error(err)) + if closeErr != nil { + logger.Error(ctx, "close agent connection", slog.Error(closeErr)) } }) return closeErr @@ -142,7 +152,6 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) (err error) // Read until connection is closed. go func() { - rch := rch // Shadowed for reassignment. logger.Debug(ctx, "reading from agent") rch <- drain(conn) logger.Debug(ctx, "done reading from agent") @@ -151,7 +160,6 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) (err error) // Write random data to the conn every tick. go func() { - wch := wch // Shadowed for reassignment. logger.Debug(ctx, "writing to agent") wch <- writeRandomData(conn, bytesPerTick, tick.C) logger.Debug(ctx, "done writing to agent") @@ -160,16 +168,17 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) (err error) var waitCloseTimeoutCh <-chan struct{} deadlineCtxCh := deadlineCtx.Done() + wchRef, rchRef := wch, rch for { - if wch == nil && rch == nil { + if wchRef == nil && rchRef == nil { return nil } select { case <-waitCloseTimeoutCh: logger.Warn(ctx, "timed out waiting for read/write to complete", - slog.F("write_done", wch == nil), - slog.F("read_done", rch == nil), + slog.F("write_done", wchRef == nil), + slog.F("read_done", rchRef == nil), ) return xerrors.Errorf("timed out waiting for read/write to complete: %w", ctx.Err()) case <-deadlineCtxCh: @@ -181,16 +190,16 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) (err error) waitCtx, cancel := context.WithTimeout(context.Background(), waitCloseTimeout) defer cancel() //nolint:revive // Only called once. waitCloseTimeoutCh = waitCtx.Done() - case err = <-wch: + case err = <-wchRef: if err != nil { return xerrors.Errorf("write to agent: %w", err) } - wch = nil - case err = <-rch: + wchRef = nil + case err = <-rchRef: if err != nil { return xerrors.Errorf("read from agent: %w", err) } - rch = nil + rchRef = nil } } } diff --git a/scaletest/workspacetraffic/run_test.go b/scaletest/workspacetraffic/run_test.go index a2f9d609a5e54..a177390f9fd96 100644 --- a/scaletest/workspacetraffic/run_test.go +++ b/scaletest/workspacetraffic/run_test.go @@ -2,6 +2,10 @@ package workspacetraffic_test import ( "context" + "errors" + "io" + "net/http" + "net/http/httptest" "runtime" "strings" "sync" @@ -9,6 +13,7 @@ import ( "time" "golang.org/x/exp/slices" + "nhooyr.io/websocket" "github.com/coder/coder/v2/agent/agenttest" "github.com/coder/coder/v2/coderd/coderdtest" @@ -138,13 +143,11 @@ func TestRun(t *testing.T) { t.Logf("bytes read total: %.0f\n", readMetrics.Total()) t.Logf("bytes written total: %.0f\n", writeMetrics.Total()) - // We want to ensure the metrics are somewhat accurate. - // TODO: https://github.com/coder/coder/issues/11175 - // assert.InDelta(t, bytesPerTick, writeMetrics.Total(), 0.1) - - // Read is highly variable, depending on how far we read before stopping. - // Just ensure it's not zero. + // Ensure something was both read and written. assert.NotZero(t, readMetrics.Total()) + assert.NotZero(t, writeMetrics.Total()) + // We want to ensure the metrics are somewhat accurate. + assert.InDelta(t, writeMetrics.Total(), readMetrics.Total(), float64(bytesPerTick)*10) // Latency should report non-zero values. assert.NotEmpty(t, readMetrics.Latencies()) assert.NotEmpty(t, writeMetrics.Latencies()) @@ -260,13 +263,106 @@ func TestRun(t *testing.T) { t.Logf("bytes read total: %.0f\n", readMetrics.Total()) t.Logf("bytes written total: %.0f\n", writeMetrics.Total()) + // Ensure something was both read and written. + assert.NotZero(t, readMetrics.Total()) + assert.NotZero(t, writeMetrics.Total()) // We want to ensure the metrics are somewhat accurate. - // TODO: https://github.com/coder/coder/issues/11175 - // assert.InDelta(t, bytesPerTick, writeMetrics.Total(), 0.1) + assert.InDelta(t, writeMetrics.Total(), readMetrics.Total(), float64(bytesPerTick)*10) + // Latency should report non-zero values. + assert.NotEmpty(t, readMetrics.Latencies()) + assert.NotEmpty(t, writeMetrics.Latencies()) + // Should not report any errors! + assert.Zero(t, readMetrics.Errors()) + assert.Zero(t, writeMetrics.Errors()) + }) + + t.Run("App", func(t *testing.T) { + t.Parallel() + + // Start a test server that will echo back the request body, this skips + // the roundtrip to coderd/agent and simply tests the http request conn + // directly. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{}) + if err != nil { + t.Error(err) + return + } + + nc := websocket.NetConn(context.Background(), c, websocket.MessageBinary) + defer nc.Close() + + _, err = io.Copy(nc, nc) + if err == nil || errors.Is(err, io.EOF) { + return + } + t.Error(err) + })) + defer srv.Close() + + // Now we can start the runner. + var ( + bytesPerTick = 1024 + tickInterval = 1000 * time.Millisecond + readMetrics = &testMetrics{} + writeMetrics = &testMetrics{} + ) + client := &codersdk.Client{ + HTTPClient: &http.Client{}, + } + runner := workspacetraffic.NewRunner(client, workspacetraffic.Config{ + BytesPerTick: int64(bytesPerTick), + TickInterval: tickInterval, + Duration: testutil.WaitLong, + ReadMetrics: readMetrics, + WriteMetrics: writeMetrics, + App: workspacetraffic.AppConfig{ + Name: "echo", + URL: srv.URL, + }, + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var logs strings.Builder + + runDone := make(chan struct{}) + go func() { + defer close(runDone) + err := runner.Run(ctx, "", &logs) + assert.NoError(t, err, "unexpected error calling Run()") + }() - // Read is highly variable, depending on how far we read before stopping. - // Just ensure it's not zero. + gotMetrics := make(chan struct{}) + go func() { + defer close(gotMetrics) + // Wait until we get some non-zero metrics before canceling. + assert.Eventually(t, func() bool { + readLatencies := readMetrics.Latencies() + writeLatencies := writeMetrics.Latencies() + return len(readLatencies) > 0 && + len(writeLatencies) > 0 && + slices.ContainsFunc(readLatencies, func(f float64) bool { return f > 0.0 }) && + slices.ContainsFunc(writeLatencies, func(f float64) bool { return f > 0.0 }) + }, testutil.WaitLong, testutil.IntervalMedium, "expected non-zero metrics") + }() + + // Stop the test after we get some non-zero metrics. + <-gotMetrics + cancel() + <-runDone + + t.Logf("read errors: %.0f\n", readMetrics.Errors()) + t.Logf("write errors: %.0f\n", writeMetrics.Errors()) + t.Logf("bytes read total: %.0f\n", readMetrics.Total()) + t.Logf("bytes written total: %.0f\n", writeMetrics.Total()) + + // Ensure something was both read and written. assert.NotZero(t, readMetrics.Total()) + assert.NotZero(t, writeMetrics.Total()) + // We want to ensure the metrics are somewhat accurate. + assert.InDelta(t, writeMetrics.Total(), readMetrics.Total(), float64(bytesPerTick)*10) // Latency should report non-zero values. assert.NotEmpty(t, readMetrics.Latencies()) assert.NotEmpty(t, writeMetrics.Latencies())