Skip to content

Commit 5c3b3a4

Browse files
committed
chore: add resume token controller
1 parent 7cdbc31 commit 5c3b3a4

File tree

2 files changed

+307
-0
lines changed

2 files changed

+307
-0
lines changed

tailnet/controllers.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"io"
7+
"math"
78
"strings"
89
"sync"
910
"time"
@@ -16,6 +17,7 @@ import (
1617

1718
"cdr.dev/slog"
1819
"github.com/coder/coder/v2/tailnet/proto"
20+
"github.com/coder/quartz"
1921
)
2022

2123
// A Controller connects to the tailnet control plane, and then uses the control protocols to
@@ -523,3 +525,139 @@ func sendTelemetry(
523525
}
524526
return false
525527
}
528+
529+
type basicResumeTokenController struct {
530+
logger slog.Logger
531+
532+
sync.Mutex
533+
token *proto.RefreshResumeTokenResponse
534+
refresher *basicResumeTokenRefresher
535+
536+
// for testing
537+
clock quartz.Clock
538+
}
539+
540+
func (b *basicResumeTokenController) New(client ResumeTokenClient) CloserWaiter {
541+
b.Lock()
542+
defer b.Unlock()
543+
if b.refresher != nil {
544+
b.refresher.Close(context.Background())
545+
}
546+
b.refresher = newBasicResumeTokenRefresher(b.logger, b.clock, b, client)
547+
return b.refresher
548+
}
549+
550+
func (b *basicResumeTokenController) Token() (string, bool) {
551+
b.Lock()
552+
defer b.Unlock()
553+
if b.token == nil {
554+
return "", false
555+
}
556+
if b.token.ExpiresAt.AsTime().Before(b.clock.Now()) {
557+
return "", false
558+
}
559+
return b.token.Token, true
560+
}
561+
562+
func NewBasicResumeTokenController(logger slog.Logger, clock quartz.Clock) ResumeTokenController {
563+
return &basicResumeTokenController{
564+
logger: logger,
565+
clock: clock,
566+
}
567+
}
568+
569+
type basicResumeTokenRefresher struct {
570+
logger slog.Logger
571+
ctx context.Context
572+
cancel context.CancelFunc
573+
ctrl *basicResumeTokenController
574+
client ResumeTokenClient
575+
errCh chan error
576+
577+
sync.Mutex
578+
closed bool
579+
timer *quartz.Timer
580+
}
581+
582+
func (r *basicResumeTokenRefresher) Close(_ context.Context) error {
583+
r.cancel()
584+
r.Lock()
585+
defer r.Unlock()
586+
if r.closed {
587+
return nil
588+
}
589+
r.closed = true
590+
r.timer.Stop()
591+
select {
592+
case r.errCh <- nil:
593+
default: // already have an error
594+
}
595+
return nil
596+
}
597+
598+
func (r *basicResumeTokenRefresher) Wait() <-chan error {
599+
return r.errCh
600+
}
601+
602+
const never time.Duration = math.MaxInt64
603+
604+
func newBasicResumeTokenRefresher(
605+
logger slog.Logger, clock quartz.Clock,
606+
ctrl *basicResumeTokenController, client ResumeTokenClient,
607+
) *basicResumeTokenRefresher {
608+
r := &basicResumeTokenRefresher{
609+
logger: logger,
610+
ctrl: ctrl,
611+
client: client,
612+
errCh: make(chan error, 1),
613+
}
614+
r.ctx, r.cancel = context.WithCancel(context.Background())
615+
r.timer = clock.AfterFunc(never, r.refresh)
616+
go r.refresh()
617+
return r
618+
}
619+
620+
func (r *basicResumeTokenRefresher) refresh() {
621+
if r.ctx.Err() != nil {
622+
return // context done, no need to refresh
623+
}
624+
res, err := r.client.RefreshResumeToken(r.ctx, &proto.RefreshResumeTokenRequest{})
625+
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
626+
// these can only come from being closed, no need to log
627+
select {
628+
case r.errCh <- nil:
629+
default: // already have an error
630+
}
631+
return
632+
}
633+
if err != nil {
634+
r.logger.Error(r.ctx, "error refreshing coordinator resume token", slog.Error(err))
635+
select {
636+
case r.errCh <- err:
637+
default: // already have an error
638+
}
639+
return
640+
}
641+
r.logger.Debug(r.ctx, "refreshed coordinator resume token",
642+
slog.F("expires_at", res.GetExpiresAt()),
643+
slog.F("refresh_in", res.GetRefreshIn()),
644+
)
645+
r.ctrl.Lock()
646+
if r.ctrl.refresher == r { // don't overwrite if we're not the current refresher
647+
r.ctrl.token = res
648+
} else {
649+
r.logger.Debug(context.Background(), "not writing token because we have a new client")
650+
}
651+
r.ctrl.Unlock()
652+
dur := res.RefreshIn.AsDuration()
653+
if dur <= 0 {
654+
// A sensible delay to refresh again.
655+
dur = 30 * time.Minute
656+
}
657+
r.Lock()
658+
defer r.Unlock()
659+
if r.closed {
660+
return
661+
}
662+
r.timer.Reset(dur, "basicResumeTokenRefresher", "refresh")
663+
}

tailnet/controllers_test.go

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import (
1313
"github.com/stretchr/testify/require"
1414
"go.uber.org/mock/gomock"
1515
"golang.org/x/xerrors"
16+
"google.golang.org/protobuf/types/known/durationpb"
17+
"google.golang.org/protobuf/types/known/timestamppb"
1618
"storj.io/drpc"
1719
"storj.io/drpc/drpcerr"
1820
"tailscale.com/tailcfg"
@@ -24,6 +26,7 @@ import (
2426
"github.com/coder/coder/v2/tailnet/proto"
2527
"github.com/coder/coder/v2/tailnet/tailnettest"
2628
"github.com/coder/coder/v2/testutil"
29+
"github.com/coder/quartz"
2730
)
2831

2932
func TestInMemoryCoordination(t *testing.T) {
@@ -507,3 +510,169 @@ type fakeTelemetryCall struct {
507510
req *proto.TelemetryRequest
508511
errCh chan error
509512
}
513+
514+
func TestBasicResumeTokenController_Mainline(t *testing.T) {
515+
ctx := testutil.Context(t, testutil.WaitShort)
516+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
517+
fr := newFakeResumeTokenClient(ctx)
518+
mClock := quartz.NewMock(t)
519+
trp := mClock.Trap().TimerReset("basicResumeTokenRefresher", "refresh")
520+
defer trp.Close()
521+
522+
uut := tailnet.NewBasicResumeTokenController(logger, mClock)
523+
_, ok := uut.Token()
524+
require.False(t, ok)
525+
526+
cwCh := make(chan tailnet.CloserWaiter, 1)
527+
go func() {
528+
cwCh <- uut.New(fr)
529+
}()
530+
call := testutil.RequireRecvCtx(ctx, t, fr.calls)
531+
testutil.RequireSendCtx(ctx, t, call.resp, &proto.RefreshResumeTokenResponse{
532+
Token: "test token 1",
533+
RefreshIn: durationpb.New(100 * time.Second),
534+
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
535+
})
536+
trp.MustWait(ctx).Release() // initial refresh done
537+
token, ok := uut.Token()
538+
require.True(t, ok)
539+
require.Equal(t, "test token 1", token)
540+
cw := testutil.RequireRecvCtx(ctx, t, cwCh)
541+
542+
w := mClock.Advance(100 * time.Second)
543+
call = testutil.RequireRecvCtx(ctx, t, fr.calls)
544+
testutil.RequireSendCtx(ctx, t, call.resp, &proto.RefreshResumeTokenResponse{
545+
Token: "test token 2",
546+
RefreshIn: durationpb.New(50 * time.Second),
547+
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
548+
})
549+
resetCall := trp.MustWait(ctx)
550+
require.Equal(t, resetCall.Duration, 50*time.Second)
551+
resetCall.Release()
552+
w.MustWait(ctx)
553+
token, ok = uut.Token()
554+
require.True(t, ok)
555+
require.Equal(t, "test token 2", token)
556+
557+
err := cw.Close(ctx)
558+
require.NoError(t, err)
559+
err = testutil.RequireRecvCtx(ctx, t, cw.Wait())
560+
require.NoError(t, err)
561+
562+
token, ok = uut.Token()
563+
require.True(t, ok)
564+
require.Equal(t, "test token 2", token)
565+
566+
mClock.Advance(201 * time.Second).MustWait(ctx)
567+
token, ok = uut.Token()
568+
require.False(t, ok)
569+
}
570+
571+
func TestBasicResumeTokenController_NewWhileRefreshing(t *testing.T) {
572+
ctx := testutil.Context(t, testutil.WaitShort)
573+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
574+
mClock := quartz.NewMock(t)
575+
trp := mClock.Trap().TimerReset("basicResumeTokenRefresher", "refresh")
576+
defer trp.Close()
577+
578+
uut := tailnet.NewBasicResumeTokenController(logger, mClock)
579+
_, ok := uut.Token()
580+
require.False(t, ok)
581+
582+
fr1 := newFakeResumeTokenClient(ctx)
583+
cwCh1 := make(chan tailnet.CloserWaiter, 1)
584+
go func() {
585+
cwCh1 <- uut.New(fr1)
586+
}()
587+
call1 := testutil.RequireRecvCtx(ctx, t, fr1.calls)
588+
589+
fr2 := newFakeResumeTokenClient(ctx)
590+
cwCh2 := make(chan tailnet.CloserWaiter, 1)
591+
go func() {
592+
cwCh2 <- uut.New(fr2)
593+
}()
594+
call2 := testutil.RequireRecvCtx(ctx, t, fr2.calls)
595+
596+
testutil.RequireSendCtx(ctx, t, call2.resp, &proto.RefreshResumeTokenResponse{
597+
Token: "test token 2.0",
598+
RefreshIn: durationpb.New(102 * time.Second),
599+
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
600+
})
601+
602+
cw2 := testutil.RequireRecvCtx(ctx, t, cwCh2) // this ensures Close was called on 1
603+
604+
testutil.RequireSendCtx(ctx, t, call1.resp, &proto.RefreshResumeTokenResponse{
605+
Token: "test token 1",
606+
RefreshIn: durationpb.New(101 * time.Second),
607+
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
608+
})
609+
610+
trp.MustWait(ctx).Release()
611+
612+
token, ok := uut.Token()
613+
require.True(t, ok)
614+
require.Equal(t, "test token 2.0", token)
615+
616+
// refresher 1 should already be closed.
617+
cw1 := testutil.RequireRecvCtx(ctx, t, cwCh1)
618+
err := testutil.RequireRecvCtx(ctx, t, cw1.Wait())
619+
require.NoError(t, err)
620+
621+
w := mClock.Advance(102 * time.Second)
622+
call := testutil.RequireRecvCtx(ctx, t, fr2.calls)
623+
testutil.RequireSendCtx(ctx, t, call.resp, &proto.RefreshResumeTokenResponse{
624+
Token: "test token 2.1",
625+
RefreshIn: durationpb.New(50 * time.Second),
626+
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
627+
})
628+
resetCall := trp.MustWait(ctx)
629+
require.Equal(t, resetCall.Duration, 50*time.Second)
630+
resetCall.Release()
631+
w.MustWait(ctx)
632+
token, ok = uut.Token()
633+
require.True(t, ok)
634+
require.Equal(t, "test token 2.1", token)
635+
636+
err = cw2.Close(ctx)
637+
require.NoError(t, err)
638+
err = testutil.RequireRecvCtx(ctx, t, cw2.Wait())
639+
require.NoError(t, err)
640+
}
641+
642+
func newFakeResumeTokenClient(ctx context.Context) *fakeResumeTokenClient {
643+
return &fakeResumeTokenClient{
644+
ctx: ctx,
645+
calls: make(chan *fakeResumeTokenCall),
646+
}
647+
}
648+
649+
type fakeResumeTokenClient struct {
650+
ctx context.Context
651+
calls chan *fakeResumeTokenCall
652+
}
653+
654+
func (f *fakeResumeTokenClient) RefreshResumeToken(_ context.Context, _ *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error) {
655+
call := &fakeResumeTokenCall{
656+
resp: make(chan *proto.RefreshResumeTokenResponse),
657+
errCh: make(chan error),
658+
}
659+
select {
660+
case <-f.ctx.Done():
661+
return nil, f.ctx.Err()
662+
case f.calls <- call:
663+
// OK
664+
}
665+
select {
666+
case <-f.ctx.Done():
667+
return nil, f.ctx.Err()
668+
case err := <-call.errCh:
669+
return nil, err
670+
case resp := <-call.resp:
671+
return resp, nil
672+
}
673+
}
674+
675+
type fakeResumeTokenCall struct {
676+
resp chan *proto.RefreshResumeTokenResponse
677+
errCh chan error
678+
}

0 commit comments

Comments
 (0)