Skip to content

Commit 6862409

Browse files
authored
feat(agent/reconnectingpty): allow selecting backend type (coder#17011)
agent/reconnectingpty: allow specifying backend type cli: exp rpty: automatically select backend based on command
1 parent 0cd254f commit 6862409

File tree

7 files changed

+78
-21
lines changed

7 files changed

+78
-21
lines changed

agent/reconnectingpty/reconnectingpty.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ type Options struct {
3232
Timeout time.Duration
3333
// Metrics tracks various error counters.
3434
Metrics *prometheus.CounterVec
35+
// BackendType specifies the ReconnectingPTY backend to use.
36+
BackendType string
3537
}
3638

3739
// ReconnectingPTY is a pty that can be reconnected within a timeout and to
@@ -64,13 +66,20 @@ func New(ctx context.Context, logger slog.Logger, execer agentexec.Execer, cmd *
6466
// runs) but in CI screen often incorrectly claims the session name does not
6567
// exist even though screen -list shows it. For now, restrict screen to
6668
// Linux.
67-
backendType := "buffered"
69+
autoBackendType := "buffered"
6870
if runtime.GOOS == "linux" {
6971
_, err := exec.LookPath("screen")
7072
if err == nil {
71-
backendType = "screen"
73+
autoBackendType = "screen"
7274
}
7375
}
76+
var backendType string
77+
switch options.BackendType {
78+
case "":
79+
backendType = autoBackendType
80+
default:
81+
backendType = options.BackendType
82+
}
7483

7584
logger.Info(ctx, "start reconnecting pty", slog.F("backend_type", backendType))
7685

agent/reconnectingpty/server.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,9 @@ func (s *Server) handleConn(ctx context.Context, logger slog.Logger, conn net.Co
207207
s.commandCreator.Execer,
208208
cmd,
209209
&Options{
210-
Timeout: s.timeout,
211-
Metrics: s.errorsTotal,
210+
Timeout: s.timeout,
211+
Metrics: s.errorsTotal,
212+
BackendType: msg.BackendType,
212213
},
213214
)
214215

cli/exp_rpty.go

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"bufio"
55
"context"
66
"encoding/json"
7-
"fmt"
87
"io"
98
"os"
109
"strings"
@@ -15,6 +14,7 @@ import (
1514
"golang.org/x/xerrors"
1615

1716
"github.com/coder/coder/v2/cli/cliui"
17+
"github.com/coder/coder/v2/coderd/util/slice"
1818
"github.com/coder/coder/v2/codersdk"
1919
"github.com/coder/coder/v2/codersdk/workspacesdk"
2020
"github.com/coder/coder/v2/pty"
@@ -96,6 +96,7 @@ func handleRPTY(inv *serpent.Invocation, client *codersdk.Client, args handleRPT
9696
} else {
9797
reconnectID = uuid.New()
9898
}
99+
99100
ws, agt, err := getWorkspaceAndAgent(ctx, inv, client, true, args.NamedWorkspace)
100101
if err != nil {
101102
return err
@@ -118,14 +119,6 @@ func handleRPTY(inv *serpent.Invocation, client *codersdk.Client, args handleRPT
118119
}
119120
}
120121

121-
if err := cliui.Agent(ctx, inv.Stderr, agt.ID, cliui.AgentOptions{
122-
FetchInterval: 0,
123-
Fetch: client.WorkspaceAgent,
124-
Wait: false,
125-
}); err != nil {
126-
return err
127-
}
128-
129122
// Get the width and height of the terminal.
130123
var termWidth, termHeight uint16
131124
stdoutFile, validOut := inv.Stdout.(*os.File)
@@ -149,6 +142,15 @@ func handleRPTY(inv *serpent.Invocation, client *codersdk.Client, args handleRPT
149142
}()
150143
}
151144

145+
// If a user does not specify a command, we'll assume they intend to open an
146+
// interactive shell.
147+
var backend string
148+
if isOneShotCommand(args.Command) {
149+
// If the user specified a command, we'll prefer to use the buffered method.
150+
// The screen backend is not well suited for one-shot commands.
151+
backend = "buffered"
152+
}
153+
152154
conn, err := workspacesdk.New(client).AgentReconnectingPTY(ctx, workspacesdk.WorkspaceAgentReconnectingPTYOpts{
153155
AgentID: agt.ID,
154156
Reconnect: reconnectID,
@@ -157,14 +159,13 @@ func handleRPTY(inv *serpent.Invocation, client *codersdk.Client, args handleRPT
157159
ContainerUser: args.ContainerUser,
158160
Width: termWidth,
159161
Height: termHeight,
162+
BackendType: backend,
160163
})
161164
if err != nil {
162165
return xerrors.Errorf("open reconnecting PTY: %w", err)
163166
}
164167
defer conn.Close()
165168

166-
cliui.Infof(inv.Stderr, "Connected to %s (agent id: %s)", args.NamedWorkspace, agt.ID)
167-
cliui.Infof(inv.Stderr, "Reconnect ID: %s", reconnectID)
168169
closeUsage := client.UpdateWorkspaceUsageWithBodyContext(ctx, ws.ID, codersdk.PostWorkspaceUsageRequest{
169170
AgentID: agt.ID,
170171
AppName: codersdk.UsageAppNameReconnectingPty,
@@ -210,7 +211,21 @@ func handleRPTY(inv *serpent.Invocation, client *codersdk.Client, args handleRPT
210211
_, _ = io.Copy(inv.Stdout, conn)
211212
cancel()
212213
_ = conn.Close()
213-
_, _ = fmt.Fprintf(inv.Stderr, "Connection closed\n")
214214

215215
return nil
216216
}
217+
218+
var knownShells = []string{"ash", "bash", "csh", "dash", "fish", "ksh", "powershell", "pwsh", "zsh"}
219+
220+
func isOneShotCommand(cmd []string) bool {
221+
// If the command is empty, we'll assume the user wants to open a shell.
222+
if len(cmd) == 0 {
223+
return false
224+
}
225+
// If the command is a single word, and that word is a known shell, we'll
226+
// assume the user wants to open a shell.
227+
if len(cmd) == 1 && slice.Contains(knownShells, cmd[0]) {
228+
return false
229+
}
230+
return true
231+
}

cli/exp_rpty_test.go

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
package cli_test
22

33
import (
4-
"fmt"
54
"runtime"
65
"testing"
76

7+
"github.com/google/uuid"
88
"github.com/ory/dockertest/v3"
99
"github.com/ory/dockertest/v3/docker"
1010

@@ -23,7 +23,7 @@ import (
2323
func TestExpRpty(t *testing.T) {
2424
t.Parallel()
2525

26-
t.Run("OK", func(t *testing.T) {
26+
t.Run("DefaultCommand", func(t *testing.T) {
2727
t.Parallel()
2828

2929
client, workspace, agentToken := setupWorkspaceForAgent(t)
@@ -41,11 +41,33 @@ func TestExpRpty(t *testing.T) {
4141
_ = agenttest.New(t, client.URL, agentToken)
4242
_ = coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
4343

44-
pty.ExpectMatch(fmt.Sprintf("Connected to %s", workspace.Name))
4544
pty.WriteLine("exit")
4645
<-cmdDone
4746
})
4847

48+
t.Run("Command", func(t *testing.T) {
49+
t.Parallel()
50+
51+
client, workspace, agentToken := setupWorkspaceForAgent(t)
52+
randStr := uuid.NewString()
53+
inv, root := clitest.New(t, "exp", "rpty", workspace.Name, "echo", randStr)
54+
clitest.SetupConfig(t, client, root)
55+
pty := ptytest.New(t).Attach(inv)
56+
57+
ctx := testutil.Context(t, testutil.WaitLong)
58+
59+
cmdDone := tGo(t, func() {
60+
err := inv.WithContext(ctx).Run()
61+
assert.NoError(t, err)
62+
})
63+
64+
_ = agenttest.New(t, client.URL, agentToken)
65+
_ = coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
66+
67+
pty.ExpectMatch(randStr)
68+
<-cmdDone
69+
})
70+
4971
t.Run("NotFound", func(t *testing.T) {
5072
t.Parallel()
5173

@@ -103,8 +125,6 @@ func TestExpRpty(t *testing.T) {
103125
assert.NoError(t, err)
104126
})
105127

106-
pty.ExpectMatch(fmt.Sprintf("Connected to %s", workspace.Name))
107-
pty.ExpectMatch("Reconnect ID: ")
108128
pty.ExpectMatch(" #")
109129
pty.WriteLine("hostname")
110130
pty.ExpectMatch(ct.Container.Config.Hostname)

coderd/workspaceapps/proxy.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,7 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
655655
width := parser.UInt(values, 80, "width")
656656
container := parser.String(values, "", "container")
657657
containerUser := parser.String(values, "", "container_user")
658+
backendType := parser.String(values, "", "backend_type")
658659
if len(parser.Errors) > 0 {
659660
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
660661
Message: "Invalid query parameters.",
@@ -695,6 +696,7 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
695696
ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect, uint16(height), uint16(width), r.URL.Query().Get("command"), func(arp *workspacesdk.AgentReconnectingPTYInit) {
696697
arp.Container = container
697698
arp.ContainerUser = containerUser
699+
arp.BackendType = backendType
698700
})
699701
if err != nil {
700702
log.Debug(ctx, "dial reconnecting pty server in workspace agent", slog.Error(err))

codersdk/workspacesdk/agentconn.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ type AgentReconnectingPTYInit struct {
100100
// This can be a username or UID, depending on the underlying implementation.
101101
// This is ignored if Container is not set.
102102
ContainerUser string
103+
104+
BackendType string
103105
}
104106

105107
// AgentReconnectingPTYInitOption is a functional option for AgentReconnectingPTYInit.

codersdk/workspacesdk/workspacesdk.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,11 @@ type WorkspaceAgentReconnectingPTYOpts struct {
318318
// CODER_AGENT_DEVCONTAINERS_ENABLE set to "true".
319319
Container string
320320
ContainerUser string
321+
322+
// BackendType is the type of backend to use for the PTY. If not set, the
323+
// workspace agent will attempt to determine the preferred backend type.
324+
// Supported values are "screen" and "buffered".
325+
BackendType string
321326
}
322327

323328
// AgentReconnectingPTY spawns a PTY that reconnects using the token provided.
@@ -339,6 +344,9 @@ func (c *Client) AgentReconnectingPTY(ctx context.Context, opts WorkspaceAgentRe
339344
if opts.ContainerUser != "" {
340345
q.Set("container_user", opts.ContainerUser)
341346
}
347+
if opts.BackendType != "" {
348+
q.Set("backend_type", opts.BackendType)
349+
}
342350
// If we're using a signed token, set the query parameter.
343351
if opts.SignedToken != "" {
344352
q.Set(codersdk.SignedAppTokenQueryParameter, opts.SignedToken)

0 commit comments

Comments
 (0)