@@ -273,7 +273,7 @@ func TestAgent_Session_TTY_Hushlogin(t *testing.T) {
273
273
}
274
274
275
275
//nolint:paralleltest // This test reserves a port.
276
- func TestAgent_LocalForwarding (t * testing.T ) {
276
+ func TestAgent_TCPLocalForwarding (t * testing.T ) {
277
277
random , err := net .Listen ("tcp" , "127.0.0.1:0" )
278
278
require .NoError (t , err )
279
279
_ = random .Close ()
@@ -286,24 +286,239 @@ func TestAgent_LocalForwarding(t *testing.T) {
286
286
defer local .Close ()
287
287
tcpAddr , valid = local .Addr ().(* net.TCPAddr )
288
288
require .True (t , valid )
289
- localPort := tcpAddr .Port
289
+ remotePort := tcpAddr .Port
290
290
done := make (chan struct {})
291
291
go func () {
292
292
defer close (done )
293
293
conn , err := local .Accept ()
294
294
if ! assert .NoError (t , err ) {
295
295
return
296
296
}
297
- _ = conn .Close ()
297
+ defer conn .Close ()
298
+ b := make ([]byte , 4 )
299
+ _ , err = conn .Read (b )
300
+ if ! assert .NoError (t , err ) {
301
+ return
302
+ }
303
+ _ , err = conn .Write (b )
304
+ if ! assert .NoError (t , err ) {
305
+ return
306
+ }
298
307
}()
299
308
300
- err = setupSSHCommand (t , []string {"-L" , fmt .Sprintf ("%d:127.0.0.1:%d" , randomPort , localPort )}, []string {"echo" , "test" }).Start ()
309
+ cmd := setupSSHCommand (t , []string {"-L" , fmt .Sprintf ("%d:127.0.0.1:%d" , randomPort , remotePort )}, []string {"sleep" , "10" })
310
+ err = cmd .Start ()
311
+ require .NoError (t , err )
312
+
313
+ require .Eventually (t , func () bool {
314
+ conn , err := net .Dial ("tcp" , "127.0.0.1:" + strconv .Itoa (randomPort ))
315
+ if err != nil {
316
+ return false
317
+ }
318
+ defer conn .Close ()
319
+ _ , err = conn .Write ([]byte ("test" ))
320
+ if ! assert .NoError (t , err ) {
321
+ return false
322
+ }
323
+ b := make ([]byte , 4 )
324
+ _ , err = conn .Read (b )
325
+ if ! assert .NoError (t , err ) {
326
+ return false
327
+ }
328
+ if ! assert .Equal (t , "test" , string (b )) {
329
+ return false
330
+ }
331
+
332
+ return true
333
+ }, testutil .WaitLong , testutil .IntervalSlow )
334
+
335
+ <- done
336
+
337
+ _ = cmd .Process .Kill ()
338
+ }
339
+
340
+ //nolint:paralleltest // This test reserves a port.
341
+ func TestAgent_TCPRemoteForwarding (t * testing.T ) {
342
+ random , err := net .Listen ("tcp" , "127.0.0.1:0" )
301
343
require .NoError (t , err )
344
+ _ = random .Close ()
345
+ tcpAddr , valid := random .Addr ().(* net.TCPAddr )
346
+ require .True (t , valid )
347
+ randomPort := tcpAddr .Port
302
348
303
- conn , err := net .Dial ("tcp" , "127.0.0.1:" + strconv . Itoa ( localPort ) )
349
+ l , err := net .Listen ("tcp" , "127.0.0.1:0" )
304
350
require .NoError (t , err )
305
- conn .Close ()
351
+ defer l .Close ()
352
+ tcpAddr , valid = l .Addr ().(* net.TCPAddr )
353
+ require .True (t , valid )
354
+ localPort := tcpAddr .Port
355
+
356
+ done := make (chan struct {})
357
+ go func () {
358
+ defer close (done )
359
+
360
+ conn , err := l .Accept ()
361
+ if err != nil {
362
+ return
363
+ }
364
+ defer conn .Close ()
365
+ b := make ([]byte , 4 )
366
+ _ , err = conn .Read (b )
367
+ if ! assert .NoError (t , err ) {
368
+ return
369
+ }
370
+ _ , err = conn .Write (b )
371
+ if ! assert .NoError (t , err ) {
372
+ return
373
+ }
374
+ }()
375
+
376
+ cmd := setupSSHCommand (t , []string {"-R" , fmt .Sprintf ("127.0.0.1:%d:127.0.0.1:%d" , randomPort , localPort )}, []string {"sleep" , "10" })
377
+ err = cmd .Start ()
378
+ require .NoError (t , err )
379
+
380
+ require .Eventually (t , func () bool {
381
+ conn , err := net .Dial ("tcp" , fmt .Sprintf ("127.0.0.1:%d" , randomPort ))
382
+ if err != nil {
383
+ return false
384
+ }
385
+ defer conn .Close ()
386
+ _ , err = conn .Write ([]byte ("test" ))
387
+ if ! assert .NoError (t , err ) {
388
+ return false
389
+ }
390
+ b := make ([]byte , 4 )
391
+ _ , err = conn .Read (b )
392
+ if ! assert .NoError (t , err ) {
393
+ return false
394
+ }
395
+ if ! assert .Equal (t , "test" , string (b )) {
396
+ return false
397
+ }
398
+
399
+ return true
400
+ }, testutil .WaitLong , testutil .IntervalSlow )
401
+
306
402
<- done
403
+
404
+ _ = cmd .Process .Kill ()
405
+ }
406
+
407
+ func TestAgent_UnixLocalForwarding (t * testing.T ) {
408
+ t .Parallel ()
409
+ if runtime .GOOS == "windows" {
410
+ t .Skip ("unix domain sockets are not fully supported on Windows" )
411
+ }
412
+
413
+ tmpdir := tempDirUnixSocket (t )
414
+ remoteSocketPath := filepath .Join (tmpdir , "remote-socket" )
415
+ localSocketPath := filepath .Join (tmpdir , "local-socket" )
416
+
417
+ l , err := net .Listen ("unix" , remoteSocketPath )
418
+ require .NoError (t , err )
419
+ defer l .Close ()
420
+
421
+ done := make (chan struct {})
422
+ go func () {
423
+ defer close (done )
424
+
425
+ conn , err := l .Accept ()
426
+ if err != nil {
427
+ return
428
+ }
429
+ defer conn .Close ()
430
+ b := make ([]byte , 4 )
431
+ _ , err = conn .Read (b )
432
+ if ! assert .NoError (t , err ) {
433
+ return
434
+ }
435
+ _ , err = conn .Write (b )
436
+ if ! assert .NoError (t , err ) {
437
+ return
438
+ }
439
+ }()
440
+
441
+ cmd := setupSSHCommand (t , []string {"-L" , fmt .Sprintf ("%s:%s" , localSocketPath , remoteSocketPath )}, []string {"sleep" , "10" })
442
+ err = cmd .Start ()
443
+ require .NoError (t , err )
444
+
445
+ require .Eventually (t , func () bool {
446
+ _ , err := os .Stat (localSocketPath )
447
+ return err == nil
448
+ }, testutil .WaitLong , testutil .IntervalFast )
449
+
450
+ conn , err := net .Dial ("unix" , localSocketPath )
451
+ require .NoError (t , err )
452
+ defer conn .Close ()
453
+ _ , err = conn .Write ([]byte ("test" ))
454
+ require .NoError (t , err )
455
+ b := make ([]byte , 4 )
456
+ _ , err = conn .Read (b )
457
+ require .NoError (t , err )
458
+ require .Equal (t , "test" , string (b ))
459
+ _ = conn .Close ()
460
+ <- done
461
+
462
+ _ = cmd .Process .Kill ()
463
+ }
464
+
465
+ func TestAgent_UnixRemoteForwarding (t * testing.T ) {
466
+ t .Parallel ()
467
+ if runtime .GOOS == "windows" {
468
+ t .Skip ("unix domain sockets are not fully supported on Windows" )
469
+ }
470
+
471
+ tmpdir := tempDirUnixSocket (t )
472
+ remoteSocketPath := filepath .Join (tmpdir , "remote-socket" )
473
+ localSocketPath := filepath .Join (tmpdir , "local-socket" )
474
+
475
+ l , err := net .Listen ("unix" , localSocketPath )
476
+ require .NoError (t , err )
477
+ defer l .Close ()
478
+
479
+ done := make (chan struct {})
480
+ go func () {
481
+ defer close (done )
482
+
483
+ conn , err := l .Accept ()
484
+ if err != nil {
485
+ return
486
+ }
487
+ defer conn .Close ()
488
+ b := make ([]byte , 4 )
489
+ _ , err = conn .Read (b )
490
+ if ! assert .NoError (t , err ) {
491
+ return
492
+ }
493
+ _ , err = conn .Write (b )
494
+ if ! assert .NoError (t , err ) {
495
+ return
496
+ }
497
+ }()
498
+
499
+ cmd := setupSSHCommand (t , []string {"-R" , fmt .Sprintf ("%s:%s" , remoteSocketPath , localSocketPath )}, []string {"sleep" , "10" })
500
+ err = cmd .Start ()
501
+ require .NoError (t , err )
502
+
503
+ require .Eventually (t , func () bool {
504
+ _ , err := os .Stat (remoteSocketPath )
505
+ return err == nil
506
+ }, testutil .WaitLong , testutil .IntervalFast )
507
+
508
+ conn , err := net .Dial ("unix" , remoteSocketPath )
509
+ require .NoError (t , err )
510
+ defer conn .Close ()
511
+ _ , err = conn .Write ([]byte ("test" ))
512
+ require .NoError (t , err )
513
+ b := make ([]byte , 4 )
514
+ _ , err = conn .Read (b )
515
+ require .NoError (t , err )
516
+ require .Equal (t , "test" , string (b ))
517
+ _ = conn .Close ()
518
+
519
+ <- done
520
+
521
+ _ = cmd .Process .Kill ()
307
522
}
308
523
309
524
func TestAgent_SFTP (t * testing.T ) {
@@ -733,7 +948,10 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
733
948
args := append (beforeArgs ,
734
949
"-o" , "HostName " + tcpAddr .IP .String (),
735
950
"-o" , "Port " + strconv .Itoa (tcpAddr .Port ),
736
- "-o" , "StrictHostKeyChecking=no" , "host" )
951
+ "-o" , "StrictHostKeyChecking=no" ,
952
+ "-o" , "UserKnownHostsFile=/dev/null" ,
953
+ "host" ,
954
+ )
737
955
args = append (args , afterArgs ... )
738
956
return exec .Command ("ssh" , args ... )
739
957
}
@@ -919,3 +1137,26 @@ func (*client) PostWorkspaceAgentAppHealth(_ context.Context, _ codersdk.PostWor
919
1137
func (* client ) PostWorkspaceAgentVersion (_ context.Context , _ string ) error {
920
1138
return nil
921
1139
}
1140
+
1141
+ // tempDirUnixSocket returns a temporary directory that can safely hold unix
1142
+ // sockets (probably).
1143
+ //
1144
+ // During tests on darwin we hit the max path length limit for unix sockets
1145
+ // pretty easily in the default location, so this function uses /tmp instead to
1146
+ // get shorter paths.
1147
+ func tempDirUnixSocket (t * testing.T ) string {
1148
+ t .Helper ()
1149
+ if runtime .GOOS == "darwin" {
1150
+ testName := strings .ReplaceAll (t .Name (), "/" , "_" )
1151
+ dir , err := os .MkdirTemp ("/tmp" , fmt .Sprintf ("coder-test-%s-" , testName ))
1152
+ require .NoError (t , err , "create temp dir for gpg test" )
1153
+
1154
+ t .Cleanup (func () {
1155
+ err := os .RemoveAll (dir )
1156
+ assert .NoError (t , err , "remove temp dir" , dir )
1157
+ })
1158
+ return dir
1159
+ }
1160
+
1161
+ return t .TempDir ()
1162
+ }
0 commit comments