Skip to content

Commit 5552b4f

Browse files
committed
chore: add resume token controller
1 parent 0b51bd5 commit 5552b4f

File tree

2 files changed

+312
-0
lines changed

2 files changed

+312
-0
lines changed

tailnet/controllers.go

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"io"
7+
"math"
78
"strings"
89
"sync"
910
"time"
@@ -16,6 +17,7 @@ import (
1617

1718
"cdr.dev/slog"
1819
"github.com/coder/coder/v2/tailnet/proto"
20+
"github.com/coder/quartz"
1921
)
2022

2123
// A Controller connects to the tailnet control plane, and then uses the control protocols to
@@ -523,3 +525,142 @@ func sendTelemetry(
523525
}
524526
return false
525527
}
528+
529+
type basicResumeTokenController struct {
530+
logger slog.Logger
531+
532+
sync.Mutex
533+
token *proto.RefreshResumeTokenResponse
534+
refresher *basicResumeTokenRefresher
535+
536+
// for testing
537+
clock quartz.Clock
538+
}
539+
540+
func (b *basicResumeTokenController) New(client ResumeTokenClient) CloserWaiter {
541+
b.Lock()
542+
defer b.Unlock()
543+
if b.refresher != nil {
544+
cErr := b.refresher.Close(context.Background())
545+
if cErr != nil {
546+
b.logger.Debug(context.Background(), "closed previous refresher", slog.Error(cErr))
547+
}
548+
}
549+
b.refresher = newBasicResumeTokenRefresher(b.logger, b.clock, b, client)
550+
return b.refresher
551+
}
552+
553+
func (b *basicResumeTokenController) Token() (string, bool) {
554+
b.Lock()
555+
defer b.Unlock()
556+
if b.token == nil {
557+
return "", false
558+
}
559+
if b.token.ExpiresAt.AsTime().Before(b.clock.Now()) {
560+
return "", false
561+
}
562+
return b.token.Token, true
563+
}
564+
565+
func NewBasicResumeTokenController(logger slog.Logger, clock quartz.Clock) ResumeTokenController {
566+
return &basicResumeTokenController{
567+
logger: logger,
568+
clock: clock,
569+
}
570+
}
571+
572+
type basicResumeTokenRefresher struct {
573+
logger slog.Logger
574+
ctx context.Context
575+
cancel context.CancelFunc
576+
ctrl *basicResumeTokenController
577+
client ResumeTokenClient
578+
errCh chan error
579+
580+
sync.Mutex
581+
closed bool
582+
timer *quartz.Timer
583+
}
584+
585+
func (r *basicResumeTokenRefresher) Close(_ context.Context) error {
586+
r.cancel()
587+
r.Lock()
588+
defer r.Unlock()
589+
if r.closed {
590+
return nil
591+
}
592+
r.closed = true
593+
r.timer.Stop()
594+
select {
595+
case r.errCh <- nil:
596+
default: // already have an error
597+
}
598+
return nil
599+
}
600+
601+
func (r *basicResumeTokenRefresher) Wait() <-chan error {
602+
return r.errCh
603+
}
604+
605+
const never time.Duration = math.MaxInt64
606+
607+
func newBasicResumeTokenRefresher(
608+
logger slog.Logger, clock quartz.Clock,
609+
ctrl *basicResumeTokenController, client ResumeTokenClient,
610+
) *basicResumeTokenRefresher {
611+
r := &basicResumeTokenRefresher{
612+
logger: logger,
613+
ctrl: ctrl,
614+
client: client,
615+
errCh: make(chan error, 1),
616+
}
617+
r.ctx, r.cancel = context.WithCancel(context.Background())
618+
r.timer = clock.AfterFunc(never, r.refresh)
619+
go r.refresh()
620+
return r
621+
}
622+
623+
func (r *basicResumeTokenRefresher) refresh() {
624+
if r.ctx.Err() != nil {
625+
return // context done, no need to refresh
626+
}
627+
res, err := r.client.RefreshResumeToken(r.ctx, &proto.RefreshResumeTokenRequest{})
628+
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
629+
// these can only come from being closed, no need to log
630+
select {
631+
case r.errCh <- nil:
632+
default: // already have an error
633+
}
634+
return
635+
}
636+
if err != nil {
637+
r.logger.Error(r.ctx, "error refreshing coordinator resume token", slog.Error(err))
638+
select {
639+
case r.errCh <- err:
640+
default: // already have an error
641+
}
642+
return
643+
}
644+
r.logger.Debug(r.ctx, "refreshed coordinator resume token",
645+
slog.F("expires_at", res.GetExpiresAt()),
646+
slog.F("refresh_in", res.GetRefreshIn()),
647+
)
648+
r.ctrl.Lock()
649+
if r.ctrl.refresher == r { // don't overwrite if we're not the current refresher
650+
r.ctrl.token = res
651+
} else {
652+
r.logger.Debug(context.Background(), "not writing token because we have a new client")
653+
}
654+
r.ctrl.Unlock()
655+
dur := res.RefreshIn.AsDuration()
656+
if dur <= 0 {
657+
// A sensible delay to refresh again.
658+
dur = 30 * time.Minute
659+
}
660+
r.Lock()
661+
defer r.Unlock()
662+
if r.closed {
663+
return
664+
}
665+
r.timer.Reset(dur, "basicResumeTokenRefresher", "refresh")
666+
}

tailnet/controllers_test.go

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import (
1313
"github.com/stretchr/testify/require"
1414
"go.uber.org/mock/gomock"
1515
"golang.org/x/xerrors"
16+
"google.golang.org/protobuf/types/known/durationpb"
17+
"google.golang.org/protobuf/types/known/timestamppb"
1618
"storj.io/drpc"
1719
"storj.io/drpc/drpcerr"
1820
"tailscale.com/tailcfg"
@@ -24,6 +26,7 @@ import (
2426
"github.com/coder/coder/v2/tailnet/proto"
2527
"github.com/coder/coder/v2/tailnet/tailnettest"
2628
"github.com/coder/coder/v2/testutil"
29+
"github.com/coder/quartz"
2730
)
2831

2932
func TestInMemoryCoordination(t *testing.T) {
@@ -507,3 +510,171 @@ type fakeTelemetryCall struct {
507510
req *proto.TelemetryRequest
508511
errCh chan error
509512
}
513+
514+
func TestBasicResumeTokenController_Mainline(t *testing.T) {
515+
t.Parallel()
516+
ctx := testutil.Context(t, testutil.WaitShort)
517+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
518+
fr := newFakeResumeTokenClient(ctx)
519+
mClock := quartz.NewMock(t)
520+
trp := mClock.Trap().TimerReset("basicResumeTokenRefresher", "refresh")
521+
defer trp.Close()
522+
523+
uut := tailnet.NewBasicResumeTokenController(logger, mClock)
524+
_, ok := uut.Token()
525+
require.False(t, ok)
526+
527+
cwCh := make(chan tailnet.CloserWaiter, 1)
528+
go func() {
529+
cwCh <- uut.New(fr)
530+
}()
531+
call := testutil.RequireRecvCtx(ctx, t, fr.calls)
532+
testutil.RequireSendCtx(ctx, t, call.resp, &proto.RefreshResumeTokenResponse{
533+
Token: "test token 1",
534+
RefreshIn: durationpb.New(100 * time.Second),
535+
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
536+
})
537+
trp.MustWait(ctx).Release() // initial refresh done
538+
token, ok := uut.Token()
539+
require.True(t, ok)
540+
require.Equal(t, "test token 1", token)
541+
cw := testutil.RequireRecvCtx(ctx, t, cwCh)
542+
543+
w := mClock.Advance(100 * time.Second)
544+
call = testutil.RequireRecvCtx(ctx, t, fr.calls)
545+
testutil.RequireSendCtx(ctx, t, call.resp, &proto.RefreshResumeTokenResponse{
546+
Token: "test token 2",
547+
RefreshIn: durationpb.New(50 * time.Second),
548+
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
549+
})
550+
resetCall := trp.MustWait(ctx)
551+
require.Equal(t, resetCall.Duration, 50*time.Second)
552+
resetCall.Release()
553+
w.MustWait(ctx)
554+
token, ok = uut.Token()
555+
require.True(t, ok)
556+
require.Equal(t, "test token 2", token)
557+
558+
err := cw.Close(ctx)
559+
require.NoError(t, err)
560+
err = testutil.RequireRecvCtx(ctx, t, cw.Wait())
561+
require.NoError(t, err)
562+
563+
token, ok = uut.Token()
564+
require.True(t, ok)
565+
require.Equal(t, "test token 2", token)
566+
567+
mClock.Advance(201 * time.Second).MustWait(ctx)
568+
_, ok = uut.Token()
569+
require.False(t, ok)
570+
}
571+
572+
func TestBasicResumeTokenController_NewWhileRefreshing(t *testing.T) {
573+
t.Parallel()
574+
ctx := testutil.Context(t, testutil.WaitShort)
575+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
576+
mClock := quartz.NewMock(t)
577+
trp := mClock.Trap().TimerReset("basicResumeTokenRefresher", "refresh")
578+
defer trp.Close()
579+
580+
uut := tailnet.NewBasicResumeTokenController(logger, mClock)
581+
_, ok := uut.Token()
582+
require.False(t, ok)
583+
584+
fr1 := newFakeResumeTokenClient(ctx)
585+
cwCh1 := make(chan tailnet.CloserWaiter, 1)
586+
go func() {
587+
cwCh1 <- uut.New(fr1)
588+
}()
589+
call1 := testutil.RequireRecvCtx(ctx, t, fr1.calls)
590+
591+
fr2 := newFakeResumeTokenClient(ctx)
592+
cwCh2 := make(chan tailnet.CloserWaiter, 1)
593+
go func() {
594+
cwCh2 <- uut.New(fr2)
595+
}()
596+
call2 := testutil.RequireRecvCtx(ctx, t, fr2.calls)
597+
598+
testutil.RequireSendCtx(ctx, t, call2.resp, &proto.RefreshResumeTokenResponse{
599+
Token: "test token 2.0",
600+
RefreshIn: durationpb.New(102 * time.Second),
601+
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
602+
})
603+
604+
cw2 := testutil.RequireRecvCtx(ctx, t, cwCh2) // this ensures Close was called on 1
605+
606+
testutil.RequireSendCtx(ctx, t, call1.resp, &proto.RefreshResumeTokenResponse{
607+
Token: "test token 1",
608+
RefreshIn: durationpb.New(101 * time.Second),
609+
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
610+
})
611+
612+
trp.MustWait(ctx).Release()
613+
614+
token, ok := uut.Token()
615+
require.True(t, ok)
616+
require.Equal(t, "test token 2.0", token)
617+
618+
// refresher 1 should already be closed.
619+
cw1 := testutil.RequireRecvCtx(ctx, t, cwCh1)
620+
err := testutil.RequireRecvCtx(ctx, t, cw1.Wait())
621+
require.NoError(t, err)
622+
623+
w := mClock.Advance(102 * time.Second)
624+
call := testutil.RequireRecvCtx(ctx, t, fr2.calls)
625+
testutil.RequireSendCtx(ctx, t, call.resp, &proto.RefreshResumeTokenResponse{
626+
Token: "test token 2.1",
627+
RefreshIn: durationpb.New(50 * time.Second),
628+
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
629+
})
630+
resetCall := trp.MustWait(ctx)
631+
require.Equal(t, resetCall.Duration, 50*time.Second)
632+
resetCall.Release()
633+
w.MustWait(ctx)
634+
token, ok = uut.Token()
635+
require.True(t, ok)
636+
require.Equal(t, "test token 2.1", token)
637+
638+
err = cw2.Close(ctx)
639+
require.NoError(t, err)
640+
err = testutil.RequireRecvCtx(ctx, t, cw2.Wait())
641+
require.NoError(t, err)
642+
}
643+
644+
func newFakeResumeTokenClient(ctx context.Context) *fakeResumeTokenClient {
645+
return &fakeResumeTokenClient{
646+
ctx: ctx,
647+
calls: make(chan *fakeResumeTokenCall),
648+
}
649+
}
650+
651+
type fakeResumeTokenClient struct {
652+
ctx context.Context
653+
calls chan *fakeResumeTokenCall
654+
}
655+
656+
func (f *fakeResumeTokenClient) RefreshResumeToken(_ context.Context, _ *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error) {
657+
call := &fakeResumeTokenCall{
658+
resp: make(chan *proto.RefreshResumeTokenResponse),
659+
errCh: make(chan error),
660+
}
661+
select {
662+
case <-f.ctx.Done():
663+
return nil, f.ctx.Err()
664+
case f.calls <- call:
665+
// OK
666+
}
667+
select {
668+
case <-f.ctx.Done():
669+
return nil, f.ctx.Err()
670+
case err := <-call.errCh:
671+
return nil, err
672+
case resp := <-call.resp:
673+
return resp, nil
674+
}
675+
}
676+
677+
type fakeResumeTokenCall struct {
678+
resp chan *proto.RefreshResumeTokenResponse
679+
errCh chan error
680+
}

0 commit comments

Comments
 (0)