@@ -13,6 +13,8 @@ import (
13
13
"github.com/stretchr/testify/require"
14
14
"go.uber.org/mock/gomock"
15
15
"golang.org/x/xerrors"
16
+ "google.golang.org/protobuf/types/known/durationpb"
17
+ "google.golang.org/protobuf/types/known/timestamppb"
16
18
"storj.io/drpc"
17
19
"storj.io/drpc/drpcerr"
18
20
"tailscale.com/tailcfg"
@@ -24,6 +26,7 @@ import (
24
26
"github.com/coder/coder/v2/tailnet/proto"
25
27
"github.com/coder/coder/v2/tailnet/tailnettest"
26
28
"github.com/coder/coder/v2/testutil"
29
+ "github.com/coder/quartz"
27
30
)
28
31
29
32
func TestInMemoryCoordination (t * testing.T ) {
@@ -507,3 +510,169 @@ type fakeTelemetryCall struct {
507
510
req * proto.TelemetryRequest
508
511
errCh chan error
509
512
}
513
+
514
+ func TestBasicResumeTokenController_Mainline (t * testing.T ) {
515
+ ctx := testutil .Context (t , testutil .WaitShort )
516
+ logger := slogtest .Make (t , nil ).Leveled (slog .LevelDebug )
517
+ fr := newFakeResumeTokenClient (ctx )
518
+ mClock := quartz .NewMock (t )
519
+ trp := mClock .Trap ().TimerReset ("basicResumeTokenRefresher" , "refresh" )
520
+ defer trp .Close ()
521
+
522
+ uut := tailnet .NewBasicResumeTokenController (logger , mClock )
523
+ _ , ok := uut .Token ()
524
+ require .False (t , ok )
525
+
526
+ cwCh := make (chan tailnet.CloserWaiter , 1 )
527
+ go func () {
528
+ cwCh <- uut .New (fr )
529
+ }()
530
+ call := testutil .RequireRecvCtx (ctx , t , fr .calls )
531
+ testutil .RequireSendCtx (ctx , t , call .resp , & proto.RefreshResumeTokenResponse {
532
+ Token : "test token 1" ,
533
+ RefreshIn : durationpb .New (100 * time .Second ),
534
+ ExpiresAt : timestamppb .New (mClock .Now ().Add (200 * time .Second )),
535
+ })
536
+ trp .MustWait (ctx ).Release () // initial refresh done
537
+ token , ok := uut .Token ()
538
+ require .True (t , ok )
539
+ require .Equal (t , "test token 1" , token )
540
+ cw := testutil .RequireRecvCtx (ctx , t , cwCh )
541
+
542
+ w := mClock .Advance (100 * time .Second )
543
+ call = testutil .RequireRecvCtx (ctx , t , fr .calls )
544
+ testutil .RequireSendCtx (ctx , t , call .resp , & proto.RefreshResumeTokenResponse {
545
+ Token : "test token 2" ,
546
+ RefreshIn : durationpb .New (50 * time .Second ),
547
+ ExpiresAt : timestamppb .New (mClock .Now ().Add (200 * time .Second )),
548
+ })
549
+ resetCall := trp .MustWait (ctx )
550
+ require .Equal (t , resetCall .Duration , 50 * time .Second )
551
+ resetCall .Release ()
552
+ w .MustWait (ctx )
553
+ token , ok = uut .Token ()
554
+ require .True (t , ok )
555
+ require .Equal (t , "test token 2" , token )
556
+
557
+ err := cw .Close (ctx )
558
+ require .NoError (t , err )
559
+ err = testutil .RequireRecvCtx (ctx , t , cw .Wait ())
560
+ require .NoError (t , err )
561
+
562
+ token , ok = uut .Token ()
563
+ require .True (t , ok )
564
+ require .Equal (t , "test token 2" , token )
565
+
566
+ mClock .Advance (201 * time .Second ).MustWait (ctx )
567
+ token , ok = uut .Token ()
568
+ require .False (t , ok )
569
+ }
570
+
571
+ func TestBasicResumeTokenController_NewWhileRefreshing (t * testing.T ) {
572
+ ctx := testutil .Context (t , testutil .WaitShort )
573
+ logger := slogtest .Make (t , nil ).Leveled (slog .LevelDebug )
574
+ mClock := quartz .NewMock (t )
575
+ trp := mClock .Trap ().TimerReset ("basicResumeTokenRefresher" , "refresh" )
576
+ defer trp .Close ()
577
+
578
+ uut := tailnet .NewBasicResumeTokenController (logger , mClock )
579
+ _ , ok := uut .Token ()
580
+ require .False (t , ok )
581
+
582
+ fr1 := newFakeResumeTokenClient (ctx )
583
+ cwCh1 := make (chan tailnet.CloserWaiter , 1 )
584
+ go func () {
585
+ cwCh1 <- uut .New (fr1 )
586
+ }()
587
+ call1 := testutil .RequireRecvCtx (ctx , t , fr1 .calls )
588
+
589
+ fr2 := newFakeResumeTokenClient (ctx )
590
+ cwCh2 := make (chan tailnet.CloserWaiter , 1 )
591
+ go func () {
592
+ cwCh2 <- uut .New (fr2 )
593
+ }()
594
+ call2 := testutil .RequireRecvCtx (ctx , t , fr2 .calls )
595
+
596
+ testutil .RequireSendCtx (ctx , t , call2 .resp , & proto.RefreshResumeTokenResponse {
597
+ Token : "test token 2.0" ,
598
+ RefreshIn : durationpb .New (102 * time .Second ),
599
+ ExpiresAt : timestamppb .New (mClock .Now ().Add (200 * time .Second )),
600
+ })
601
+
602
+ cw2 := testutil .RequireRecvCtx (ctx , t , cwCh2 ) // this ensures Close was called on 1
603
+
604
+ testutil .RequireSendCtx (ctx , t , call1 .resp , & proto.RefreshResumeTokenResponse {
605
+ Token : "test token 1" ,
606
+ RefreshIn : durationpb .New (101 * time .Second ),
607
+ ExpiresAt : timestamppb .New (mClock .Now ().Add (200 * time .Second )),
608
+ })
609
+
610
+ trp .MustWait (ctx ).Release ()
611
+
612
+ token , ok := uut .Token ()
613
+ require .True (t , ok )
614
+ require .Equal (t , "test token 2.0" , token )
615
+
616
+ // refresher 1 should already be closed.
617
+ cw1 := testutil .RequireRecvCtx (ctx , t , cwCh1 )
618
+ err := testutil .RequireRecvCtx (ctx , t , cw1 .Wait ())
619
+ require .NoError (t , err )
620
+
621
+ w := mClock .Advance (102 * time .Second )
622
+ call := testutil .RequireRecvCtx (ctx , t , fr2 .calls )
623
+ testutil .RequireSendCtx (ctx , t , call .resp , & proto.RefreshResumeTokenResponse {
624
+ Token : "test token 2.1" ,
625
+ RefreshIn : durationpb .New (50 * time .Second ),
626
+ ExpiresAt : timestamppb .New (mClock .Now ().Add (200 * time .Second )),
627
+ })
628
+ resetCall := trp .MustWait (ctx )
629
+ require .Equal (t , resetCall .Duration , 50 * time .Second )
630
+ resetCall .Release ()
631
+ w .MustWait (ctx )
632
+ token , ok = uut .Token ()
633
+ require .True (t , ok )
634
+ require .Equal (t , "test token 2.1" , token )
635
+
636
+ err = cw2 .Close (ctx )
637
+ require .NoError (t , err )
638
+ err = testutil .RequireRecvCtx (ctx , t , cw2 .Wait ())
639
+ require .NoError (t , err )
640
+ }
641
+
642
+ func newFakeResumeTokenClient (ctx context.Context ) * fakeResumeTokenClient {
643
+ return & fakeResumeTokenClient {
644
+ ctx : ctx ,
645
+ calls : make (chan * fakeResumeTokenCall ),
646
+ }
647
+ }
648
+
649
+ type fakeResumeTokenClient struct {
650
+ ctx context.Context
651
+ calls chan * fakeResumeTokenCall
652
+ }
653
+
654
+ func (f * fakeResumeTokenClient ) RefreshResumeToken (_ context.Context , _ * proto.RefreshResumeTokenRequest ) (* proto.RefreshResumeTokenResponse , error ) {
655
+ call := & fakeResumeTokenCall {
656
+ resp : make (chan * proto.RefreshResumeTokenResponse ),
657
+ errCh : make (chan error ),
658
+ }
659
+ select {
660
+ case <- f .ctx .Done ():
661
+ return nil , f .ctx .Err ()
662
+ case f .calls <- call :
663
+ // OK
664
+ }
665
+ select {
666
+ case <- f .ctx .Done ():
667
+ return nil , f .ctx .Err ()
668
+ case err := <- call .errCh :
669
+ return nil , err
670
+ case resp := <- call .resp :
671
+ return resp , nil
672
+ }
673
+ }
674
+
675
+ type fakeResumeTokenCall struct {
676
+ resp chan * proto.RefreshResumeTokenResponse
677
+ errCh chan error
678
+ }
0 commit comments