diff --git a/agent/agent_test.go b/agent/agent_test.go index ec76aa1b0b6b9..6b9fe28fa312f 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -480,10 +480,24 @@ func TestAgent_TCPLocalForwarding(t *testing.T) { } }() + pty := ptytest.New(t) + cmd := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, remotePort)}, []string{"sleep", "5"}) + cmd.Stdin = pty.Input() + cmd.Stdout = pty.Output() + cmd.Stderr = pty.Output() err = cmd.Start() require.NoError(t, err) + go func() { + err := cmd.Wait() + select { + case <-done: + default: + assert.NoError(t, err) + } + }() + require.Eventually(t, func() bool { conn, err := net.Dial("tcp", "127.0.0.1:"+strconv.Itoa(randomPort)) if err != nil { @@ -547,10 +561,24 @@ func TestAgent_TCPRemoteForwarding(t *testing.T) { } }() + pty := ptytest.New(t) + cmd := setupSSHCommand(t, []string{"-R", fmt.Sprintf("127.0.0.1:%d:127.0.0.1:%d", randomPort, localPort)}, []string{"sleep", "5"}) + cmd.Stdin = pty.Input() + cmd.Stdout = pty.Output() + cmd.Stderr = pty.Output() err = cmd.Start() require.NoError(t, err) + go func() { + err := cmd.Wait() + select { + case <-done: + default: + assert.NoError(t, err) + } + }() + require.Eventually(t, func() bool { conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", randomPort)) if err != nil { @@ -612,10 +640,24 @@ func TestAgent_UnixLocalForwarding(t *testing.T) { } }() + pty := ptytest.New(t) + cmd := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%s:%s", localSocketPath, remoteSocketPath)}, []string{"sleep", "5"}) + cmd.Stdin = pty.Input() + cmd.Stdout = pty.Output() + cmd.Stderr = pty.Output() err = cmd.Start() require.NoError(t, err) + go func() { + err := cmd.Wait() + select { + case <-done: + default: + assert.NoError(t, err) + } + }() + require.Eventually(t, func() bool { _, err := os.Stat(localSocketPath) return err == nil @@ -670,10 +712,24 @@ func TestAgent_UnixRemoteForwarding(t *testing.T) { } }() + pty := ptytest.New(t) + cmd := setupSSHCommand(t, []string{"-R", fmt.Sprintf("%s:%s", remoteSocketPath, localSocketPath)}, []string{"sleep", "5"}) + cmd.Stdin = pty.Input() + cmd.Stdout = pty.Output() + cmd.Stderr = pty.Output() err = cmd.Start() require.NoError(t, err) + go func() { + err := cmd.Wait() + select { + case <-done: + default: + assert.NoError(t, err) + } + }() + // It's possible that the socket is created but the server is not ready to // accept connections yet. We need to retry until we can connect. var conn net.Conn