Skip to content

Commit 819fd9a

Browse files
committed
add tests and check for socket file
1 parent add7b45 commit 819fd9a

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

cli/remoteforward.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"io"
77
"net"
8+
"os"
89
"regexp"
910
"strconv"
1011

@@ -70,6 +71,15 @@ func parseRemoteForwardUnixSocket(matches []string) (net.Addr, net.Addr, error)
7071
remoteSocket := matches[1]
7172
localSocket := matches[2]
7273

74+
fileInfo, err := os.Stat(localSocket)
75+
if err != nil {
76+
return nil, nil, err
77+
}
78+
79+
if fileInfo.Mode()&os.ModeSocket == 0 {
80+
return nil, nil, xerrors.New("File is not a Unix domain socket file")
81+
}
82+
7383
remoteAddr := &net.UnixAddr{
7484
Name: remoteSocket,
7585
Net: "unix",

cli/ssh_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,54 @@ func TestSSH(t *testing.T) {
428428
<-cmdDone
429429
})
430430

431+
t.Run("RemoteForwardUnixSocket", func(t *testing.T) {
432+
if runtime.GOOS == "windows" {
433+
t.Skip("Test not supported on windows")
434+
}
435+
436+
t.Parallel()
437+
438+
client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
439+
440+
_ = agenttest.New(t, client.URL, agentToken)
441+
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
442+
443+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
444+
defer cancel()
445+
446+
tmpdir := tempDirUnixSocket(t)
447+
agentSock := filepath.Join(tmpdir, "agent.sock")
448+
l, err := net.Listen("unix", agentSock)
449+
require.NoError(t, err)
450+
defer l.Close()
451+
452+
inv, root := clitest.New(t,
453+
"ssh",
454+
workspace.Name,
455+
"--remote-forward",
456+
agentSock+":/tmp/test.sock",
457+
)
458+
clitest.SetupConfig(t, client, root)
459+
pty := ptytest.New(t).Attach(inv)
460+
inv.Stderr = pty.Output()
461+
cmdDone := tGo(t, func() {
462+
err := inv.WithContext(ctx).Run()
463+
assert.NoError(t, err, "ssh command failed")
464+
})
465+
466+
// Wait for the prompt or any output really to indicate the command has
467+
// started and accepting input on stdin.
468+
_ = pty.Peek(ctx, 1)
469+
470+
// Download the test page
471+
pty.WriteLine("ss -xl state listening src /var/roo/daemon.sock | wc -l")
472+
pty.ExpectMatch("2")
473+
474+
// And we're done.
475+
pty.WriteLine("exit")
476+
<-cmdDone
477+
})
478+
431479
t.Run("FileLogging", func(t *testing.T) {
432480
t.Parallel()
433481

0 commit comments

Comments
 (0)