Skip to content

Commit 200a87e

Browse files
authored
feat(cli/ssh): allow multiple remote forwards and allow missing local file (#11648)
1 parent 73e6bbf commit 200a87e

File tree

5 files changed

+127
-32
lines changed

5 files changed

+127
-32
lines changed

cli/remoteforward.go

+3-10
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"fmt"
66
"io"
77
"net"
8-
"os"
98
"regexp"
109
"strconv"
1110

@@ -67,19 +66,13 @@ func parseRemoteForwardTCP(matches []string) (net.Addr, net.Addr, error) {
6766
return localAddr, remoteAddr, nil
6867
}
6968

69+
// parseRemoteForwardUnixSocket parses a remote forward flag. Note that
70+
// we don't verify that the local socket path exists because the user
71+
// may create it later. This behavior matches OpenSSH.
7072
func parseRemoteForwardUnixSocket(matches []string) (net.Addr, net.Addr, error) {
7173
remoteSocket := matches[1]
7274
localSocket := matches[2]
7375

74-
fileInfo, err := os.Stat(localSocket)
75-
if err != nil {
76-
return nil, nil, err
77-
}
78-
79-
if fileInfo.Mode()&os.ModeSocket == 0 {
80-
return nil, nil, xerrors.New("File is not a Unix domain socket file")
81-
}
82-
8376
remoteAddr := &net.UnixAddr{
8477
Name: remoteSocket,
8578
Net: "unix",

cli/ssh.go

+24-20
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
5353
waitEnum string
5454
noWait bool
5555
logDirPath string
56-
remoteForward string
56+
remoteForwards []string
5757
disableAutostart bool
5858
)
5959
client := new(codersdk.Client)
@@ -135,13 +135,15 @@ func (r *RootCmd) ssh() *clibase.Cmd {
135135
stack := newCloserStack(ctx, logger)
136136
defer stack.close(nil)
137137

138-
if remoteForward != "" {
139-
isValid := validateRemoteForward(remoteForward)
140-
if !isValid {
141-
return xerrors.Errorf(`invalid format of remote-forward, expected: remote_port:local_address:local_port`)
142-
}
143-
if isValid && stdio {
144-
return xerrors.Errorf(`remote-forward can't be enabled in the stdio mode`)
138+
if len(remoteForwards) > 0 {
139+
for _, remoteForward := range remoteForwards {
140+
isValid := validateRemoteForward(remoteForward)
141+
if !isValid {
142+
return xerrors.Errorf(`invalid format of remote-forward, expected: remote_port:local_address:local_port`)
143+
}
144+
if isValid && stdio {
145+
return xerrors.Errorf(`remote-forward can't be enabled in the stdio mode`)
146+
}
145147
}
146148
}
147149

@@ -311,18 +313,20 @@ func (r *RootCmd) ssh() *clibase.Cmd {
311313
}
312314
}
313315

314-
if remoteForward != "" {
315-
localAddr, remoteAddr, err := parseRemoteForward(remoteForward)
316-
if err != nil {
317-
return err
318-
}
316+
if len(remoteForwards) > 0 {
317+
for _, remoteForward := range remoteForwards {
318+
localAddr, remoteAddr, err := parseRemoteForward(remoteForward)
319+
if err != nil {
320+
return err
321+
}
319322

320-
closer, err := sshRemoteForward(ctx, inv.Stderr, sshClient, localAddr, remoteAddr)
321-
if err != nil {
322-
return xerrors.Errorf("ssh remote forward: %w", err)
323-
}
324-
if err = stack.push("sshRemoteForward", closer); err != nil {
325-
return err
323+
closer, err := sshRemoteForward(ctx, inv.Stderr, sshClient, localAddr, remoteAddr)
324+
if err != nil {
325+
return xerrors.Errorf("ssh remote forward: %w", err)
326+
}
327+
if err = stack.push("sshRemoteForward", closer); err != nil {
328+
return err
329+
}
326330
}
327331
}
328332

@@ -460,7 +464,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
460464
Description: "Enable remote port forwarding (remote_port:local_address:local_port).",
461465
Env: "CODER_SSH_REMOTE_FORWARD",
462466
FlagShorthand: "R",
463-
Value: clibase.StringOf(&remoteForward),
467+
Value: clibase.StringArrayOf(&remoteForwards),
464468
},
465469
sshDisableAutostartOption(clibase.BoolOf(&disableAutostart)),
466470
}

cli/ssh_test.go

+98
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,104 @@ func TestSSH(t *testing.T) {
883883
require.NoError(t, err)
884884
})
885885

886+
// Test that we can remote forward multiple sockets, whether or not the
887+
// local sockets exists at the time of establishing xthe SSH connection.
888+
t.Run("RemoteForwardMultipleUnixSockets", func(t *testing.T) {
889+
if runtime.GOOS == "windows" {
890+
t.Skip("Test not supported on windows")
891+
}
892+
893+
t.Parallel()
894+
895+
client, workspace, agentToken := setupWorkspaceForAgent(t)
896+
897+
_ = agenttest.New(t, client.URL, agentToken)
898+
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
899+
900+
// Wait super long so this doesn't flake on -race test.
901+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
902+
defer cancel()
903+
904+
tmpdir := tempDirUnixSocket(t)
905+
906+
type testSocket struct {
907+
local string
908+
remote string
909+
}
910+
911+
args := []string{"ssh", workspace.Name}
912+
var sockets []testSocket
913+
for i := 0; i < 2; i++ {
914+
localSock := filepath.Join(tmpdir, fmt.Sprintf("local-%d.sock", i))
915+
remoteSock := filepath.Join(tmpdir, fmt.Sprintf("remote-%d.sock", i))
916+
sockets = append(sockets, testSocket{
917+
local: localSock,
918+
remote: remoteSock,
919+
})
920+
args = append(args, "--remote-forward", fmt.Sprintf("%s:%s", remoteSock, localSock))
921+
}
922+
923+
inv, root := clitest.New(t, args...)
924+
clitest.SetupConfig(t, client, root)
925+
pty := ptytest.New(t).Attach(inv)
926+
inv.Stderr = pty.Output()
927+
928+
w := clitest.StartWithWaiter(t, inv.WithContext(ctx))
929+
defer w.Wait() // We don't care about any exit error (exit code 255: SSH connection ended unexpectedly).
930+
931+
// Since something was output, it should be safe to write input.
932+
// This could show a prompt or "running startup scripts", so it's
933+
// not indicative of the SSH connection being ready.
934+
_ = pty.Peek(ctx, 1)
935+
936+
// Ensure the SSH connection is ready by testing the shell
937+
// input/output.
938+
pty.WriteLine("echo ping' 'pong")
939+
pty.ExpectMatchContext(ctx, "ping pong")
940+
941+
for i, sock := range sockets {
942+
i := i
943+
// Start the listener on the "local machine".
944+
l, err := net.Listen("unix", sock.local)
945+
require.NoError(t, err)
946+
defer l.Close() //nolint:revive // Defer is fine in this loop, we only run it twice.
947+
testutil.Go(t, func() {
948+
for {
949+
fd, err := l.Accept()
950+
if err != nil {
951+
if !errors.Is(err, net.ErrClosed) {
952+
assert.NoError(t, err, "listener accept failed", i)
953+
}
954+
return
955+
}
956+
957+
testutil.Go(t, func() {
958+
defer fd.Close()
959+
agentssh.Bicopy(ctx, fd, fd)
960+
})
961+
}
962+
})
963+
964+
// Dial the forwarded socket on the "remote machine".
965+
d := &net.Dialer{}
966+
fd, err := d.DialContext(ctx, "unix", sock.remote)
967+
require.NoError(t, err, i)
968+
defer fd.Close() //nolint:revive // Defer is fine in this loop, we only run it twice.
969+
970+
// Ping / pong to ensure the socket is working.
971+
_, err = fd.Write([]byte("hello world"))
972+
require.NoError(t, err, i)
973+
974+
buf := make([]byte, 11)
975+
_, err = fd.Read(buf)
976+
require.NoError(t, err, i)
977+
require.Equal(t, "hello world", string(buf), i)
978+
}
979+
980+
// And we're done.
981+
pty.WriteLine("exit")
982+
})
983+
886984
t.Run("FileLogging", func(t *testing.T) {
887985
t.Parallel()
888986

cli/testdata/coder_ssh_--help.golden

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ OPTIONS:
3333
behavior as non-blocking.
3434
DEPRECATED: Use --wait instead.
3535

36-
-R, --remote-forward string, $CODER_SSH_REMOTE_FORWARD
36+
-R, --remote-forward string-array, $CODER_SSH_REMOTE_FORWARD
3737
Enable remote port forwarding (remote_port:local_address:local_port).
3838

3939
--stdio bool, $CODER_SSH_STDIO

docs/cli/ssh.md

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)