Skip to content

Commit 2f46f23

Browse files
authored
fix: fix race in PGCoord at startup (#9144)
Signed-off-by: Spike Curtis <spike@coder.com>
1 parent c0a7853 commit 2f46f23

File tree

2 files changed

+82
-57
lines changed

2 files changed

+82
-57
lines changed

enterprise/tailnet/pgcoord.go

+63-38
Original file line numberDiff line numberDiff line change
@@ -403,11 +403,15 @@ func (b *binder) writeOne(bnd binding) error {
403403
CoordinatorID: b.coordinatorID,
404404
Node: nodeRaw,
405405
})
406+
b.logger.Debug(b.ctx, "upserted agent binding",
407+
slog.F("agent_id", bnd.agent), slog.F("node", nodeRaw), slog.Error(err))
406408
case bnd.isAgent() && len(nodeRaw) == 0:
407409
_, err = b.store.DeleteTailnetAgent(b.ctx, database.DeleteTailnetAgentParams{
408410
ID: bnd.agent,
409411
CoordinatorID: b.coordinatorID,
410412
})
413+
b.logger.Debug(b.ctx, "deleted agent binding",
414+
slog.F("agent_id", bnd.agent), slog.Error(err))
411415
if xerrors.Is(err, sql.ErrNoRows) {
412416
// treat deletes as idempotent
413417
err = nil
@@ -419,11 +423,16 @@ func (b *binder) writeOne(bnd binding) error {
419423
AgentID: bnd.agent,
420424
Node: nodeRaw,
421425
})
426+
b.logger.Debug(b.ctx, "upserted client binding",
427+
slog.F("agent_id", bnd.agent), slog.F("client_id", bnd.client),
428+
slog.F("node", nodeRaw), slog.Error(err))
422429
case bnd.isClient() && len(nodeRaw) == 0:
423430
_, err = b.store.DeleteTailnetClient(b.ctx, database.DeleteTailnetClientParams{
424431
ID: bnd.client,
425432
CoordinatorID: b.coordinatorID,
426433
})
434+
b.logger.Debug(b.ctx, "deleted client binding",
435+
slog.F("agent_id", bnd.agent), slog.F("client_id", bnd.client), slog.Error(err))
427436
if xerrors.Is(err, sql.ErrNoRows) {
428437
// treat deletes as idempotent
429438
err = nil
@@ -620,7 +629,7 @@ func newQuerier(
620629
updates: updates,
621630
healthy: true, // assume we start healthy
622631
}
623-
go q.subscribe()
632+
q.subscribe()
624633
go q.handleConnIO()
625634
for i := 0; i < numWorkers; i++ {
626635
go q.worker()
@@ -748,6 +757,8 @@ func (q *querier) query(mk mKey) error {
748757

749758
func (q *querier) queryClientsOfAgent(agent uuid.UUID) ([]mapping, error) {
750759
clients, err := q.store.GetTailnetClientsForAgent(q.ctx, agent)
760+
q.logger.Debug(q.ctx, "queried clients of agent",
761+
slog.F("agent_id", agent), slog.F("num_clients", len(clients)), slog.Error(err))
751762
if err != nil {
752763
return nil, err
753764
}
@@ -772,6 +783,8 @@ func (q *querier) queryClientsOfAgent(agent uuid.UUID) ([]mapping, error) {
772783

773784
func (q *querier) queryAgent(agentID uuid.UUID) ([]mapping, error) {
774785
agents, err := q.store.GetTailnetAgents(q.ctx, agentID)
786+
q.logger.Debug(q.ctx, "queried agents",
787+
slog.F("agent_id", agentID), slog.F("num_agents", len(agents)), slog.Error(err))
775788
if err != nil {
776789
return nil, err
777790
}
@@ -793,50 +806,62 @@ func (q *querier) queryAgent(agentID uuid.UUID) ([]mapping, error) {
793806
return mappings, nil
794807
}
795808

809+
// subscribe starts our subscriptions to client and agent updates in a new goroutine, and returns once we are subscribed
810+
// or the querier context is canceled.
796811
func (q *querier) subscribe() {
797-
eb := backoff.NewExponentialBackOff()
798-
eb.MaxElapsedTime = 0 // retry indefinitely
799-
eb.MaxInterval = dbMaxBackoff
800-
bkoff := backoff.WithContext(eb, q.ctx)
801-
var cancelClient context.CancelFunc
802-
err := backoff.Retry(func() error {
803-
cancelFn, err := q.pubsub.SubscribeWithErr(eventClientUpdate, q.listenClient)
812+
subscribed := make(chan struct{})
813+
go func() {
814+
defer close(subscribed)
815+
eb := backoff.NewExponentialBackOff()
816+
eb.MaxElapsedTime = 0 // retry indefinitely
817+
eb.MaxInterval = dbMaxBackoff
818+
bkoff := backoff.WithContext(eb, q.ctx)
819+
var cancelClient context.CancelFunc
820+
err := backoff.Retry(func() error {
821+
cancelFn, err := q.pubsub.SubscribeWithErr(eventClientUpdate, q.listenClient)
822+
if err != nil {
823+
q.logger.Warn(q.ctx, "failed to subscribe to client updates", slog.Error(err))
824+
return err
825+
}
826+
cancelClient = cancelFn
827+
return nil
828+
}, bkoff)
804829
if err != nil {
805-
q.logger.Warn(q.ctx, "failed to subscribe to client updates", slog.Error(err))
806-
return err
807-
}
808-
cancelClient = cancelFn
809-
return nil
810-
}, bkoff)
811-
if err != nil {
812-
if q.ctx.Err() == nil {
813-
q.logger.Error(q.ctx, "code bug: retry failed before context canceled", slog.Error(err))
830+
if q.ctx.Err() == nil {
831+
q.logger.Error(q.ctx, "code bug: retry failed before context canceled", slog.Error(err))
832+
}
833+
return
814834
}
815-
return
816-
}
817-
defer cancelClient()
818-
bkoff.Reset()
835+
defer cancelClient()
836+
bkoff.Reset()
837+
q.logger.Debug(q.ctx, "subscribed to client updates")
819838

820-
var cancelAgent context.CancelFunc
821-
err = backoff.Retry(func() error {
822-
cancelFn, err := q.pubsub.SubscribeWithErr(eventAgentUpdate, q.listenAgent)
839+
var cancelAgent context.CancelFunc
840+
err = backoff.Retry(func() error {
841+
cancelFn, err := q.pubsub.SubscribeWithErr(eventAgentUpdate, q.listenAgent)
842+
if err != nil {
843+
q.logger.Warn(q.ctx, "failed to subscribe to agent updates", slog.Error(err))
844+
return err
845+
}
846+
cancelAgent = cancelFn
847+
return nil
848+
}, bkoff)
823849
if err != nil {
824-
q.logger.Warn(q.ctx, "failed to subscribe to agent updates", slog.Error(err))
825-
return err
826-
}
827-
cancelAgent = cancelFn
828-
return nil
829-
}, bkoff)
830-
if err != nil {
831-
if q.ctx.Err() == nil {
832-
q.logger.Error(q.ctx, "code bug: retry failed before context canceled", slog.Error(err))
850+
if q.ctx.Err() == nil {
851+
q.logger.Error(q.ctx, "code bug: retry failed before context canceled", slog.Error(err))
852+
}
853+
return
833854
}
834-
return
835-
}
836-
defer cancelAgent()
855+
defer cancelAgent()
856+
q.logger.Debug(q.ctx, "subscribed to agent updates")
837857

838-
// hold subscriptions open until context is canceled
839-
<-q.ctx.Done()
858+
// unblock the outer function from returning
859+
subscribed <- struct{}{}
860+
861+
// hold subscriptions open until context is canceled
862+
<-q.ctx.Done()
863+
}()
864+
<-subscribed
840865
}
841866

842867
func (q *querier) listenClient(_ context.Context, msg []byte, err error) {

enterprise/tailnet/pgcoord_test.go

+19-19
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) {
8686
require.NoError(t, err)
8787
defer coordinator.Close()
8888

89-
agent := newTestAgent(t, coordinator)
89+
agent := newTestAgent(t, coordinator, "agent")
9090
defer agent.close()
9191
agent.sendNode(&agpl.Node{PreferredDERP: 10})
9292
require.Eventually(t, func() bool {
@@ -123,7 +123,7 @@ func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) {
123123
require.NoError(t, err)
124124
defer coordinator.Close()
125125

126-
agent := newTestAgent(t, coordinator)
126+
agent := newTestAgent(t, coordinator, "original")
127127
defer agent.close()
128128
agent.sendNode(&agpl.Node{PreferredDERP: 10})
129129

@@ -151,7 +151,7 @@ func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) {
151151
agent.waitForClose(ctx, t)
152152

153153
// Create a new agent connection. This is to simulate a reconnect!
154-
agent = newTestAgent(t, coordinator, agent.id)
154+
agent = newTestAgent(t, coordinator, "reconnection", agent.id)
155155
// Ensure the existing listening connIO sends its node immediately!
156156
clientNodes = agent.recvNodes(ctx, t)
157157
require.Len(t, clientNodes, 1)
@@ -200,7 +200,7 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) {
200200
require.NoError(t, err)
201201
defer coordinator.Close()
202202

203-
agent := newTestAgent(t, coordinator)
203+
agent := newTestAgent(t, coordinator, "agent")
204204
defer agent.close()
205205
agent.sendNode(&agpl.Node{PreferredDERP: 10})
206206

@@ -333,16 +333,16 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) {
333333
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
334334
defer cancel()
335335
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
336-
coord1, err := tailnet.NewPGCoord(ctx, logger, ps, store)
336+
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
337337
require.NoError(t, err)
338338
defer coord1.Close()
339-
coord2, err := tailnet.NewPGCoord(ctx, logger, ps, store)
339+
coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store)
340340
require.NoError(t, err)
341341
defer coord2.Close()
342342

343-
agent1 := newTestAgent(t, coord1)
343+
agent1 := newTestAgent(t, coord1, "agent1")
344344
defer agent1.close()
345-
agent2 := newTestAgent(t, coord2)
345+
agent2 := newTestAgent(t, coord2, "agent2")
346346
defer agent2.close()
347347

348348
client11 := newTestClient(t, coord1, agent1.id)
@@ -460,19 +460,19 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
460460
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
461461
defer cancel()
462462
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
463-
coord1, err := tailnet.NewPGCoord(ctx, logger, ps, store)
463+
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
464464
require.NoError(t, err)
465465
defer coord1.Close()
466-
coord2, err := tailnet.NewPGCoord(ctx, logger, ps, store)
466+
coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store)
467467
require.NoError(t, err)
468468
defer coord2.Close()
469-
coord3, err := tailnet.NewPGCoord(ctx, logger, ps, store)
469+
coord3, err := tailnet.NewPGCoord(ctx, logger.Named("coord3"), ps, store)
470470
require.NoError(t, err)
471471
defer coord3.Close()
472472

473-
agent1 := newTestAgent(t, coord1)
473+
agent1 := newTestAgent(t, coord1, "agent1")
474474
defer agent1.close()
475-
agent2 := newTestAgent(t, coord2, agent1.id)
475+
agent2 := newTestAgent(t, coord2, "agent2", agent1.id)
476476
defer agent2.close()
477477

478478
client := newTestClient(t, coord3, agent1.id)
@@ -552,7 +552,7 @@ func TestPGCoordinator_Unhealthy(t *testing.T) {
552552
err := uut.Close()
553553
require.NoError(t, err)
554554
}()
555-
agent1 := newTestAgent(t, uut)
555+
agent1 := newTestAgent(t, uut, "agent1")
556556
defer agent1.close()
557557
for i := 0; i < 3; i++ {
558558
select {
@@ -566,7 +566,7 @@ func TestPGCoordinator_Unhealthy(t *testing.T) {
566566
agent1.waitForClose(ctx, t)
567567

568568
// new agent should immediately disconnect
569-
agent2 := newTestAgent(t, uut)
569+
agent2 := newTestAgent(t, uut, "agent2")
570570
defer agent2.close()
571571
agent2.waitForClose(ctx, t)
572572

@@ -579,7 +579,7 @@ func TestPGCoordinator_Unhealthy(t *testing.T) {
579579
// OK
580580
}
581581
}
582-
agent3 := newTestAgent(t, uut)
582+
agent3 := newTestAgent(t, uut, "agent3")
583583
defer agent3.close()
584584
select {
585585
case <-agent3.closeChan:
@@ -618,10 +618,10 @@ func newTestConn(ids []uuid.UUID) *testConn {
618618
return a
619619
}
620620

621-
func newTestAgent(t *testing.T, coord agpl.Coordinator, id ...uuid.UUID) *testConn {
621+
func newTestAgent(t *testing.T, coord agpl.Coordinator, name string, id ...uuid.UUID) *testConn {
622622
a := newTestConn(id)
623623
go func() {
624-
err := coord.ServeAgent(a.serverWS, a.id, "")
624+
err := coord.ServeAgent(a.serverWS, a.id, name)
625625
assert.NoError(t, err)
626626
close(a.closeChan)
627627
}()
@@ -636,7 +636,7 @@ func (c *testConn) recvNodes(ctx context.Context, t *testing.T) []*agpl.Node {
636636
t.Helper()
637637
select {
638638
case <-ctx.Done():
639-
t.Fatal("timeout receiving nodes")
639+
t.Fatalf("testConn id %s: timeout receiving nodes ", c.id)
640640
return nil
641641
case nodes := <-c.nodeChan:
642642
return nodes

0 commit comments

Comments
 (0)