Skip to content

Commit ba21ba8

Browse files
committed
fix(tailnet): prevent redial after Coord graceful restart
1 parent c3c23ed commit ba21ba8

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

tailnet/controllers.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,10 +1163,17 @@ func (c *Controller) Run(ctx context.Context) {
11631163
// Sadly retry doesn't support quartz.Clock yet so this is not
11641164
// influenced by the configured clock.
11651165
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(c.ctx); {
1166+
// Check the context again before dialing, since `retrier.Wait()` could return true
1167+
// if the delay is 0, even if the context was canceled. This ensures we don't redial
1168+
// after a graceful shutdown.
1169+
if c.ctx.Err() != nil {
1170+
return
1171+
}
1172+
11661173
tailnetClients, err := c.Dialer.Dial(c.ctx, c.ResumeTokenCtrl)
11671174
if err != nil {
1168-
if xerrors.Is(err, context.Canceled) {
1169-
continue
1175+
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
1176+
return
11701177
}
11711178
errF := slog.Error(err)
11721179
var sdkErr *codersdk.Error

tailnet/controllers_test.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import (
1414

1515
"github.com/google/uuid"
1616
"github.com/hashicorp/yamux"
17-
"github.com/stretchr/testify/assert"
1817
"github.com/stretchr/testify/require"
1918
"go.uber.org/mock/gomock"
2019
"golang.org/x/xerrors"
@@ -1071,6 +1070,7 @@ func TestController_Disconnects(t *testing.T) {
10711070
close(call.Resps)
10721071

10731072
_ = testutil.RequireRecvCtx(testCtx, t, peersLost)
1073+
_ = testutil.RequireRecvCtx(testCtx, t, uut.Closed())
10741074
}
10751075

10761076
func TestController_TelemetrySuccess(t *testing.T) {
@@ -1210,24 +1210,28 @@ type pipeDialer struct {
12101210

12111211
func (p *pipeDialer) Dial(_ context.Context, _ tailnet.ResumeTokenController) (tailnet.ControlProtocolClients, error) {
12121212
s, c := net.Pipe()
1213+
p.t.Cleanup(func() {
1214+
_ = s.Close()
1215+
_ = c.Close()
1216+
})
12131217
go func() {
12141218
err := p.svc.ServeConnV2(p.ctx, s, p.streamID)
12151219
p.logger.Debug(p.ctx, "piped tailnet service complete", slog.Error(err))
12161220
}()
12171221
client, err := tailnet.NewDRPCClient(c, p.logger)
1218-
if !assert.NoError(p.t, err) {
1222+
if err != nil {
12191223
_ = c.Close()
12201224
return tailnet.ControlProtocolClients{}, err
12211225
}
12221226
coord, err := client.Coordinate(context.Background())
1223-
if !assert.NoError(p.t, err) {
1227+
if err != nil {
12241228
_ = c.Close()
12251229
return tailnet.ControlProtocolClients{}, err
12261230
}
12271231

12281232
derps := &tailnet.DERPFromDRPCWrapper{}
12291233
derps.Client, err = client.StreamDERPMaps(context.Background(), &proto.StreamDERPMapsRequest{})
1230-
if !assert.NoError(p.t, err) {
1234+
if err != nil {
12311235
_ = c.Close()
12321236
return tailnet.ControlProtocolClients{}, err
12331237
}

0 commit comments

Comments
 (0)