From ba21ba87ba2209fad3c9f4bb131d7de1fc0e58be Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Tue, 19 Nov 2024 11:18:07 +0400 Subject: [PATCH] fix(tailnet): prevent redial after Coord graceful restart --- tailnet/controllers.go | 11 +++++++++-- tailnet/controllers_test.go | 12 ++++++++---- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/tailnet/controllers.go b/tailnet/controllers.go index 0afe74efb837e..2cb03691503be 100644 --- a/tailnet/controllers.go +++ b/tailnet/controllers.go @@ -1163,10 +1163,17 @@ func (c *Controller) Run(ctx context.Context) { // Sadly retry doesn't support quartz.Clock yet so this is not // influenced by the configured clock. for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(c.ctx); { + // Check the context again before dialing, since `retrier.Wait()` could return true + // if the delay is 0, even if the context was canceled. This ensures we don't redial + // after a graceful shutdown. + if c.ctx.Err() != nil { + return + } + tailnetClients, err := c.Dialer.Dial(c.ctx, c.ResumeTokenCtrl) if err != nil { - if xerrors.Is(err, context.Canceled) { - continue + if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) { + return } errF := slog.Error(err) var sdkErr *codersdk.Error diff --git a/tailnet/controllers_test.go b/tailnet/controllers_test.go index 53ffe005825df..11ace3e073b26 100644 --- a/tailnet/controllers_test.go +++ b/tailnet/controllers_test.go @@ -14,7 +14,6 @@ import ( "github.com/google/uuid" "github.com/hashicorp/yamux" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "golang.org/x/xerrors" @@ -1071,6 +1070,7 @@ func TestController_Disconnects(t *testing.T) { close(call.Resps) _ = testutil.RequireRecvCtx(testCtx, t, peersLost) + _ = testutil.RequireRecvCtx(testCtx, t, uut.Closed()) } func TestController_TelemetrySuccess(t *testing.T) { @@ -1210,24 +1210,28 @@ type pipeDialer struct { func (p *pipeDialer) Dial(_ context.Context, _ tailnet.ResumeTokenController) (tailnet.ControlProtocolClients, error) { s, c := net.Pipe() + p.t.Cleanup(func() { + _ = s.Close() + _ = c.Close() + }) go func() { err := p.svc.ServeConnV2(p.ctx, s, p.streamID) p.logger.Debug(p.ctx, "piped tailnet service complete", slog.Error(err)) }() client, err := tailnet.NewDRPCClient(c, p.logger) - if !assert.NoError(p.t, err) { + if err != nil { _ = c.Close() return tailnet.ControlProtocolClients{}, err } coord, err := client.Coordinate(context.Background()) - if !assert.NoError(p.t, err) { + if err != nil { _ = c.Close() return tailnet.ControlProtocolClients{}, err } derps := &tailnet.DERPFromDRPCWrapper{} derps.Client, err = client.StreamDERPMaps(context.Background(), &proto.StreamDERPMapsRequest{}) - if !assert.NoError(p.t, err) { + if err != nil { _ = c.Close() return tailnet.ControlProtocolClients{}, err }