Skip to content

Commit a0962ba

Browse files
authored
fix: wait for PGCoordinator to clean up db state (#13351)
c.f. #13192 (comment) We need to wait for PGCoordinator to finish its work before returning on `Close()`, so that we delete database state (best effort -- if this fails others will filter it out based on heartbeats).
1 parent e5bb0a7 commit a0962ba

File tree

3 files changed

+69
-2
lines changed

3 files changed

+69
-2
lines changed

enterprise/tailnet/pgcoord.go

+21-2
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,12 @@ func newPGCoordInternal(
161161
closed: make(chan struct{}),
162162
}
163163
go func() {
164-
// when the main context is canceled, or the coordinator closed, the binder and tunneler
165-
// always eventually stop. Once they stop it's safe to cancel the querier context, which
164+
// when the main context is canceled, or the coordinator closed, the binder, tunneler, and
165+
// handshaker always eventually stop. Once they stop it's safe to cancel the querier context, which
166166
// has the effect of deleting the coordinator from the database and ceasing heartbeats.
167167
c.binder.workerWG.Wait()
168168
c.tunneler.workerWG.Wait()
169+
c.handshaker.workerWG.Wait()
169170
querierCancel()
170171
}()
171172
logger.Info(ctx, "starting coordinator")
@@ -231,6 +232,7 @@ func (c *pgCoord) Close() error {
231232
c.logger.Info(c.ctx, "closing coordinator")
232233
c.cancel()
233234
c.closeOnce.Do(func() { close(c.closed) })
235+
c.querier.wait()
234236
return nil
235237
}
236238

@@ -795,6 +797,8 @@ type querier struct {
795797

796798
workQ *workQ[querierWorkKey]
797799

800+
wg sync.WaitGroup
801+
798802
heartbeats *heartbeats
799803
updates <-chan hbUpdate
800804

@@ -831,6 +835,7 @@ func newQuerier(ctx context.Context,
831835
}
832836
q.subscribe()
833837

838+
q.wg.Add(2 + numWorkers)
834839
go func() {
835840
<-firstHeartbeat
836841
go q.handleIncoming()
@@ -842,7 +847,13 @@ func newQuerier(ctx context.Context,
842847
return q
843848
}
844849

850+
func (q *querier) wait() {
851+
q.wg.Wait()
852+
q.heartbeats.wg.Wait()
853+
}
854+
845855
func (q *querier) handleIncoming() {
856+
defer q.wg.Done()
846857
for {
847858
select {
848859
case <-q.ctx.Done():
@@ -919,6 +930,7 @@ func (q *querier) cleanupConn(c *connIO) {
919930
}
920931

921932
func (q *querier) worker() {
933+
defer q.wg.Done()
922934
eb := backoff.NewExponentialBackOff()
923935
eb.MaxElapsedTime = 0 // retry indefinitely
924936
eb.MaxInterval = dbMaxBackoff
@@ -1204,6 +1216,7 @@ func (q *querier) resyncPeerMappings() {
12041216
}
12051217

12061218
func (q *querier) handleUpdates() {
1219+
defer q.wg.Done()
12071220
for {
12081221
select {
12091222
case <-q.ctx.Done():
@@ -1451,6 +1464,8 @@ type heartbeats struct {
14511464
coordinators map[uuid.UUID]time.Time
14521465
timer *time.Timer
14531466

1467+
wg sync.WaitGroup
1468+
14541469
// overwritten in tests, but otherwise constant
14551470
cleanupPeriod time.Duration
14561471
}
@@ -1472,6 +1487,7 @@ func newHeartbeats(
14721487
coordinators: make(map[uuid.UUID]time.Time),
14731488
cleanupPeriod: cleanupPeriod,
14741489
}
1490+
h.wg.Add(3)
14751491
go h.subscribe()
14761492
go h.sendBeats()
14771493
go h.cleanupLoop()
@@ -1502,6 +1518,7 @@ func (h *heartbeats) filter(mappings []mapping) []mapping {
15021518
}
15031519

15041520
func (h *heartbeats) subscribe() {
1521+
defer h.wg.Done()
15051522
eb := backoff.NewExponentialBackOff()
15061523
eb.MaxElapsedTime = 0 // retry indefinitely
15071524
eb.MaxInterval = dbMaxBackoff
@@ -1611,6 +1628,7 @@ func (h *heartbeats) checkExpiry() {
16111628
}
16121629

16131630
func (h *heartbeats) sendBeats() {
1631+
defer h.wg.Done()
16141632
// send an initial heartbeat so that other coordinators can start using our bindings right away.
16151633
h.sendBeat()
16161634
close(h.firstHeartbeat) // signal binder it can start writing
@@ -1662,6 +1680,7 @@ func (h *heartbeats) sendDelete() {
16621680
}
16631681

16641682
func (h *heartbeats) cleanupLoop() {
1683+
defer h.wg.Done()
16651684
h.cleanup()
16661685
tkr := time.NewTicker(h.cleanupPeriod)
16671686
defer tkr.Stop()

enterprise/tailnet/pgcoord_internal_test.go

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ func TestHeartbeats_Cleanup(t *testing.T) {
6666
store: mStore,
6767
cleanupPeriod: time.Millisecond,
6868
}
69+
uut.wg.Add(1)
6970
go uut.cleanupLoop()
7071

7172
for i := 0; i < 6; i++ {

enterprise/tailnet/pgcoord_test.go

+47
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,53 @@ func TestPGCoordinator_Lost(t *testing.T) {
864864
agpltest.LostTest(ctx, t, coordinator)
865865
}
866866

867+
func TestPGCoordinator_DeleteOnClose(t *testing.T) {
868+
t.Parallel()
869+
870+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
871+
defer cancel()
872+
ctrl := gomock.NewController(t)
873+
mStore := dbmock.NewMockStore(ctrl)
874+
ps := pubsub.NewInMemory()
875+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
876+
877+
upsertDone := make(chan struct{})
878+
deleteCalled := make(chan struct{})
879+
finishDelete := make(chan struct{})
880+
mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()).
881+
MinTimes(1).
882+
Do(func(_ context.Context, _ uuid.UUID) { close(upsertDone) }).
883+
Return(database.TailnetCoordinator{}, nil)
884+
mStore.EXPECT().DeleteCoordinator(gomock.Any(), gomock.Any()).
885+
Times(1).
886+
Do(func(_ context.Context, _ uuid.UUID) {
887+
close(deleteCalled)
888+
<-finishDelete
889+
}).
890+
Return(nil)
891+
892+
// extra calls we don't particularly care about for this test
893+
mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil)
894+
mStore.EXPECT().CleanTailnetLostPeers(gomock.Any()).AnyTimes().Return(nil)
895+
mStore.EXPECT().CleanTailnetTunnels(gomock.Any()).AnyTimes().Return(nil)
896+
897+
uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore)
898+
require.NoError(t, err)
899+
testutil.RequireRecvCtx(ctx, t, upsertDone)
900+
closeErr := make(chan error, 1)
901+
go func() {
902+
closeErr <- uut.Close()
903+
}()
904+
select {
905+
case <-closeErr:
906+
t.Fatal("close returned before DeleteCoordinator called")
907+
case <-deleteCalled:
908+
close(finishDelete)
909+
err := testutil.RequireRecvCtx(ctx, t, closeErr)
910+
require.NoError(t, err)
911+
}
912+
}
913+
867914
type testConn struct {
868915
ws, serverWS net.Conn
869916
nodeChan chan []*agpl.Node

0 commit comments

Comments
 (0)