diff --git a/cli/remoteforward.go b/cli/remoteforward.go
index 2c4207583b289..bffc50694c061 100644
--- a/cli/remoteforward.go
+++ b/cli/remoteforward.go
@@ -5,7 +5,6 @@ import (
"fmt"
"io"
"net"
- "os"
"regexp"
"strconv"
@@ -67,19 +66,13 @@ func parseRemoteForwardTCP(matches []string) (net.Addr, net.Addr, error) {
return localAddr, remoteAddr, nil
}
+// parseRemoteForwardUnixSocket parses a remote forward flag. Note that
+// we don't verify that the local socket path exists because the user
+// may create it later. This behavior matches OpenSSH.
func parseRemoteForwardUnixSocket(matches []string) (net.Addr, net.Addr, error) {
remoteSocket := matches[1]
localSocket := matches[2]
- fileInfo, err := os.Stat(localSocket)
- if err != nil {
- return nil, nil, err
- }
-
- if fileInfo.Mode()&os.ModeSocket == 0 {
- return nil, nil, xerrors.New("File is not a Unix domain socket file")
- }
-
remoteAddr := &net.UnixAddr{
Name: remoteSocket,
Net: "unix",
diff --git a/cli/ssh.go b/cli/ssh.go
index b3fc79d51df73..b11f48b9b1780 100644
--- a/cli/ssh.go
+++ b/cli/ssh.go
@@ -53,7 +53,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
waitEnum string
noWait bool
logDirPath string
- remoteForward string
+ remoteForwards []string
disableAutostart bool
)
client := new(codersdk.Client)
@@ -135,13 +135,15 @@ func (r *RootCmd) ssh() *clibase.Cmd {
stack := newCloserStack(ctx, logger)
defer stack.close(nil)
- if remoteForward != "" {
- isValid := validateRemoteForward(remoteForward)
- if !isValid {
- return xerrors.Errorf(`invalid format of remote-forward, expected: remote_port:local_address:local_port`)
- }
- if isValid && stdio {
- return xerrors.Errorf(`remote-forward can't be enabled in the stdio mode`)
+ if len(remoteForwards) > 0 {
+ for _, remoteForward := range remoteForwards {
+ isValid := validateRemoteForward(remoteForward)
+ if !isValid {
+ return xerrors.Errorf(`invalid format of remote-forward, expected: remote_port:local_address:local_port`)
+ }
+ if isValid && stdio {
+ return xerrors.Errorf(`remote-forward can't be enabled in the stdio mode`)
+ }
}
}
@@ -311,18 +313,20 @@ func (r *RootCmd) ssh() *clibase.Cmd {
}
}
- if remoteForward != "" {
- localAddr, remoteAddr, err := parseRemoteForward(remoteForward)
- if err != nil {
- return err
- }
+ if len(remoteForwards) > 0 {
+ for _, remoteForward := range remoteForwards {
+ localAddr, remoteAddr, err := parseRemoteForward(remoteForward)
+ if err != nil {
+ return err
+ }
- closer, err := sshRemoteForward(ctx, inv.Stderr, sshClient, localAddr, remoteAddr)
- if err != nil {
- return xerrors.Errorf("ssh remote forward: %w", err)
- }
- if err = stack.push("sshRemoteForward", closer); err != nil {
- return err
+ closer, err := sshRemoteForward(ctx, inv.Stderr, sshClient, localAddr, remoteAddr)
+ if err != nil {
+ return xerrors.Errorf("ssh remote forward: %w", err)
+ }
+ if err = stack.push("sshRemoteForward", closer); err != nil {
+ return err
+ }
}
}
@@ -460,7 +464,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
Description: "Enable remote port forwarding (remote_port:local_address:local_port).",
Env: "CODER_SSH_REMOTE_FORWARD",
FlagShorthand: "R",
- Value: clibase.StringOf(&remoteForward),
+ Value: clibase.StringArrayOf(&remoteForwards),
},
sshDisableAutostartOption(clibase.BoolOf(&disableAutostart)),
}
diff --git a/cli/ssh_test.go b/cli/ssh_test.go
index 684e8700c1f50..fdde064ce9cf7 100644
--- a/cli/ssh_test.go
+++ b/cli/ssh_test.go
@@ -883,6 +883,104 @@ func TestSSH(t *testing.T) {
require.NoError(t, err)
})
+ // Test that we can remote forward multiple sockets, whether or not the
+ // local sockets exists at the time of establishing xthe SSH connection.
+ t.Run("RemoteForwardMultipleUnixSockets", func(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("Test not supported on windows")
+ }
+
+ t.Parallel()
+
+ client, workspace, agentToken := setupWorkspaceForAgent(t)
+
+ _ = agenttest.New(t, client.URL, agentToken)
+ coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
+
+ // Wait super long so this doesn't flake on -race test.
+ ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
+ defer cancel()
+
+ tmpdir := tempDirUnixSocket(t)
+
+ type testSocket struct {
+ local string
+ remote string
+ }
+
+ args := []string{"ssh", workspace.Name}
+ var sockets []testSocket
+ for i := 0; i < 2; i++ {
+ localSock := filepath.Join(tmpdir, fmt.Sprintf("local-%d.sock", i))
+ remoteSock := filepath.Join(tmpdir, fmt.Sprintf("remote-%d.sock", i))
+ sockets = append(sockets, testSocket{
+ local: localSock,
+ remote: remoteSock,
+ })
+ args = append(args, "--remote-forward", fmt.Sprintf("%s:%s", remoteSock, localSock))
+ }
+
+ inv, root := clitest.New(t, args...)
+ clitest.SetupConfig(t, client, root)
+ pty := ptytest.New(t).Attach(inv)
+ inv.Stderr = pty.Output()
+
+ w := clitest.StartWithWaiter(t, inv.WithContext(ctx))
+ defer w.Wait() // We don't care about any exit error (exit code 255: SSH connection ended unexpectedly).
+
+ // Since something was output, it should be safe to write input.
+ // This could show a prompt or "running startup scripts", so it's
+ // not indicative of the SSH connection being ready.
+ _ = pty.Peek(ctx, 1)
+
+ // Ensure the SSH connection is ready by testing the shell
+ // input/output.
+ pty.WriteLine("echo ping' 'pong")
+ pty.ExpectMatchContext(ctx, "ping pong")
+
+ for i, sock := range sockets {
+ i := i
+ // Start the listener on the "local machine".
+ l, err := net.Listen("unix", sock.local)
+ require.NoError(t, err)
+ defer l.Close() //nolint:revive // Defer is fine in this loop, we only run it twice.
+ testutil.Go(t, func() {
+ for {
+ fd, err := l.Accept()
+ if err != nil {
+ if !errors.Is(err, net.ErrClosed) {
+ assert.NoError(t, err, "listener accept failed", i)
+ }
+ return
+ }
+
+ testutil.Go(t, func() {
+ defer fd.Close()
+ agentssh.Bicopy(ctx, fd, fd)
+ })
+ }
+ })
+
+ // Dial the forwarded socket on the "remote machine".
+ d := &net.Dialer{}
+ fd, err := d.DialContext(ctx, "unix", sock.remote)
+ require.NoError(t, err, i)
+ defer fd.Close() //nolint:revive // Defer is fine in this loop, we only run it twice.
+
+ // Ping / pong to ensure the socket is working.
+ _, err = fd.Write([]byte("hello world"))
+ require.NoError(t, err, i)
+
+ buf := make([]byte, 11)
+ _, err = fd.Read(buf)
+ require.NoError(t, err, i)
+ require.Equal(t, "hello world", string(buf), i)
+ }
+
+ // And we're done.
+ pty.WriteLine("exit")
+ })
+
t.Run("FileLogging", func(t *testing.T) {
t.Parallel()
diff --git a/cli/testdata/coder_ssh_--help.golden b/cli/testdata/coder_ssh_--help.golden
index b76e56a8abafd..ce53948c70f47 100644
--- a/cli/testdata/coder_ssh_--help.golden
+++ b/cli/testdata/coder_ssh_--help.golden
@@ -33,7 +33,7 @@ OPTIONS:
behavior as non-blocking.
DEPRECATED: Use --wait instead.
- -R, --remote-forward string, $CODER_SSH_REMOTE_FORWARD
+ -R, --remote-forward string-array, $CODER_SSH_REMOTE_FORWARD
Enable remote port forwarding (remote_port:local_address:local_port).
--stdio bool, $CODER_SSH_STDIO
diff --git a/docs/cli/ssh.md b/docs/cli/ssh.md
index b3416f3307950..34762d5b2bd59 100644
--- a/docs/cli/ssh.md
+++ b/docs/cli/ssh.md
@@ -71,7 +71,7 @@ Enter workspace immediately after the agent has connected. This is the default i
| | |
| ----------- | -------------------------------------- |
-| Type | string
|
+| Type | string-array
|
| Environment | $CODER_SSH_REMOTE_FORWARD
|
Enable remote port forwarding (remote_port:local_address:local_port).