Skip to content

fix: avoid deleting peers on graceful close #14165

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 14, 2024
36 changes: 21 additions & 15 deletions enterprise/tailnet/pgcoord.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,19 @@ import (
)

const (
EventHeartbeats = "tailnet_coordinator_heartbeat"
eventPeerUpdate = "tailnet_peer_update"
eventTunnelUpdate = "tailnet_tunnel_update"
eventReadyForHandshake = "tailnet_ready_for_handshake"
HeartbeatPeriod = time.Second * 2
MissedHeartbeats = 3
numQuerierWorkers = 10
numBinderWorkers = 10
numTunnelerWorkers = 10
numHandshakerWorkers = 5
dbMaxBackoff = 10 * time.Second
cleanupPeriod = time.Hour
EventHeartbeats = "tailnet_coordinator_heartbeat"
eventPeerUpdate = "tailnet_peer_update"
eventTunnelUpdate = "tailnet_tunnel_update"
eventReadyForHandshake = "tailnet_ready_for_handshake"
HeartbeatPeriod = time.Second * 2
MissedHeartbeats = 3
unhealthyHeartbeatThreshold = 3
numQuerierWorkers = 10
numBinderWorkers = 10
numTunnelerWorkers = 10
numHandshakerWorkers = 5
dbMaxBackoff = 10 * time.Second
cleanupPeriod = time.Hour
)

// pgCoord is a postgres-backed coordinator
Expand Down Expand Up @@ -1646,13 +1647,18 @@ func (h *heartbeats) sendBeats() {
// send an initial heartbeat so that other coordinators can start using our bindings right away.
h.sendBeat()
close(h.firstHeartbeat) // signal binder it can start writing
defer h.sendDelete()
tkr := h.clock.TickerFunc(h.ctx, HeartbeatPeriod, func() error {
h.sendBeat()
return nil
}, "heartbeats", "sendBeats")
err := tkr.Wait()
h.logger.Debug(h.ctx, "ending heartbeats", slog.Error(err))
// This is unlikely to succeed if we're unhealthy but
// we get it our best effort.
if h.failedHeartbeats >= unhealthyHeartbeatThreshold {
h.logger.Debug(h.ctx, "coordinator detected unhealthy, deleting self", slog.Error(err))
h.sendDelete()
}
}

func (h *heartbeats) sendBeat() {
Expand All @@ -1663,14 +1669,14 @@ func (h *heartbeats) sendBeat() {
if err != nil {
h.logger.Error(h.ctx, "failed to send heartbeat", slog.Error(err))
h.failedHeartbeats++
if h.failedHeartbeats == 3 {
if h.failedHeartbeats == unhealthyHeartbeatThreshold {
h.logger.Error(h.ctx, "coordinator failed 3 heartbeats and is unhealthy")
_ = agpl.SendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateUnhealthy})
}
return
}
h.logger.Debug(h.ctx, "sent heartbeat")
if h.failedHeartbeats >= 3 {
if h.failedHeartbeats >= unhealthyHeartbeatThreshold {
h.logger.Info(h.ctx, "coordinator sent heartbeat and is healthy")
_ = agpl.SendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateHealthy})
}
Expand Down
98 changes: 38 additions & 60 deletions enterprise/tailnet/pgcoord_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ func TestPGCoordinatorSingle_SendsHeartbeats(t *testing.T) {

mu := sync.Mutex{}
heartbeats := []time.Time{}
unsub, err := ps.SubscribeWithErr(tailnet.EventHeartbeats, func(_ context.Context, msg []byte, err error) {
unsub, err := ps.SubscribeWithErr(tailnet.EventHeartbeats, func(_ context.Context, _ []byte, err error) {
assert.NoError(t, err)
mu.Lock()
defer mu.Unlock()
Expand Down Expand Up @@ -592,8 +592,6 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) {
err = client21.recvErr(ctx, t)
require.ErrorIs(t, err, io.EOF)

assertEventuallyNoAgents(ctx, t, store, agent2.id)

t.Logf("close coord1")
err = coord1.Close()
require.NoError(t, err)
Expand Down Expand Up @@ -629,10 +627,6 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) {
err = client22.close()
require.NoError(t, err)
client22.waitForClose(ctx, t)

assertEventuallyNoAgents(ctx, t, store, agent1.id)
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
assertEventuallyNoClientsForAgent(ctx, t, store, agent2.id)
}

// TestPGCoordinator_MultiCoordinatorAgent tests when a single agent connects to multiple coordinators.
Expand Down Expand Up @@ -746,7 +740,6 @@ func TestPGCoordinator_Unhealthy(t *testing.T) {
mStore.EXPECT().DeleteTailnetPeer(gomock.Any(), gomock.Any()).
AnyTimes().Return(database.DeleteTailnetPeerRow{}, nil)
mStore.EXPECT().DeleteAllTailnetTunnels(gomock.Any(), gomock.Any()).AnyTimes().Return(nil)
mStore.EXPECT().DeleteCoordinator(gomock.Any(), gomock.Any()).AnyTimes().Return(nil)

uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore)
require.NoError(t, err)
Expand Down Expand Up @@ -871,51 +864,50 @@ func TestPGCoordinator_Lost(t *testing.T) {
agpltest.LostTest(ctx, t, coordinator)
}

func TestPGCoordinator_DeleteOnClose(t *testing.T) {
func TestPGCoordinator_NoDeleteOnClose(t *testing.T) {
t.Parallel()

if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
ctrl := gomock.NewController(t)
mStore := dbmock.NewMockStore(ctrl)
ps := pubsub.NewInMemory()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()

upsertDone := make(chan struct{})
deleteCalled := make(chan struct{})
finishDelete := make(chan struct{})
mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()).
MinTimes(1).
Do(func(_ context.Context, _ uuid.UUID) { close(upsertDone) }).
Return(database.TailnetCoordinator{}, nil)
mStore.EXPECT().DeleteCoordinator(gomock.Any(), gomock.Any()).
Times(1).
Do(func(_ context.Context, _ uuid.UUID) {
close(deleteCalled)
<-finishDelete
}).
Return(nil)
agent := newTestAgent(t, coordinator, "original")
defer agent.close()
agent.sendNode(&agpl.Node{PreferredDERP: 10})

// extra calls we don't particularly care about for this test
mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil)
mStore.EXPECT().CleanTailnetLostPeers(gomock.Any()).AnyTimes().Return(nil)
mStore.EXPECT().CleanTailnetTunnels(gomock.Any()).AnyTimes().Return(nil)
client := newTestClient(t, coordinator, agent.id)
defer client.close()

uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore)
// Simulate some traffic to generate
// a peer.
agentNodes := client.recvNodes(ctx, t)
require.Len(t, agentNodes, 1)
assert.Equal(t, 10, agentNodes[0].PreferredDERP)
client.sendNode(&agpl.Node{PreferredDERP: 11})
clientNodes := agent.recvNodes(ctx, t)
require.Len(t, clientNodes, 1)
assert.Equal(t, 11, clientNodes[0].PreferredDERP)

anode := coordinator.Node(agent.id)
require.NotNil(t, anode)
cnode := coordinator.Node(client.id)
require.NotNil(t, cnode)

err = coordinator.Close()
require.NoError(t, err)
testutil.RequireRecvCtx(ctx, t, upsertDone)
closeErr := make(chan error, 1)
go func() {
closeErr <- uut.Close()
}()
select {
case <-closeErr:
t.Fatal("close returned before DeleteCoordinator called")
case <-deleteCalled:
close(finishDelete)
err := testutil.RequireRecvCtx(ctx, t, closeErr)
require.NoError(t, err)
}

coordinator2, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator2.Close()

anode = coordinator2.Node(agent.id)
require.NotNil(t, anode)
}

type testConn struct {
Expand Down Expand Up @@ -1056,20 +1048,6 @@ func assertNeverHasDERPs(ctx context.Context, t *testing.T, c *testConn, expecte
}
}

func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) {
t.Helper()
assert.Eventually(t, func() bool {
agents, err := store.GetTailnetPeers(ctx, agentID)
if xerrors.Is(err, sql.ErrNoRows) {
return true
}
if err != nil {
t.Fatal(err)
}
return len(agents) == 0
}, testutil.WaitShort, testutil.IntervalFast)
}

func assertEventuallyLost(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) {
t.Helper()
assert.Eventually(t, func() bool {
Expand Down
Loading