@@ -12,10 +12,9 @@ import (
12
12
"net/netip"
13
13
"strconv"
14
14
"strings"
15
+ "sync"
15
16
"time"
16
17
17
- "golang.org/x/sync/errgroup"
18
-
19
18
"github.com/google/uuid"
20
19
"golang.org/x/xerrors"
21
20
"nhooyr.io/websocket"
@@ -360,6 +359,13 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
360
359
return agentConn , nil
361
360
}
362
361
362
+ // tailnetConn is the subset of the tailnet.Conn methods that tailnetAPIConnector uses. It is
363
+ // included so that we can fake it in testing.
364
+ type tailnetConn interface {
365
+ tailnet.Coordinatee
366
+ SetDERPMap (derpMap * tailcfg.DERPMap )
367
+ }
368
+
363
369
// tailnetAPIConnector dials the tailnet API (v2+) and then uses the API with a tailnet.Conn to
364
370
//
365
371
// 1) run the Coordinate API and pass node information back and forth
@@ -370,13 +376,20 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
370
376
//
371
377
// @typescript-ignore tailnetAPIConnector
372
378
type tailnetAPIConnector struct {
373
- ctx context.Context
379
+ // We keep track of two contexts: the main context from the caller, and a "graceful" context
380
+ // that we keep open slightly longer than the main context to give a chance to send the
381
+ // Disconnect message to the coordinator. That tells the coordinator that we really meant to
382
+ // disconnect instead of just losing network connectivity.
383
+ ctx context.Context
384
+ gracefulCtx context.Context
385
+ cancelGracefulCtx context.CancelFunc
386
+
374
387
logger slog.Logger
375
388
376
389
agentID uuid.UUID
377
390
coordinateURL string
378
391
dialOptions * websocket.DialOptions
379
- conn * tailnet. Conn
392
+ conn tailnetConn
380
393
381
394
connected chan error
382
395
isFirst bool
@@ -387,7 +400,7 @@ type tailnetAPIConnector struct {
387
400
func runTailnetAPIConnector (
388
401
ctx context.Context , logger slog.Logger ,
389
402
agentID uuid.UUID , coordinateURL string , dialOptions * websocket.DialOptions ,
390
- conn * tailnet. Conn ,
403
+ conn tailnetConn ,
391
404
) * tailnetAPIConnector {
392
405
tac := & tailnetAPIConnector {
393
406
ctx : ctx ,
@@ -399,10 +412,23 @@ func runTailnetAPIConnector(
399
412
connected : make (chan error , 1 ),
400
413
closed : make (chan struct {}),
401
414
}
415
+ tac .gracefulCtx , tac .cancelGracefulCtx = context .WithCancel (context .Background ())
416
+ go tac .manageGracefulTimeout ()
402
417
go tac .run ()
403
418
return tac
404
419
}
405
420
421
+ // manageGracefulTimeout allows the gracefulContext to last 1 second longer than the main context
422
+ // to allow a graceful disconnect.
423
+ func (tac * tailnetAPIConnector ) manageGracefulTimeout () {
424
+ defer tac .cancelGracefulCtx ()
425
+ <- tac .ctx .Done ()
426
+ select {
427
+ case <- tac .closed :
428
+ case <- time .After (time .Second ):
429
+ }
430
+ }
431
+
406
432
func (tac * tailnetAPIConnector ) run () {
407
433
tac .isFirst = true
408
434
defer close (tac .closed )
@@ -437,7 +463,7 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
437
463
return nil , err
438
464
}
439
465
client , err := tailnet .NewDRPCClient (
440
- websocket .NetConn (tac .ctx , ws , websocket .MessageBinary ),
466
+ websocket .NetConn (tac .gracefulCtx , ws , websocket .MessageBinary ),
441
467
tac .logger ,
442
468
)
443
469
if err != nil {
@@ -464,65 +490,81 @@ func (tac *tailnetAPIConnector) coordinateAndDERPMap(client proto.DRPCTailnetCli
464
490
<- conn .Closed ()
465
491
}
466
492
}()
467
- eg , egCtx := errgroup .WithContext (tac .ctx )
468
- eg .Go (func () error {
469
- return tac .coordinate (egCtx , client )
470
- })
471
- eg .Go (func () error {
472
- return tac .derpMap (egCtx , client )
473
- })
474
- err := eg .Wait ()
475
- if err != nil &&
476
- ! xerrors .Is (err , io .EOF ) &&
477
- ! xerrors .Is (err , context .Canceled ) &&
478
- ! xerrors .Is (err , context .DeadlineExceeded ) {
479
- tac .logger .Error (tac .ctx , "error while connected to tailnet v2+ API" )
480
- }
493
+ wg := sync.WaitGroup {}
494
+ wg .Add (2 )
495
+ go func () {
496
+ defer wg .Done ()
497
+ tac .coordinate (client )
498
+ }()
499
+ go func () {
500
+ defer wg .Done ()
501
+ dErr := tac .derpMap (client )
502
+ if dErr != nil && tac .ctx .Err () == nil {
503
+ // The main context is still active, meaning that we want the tailnet data plane to stay
504
+ // up, even though we hit some error getting DERP maps on the control plane. That means
505
+ // we do NOT want to gracefully disconnect on the coordinate() routine. So, we'll just
506
+ // close the underlying connection. This will trigger a retry of the control plane in
507
+ // run().
508
+ client .DRPCConn ().Close ()
509
+ // Note that derpMap() logs it own errors, we don't bother here.
510
+ }
511
+ }()
512
+ wg .Wait ()
481
513
}
482
514
483
- func (tac * tailnetAPIConnector ) coordinate (ctx context.Context , client proto.DRPCTailnetClient ) error {
484
- coord , err := client .Coordinate (ctx )
515
+ func (tac * tailnetAPIConnector ) coordinate (client proto.DRPCTailnetClient ) {
516
+ // we use the gracefulCtx here so that we'll have time to send the graceful disconnect
517
+ coord , err := client .Coordinate (tac .gracefulCtx )
485
518
if err != nil {
486
- return xerrors .Errorf ("failed to connect to Coordinate RPC: %w" , err )
519
+ tac .logger .Error (tac .ctx , "failed to connect to Coordinate RPC" , slog .Error (err ))
520
+ return
487
521
}
488
522
defer func () {
489
523
cErr := coord .Close ()
490
524
if cErr != nil {
491
- tac .logger .Debug (ctx , "error closing Coordinate RPC" , slog .Error (cErr ))
525
+ tac .logger .Debug (tac . ctx , "error closing Coordinate RPC" , slog .Error (cErr ))
492
526
}
493
527
}()
494
528
coordination := tailnet .NewRemoteCoordination (tac .logger , coord , tac .conn , tac .agentID )
495
- tac .logger .Debug (ctx , "serving coordinator" )
496
- err = <- coordination .Error ()
497
- if err != nil &&
498
- ! xerrors .Is (err , io .EOF ) &&
499
- ! xerrors .Is (err , context .Canceled ) &&
500
- ! xerrors .Is (err , context .DeadlineExceeded ) {
501
- return xerrors .Errorf ("remote coordination error: %w" , err )
529
+ tac .logger .Debug (tac .ctx , "serving coordinator" )
530
+ select {
531
+ case <- tac .ctx .Done ():
532
+ // main context canceled; do graceful disconnect
533
+ crdErr := coordination .Close ()
534
+ if crdErr != nil {
535
+ tac .logger .Error (tac .ctx , "failed to close remote coordination" , slog .Error (err ))
536
+ }
537
+ case err = <- coordination .Error ():
538
+ if err != nil &&
539
+ ! xerrors .Is (err , io .EOF ) &&
540
+ ! xerrors .Is (err , context .Canceled ) &&
541
+ ! xerrors .Is (err , context .DeadlineExceeded ) {
542
+ tac .logger .Error (tac .ctx , "remote coordination error: %w" , err )
543
+ }
502
544
}
503
- return nil
504
545
}
505
546
506
- func (tac * tailnetAPIConnector ) derpMap (ctx context. Context , client proto.DRPCTailnetClient ) error {
507
- s , err := client .StreamDERPMaps (ctx , & proto.StreamDERPMapsRequest {})
547
+ func (tac * tailnetAPIConnector ) derpMap (client proto.DRPCTailnetClient ) error {
548
+ s , err := client .StreamDERPMaps (tac . ctx , & proto.StreamDERPMapsRequest {})
508
549
if err != nil {
509
550
return xerrors .Errorf ("failed to connect to StreamDERPMaps RPC: %w" , err )
510
551
}
511
552
defer func () {
512
553
cErr := s .Close ()
513
554
if cErr != nil {
514
- tac .logger .Debug (ctx , "error closing StreamDERPMaps RPC" , slog .Error (cErr ))
555
+ tac .logger .Debug (tac . ctx , "error closing StreamDERPMaps RPC" , slog .Error (cErr ))
515
556
}
516
557
}()
517
558
for {
518
559
dmp , err := s .Recv ()
519
560
if err != nil {
520
- if xerrors .Is (err , io . EOF ) || xerrors . Is ( err , context .Canceled ) || xerrors .Is (err , context .DeadlineExceeded ) {
561
+ if xerrors .Is (err , context .Canceled ) || xerrors .Is (err , context .DeadlineExceeded ) {
521
562
return nil
522
563
}
523
- return xerrors .Errorf ("error receiving DERP Map: %w" , err )
564
+ tac .logger .Error (tac .ctx , "error receiving DERP Map" , slog .Error (err ))
565
+ return err
524
566
}
525
- tac .logger .Debug (ctx , "got new DERP Map" , slog .F ("derp_map" , dmp ))
567
+ tac .logger .Debug (tac . ctx , "got new DERP Map" , slog .F ("derp_map" , dmp ))
526
568
dm := tailnet .DERPMapFromProto (dmp )
527
569
tac .conn .SetDERPMap (dm )
528
570
}
0 commit comments