Skip to content

Commit f752bb6

Browse files
committed
add test case for multiple socket forwards
1 parent 8ba3c82 commit f752bb6

File tree

1 file changed

+100
-3
lines changed

1 file changed

+100
-3
lines changed

cli/ssh_test.go

+100-3
Original file line numberDiff line numberDiff line change
@@ -773,9 +773,8 @@ func TestSSH(t *testing.T) {
773773
<-cmdDone
774774
})
775775

776-
// Test that we can forward a local unix socket to a remote unix socket and
777-
// that new SSH sessions take over the socket without closing active socket
778-
// connections.
776+
// Test that we can remote forward multiple sockets, whether or not the
777+
// local sockets exists at the time of establishing the SSH connection.
779778
t.Run("RemoteForwardUnixSocketMultipleSessionsOverwrite", func(t *testing.T) {
780779
if runtime.GOOS == "windows" {
781780
t.Skip("Test not supported on windows")
@@ -883,6 +882,104 @@ func TestSSH(t *testing.T) {
883882
require.NoError(t, err)
884883
})
885884

885+
// Test that we can forward a local unix socket to a remote unix socket and
886+
// that new SSH sessions take over the socket without closing active socket
887+
// connections.
888+
t.Run("RemoteForwardMultipleUnixSockets", func(t *testing.T) {
889+
if runtime.GOOS == "windows" {
890+
t.Skip("Test not supported on windows")
891+
}
892+
893+
t.Parallel()
894+
895+
client, workspace, agentToken := setupWorkspaceForAgent(t)
896+
897+
_ = agenttest.New(t, client.URL, agentToken)
898+
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
899+
900+
// Wait super long so this doesn't flake on -race test.
901+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
902+
defer cancel()
903+
904+
tmpdir := tempDirUnixSocket(t)
905+
906+
type testSocket struct {
907+
local string
908+
remote string
909+
}
910+
911+
args := []string{"ssh", workspace.Name}
912+
var sockets []testSocket
913+
for i := 0; i < 2; i++ {
914+
localSock := filepath.Join(tmpdir, fmt.Sprintf("local-%d.sock", i))
915+
remoteSock := filepath.Join(tmpdir, fmt.Sprintf("remote-%d.sock", i))
916+
sockets = append(sockets, testSocket{
917+
local: localSock,
918+
remote: remoteSock,
919+
})
920+
args = append(args, "--remote-forward", fmt.Sprintf("%s:%s", remoteSock, localSock))
921+
}
922+
923+
inv, root := clitest.New(t, args...)
924+
clitest.SetupConfig(t, client, root)
925+
pty := ptytest.New(t).Attach(inv)
926+
inv.Stderr = pty.Output()
927+
928+
clitest.Start(t, inv.WithContext(ctx))
929+
930+
// Since something was output, it should be safe to write input.
931+
// This could show a prompt or "running startup scripts", so it's
932+
// not indicative of the SSH connection being ready.
933+
_ = pty.Peek(ctx, 1)
934+
935+
// Ensure the SSH connection is ready by testing the shell
936+
// input/output.
937+
pty.WriteLine("echo ping' 'pong")
938+
pty.ExpectMatchContext(ctx, "ping pong")
939+
940+
for i, sock := range sockets {
941+
i := i
942+
// Start the listener on the "local machine".
943+
l, err := net.Listen("unix", sock.local)
944+
require.NoError(t, err)
945+
defer l.Close() //nolint:revive // Defer is fine in this loop, we only run it twice.
946+
testutil.Go(t, func() {
947+
for {
948+
fd, err := l.Accept()
949+
if err != nil {
950+
if !errors.Is(err, net.ErrClosed) {
951+
assert.NoError(t, err, "listener accept failed", i)
952+
}
953+
return
954+
}
955+
956+
testutil.Go(t, func() {
957+
defer fd.Close()
958+
agentssh.Bicopy(ctx, fd, fd)
959+
})
960+
}
961+
})
962+
963+
// Dial the forwarded socket on the "remote machine".
964+
d := &net.Dialer{}
965+
fd, err := d.DialContext(ctx, "unix", sock.remote)
966+
require.NoError(t, err, i)
967+
defer fd.Close() //nolint:revive // Defer is fine in this loop, we only run it twice.
968+
969+
// Ping / pong to ensure the socket is working.
970+
_, err = fd.Write([]byte("hello world"))
971+
require.NoError(t, err, i)
972+
973+
buf := make([]byte, 11)
974+
_, err = fd.Read(buf)
975+
require.NoError(t, err, i)
976+
require.Equal(t, "hello world", string(buf), i)
977+
}
978+
979+
// And we're done.
980+
pty.WriteLine("exit")
981+
})
982+
886983
t.Run("FileLogging", func(t *testing.T) {
887984
t.Parallel()
888985

0 commit comments

Comments
 (0)