Skip to content

Commit a862472

Browse files
committed
feat(cli): use coder connect in coder ssh, if available
1 parent 02b2de9 commit a862472

File tree

7 files changed

+411
-77
lines changed

7 files changed

+411
-77
lines changed

cli/cliutil/stdioconn.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package cliutil
2+
3+
import (
4+
"io"
5+
"net"
6+
"time"
7+
)
8+
9+
type StdioConn struct {
10+
io.Reader
11+
io.Writer
12+
}
13+
14+
func (*StdioConn) Close() (err error) {
15+
return nil
16+
}
17+
18+
func (*StdioConn) LocalAddr() net.Addr {
19+
return nil
20+
}
21+
22+
func (*StdioConn) RemoteAddr() net.Addr {
23+
return nil
24+
}
25+
26+
func (*StdioConn) SetDeadline(_ time.Time) error {
27+
return nil
28+
}
29+
30+
func (*StdioConn) SetReadDeadline(_ time.Time) error {
31+
return nil
32+
}
33+
34+
func (*StdioConn) SetWriteDeadline(_ time.Time) error {
35+
return nil
36+
}

cli/ssh.go

Lines changed: 195 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"fmt"
99
"io"
1010
"log"
11+
"net"
1112
"net/http"
1213
"net/url"
1314
"os"
@@ -66,6 +67,7 @@ func (r *RootCmd) ssh() *serpent.Command {
6667
stdio bool
6768
hostPrefix string
6869
hostnameSuffix string
70+
forceTunnel bool
6971
forwardAgent bool
7072
forwardGPG bool
7173
identityAgent string
@@ -85,6 +87,7 @@ func (r *RootCmd) ssh() *serpent.Command {
8587
containerUser string
8688
)
8789
client := new(codersdk.Client)
90+
wsClient := workspacesdk.New(client)
8891
cmd := &serpent.Command{
8992
Annotations: workspaceCommand,
9093
Use: "ssh <workspace>",
@@ -203,14 +206,14 @@ func (r *RootCmd) ssh() *serpent.Command {
203206
parsedEnv = append(parsedEnv, [2]string{k, v})
204207
}
205208

206-
deploymentSSHConfig := codersdk.SSHConfigResponse{
209+
cliConfig := codersdk.SSHConfigResponse{
207210
HostnamePrefix: hostPrefix,
208211
HostnameSuffix: hostnameSuffix,
209212
}
210213

211214
workspace, workspaceAgent, err := findWorkspaceAndAgentByHostname(
212215
ctx, inv, client,
213-
inv.Args[0], deploymentSSHConfig, disableAutostart)
216+
inv.Args[0], cliConfig, disableAutostart)
214217
if err != nil {
215218
return err
216219
}
@@ -275,10 +278,34 @@ func (r *RootCmd) ssh() *serpent.Command {
275278
return err
276279
}
277280

281+
// See if we can use the Coder Connect tunnel
282+
if !forceTunnel {
283+
connInfo, err := wsClient.AgentConnectionInfoGeneric(ctx)
284+
if err != nil {
285+
return xerrors.Errorf("get agent connection info: %w", err)
286+
}
287+
288+
coderConnectHost := fmt.Sprintf("%s.%s.%s.%s",
289+
workspaceAgent.Name, workspace.Name, workspace.OwnerName, connInfo.HostnameSuffix)
290+
exists, _ := workspacesdk.ExistsViaCoderConnect(ctx, coderConnectHost)
291+
if exists {
292+
_, _ = fmt.Fprintln(inv.Stderr, "Connecting to workspace via Coder Connect...")
293+
defer cancel()
294+
addr := fmt.Sprintf("%s:22", coderConnectHost)
295+
if stdio {
296+
if err := writeCoderConnectNetInfo(ctx, networkInfoDir); err != nil {
297+
logger.Error(ctx, "failed to write coder connect net info file", slog.Error(err))
298+
}
299+
return runCoderConnectStdio(ctx, addr, stdioReader, stdioWriter, stack)
300+
}
301+
return runCoderConnectPTY(ctx, addr, inv.Stdin, inv.Stdout, inv.Stderr, stack)
302+
}
303+
}
304+
278305
if r.disableDirect {
279306
_, _ = fmt.Fprintln(inv.Stderr, "Direct connections disabled.")
280307
}
281-
conn, err := workspacesdk.New(client).
308+
conn, err := wsClient.
282309
DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{
283310
Logger: logger,
284311
BlockEndpoints: r.disableDirect,
@@ -454,36 +481,11 @@ func (r *RootCmd) ssh() *serpent.Command {
454481
stdinFile, validIn := inv.Stdin.(*os.File)
455482
stdoutFile, validOut := inv.Stdout.(*os.File)
456483
if validIn && validOut && isatty.IsTerminal(stdinFile.Fd()) && isatty.IsTerminal(stdoutFile.Fd()) {
457-
inState, err := pty.MakeInputRaw(stdinFile.Fd())
458-
if err != nil {
459-
return err
460-
}
461-
defer func() {
462-
_ = pty.RestoreTerminal(stdinFile.Fd(), inState)
463-
}()
464-
outState, err := pty.MakeOutputRaw(stdoutFile.Fd())
484+
restorePtyFn, err := configurePTY(ctx, stdinFile, stdoutFile, sshSession)
485+
defer restorePtyFn()
465486
if err != nil {
466-
return err
487+
return xerrors.Errorf("configure pty: %w", err)
467488
}
468-
defer func() {
469-
_ = pty.RestoreTerminal(stdoutFile.Fd(), outState)
470-
}()
471-
472-
windowChange := listenWindowSize(ctx)
473-
go func() {
474-
for {
475-
select {
476-
case <-ctx.Done():
477-
return
478-
case <-windowChange:
479-
}
480-
width, height, err := term.GetSize(int(stdoutFile.Fd()))
481-
if err != nil {
482-
continue
483-
}
484-
_ = sshSession.WindowChange(height, width)
485-
}
486-
}()
487489
}
488490

489491
for _, kv := range parsedEnv {
@@ -662,11 +664,51 @@ func (r *RootCmd) ssh() *serpent.Command {
662664
Value: serpent.StringOf(&containerUser),
663665
Hidden: true, // Hidden until this features is at least in beta.
664666
},
667+
{
668+
Flag: "force-tunnel",
669+
Description: "Force the use of a new tunnel to the workspace, even if the Coder Connect tunnel is available.",
670+
Value: serpent.BoolOf(&forceTunnel),
671+
},
665672
sshDisableAutostartOption(serpent.BoolOf(&disableAutostart)),
666673
}
667674
return cmd
668675
}
669676

677+
func configurePTY(ctx context.Context, stdinFile *os.File, stdoutFile *os.File, sshSession *gossh.Session) (restoreFn func(), err error) {
678+
inState, err := pty.MakeInputRaw(stdinFile.Fd())
679+
if err != nil {
680+
return restoreFn, err
681+
}
682+
restoreFn = func() {
683+
_ = pty.RestoreTerminal(stdinFile.Fd(), inState)
684+
}
685+
outState, err := pty.MakeOutputRaw(stdoutFile.Fd())
686+
if err != nil {
687+
return restoreFn, err
688+
}
689+
restoreFn = func() {
690+
_ = pty.RestoreTerminal(stdinFile.Fd(), inState)
691+
_ = pty.RestoreTerminal(stdoutFile.Fd(), outState)
692+
}
693+
694+
windowChange := listenWindowSize(ctx)
695+
go func() {
696+
for {
697+
select {
698+
case <-ctx.Done():
699+
return
700+
case <-windowChange:
701+
}
702+
width, height, err := term.GetSize(int(stdoutFile.Fd()))
703+
if err != nil {
704+
continue
705+
}
706+
_ = sshSession.WindowChange(height, width)
707+
}
708+
}()
709+
return restoreFn, nil
710+
}
711+
670712
// findWorkspaceAndAgentByHostname parses the hostname from the commandline and finds the workspace and agent it
671713
// corresponds to, taking into account any name prefixes or suffixes configured (e.g. myworkspace.coder, or
672714
// vscode-coder--myusername--myworkspace).
@@ -1374,12 +1416,13 @@ func setStatsCallback(
13741416
}
13751417

13761418
type sshNetworkStats struct {
1377-
P2P bool `json:"p2p"`
1378-
Latency float64 `json:"latency"`
1379-
PreferredDERP string `json:"preferred_derp"`
1380-
DERPLatency map[string]float64 `json:"derp_latency"`
1381-
UploadBytesSec int64 `json:"upload_bytes_sec"`
1382-
DownloadBytesSec int64 `json:"download_bytes_sec"`
1419+
P2P bool `json:"p2p"`
1420+
Latency float64 `json:"latency"`
1421+
PreferredDERP string `json:"preferred_derp"`
1422+
DERPLatency map[string]float64 `json:"derp_latency"`
1423+
UploadBytesSec int64 `json:"upload_bytes_sec"`
1424+
DownloadBytesSec int64 `json:"download_bytes_sec"`
1425+
UsingCoderConnect bool `json:"using_coder_connect"`
13831426
}
13841427

13851428
func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) {
@@ -1450,6 +1493,121 @@ func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn,
14501493
}, nil
14511494
}
14521495

1496+
func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stack *closerStack) error {
1497+
conn, err := net.Dial("tcp", addr)
1498+
if err != nil {
1499+
return xerrors.Errorf("dial coder connect host: %w", err)
1500+
}
1501+
if err := stack.push("tcp conn", conn); err != nil {
1502+
return err
1503+
}
1504+
1505+
agentssh.Bicopy(ctx, conn, &cliutil.StdioConn{
1506+
Reader: stdin,
1507+
Writer: stdout,
1508+
})
1509+
1510+
return nil
1511+
}
1512+
1513+
func runCoderConnectPTY(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stderr io.Writer, stack *closerStack) error {
1514+
client, err := gossh.Dial("tcp", addr, &gossh.ClientConfig{
1515+
// We've already checked the agent's address
1516+
// is within the Coder service prefix.
1517+
// #nosec
1518+
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
1519+
})
1520+
if err != nil {
1521+
return xerrors.Errorf("dial coder connect host: %w", err)
1522+
}
1523+
if err := stack.push("ssh client", client); err != nil {
1524+
return err
1525+
}
1526+
1527+
session, err := client.NewSession()
1528+
if err != nil {
1529+
return xerrors.Errorf("create ssh session: %w", err)
1530+
}
1531+
if err := stack.push("ssh session", session); err != nil {
1532+
return err
1533+
}
1534+
1535+
stdinFile, validIn := stdin.(*os.File)
1536+
stdoutFile, validOut := stdout.(*os.File)
1537+
if validIn && validOut && isatty.IsTerminal(stdinFile.Fd()) && isatty.IsTerminal(stdoutFile.Fd()) {
1538+
restorePtyFn, err := configurePTY(ctx, stdinFile, stdoutFile, session)
1539+
defer restorePtyFn()
1540+
if err != nil {
1541+
return xerrors.Errorf("configure pty: %w", err)
1542+
}
1543+
}
1544+
1545+
session.Stdin = stdin
1546+
session.Stdout = stdout
1547+
session.Stderr = stderr
1548+
1549+
err = session.RequestPty("xterm-256color", 80, 24, gossh.TerminalModes{})
1550+
if err != nil {
1551+
return xerrors.Errorf("request pty: %w", err)
1552+
}
1553+
1554+
err = session.Shell()
1555+
if err != nil {
1556+
return xerrors.Errorf("start shell: %w", err)
1557+
}
1558+
1559+
if validOut {
1560+
// Set initial window size.
1561+
width, height, err := term.GetSize(int(stdoutFile.Fd()))
1562+
if err == nil {
1563+
_ = session.WindowChange(height, width)
1564+
}
1565+
}
1566+
1567+
err = session.Wait()
1568+
if err != nil {
1569+
if exitErr := (&gossh.ExitError{}); errors.As(err, &exitErr) {
1570+
// Clear the error since it's not useful beyond
1571+
// reporting status.
1572+
return ExitError(exitErr.ExitStatus(), nil)
1573+
}
1574+
// If the connection drops unexpectedly, we get an
1575+
// ExitMissingError but no other error details, so try to at
1576+
// least give the user a better message
1577+
if errors.Is(err, &gossh.ExitMissingError{}) {
1578+
return ExitError(255, xerrors.New("SSH connection ended unexpectedly"))
1579+
}
1580+
return xerrors.Errorf("session ended: %w", err)
1581+
}
1582+
1583+
return nil
1584+
}
1585+
1586+
func writeCoderConnectNetInfo(ctx context.Context, networkInfoDir string) error {
1587+
fs, ok := ctx.Value("fs").(afero.Fs)
1588+
if !ok {
1589+
fs = afero.NewOsFs()
1590+
}
1591+
// The VS Code extension obtains the PID of the SSH process to
1592+
// find the log file associated with a SSH session.
1593+
//
1594+
// We get the parent PID because it's assumed `ssh` is calling this
1595+
// command via the ProxyCommand SSH option.
1596+
networkInfoFilePath := filepath.Join(networkInfoDir, fmt.Sprintf("%d.json", os.Getppid()))
1597+
stats := &sshNetworkStats{
1598+
UsingCoderConnect: true,
1599+
}
1600+
rawStats, err := json.Marshal(stats)
1601+
if err != nil {
1602+
return xerrors.Errorf("marshal network stats: %w", err)
1603+
}
1604+
err = afero.WriteFile(fs, networkInfoFilePath, rawStats, 0o600)
1605+
if err != nil {
1606+
return xerrors.Errorf("write network stats: %w", err)
1607+
}
1608+
return nil
1609+
}
1610+
14531611
// Converts workspace name input to owner/workspace.agent format
14541612
// Possible valid input formats:
14551613
// workspace

0 commit comments

Comments
 (0)