@@ -13,6 +13,8 @@ import (
13
13
"github.com/stretchr/testify/require"
14
14
"go.uber.org/mock/gomock"
15
15
"golang.org/x/xerrors"
16
+ "google.golang.org/protobuf/types/known/durationpb"
17
+ "google.golang.org/protobuf/types/known/timestamppb"
16
18
"storj.io/drpc"
17
19
"storj.io/drpc/drpcerr"
18
20
"tailscale.com/tailcfg"
@@ -24,6 +26,7 @@ import (
24
26
"github.com/coder/coder/v2/tailnet/proto"
25
27
"github.com/coder/coder/v2/tailnet/tailnettest"
26
28
"github.com/coder/coder/v2/testutil"
29
+ "github.com/coder/quartz"
27
30
)
28
31
29
32
func TestInMemoryCoordination (t * testing.T ) {
@@ -507,3 +510,171 @@ type fakeTelemetryCall struct {
507
510
req * proto.TelemetryRequest
508
511
errCh chan error
509
512
}
513
+
514
+ func TestBasicResumeTokenController_Mainline (t * testing.T ) {
515
+ t .Parallel ()
516
+ ctx := testutil .Context (t , testutil .WaitShort )
517
+ logger := slogtest .Make (t , nil ).Leveled (slog .LevelDebug )
518
+ fr := newFakeResumeTokenClient (ctx )
519
+ mClock := quartz .NewMock (t )
520
+ trp := mClock .Trap ().TimerReset ("basicResumeTokenRefresher" , "refresh" )
521
+ defer trp .Close ()
522
+
523
+ uut := tailnet .NewBasicResumeTokenController (logger , mClock )
524
+ _ , ok := uut .Token ()
525
+ require .False (t , ok )
526
+
527
+ cwCh := make (chan tailnet.CloserWaiter , 1 )
528
+ go func () {
529
+ cwCh <- uut .New (fr )
530
+ }()
531
+ call := testutil .RequireRecvCtx (ctx , t , fr .calls )
532
+ testutil .RequireSendCtx (ctx , t , call .resp , & proto.RefreshResumeTokenResponse {
533
+ Token : "test token 1" ,
534
+ RefreshIn : durationpb .New (100 * time .Second ),
535
+ ExpiresAt : timestamppb .New (mClock .Now ().Add (200 * time .Second )),
536
+ })
537
+ trp .MustWait (ctx ).Release () // initial refresh done
538
+ token , ok := uut .Token ()
539
+ require .True (t , ok )
540
+ require .Equal (t , "test token 1" , token )
541
+ cw := testutil .RequireRecvCtx (ctx , t , cwCh )
542
+
543
+ w := mClock .Advance (100 * time .Second )
544
+ call = testutil .RequireRecvCtx (ctx , t , fr .calls )
545
+ testutil .RequireSendCtx (ctx , t , call .resp , & proto.RefreshResumeTokenResponse {
546
+ Token : "test token 2" ,
547
+ RefreshIn : durationpb .New (50 * time .Second ),
548
+ ExpiresAt : timestamppb .New (mClock .Now ().Add (200 * time .Second )),
549
+ })
550
+ resetCall := trp .MustWait (ctx )
551
+ require .Equal (t , resetCall .Duration , 50 * time .Second )
552
+ resetCall .Release ()
553
+ w .MustWait (ctx )
554
+ token , ok = uut .Token ()
555
+ require .True (t , ok )
556
+ require .Equal (t , "test token 2" , token )
557
+
558
+ err := cw .Close (ctx )
559
+ require .NoError (t , err )
560
+ err = testutil .RequireRecvCtx (ctx , t , cw .Wait ())
561
+ require .NoError (t , err )
562
+
563
+ token , ok = uut .Token ()
564
+ require .True (t , ok )
565
+ require .Equal (t , "test token 2" , token )
566
+
567
+ mClock .Advance (201 * time .Second ).MustWait (ctx )
568
+ _ , ok = uut .Token ()
569
+ require .False (t , ok )
570
+ }
571
+
572
+ func TestBasicResumeTokenController_NewWhileRefreshing (t * testing.T ) {
573
+ t .Parallel ()
574
+ ctx := testutil .Context (t , testutil .WaitShort )
575
+ logger := slogtest .Make (t , nil ).Leveled (slog .LevelDebug )
576
+ mClock := quartz .NewMock (t )
577
+ trp := mClock .Trap ().TimerReset ("basicResumeTokenRefresher" , "refresh" )
578
+ defer trp .Close ()
579
+
580
+ uut := tailnet .NewBasicResumeTokenController (logger , mClock )
581
+ _ , ok := uut .Token ()
582
+ require .False (t , ok )
583
+
584
+ fr1 := newFakeResumeTokenClient (ctx )
585
+ cwCh1 := make (chan tailnet.CloserWaiter , 1 )
586
+ go func () {
587
+ cwCh1 <- uut .New (fr1 )
588
+ }()
589
+ call1 := testutil .RequireRecvCtx (ctx , t , fr1 .calls )
590
+
591
+ fr2 := newFakeResumeTokenClient (ctx )
592
+ cwCh2 := make (chan tailnet.CloserWaiter , 1 )
593
+ go func () {
594
+ cwCh2 <- uut .New (fr2 )
595
+ }()
596
+ call2 := testutil .RequireRecvCtx (ctx , t , fr2 .calls )
597
+
598
+ testutil .RequireSendCtx (ctx , t , call2 .resp , & proto.RefreshResumeTokenResponse {
599
+ Token : "test token 2.0" ,
600
+ RefreshIn : durationpb .New (102 * time .Second ),
601
+ ExpiresAt : timestamppb .New (mClock .Now ().Add (200 * time .Second )),
602
+ })
603
+
604
+ cw2 := testutil .RequireRecvCtx (ctx , t , cwCh2 ) // this ensures Close was called on 1
605
+
606
+ testutil .RequireSendCtx (ctx , t , call1 .resp , & proto.RefreshResumeTokenResponse {
607
+ Token : "test token 1" ,
608
+ RefreshIn : durationpb .New (101 * time .Second ),
609
+ ExpiresAt : timestamppb .New (mClock .Now ().Add (200 * time .Second )),
610
+ })
611
+
612
+ trp .MustWait (ctx ).Release ()
613
+
614
+ token , ok := uut .Token ()
615
+ require .True (t , ok )
616
+ require .Equal (t , "test token 2.0" , token )
617
+
618
+ // refresher 1 should already be closed.
619
+ cw1 := testutil .RequireRecvCtx (ctx , t , cwCh1 )
620
+ err := testutil .RequireRecvCtx (ctx , t , cw1 .Wait ())
621
+ require .NoError (t , err )
622
+
623
+ w := mClock .Advance (102 * time .Second )
624
+ call := testutil .RequireRecvCtx (ctx , t , fr2 .calls )
625
+ testutil .RequireSendCtx (ctx , t , call .resp , & proto.RefreshResumeTokenResponse {
626
+ Token : "test token 2.1" ,
627
+ RefreshIn : durationpb .New (50 * time .Second ),
628
+ ExpiresAt : timestamppb .New (mClock .Now ().Add (200 * time .Second )),
629
+ })
630
+ resetCall := trp .MustWait (ctx )
631
+ require .Equal (t , resetCall .Duration , 50 * time .Second )
632
+ resetCall .Release ()
633
+ w .MustWait (ctx )
634
+ token , ok = uut .Token ()
635
+ require .True (t , ok )
636
+ require .Equal (t , "test token 2.1" , token )
637
+
638
+ err = cw2 .Close (ctx )
639
+ require .NoError (t , err )
640
+ err = testutil .RequireRecvCtx (ctx , t , cw2 .Wait ())
641
+ require .NoError (t , err )
642
+ }
643
+
644
+ func newFakeResumeTokenClient (ctx context.Context ) * fakeResumeTokenClient {
645
+ return & fakeResumeTokenClient {
646
+ ctx : ctx ,
647
+ calls : make (chan * fakeResumeTokenCall ),
648
+ }
649
+ }
650
+
651
+ type fakeResumeTokenClient struct {
652
+ ctx context.Context
653
+ calls chan * fakeResumeTokenCall
654
+ }
655
+
656
+ func (f * fakeResumeTokenClient ) RefreshResumeToken (_ context.Context , _ * proto.RefreshResumeTokenRequest ) (* proto.RefreshResumeTokenResponse , error ) {
657
+ call := & fakeResumeTokenCall {
658
+ resp : make (chan * proto.RefreshResumeTokenResponse ),
659
+ errCh : make (chan error ),
660
+ }
661
+ select {
662
+ case <- f .ctx .Done ():
663
+ return nil , f .ctx .Err ()
664
+ case f .calls <- call :
665
+ // OK
666
+ }
667
+ select {
668
+ case <- f .ctx .Done ():
669
+ return nil , f .ctx .Err ()
670
+ case err := <- call .errCh :
671
+ return nil , err
672
+ case resp := <- call .resp :
673
+ return resp , nil
674
+ }
675
+ }
676
+
677
+ type fakeResumeTokenCall struct {
678
+ resp chan * proto.RefreshResumeTokenResponse
679
+ errCh chan error
680
+ }
0 commit comments