From 018b9e6cf9158a4ef161d5af84a6abea06bca25b Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 22 Nov 2023 20:33:22 +0400 Subject: [PATCH] fix: detect and retry reverse port forward on used port --- agent/agent_test.go | 374 ++++++++++--------------------------- agent/agentssh/agentssh.go | 2 +- 2 files changed, 103 insertions(+), 273 deletions(-) diff --git a/agent/agent_test.go b/agent/agent_test.go index b54f877fcdab9..31f1448f34018 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "io" + "math/rand" "net" "net/http" "net/http/httptest" @@ -17,7 +18,6 @@ import ( "path/filepath" "regexp" "runtime" - "strconv" "strings" "sync" "sync/atomic" @@ -25,7 +25,7 @@ import ( "testing" "time" - scp "github.com/bramvdbogaerde/go-scp" + "github.com/bramvdbogaerde/go-scp" "github.com/golang/mock/gomock" "github.com/google/uuid" "github.com/pion/udp" @@ -52,7 +52,6 @@ import ( "github.com/coder/coder/v2/agent/agenttest" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" - "github.com/coder/coder/v2/pty" "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet/tailnettest" @@ -648,150 +647,57 @@ func TestAgent_Session_TTY_HugeOutputIsNotLost(t *testing.T) { } } -//nolint:paralleltest // This test reserves a port. func TestAgent_TCPLocalForwarding(t *testing.T) { - random, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - _ = random.Close() - tcpAddr, valid := random.Addr().(*net.TCPAddr) - require.True(t, valid) - randomPort := tcpAddr.Port + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) - local, err := net.Listen("tcp", "127.0.0.1:0") + rl, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) - defer local.Close() - tcpAddr, valid = local.Addr().(*net.TCPAddr) + defer rl.Close() + tcpAddr, valid := rl.Addr().(*net.TCPAddr) require.True(t, valid) remotePort := tcpAddr.Port - done := make(chan struct{}) - go func() { - defer close(done) - conn, err := local.Accept() - if !assert.NoError(t, err) { - return - } - defer conn.Close() - b := make([]byte, 4) - _, err = conn.Read(b) - if !assert.NoError(t, err) { - return - } - _, err = conn.Write(b) - if !assert.NoError(t, err) { - return - } - }() + go echoOnce(t, rl) - _, proc := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, remotePort)}, []string{"sleep", "5"}) + sshClient := setupAgentSSHClient(ctx, t) - go func() { - err := proc.Wait() - select { - case <-done: - default: - assert.NoError(t, err) - } - }() - - require.Eventually(t, func() bool { - conn, err := net.Dial("tcp", "127.0.0.1:"+strconv.Itoa(randomPort)) - if err != nil { - return false - } - defer conn.Close() - _, err = conn.Write([]byte("test")) - if !assert.NoError(t, err) { - return false - } - b := make([]byte, 4) - _, err = conn.Read(b) - if !assert.NoError(t, err) { - return false - } - if !assert.Equal(t, "test", string(b)) { - return false - } - - return true - }, testutil.WaitLong, testutil.IntervalSlow) - - <-done - - _ = proc.Kill() + conn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort)) + require.NoError(t, err) + defer conn.Close() + requireEcho(t, conn) } -//nolint:paralleltest // This test reserves a port. func TestAgent_TCPRemoteForwarding(t *testing.T) { - random, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - _ = random.Close() - tcpAddr, valid := random.Addr().(*net.TCPAddr) - require.True(t, valid) - randomPort := tcpAddr.Port - - l, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - defer l.Close() - tcpAddr, valid = l.Addr().(*net.TCPAddr) - require.True(t, valid) - localPort := tcpAddr.Port - - done := make(chan struct{}) - go func() { - defer close(done) - - conn, err := l.Accept() - if err != nil { - return - } - defer conn.Close() - b := make([]byte, 4) - _, err = conn.Read(b) - if !assert.NoError(t, err) { - return - } - _, err = conn.Write(b) - if !assert.NoError(t, err) { - return - } - }() - - _, proc := setupSSHCommand(t, []string{"-R", fmt.Sprintf("127.0.0.1:%d:127.0.0.1:%d", randomPort, localPort)}, []string{"sleep", "5"}) - - go func() { - err := proc.Wait() - select { - case <-done: - default: - assert.NoError(t, err) - } - }() - - require.Eventually(t, func() bool { - conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", randomPort)) + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + sshClient := setupAgentSSHClient(ctx, t) + + localhost := netip.MustParseAddr("127.0.0.1") + var randomPort uint16 + var ll net.Listener + var err error + for { + randomPort = pickRandomPort() + addr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(localhost, randomPort)) + ll, err = sshClient.ListenTCP(addr) if err != nil { - return false - } - defer conn.Close() - _, err = conn.Write([]byte("test")) - if !assert.NoError(t, err) { - return false - } - b := make([]byte, 4) - _, err = conn.Read(b) - if !assert.NoError(t, err) { - return false - } - if !assert.Equal(t, "test", string(b)) { - return false + t.Logf("error remote forwarding: %s", err.Error()) + select { + case <-ctx.Done(): + t.Fatal("timed out getting random listener") + default: + continue + } } + break + } + defer ll.Close() + go echoOnce(t, ll) - return true - }, testutil.WaitLong, testutil.IntervalSlow) - - <-done - - _ = proc.Kill() + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", randomPort)) + require.NoError(t, err) + defer conn.Close() + requireEcho(t, conn) } func TestAgent_UnixLocalForwarding(t *testing.T) { @@ -799,52 +705,18 @@ func TestAgent_UnixLocalForwarding(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("unix domain sockets are not fully supported on Windows") } - + ctx := testutil.Context(t, testutil.WaitLong) tmpdir := tempDirUnixSocket(t) remoteSocketPath := filepath.Join(tmpdir, "remote-socket") - localSocketPath := filepath.Join(tmpdir, "local-socket") l, err := net.Listen("unix", remoteSocketPath) require.NoError(t, err) defer l.Close() + go echoOnce(t, l) - done := make(chan struct{}) - go func() { - defer close(done) - - conn, err := l.Accept() - if err != nil { - return - } - defer conn.Close() - b := make([]byte, 4) - _, err = conn.Read(b) - if !assert.NoError(t, err) { - return - } - _, err = conn.Write(b) - if !assert.NoError(t, err) { - return - } - }() - - _, proc := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%s:%s", localSocketPath, remoteSocketPath)}, []string{"sleep", "5"}) - - go func() { - err := proc.Wait() - select { - case <-done: - default: - assert.NoError(t, err) - } - }() - - require.Eventually(t, func() bool { - _, err := os.Stat(localSocketPath) - return err == nil - }, testutil.WaitLong, testutil.IntervalFast) + sshClient := setupAgentSSHClient(ctx, t) - conn, err := net.Dial("unix", localSocketPath) + conn, err := sshClient.Dial("unix", remoteSocketPath) require.NoError(t, err) defer conn.Close() _, err = conn.Write([]byte("test")) @@ -854,9 +726,6 @@ func TestAgent_UnixLocalForwarding(t *testing.T) { require.NoError(t, err) require.Equal(t, "test", string(b)) _ = conn.Close() - <-done - - _ = proc.Kill() } func TestAgent_UnixRemoteForwarding(t *testing.T) { @@ -867,66 +736,19 @@ func TestAgent_UnixRemoteForwarding(t *testing.T) { tmpdir := tempDirUnixSocket(t) remoteSocketPath := filepath.Join(tmpdir, "remote-socket") - localSocketPath := filepath.Join(tmpdir, "local-socket") - l, err := net.Listen("unix", localSocketPath) + ctx := testutil.Context(t, testutil.WaitLong) + sshClient := setupAgentSSHClient(ctx, t) + + l, err := sshClient.ListenUnix(remoteSocketPath) require.NoError(t, err) defer l.Close() + go echoOnce(t, l) - done := make(chan struct{}) - go func() { - defer close(done) - - conn, err := l.Accept() - if err != nil { - return - } - defer conn.Close() - b := make([]byte, 4) - _, err = conn.Read(b) - if !assert.NoError(t, err) { - return - } - _, err = conn.Write(b) - if !assert.NoError(t, err) { - return - } - }() - - _, proc := setupSSHCommand(t, []string{"-R", fmt.Sprintf("%s:%s", remoteSocketPath, localSocketPath)}, []string{"sleep", "5"}) - - go func() { - err := proc.Wait() - select { - case <-done: - default: - assert.NoError(t, err) - } - }() - - // It's possible that the socket is created but the server is not ready to - // accept connections yet. We need to retry until we can connect. - // - // Note that we wait long here because if the tailnet connection has trouble - // connecting, it could take 5 seconds or more to reconnect. - var conn net.Conn - require.Eventually(t, func() bool { - var err error - conn, err = net.Dial("unix", remoteSocketPath) - return err == nil - }, testutil.WaitLong, testutil.IntervalFast) - defer conn.Close() - _, err = conn.Write([]byte("test")) - require.NoError(t, err) - b := make([]byte, 4) - _, err = conn.Read(b) + conn, err := net.Dial("unix", remoteSocketPath) require.NoError(t, err) - require.Equal(t, "test", string(b)) - _ = conn.Close() - - <-done - - _ = proc.Kill() + defer conn.Close() + requireEcho(t, conn) } func TestAgent_SFTP(t *testing.T) { @@ -2063,50 +1885,14 @@ func TestAgent_DebugServer(t *testing.T) { }) } -func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) (*ptytest.PTYCmd, pty.Process) { - //nolint:dogsled +// setupAgentSSHClient creates an agent, dials it, and sets up an ssh.Client for it +func setupAgentSSHClient(ctx context.Context, t *testing.T) *ssh.Client { + //nolint: dogsled agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) - listener, err := net.Listen("tcp", "127.0.0.1:0") + sshClient, err := agentConn.SSHClient(ctx) require.NoError(t, err) - waitGroup := sync.WaitGroup{} - go func() { - defer listener.Close() - for { - conn, err := listener.Accept() - if err != nil { - return - } - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - ssh, err := agentConn.SSH(ctx) - cancel() - if err != nil { - _ = conn.Close() - return - } - waitGroup.Add(1) - go func() { - agentssh.Bicopy(context.Background(), conn, ssh) - waitGroup.Done() - }() - } - }() - t.Cleanup(func() { - _ = listener.Close() - waitGroup.Wait() - }) - tcpAddr, valid := listener.Addr().(*net.TCPAddr) - require.True(t, valid) - args := append(beforeArgs, - "-o", "HostName "+tcpAddr.IP.String(), - "-o", "Port "+strconv.Itoa(tcpAddr.Port), - "-o", "StrictHostKeyChecking=no", - "-o", "UserKnownHostsFile=/dev/null", - "host", - ) - args = append(args, afterArgs...) - cmd := pty.Command("ssh", args...) - return ptytest.Start(t, cmd) + t.Cleanup(func() { sshClient.Close() }) + return sshClient } func setupSSHSession( @@ -2580,3 +2366,47 @@ func (s *syncWriter) Write(p []byte) (int, error) { defer s.mu.Unlock() return s.w.Write(p) } + +// pickRandomPort picks a random port number for the ephemeral range. We do this entirely randomly +// instead of opening a listener and closing it to find a port that is likely to be free, since +// sometimes the OS reallocates the port very quickly. +func pickRandomPort() uint16 { + const ( + // Overlap of windows, linux in https://en.wikipedia.org/wiki/Ephemeral_port + min = 49152 + max = 60999 + ) + n := max - min + x := rand.Intn(n) //nolint: gosec + return uint16(min + x) +} + +// echoOnce accepts a single connection, reads 4 bytes and echos them back +func echoOnce(t *testing.T, ll net.Listener) { + t.Helper() + conn, err := ll.Accept() + if err != nil { + return + } + defer conn.Close() + b := make([]byte, 4) + _, err = conn.Read(b) + if !assert.NoError(t, err) { + return + } + _, err = conn.Write(b) + if !assert.NoError(t, err) { + return + } +} + +// requireEcho sends 4 bytes and requires the read response to match what was sent. +func requireEcho(t *testing.T, conn net.Conn) { + t.Helper() + _, err := conn.Write([]byte("test")) + require.NoError(t, err) + b := make([]byte, 4) + _, err = conn.Read(b) + require.NoError(t, err) + require.Equal(t, "test", string(b)) +} diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index b0f9c11806235..28ca757749386 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -142,7 +142,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom }, ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool { // Allow reverse port forwarding all! - s.logger.Debug(ctx, "local port forward", + s.logger.Debug(ctx, "reverse port forward", slog.F("bind_host", bindHost), slog.F("bind_port", bindPort)) return true