Skip to content

Commit 4966ef0

Browse files
authored
feat(cli): add reverse tunnelling SSH support for unix sockets (#9976)
1 parent 465546e commit 4966ef0

File tree

2 files changed

+103
-5
lines changed

2 files changed

+103
-5
lines changed

cli/remoteforward.go

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"io"
77
"net"
8+
"os"
89
"regexp"
910
"strconv"
1011

@@ -23,15 +24,24 @@ type cookieAddr struct {
2324

2425
// Format:
2526
// remote_port:local_address:local_port
26-
var remoteForwardRegex = regexp.MustCompile(`^(\d+):(.+):(\d+)$`)
27+
var remoteForwardRegexTCP = regexp.MustCompile(`^(\d+):(.+):(\d+)$`)
2728

28-
func validateRemoteForward(flag string) bool {
29-
return remoteForwardRegex.MatchString(flag)
29+
// remote_socket_path:local_socket_path (both absolute paths)
30+
var remoteForwardRegexUnixSocket = regexp.MustCompile(`^(\/.+):(\/.+)$`)
31+
32+
func isRemoteForwardTCP(flag string) bool {
33+
return remoteForwardRegexTCP.MatchString(flag)
3034
}
3135

32-
func parseRemoteForward(flag string) (net.Addr, net.Addr, error) {
33-
matches := remoteForwardRegex.FindStringSubmatch(flag)
36+
func isRemoteForwardUnixSocket(flag string) bool {
37+
return remoteForwardRegexUnixSocket.MatchString(flag)
38+
}
39+
40+
func validateRemoteForward(flag string) bool {
41+
return isRemoteForwardTCP(flag) || isRemoteForwardUnixSocket(flag)
42+
}
3443

44+
func parseRemoteForwardTCP(matches []string) (net.Addr, net.Addr, error) {
3545
remotePort, err := strconv.Atoi(matches[1])
3646
if err != nil {
3747
return nil, nil, xerrors.Errorf("remote port is invalid: %w", err)
@@ -57,6 +67,46 @@ func parseRemoteForward(flag string) (net.Addr, net.Addr, error) {
5767
return localAddr, remoteAddr, nil
5868
}
5969

70+
func parseRemoteForwardUnixSocket(matches []string) (net.Addr, net.Addr, error) {
71+
remoteSocket := matches[1]
72+
localSocket := matches[2]
73+
74+
fileInfo, err := os.Stat(localSocket)
75+
if err != nil {
76+
return nil, nil, err
77+
}
78+
79+
if fileInfo.Mode()&os.ModeSocket == 0 {
80+
return nil, nil, xerrors.New("File is not a Unix domain socket file")
81+
}
82+
83+
remoteAddr := &net.UnixAddr{
84+
Name: remoteSocket,
85+
Net: "unix",
86+
}
87+
88+
localAddr := &net.UnixAddr{
89+
Name: localSocket,
90+
Net: "unix",
91+
}
92+
return localAddr, remoteAddr, nil
93+
}
94+
95+
func parseRemoteForward(flag string) (net.Addr, net.Addr, error) {
96+
tcpMatches := remoteForwardRegexTCP.FindStringSubmatch(flag)
97+
98+
if len(tcpMatches) > 0 {
99+
return parseRemoteForwardTCP(tcpMatches)
100+
}
101+
102+
unixSocketMatches := remoteForwardRegexUnixSocket.FindStringSubmatch(flag)
103+
if len(unixSocketMatches) > 0 {
104+
return parseRemoteForwardUnixSocket(unixSocketMatches)
105+
}
106+
107+
return nil, nil, xerrors.New("Could not match forward arguments")
108+
}
109+
60110
// sshRemoteForward starts forwarding connections from a remote listener to a
61111
// local address via SSH in a goroutine.
62112
//

cli/ssh_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,54 @@ func TestSSH(t *testing.T) {
428428
<-cmdDone
429429
})
430430

431+
t.Run("RemoteForwardUnixSocket", func(t *testing.T) {
432+
if runtime.GOOS == "windows" {
433+
t.Skip("Test not supported on windows")
434+
}
435+
436+
t.Parallel()
437+
438+
client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
439+
440+
_ = agenttest.New(t, client.URL, agentToken)
441+
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
442+
443+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
444+
defer cancel()
445+
446+
tmpdir := tempDirUnixSocket(t)
447+
agentSock := filepath.Join(tmpdir, "agent.sock")
448+
l, err := net.Listen("unix", agentSock)
449+
require.NoError(t, err)
450+
defer l.Close()
451+
452+
inv, root := clitest.New(t,
453+
"ssh",
454+
workspace.Name,
455+
"--remote-forward",
456+
"/tmp/test.sock:"+agentSock,
457+
)
458+
clitest.SetupConfig(t, client, root)
459+
pty := ptytest.New(t).Attach(inv)
460+
inv.Stderr = pty.Output()
461+
cmdDone := tGo(t, func() {
462+
err := inv.WithContext(ctx).Run()
463+
assert.NoError(t, err, "ssh command failed")
464+
})
465+
466+
// Wait for the prompt or any output really to indicate the command has
467+
// started and accepting input on stdin.
468+
_ = pty.Peek(ctx, 1)
469+
470+
// Download the test page
471+
pty.WriteLine("ss -xl state listening src /tmp/test.sock | wc -l")
472+
pty.ExpectMatch("2")
473+
474+
// And we're done.
475+
pty.WriteLine("exit")
476+
<-cmdDone
477+
})
478+
431479
t.Run("FileLogging", func(t *testing.T) {
432480
t.Parallel()
433481

0 commit comments

Comments
 (0)