From a86247287cbc7e35a473f39b9d8698b7393a82cc Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Mon, 28 Apr 2025 07:42:46 +0000 Subject: [PATCH 01/10] feat(cli): use coder connect in `coder ssh`, if available --- cli/cliutil/stdioconn.go | 36 +++++ cli/ssh.go | 232 ++++++++++++++++++++++----- cli/ssh_internal_test.go | 118 ++++++++++++++ cli/ssh_test.go | 86 +++++----- cli/testdata/coder_ssh_--help.golden | 4 + codersdk/workspacesdk/agentconn.go | 4 +- docs/reference/cli/ssh.md | 8 + 7 files changed, 411 insertions(+), 77 deletions(-) create mode 100644 cli/cliutil/stdioconn.go diff --git a/cli/cliutil/stdioconn.go b/cli/cliutil/stdioconn.go new file mode 100644 index 0000000000000..7f919dbf9d456 --- /dev/null +++ b/cli/cliutil/stdioconn.go @@ -0,0 +1,36 @@ +package cliutil + +import ( + "io" + "net" + "time" +) + +type StdioConn struct { + io.Reader + io.Writer +} + +func (*StdioConn) Close() (err error) { + return nil +} + +func (*StdioConn) LocalAddr() net.Addr { + return nil +} + +func (*StdioConn) RemoteAddr() net.Addr { + return nil +} + +func (*StdioConn) SetDeadline(_ time.Time) error { + return nil +} + +func (*StdioConn) SetReadDeadline(_ time.Time) error { + return nil +} + +func (*StdioConn) SetWriteDeadline(_ time.Time) error { + return nil +} diff --git a/cli/ssh.go b/cli/ssh.go index e02443e7032c6..82cab0aee1219 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "log" + "net" "net/http" "net/url" "os" @@ -66,6 +67,7 @@ func (r *RootCmd) ssh() *serpent.Command { stdio bool hostPrefix string hostnameSuffix string + forceTunnel bool forwardAgent bool forwardGPG bool identityAgent string @@ -85,6 +87,7 @@ func (r *RootCmd) ssh() *serpent.Command { containerUser string ) client := new(codersdk.Client) + wsClient := workspacesdk.New(client) cmd := &serpent.Command{ Annotations: workspaceCommand, Use: "ssh ", @@ -203,14 +206,14 @@ func (r *RootCmd) ssh() *serpent.Command { parsedEnv = append(parsedEnv, [2]string{k, v}) } - deploymentSSHConfig := codersdk.SSHConfigResponse{ + cliConfig := codersdk.SSHConfigResponse{ HostnamePrefix: hostPrefix, HostnameSuffix: hostnameSuffix, } workspace, workspaceAgent, err := findWorkspaceAndAgentByHostname( ctx, inv, client, - inv.Args[0], deploymentSSHConfig, disableAutostart) + inv.Args[0], cliConfig, disableAutostart) if err != nil { return err } @@ -275,10 +278,34 @@ func (r *RootCmd) ssh() *serpent.Command { return err } + // See if we can use the Coder Connect tunnel + if !forceTunnel { + connInfo, err := wsClient.AgentConnectionInfoGeneric(ctx) + if err != nil { + return xerrors.Errorf("get agent connection info: %w", err) + } + + coderConnectHost := fmt.Sprintf("%s.%s.%s.%s", + workspaceAgent.Name, workspace.Name, workspace.OwnerName, connInfo.HostnameSuffix) + exists, _ := workspacesdk.ExistsViaCoderConnect(ctx, coderConnectHost) + if exists { + _, _ = fmt.Fprintln(inv.Stderr, "Connecting to workspace via Coder Connect...") + defer cancel() + addr := fmt.Sprintf("%s:22", coderConnectHost) + if stdio { + if err := writeCoderConnectNetInfo(ctx, networkInfoDir); err != nil { + logger.Error(ctx, "failed to write coder connect net info file", slog.Error(err)) + } + return runCoderConnectStdio(ctx, addr, stdioReader, stdioWriter, stack) + } + return runCoderConnectPTY(ctx, addr, inv.Stdin, inv.Stdout, inv.Stderr, stack) + } + } + if r.disableDirect { _, _ = fmt.Fprintln(inv.Stderr, "Direct connections disabled.") } - conn, err := workspacesdk.New(client). + conn, err := wsClient. DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{ Logger: logger, BlockEndpoints: r.disableDirect, @@ -454,36 +481,11 @@ func (r *RootCmd) ssh() *serpent.Command { stdinFile, validIn := inv.Stdin.(*os.File) stdoutFile, validOut := inv.Stdout.(*os.File) if validIn && validOut && isatty.IsTerminal(stdinFile.Fd()) && isatty.IsTerminal(stdoutFile.Fd()) { - inState, err := pty.MakeInputRaw(stdinFile.Fd()) - if err != nil { - return err - } - defer func() { - _ = pty.RestoreTerminal(stdinFile.Fd(), inState) - }() - outState, err := pty.MakeOutputRaw(stdoutFile.Fd()) + restorePtyFn, err := configurePTY(ctx, stdinFile, stdoutFile, sshSession) + defer restorePtyFn() if err != nil { - return err + return xerrors.Errorf("configure pty: %w", err) } - defer func() { - _ = pty.RestoreTerminal(stdoutFile.Fd(), outState) - }() - - windowChange := listenWindowSize(ctx) - go func() { - for { - select { - case <-ctx.Done(): - return - case <-windowChange: - } - width, height, err := term.GetSize(int(stdoutFile.Fd())) - if err != nil { - continue - } - _ = sshSession.WindowChange(height, width) - } - }() } for _, kv := range parsedEnv { @@ -662,11 +664,51 @@ func (r *RootCmd) ssh() *serpent.Command { Value: serpent.StringOf(&containerUser), Hidden: true, // Hidden until this features is at least in beta. }, + { + Flag: "force-tunnel", + Description: "Force the use of a new tunnel to the workspace, even if the Coder Connect tunnel is available.", + Value: serpent.BoolOf(&forceTunnel), + }, sshDisableAutostartOption(serpent.BoolOf(&disableAutostart)), } return cmd } +func configurePTY(ctx context.Context, stdinFile *os.File, stdoutFile *os.File, sshSession *gossh.Session) (restoreFn func(), err error) { + inState, err := pty.MakeInputRaw(stdinFile.Fd()) + if err != nil { + return restoreFn, err + } + restoreFn = func() { + _ = pty.RestoreTerminal(stdinFile.Fd(), inState) + } + outState, err := pty.MakeOutputRaw(stdoutFile.Fd()) + if err != nil { + return restoreFn, err + } + restoreFn = func() { + _ = pty.RestoreTerminal(stdinFile.Fd(), inState) + _ = pty.RestoreTerminal(stdoutFile.Fd(), outState) + } + + windowChange := listenWindowSize(ctx) + go func() { + for { + select { + case <-ctx.Done(): + return + case <-windowChange: + } + width, height, err := term.GetSize(int(stdoutFile.Fd())) + if err != nil { + continue + } + _ = sshSession.WindowChange(height, width) + } + }() + return restoreFn, nil +} + // findWorkspaceAndAgentByHostname parses the hostname from the commandline and finds the workspace and agent it // corresponds to, taking into account any name prefixes or suffixes configured (e.g. myworkspace.coder, or // vscode-coder--myusername--myworkspace). @@ -1374,12 +1416,13 @@ func setStatsCallback( } type sshNetworkStats struct { - P2P bool `json:"p2p"` - Latency float64 `json:"latency"` - PreferredDERP string `json:"preferred_derp"` - DERPLatency map[string]float64 `json:"derp_latency"` - UploadBytesSec int64 `json:"upload_bytes_sec"` - DownloadBytesSec int64 `json:"download_bytes_sec"` + P2P bool `json:"p2p"` + Latency float64 `json:"latency"` + PreferredDERP string `json:"preferred_derp"` + DERPLatency map[string]float64 `json:"derp_latency"` + UploadBytesSec int64 `json:"upload_bytes_sec"` + DownloadBytesSec int64 `json:"download_bytes_sec"` + UsingCoderConnect bool `json:"using_coder_connect"` } func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) { @@ -1450,6 +1493,121 @@ func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, }, nil } +func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stack *closerStack) error { + conn, err := net.Dial("tcp", addr) + if err != nil { + return xerrors.Errorf("dial coder connect host: %w", err) + } + if err := stack.push("tcp conn", conn); err != nil { + return err + } + + agentssh.Bicopy(ctx, conn, &cliutil.StdioConn{ + Reader: stdin, + Writer: stdout, + }) + + return nil +} + +func runCoderConnectPTY(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stderr io.Writer, stack *closerStack) error { + client, err := gossh.Dial("tcp", addr, &gossh.ClientConfig{ + // We've already checked the agent's address + // is within the Coder service prefix. + // #nosec + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + }) + if err != nil { + return xerrors.Errorf("dial coder connect host: %w", err) + } + if err := stack.push("ssh client", client); err != nil { + return err + } + + session, err := client.NewSession() + if err != nil { + return xerrors.Errorf("create ssh session: %w", err) + } + if err := stack.push("ssh session", session); err != nil { + return err + } + + stdinFile, validIn := stdin.(*os.File) + stdoutFile, validOut := stdout.(*os.File) + if validIn && validOut && isatty.IsTerminal(stdinFile.Fd()) && isatty.IsTerminal(stdoutFile.Fd()) { + restorePtyFn, err := configurePTY(ctx, stdinFile, stdoutFile, session) + defer restorePtyFn() + if err != nil { + return xerrors.Errorf("configure pty: %w", err) + } + } + + session.Stdin = stdin + session.Stdout = stdout + session.Stderr = stderr + + err = session.RequestPty("xterm-256color", 80, 24, gossh.TerminalModes{}) + if err != nil { + return xerrors.Errorf("request pty: %w", err) + } + + err = session.Shell() + if err != nil { + return xerrors.Errorf("start shell: %w", err) + } + + if validOut { + // Set initial window size. + width, height, err := term.GetSize(int(stdoutFile.Fd())) + if err == nil { + _ = session.WindowChange(height, width) + } + } + + err = session.Wait() + if err != nil { + if exitErr := (&gossh.ExitError{}); errors.As(err, &exitErr) { + // Clear the error since it's not useful beyond + // reporting status. + return ExitError(exitErr.ExitStatus(), nil) + } + // If the connection drops unexpectedly, we get an + // ExitMissingError but no other error details, so try to at + // least give the user a better message + if errors.Is(err, &gossh.ExitMissingError{}) { + return ExitError(255, xerrors.New("SSH connection ended unexpectedly")) + } + return xerrors.Errorf("session ended: %w", err) + } + + return nil +} + +func writeCoderConnectNetInfo(ctx context.Context, networkInfoDir string) error { + fs, ok := ctx.Value("fs").(afero.Fs) + if !ok { + fs = afero.NewOsFs() + } + // The VS Code extension obtains the PID of the SSH process to + // find the log file associated with a SSH session. + // + // We get the parent PID because it's assumed `ssh` is calling this + // command via the ProxyCommand SSH option. + networkInfoFilePath := filepath.Join(networkInfoDir, fmt.Sprintf("%d.json", os.Getppid())) + stats := &sshNetworkStats{ + UsingCoderConnect: true, + } + rawStats, err := json.Marshal(stats) + if err != nil { + return xerrors.Errorf("marshal network stats: %w", err) + } + err = afero.WriteFile(fs, networkInfoFilePath, rawStats, 0o600) + if err != nil { + return xerrors.Errorf("write network stats: %w", err) + } + return nil +} + // Converts workspace name input to owner/workspace.agent format // Possible valid input formats: // workspace diff --git a/cli/ssh_internal_test.go b/cli/ssh_internal_test.go index d5e4c049347b2..6de6bd9ea24bf 100644 --- a/cli/ssh_internal_test.go +++ b/cli/ssh_internal_test.go @@ -3,20 +3,26 @@ package cli import ( "context" "fmt" + "io" + "net" "net/url" "sync" "testing" "time" + gliderssh "github.com/gliderlabs/ssh" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" "golang.org/x/xerrors" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/quartz" + "github.com/coder/coder/v2/cli/cliutil" "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/testutil" ) @@ -220,6 +226,118 @@ func TestCloserStack_Timeout(t *testing.T) { testutil.TryReceive(ctx, t, closed) } +func TestCoderConnectPTY(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + stack := newCloserStack(ctx, logger, quartz.NewMock(t)) + + server := newSSHServer("127.0.0.1:0") + ln, err := net.Listen("tcp", server.server.Addr) + require.NoError(t, err) + + go func() { + _ = server.Serve(ln) + }() + t.Cleanup(func() { + _ = server.Close() + }) + + ptty := ptytest.New(t) + ptyDone := make(chan struct{}) + go func() { + err := runCoderConnectPTY(ctx, ln.Addr().String(), ptty.Output(), ptty.Input(), ptty.Output(), stack) + assert.NoError(t, err) + close(ptyDone) + }() + ptty.ExpectMatch("Connected!") + // Shells on Mac, Windows, and Linux all exit shells with the "exit" command. + ptty.WriteLine("exit") + <-ptyDone +} + +func TestCoderConnectStdio(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + stack := newCloserStack(ctx, logger, quartz.NewMock(t)) + + clientOutput, clientInput := io.Pipe() + serverOutput, serverInput := io.Pipe() + defer func() { + for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} { + _ = c.Close() + } + }() + + server := newSSHServer("127.0.0.1:0") + ln, err := net.Listen("tcp", server.server.Addr) + require.NoError(t, err) + + go func() { + _ = server.Serve(ln) + }() + t.Cleanup(func() { + _ = server.Close() + }) + + stdioDone := make(chan struct{}) + go func() { + err = runCoderConnectStdio(ctx, ln.Addr().String(), clientOutput, serverInput, stack) + assert.NoError(t, err) + close(stdioDone) + }() + + conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ + Reader: serverOutput, + Writer: clientInput, + }, "", &ssh.ClientConfig{ + // #nosec + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + require.NoError(t, err) + defer conn.Close() + + sshClient := ssh.NewClient(conn, channels, requests) + session, err := sshClient.NewSession() + require.NoError(t, err) + defer session.Close() + + // Shells on Mac, Windows, and Linux all exit shells with the "exit" command. + err = session.Run("exit") + require.NoError(t, err) + err = sshClient.Close() + require.NoError(t, err) + _ = clientOutput.Close() + + <-stdioDone +} + +type sshServer struct { + server *gliderssh.Server +} + +func newSSHServer(addr string) *sshServer { + return &sshServer{ + server: &gliderssh.Server{ + Addr: addr, + Handler: func(s gliderssh.Session) { + _, _ = io.WriteString(s, "Connected!") + }, + }, + } +} + +func (s *sshServer) Serve(ln net.Listener) error { + return s.server.Serve(ln) +} + +func (s *sshServer) Close() error { + return s.server.Close() +} + type fakeCloser struct { closes *[]*fakeCloser err error diff --git a/cli/ssh_test.go b/cli/ssh_test.go index c8ad072270169..e9dd7c4bc42b2 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -43,6 +43,7 @@ import ( agentproto "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/cli/cliui" + "github.com/coder/coder/v2/cli/cliutil" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbfake" @@ -473,7 +474,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -542,7 +543,7 @@ func TestSSH(t *testing.T) { signer, err := agentssh.CoderSigner(keySeed) assert.NoError(t, err) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -605,7 +606,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -773,7 +774,7 @@ func TestSSH(t *testing.T) { // have access to the shell. _ = agenttest.New(t, client.URL, authToken) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ Reader: proxyCommandStdoutR, Writer: clientStdinW, }, "", &ssh.ClientConfig{ @@ -835,7 +836,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -894,7 +895,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -1082,7 +1083,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -1741,7 +1742,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -2110,6 +2111,46 @@ func TestSSH_Container(t *testing.T) { }) } +func TestSSH_CoderConnect(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + + fs := afero.NewMemMapFs() + //nolint:revive,staticcheck + ctx = context.WithValue(ctx, "fs", fs) + + client, workspace, agentToken := setupWorkspaceForAgent(t) + inv, root := clitest.New(t, "ssh", workspace.Name, "--network-info-dir", "/net", "--stdio") + clitest.SetupConfig(t, client, root) + _ = ptytest.New(t).Attach(inv) + + errCh := make(chan error, 1) + tGo(t, func() { + err := inv.WithContext(withCoderConnectRunning(ctx)).Run() + errCh <- err + }) + + _ = agenttest.New(t, client.URL, agentToken) + coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + + err := testutil.TryReceive(ctx, t, errCh) + // Making an SSH server available here is difficult, so we'll just check + // the command attempts to dial it. + require.ErrorContains(t, err, "dial coder connect host") + require.ErrorContains(t, err, "dev.myworkspace.myuser.coder") + + // The network info file should be created since we passed `--stdio` + assert.Eventually(t, func() bool { + entries, err := afero.ReadDir(fs, "/net") + if err != nil { + return false + } + return len(entries) > 0 + }, testutil.WaitLong, testutil.IntervalFast) +} + // tGoContext runs fn in a goroutine passing a context that will be // canceled on test completion and wait until fn has finished executing. // Done and cancel are returned for optionally waiting until completion @@ -2153,35 +2194,6 @@ func tGo(t *testing.T, fn func()) (done <-chan struct{}) { return doneC } -type stdioConn struct { - io.Reader - io.Writer -} - -func (*stdioConn) Close() (err error) { - return nil -} - -func (*stdioConn) LocalAddr() net.Addr { - return nil -} - -func (*stdioConn) RemoteAddr() net.Addr { - return nil -} - -func (*stdioConn) SetDeadline(_ time.Time) error { - return nil -} - -func (*stdioConn) SetReadDeadline(_ time.Time) error { - return nil -} - -func (*stdioConn) SetWriteDeadline(_ time.Time) error { - return nil -} - // tempDirUnixSocket returns a temporary directory that can safely hold unix // sockets (probably). // diff --git a/cli/testdata/coder_ssh_--help.golden b/cli/testdata/coder_ssh_--help.golden index 1f7122dd655a2..9aefb24145596 100644 --- a/cli/testdata/coder_ssh_--help.golden +++ b/cli/testdata/coder_ssh_--help.golden @@ -12,6 +12,10 @@ OPTIONS: -e, --env string-array, $CODER_SSH_ENV Set environment variable(s) for session (key1=value1,key2=value2,...). + --force-tunnel bool + Force the use of a new tunnel to the workspace, even if the Coder + Connect tunnel is available. + -A, --forward-agent bool, $CODER_SSH_FORWARD_AGENT Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK. diff --git a/codersdk/workspacesdk/agentconn.go b/codersdk/workspacesdk/agentconn.go index fa569080f7dd2..97b4268c68780 100644 --- a/codersdk/workspacesdk/agentconn.go +++ b/codersdk/workspacesdk/agentconn.go @@ -185,14 +185,12 @@ func (c *AgentConn) SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, return c.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), port)) } -// SSHClient calls SSH to create a client that uses a weak cipher -// to improve throughput. +// SSHClient calls SSH to create a client func (c *AgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) { return c.SSHClientOnPort(ctx, AgentSSHPort) } // SSHClientOnPort calls SSH to create a client on a specific port -// that uses a weak cipher to improve throughput. func (c *AgentConn) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() diff --git a/docs/reference/cli/ssh.md b/docs/reference/cli/ssh.md index c5bae755c8419..b9d76afd1452f 100644 --- a/docs/reference/cli/ssh.md +++ b/docs/reference/cli/ssh.md @@ -138,6 +138,14 @@ Specifies a directory to write network information periodically. Specifies the interval to update network information. +### --force-tunnel + +| | | +|------|-------------------| +| Type | bool | + +Force the use of a new tunnel to the workspace, even if the Coder Connect tunnel is available. + ### --disable-autostart | | | From d8e1c90fa6f33852d687d9741b4f851c676979de Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Mon, 28 Apr 2025 11:06:29 +0000 Subject: [PATCH 02/10] fix windows tests --- cli/ssh_internal_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/ssh_internal_test.go b/cli/ssh_internal_test.go index 6de6bd9ea24bf..82659e1495afd 100644 --- a/cli/ssh_internal_test.go +++ b/cli/ssh_internal_test.go @@ -324,7 +324,7 @@ func newSSHServer(addr string) *sshServer { server: &gliderssh.Server{ Addr: addr, Handler: func(s gliderssh.Session) { - _, _ = io.WriteString(s, "Connected!") + _, _ = io.WriteString(s.Stderr(), "Connected!") }, }, } From 578ebb0b744d5db9cde72101412a4686dfcd931a Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Mon, 28 Apr 2025 12:46:42 +0000 Subject: [PATCH 03/10] rename flag, extra test --- cli/ssh.go | 4 +- cli/ssh_test.go | 88 +++++++++++++++++++--------- cli/testdata/coder_ssh_--help.golden | 4 +- docs/reference/cli/ssh.md | 4 +- 4 files changed, 65 insertions(+), 35 deletions(-) diff --git a/cli/ssh.go b/cli/ssh.go index 82cab0aee1219..82c66eb939964 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -665,8 +665,8 @@ func (r *RootCmd) ssh() *serpent.Command { Hidden: true, // Hidden until this features is at least in beta. }, { - Flag: "force-tunnel", - Description: "Force the use of a new tunnel to the workspace, even if the Coder Connect tunnel is available.", + Flag: "force-new-tunnel", + Description: "Force the creation of a new tunnel to the workspace, even if the Coder Connect tunnel is available.", Value: serpent.BoolOf(&forceTunnel), }, sshDisableAutostartOption(serpent.BoolOf(&disableAutostart)), diff --git a/cli/ssh_test.go b/cli/ssh_test.go index e9dd7c4bc42b2..1a2d2aa37425f 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -2114,41 +2114,71 @@ func TestSSH_Container(t *testing.T) { func TestSSH_CoderConnect(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() + t.Run("Enabled", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() - fs := afero.NewMemMapFs() - //nolint:revive,staticcheck - ctx = context.WithValue(ctx, "fs", fs) + fs := afero.NewMemMapFs() + //nolint:revive,staticcheck + ctx = context.WithValue(ctx, "fs", fs) - client, workspace, agentToken := setupWorkspaceForAgent(t) - inv, root := clitest.New(t, "ssh", workspace.Name, "--network-info-dir", "/net", "--stdio") - clitest.SetupConfig(t, client, root) - _ = ptytest.New(t).Attach(inv) + client, workspace, agentToken := setupWorkspaceForAgent(t) + inv, root := clitest.New(t, "ssh", workspace.Name, "--network-info-dir", "/net", "--stdio") + clitest.SetupConfig(t, client, root) + _ = ptytest.New(t).Attach(inv) - errCh := make(chan error, 1) - tGo(t, func() { - err := inv.WithContext(withCoderConnectRunning(ctx)).Run() - errCh <- err + errCh := make(chan error, 1) + tGo(t, func() { + err := inv.WithContext(withCoderConnectRunning(ctx)).Run() + errCh <- err + }) + + _ = agenttest.New(t, client.URL, agentToken) + coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + + err := testutil.TryReceive(ctx, t, errCh) + // Making an SSH server available here is difficult, so we'll just check + // the command attempts to dial it. + require.ErrorContains(t, err, "dial coder connect host") + require.ErrorContains(t, err, "dev.myworkspace.myuser.coder") + + // The network info file should be created since we passed `--stdio` + assert.Eventually(t, func() bool { + entries, err := afero.ReadDir(fs, "/net") + if err != nil { + return false + } + return len(entries) > 0 + }, testutil.WaitLong, testutil.IntervalFast) }) - _ = agenttest.New(t, client.URL, agentToken) - coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + t.Run("Disabled", func(t *testing.T) { + t.Parallel() - err := testutil.TryReceive(ctx, t, errCh) - // Making an SSH server available here is difficult, so we'll just check - // the command attempts to dial it. - require.ErrorContains(t, err, "dial coder connect host") - require.ErrorContains(t, err, "dev.myworkspace.myuser.coder") - - // The network info file should be created since we passed `--stdio` - assert.Eventually(t, func() bool { - entries, err := afero.ReadDir(fs, "/net") - if err != nil { - return false - } - return len(entries) > 0 - }, testutil.WaitLong, testutil.IntervalFast) + client, workspace, agentToken := setupWorkspaceForAgent(t) + inv, root := clitest.New(t, "ssh", workspace.Name, "--force-new-tunnel") + clitest.SetupConfig(t, client, root) + pty := ptytest.New(t).Attach(inv) + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + cmdDone := tGo(t, func() { + err := inv.WithContext(withCoderConnectRunning(ctx)).Run() + assert.NoError(t, err) + }) + // Shouldn't fail to dial the coder connect host `--force-new-tunnel` + // is passed. + pty.ExpectMatch("Waiting") + + _ = agenttest.New(t, client.URL, agentToken) + coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + + // Shells on Mac, Windows, and Linux all exit shells with the "exit" command. + pty.WriteLine("exit") + <-cmdDone + }) } // tGoContext runs fn in a goroutine passing a context that will be diff --git a/cli/testdata/coder_ssh_--help.golden b/cli/testdata/coder_ssh_--help.golden index 9aefb24145596..12e70e03682de 100644 --- a/cli/testdata/coder_ssh_--help.golden +++ b/cli/testdata/coder_ssh_--help.golden @@ -12,8 +12,8 @@ OPTIONS: -e, --env string-array, $CODER_SSH_ENV Set environment variable(s) for session (key1=value1,key2=value2,...). - --force-tunnel bool - Force the use of a new tunnel to the workspace, even if the Coder + --force-new-tunnel bool + Force the creation of a new tunnel to the workspace, even if the Coder Connect tunnel is available. -A, --forward-agent bool, $CODER_SSH_FORWARD_AGENT diff --git a/docs/reference/cli/ssh.md b/docs/reference/cli/ssh.md index b9d76afd1452f..e7d1b75a616c6 100644 --- a/docs/reference/cli/ssh.md +++ b/docs/reference/cli/ssh.md @@ -138,13 +138,13 @@ Specifies a directory to write network information periodically. Specifies the interval to update network information. -### --force-tunnel +### --force-new-tunnel | | | |------|-------------------| | Type | bool | -Force the use of a new tunnel to the workspace, even if the Coder Connect tunnel is available. +Force the creation of a new tunnel to the workspace, even if the Coder Connect tunnel is available. ### --disable-autostart From be118e600ff4731af7996925557be39d40ecfde8 Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Tue, 29 Apr 2025 13:25:26 +0000 Subject: [PATCH 04/10] reduce scope --- cli/cliutil/stdioconn.go | 14 +-- cli/ssh.go | 171 ++++++++------------------- cli/ssh_internal_test.go | 34 +----- cli/ssh_test.go | 16 +-- cli/testdata/coder_ssh_--help.golden | 4 - docs/reference/cli/ssh.md | 8 -- 6 files changed, 66 insertions(+), 181 deletions(-) diff --git a/cli/cliutil/stdioconn.go b/cli/cliutil/stdioconn.go index 7f919dbf9d456..ed87fe552cbd5 100644 --- a/cli/cliutil/stdioconn.go +++ b/cli/cliutil/stdioconn.go @@ -6,31 +6,31 @@ import ( "time" ) -type StdioConn struct { +type ReaderWriterConn struct { io.Reader io.Writer } -func (*StdioConn) Close() (err error) { +func (*ReaderWriterConn) Close() (err error) { return nil } -func (*StdioConn) LocalAddr() net.Addr { +func (*ReaderWriterConn) LocalAddr() net.Addr { return nil } -func (*StdioConn) RemoteAddr() net.Addr { +func (*ReaderWriterConn) RemoteAddr() net.Addr { return nil } -func (*StdioConn) SetDeadline(_ time.Time) error { +func (*ReaderWriterConn) SetDeadline(_ time.Time) error { return nil } -func (*StdioConn) SetReadDeadline(_ time.Time) error { +func (*ReaderWriterConn) SetReadDeadline(_ time.Time) error { return nil } -func (*StdioConn) SetWriteDeadline(_ time.Time) error { +func (*ReaderWriterConn) SetWriteDeadline(_ time.Time) error { return nil } diff --git a/cli/ssh.go b/cli/ssh.go index 82c66eb939964..365f46ab07db9 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -67,7 +67,7 @@ func (r *RootCmd) ssh() *serpent.Command { stdio bool hostPrefix string hostnameSuffix string - forceTunnel bool + forceNewTunnel bool forwardAgent bool forwardGPG bool identityAgent string @@ -278,27 +278,38 @@ func (r *RootCmd) ssh() *serpent.Command { return err } - // See if we can use the Coder Connect tunnel - if !forceTunnel { + // If we're in stdio mode, check to see if we can use Coder Connect. + // We don't support Coder Connect over non-stdio coder ssh yet. + if stdio && !forceNewTunnel { connInfo, err := wsClient.AgentConnectionInfoGeneric(ctx) if err != nil { return xerrors.Errorf("get agent connection info: %w", err) } - coderConnectHost := fmt.Sprintf("%s.%s.%s.%s", workspaceAgent.Name, workspace.Name, workspace.OwnerName, connInfo.HostnameSuffix) exists, _ := workspacesdk.ExistsViaCoderConnect(ctx, coderConnectHost) if exists { _, _ = fmt.Fprintln(inv.Stderr, "Connecting to workspace via Coder Connect...") defer cancel() - addr := fmt.Sprintf("%s:22", coderConnectHost) - if stdio { + + if networkInfoDir != "" { if err := writeCoderConnectNetInfo(ctx, networkInfoDir); err != nil { logger.Error(ctx, "failed to write coder connect net info file", slog.Error(err)) } - return runCoderConnectStdio(ctx, addr, stdioReader, stdioWriter, stack) } - return runCoderConnectPTY(ctx, addr, inv.Stdin, inv.Stdout, inv.Stderr, stack) + + stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace) + defer stopPolling() + + usageAppName := getUsageAppName(usageApp) + if usageAppName != "" { + closeUsage := client.UpdateWorkspaceUsageWithBodyContext(ctx, workspace.ID, codersdk.PostWorkspaceUsageRequest{ + AgentID: workspaceAgent.ID, + AppName: usageAppName, + }) + defer closeUsage() + } + return runCoderConnectStdio(ctx, fmt.Sprintf("%s:22", coderConnectHost), stdioReader, stdioWriter, stack) } } @@ -481,11 +492,36 @@ func (r *RootCmd) ssh() *serpent.Command { stdinFile, validIn := inv.Stdin.(*os.File) stdoutFile, validOut := inv.Stdout.(*os.File) if validIn && validOut && isatty.IsTerminal(stdinFile.Fd()) && isatty.IsTerminal(stdoutFile.Fd()) { - restorePtyFn, err := configurePTY(ctx, stdinFile, stdoutFile, sshSession) - defer restorePtyFn() + inState, err := pty.MakeInputRaw(stdinFile.Fd()) + if err != nil { + return err + } + defer func() { + _ = pty.RestoreTerminal(stdinFile.Fd(), inState) + }() + outState, err := pty.MakeOutputRaw(stdoutFile.Fd()) if err != nil { - return xerrors.Errorf("configure pty: %w", err) + return err } + defer func() { + _ = pty.RestoreTerminal(stdoutFile.Fd(), outState) + }() + + windowChange := listenWindowSize(ctx) + go func() { + for { + select { + case <-ctx.Done(): + return + case <-windowChange: + } + width, height, err := term.GetSize(int(stdoutFile.Fd())) + if err != nil { + continue + } + _ = sshSession.WindowChange(height, width) + } + }() } for _, kv := range parsedEnv { @@ -667,48 +703,14 @@ func (r *RootCmd) ssh() *serpent.Command { { Flag: "force-new-tunnel", Description: "Force the creation of a new tunnel to the workspace, even if the Coder Connect tunnel is available.", - Value: serpent.BoolOf(&forceTunnel), + Value: serpent.BoolOf(&forceNewTunnel), + Hidden: true, }, sshDisableAutostartOption(serpent.BoolOf(&disableAutostart)), } return cmd } -func configurePTY(ctx context.Context, stdinFile *os.File, stdoutFile *os.File, sshSession *gossh.Session) (restoreFn func(), err error) { - inState, err := pty.MakeInputRaw(stdinFile.Fd()) - if err != nil { - return restoreFn, err - } - restoreFn = func() { - _ = pty.RestoreTerminal(stdinFile.Fd(), inState) - } - outState, err := pty.MakeOutputRaw(stdoutFile.Fd()) - if err != nil { - return restoreFn, err - } - restoreFn = func() { - _ = pty.RestoreTerminal(stdinFile.Fd(), inState) - _ = pty.RestoreTerminal(stdoutFile.Fd(), outState) - } - - windowChange := listenWindowSize(ctx) - go func() { - for { - select { - case <-ctx.Done(): - return - case <-windowChange: - } - width, height, err := term.GetSize(int(stdoutFile.Fd())) - if err != nil { - continue - } - _ = sshSession.WindowChange(height, width) - } - }() - return restoreFn, nil -} - // findWorkspaceAndAgentByHostname parses the hostname from the commandline and finds the workspace and agent it // corresponds to, taking into account any name prefixes or suffixes configured (e.g. myworkspace.coder, or // vscode-coder--myusername--myworkspace). @@ -1502,7 +1504,7 @@ func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, std return err } - agentssh.Bicopy(ctx, conn, &cliutil.StdioConn{ + agentssh.Bicopy(ctx, conn, &cliutil.ReaderWriterConn{ Reader: stdin, Writer: stdout, }) @@ -1510,79 +1512,6 @@ func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, std return nil } -func runCoderConnectPTY(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stderr io.Writer, stack *closerStack) error { - client, err := gossh.Dial("tcp", addr, &gossh.ClientConfig{ - // We've already checked the agent's address - // is within the Coder service prefix. - // #nosec - HostKeyCallback: gossh.InsecureIgnoreHostKey(), - }) - if err != nil { - return xerrors.Errorf("dial coder connect host: %w", err) - } - if err := stack.push("ssh client", client); err != nil { - return err - } - - session, err := client.NewSession() - if err != nil { - return xerrors.Errorf("create ssh session: %w", err) - } - if err := stack.push("ssh session", session); err != nil { - return err - } - - stdinFile, validIn := stdin.(*os.File) - stdoutFile, validOut := stdout.(*os.File) - if validIn && validOut && isatty.IsTerminal(stdinFile.Fd()) && isatty.IsTerminal(stdoutFile.Fd()) { - restorePtyFn, err := configurePTY(ctx, stdinFile, stdoutFile, session) - defer restorePtyFn() - if err != nil { - return xerrors.Errorf("configure pty: %w", err) - } - } - - session.Stdin = stdin - session.Stdout = stdout - session.Stderr = stderr - - err = session.RequestPty("xterm-256color", 80, 24, gossh.TerminalModes{}) - if err != nil { - return xerrors.Errorf("request pty: %w", err) - } - - err = session.Shell() - if err != nil { - return xerrors.Errorf("start shell: %w", err) - } - - if validOut { - // Set initial window size. - width, height, err := term.GetSize(int(stdoutFile.Fd())) - if err == nil { - _ = session.WindowChange(height, width) - } - } - - err = session.Wait() - if err != nil { - if exitErr := (&gossh.ExitError{}); errors.As(err, &exitErr) { - // Clear the error since it's not useful beyond - // reporting status. - return ExitError(exitErr.ExitStatus(), nil) - } - // If the connection drops unexpectedly, we get an - // ExitMissingError but no other error details, so try to at - // least give the user a better message - if errors.Is(err, &gossh.ExitMissingError{}) { - return ExitError(255, xerrors.New("SSH connection ended unexpectedly")) - } - return xerrors.Errorf("session ended: %w", err) - } - - return nil -} - func writeCoderConnectNetInfo(ctx context.Context, networkInfoDir string) error { fs, ok := ctx.Value("fs").(afero.Fs) if !ok { diff --git a/cli/ssh_internal_test.go b/cli/ssh_internal_test.go index 82659e1495afd..ea0e3f1534713 100644 --- a/cli/ssh_internal_test.go +++ b/cli/ssh_internal_test.go @@ -22,7 +22,6 @@ import ( "github.com/coder/coder/v2/cli/cliutil" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/testutil" ) @@ -226,37 +225,6 @@ func TestCloserStack_Timeout(t *testing.T) { testutil.TryReceive(ctx, t, closed) } -func TestCoderConnectPTY(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitShort) - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - stack := newCloserStack(ctx, logger, quartz.NewMock(t)) - - server := newSSHServer("127.0.0.1:0") - ln, err := net.Listen("tcp", server.server.Addr) - require.NoError(t, err) - - go func() { - _ = server.Serve(ln) - }() - t.Cleanup(func() { - _ = server.Close() - }) - - ptty := ptytest.New(t) - ptyDone := make(chan struct{}) - go func() { - err := runCoderConnectPTY(ctx, ln.Addr().String(), ptty.Output(), ptty.Input(), ptty.Output(), stack) - assert.NoError(t, err) - close(ptyDone) - }() - ptty.ExpectMatch("Connected!") - // Shells on Mac, Windows, and Linux all exit shells with the "exit" command. - ptty.WriteLine("exit") - <-ptyDone -} - func TestCoderConnectStdio(t *testing.T) { t.Parallel() @@ -290,7 +258,7 @@ func TestCoderConnectStdio(t *testing.T) { close(stdioDone) }() - conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 1a2d2aa37425f..d76633f27c858 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -474,7 +474,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -543,7 +543,7 @@ func TestSSH(t *testing.T) { signer, err := agentssh.CoderSigner(keySeed) assert.NoError(t, err) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -606,7 +606,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -774,7 +774,7 @@ func TestSSH(t *testing.T) { // have access to the shell. _ = agenttest.New(t, client.URL, authToken) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ Reader: proxyCommandStdoutR, Writer: clientStdinW, }, "", &ssh.ClientConfig{ @@ -836,7 +836,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -895,7 +895,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -1083,7 +1083,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -1742,7 +1742,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ diff --git a/cli/testdata/coder_ssh_--help.golden b/cli/testdata/coder_ssh_--help.golden index 12e70e03682de..1f7122dd655a2 100644 --- a/cli/testdata/coder_ssh_--help.golden +++ b/cli/testdata/coder_ssh_--help.golden @@ -12,10 +12,6 @@ OPTIONS: -e, --env string-array, $CODER_SSH_ENV Set environment variable(s) for session (key1=value1,key2=value2,...). - --force-new-tunnel bool - Force the creation of a new tunnel to the workspace, even if the Coder - Connect tunnel is available. - -A, --forward-agent bool, $CODER_SSH_FORWARD_AGENT Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK. diff --git a/docs/reference/cli/ssh.md b/docs/reference/cli/ssh.md index e7d1b75a616c6..c5bae755c8419 100644 --- a/docs/reference/cli/ssh.md +++ b/docs/reference/cli/ssh.md @@ -138,14 +138,6 @@ Specifies a directory to write network information periodically. Specifies the interval to update network information. -### --force-new-tunnel - -| | | -|------|-------------------| -| Type | bool | - -Force the creation of a new tunnel to the workspace, even if the Coder Connect tunnel is available. - ### --disable-autostart | | | From bb75fa27deedc26d326a9907be347338cfe3659b Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Tue, 29 Apr 2025 13:40:39 +0000 Subject: [PATCH 05/10] review --- cli/ssh.go | 11 ++++++++++- cli/ssh_internal_test.go | 3 +-- cli/ssh_test.go | 17 ++++++++--------- cli/cliutil/stdioconn.go => testutil/rwconn.go | 2 +- 4 files changed, 20 insertions(+), 13 deletions(-) rename cli/cliutil/stdioconn.go => testutil/rwconn.go (96%) diff --git a/cli/ssh.go b/cli/ssh.go index 365f46ab07db9..bb04649535de3 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -1504,7 +1504,7 @@ func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, std return err } - agentssh.Bicopy(ctx, conn, &cliutil.ReaderWriterConn{ + agentssh.Bicopy(ctx, conn, &StdioRwc{ Reader: stdin, Writer: stdout, }) @@ -1512,6 +1512,15 @@ func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, std return nil } +type StdioRwc struct { + io.Reader + io.Writer +} + +func (*StdioRwc) Close() error { + return nil +} + func writeCoderConnectNetInfo(ctx context.Context, networkInfoDir string) error { fs, ok := ctx.Value("fs").(afero.Fs) if !ok { diff --git a/cli/ssh_internal_test.go b/cli/ssh_internal_test.go index ea0e3f1534713..d76ff1881680c 100644 --- a/cli/ssh_internal_test.go +++ b/cli/ssh_internal_test.go @@ -20,7 +20,6 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/quartz" - "github.com/coder/coder/v2/cli/cliutil" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" ) @@ -258,7 +257,7 @@ func TestCoderConnectStdio(t *testing.T) { close(stdioDone) }() - conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ diff --git a/cli/ssh_test.go b/cli/ssh_test.go index d76633f27c858..6f1703fe92236 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -43,7 +43,6 @@ import ( agentproto "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/cli/cliui" - "github.com/coder/coder/v2/cli/cliutil" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbfake" @@ -474,7 +473,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -543,7 +542,7 @@ func TestSSH(t *testing.T) { signer, err := agentssh.CoderSigner(keySeed) assert.NoError(t, err) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -606,7 +605,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -774,7 +773,7 @@ func TestSSH(t *testing.T) { // have access to the shell. _ = agenttest.New(t, client.URL, authToken) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: proxyCommandStdoutR, Writer: clientStdinW, }, "", &ssh.ClientConfig{ @@ -836,7 +835,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -895,7 +894,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -1083,7 +1082,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -1742,7 +1741,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ diff --git a/cli/cliutil/stdioconn.go b/testutil/rwconn.go similarity index 96% rename from cli/cliutil/stdioconn.go rename to testutil/rwconn.go index ed87fe552cbd5..a731e9c3c0ab0 100644 --- a/cli/cliutil/stdioconn.go +++ b/testutil/rwconn.go @@ -1,4 +1,4 @@ -package cliutil +package testutil import ( "io" From 00f18afef5d5474f33bc3f7507fb4bd2619962f2 Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Tue, 29 Apr 2025 13:50:57 +0000 Subject: [PATCH 06/10] typo --- cli/ssh_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 6f1703fe92236..15663dd4bbe9c 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -2167,8 +2167,8 @@ func TestSSH_CoderConnect(t *testing.T) { err := inv.WithContext(withCoderConnectRunning(ctx)).Run() assert.NoError(t, err) }) - // Shouldn't fail to dial the coder connect host `--force-new-tunnel` - // is passed. + // Shouldn't fail to dial the coder connect host since + // `--force-new-tunnel` is passed. pty.ExpectMatch("Waiting") _ = agenttest.New(t, client.URL, agentToken) From 4ce57b72acc732609cfa629701be33757df398d6 Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Tue, 29 Apr 2025 14:12:43 +0000 Subject: [PATCH 07/10] fix tests --- cli/ssh_test.go | 51 ++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 11 deletions(-) diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 15663dd4bbe9c..7c0f299414953 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -2154,28 +2154,57 @@ func TestSSH_CoderConnect(t *testing.T) { t.Run("Disabled", func(t *testing.T) { t.Parallel() - client, workspace, agentToken := setupWorkspaceForAgent(t) - inv, root := clitest.New(t, "ssh", workspace.Name, "--force-new-tunnel") - clitest.SetupConfig(t, client, root) - pty := ptytest.New(t).Attach(inv) + + _ = agenttest.New(t, client.URL, agentToken) + coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + + clientOutput, clientInput := io.Pipe() + serverOutput, serverInput := io.Pipe() + defer func() { + for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} { + _ = c.Close() + } + }() ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() + inv, root := clitest.New(t, "ssh", "--force-new-tunnel", "--stdio", workspace.Name) + clitest.SetupConfig(t, client, root) + inv.Stdin = clientOutput + inv.Stdout = serverInput + inv.Stderr = io.Discard + cmdDone := tGo(t, func() { - err := inv.WithContext(withCoderConnectRunning(ctx)).Run() + err := inv.WithContext(ctx).Run() + // Shouldn't fail to dial the Coder Connect host + // since `--force-new-tunnel` was passed assert.NoError(t, err) }) - // Shouldn't fail to dial the coder connect host since - // `--force-new-tunnel` is passed. - pty.ExpectMatch("Waiting") - _ = agenttest.New(t, client.URL, agentToken) - coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ + Reader: serverOutput, + Writer: clientInput, + }, "", &ssh.ClientConfig{ + // #nosec + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + require.NoError(t, err) + defer conn.Close() + + sshClient := ssh.NewClient(conn, channels, requests) + session, err := sshClient.NewSession() + require.NoError(t, err) + defer session.Close() // Shells on Mac, Windows, and Linux all exit shells with the "exit" command. - pty.WriteLine("exit") + err = session.Run("exit") + require.NoError(t, err) + err = sshClient.Close() + require.NoError(t, err) + _ = clientOutput.Close() + <-cmdDone }) } From e46a08405a720a219e5845bb2e973cf5bc571820 Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Wed, 30 Apr 2025 02:01:35 +0000 Subject: [PATCH 08/10] fixup --- cli/ssh.go | 1 - cli/ssh_test.go | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/cli/ssh.go b/cli/ssh.go index bb04649535de3..ffa85f8690eab 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -289,7 +289,6 @@ func (r *RootCmd) ssh() *serpent.Command { workspaceAgent.Name, workspace.Name, workspace.OwnerName, connInfo.HostnameSuffix) exists, _ := workspacesdk.ExistsViaCoderConnect(ctx, coderConnectHost) if exists { - _, _ = fmt.Fprintln(inv.Stderr, "Connecting to workspace via Coder Connect...") defer cancel() if networkInfoDir != "" { diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 7c0f299414953..2ec33f98a8437 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -2177,7 +2177,7 @@ func TestSSH_CoderConnect(t *testing.T) { inv.Stderr = io.Discard cmdDone := tGo(t, func() { - err := inv.WithContext(ctx).Run() + err := inv.WithContext(withCoderConnectRunning(ctx)).Run() // Shouldn't fail to dial the Coder Connect host // since `--force-new-tunnel` was passed assert.NoError(t, err) From 76603ed54c7439d02c684085ecd327b3756a7a00 Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Wed, 30 Apr 2025 03:54:04 +0000 Subject: [PATCH 09/10] review --- cli/ssh.go | 21 ++++++++++++++++++++- cli/ssh_internal_test.go | 4 ++-- cli/ssh_test.go | 34 +++++++++++++++++++++------------- 3 files changed, 43 insertions(+), 16 deletions(-) diff --git a/cli/ssh.go b/cli/ssh.go index ffa85f8690eab..f93fa79656858 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -1494,8 +1494,27 @@ func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, }, nil } +type coderConnectDialerContextKey struct{} + +type coderConnectDialer interface { + DialContext(ctx context.Context, network, addr string) (net.Conn, error) +} + +func WithTestOnlyCoderConnectDialer(ctx context.Context, dialer coderConnectDialer) context.Context { + return context.WithValue(ctx, coderConnectDialerContextKey{}, dialer) +} + +func testOrDefaultDialer(ctx context.Context) coderConnectDialer { + dialer, ok := ctx.Value(coderConnectDialerContextKey{}).(coderConnectDialer) + if !ok || dialer == nil { + return &net.Dialer{} + } + return dialer +} + func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stack *closerStack) error { - conn, err := net.Dial("tcp", addr) + dialer := testOrDefaultDialer(ctx) + conn, err := dialer.DialContext(ctx, "tcp", addr) if err != nil { return xerrors.Errorf("dial coder connect host: %w", err) } diff --git a/cli/ssh_internal_test.go b/cli/ssh_internal_test.go index d76ff1881680c..caee1ec25b710 100644 --- a/cli/ssh_internal_test.go +++ b/cli/ssh_internal_test.go @@ -272,8 +272,8 @@ func TestCoderConnectStdio(t *testing.T) { require.NoError(t, err) defer session.Close() - // Shells on Mac, Windows, and Linux all exit shells with the "exit" command. - err = session.Run("exit") + // We're not connected to a real shell + err = session.Run("") require.NoError(t, err) err = sshClient.Close() require.NoError(t, err) diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 2ec33f98a8437..90d36f57dbb81 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -41,6 +41,7 @@ import ( "github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/agent/agenttest" agentproto "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/cli" "github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/cli/cliui" "github.com/coder/coder/v2/coderd/coderdtest" @@ -2127,9 +2128,12 @@ func TestSSH_CoderConnect(t *testing.T) { clitest.SetupConfig(t, client, root) _ = ptytest.New(t).Attach(inv) + ctx = cli.WithTestOnlyCoderConnectDialer(ctx, &fakeCoderConnectDialer{}) + ctx = withCoderConnectRunning(ctx) + errCh := make(chan error, 1) tGo(t, func() { - err := inv.WithContext(withCoderConnectRunning(ctx)).Run() + err := inv.WithContext(ctx).Run() errCh <- err }) @@ -2137,19 +2141,14 @@ func TestSSH_CoderConnect(t *testing.T) { coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) err := testutil.TryReceive(ctx, t, errCh) - // Making an SSH server available here is difficult, so we'll just check - // the command attempts to dial it. - require.ErrorContains(t, err, "dial coder connect host") - require.ErrorContains(t, err, "dev.myworkspace.myuser.coder") + // Our mock dialer will always fail with this error, if it was called + require.ErrorContains(t, err, "dial coder connect host \"dev.myworkspace.myuser.coder:22\" over tcp") // The network info file should be created since we passed `--stdio` - assert.Eventually(t, func() bool { - entries, err := afero.ReadDir(fs, "/net") - if err != nil { - return false - } - return len(entries) > 0 - }, testutil.WaitLong, testutil.IntervalFast) + entries, err := afero.ReadDir(fs, "/net") + require.NoError(t, err) + require.True(t, len(entries) > 0) + }) t.Run("Disabled", func(t *testing.T) { @@ -2176,8 +2175,11 @@ func TestSSH_CoderConnect(t *testing.T) { inv.Stdout = serverInput inv.Stderr = io.Discard + ctx = cli.WithTestOnlyCoderConnectDialer(ctx, &fakeCoderConnectDialer{}) + ctx = withCoderConnectRunning(ctx) + cmdDone := tGo(t, func() { - err := inv.WithContext(withCoderConnectRunning(ctx)).Run() + err := inv.WithContext(ctx).Run() // Shouldn't fail to dial the Coder Connect host // since `--force-new-tunnel` was passed assert.NoError(t, err) @@ -2209,6 +2211,12 @@ func TestSSH_CoderConnect(t *testing.T) { }) } +type fakeCoderConnectDialer struct{} + +func (*fakeCoderConnectDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, xerrors.Errorf("dial coder connect host %q over %s", addr, network) +} + // tGoContext runs fn in a goroutine passing a context that will be // canceled on test completion and wait until fn has finished executing. // Done and cancel are returned for optionally waiting until completion From 8c050235ed4c610ed1cf2f2d56529431a3602563 Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Wed, 30 Apr 2025 04:29:43 +0000 Subject: [PATCH 10/10] fmt --- cli/ssh_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 90d36f57dbb81..ab754626c54fa 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -2148,7 +2148,6 @@ func TestSSH_CoderConnect(t *testing.T) { entries, err := afero.ReadDir(fs, "/net") require.NoError(t, err) require.True(t, len(entries) > 0) - }) t.Run("Disabled", func(t *testing.T) {