From 7ea3be50588e9165ef2e7621e678da1ce7c720a0 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 17 Aug 2023 07:13:58 +0000 Subject: [PATCH] fix: fix race in PGCoord at startup Signed-off-by: Spike Curtis --- enterprise/tailnet/pgcoord.go | 101 ++++++++++++++++++----------- enterprise/tailnet/pgcoord_test.go | 38 +++++------ 2 files changed, 82 insertions(+), 57 deletions(-) diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 5d1e09d441243..c12494856c0d5 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -403,11 +403,15 @@ func (b *binder) writeOne(bnd binding) error { CoordinatorID: b.coordinatorID, Node: nodeRaw, }) + b.logger.Debug(b.ctx, "upserted agent binding", + slog.F("agent_id", bnd.agent), slog.F("node", nodeRaw), slog.Error(err)) case bnd.isAgent() && len(nodeRaw) == 0: _, err = b.store.DeleteTailnetAgent(b.ctx, database.DeleteTailnetAgentParams{ ID: bnd.agent, CoordinatorID: b.coordinatorID, }) + b.logger.Debug(b.ctx, "deleted agent binding", + slog.F("agent_id", bnd.agent), slog.Error(err)) if xerrors.Is(err, sql.ErrNoRows) { // treat deletes as idempotent err = nil @@ -419,11 +423,16 @@ func (b *binder) writeOne(bnd binding) error { AgentID: bnd.agent, Node: nodeRaw, }) + b.logger.Debug(b.ctx, "upserted client binding", + slog.F("agent_id", bnd.agent), slog.F("client_id", bnd.client), + slog.F("node", nodeRaw), slog.Error(err)) case bnd.isClient() && len(nodeRaw) == 0: _, err = b.store.DeleteTailnetClient(b.ctx, database.DeleteTailnetClientParams{ ID: bnd.client, CoordinatorID: b.coordinatorID, }) + b.logger.Debug(b.ctx, "deleted client binding", + slog.F("agent_id", bnd.agent), slog.F("client_id", bnd.client), slog.Error(err)) if xerrors.Is(err, sql.ErrNoRows) { // treat deletes as idempotent err = nil @@ -620,7 +629,7 @@ func newQuerier( updates: updates, healthy: true, // assume we start healthy } - go q.subscribe() + q.subscribe() go q.handleConnIO() for i := 0; i < numWorkers; i++ { go q.worker() @@ -748,6 +757,8 @@ func (q *querier) query(mk mKey) error { func (q *querier) queryClientsOfAgent(agent uuid.UUID) ([]mapping, error) { clients, err := q.store.GetTailnetClientsForAgent(q.ctx, agent) + q.logger.Debug(q.ctx, "queried clients of agent", + slog.F("agent_id", agent), slog.F("num_clients", len(clients)), slog.Error(err)) if err != nil { return nil, err } @@ -772,6 +783,8 @@ func (q *querier) queryClientsOfAgent(agent uuid.UUID) ([]mapping, error) { func (q *querier) queryAgent(agentID uuid.UUID) ([]mapping, error) { agents, err := q.store.GetTailnetAgents(q.ctx, agentID) + q.logger.Debug(q.ctx, "queried agents", + slog.F("agent_id", agentID), slog.F("num_agents", len(agents)), slog.Error(err)) if err != nil { return nil, err } @@ -793,50 +806,62 @@ func (q *querier) queryAgent(agentID uuid.UUID) ([]mapping, error) { return mappings, nil } +// subscribe starts our subscriptions to client and agent updates in a new goroutine, and returns once we are subscribed +// or the querier context is canceled. func (q *querier) subscribe() { - eb := backoff.NewExponentialBackOff() - eb.MaxElapsedTime = 0 // retry indefinitely - eb.MaxInterval = dbMaxBackoff - bkoff := backoff.WithContext(eb, q.ctx) - var cancelClient context.CancelFunc - err := backoff.Retry(func() error { - cancelFn, err := q.pubsub.SubscribeWithErr(eventClientUpdate, q.listenClient) + subscribed := make(chan struct{}) + go func() { + defer close(subscribed) + eb := backoff.NewExponentialBackOff() + eb.MaxElapsedTime = 0 // retry indefinitely + eb.MaxInterval = dbMaxBackoff + bkoff := backoff.WithContext(eb, q.ctx) + var cancelClient context.CancelFunc + err := backoff.Retry(func() error { + cancelFn, err := q.pubsub.SubscribeWithErr(eventClientUpdate, q.listenClient) + if err != nil { + q.logger.Warn(q.ctx, "failed to subscribe to client updates", slog.Error(err)) + return err + } + cancelClient = cancelFn + return nil + }, bkoff) if err != nil { - q.logger.Warn(q.ctx, "failed to subscribe to client updates", slog.Error(err)) - return err - } - cancelClient = cancelFn - return nil - }, bkoff) - if err != nil { - if q.ctx.Err() == nil { - q.logger.Error(q.ctx, "code bug: retry failed before context canceled", slog.Error(err)) + if q.ctx.Err() == nil { + q.logger.Error(q.ctx, "code bug: retry failed before context canceled", slog.Error(err)) + } + return } - return - } - defer cancelClient() - bkoff.Reset() + defer cancelClient() + bkoff.Reset() + q.logger.Debug(q.ctx, "subscribed to client updates") - var cancelAgent context.CancelFunc - err = backoff.Retry(func() error { - cancelFn, err := q.pubsub.SubscribeWithErr(eventAgentUpdate, q.listenAgent) + var cancelAgent context.CancelFunc + err = backoff.Retry(func() error { + cancelFn, err := q.pubsub.SubscribeWithErr(eventAgentUpdate, q.listenAgent) + if err != nil { + q.logger.Warn(q.ctx, "failed to subscribe to agent updates", slog.Error(err)) + return err + } + cancelAgent = cancelFn + return nil + }, bkoff) if err != nil { - q.logger.Warn(q.ctx, "failed to subscribe to agent updates", slog.Error(err)) - return err - } - cancelAgent = cancelFn - return nil - }, bkoff) - if err != nil { - if q.ctx.Err() == nil { - q.logger.Error(q.ctx, "code bug: retry failed before context canceled", slog.Error(err)) + if q.ctx.Err() == nil { + q.logger.Error(q.ctx, "code bug: retry failed before context canceled", slog.Error(err)) + } + return } - return - } - defer cancelAgent() + defer cancelAgent() + q.logger.Debug(q.ctx, "subscribed to agent updates") - // hold subscriptions open until context is canceled - <-q.ctx.Done() + // unblock the outer function from returning + subscribed <- struct{}{} + + // hold subscriptions open until context is canceled + <-q.ctx.Done() + }() + <-subscribed } func (q *querier) listenClient(_ context.Context, msg []byte, err error) { diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index 200945371099e..ec598c6ba8eb1 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -86,7 +86,7 @@ func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) { require.NoError(t, err) defer coordinator.Close() - agent := newTestAgent(t, coordinator) + agent := newTestAgent(t, coordinator, "agent") defer agent.close() agent.sendNode(&agpl.Node{PreferredDERP: 10}) require.Eventually(t, func() bool { @@ -123,7 +123,7 @@ func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) { require.NoError(t, err) defer coordinator.Close() - agent := newTestAgent(t, coordinator) + agent := newTestAgent(t, coordinator, "original") defer agent.close() agent.sendNode(&agpl.Node{PreferredDERP: 10}) @@ -151,7 +151,7 @@ func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) { agent.waitForClose(ctx, t) // Create a new agent connection. This is to simulate a reconnect! - agent = newTestAgent(t, coordinator, agent.id) + agent = newTestAgent(t, coordinator, "reconnection", agent.id) // Ensure the existing listening connIO sends its node immediately! clientNodes = agent.recvNodes(ctx, t) require.Len(t, clientNodes, 1) @@ -200,7 +200,7 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) { require.NoError(t, err) defer coordinator.Close() - agent := newTestAgent(t, coordinator) + agent := newTestAgent(t, coordinator, "agent") defer agent.close() agent.sendNode(&agpl.Node{PreferredDERP: 10}) @@ -333,16 +333,16 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - coord1, err := tailnet.NewPGCoord(ctx, logger, ps, store) + coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) require.NoError(t, err) defer coord1.Close() - coord2, err := tailnet.NewPGCoord(ctx, logger, ps, store) + coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store) require.NoError(t, err) defer coord2.Close() - agent1 := newTestAgent(t, coord1) + agent1 := newTestAgent(t, coord1, "agent1") defer agent1.close() - agent2 := newTestAgent(t, coord2) + agent2 := newTestAgent(t, coord2, "agent2") defer agent2.close() client11 := newTestClient(t, coord1, agent1.id) @@ -460,19 +460,19 @@ func TestPGCoordinator_MultiAgent(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - coord1, err := tailnet.NewPGCoord(ctx, logger, ps, store) + coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) require.NoError(t, err) defer coord1.Close() - coord2, err := tailnet.NewPGCoord(ctx, logger, ps, store) + coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store) require.NoError(t, err) defer coord2.Close() - coord3, err := tailnet.NewPGCoord(ctx, logger, ps, store) + coord3, err := tailnet.NewPGCoord(ctx, logger.Named("coord3"), ps, store) require.NoError(t, err) defer coord3.Close() - agent1 := newTestAgent(t, coord1) + agent1 := newTestAgent(t, coord1, "agent1") defer agent1.close() - agent2 := newTestAgent(t, coord2, agent1.id) + agent2 := newTestAgent(t, coord2, "agent2", agent1.id) defer agent2.close() client := newTestClient(t, coord3, agent1.id) @@ -552,7 +552,7 @@ func TestPGCoordinator_Unhealthy(t *testing.T) { err := uut.Close() require.NoError(t, err) }() - agent1 := newTestAgent(t, uut) + agent1 := newTestAgent(t, uut, "agent1") defer agent1.close() for i := 0; i < 3; i++ { select { @@ -566,7 +566,7 @@ func TestPGCoordinator_Unhealthy(t *testing.T) { agent1.waitForClose(ctx, t) // new agent should immediately disconnect - agent2 := newTestAgent(t, uut) + agent2 := newTestAgent(t, uut, "agent2") defer agent2.close() agent2.waitForClose(ctx, t) @@ -579,7 +579,7 @@ func TestPGCoordinator_Unhealthy(t *testing.T) { // OK } } - agent3 := newTestAgent(t, uut) + agent3 := newTestAgent(t, uut, "agent3") defer agent3.close() select { case <-agent3.closeChan: @@ -618,10 +618,10 @@ func newTestConn(ids []uuid.UUID) *testConn { return a } -func newTestAgent(t *testing.T, coord agpl.Coordinator, id ...uuid.UUID) *testConn { +func newTestAgent(t *testing.T, coord agpl.Coordinator, name string, id ...uuid.UUID) *testConn { a := newTestConn(id) go func() { - err := coord.ServeAgent(a.serverWS, a.id, "") + err := coord.ServeAgent(a.serverWS, a.id, name) assert.NoError(t, err) close(a.closeChan) }() @@ -636,7 +636,7 @@ func (c *testConn) recvNodes(ctx context.Context, t *testing.T) []*agpl.Node { t.Helper() select { case <-ctx.Done(): - t.Fatal("timeout receiving nodes") + t.Fatalf("testConn id %s: timeout receiving nodes ", c.id) return nil case nodes := <-c.nodeChan: return nodes