@@ -12,10 +12,9 @@ import (
12
12
"net/netip"
13
13
"strconv"
14
14
"strings"
15
+ "sync"
15
16
"time"
16
17
17
- "golang.org/x/sync/errgroup"
18
-
19
18
"github.com/google/uuid"
20
19
"golang.org/x/xerrors"
21
20
"nhooyr.io/websocket"
@@ -360,6 +359,15 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
360
359
return agentConn , nil
361
360
}
362
361
362
+ // tailnetConn is the subset of the tailnet.Conn methods that tailnetAPIConnector uses. It is
363
+ // included so that we can fake it in testing.
364
+ //
365
+ // @typescript-ignore tailnetConn
366
+ type tailnetConn interface {
367
+ tailnet.Coordinatee
368
+ SetDERPMap (derpMap * tailcfg.DERPMap )
369
+ }
370
+
363
371
// tailnetAPIConnector dials the tailnet API (v2+) and then uses the API with a tailnet.Conn to
364
372
//
365
373
// 1) run the Coordinate API and pass node information back and forth
@@ -370,13 +378,20 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
370
378
//
371
379
// @typescript-ignore tailnetAPIConnector
372
380
type tailnetAPIConnector struct {
373
- ctx context.Context
381
+ // We keep track of two contexts: the main context from the caller, and a "graceful" context
382
+ // that we keep open slightly longer than the main context to give a chance to send the
383
+ // Disconnect message to the coordinator. That tells the coordinator that we really meant to
384
+ // disconnect instead of just losing network connectivity.
385
+ ctx context.Context
386
+ gracefulCtx context.Context
387
+ cancelGracefulCtx context.CancelFunc
388
+
374
389
logger slog.Logger
375
390
376
391
agentID uuid.UUID
377
392
coordinateURL string
378
393
dialOptions * websocket.DialOptions
379
- conn * tailnet. Conn
394
+ conn tailnetConn
380
395
381
396
connected chan error
382
397
isFirst bool
@@ -387,7 +402,7 @@ type tailnetAPIConnector struct {
387
402
func runTailnetAPIConnector (
388
403
ctx context.Context , logger slog.Logger ,
389
404
agentID uuid.UUID , coordinateURL string , dialOptions * websocket.DialOptions ,
390
- conn * tailnet. Conn ,
405
+ conn tailnetConn ,
391
406
) * tailnetAPIConnector {
392
407
tac := & tailnetAPIConnector {
393
408
ctx : ctx ,
@@ -399,10 +414,23 @@ func runTailnetAPIConnector(
399
414
connected : make (chan error , 1 ),
400
415
closed : make (chan struct {}),
401
416
}
417
+ tac .gracefulCtx , tac .cancelGracefulCtx = context .WithCancel (context .Background ())
418
+ go tac .manageGracefulTimeout ()
402
419
go tac .run ()
403
420
return tac
404
421
}
405
422
423
+ // manageGracefulTimeout allows the gracefulContext to last 1 second longer than the main context
424
+ // to allow a graceful disconnect.
425
+ func (tac * tailnetAPIConnector ) manageGracefulTimeout () {
426
+ defer tac .cancelGracefulCtx ()
427
+ <- tac .ctx .Done ()
428
+ select {
429
+ case <- tac .closed :
430
+ case <- time .After (time .Second ):
431
+ }
432
+ }
433
+
406
434
func (tac * tailnetAPIConnector ) run () {
407
435
tac .isFirst = true
408
436
defer close (tac .closed )
@@ -437,7 +465,7 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
437
465
return nil , err
438
466
}
439
467
client , err := tailnet .NewDRPCClient (
440
- websocket .NetConn (tac .ctx , ws , websocket .MessageBinary ),
468
+ websocket .NetConn (tac .gracefulCtx , ws , websocket .MessageBinary ),
441
469
tac .logger ,
442
470
)
443
471
if err != nil {
@@ -464,65 +492,81 @@ func (tac *tailnetAPIConnector) coordinateAndDERPMap(client proto.DRPCTailnetCli
464
492
<- conn .Closed ()
465
493
}
466
494
}()
467
- eg , egCtx := errgroup .WithContext (tac .ctx )
468
- eg .Go (func () error {
469
- return tac .coordinate (egCtx , client )
470
- })
471
- eg .Go (func () error {
472
- return tac .derpMap (egCtx , client )
473
- })
474
- err := eg .Wait ()
475
- if err != nil &&
476
- ! xerrors .Is (err , io .EOF ) &&
477
- ! xerrors .Is (err , context .Canceled ) &&
478
- ! xerrors .Is (err , context .DeadlineExceeded ) {
479
- tac .logger .Error (tac .ctx , "error while connected to tailnet v2+ API" )
480
- }
495
+ wg := sync.WaitGroup {}
496
+ wg .Add (2 )
497
+ go func () {
498
+ defer wg .Done ()
499
+ tac .coordinate (client )
500
+ }()
501
+ go func () {
502
+ defer wg .Done ()
503
+ dErr := tac .derpMap (client )
504
+ if dErr != nil && tac .ctx .Err () == nil {
505
+ // The main context is still active, meaning that we want the tailnet data plane to stay
506
+ // up, even though we hit some error getting DERP maps on the control plane. That means
507
+ // we do NOT want to gracefully disconnect on the coordinate() routine. So, we'll just
508
+ // close the underlying connection. This will trigger a retry of the control plane in
509
+ // run().
510
+ client .DRPCConn ().Close ()
511
+ // Note that derpMap() logs it own errors, we don't bother here.
512
+ }
513
+ }()
514
+ wg .Wait ()
481
515
}
482
516
483
- func (tac * tailnetAPIConnector ) coordinate (ctx context.Context , client proto.DRPCTailnetClient ) error {
484
- coord , err := client .Coordinate (ctx )
517
+ func (tac * tailnetAPIConnector ) coordinate (client proto.DRPCTailnetClient ) {
518
+ // we use the gracefulCtx here so that we'll have time to send the graceful disconnect
519
+ coord , err := client .Coordinate (tac .gracefulCtx )
485
520
if err != nil {
486
- return xerrors .Errorf ("failed to connect to Coordinate RPC: %w" , err )
521
+ tac .logger .Error (tac .ctx , "failed to connect to Coordinate RPC" , slog .Error (err ))
522
+ return
487
523
}
488
524
defer func () {
489
525
cErr := coord .Close ()
490
526
if cErr != nil {
491
- tac .logger .Debug (ctx , "error closing Coordinate RPC" , slog .Error (cErr ))
527
+ tac .logger .Debug (tac . ctx , "error closing Coordinate RPC" , slog .Error (cErr ))
492
528
}
493
529
}()
494
530
coordination := tailnet .NewRemoteCoordination (tac .logger , coord , tac .conn , tac .agentID )
495
- tac .logger .Debug (ctx , "serving coordinator" )
496
- err = <- coordination .Error ()
497
- if err != nil &&
498
- ! xerrors .Is (err , io .EOF ) &&
499
- ! xerrors .Is (err , context .Canceled ) &&
500
- ! xerrors .Is (err , context .DeadlineExceeded ) {
501
- return xerrors .Errorf ("remote coordination error: %w" , err )
531
+ tac .logger .Debug (tac .ctx , "serving coordinator" )
532
+ select {
533
+ case <- tac .ctx .Done ():
534
+ tac .logger .Debug (tac .ctx , "main context canceled; do graceful disconnect" )
535
+ crdErr := coordination .Close ()
536
+ if crdErr != nil {
537
+ tac .logger .Error (tac .ctx , "failed to close remote coordination" , slog .Error (err ))
538
+ }
539
+ case err = <- coordination .Error ():
540
+ if err != nil &&
541
+ ! xerrors .Is (err , io .EOF ) &&
542
+ ! xerrors .Is (err , context .Canceled ) &&
543
+ ! xerrors .Is (err , context .DeadlineExceeded ) {
544
+ tac .logger .Error (tac .ctx , "remote coordination error: %w" , err )
545
+ }
502
546
}
503
- return nil
504
547
}
505
548
506
- func (tac * tailnetAPIConnector ) derpMap (ctx context. Context , client proto.DRPCTailnetClient ) error {
507
- s , err := client .StreamDERPMaps (ctx , & proto.StreamDERPMapsRequest {})
549
+ func (tac * tailnetAPIConnector ) derpMap (client proto.DRPCTailnetClient ) error {
550
+ s , err := client .StreamDERPMaps (tac . ctx , & proto.StreamDERPMapsRequest {})
508
551
if err != nil {
509
552
return xerrors .Errorf ("failed to connect to StreamDERPMaps RPC: %w" , err )
510
553
}
511
554
defer func () {
512
555
cErr := s .Close ()
513
556
if cErr != nil {
514
- tac .logger .Debug (ctx , "error closing StreamDERPMaps RPC" , slog .Error (cErr ))
557
+ tac .logger .Debug (tac . ctx , "error closing StreamDERPMaps RPC" , slog .Error (cErr ))
515
558
}
516
559
}()
517
560
for {
518
561
dmp , err := s .Recv ()
519
562
if err != nil {
520
- if xerrors .Is (err , io . EOF ) || xerrors . Is ( err , context .Canceled ) || xerrors .Is (err , context .DeadlineExceeded ) {
563
+ if xerrors .Is (err , context .Canceled ) || xerrors .Is (err , context .DeadlineExceeded ) {
521
564
return nil
522
565
}
523
- return xerrors .Errorf ("error receiving DERP Map: %w" , err )
566
+ tac .logger .Error (tac .ctx , "error receiving DERP Map" , slog .Error (err ))
567
+ return err
524
568
}
525
- tac .logger .Debug (ctx , "got new DERP Map" , slog .F ("derp_map" , dmp ))
569
+ tac .logger .Debug (tac . ctx , "got new DERP Map" , slog .F ("derp_map" , dmp ))
526
570
dm := tailnet .DERPMapFromProto (dmp )
527
571
tac .conn .SetDERPMap (dm )
528
572
}
0 commit comments