diff --git a/cli/remoteforward.go b/cli/remoteforward.go index 95daa46663ea5..2c4207583b289 100644 --- a/cli/remoteforward.go +++ b/cli/remoteforward.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net" + "os" "regexp" "strconv" @@ -23,15 +24,24 @@ type cookieAddr struct { // Format: // remote_port:local_address:local_port -var remoteForwardRegex = regexp.MustCompile(`^(\d+):(.+):(\d+)$`) +var remoteForwardRegexTCP = regexp.MustCompile(`^(\d+):(.+):(\d+)$`) -func validateRemoteForward(flag string) bool { - return remoteForwardRegex.MatchString(flag) +// remote_socket_path:local_socket_path (both absolute paths) +var remoteForwardRegexUnixSocket = regexp.MustCompile(`^(\/.+):(\/.+)$`) + +func isRemoteForwardTCP(flag string) bool { + return remoteForwardRegexTCP.MatchString(flag) } -func parseRemoteForward(flag string) (net.Addr, net.Addr, error) { - matches := remoteForwardRegex.FindStringSubmatch(flag) +func isRemoteForwardUnixSocket(flag string) bool { + return remoteForwardRegexUnixSocket.MatchString(flag) +} + +func validateRemoteForward(flag string) bool { + return isRemoteForwardTCP(flag) || isRemoteForwardUnixSocket(flag) +} +func parseRemoteForwardTCP(matches []string) (net.Addr, net.Addr, error) { remotePort, err := strconv.Atoi(matches[1]) if err != nil { return nil, nil, xerrors.Errorf("remote port is invalid: %w", err) @@ -57,6 +67,46 @@ func parseRemoteForward(flag string) (net.Addr, net.Addr, error) { return localAddr, remoteAddr, nil } +func parseRemoteForwardUnixSocket(matches []string) (net.Addr, net.Addr, error) { + remoteSocket := matches[1] + localSocket := matches[2] + + fileInfo, err := os.Stat(localSocket) + if err != nil { + return nil, nil, err + } + + if fileInfo.Mode()&os.ModeSocket == 0 { + return nil, nil, xerrors.New("File is not a Unix domain socket file") + } + + remoteAddr := &net.UnixAddr{ + Name: remoteSocket, + Net: "unix", + } + + localAddr := &net.UnixAddr{ + Name: localSocket, + Net: "unix", + } + return localAddr, remoteAddr, nil +} + +func parseRemoteForward(flag string) (net.Addr, net.Addr, error) { + tcpMatches := remoteForwardRegexTCP.FindStringSubmatch(flag) + + if len(tcpMatches) > 0 { + return parseRemoteForwardTCP(tcpMatches) + } + + unixSocketMatches := remoteForwardRegexUnixSocket.FindStringSubmatch(flag) + if len(unixSocketMatches) > 0 { + return parseRemoteForwardUnixSocket(unixSocketMatches) + } + + return nil, nil, xerrors.New("Could not match forward arguments") +} + // sshRemoteForward starts forwarding connections from a remote listener to a // local address via SSH in a goroutine. // diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 90b34ca9f4b70..0f5f00cbd01ba 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -428,6 +428,54 @@ func TestSSH(t *testing.T) { <-cmdDone }) + t.Run("RemoteForwardUnixSocket", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Test not supported on windows") + } + + t.Parallel() + + client, workspace, agentToken := setupWorkspaceForAgent(t, nil) + + _ = agenttest.New(t, client.URL, agentToken) + coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + tmpdir := tempDirUnixSocket(t) + agentSock := filepath.Join(tmpdir, "agent.sock") + l, err := net.Listen("unix", agentSock) + require.NoError(t, err) + defer l.Close() + + inv, root := clitest.New(t, + "ssh", + workspace.Name, + "--remote-forward", + "/tmp/test.sock:"+agentSock, + ) + clitest.SetupConfig(t, client, root) + pty := ptytest.New(t).Attach(inv) + inv.Stderr = pty.Output() + cmdDone := tGo(t, func() { + err := inv.WithContext(ctx).Run() + assert.NoError(t, err, "ssh command failed") + }) + + // Wait for the prompt or any output really to indicate the command has + // started and accepting input on stdin. + _ = pty.Peek(ctx, 1) + + // Download the test page + pty.WriteLine("ss -xl state listening src /tmp/test.sock | wc -l") + pty.ExpectMatch("2") + + // And we're done. + pty.WriteLine("exit") + <-cmdDone + }) + t.Run("FileLogging", func(t *testing.T) { t.Parallel()