From 49baee86fe6282ab9c453caf3b5a02c9d18e1027 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Mon, 28 Oct 2024 13:48:27 +0400 Subject: [PATCH] chore: add resume token controller --- tailnet/controllers.go | 141 +++++++++++++++++++++++++++++ tailnet/controllers_test.go | 171 ++++++++++++++++++++++++++++++++++++ 2 files changed, 312 insertions(+) diff --git a/tailnet/controllers.go b/tailnet/controllers.go index 3176d70129a86..7a3e23e2e216d 100644 --- a/tailnet/controllers.go +++ b/tailnet/controllers.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "math" "strings" "sync" "time" @@ -16,6 +17,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/tailnet/proto" + "github.com/coder/quartz" ) // A Controller connects to the tailnet control plane, and then uses the control protocols to @@ -523,3 +525,142 @@ func sendTelemetry( } return false } + +type basicResumeTokenController struct { + logger slog.Logger + + sync.Mutex + token *proto.RefreshResumeTokenResponse + refresher *basicResumeTokenRefresher + + // for testing + clock quartz.Clock +} + +func (b *basicResumeTokenController) New(client ResumeTokenClient) CloserWaiter { + b.Lock() + defer b.Unlock() + if b.refresher != nil { + cErr := b.refresher.Close(context.Background()) + if cErr != nil { + b.logger.Debug(context.Background(), "closed previous refresher", slog.Error(cErr)) + } + } + b.refresher = newBasicResumeTokenRefresher(b.logger, b.clock, b, client) + return b.refresher +} + +func (b *basicResumeTokenController) Token() (string, bool) { + b.Lock() + defer b.Unlock() + if b.token == nil { + return "", false + } + if b.token.ExpiresAt.AsTime().Before(b.clock.Now()) { + return "", false + } + return b.token.Token, true +} + +func NewBasicResumeTokenController(logger slog.Logger, clock quartz.Clock) ResumeTokenController { + return &basicResumeTokenController{ + logger: logger, + clock: clock, + } +} + +type basicResumeTokenRefresher struct { + logger slog.Logger + ctx context.Context + cancel context.CancelFunc + ctrl *basicResumeTokenController + client ResumeTokenClient + errCh chan error + + sync.Mutex + closed bool + timer *quartz.Timer +} + +func (r *basicResumeTokenRefresher) Close(_ context.Context) error { + r.cancel() + r.Lock() + defer r.Unlock() + if r.closed { + return nil + } + r.closed = true + r.timer.Stop() + select { + case r.errCh <- nil: + default: // already have an error + } + return nil +} + +func (r *basicResumeTokenRefresher) Wait() <-chan error { + return r.errCh +} + +const never time.Duration = math.MaxInt64 + +func newBasicResumeTokenRefresher( + logger slog.Logger, clock quartz.Clock, + ctrl *basicResumeTokenController, client ResumeTokenClient, +) *basicResumeTokenRefresher { + r := &basicResumeTokenRefresher{ + logger: logger, + ctrl: ctrl, + client: client, + errCh: make(chan error, 1), + } + r.ctx, r.cancel = context.WithCancel(context.Background()) + r.timer = clock.AfterFunc(never, r.refresh) + go r.refresh() + return r +} + +func (r *basicResumeTokenRefresher) refresh() { + if r.ctx.Err() != nil { + return // context done, no need to refresh + } + res, err := r.client.RefreshResumeToken(r.ctx, &proto.RefreshResumeTokenRequest{}) + if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) { + // these can only come from being closed, no need to log + select { + case r.errCh <- nil: + default: // already have an error + } + return + } + if err != nil { + r.logger.Error(r.ctx, "error refreshing coordinator resume token", slog.Error(err)) + select { + case r.errCh <- err: + default: // already have an error + } + return + } + r.logger.Debug(r.ctx, "refreshed coordinator resume token", + slog.F("expires_at", res.GetExpiresAt()), + slog.F("refresh_in", res.GetRefreshIn()), + ) + r.ctrl.Lock() + if r.ctrl.refresher == r { // don't overwrite if we're not the current refresher + r.ctrl.token = res + } else { + r.logger.Debug(context.Background(), "not writing token because we have a new client") + } + r.ctrl.Unlock() + dur := res.RefreshIn.AsDuration() + if dur <= 0 { + // A sensible delay to refresh again. + dur = 30 * time.Minute + } + r.Lock() + defer r.Unlock() + if r.closed { + return + } + r.timer.Reset(dur, "basicResumeTokenRefresher", "refresh") +} diff --git a/tailnet/controllers_test.go b/tailnet/controllers_test.go index 7c810af0c0077..d3f88ad23cae3 100644 --- a/tailnet/controllers_test.go +++ b/tailnet/controllers_test.go @@ -13,6 +13,8 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "golang.org/x/xerrors" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" "storj.io/drpc" "storj.io/drpc/drpcerr" "tailscale.com/tailcfg" @@ -24,6 +26,7 @@ import ( "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/tailnet/tailnettest" "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" ) func TestInMemoryCoordination(t *testing.T) { @@ -507,3 +510,171 @@ type fakeTelemetryCall struct { req *proto.TelemetryRequest errCh chan error } + +func TestBasicResumeTokenController_Mainline(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + fr := newFakeResumeTokenClient(ctx) + mClock := quartz.NewMock(t) + trp := mClock.Trap().TimerReset("basicResumeTokenRefresher", "refresh") + defer trp.Close() + + uut := tailnet.NewBasicResumeTokenController(logger, mClock) + _, ok := uut.Token() + require.False(t, ok) + + cwCh := make(chan tailnet.CloserWaiter, 1) + go func() { + cwCh <- uut.New(fr) + }() + call := testutil.RequireRecvCtx(ctx, t, fr.calls) + testutil.RequireSendCtx(ctx, t, call.resp, &proto.RefreshResumeTokenResponse{ + Token: "test token 1", + RefreshIn: durationpb.New(100 * time.Second), + ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)), + }) + trp.MustWait(ctx).Release() // initial refresh done + token, ok := uut.Token() + require.True(t, ok) + require.Equal(t, "test token 1", token) + cw := testutil.RequireRecvCtx(ctx, t, cwCh) + + w := mClock.Advance(100 * time.Second) + call = testutil.RequireRecvCtx(ctx, t, fr.calls) + testutil.RequireSendCtx(ctx, t, call.resp, &proto.RefreshResumeTokenResponse{ + Token: "test token 2", + RefreshIn: durationpb.New(50 * time.Second), + ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)), + }) + resetCall := trp.MustWait(ctx) + require.Equal(t, resetCall.Duration, 50*time.Second) + resetCall.Release() + w.MustWait(ctx) + token, ok = uut.Token() + require.True(t, ok) + require.Equal(t, "test token 2", token) + + err := cw.Close(ctx) + require.NoError(t, err) + err = testutil.RequireRecvCtx(ctx, t, cw.Wait()) + require.NoError(t, err) + + token, ok = uut.Token() + require.True(t, ok) + require.Equal(t, "test token 2", token) + + mClock.Advance(201 * time.Second).MustWait(ctx) + _, ok = uut.Token() + require.False(t, ok) +} + +func TestBasicResumeTokenController_NewWhileRefreshing(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + mClock := quartz.NewMock(t) + trp := mClock.Trap().TimerReset("basicResumeTokenRefresher", "refresh") + defer trp.Close() + + uut := tailnet.NewBasicResumeTokenController(logger, mClock) + _, ok := uut.Token() + require.False(t, ok) + + fr1 := newFakeResumeTokenClient(ctx) + cwCh1 := make(chan tailnet.CloserWaiter, 1) + go func() { + cwCh1 <- uut.New(fr1) + }() + call1 := testutil.RequireRecvCtx(ctx, t, fr1.calls) + + fr2 := newFakeResumeTokenClient(ctx) + cwCh2 := make(chan tailnet.CloserWaiter, 1) + go func() { + cwCh2 <- uut.New(fr2) + }() + call2 := testutil.RequireRecvCtx(ctx, t, fr2.calls) + + testutil.RequireSendCtx(ctx, t, call2.resp, &proto.RefreshResumeTokenResponse{ + Token: "test token 2.0", + RefreshIn: durationpb.New(102 * time.Second), + ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)), + }) + + cw2 := testutil.RequireRecvCtx(ctx, t, cwCh2) // this ensures Close was called on 1 + + testutil.RequireSendCtx(ctx, t, call1.resp, &proto.RefreshResumeTokenResponse{ + Token: "test token 1", + RefreshIn: durationpb.New(101 * time.Second), + ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)), + }) + + trp.MustWait(ctx).Release() + + token, ok := uut.Token() + require.True(t, ok) + require.Equal(t, "test token 2.0", token) + + // refresher 1 should already be closed. + cw1 := testutil.RequireRecvCtx(ctx, t, cwCh1) + err := testutil.RequireRecvCtx(ctx, t, cw1.Wait()) + require.NoError(t, err) + + w := mClock.Advance(102 * time.Second) + call := testutil.RequireRecvCtx(ctx, t, fr2.calls) + testutil.RequireSendCtx(ctx, t, call.resp, &proto.RefreshResumeTokenResponse{ + Token: "test token 2.1", + RefreshIn: durationpb.New(50 * time.Second), + ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)), + }) + resetCall := trp.MustWait(ctx) + require.Equal(t, resetCall.Duration, 50*time.Second) + resetCall.Release() + w.MustWait(ctx) + token, ok = uut.Token() + require.True(t, ok) + require.Equal(t, "test token 2.1", token) + + err = cw2.Close(ctx) + require.NoError(t, err) + err = testutil.RequireRecvCtx(ctx, t, cw2.Wait()) + require.NoError(t, err) +} + +func newFakeResumeTokenClient(ctx context.Context) *fakeResumeTokenClient { + return &fakeResumeTokenClient{ + ctx: ctx, + calls: make(chan *fakeResumeTokenCall), + } +} + +type fakeResumeTokenClient struct { + ctx context.Context + calls chan *fakeResumeTokenCall +} + +func (f *fakeResumeTokenClient) RefreshResumeToken(_ context.Context, _ *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error) { + call := &fakeResumeTokenCall{ + resp: make(chan *proto.RefreshResumeTokenResponse), + errCh: make(chan error), + } + select { + case <-f.ctx.Done(): + return nil, f.ctx.Err() + case f.calls <- call: + // OK + } + select { + case <-f.ctx.Done(): + return nil, f.ctx.Err() + case err := <-call.errCh: + return nil, err + case resp := <-call.resp: + return resp, nil + } +} + +type fakeResumeTokenCall struct { + resp chan *proto.RefreshResumeTokenResponse + errCh chan error +}