Skip to content

feat: add reconnectingpty loadtest #5083

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,7 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.ReconnectingPTYInit, conn net.Conn) {
defer conn.Close()

connectionID := uuid.NewString()
var rpty *reconnectingPTY
rawRPTY, ok := a.reconnectingPTYs.Load(msg.ID)
if ok {
Expand Down Expand Up @@ -760,8 +761,12 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec
a.closeMutex.Unlock()
ctx, cancelFunc := context.WithCancel(ctx)
rpty = &reconnectingPTY{
activeConns: make(map[string]net.Conn),
ptty: ptty,
activeConns: map[string]net.Conn{
// We have to put the connection in the map instantly otherwise
// the connection won't be closed if the process instantly dies.
connectionID: conn,
},
ptty: ptty,
// Timeouts created with an after func can be reset!
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancelFunc),
circularBuffer: circularBuffer,
Expand Down Expand Up @@ -827,7 +832,6 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec
a.logger.Warn(ctx, "write reconnecting pty buffer", slog.F("id", msg.ID), slog.Error(err))
return
}
connectionID := uuid.NewString()
// Multiple connections to the same TTY are permitted.
// This could easily be used for terminal sharing, but
// we do it because it's a nice user experience to
Expand Down
4 changes: 2 additions & 2 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func TestAgent(t *testing.T) {

conn, stats, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)

ptyConn, err := conn.ReconnectingPTY(ctx, uuid.NewString(), 128, 128, "/bin/bash")
ptyConn, err := conn.ReconnectingPTY(ctx, uuid.New(), 128, 128, "/bin/bash")
require.NoError(t, err)
defer ptyConn.Close()

Expand Down Expand Up @@ -405,7 +405,7 @@ func TestAgent(t *testing.T) {
defer cancel()

conn, _, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
id := uuid.NewString()
id := uuid.New()
netConn, err := conn.ReconnectingPTY(ctx, id, 100, 100, "/bin/bash")
require.NoError(t, err)
bufRead := bufio.NewReader(netConn)
Expand Down
24 changes: 21 additions & 3 deletions cli/loadtestconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/coder/coder/loadtest/agentconn"
"github.com/coder/coder/loadtest/harness"
"github.com/coder/coder/loadtest/placebo"
"github.com/coder/coder/loadtest/reconnectingpty"
"github.com/coder/coder/loadtest/workspacebuild"
)

Expand Down Expand Up @@ -88,9 +89,10 @@ func (s LoadTestStrategy) ExecutionStrategy() harness.ExecutionStrategy {
type LoadTestType string

const (
LoadTestTypeAgentConn LoadTestType = "agentconn"
LoadTestTypePlacebo LoadTestType = "placebo"
LoadTestTypeWorkspaceBuild LoadTestType = "workspacebuild"
LoadTestTypeAgentConn LoadTestType = "agentconn"
LoadTestTypePlacebo LoadTestType = "placebo"
LoadTestTypeReconnectingPTY LoadTestType = "reconnectingpty"
LoadTestTypeWorkspaceBuild LoadTestType = "workspacebuild"
)

type LoadTest struct {
Expand All @@ -104,6 +106,8 @@ type LoadTest struct {
AgentConn *agentconn.Config `json:"agentconn,omitempty"`
// Placebo must be set if type == "placebo".
Placebo *placebo.Config `json:"placebo,omitempty"`
// ReconnectingPTY must be set if type == "reconnectingpty".
ReconnectingPTY *reconnectingpty.Config `json:"reconnectingpty,omitempty"`
// WorkspaceBuild must be set if type == "workspacebuild".
WorkspaceBuild *workspacebuild.Config `json:"workspacebuild,omitempty"`
}
Expand All @@ -120,6 +124,11 @@ func (t LoadTest) NewRunner(client *codersdk.Client) (harness.Runnable, error) {
return nil, xerrors.New("placebo config must be set")
}
return placebo.NewRunner(*t.Placebo), nil
case LoadTestTypeReconnectingPTY:
if t.ReconnectingPTY == nil {
return nil, xerrors.New("reconnectingpty config must be set")
}
return reconnectingpty.NewRunner(client, *t.ReconnectingPTY), nil
case LoadTestTypeWorkspaceBuild:
if t.WorkspaceBuild == nil {
return nil, xerrors.Errorf("workspacebuild config must be set")
Expand Down Expand Up @@ -185,6 +194,15 @@ func (t *LoadTest) Validate() error {
if err != nil {
return xerrors.Errorf("validate placebo: %w", err)
}
case LoadTestTypeReconnectingPTY:
if t.ReconnectingPTY == nil {
return xerrors.Errorf("reconnectingpty test type must specify reconnectingpty")
}

err := t.ReconnectingPTY.Validate()
if err != nil {
return xerrors.Errorf("validate reconnectingpty: %w", err)
}
case LoadTestTypeWorkspaceBuild:
if t.WorkspaceBuild == nil {
return xerrors.New("workspacebuild test type must specify workspacebuild")
Expand Down
10 changes: 3 additions & 7 deletions coderd/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/netip"
Expand All @@ -26,6 +25,7 @@ import (
"tailscale.com/tailcfg"

"cdr.dev/slog"
"github.com/coder/coder/agent"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/gitauth"
"github.com/coder/coder/coderd/httpapi"
Expand Down Expand Up @@ -247,17 +247,13 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
return
}
defer release()
ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect.String(), uint16(height), uint16(width), r.URL.Query().Get("command"))
ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect, uint16(height), uint16(width), r.URL.Query().Get("command"))
if err != nil {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial: %s", err))
return
}
defer ptNetConn.Close()
// Pipe the ends together!
go func() {
_, _ = io.Copy(wsNetConn, ptNetConn)
}()
_, _ = io.Copy(ptNetConn, wsNetConn)
agent.Bicopy(ctx, wsNetConn, ptNetConn)
}

func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Request) {
Expand Down
5 changes: 3 additions & 2 deletions codersdk/agentconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"strings"
"time"

"github.com/google/uuid"
"golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
"tailscale.com/net/speedtest"
Expand Down Expand Up @@ -158,13 +159,13 @@ func (c *AgentConn) Close() error {

// @typescript-ignore ReconnectingPTYInit
type ReconnectingPTYInit struct {
ID string
ID uuid.UUID
Height uint16
Width uint16
Command string
}

func (c *AgentConn) ReconnectingPTY(ctx context.Context, id string, height, width uint16, command string) (net.Conn, error) {
func (c *AgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string) (net.Conn, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()

Expand Down
2 changes: 1 addition & 1 deletion codersdk/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ func readBodyAsError(res *http.Response) error {
return &Error{
statusCode: res.StatusCode,
Response: Response{
Message: "unexpected non-JSON response",
Message: fmt.Sprintf("unexpected non-JSON response %q", contentType),
Detail: string(resp),
},
Helper: helper,
Expand Down
11 changes: 9 additions & 2 deletions codersdk/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,11 +501,18 @@ func (c *Client) PostWorkspaceAgentVersion(ctx context.Context, version string)
// WorkspaceAgentReconnectingPTY spawns a PTY that reconnects using the token provided.
// It communicates using `agent.ReconnectingPTYRequest` marshaled as JSON.
// Responses are PTY output that can be rendered.
func (c *Client) WorkspaceAgentReconnectingPTY(ctx context.Context, agentID, reconnect uuid.UUID, height, width int, command string) (net.Conn, error) {
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/pty?reconnect=%s&height=%d&width=%d&command=%s", agentID, reconnect, height, width, command))
func (c *Client) WorkspaceAgentReconnectingPTY(ctx context.Context, agentID, reconnect uuid.UUID, height, width uint16, command string) (net.Conn, error) {
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/pty", agentID))
if err != nil {
return nil, xerrors.Errorf("parse url: %w", err)
}
q := serverURL.Query()
q.Set("reconnect", reconnect.String())
q.Set("height", strconv.Itoa(int(height)))
q.Set("width", strconv.Itoa(int(width)))
q.Set("command", command)
serverURL.RawQuery = q.Encode()

jar, err := cookiejar.New(nil)
if err != nil {
return nil, xerrors.Errorf("create cookie jar: %w", err)
Expand Down
52 changes: 52 additions & 0 deletions loadtest/reconnectingpty/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package reconnectingpty

import (
"time"

"github.com/google/uuid"
"golang.org/x/xerrors"

"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
)

const (
DefaultWidth = 80
DefaultHeight = 24
DefaultTimeout = httpapi.Duration(5 * time.Minute)
)

type Config struct {
// AgentID is the ID of the agent to run the command in.
AgentID uuid.UUID `json:"agent_id"`
// Init is the initial packet to send to the agent when launching the TTY.
// If the ID is not set, defaults to a random UUID. If the width or height
// is not set, defaults to 80x24. If the command is not set, defaults to
// opening a login shell. Command runs in the default shell.
Init codersdk.ReconnectingPTYInit `json:"init"`
// Timeout is the duration to wait for the command to exit. Defaults to
// 5 minutes.
Timeout httpapi.Duration `json:"timeout"`
// ExpectTimeout means we expect the timeout to be reached (i.e. the command
// doesn't exit within the given timeout).
ExpectTimeout bool `json:"expect_timeout"`
// ExpectOutput checks that the given string is present in the output. The
// string must be present on a single line.
ExpectOutput string `json:"expect_output"`
// LogOutput determines whether the output of the command should be logged.
// For commands that produce a lot of output this should be disabled to
// avoid loadtest OOMs. All log output is still read and discarded if this
// is false.
LogOutput bool `json:"log_output"`
}

func (c Config) Validate() error {
if c.AgentID == uuid.Nil {
return xerrors.New("agent_id must be set")
}
if c.Timeout < 0 {
return xerrors.New("timeout must be a positive value")
}

return nil
}
78 changes: 78 additions & 0 deletions loadtest/reconnectingpty/config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package reconnectingpty_test

import (
"testing"
"time"

"github.com/google/uuid"
"github.com/stretchr/testify/require"

"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/loadtest/reconnectingpty"
)

func Test_Config(t *testing.T) {
t.Parallel()

id := uuid.New()
cases := []struct {
name string
config reconnectingpty.Config
errContains string
}{
{
name: "OKBasic",
config: reconnectingpty.Config{
AgentID: id,
},
},
{
name: "OKFull",
config: reconnectingpty.Config{
AgentID: id,
Init: codersdk.ReconnectingPTYInit{
ID: id,
Width: 80,
Height: 24,
Command: "echo 'hello world'",
},
Timeout: httpapi.Duration(time.Minute),
ExpectTimeout: false,
ExpectOutput: "hello world",
LogOutput: true,
},
},
{
name: "NoAgentID",
config: reconnectingpty.Config{
AgentID: uuid.Nil,
},
errContains: "agent_id must be set",
},
{
name: "NegativeTimeout",
config: reconnectingpty.Config{
AgentID: id,
Timeout: httpapi.Duration(-time.Minute),
},
errContains: "timeout must be a positive value",
},
}

for _, c := range cases {
c := c

t.Run(c.name, func(t *testing.T) {
t.Parallel()

err := c.config.Validate()
if c.errContains != "" {
require.Error(t, err)
require.Contains(t, err.Error(), c.errContains)
} else {
require.NoError(t, err)
}
})
}
}
Loading