Skip to content

Commit be118e6

Browse files
committed
reduce scope
1 parent 578ebb0 commit be118e6

File tree

6 files changed

+66
-181
lines changed

6 files changed

+66
-181
lines changed

cli/cliutil/stdioconn.go

+7-7
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,31 @@ import (
66
"time"
77
)
88

9-
type StdioConn struct {
9+
type ReaderWriterConn struct {
1010
io.Reader
1111
io.Writer
1212
}
1313

14-
func (*StdioConn) Close() (err error) {
14+
func (*ReaderWriterConn) Close() (err error) {
1515
return nil
1616
}
1717

18-
func (*StdioConn) LocalAddr() net.Addr {
18+
func (*ReaderWriterConn) LocalAddr() net.Addr {
1919
return nil
2020
}
2121

22-
func (*StdioConn) RemoteAddr() net.Addr {
22+
func (*ReaderWriterConn) RemoteAddr() net.Addr {
2323
return nil
2424
}
2525

26-
func (*StdioConn) SetDeadline(_ time.Time) error {
26+
func (*ReaderWriterConn) SetDeadline(_ time.Time) error {
2727
return nil
2828
}
2929

30-
func (*StdioConn) SetReadDeadline(_ time.Time) error {
30+
func (*ReaderWriterConn) SetReadDeadline(_ time.Time) error {
3131
return nil
3232
}
3333

34-
func (*StdioConn) SetWriteDeadline(_ time.Time) error {
34+
func (*ReaderWriterConn) SetWriteDeadline(_ time.Time) error {
3535
return nil
3636
}

cli/ssh.go

+50-121
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func (r *RootCmd) ssh() *serpent.Command {
6767
stdio bool
6868
hostPrefix string
6969
hostnameSuffix string
70-
forceTunnel bool
70+
forceNewTunnel bool
7171
forwardAgent bool
7272
forwardGPG bool
7373
identityAgent string
@@ -278,27 +278,38 @@ func (r *RootCmd) ssh() *serpent.Command {
278278
return err
279279
}
280280

281-
// See if we can use the Coder Connect tunnel
282-
if !forceTunnel {
281+
// If we're in stdio mode, check to see if we can use Coder Connect.
282+
// We don't support Coder Connect over non-stdio coder ssh yet.
283+
if stdio && !forceNewTunnel {
283284
connInfo, err := wsClient.AgentConnectionInfoGeneric(ctx)
284285
if err != nil {
285286
return xerrors.Errorf("get agent connection info: %w", err)
286287
}
287-
288288
coderConnectHost := fmt.Sprintf("%s.%s.%s.%s",
289289
workspaceAgent.Name, workspace.Name, workspace.OwnerName, connInfo.HostnameSuffix)
290290
exists, _ := workspacesdk.ExistsViaCoderConnect(ctx, coderConnectHost)
291291
if exists {
292292
_, _ = fmt.Fprintln(inv.Stderr, "Connecting to workspace via Coder Connect...")
293293
defer cancel()
294-
addr := fmt.Sprintf("%s:22", coderConnectHost)
295-
if stdio {
294+
295+
if networkInfoDir != "" {
296296
if err := writeCoderConnectNetInfo(ctx, networkInfoDir); err != nil {
297297
logger.Error(ctx, "failed to write coder connect net info file", slog.Error(err))
298298
}
299-
return runCoderConnectStdio(ctx, addr, stdioReader, stdioWriter, stack)
300299
}
301-
return runCoderConnectPTY(ctx, addr, inv.Stdin, inv.Stdout, inv.Stderr, stack)
300+
301+
stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
302+
defer stopPolling()
303+
304+
usageAppName := getUsageAppName(usageApp)
305+
if usageAppName != "" {
306+
closeUsage := client.UpdateWorkspaceUsageWithBodyContext(ctx, workspace.ID, codersdk.PostWorkspaceUsageRequest{
307+
AgentID: workspaceAgent.ID,
308+
AppName: usageAppName,
309+
})
310+
defer closeUsage()
311+
}
312+
return runCoderConnectStdio(ctx, fmt.Sprintf("%s:22", coderConnectHost), stdioReader, stdioWriter, stack)
302313
}
303314
}
304315

@@ -481,11 +492,36 @@ func (r *RootCmd) ssh() *serpent.Command {
481492
stdinFile, validIn := inv.Stdin.(*os.File)
482493
stdoutFile, validOut := inv.Stdout.(*os.File)
483494
if validIn && validOut && isatty.IsTerminal(stdinFile.Fd()) && isatty.IsTerminal(stdoutFile.Fd()) {
484-
restorePtyFn, err := configurePTY(ctx, stdinFile, stdoutFile, sshSession)
485-
defer restorePtyFn()
495+
inState, err := pty.MakeInputRaw(stdinFile.Fd())
496+
if err != nil {
497+
return err
498+
}
499+
defer func() {
500+
_ = pty.RestoreTerminal(stdinFile.Fd(), inState)
501+
}()
502+
outState, err := pty.MakeOutputRaw(stdoutFile.Fd())
486503
if err != nil {
487-
return xerrors.Errorf("configure pty: %w", err)
504+
return err
488505
}
506+
defer func() {
507+
_ = pty.RestoreTerminal(stdoutFile.Fd(), outState)
508+
}()
509+
510+
windowChange := listenWindowSize(ctx)
511+
go func() {
512+
for {
513+
select {
514+
case <-ctx.Done():
515+
return
516+
case <-windowChange:
517+
}
518+
width, height, err := term.GetSize(int(stdoutFile.Fd()))
519+
if err != nil {
520+
continue
521+
}
522+
_ = sshSession.WindowChange(height, width)
523+
}
524+
}()
489525
}
490526

491527
for _, kv := range parsedEnv {
@@ -667,48 +703,14 @@ func (r *RootCmd) ssh() *serpent.Command {
667703
{
668704
Flag: "force-new-tunnel",
669705
Description: "Force the creation of a new tunnel to the workspace, even if the Coder Connect tunnel is available.",
670-
Value: serpent.BoolOf(&forceTunnel),
706+
Value: serpent.BoolOf(&forceNewTunnel),
707+
Hidden: true,
671708
},
672709
sshDisableAutostartOption(serpent.BoolOf(&disableAutostart)),
673710
}
674711
return cmd
675712
}
676713

677-
func configurePTY(ctx context.Context, stdinFile *os.File, stdoutFile *os.File, sshSession *gossh.Session) (restoreFn func(), err error) {
678-
inState, err := pty.MakeInputRaw(stdinFile.Fd())
679-
if err != nil {
680-
return restoreFn, err
681-
}
682-
restoreFn = func() {
683-
_ = pty.RestoreTerminal(stdinFile.Fd(), inState)
684-
}
685-
outState, err := pty.MakeOutputRaw(stdoutFile.Fd())
686-
if err != nil {
687-
return restoreFn, err
688-
}
689-
restoreFn = func() {
690-
_ = pty.RestoreTerminal(stdinFile.Fd(), inState)
691-
_ = pty.RestoreTerminal(stdoutFile.Fd(), outState)
692-
}
693-
694-
windowChange := listenWindowSize(ctx)
695-
go func() {
696-
for {
697-
select {
698-
case <-ctx.Done():
699-
return
700-
case <-windowChange:
701-
}
702-
width, height, err := term.GetSize(int(stdoutFile.Fd()))
703-
if err != nil {
704-
continue
705-
}
706-
_ = sshSession.WindowChange(height, width)
707-
}
708-
}()
709-
return restoreFn, nil
710-
}
711-
712714
// findWorkspaceAndAgentByHostname parses the hostname from the commandline and finds the workspace and agent it
713715
// corresponds to, taking into account any name prefixes or suffixes configured (e.g. myworkspace.coder, or
714716
// vscode-coder--myusername--myworkspace).
@@ -1502,87 +1504,14 @@ func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, std
15021504
return err
15031505
}
15041506

1505-
agentssh.Bicopy(ctx, conn, &cliutil.StdioConn{
1507+
agentssh.Bicopy(ctx, conn, &cliutil.ReaderWriterConn{
15061508
Reader: stdin,
15071509
Writer: stdout,
15081510
})
15091511

15101512
return nil
15111513
}
15121514

1513-
func runCoderConnectPTY(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stderr io.Writer, stack *closerStack) error {
1514-
client, err := gossh.Dial("tcp", addr, &gossh.ClientConfig{
1515-
// We've already checked the agent's address
1516-
// is within the Coder service prefix.
1517-
// #nosec
1518-
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
1519-
})
1520-
if err != nil {
1521-
return xerrors.Errorf("dial coder connect host: %w", err)
1522-
}
1523-
if err := stack.push("ssh client", client); err != nil {
1524-
return err
1525-
}
1526-
1527-
session, err := client.NewSession()
1528-
if err != nil {
1529-
return xerrors.Errorf("create ssh session: %w", err)
1530-
}
1531-
if err := stack.push("ssh session", session); err != nil {
1532-
return err
1533-
}
1534-
1535-
stdinFile, validIn := stdin.(*os.File)
1536-
stdoutFile, validOut := stdout.(*os.File)
1537-
if validIn && validOut && isatty.IsTerminal(stdinFile.Fd()) && isatty.IsTerminal(stdoutFile.Fd()) {
1538-
restorePtyFn, err := configurePTY(ctx, stdinFile, stdoutFile, session)
1539-
defer restorePtyFn()
1540-
if err != nil {
1541-
return xerrors.Errorf("configure pty: %w", err)
1542-
}
1543-
}
1544-
1545-
session.Stdin = stdin
1546-
session.Stdout = stdout
1547-
session.Stderr = stderr
1548-
1549-
err = session.RequestPty("xterm-256color", 80, 24, gossh.TerminalModes{})
1550-
if err != nil {
1551-
return xerrors.Errorf("request pty: %w", err)
1552-
}
1553-
1554-
err = session.Shell()
1555-
if err != nil {
1556-
return xerrors.Errorf("start shell: %w", err)
1557-
}
1558-
1559-
if validOut {
1560-
// Set initial window size.
1561-
width, height, err := term.GetSize(int(stdoutFile.Fd()))
1562-
if err == nil {
1563-
_ = session.WindowChange(height, width)
1564-
}
1565-
}
1566-
1567-
err = session.Wait()
1568-
if err != nil {
1569-
if exitErr := (&gossh.ExitError{}); errors.As(err, &exitErr) {
1570-
// Clear the error since it's not useful beyond
1571-
// reporting status.
1572-
return ExitError(exitErr.ExitStatus(), nil)
1573-
}
1574-
// If the connection drops unexpectedly, we get an
1575-
// ExitMissingError but no other error details, so try to at
1576-
// least give the user a better message
1577-
if errors.Is(err, &gossh.ExitMissingError{}) {
1578-
return ExitError(255, xerrors.New("SSH connection ended unexpectedly"))
1579-
}
1580-
return xerrors.Errorf("session ended: %w", err)
1581-
}
1582-
1583-
return nil
1584-
}
1585-
15861515
func writeCoderConnectNetInfo(ctx context.Context, networkInfoDir string) error {
15871516
fs, ok := ctx.Value("fs").(afero.Fs)
15881517
if !ok {

cli/ssh_internal_test.go

+1-33
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import (
2222

2323
"github.com/coder/coder/v2/cli/cliutil"
2424
"github.com/coder/coder/v2/codersdk"
25-
"github.com/coder/coder/v2/pty/ptytest"
2625
"github.com/coder/coder/v2/testutil"
2726
)
2827

@@ -226,37 +225,6 @@ func TestCloserStack_Timeout(t *testing.T) {
226225
testutil.TryReceive(ctx, t, closed)
227226
}
228227

229-
func TestCoderConnectPTY(t *testing.T) {
230-
t.Parallel()
231-
232-
ctx := testutil.Context(t, testutil.WaitShort)
233-
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
234-
stack := newCloserStack(ctx, logger, quartz.NewMock(t))
235-
236-
server := newSSHServer("127.0.0.1:0")
237-
ln, err := net.Listen("tcp", server.server.Addr)
238-
require.NoError(t, err)
239-
240-
go func() {
241-
_ = server.Serve(ln)
242-
}()
243-
t.Cleanup(func() {
244-
_ = server.Close()
245-
})
246-
247-
ptty := ptytest.New(t)
248-
ptyDone := make(chan struct{})
249-
go func() {
250-
err := runCoderConnectPTY(ctx, ln.Addr().String(), ptty.Output(), ptty.Input(), ptty.Output(), stack)
251-
assert.NoError(t, err)
252-
close(ptyDone)
253-
}()
254-
ptty.ExpectMatch("Connected!")
255-
// Shells on Mac, Windows, and Linux all exit shells with the "exit" command.
256-
ptty.WriteLine("exit")
257-
<-ptyDone
258-
}
259-
260228
func TestCoderConnectStdio(t *testing.T) {
261229
t.Parallel()
262230

@@ -290,7 +258,7 @@ func TestCoderConnectStdio(t *testing.T) {
290258
close(stdioDone)
291259
}()
292260

293-
conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{
261+
conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{
294262
Reader: serverOutput,
295263
Writer: clientInput,
296264
}, "", &ssh.ClientConfig{

0 commit comments

Comments
 (0)