Skip to content

Commit 8da5d2c

Browse files
committed
add test case for multiple socket forwards
1 parent 8ba3c82 commit 8da5d2c

File tree

1 file changed

+98
-0
lines changed

1 file changed

+98
-0
lines changed

cli/ssh_test.go

+98
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,104 @@ func TestSSH(t *testing.T) {
883883
require.NoError(t, err)
884884
})
885885

886+
// Test that we can forward a local unix socket to a remote unix socket and
887+
// that new SSH sessions take over the socket without closing active socket
888+
// connections.
889+
t.Run("RemoteForwardMultipleUnixSockets", func(t *testing.T) {
890+
if runtime.GOOS == "windows" {
891+
t.Skip("Test not supported on windows")
892+
}
893+
894+
t.Parallel()
895+
896+
client, workspace, agentToken := setupWorkspaceForAgent(t)
897+
898+
_ = agenttest.New(t, client.URL, agentToken)
899+
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
900+
901+
// Wait super long so this doesn't flake on -race test.
902+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
903+
defer cancel()
904+
905+
tmpdir := tempDirUnixSocket(t)
906+
907+
type testSocket struct {
908+
local string
909+
remote string
910+
}
911+
912+
args := []string{"ssh", workspace.Name}
913+
var sockets []testSocket
914+
for i := 0; i < 2; i++ {
915+
localSock := filepath.Join(tmpdir, fmt.Sprintf("local-%d.sock", i))
916+
remoteSock := filepath.Join(tmpdir, fmt.Sprintf("remote-%d.sock", i))
917+
sockets = append(sockets, testSocket{
918+
local: localSock,
919+
remote: remoteSock,
920+
})
921+
args = append(args, "--remote-forward", fmt.Sprintf("%s:%s", remoteSock, localSock))
922+
}
923+
924+
inv, root := clitest.New(t, args...)
925+
clitest.SetupConfig(t, client, root)
926+
pty := ptytest.New(t).Attach(inv)
927+
inv.Stderr = pty.Output()
928+
929+
clitest.Start(t, inv.WithContext(ctx))
930+
931+
// Since something was output, it should be safe to write input.
932+
// This could show a prompt or "running startup scripts", so it's
933+
// not indicative of the SSH connection being ready.
934+
_ = pty.Peek(ctx, 1)
935+
936+
// Ensure the SSH connection is ready by testing the shell
937+
// input/output.
938+
pty.WriteLine("echo ping' 'pong")
939+
pty.ExpectMatchContext(ctx, "ping pong")
940+
941+
for i, sock := range sockets {
942+
i := i
943+
// Start the listener on the "local machine".
944+
l, err := net.Listen("unix", sock.local)
945+
require.NoError(t, err)
946+
defer l.Close() //nolint:revive // Defer is fine in this loop, we only run it twice.
947+
testutil.Go(t, func() {
948+
for {
949+
fd, err := l.Accept()
950+
if err != nil {
951+
if !errors.Is(err, net.ErrClosed) {
952+
assert.NoError(t, err, "listener accept failed", i)
953+
}
954+
return
955+
}
956+
957+
testutil.Go(t, func() {
958+
defer fd.Close()
959+
agentssh.Bicopy(ctx, fd, fd)
960+
})
961+
}
962+
})
963+
964+
// Dial the forwarded socket on the "remote machine".
965+
d := &net.Dialer{}
966+
fd, err := d.DialContext(ctx, "unix", sock.remote)
967+
require.NoError(t, err, i)
968+
defer fd.Close() //nolint:revive // Defer is fine in this loop, we only run it twice.
969+
970+
// Ping / pong to ensure the socket is working.
971+
_, err = fd.Write([]byte("hello world"))
972+
require.NoError(t, err, i)
973+
974+
buf := make([]byte, 11)
975+
_, err = fd.Read(buf)
976+
require.NoError(t, err, i)
977+
require.Equal(t, "hello world", string(buf), i)
978+
}
979+
980+
// And we're done.
981+
pty.WriteLine("exit")
982+
})
983+
886984
t.Run("FileLogging", func(t *testing.T) {
887985
t.Parallel()
888986

0 commit comments

Comments
 (0)