From 8cfb9f0c55021f45c652603cc86dda2a1cd43f10 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 23 May 2024 13:30:57 +0400 Subject: [PATCH] fix: wait for PGCoordinator to clean up db state --- enterprise/tailnet/pgcoord.go | 23 +++++++++- enterprise/tailnet/pgcoord_internal_test.go | 1 + enterprise/tailnet/pgcoord_test.go | 47 +++++++++++++++++++++ 3 files changed, 69 insertions(+), 2 deletions(-) diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index baccfe66a7fd7..857cdafe94e79 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -161,11 +161,12 @@ func newPGCoordInternal( closed: make(chan struct{}), } go func() { - // when the main context is canceled, or the coordinator closed, the binder and tunneler - // always eventually stop. Once they stop it's safe to cancel the querier context, which + // when the main context is canceled, or the coordinator closed, the binder, tunneler, and + // handshaker always eventually stop. Once they stop it's safe to cancel the querier context, which // has the effect of deleting the coordinator from the database and ceasing heartbeats. c.binder.workerWG.Wait() c.tunneler.workerWG.Wait() + c.handshaker.workerWG.Wait() querierCancel() }() logger.Info(ctx, "starting coordinator") @@ -231,6 +232,7 @@ func (c *pgCoord) Close() error { c.logger.Info(c.ctx, "closing coordinator") c.cancel() c.closeOnce.Do(func() { close(c.closed) }) + c.querier.wait() return nil } @@ -795,6 +797,8 @@ type querier struct { workQ *workQ[querierWorkKey] + wg sync.WaitGroup + heartbeats *heartbeats updates <-chan hbUpdate @@ -831,6 +835,7 @@ func newQuerier(ctx context.Context, } q.subscribe() + q.wg.Add(2 + numWorkers) go func() { <-firstHeartbeat go q.handleIncoming() @@ -842,7 +847,13 @@ func newQuerier(ctx context.Context, return q } +func (q *querier) wait() { + q.wg.Wait() + q.heartbeats.wg.Wait() +} + func (q *querier) handleIncoming() { + defer q.wg.Done() for { select { case <-q.ctx.Done(): @@ -919,6 +930,7 @@ func (q *querier) cleanupConn(c *connIO) { } func (q *querier) worker() { + defer q.wg.Done() eb := backoff.NewExponentialBackOff() eb.MaxElapsedTime = 0 // retry indefinitely eb.MaxInterval = dbMaxBackoff @@ -1204,6 +1216,7 @@ func (q *querier) resyncPeerMappings() { } func (q *querier) handleUpdates() { + defer q.wg.Done() for { select { case <-q.ctx.Done(): @@ -1451,6 +1464,8 @@ type heartbeats struct { coordinators map[uuid.UUID]time.Time timer *time.Timer + wg sync.WaitGroup + // overwritten in tests, but otherwise constant cleanupPeriod time.Duration } @@ -1472,6 +1487,7 @@ func newHeartbeats( coordinators: make(map[uuid.UUID]time.Time), cleanupPeriod: cleanupPeriod, } + h.wg.Add(3) go h.subscribe() go h.sendBeats() go h.cleanupLoop() @@ -1502,6 +1518,7 @@ func (h *heartbeats) filter(mappings []mapping) []mapping { } func (h *heartbeats) subscribe() { + defer h.wg.Done() eb := backoff.NewExponentialBackOff() eb.MaxElapsedTime = 0 // retry indefinitely eb.MaxInterval = dbMaxBackoff @@ -1611,6 +1628,7 @@ func (h *heartbeats) checkExpiry() { } func (h *heartbeats) sendBeats() { + defer h.wg.Done() // send an initial heartbeat so that other coordinators can start using our bindings right away. h.sendBeat() close(h.firstHeartbeat) // signal binder it can start writing @@ -1662,6 +1680,7 @@ func (h *heartbeats) sendDelete() { } func (h *heartbeats) cleanupLoop() { + defer h.wg.Done() h.cleanup() tkr := time.NewTicker(h.cleanupPeriod) defer tkr.Stop() diff --git a/enterprise/tailnet/pgcoord_internal_test.go b/enterprise/tailnet/pgcoord_internal_test.go index 53fd61d73f066..4607e6fb2ab2f 100644 --- a/enterprise/tailnet/pgcoord_internal_test.go +++ b/enterprise/tailnet/pgcoord_internal_test.go @@ -66,6 +66,7 @@ func TestHeartbeats_Cleanup(t *testing.T) { store: mStore, cleanupPeriod: time.Millisecond, } + uut.wg.Add(1) go uut.cleanupLoop() for i := 0; i < 6; i++ { diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index 5bd722533dc39..9c363ee700570 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -864,6 +864,53 @@ func TestPGCoordinator_Lost(t *testing.T) { agpltest.LostTest(ctx, t, coordinator) } +func TestPGCoordinator_DeleteOnClose(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + ctrl := gomock.NewController(t) + mStore := dbmock.NewMockStore(ctrl) + ps := pubsub.NewInMemory() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + + upsertDone := make(chan struct{}) + deleteCalled := make(chan struct{}) + finishDelete := make(chan struct{}) + mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()). + MinTimes(1). + Do(func(_ context.Context, _ uuid.UUID) { close(upsertDone) }). + Return(database.TailnetCoordinator{}, nil) + mStore.EXPECT().DeleteCoordinator(gomock.Any(), gomock.Any()). + Times(1). + Do(func(_ context.Context, _ uuid.UUID) { + close(deleteCalled) + <-finishDelete + }). + Return(nil) + + // extra calls we don't particularly care about for this test + mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil) + mStore.EXPECT().CleanTailnetLostPeers(gomock.Any()).AnyTimes().Return(nil) + mStore.EXPECT().CleanTailnetTunnels(gomock.Any()).AnyTimes().Return(nil) + + uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore) + require.NoError(t, err) + testutil.RequireRecvCtx(ctx, t, upsertDone) + closeErr := make(chan error, 1) + go func() { + closeErr <- uut.Close() + }() + select { + case <-closeErr: + t.Fatal("close returned before DeleteCoordinator called") + case <-deleteCalled: + close(finishDelete) + err := testutil.RequireRecvCtx(ctx, t, closeErr) + require.NoError(t, err) + } +} + type testConn struct { ws, serverWS net.Conn nodeChan chan []*agpl.Node