@@ -20,6 +20,7 @@ import (
20
20
"regexp"
21
21
"runtime"
22
22
"strings"
23
+ "sync"
23
24
"testing"
24
25
"time"
25
26
@@ -1318,9 +1319,6 @@ func TestSSH(t *testing.T) {
1318
1319
1319
1320
tmpdir := tempDirUnixSocket (t )
1320
1321
localSock := filepath .Join (tmpdir , "local.sock" )
1321
- l , err := net .Listen ("unix" , localSock )
1322
- require .NoError (t , err )
1323
- defer l .Close ()
1324
1322
remoteSock := filepath .Join (tmpdir , "remote.sock" )
1325
1323
1326
1324
inv , root := clitest .New (t ,
@@ -1332,23 +1330,62 @@ func TestSSH(t *testing.T) {
1332
1330
clitest .SetupConfig (t , client , root )
1333
1331
pty := ptytest .New (t ).Attach (inv )
1334
1332
inv .Stderr = pty .Output ()
1335
- cmdDone := tGo (t , func () {
1336
- err := inv .WithContext (ctx ).Run ()
1337
- assert .NoError (t , err , "ssh command failed" )
1338
- })
1339
1333
1340
- // Wait for the prompt or any output really to indicate the command has
1341
- // started and accepting input on stdin.
1334
+ w := clitest .StartWithWaiter (t , inv .WithContext (ctx ))
1335
+ defer w .Wait () // We don't care about any exit error (exit code 255: SSH connection ended unexpectedly).
1336
+
1337
+ // Since something was output, it should be safe to write input.
1338
+ // This could show a prompt or "running startup scripts", so it's
1339
+ // not indicative of the SSH connection being ready.
1342
1340
_ = pty .Peek (ctx , 1 )
1343
1341
1344
- // This needs to support most shells on Linux or macOS
1345
- // We can't include exactly what's expected in the input, as that will always be matched
1346
- pty .WriteLine (fmt .Sprintf (`echo "results: $(netstat -an | grep %s | wc -l | tr -d ' ')"` , remoteSock ))
1347
- pty .ExpectMatchContext (ctx , "results: 1" )
1342
+ // Ensure the SSH connection is ready by testing the shell
1343
+ // input/output.
1344
+ pty .WriteLine ("echo ping' 'pong" )
1345
+ pty .ExpectMatchContext (ctx , "ping pong" )
1346
+
1347
+ // Start the listener on the "local machine".
1348
+ l , err := net .Listen ("unix" , localSock )
1349
+ require .NoError (t , err )
1350
+ defer l .Close ()
1351
+ testutil .Go (t , func () {
1352
+ var wg sync.WaitGroup
1353
+ defer wg .Wait ()
1354
+ for {
1355
+ fd , err := l .Accept ()
1356
+ if err != nil {
1357
+ if ! errors .Is (err , net .ErrClosed ) {
1358
+ assert .NoError (t , err , "listener accept failed" )
1359
+ }
1360
+ return
1361
+ }
1362
+
1363
+ wg .Add (1 )
1364
+ go func () {
1365
+ defer wg .Done ()
1366
+ defer fd .Close ()
1367
+ agentssh .Bicopy (ctx , fd , fd )
1368
+ }()
1369
+ }
1370
+ })
1371
+
1372
+ // Dial the forwarded socket on the "remote machine".
1373
+ d := & net.Dialer {}
1374
+ fd , err := d .DialContext (ctx , "unix" , remoteSock )
1375
+ require .NoError (t , err )
1376
+ defer fd .Close ()
1377
+
1378
+ // Ping / pong to ensure the socket is working.
1379
+ _ , err = fd .Write ([]byte ("hello world" ))
1380
+ require .NoError (t , err )
1381
+
1382
+ buf := make ([]byte , 11 )
1383
+ _ , err = fd .Read (buf )
1384
+ require .NoError (t , err )
1385
+ require .Equal (t , "hello world" , string (buf ))
1348
1386
1349
1387
// And we're done.
1350
1388
pty .WriteLine ("exit" )
1351
- <- cmdDone
1352
1389
})
1353
1390
1354
1391
// Test that we can forward a local unix socket to a remote unix socket and
@@ -1377,6 +1414,8 @@ func TestSSH(t *testing.T) {
1377
1414
require .NoError (t , err )
1378
1415
defer l .Close ()
1379
1416
testutil .Go (t , func () {
1417
+ var wg sync.WaitGroup
1418
+ defer wg .Wait ()
1380
1419
for {
1381
1420
fd , err := l .Accept ()
1382
1421
if err != nil {
@@ -1386,10 +1425,12 @@ func TestSSH(t *testing.T) {
1386
1425
return
1387
1426
}
1388
1427
1389
- testutil .Go (t , func () {
1428
+ wg .Add (1 )
1429
+ go func () {
1430
+ defer wg .Done ()
1390
1431
defer fd .Close ()
1391
1432
agentssh .Bicopy (ctx , fd , fd )
1392
- })
1433
+ }( )
1393
1434
}
1394
1435
})
1395
1436
@@ -1522,6 +1563,8 @@ func TestSSH(t *testing.T) {
1522
1563
require .NoError (t , err )
1523
1564
defer l .Close () //nolint:revive // Defer is fine in this loop, we only run it twice.
1524
1565
testutil .Go (t , func () {
1566
+ var wg sync.WaitGroup
1567
+ defer wg .Wait ()
1525
1568
for {
1526
1569
fd , err := l .Accept ()
1527
1570
if err != nil {
@@ -1531,10 +1574,12 @@ func TestSSH(t *testing.T) {
1531
1574
return
1532
1575
}
1533
1576
1534
- testutil .Go (t , func () {
1577
+ wg .Add (1 )
1578
+ go func () {
1579
+ defer wg .Done ()
1535
1580
defer fd .Close ()
1536
1581
agentssh .Bicopy (ctx , fd , fd )
1537
- })
1582
+ }( )
1538
1583
}
1539
1584
})
1540
1585
0 commit comments