diff --git a/cli/portforward_test.go b/cli/portforward_test.go index 2260a4a32cdf4..04602db81f358 100644 --- a/cli/portforward_test.go +++ b/cli/portforward_test.go @@ -144,20 +144,19 @@ func TestPortForward(t *testing.T) { for _, c := range cases { //nolint:paralleltest // the `c := c` confuses the linter c := c + // Avoid parallel test here because setupLocal reserves + // a free open port which is not guaranteed to be free + // after the listener closes. + //nolint:paralleltest t.Run(c.name, func(t *testing.T) { - t.Parallel() - + //nolint:paralleltest t.Run("OnePort", func(t *testing.T) { - t.Parallel() var ( client = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true}) user = coderdtest.CreateFirstUser(t, client) _, workspace = runAgent(t, client, user.UserID) - l1, p1 = setupTestListener(t, c.setupRemote(t)) + p1 = setupTestListener(t, c.setupRemote(t)) ) - t.Cleanup(func() { - _ = l1.Close() - }) // Create a flag that forwards from local to listener 1. localAddress, localFlag := c.setupLocal(t) @@ -171,9 +170,9 @@ func TestPortForward(t *testing.T) { cmd.SetOut(io.MultiWriter(buf, os.Stderr)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() + errC := make(chan error) go func() { - err := cmd.ExecuteContext(ctx) - assert.ErrorIs(t, err, context.Canceled) + errC <- cmd.ExecuteContext(ctx) }() waitForPortForwardReady(t, buf) @@ -188,21 +187,21 @@ func TestPortForward(t *testing.T) { defer c2.Close() testDial(t, c2) testDial(t, c1) + + cancel() + err = <-errC + require.ErrorIs(t, err, context.Canceled) }) + //nolint:paralleltest t.Run("TwoPorts", func(t *testing.T) { - t.Parallel() var ( client = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true}) user = coderdtest.CreateFirstUser(t, client) _, workspace = runAgent(t, client, user.UserID) - l1, p1 = setupTestListener(t, c.setupRemote(t)) - l2, p2 = setupTestListener(t, c.setupRemote(t)) + p1 = setupTestListener(t, c.setupRemote(t)) + p2 = setupTestListener(t, c.setupRemote(t)) ) - t.Cleanup(func() { - _ = l1.Close() - _ = l2.Close() - }) // Create a flags for listener 1 and listener 2. localAddress1, localFlag1 := c.setupLocal(t) @@ -218,9 +217,9 @@ func TestPortForward(t *testing.T) { cmd.SetOut(io.MultiWriter(buf, os.Stderr)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() + errC := make(chan error) go func() { - err := cmd.ExecuteContext(ctx) - assert.ErrorIs(t, err, context.Canceled) + errC <- cmd.ExecuteContext(ctx) }() waitForPortForwardReady(t, buf) @@ -235,13 +234,17 @@ func TestPortForward(t *testing.T) { defer c2.Close() testDial(t, c2) testDial(t, c1) + + cancel() + err = <-errC + require.ErrorIs(t, err, context.Canceled) }) }) } // Test doing a TCP -> Unix forward. + //nolint:paralleltest t.Run("TCP2Unix", func(t *testing.T) { - t.Parallel() var ( client = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true}) user = coderdtest.CreateFirstUser(t, client) @@ -253,11 +256,8 @@ func TestPortForward(t *testing.T) { unixCase = cases[2] // Setup remote Unix listener. - l1, p1 = setupTestListener(t, unixCase.setupRemote(t)) + p1 = setupTestListener(t, unixCase.setupRemote(t)) ) - t.Cleanup(func() { - _ = l1.Close() - }) // Create a flag that forwards from local TCP to Unix listener 1. // Notably this is a --unix flag. @@ -272,9 +272,9 @@ func TestPortForward(t *testing.T) { cmd.SetOut(io.MultiWriter(buf, os.Stderr)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() + errC := make(chan error) go func() { - err := cmd.ExecuteContext(ctx) - assert.ErrorIs(t, err, context.Canceled) + errC <- cmd.ExecuteContext(ctx) }() waitForPortForwardReady(t, buf) @@ -289,11 +289,15 @@ func TestPortForward(t *testing.T) { defer c2.Close() testDial(t, c2) testDial(t, c1) + + cancel() + err = <-errC + require.ErrorIs(t, err, context.Canceled) }) // Test doing TCP, UDP and Unix at the same time. + //nolint:paralleltest t.Run("All", func(t *testing.T) { - t.Parallel() var ( client = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true}) user = coderdtest.CreateFirstUser(t, client) @@ -311,10 +315,7 @@ func TestPortForward(t *testing.T) { continue } - l, p := setupTestListener(t, c.setupRemote(t)) - t.Cleanup(func() { - _ = l.Close() - }) + p := setupTestListener(t, c.setupRemote(t)) localAddress, localFlag := c.setupLocal(t) dials = append(dials, addr{ @@ -332,10 +333,9 @@ func TestPortForward(t *testing.T) { cmd.SetOut(io.MultiWriter(buf, os.Stderr)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() + errC := make(chan error) go func() { - err := cmd.ExecuteContext(ctx) - assert.Error(t, err) - assert.ErrorIs(t, err, context.Canceled) + errC <- cmd.ExecuteContext(ctx) }() waitForPortForwardReady(t, buf) @@ -357,6 +357,10 @@ func TestPortForward(t *testing.T) { for i := len(conns) - 1; i >= 0; i-- { testDial(t, conns[i]) } + + cancel() + err := <-errC + require.ErrorIs(t, err, context.Canceled) }) } @@ -400,11 +404,15 @@ func runAgent(t *testing.T, client *codersdk.Client, userID uuid.UUID) ([]coders // Start workspace agent in a goroutine cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String()) clitest.SetupConfig(t, client, root) + errC := make(chan error) agentCtx, agentCancel := context.WithCancel(ctx) - t.Cleanup(agentCancel) + t.Cleanup(func() { + agentCancel() + err := <-errC + require.NoError(t, err) + }) go func() { - err := cmd.ExecuteContext(agentCtx) - assert.NoError(t, err) + errC <- cmd.ExecuteContext(agentCtx) }() coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) @@ -416,18 +424,30 @@ func runAgent(t *testing.T, client *codersdk.Client, userID uuid.UUID) ([]coders // setupTestListener starts accepting connections and echoing a single packet. // Returns the listener and the listen port or Unix path. -func setupTestListener(t *testing.T, l net.Listener) (net.Listener, string) { +func setupTestListener(t *testing.T, l net.Listener) string { + // Wait for listener to completely exit before releasing. + done := make(chan struct{}) t.Cleanup(func() { _ = l.Close() + <-done }) go func() { + defer close(done) + // Guard against testAccept running require after test completion. + var wg sync.WaitGroup + defer wg.Wait() + for { c, err := l.Accept() if err != nil { return } - go testAccept(t, c) + wg.Add(1) + go func() { + testAccept(t, c) + wg.Done() + }() } }() @@ -438,7 +458,7 @@ func setupTestListener(t *testing.T, l net.Listener) (net.Listener, string) { addr = port } - return l, addr + return addr } var dialTestPayload = []byte("dean-was-here123") @@ -502,8 +522,10 @@ func newThreadSafeBuffer() *threadSafeBuffer { } } -var _ io.Reader = &threadSafeBuffer{} -var _ io.Writer = &threadSafeBuffer{} +var ( + _ io.Reader = &threadSafeBuffer{} + _ io.Writer = &threadSafeBuffer{} +) // Read implements io.Reader. func (b *threadSafeBuffer) Read(p []byte) (int, error) {