@@ -387,16 +387,17 @@ func (c *Conn) Closed() <-chan struct{} {
387
387
// Close shuts down the Wireguard connection.
388
388
func (c * Conn ) Close () error {
389
389
c .mutex .Lock ()
390
- defer c .mutex .Unlock ()
391
390
select {
392
391
case <- c .closed :
392
+ c .mutex .Unlock ()
393
393
return nil
394
394
default :
395
395
}
396
+ close (c .closed )
396
397
for _ , l := range c .listeners {
397
398
_ = l .closeNoLock ()
398
399
}
399
- close ( c . closed )
400
+ c . mutex . Unlock ( )
400
401
_ = c .dialer .Close ()
401
402
_ = c .magicConn .Close ()
402
403
_ = c .netStack .Close ()
@@ -406,6 +407,15 @@ func (c *Conn) Close() error {
406
407
return nil
407
408
}
408
409
410
+ func (c * Conn ) isClosed () bool {
411
+ select {
412
+ case <- c .closed :
413
+ return true
414
+ default :
415
+ return false
416
+ }
417
+ }
418
+
409
419
// This and below is taken _mostly_ verbatim from Tailscale:
410
420
// https://github.com/tailscale/tailscale/blob/c88bd53b1b7b2fcf7ba302f2e53dd1ce8c32dad4/tsnet/tsnet.go#L459-L494
411
421
@@ -422,9 +432,14 @@ func (c *Conn) Listen(network, addr string) (net.Listener, error) {
422
432
key : lk ,
423
433
addr : addr ,
424
434
425
- conn : make (chan net.Conn ),
435
+ closed : make (chan struct {}),
436
+ conn : make (chan net.Conn ),
426
437
}
427
438
c .mutex .Lock ()
439
+ if c .isClosed () {
440
+ c .mutex .Unlock ()
441
+ return nil , xerrors .New ("closed" )
442
+ }
428
443
if c .listeners == nil {
429
444
c .listeners = map [listenKey ]* listener {}
430
445
}
@@ -460,9 +475,12 @@ func (c *Conn) forwardTCP(conn net.Conn, port uint16) {
460
475
defer t .Stop ()
461
476
select {
462
477
case ln .conn <- conn :
478
+ return
479
+ case <- ln .closed :
480
+ case <- c .closed :
463
481
case <- t .C :
464
- _ = conn .Close ()
465
482
}
483
+ _ = conn .Close ()
466
484
}
467
485
468
486
func (c * Conn ) forwardTCPToLocal (conn net.Conn , port uint16 ) {
@@ -506,15 +524,18 @@ type listenKey struct {
506
524
}
507
525
508
526
type listener struct {
509
- s * Conn
510
- key listenKey
511
- addr string
512
- conn chan net.Conn
527
+ s * Conn
528
+ key listenKey
529
+ addr string
530
+ conn chan net.Conn
531
+ closed chan struct {}
513
532
}
514
533
515
534
func (ln * listener ) Accept () (net.Conn , error ) {
516
- c , ok := <- ln .conn
517
- if ! ok {
535
+ var c net.Conn
536
+ select {
537
+ case c = <- ln .conn :
538
+ case <- ln .closed :
518
539
return nil , xerrors .Errorf ("wgnet: %w" , net .ErrClosed )
519
540
}
520
541
return c , nil
@@ -530,7 +551,7 @@ func (ln *listener) Close() error {
530
551
func (ln * listener ) closeNoLock () error {
531
552
if v , ok := ln .s .listeners [ln .key ]; ok && v == ln {
532
553
delete (ln .s .listeners , ln .key )
533
- close (ln .conn )
554
+ close (ln .closed )
534
555
}
535
556
return nil
536
557
}
0 commit comments