Skip to content

Commit 9419c9a

Browse files
mafredrikylecarbs
authored andcommitted
fix: Try to fix cli portforward test flakes (#1650)
* fix: Try to fix cli portforward test flakes * fix: Guard against agent exit outside test func * fix: Improve test teardown in setupTestListener, cleanup
1 parent c11caf3 commit 9419c9a

File tree

1 file changed

+63
-41
lines changed

1 file changed

+63
-41
lines changed

cli/portforward_test.go

+63-41
Original file line numberDiff line numberDiff line change
@@ -144,20 +144,19 @@ func TestPortForward(t *testing.T) {
144144

145145
for _, c := range cases { //nolint:paralleltest // the `c := c` confuses the linter
146146
c := c
147+
// Avoid parallel test here because setupLocal reserves
148+
// a free open port which is not guaranteed to be free
149+
// after the listener closes.
150+
//nolint:paralleltest
147151
t.Run(c.name, func(t *testing.T) {
148-
t.Parallel()
149-
152+
//nolint:paralleltest
150153
t.Run("OnePort", func(t *testing.T) {
151-
t.Parallel()
152154
var (
153155
client = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true})
154156
user = coderdtest.CreateFirstUser(t, client)
155157
_, workspace = runAgent(t, client, user.UserID)
156-
l1, p1 = setupTestListener(t, c.setupRemote(t))
158+
p1 = setupTestListener(t, c.setupRemote(t))
157159
)
158-
t.Cleanup(func() {
159-
_ = l1.Close()
160-
})
161160

162161
// Create a flag that forwards from local to listener 1.
163162
localAddress, localFlag := c.setupLocal(t)
@@ -171,9 +170,9 @@ func TestPortForward(t *testing.T) {
171170
cmd.SetOut(io.MultiWriter(buf, os.Stderr))
172171
ctx, cancel := context.WithCancel(context.Background())
173172
defer cancel()
173+
errC := make(chan error)
174174
go func() {
175-
err := cmd.ExecuteContext(ctx)
176-
assert.ErrorIs(t, err, context.Canceled)
175+
errC <- cmd.ExecuteContext(ctx)
177176
}()
178177
waitForPortForwardReady(t, buf)
179178

@@ -188,21 +187,21 @@ func TestPortForward(t *testing.T) {
188187
defer c2.Close()
189188
testDial(t, c2)
190189
testDial(t, c1)
190+
191+
cancel()
192+
err = <-errC
193+
require.ErrorIs(t, err, context.Canceled)
191194
})
192195

196+
//nolint:paralleltest
193197
t.Run("TwoPorts", func(t *testing.T) {
194-
t.Parallel()
195198
var (
196199
client = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true})
197200
user = coderdtest.CreateFirstUser(t, client)
198201
_, workspace = runAgent(t, client, user.UserID)
199-
l1, p1 = setupTestListener(t, c.setupRemote(t))
200-
l2, p2 = setupTestListener(t, c.setupRemote(t))
202+
p1 = setupTestListener(t, c.setupRemote(t))
203+
p2 = setupTestListener(t, c.setupRemote(t))
201204
)
202-
t.Cleanup(func() {
203-
_ = l1.Close()
204-
_ = l2.Close()
205-
})
206205

207206
// Create a flags for listener 1 and listener 2.
208207
localAddress1, localFlag1 := c.setupLocal(t)
@@ -218,9 +217,9 @@ func TestPortForward(t *testing.T) {
218217
cmd.SetOut(io.MultiWriter(buf, os.Stderr))
219218
ctx, cancel := context.WithCancel(context.Background())
220219
defer cancel()
220+
errC := make(chan error)
221221
go func() {
222-
err := cmd.ExecuteContext(ctx)
223-
assert.ErrorIs(t, err, context.Canceled)
222+
errC <- cmd.ExecuteContext(ctx)
224223
}()
225224
waitForPortForwardReady(t, buf)
226225

@@ -235,13 +234,17 @@ func TestPortForward(t *testing.T) {
235234
defer c2.Close()
236235
testDial(t, c2)
237236
testDial(t, c1)
237+
238+
cancel()
239+
err = <-errC
240+
require.ErrorIs(t, err, context.Canceled)
238241
})
239242
})
240243
}
241244

242245
// Test doing a TCP -> Unix forward.
246+
//nolint:paralleltest
243247
t.Run("TCP2Unix", func(t *testing.T) {
244-
t.Parallel()
245248
var (
246249
client = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true})
247250
user = coderdtest.CreateFirstUser(t, client)
@@ -253,11 +256,8 @@ func TestPortForward(t *testing.T) {
253256
unixCase = cases[2]
254257

255258
// Setup remote Unix listener.
256-
l1, p1 = setupTestListener(t, unixCase.setupRemote(t))
259+
p1 = setupTestListener(t, unixCase.setupRemote(t))
257260
)
258-
t.Cleanup(func() {
259-
_ = l1.Close()
260-
})
261261

262262
// Create a flag that forwards from local TCP to Unix listener 1.
263263
// Notably this is a --unix flag.
@@ -272,9 +272,9 @@ func TestPortForward(t *testing.T) {
272272
cmd.SetOut(io.MultiWriter(buf, os.Stderr))
273273
ctx, cancel := context.WithCancel(context.Background())
274274
defer cancel()
275+
errC := make(chan error)
275276
go func() {
276-
err := cmd.ExecuteContext(ctx)
277-
assert.ErrorIs(t, err, context.Canceled)
277+
errC <- cmd.ExecuteContext(ctx)
278278
}()
279279
waitForPortForwardReady(t, buf)
280280

@@ -289,11 +289,15 @@ func TestPortForward(t *testing.T) {
289289
defer c2.Close()
290290
testDial(t, c2)
291291
testDial(t, c1)
292+
293+
cancel()
294+
err = <-errC
295+
require.ErrorIs(t, err, context.Canceled)
292296
})
293297

294298
// Test doing TCP, UDP and Unix at the same time.
299+
//nolint:paralleltest
295300
t.Run("All", func(t *testing.T) {
296-
t.Parallel()
297301
var (
298302
client = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true})
299303
user = coderdtest.CreateFirstUser(t, client)
@@ -311,10 +315,7 @@ func TestPortForward(t *testing.T) {
311315
continue
312316
}
313317

314-
l, p := setupTestListener(t, c.setupRemote(t))
315-
t.Cleanup(func() {
316-
_ = l.Close()
317-
})
318+
p := setupTestListener(t, c.setupRemote(t))
318319

319320
localAddress, localFlag := c.setupLocal(t)
320321
dials = append(dials, addr{
@@ -332,10 +333,9 @@ func TestPortForward(t *testing.T) {
332333
cmd.SetOut(io.MultiWriter(buf, os.Stderr))
333334
ctx, cancel := context.WithCancel(context.Background())
334335
defer cancel()
336+
errC := make(chan error)
335337
go func() {
336-
err := cmd.ExecuteContext(ctx)
337-
assert.Error(t, err)
338-
assert.ErrorIs(t, err, context.Canceled)
338+
errC <- cmd.ExecuteContext(ctx)
339339
}()
340340
waitForPortForwardReady(t, buf)
341341

@@ -357,6 +357,10 @@ func TestPortForward(t *testing.T) {
357357
for i := len(conns) - 1; i >= 0; i-- {
358358
testDial(t, conns[i])
359359
}
360+
361+
cancel()
362+
err := <-errC
363+
require.ErrorIs(t, err, context.Canceled)
360364
})
361365
}
362366

@@ -400,11 +404,15 @@ func runAgent(t *testing.T, client *codersdk.Client, userID uuid.UUID) ([]coders
400404
// Start workspace agent in a goroutine
401405
cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String())
402406
clitest.SetupConfig(t, client, root)
407+
errC := make(chan error)
403408
agentCtx, agentCancel := context.WithCancel(ctx)
404-
t.Cleanup(agentCancel)
409+
t.Cleanup(func() {
410+
agentCancel()
411+
err := <-errC
412+
require.NoError(t, err)
413+
})
405414
go func() {
406-
err := cmd.ExecuteContext(agentCtx)
407-
assert.NoError(t, err)
415+
errC <- cmd.ExecuteContext(agentCtx)
408416
}()
409417

410418
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
@@ -416,18 +424,30 @@ func runAgent(t *testing.T, client *codersdk.Client, userID uuid.UUID) ([]coders
416424

417425
// setupTestListener starts accepting connections and echoing a single packet.
418426
// Returns the listener and the listen port or Unix path.
419-
func setupTestListener(t *testing.T, l net.Listener) (net.Listener, string) {
427+
func setupTestListener(t *testing.T, l net.Listener) string {
428+
// Wait for listener to completely exit before releasing.
429+
done := make(chan struct{})
420430
t.Cleanup(func() {
421431
_ = l.Close()
432+
<-done
422433
})
423434
go func() {
435+
defer close(done)
436+
// Guard against testAccept running require after test completion.
437+
var wg sync.WaitGroup
438+
defer wg.Wait()
439+
424440
for {
425441
c, err := l.Accept()
426442
if err != nil {
427443
return
428444
}
429445

430-
go testAccept(t, c)
446+
wg.Add(1)
447+
go func() {
448+
testAccept(t, c)
449+
wg.Done()
450+
}()
431451
}
432452
}()
433453

@@ -438,7 +458,7 @@ func setupTestListener(t *testing.T, l net.Listener) (net.Listener, string) {
438458
addr = port
439459
}
440460

441-
return l, addr
461+
return addr
442462
}
443463

444464
var dialTestPayload = []byte("dean-was-here123")
@@ -502,8 +522,10 @@ func newThreadSafeBuffer() *threadSafeBuffer {
502522
}
503523
}
504524

505-
var _ io.Reader = &threadSafeBuffer{}
506-
var _ io.Writer = &threadSafeBuffer{}
525+
var (
526+
_ io.Reader = &threadSafeBuffer{}
527+
_ io.Writer = &threadSafeBuffer{}
528+
)
507529

508530
// Read implements io.Reader.
509531
func (b *threadSafeBuffer) Read(p []byte) (int, error) {

0 commit comments

Comments
 (0)