From f8e8ac44e98fbac1abadd2695bf43c95376100a3 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 16 Aug 2023 11:49:14 +0000 Subject: [PATCH] fix: make PGCoordinator close connections when unhealthy Signed-off-by: Spike Curtis --- .golangci.yaml | 1 + enterprise/tailnet/pgcoord.go | 101 +++++++++++++++++++++++--- enterprise/tailnet/pgcoord_test.go | 109 ++++++++++++++++++++++++----- 3 files changed, 183 insertions(+), 28 deletions(-) diff --git a/.golangci.yaml b/.golangci.yaml index e3f3797d06b81..156d6649890b3 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -211,6 +211,7 @@ issues: run: skip-dirs: - node_modules + - .git skip-files: - scripts/rules.go timeout: 10m diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 03593a238201e..5d1e09d441243 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -586,10 +586,12 @@ type querier struct { workQ *workQ[mKey] heartbeats *heartbeats - updates <-chan struct{} + updates <-chan hbUpdate mu sync.Mutex mappers map[mKey]*countedMapper + conns map[*connIO]struct{} + healthy bool } type countedMapper struct { @@ -604,7 +606,7 @@ func newQuerier( self uuid.UUID, newConnections chan *connIO, numWorkers int, firstHeartbeat chan<- struct{}, ) *querier { - updates := make(chan struct{}) + updates := make(chan hbUpdate) q := &querier{ ctx: ctx, logger: logger.Named("querier"), @@ -614,7 +616,9 @@ func newQuerier( workQ: newWorkQ[mKey](ctx), heartbeats: newHeartbeats(ctx, logger, ps, store, self, updates, firstHeartbeat), mappers: make(map[mKey]*countedMapper), + conns: make(map[*connIO]struct{}), updates: updates, + healthy: true, // assume we start healthy } go q.subscribe() go q.handleConnIO() @@ -639,6 +643,15 @@ func (q *querier) handleConnIO() { func (q *querier) newConn(c *connIO) { q.mu.Lock() defer q.mu.Unlock() + if !q.healthy { + err := c.updates.Close() + q.logger.Info(q.ctx, "closed incoming connection while unhealthy", + slog.Error(err), + slog.F("agent_id", c.agent), + slog.F("client_id", c.client), + ) + return + } mk := mKey{ agent: c.agent, // if client is Nil, this is an agent connection, and it wants the mappings for all the clients of itself @@ -661,6 +674,7 @@ func (q *querier) newConn(c *connIO) { return } cm.count++ + q.conns[c] = struct{}{} go q.cleanupConn(c) } @@ -668,6 +682,7 @@ func (q *querier) cleanupConn(c *connIO) { <-c.ctx.Done() q.mu.Lock() defer q.mu.Unlock() + delete(q.conns, c) mk := mKey{ agent: c.agent, // if client is Nil, this is an agent connection, and it wants the mappings for all the clients of itself @@ -911,8 +926,18 @@ func (q *querier) handleUpdates() { select { case <-q.ctx.Done(): return - case <-q.updates: - q.updateAll() + case u := <-q.updates: + if u.filter == filterUpdateUpdated { + q.updateAll() + } + if u.health == healthUpdateUnhealthy { + q.unhealthyCloseAll() + continue + } + if u.health == healthUpdateHealthy { + q.setHealthy() + continue + } } } } @@ -932,6 +957,30 @@ func (q *querier) updateAll() { } } +// unhealthyCloseAll marks the coordinator unhealthy and closes all connections. We do this so that clients and agents +// are forced to reconnect to the coordinator, and will hopefully land on a healthy coordinator. +func (q *querier) unhealthyCloseAll() { + q.mu.Lock() + defer q.mu.Unlock() + q.healthy = false + for c := range q.conns { + // close connections async so that we don't block the querier routine that responds to updates + go func(c *connIO) { + err := c.updates.Close() + if err != nil { + q.logger.Debug(q.ctx, "error closing conn while unhealthy", slog.Error(err)) + } + }(c) + // NOTE: we don't need to remove the connection from the map, as that will happen async in q.cleanupConn() + } +} + +func (q *querier) setHealthy() { + q.mu.Lock() + defer q.mu.Unlock() + q.healthy = true +} + func (q *querier) getAll(ctx context.Context) (map[uuid.UUID]database.TailnetAgent, map[uuid.UUID][]database.TailnetClient, error) { agents, err := q.store.GetAllTailnetAgents(ctx) if err != nil { @@ -1078,6 +1127,28 @@ func (q *workQ[K]) done(key K) { q.cond.Signal() } +type filterUpdate int + +const ( + filterUpdateNone filterUpdate = iota + filterUpdateUpdated +) + +type healthUpdate int + +const ( + healthUpdateNone healthUpdate = iota + healthUpdateHealthy + healthUpdateUnhealthy +) + +// hbUpdate is an update sent from the heartbeats to the querier. Zero values of the fields mean no update of that +// kind. +type hbUpdate struct { + filter filterUpdate + health healthUpdate +} + // heartbeats sends heartbeats for this coordinator on a timer, and monitors heartbeats from other coordinators. If a // coordinator misses their heartbeat, we remove it from our map of "valid" coordinators, such that we will filter out // any mappings for it when filter() is called, and we send a signal on the update channel, which triggers all mappers @@ -1089,8 +1160,9 @@ type heartbeats struct { store database.Store self uuid.UUID - update chan<- struct{} - firstHeartbeat chan<- struct{} + update chan<- hbUpdate + firstHeartbeat chan<- struct{} + failedHeartbeats int lock sync.RWMutex coordinators map[uuid.UUID]time.Time @@ -1103,7 +1175,7 @@ type heartbeats struct { func newHeartbeats( ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store database.Store, - self uuid.UUID, update chan<- struct{}, + self uuid.UUID, update chan<- hbUpdate, firstHeartbeat chan<- struct{}, ) *heartbeats { h := &heartbeats{ @@ -1194,7 +1266,7 @@ func (h *heartbeats) recvBeat(id uuid.UUID) { h.logger.Info(h.ctx, "heartbeats (re)started", slog.F("other_coordinator_id", id)) // send on a separate goroutine to avoid holding lock. Triggering update can be async go func() { - _ = sendCtx(h.ctx, h.update, struct{}{}) + _ = sendCtx(h.ctx, h.update, hbUpdate{filter: filterUpdateUpdated}) }() } h.coordinators[id] = time.Now() @@ -1241,7 +1313,7 @@ func (h *heartbeats) checkExpiry() { if expired { // send on a separate goroutine to avoid holding lock. Triggering update can be async go func() { - _ = sendCtx(h.ctx, h.update, struct{}{}) + _ = sendCtx(h.ctx, h.update, hbUpdate{filter: filterUpdateUpdated}) }() } // we need to reset the timer for when the next oldest coordinator will expire, if any. @@ -1269,11 +1341,20 @@ func (h *heartbeats) sendBeats() { func (h *heartbeats) sendBeat() { _, err := h.store.UpsertTailnetCoordinator(h.ctx, h.self) if err != nil { - // just log errors, heartbeats are rescheduled on a timer h.logger.Error(h.ctx, "failed to send heartbeat", slog.Error(err)) + h.failedHeartbeats++ + if h.failedHeartbeats == 3 { + h.logger.Error(h.ctx, "coordinator failed 3 heartbeats and is unhealthy") + _ = sendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateUnhealthy}) + } return } h.logger.Debug(h.ctx, "sent heartbeat") + if h.failedHeartbeats >= 3 { + h.logger.Info(h.ctx, "coordinator sent heartbeat and is healthy") + _ = sendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateHealthy}) + } + h.failedHeartbeats = 0 } func (h *heartbeats) sendDelete() { diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index 25d80ca854566..200945371099e 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/golang/mock/gomock" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -21,7 +22,9 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbmock" "github.com/coder/coder/coderd/database/dbtestutil" + "github.com/coder/coder/coderd/database/pubsub" "github.com/coder/coder/enterprise/tailnet" agpl "github.com/coder/coder/tailnet" "github.com/coder/coder/testutil" @@ -36,11 +39,11 @@ func TestPGCoordinatorSingle_ClientWithoutAgent(t *testing.T) { if !dbtestutil.WillUsePostgres() { t.Skip("test only with postgres") } - store, pubsub := dbtestutil.NewDB(t) + store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer coordinator.Close() @@ -75,11 +78,11 @@ func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) { if !dbtestutil.WillUsePostgres() { t.Skip("test only with postgres") } - store, pubsub := dbtestutil.NewDB(t) + store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer coordinator.Close() @@ -112,11 +115,11 @@ func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) { if !dbtestutil.WillUsePostgres() { t.Skip("test only with postgres") } - store, pubsub := dbtestutil.NewDB(t) + store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer coordinator.Close() @@ -189,11 +192,11 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) { if !dbtestutil.WillUsePostgres() { t.Skip("test only with postgres") } - store, pubsub := dbtestutil.NewDB(t) + store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer coordinator.Close() @@ -276,14 +279,14 @@ func TestPGCoordinatorSingle_SendsHeartbeats(t *testing.T) { if !dbtestutil.WillUsePostgres() { t.Skip("test only with postgres") } - store, pubsub := dbtestutil.NewDB(t) + store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) mu := sync.Mutex{} heartbeats := []time.Time{} - unsub, err := pubsub.SubscribeWithErr(tailnet.EventHeartbeats, func(_ context.Context, msg []byte, err error) { + unsub, err := ps.SubscribeWithErr(tailnet.EventHeartbeats, func(_ context.Context, msg []byte, err error) { assert.NoError(t, err) mu.Lock() defer mu.Unlock() @@ -293,7 +296,7 @@ func TestPGCoordinatorSingle_SendsHeartbeats(t *testing.T) { defer unsub() start := time.Now() - coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer coordinator.Close() @@ -326,14 +329,14 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) { if !dbtestutil.WillUsePostgres() { t.Skip("test only with postgres") } - store, pubsub := dbtestutil.NewDB(t) + store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - coord1, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + coord1, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer coord1.Close() - coord2, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + coord2, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer coord2.Close() @@ -453,17 +456,17 @@ func TestPGCoordinator_MultiAgent(t *testing.T) { if !dbtestutil.WillUsePostgres() { t.Skip("test only with postgres") } - store, pubsub := dbtestutil.NewDB(t) + store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - coord1, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + coord1, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer coord1.Close() - coord2, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + coord2, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer coord2.Close() - coord3, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + coord3, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer coord3.Close() @@ -516,6 +519,76 @@ func TestPGCoordinator_MultiAgent(t *testing.T) { assertEventuallyNoAgents(ctx, t, store, agent1.id) } +func TestPGCoordinator_Unhealthy(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + ctrl := gomock.NewController(t) + mStore := dbmock.NewMockStore(ctrl) + ps := pubsub.NewInMemory() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + + calls := make(chan struct{}) + threeMissed := mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()). + Times(3). + Do(func(_ context.Context, _ uuid.UUID) { <-calls }). + Return(database.TailnetCoordinator{}, xerrors.New("test disconnect")) + mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()). + MinTimes(1). + After(threeMissed). + Do(func(_ context.Context, _ uuid.UUID) { <-calls }). + Return(database.TailnetCoordinator{}, nil) + // extra calls we don't particularly care about for this test + mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil) + mStore.EXPECT().GetTailnetClientsForAgent(gomock.Any(), gomock.Any()).AnyTimes().Return(nil, nil) + mStore.EXPECT().DeleteTailnetAgent(gomock.Any(), gomock.Any()). + AnyTimes().Return(database.DeleteTailnetAgentRow{}, nil) + mStore.EXPECT().DeleteCoordinator(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + + uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore) + require.NoError(t, err) + defer func() { + err := uut.Close() + require.NoError(t, err) + }() + agent1 := newTestAgent(t, uut) + defer agent1.close() + for i := 0; i < 3; i++ { + select { + case <-ctx.Done(): + t.Fatal("timeout") + case calls <- struct{}{}: + // OK + } + } + // connected agent should be disconnected + agent1.waitForClose(ctx, t) + + // new agent should immediately disconnect + agent2 := newTestAgent(t, uut) + defer agent2.close() + agent2.waitForClose(ctx, t) + + // next heartbeats succeed, so we are healthy + for i := 0; i < 2; i++ { + select { + case <-ctx.Done(): + t.Fatal("timeout") + case calls <- struct{}{}: + // OK + } + } + agent3 := newTestAgent(t, uut) + defer agent3.close() + select { + case <-agent3.closeChan: + t.Fatal("agent conn closed after we are healthy") + case <-time.After(time.Second): + // OK + } +} + type testConn struct { ws, serverWS net.Conn nodeChan chan []*agpl.Node