From e2f3c311102f50cd071ba4c66d6e7e823c717f0c Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Sun, 4 Aug 2024 01:29:53 +0000 Subject: [PATCH 01/13] fix: avoid deleting peers on graceful close - Fixes an issue where a coordinator deletes all its peers on shutdown. This can cause disconnects whenever a coderd is redeployed. --- enterprise/tailnet/pgcoord.go | 36 +++++++------ enterprise/tailnet/pgcoord_test.go | 82 ++++++++++++++---------------- 2 files changed, 58 insertions(+), 60 deletions(-) diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 1546f0ac3087b..5de8dad1f3306 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -26,18 +26,19 @@ import ( ) const ( - EventHeartbeats = "tailnet_coordinator_heartbeat" - eventPeerUpdate = "tailnet_peer_update" - eventTunnelUpdate = "tailnet_tunnel_update" - eventReadyForHandshake = "tailnet_ready_for_handshake" - HeartbeatPeriod = time.Second * 2 - MissedHeartbeats = 3 - numQuerierWorkers = 10 - numBinderWorkers = 10 - numTunnelerWorkers = 10 - numHandshakerWorkers = 5 - dbMaxBackoff = 10 * time.Second - cleanupPeriod = time.Hour + EventHeartbeats = "tailnet_coordinator_heartbeat" + eventPeerUpdate = "tailnet_peer_update" + eventTunnelUpdate = "tailnet_tunnel_update" + eventReadyForHandshake = "tailnet_ready_for_handshake" + HeartbeatPeriod = time.Second * 2 + MissedHeartbeats = 3 + unhealthyHeartbeatThreshold = 3 + numQuerierWorkers = 10 + numBinderWorkers = 10 + numTunnelerWorkers = 10 + numHandshakerWorkers = 5 + dbMaxBackoff = 10 * time.Second + cleanupPeriod = time.Hour ) // pgCoord is a postgres-backed coordinator @@ -1646,13 +1647,18 @@ func (h *heartbeats) sendBeats() { // send an initial heartbeat so that other coordinators can start using our bindings right away. h.sendBeat() close(h.firstHeartbeat) // signal binder it can start writing - defer h.sendDelete() tkr := h.clock.TickerFunc(h.ctx, HeartbeatPeriod, func() error { h.sendBeat() return nil }, "heartbeats", "sendBeats") err := tkr.Wait() h.logger.Debug(h.ctx, "ending heartbeats", slog.Error(err)) + // This is unlikely to succeed if we're unhealthy but + // we get it our best effort. + if h.failedHeartbeats >= unhealthyHeartbeatThreshold { + h.logger.Debug(h.ctx, "coordinator detected unhealthy, deleting self", slog.Error(err)) + h.sendDelete() + } } func (h *heartbeats) sendBeat() { @@ -1663,14 +1669,14 @@ func (h *heartbeats) sendBeat() { if err != nil { h.logger.Error(h.ctx, "failed to send heartbeat", slog.Error(err)) h.failedHeartbeats++ - if h.failedHeartbeats == 3 { + if h.failedHeartbeats == unhealthyHeartbeatThreshold { h.logger.Error(h.ctx, "coordinator failed 3 heartbeats and is unhealthy") _ = agpl.SendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateUnhealthy}) } return } h.logger.Debug(h.ctx, "sent heartbeat") - if h.failedHeartbeats >= 3 { + if h.failedHeartbeats >= unhealthyHeartbeatThreshold { h.logger.Info(h.ctx, "coordinator sent heartbeat and is healthy") _ = agpl.SendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateHealthy}) } diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index 2232e3941eb0c..050eb42e2f485 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -592,8 +592,6 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) { err = client21.recvErr(ctx, t) require.ErrorIs(t, err, io.EOF) - assertEventuallyNoAgents(ctx, t, store, agent2.id) - t.Logf("close coord1") err = coord1.Close() require.NoError(t, err) @@ -629,10 +627,6 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) { err = client22.close() require.NoError(t, err) client22.waitForClose(ctx, t) - - assertEventuallyNoAgents(ctx, t, store, agent1.id) - assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) - assertEventuallyNoClientsForAgent(ctx, t, store, agent2.id) } // TestPGCoordinator_MultiCoordinatorAgent tests when a single agent connects to multiple coordinators. @@ -746,7 +740,6 @@ func TestPGCoordinator_Unhealthy(t *testing.T) { mStore.EXPECT().DeleteTailnetPeer(gomock.Any(), gomock.Any()). AnyTimes().Return(database.DeleteTailnetPeerRow{}, nil) mStore.EXPECT().DeleteAllTailnetTunnels(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - mStore.EXPECT().DeleteCoordinator(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore) require.NoError(t, err) @@ -871,51 +864,50 @@ func TestPGCoordinator_Lost(t *testing.T) { agpltest.LostTest(ctx, t, coordinator) } -func TestPGCoordinator_DeleteOnClose(t *testing.T) { +func TestPGCoordinator_NoDeleteOnClose(t *testing.T) { t.Parallel() - + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() - ctrl := gomock.NewController(t) - mStore := dbmock.NewMockStore(ctrl) - ps := pubsub.NewInMemory() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store) + require.NoError(t, err) + defer coordinator.Close() - upsertDone := make(chan struct{}) - deleteCalled := make(chan struct{}) - finishDelete := make(chan struct{}) - mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()). - MinTimes(1). - Do(func(_ context.Context, _ uuid.UUID) { close(upsertDone) }). - Return(database.TailnetCoordinator{}, nil) - mStore.EXPECT().DeleteCoordinator(gomock.Any(), gomock.Any()). - Times(1). - Do(func(_ context.Context, _ uuid.UUID) { - close(deleteCalled) - <-finishDelete - }). - Return(nil) + agent := newTestAgent(t, coordinator, "original") + defer agent.close() + agent.sendNode(&agpl.Node{PreferredDERP: 10}) - // extra calls we don't particularly care about for this test - mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil) - mStore.EXPECT().CleanTailnetLostPeers(gomock.Any()).AnyTimes().Return(nil) - mStore.EXPECT().CleanTailnetTunnels(gomock.Any()).AnyTimes().Return(nil) + client := newTestClient(t, coordinator, agent.id) + defer client.close() - uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore) + // Simulate some traffic to generate + // a peer. + agentNodes := client.recvNodes(ctx, t) + require.Len(t, agentNodes, 1) + assert.Equal(t, 10, agentNodes[0].PreferredDERP) + client.sendNode(&agpl.Node{PreferredDERP: 11}) + clientNodes := agent.recvNodes(ctx, t) + require.Len(t, clientNodes, 1) + assert.Equal(t, 11, clientNodes[0].PreferredDERP) + + anode := coordinator.Node(agent.id) + require.NotNil(t, anode) + cnode := coordinator.Node(client.id) + require.NotNil(t, cnode) + + err = coordinator.Close() require.NoError(t, err) - testutil.RequireRecvCtx(ctx, t, upsertDone) - closeErr := make(chan error, 1) - go func() { - closeErr <- uut.Close() - }() - select { - case <-closeErr: - t.Fatal("close returned before DeleteCoordinator called") - case <-deleteCalled: - close(finishDelete) - err := testutil.RequireRecvCtx(ctx, t, closeErr) - require.NoError(t, err) - } + + coordinator2, err := tailnet.NewPGCoord(ctx, logger, ps, store) + require.NoError(t, err) + defer coordinator2.Close() + + anode = coordinator2.Node(agent.id) + require.NotNil(t, anode) } type testConn struct { From 82c6dce0a93756a2737180aacb4df82cccce2f8f Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Mon, 5 Aug 2024 15:23:05 +0000 Subject: [PATCH 02/13] lint --- enterprise/tailnet/pgcoord_test.go | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index 050eb42e2f485..2bbd042830486 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -480,7 +480,7 @@ func TestPGCoordinatorSingle_SendsHeartbeats(t *testing.T) { mu := sync.Mutex{} heartbeats := []time.Time{} - unsub, err := ps.SubscribeWithErr(tailnet.EventHeartbeats, func(_ context.Context, msg []byte, err error) { + unsub, err := ps.SubscribeWithErr(tailnet.EventHeartbeats, func(_ context.Context, _ []byte, err error) { assert.NoError(t, err) mu.Lock() defer mu.Unlock() @@ -1048,20 +1048,6 @@ func assertNeverHasDERPs(ctx context.Context, t *testing.T, c *testConn, expecte } } -func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { - t.Helper() - assert.Eventually(t, func() bool { - agents, err := store.GetTailnetPeers(ctx, agentID) - if xerrors.Is(err, sql.ErrNoRows) { - return true - } - if err != nil { - t.Fatal(err) - } - return len(agents) == 0 - }, testutil.WaitShort, testutil.IntervalFast) -} - func assertEventuallyLost(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { t.Helper() assert.Eventually(t, func() bool { From b5c6d96f635508c79309e83bcd1706d75d4afe01 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Sun, 11 Aug 2024 03:56:27 +0000 Subject: [PATCH 03/13] set all peers to lost on shutdown --- coderd/database/dbauthz/dbauthz.go | 4 ++ coderd/database/dbmem/dbmem.go | 9 ++++ coderd/database/dbmetrics/dbmetrics.go | 7 +++ coderd/database/dbmock/dbmock.go | 14 ++++++ coderd/database/querier.go | 1 + coderd/database/queries.sql.go | 19 ++++++++ coderd/database/queries/tailnet.sql | 8 ++++ enterprise/tailnet/pgcoord.go | 53 +++++++++------------ enterprise/tailnet/pgcoord_test.go | 64 +++++++++++++++++++++++++- tailnet/test/peer.go | 45 ++++++++++++++---- 10 files changed, 183 insertions(+), 41 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 2f3567455fed8..6c9bc6d97c584 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -3160,6 +3160,10 @@ func (q *querier) UpdateReplica(ctx context.Context, arg database.UpdateReplicaP return q.db.UpdateReplica(ctx, arg) } +func (q *querier) UpdateTailnetPeerStatusByCoordinator(ctx context.Context, arg database.UpdateTailnetPeerStatusByCoordinatorParams) error { + panic("not implemented") +} + func (q *querier) UpdateTemplateACLByID(ctx context.Context, arg database.UpdateTemplateACLByIDParams) error { fetch := func(ctx context.Context, arg database.UpdateTemplateACLByIDParams) (database.Template, error) { return q.db.GetTemplateByID(ctx, arg.ID) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 5768379535668..bf5ffacd24696 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -7759,6 +7759,15 @@ func (q *FakeQuerier) UpdateReplica(_ context.Context, arg database.UpdateReplic return database.Replica{}, sql.ErrNoRows } +func (q *FakeQuerier) UpdateTailnetPeerStatusByCoordinator(ctx context.Context, arg database.UpdateTailnetPeerStatusByCoordinatorParams) error { + err := validateDatabaseType(arg) + if err != nil { + return err + } + + panic("not implemented") +} + func (q *FakeQuerier) UpdateTemplateACLByID(_ context.Context, arg database.UpdateTemplateACLByIDParams) error { if err := validateDatabaseType(arg); err != nil { return err diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 7b6cdb147dcf9..b579ae70d10e1 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -2041,6 +2041,13 @@ func (m metricsStore) UpdateReplica(ctx context.Context, arg database.UpdateRepl return replica, err } +func (m metricsStore) UpdateTailnetPeerStatusByCoordinator(ctx context.Context, arg database.UpdateTailnetPeerStatusByCoordinatorParams) error { + start := time.Now() + r0 := m.s.UpdateTailnetPeerStatusByCoordinator(ctx, arg) + m.queryLatencies.WithLabelValues("UpdateTailnetPeerStatusByCoordinator").Observe(time.Since(start).Seconds()) + return r0 +} + func (m metricsStore) UpdateTemplateACLByID(ctx context.Context, arg database.UpdateTemplateACLByIDParams) error { start := time.Now() err := m.s.UpdateTemplateACLByID(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index bda8186a26a4f..a5c44d11c714f 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -4307,6 +4307,20 @@ func (mr *MockStoreMockRecorder) UpdateReplica(arg0, arg1 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateReplica", reflect.TypeOf((*MockStore)(nil).UpdateReplica), arg0, arg1) } +// UpdateTailnetPeerStatusByCoordinator mocks base method. +func (m *MockStore) UpdateTailnetPeerStatusByCoordinator(arg0 context.Context, arg1 database.UpdateTailnetPeerStatusByCoordinatorParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateTailnetPeerStatusByCoordinator", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateTailnetPeerStatusByCoordinator indicates an expected call of UpdateTailnetPeerStatusByCoordinator. +func (mr *MockStoreMockRecorder) UpdateTailnetPeerStatusByCoordinator(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTailnetPeerStatusByCoordinator", reflect.TypeOf((*MockStore)(nil).UpdateTailnetPeerStatusByCoordinator), arg0, arg1) +} + // UpdateTemplateACLByID mocks base method. func (m *MockStore) UpdateTemplateACLByID(arg0 context.Context, arg1 database.UpdateTemplateACLByIDParams) error { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 2d45e154b532d..77b64eda3a585 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -413,6 +413,7 @@ type sqlcQuerier interface { UpdateProvisionerJobWithCancelByID(ctx context.Context, arg UpdateProvisionerJobWithCancelByIDParams) error UpdateProvisionerJobWithCompleteByID(ctx context.Context, arg UpdateProvisionerJobWithCompleteByIDParams) error UpdateReplica(ctx context.Context, arg UpdateReplicaParams) (Replica, error) + UpdateTailnetPeerStatusByCoordinator(ctx context.Context, arg UpdateTailnetPeerStatusByCoordinatorParams) error UpdateTemplateACLByID(ctx context.Context, arg UpdateTemplateACLByIDParams) error UpdateTemplateAccessControlByID(ctx context.Context, arg UpdateTemplateAccessControlByIDParams) error UpdateTemplateActiveVersionByID(ctx context.Context, arg UpdateTemplateActiveVersionByIDParams) error diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index d8a6e3a1abb03..ae7fdd1f69948 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -7354,6 +7354,25 @@ func (q *sqlQuerier) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUI return items, nil } +const updateTailnetPeerStatusByCoordinator = `-- name: UpdateTailnetPeerStatusByCoordinator :exec +UPDATE + tailnet_peers +SET + status = $2 +WHERE + coordinator_id = $1 +` + +type UpdateTailnetPeerStatusByCoordinatorParams struct { + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` + Status TailnetStatus `db:"status" json:"status"` +} + +func (q *sqlQuerier) UpdateTailnetPeerStatusByCoordinator(ctx context.Context, arg UpdateTailnetPeerStatusByCoordinatorParams) error { + _, err := q.db.ExecContext(ctx, updateTailnetPeerStatusByCoordinator, arg.CoordinatorID, arg.Status) + return err +} + const upsertTailnetAgent = `-- name: UpsertTailnetAgent :one INSERT INTO tailnet_agents ( diff --git a/coderd/database/queries/tailnet.sql b/coderd/database/queries/tailnet.sql index 767b966cbbce3..07936e277bc52 100644 --- a/coderd/database/queries/tailnet.sql +++ b/coderd/database/queries/tailnet.sql @@ -149,6 +149,14 @@ DO UPDATE SET updated_at = now() at time zone 'utc' RETURNING *; +-- name: UpdateTailnetPeerStatusByCoordinator :exec +UPDATE + tailnet_peers +SET + status = $2 +WHERE + coordinator_id = $1; + -- name: DeleteTailnetPeer :one DELETE FROM tailnet_peers diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 5de8dad1f3306..ad6e5924a5daa 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -26,19 +26,18 @@ import ( ) const ( - EventHeartbeats = "tailnet_coordinator_heartbeat" - eventPeerUpdate = "tailnet_peer_update" - eventTunnelUpdate = "tailnet_tunnel_update" - eventReadyForHandshake = "tailnet_ready_for_handshake" - HeartbeatPeriod = time.Second * 2 - MissedHeartbeats = 3 - unhealthyHeartbeatThreshold = 3 - numQuerierWorkers = 10 - numBinderWorkers = 10 - numTunnelerWorkers = 10 - numHandshakerWorkers = 5 - dbMaxBackoff = 10 * time.Second - cleanupPeriod = time.Hour + EventHeartbeats = "tailnet_coordinator_heartbeat" + eventPeerUpdate = "tailnet_peer_update" + eventTunnelUpdate = "tailnet_tunnel_update" + eventReadyForHandshake = "tailnet_ready_for_handshake" + HeartbeatPeriod = time.Second * 2 + MissedHeartbeats = 3 + numQuerierWorkers = 10 + numBinderWorkers = 10 + numTunnelerWorkers = 10 + numHandshakerWorkers = 5 + dbMaxBackoff = 10 * time.Second + cleanupPeriod = time.Hour ) // pgCoord is a postgres-backed coordinator @@ -521,7 +520,14 @@ func (b *binder) handleBindings() { for { select { case <-b.ctx.Done(): - b.logger.Debug(b.ctx, "binder exiting", slog.Error(b.ctx.Err())) + b.logger.Debug(b.ctx, "binder exiting, updating peers to lost", slog.Error(b.ctx.Err())) + err := b.store.UpdateTailnetPeerStatusByCoordinator(context.Background(), database.UpdateTailnetPeerStatusByCoordinatorParams{ + CoordinatorID: b.coordinatorID, + Status: database.TailnetStatusLost, + }) + if err != nil { + b.logger.Error(b.ctx, "update peer status to lost", slog.Error(err)) + } return case bnd := <-b.bindings: b.storeBinding(bnd) @@ -1655,10 +1661,6 @@ func (h *heartbeats) sendBeats() { h.logger.Debug(h.ctx, "ending heartbeats", slog.Error(err)) // This is unlikely to succeed if we're unhealthy but // we get it our best effort. - if h.failedHeartbeats >= unhealthyHeartbeatThreshold { - h.logger.Debug(h.ctx, "coordinator detected unhealthy, deleting self", slog.Error(err)) - h.sendDelete() - } } func (h *heartbeats) sendBeat() { @@ -1669,31 +1671,20 @@ func (h *heartbeats) sendBeat() { if err != nil { h.logger.Error(h.ctx, "failed to send heartbeat", slog.Error(err)) h.failedHeartbeats++ - if h.failedHeartbeats == unhealthyHeartbeatThreshold { + if h.failedHeartbeats == 3 { h.logger.Error(h.ctx, "coordinator failed 3 heartbeats and is unhealthy") _ = agpl.SendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateUnhealthy}) } return } h.logger.Debug(h.ctx, "sent heartbeat") - if h.failedHeartbeats >= unhealthyHeartbeatThreshold { + if h.failedHeartbeats >= 3 { h.logger.Info(h.ctx, "coordinator sent heartbeat and is healthy") _ = agpl.SendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateHealthy}) } h.failedHeartbeats = 0 } -func (h *heartbeats) sendDelete() { - // here we don't want to use the main context, since it will have been canceled - ctx := dbauthz.As(context.Background(), pgCoordSubject) - err := h.store.DeleteCoordinator(ctx, h.self) - if err != nil { - h.logger.Error(h.ctx, "failed to send coordinator delete", slog.Error(err)) - return - } - h.logger.Debug(h.ctx, "deleted coordinator") -} - func (h *heartbeats) cleanupLoop() { defer h.wg.Done() h.cleanup() diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index 2bbd042830486..af7a124cf3c81 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -29,6 +29,7 @@ import ( "github.com/coder/coder/v2/enterprise/tailnet" agpl "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet/proto" + "github.com/coder/coder/v2/tailnet/test" agpltest "github.com/coder/coder/v2/tailnet/test" "github.com/coder/coder/v2/testutil" "github.com/coder/quartz" @@ -591,8 +592,10 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) { require.ErrorIs(t, err, io.EOF) err = client21.recvErr(ctx, t) require.ErrorIs(t, err, io.EOF) + assertEventuallyLost(ctx, t, store, agent2.id) + assertEventuallyLost(ctx, t, store, client21.id) + assertEventuallyLost(ctx, t, store, client22.id) - t.Logf("close coord1") err = coord1.Close() require.NoError(t, err) // this closes agent1, client12, client11 @@ -602,6 +605,9 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) { require.ErrorIs(t, err, io.EOF) err = client11.recvErr(ctx, t) require.ErrorIs(t, err, io.EOF) + assertEventuallyLost(ctx, t, store, agent1.id) + assertEventuallyLost(ctx, t, store, client11.id) + assertEventuallyLost(ctx, t, store, client12.id) // wait for all connections to close err = agent1.close() @@ -890,6 +896,7 @@ func TestPGCoordinator_NoDeleteOnClose(t *testing.T) { require.Len(t, agentNodes, 1) assert.Equal(t, 10, agentNodes[0].PreferredDERP) client.sendNode(&agpl.Node{PreferredDERP: 11}) + clientNodes := agent.recvNodes(ctx, t) require.Len(t, clientNodes, 1) assert.Equal(t, 11, clientNodes[0].PreferredDERP) @@ -908,6 +915,61 @@ func TestPGCoordinator_NoDeleteOnClose(t *testing.T) { anode = coordinator2.Node(agent.id) require.NotNil(t, anode) + assert.Equal(t, 10, anode.PreferredDERP) + + cnode = coordinator2.Node(client.id) + require.NotNil(t, cnode) + assert.Equal(t, 11, cnode.PreferredDERP) +} + +func TestPGCoordinatorPeerReconnect(t *testing.T) { + t.Parallel() + + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + + store, ps := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + // Create two coordinators, 1 for each peer. + c1, err := tailnet.NewPGCoord(ctx, logger, ps, store) + require.NoError(t, err) + c2, err := tailnet.NewPGCoord(ctx, logger, ps, store) + require.NoError(t, err) + + p1 := test.NewPeer(ctx, t, c1, "peer1") + p2 := test.NewPeer(ctx, t, c2, "peer2") + + // Create a binding between the two. + p1.AddTunnel(p2.ID) + + // Ensure that messages pass through. + p1.UpdateDERP(1) + p2.UpdateDERP(2) + p1.AssertEventuallyHasDERP(p2.ID, 2) + p2.AssertEventuallyHasDERP(p1.ID, 1) + + // Close coordinator1. Now we will check that we + // never send a DISCONNECTED update. + err = c1.Close() + require.NoError(t, err) + p1.AssertEventuallyResponsesClosed() + + // Connect peer1 to coordinator2. + p1.ConnectToCoordinator(ctx, c2) + // Reestablish binding. + p1.AddTunnel(p2.ID) + // Ensure messages still flow back and forth. + p1.AssertEventuallyHasDERP(p2.ID, 2) + p1.UpdateDERP(3) + p2.UpdateDERP(4) + p2.AssertEventuallyHasDERP(p1.ID, 3) + p1.AssertEventuallyHasDERP(p2.ID, 4) + // Make sure peer2 never got an update about peer1 disconnecting. + p2.AssertNeverUpdateKind(p1.ID, proto.CoordinateResponse_PeerUpdate_DISCONNECTED) } type testConn struct { diff --git a/tailnet/test/peer.go b/tailnet/test/peer.go index 791c3b0e9176d..df9f9c5d58118 100644 --- a/tailnet/test/peer.go +++ b/tailnet/test/peer.go @@ -6,6 +6,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/xerrors" "github.com/coder/coder/v2/tailnet" @@ -19,18 +20,24 @@ type PeerStatus struct { } type Peer struct { - ctx context.Context - cancel context.CancelFunc - t testing.TB - ID uuid.UUID - name string - resps <-chan *proto.CoordinateResponse - reqs chan<- *proto.CoordinateRequest - peers map[uuid.UUID]PeerStatus + ctx context.Context + cancel context.CancelFunc + t testing.TB + ID uuid.UUID + name string + resps <-chan *proto.CoordinateResponse + reqs chan<- *proto.CoordinateRequest + peers map[uuid.UUID]PeerStatus + peerUpdates map[uuid.UUID][]*proto.CoordinateResponse_PeerUpdate } func NewPeer(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, name string, id ...uuid.UUID) *Peer { - p := &Peer{t: t, name: name, peers: make(map[uuid.UUID]PeerStatus)} + p := &Peer{ + t: t, + name: name, + peers: make(map[uuid.UUID]PeerStatus), + peerUpdates: make(map[uuid.UUID][]*proto.CoordinateResponse_PeerUpdate), + } p.ctx, p.cancel = context.WithCancel(ctx) if len(id) > 1 { t.Fatal("too many") @@ -45,6 +52,12 @@ func NewPeer(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, nam return p } +func (p *Peer) ConnectToCoordinator(ctx context.Context, c tailnet.CoordinatorV2) { + p.t.Helper() + + p.reqs, p.resps = c.Coordinate(ctx, p.ID, p.name, tailnet.SingleTailnetCoordinateeAuth{}) +} + func (p *Peer) AddTunnel(other uuid.UUID) { p.t.Helper() req := &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: tailnet.UUIDToByteSlice(other)}} @@ -180,6 +193,19 @@ func (p *Peer) AssertEventuallyGetsError(match string) { } } +// AssertNeverUpdateKind asserts that we have not received +// any updates on the provided peer for the provided kind. +func (p *Peer) AssertNeverUpdateKind(peer uuid.UUID, kind proto.CoordinateResponse_PeerUpdate_Kind) { + p.t.Helper() + + updates, ok := p.peerUpdates[peer] + require.True(p.t, ok, "expected updates for peer %s", peer) + + for _, update := range updates { + assert.NotEqual(p.t, kind, update.Kind, update) + } +} + var responsesClosed = xerrors.New("responses closed") func (p *Peer) handleOneResp() error { @@ -213,6 +239,7 @@ func (p *Peer) handleOneResp() error { default: return xerrors.Errorf("unhandled update kind %s", update.Kind) } + p.peerUpdates[id] = append(p.peerUpdates[id], update) } } return nil From b1f6aea4e914ee95fc4aaa50b6814f150266a361 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Sun, 11 Aug 2024 23:34:24 +0000 Subject: [PATCH 04/13] add heartbeat failure test --- coderd/database/dbauthz/dbauthz.go | 5 +- coderd/database/dbauthz/dbauthz_test.go | 5 ++ coderd/database/dbmem/dbmem.go | 9 +-- coderd/database/dbtestutil/db.go | 10 +++ enterprise/tailnet/pgcoord.go | 2 - enterprise/tailnet/pgcoord_internal_test.go | 9 +-- enterprise/tailnet/pgcoord_test.go | 79 +++++++++++++++++++-- 7 files changed, 100 insertions(+), 19 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 6c9bc6d97c584..92d495e040d86 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -3161,7 +3161,10 @@ func (q *querier) UpdateReplica(ctx context.Context, arg database.UpdateReplicaP } func (q *querier) UpdateTailnetPeerStatusByCoordinator(ctx context.Context, arg database.UpdateTailnetPeerStatusByCoordinatorParams) error { - panic("not implemented") + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { + return err + } + return q.db.UpdateTailnetPeerStatusByCoordinator(ctx, arg) } func (q *querier) UpdateTemplateACLByID(ctx context.Context, arg database.UpdateTemplateACLByIDParams) error { diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 95d1bbcdb7f18..8161280c913e7 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -2053,6 +2053,11 @@ func (s *MethodTestSuite) TestTailnetFunctions() { Asserts(rbac.ResourceTailnetCoordinator, policy.ActionCreate). Errors(dbmem.ErrUnimplemented) })) + s.Run("UpdateTailnetPeerStatusByCoordinator", s.Subtest(func(_ database.Store, check *expects) { + check.Args(database.UpdateTailnetPeerStatusByCoordinatorParams{}). + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate). + Errors(dbmem.ErrUnimplemented) + })) } func (s *MethodTestSuite) TestDBCrypt() { diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index bf5ffacd24696..b7bff20bf014c 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -7759,13 +7759,8 @@ func (q *FakeQuerier) UpdateReplica(_ context.Context, arg database.UpdateReplic return database.Replica{}, sql.ErrNoRows } -func (q *FakeQuerier) UpdateTailnetPeerStatusByCoordinator(ctx context.Context, arg database.UpdateTailnetPeerStatusByCoordinatorParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - panic("not implemented") +func (*FakeQuerier) UpdateTailnetPeerStatusByCoordinator(context.Context, database.UpdateTailnetPeerStatusByCoordinatorParams) error { + return ErrUnimplemented } func (q *FakeQuerier) UpdateTemplateACLByID(_ context.Context, arg database.UpdateTemplateACLByIDParams) error { diff --git a/coderd/database/dbtestutil/db.go b/coderd/database/dbtestutil/db.go index 16eb3393ca346..98b4654760af4 100644 --- a/coderd/database/dbtestutil/db.go +++ b/coderd/database/dbtestutil/db.go @@ -35,6 +35,7 @@ type options struct { dumpOnFailure bool returnSQLDB func(*sql.DB) logger slog.Logger + url string } type Option func(*options) @@ -59,6 +60,12 @@ func WithLogger(logger slog.Logger) Option { } } +func WithURL(url string) Option { + return func(o *options) { + o.url = url + } +} + func withReturnSQLDB(f func(*sql.DB)) Option { return func(o *options) { o.returnSQLDB = f @@ -92,6 +99,9 @@ func NewDB(t testing.TB, opts ...Option) (database.Store, pubsub.Pubsub) { ps := pubsub.NewInMemory() if WillUsePostgres() { connectionURL := os.Getenv("CODER_PG_CONNECTION_URL") + if connectionURL == "" && o.url != "" { + connectionURL = o.url + } if connectionURL == "" { var ( err error diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index ad6e5924a5daa..cbccd05366edd 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -1659,8 +1659,6 @@ func (h *heartbeats) sendBeats() { }, "heartbeats", "sendBeats") err := tkr.Wait() h.logger.Debug(h.ctx, "ending heartbeats", slog.Error(err)) - // This is unlikely to succeed if we're unhealthy but - // we get it our best effort. } func (h *heartbeats) sendBeat() { diff --git a/enterprise/tailnet/pgcoord_internal_test.go b/enterprise/tailnet/pgcoord_internal_test.go index 253487d28d196..745479d3160a0 100644 --- a/enterprise/tailnet/pgcoord_internal_test.go +++ b/enterprise/tailnet/pgcoord_internal_test.go @@ -396,10 +396,6 @@ func TestPGCoordinatorUnhealthy(t *testing.T) { UpsertTailnetCoordinator(gomock.Any(), gomock.Any()). Times(3). Return(database.TailnetCoordinator{}, xerrors.New("badness")) - mStore.EXPECT(). - DeleteCoordinator(gomock.Any(), gomock.Any()). - Times(1). - Return(nil) // But, in particular we DO NOT want the coordinator to call DeleteTailnetPeer, as this is // unnecessary and can spam the database. c.f. https://github.com/coder/coder/issues/12923 @@ -411,6 +407,11 @@ func TestPGCoordinatorUnhealthy(t *testing.T) { coordinator, err := newPGCoordInternal(ctx, logger, ps, mStore, mClock) require.NoError(t, err) + mStore.EXPECT().UpdateTailnetPeerStatusByCoordinator(gomock.Any(), database.UpdateTailnetPeerStatusByCoordinatorParams{ + CoordinatorID: coordinator.id, + Status: database.TailnetStatusLost, + }) + expectedPeriod := HeartbeatPeriod tfCall, err := tfTrap.Wait(ctx) require.NoError(t, err) diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index af7a124cf3c81..54d8ae77b21d7 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -29,7 +29,6 @@ import ( "github.com/coder/coder/v2/enterprise/tailnet" agpl "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet/proto" - "github.com/coder/coder/v2/tailnet/test" agpltest "github.com/coder/coder/v2/tailnet/test" "github.com/coder/coder/v2/testutil" "github.com/coder/quartz" @@ -746,6 +745,7 @@ func TestPGCoordinator_Unhealthy(t *testing.T) { mStore.EXPECT().DeleteTailnetPeer(gomock.Any(), gomock.Any()). AnyTimes().Return(database.DeleteTailnetPeerRow{}, nil) mStore.EXPECT().DeleteAllTailnetTunnels(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + mStore.EXPECT().UpdateTailnetPeerStatusByCoordinator(gomock.Any(), gomock.Any()) uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore) require.NoError(t, err) @@ -810,7 +810,7 @@ func TestPGCoordinator_Node_Empty(t *testing.T) { mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil) mStore.EXPECT().CleanTailnetLostPeers(gomock.Any()).AnyTimes().Return(nil) mStore.EXPECT().CleanTailnetTunnels(gomock.Any()).AnyTimes().Return(nil) - mStore.EXPECT().DeleteCoordinator(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + mStore.EXPECT().UpdateTailnetPeerStatusByCoordinator(gomock.Any(), gomock.Any()).Times(1) uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore) require.NoError(t, err) @@ -922,7 +922,75 @@ func TestPGCoordinator_NoDeleteOnClose(t *testing.T) { assert.Equal(t, 11, cnode.PreferredDERP) } -func TestPGCoordinatorPeerReconnect(t *testing.T) { +// TestPGCoordinatorDual_FailedHeartbeat tests that peers +// disconnect from a coordinator when they are unhealthy, +// are marked as LOST (not DISCONNECTED), and can reconnect to +// a new coordinator and reestablish their tunnels. +func TestPGCoordinatorDual_FailedHeartbeat(t *testing.T) { + t.Parallel() + + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + + dburl, closeFn, err := dbtestutil.Open() + require.NoError(t, err) + t.Cleanup(closeFn) + + store1, ps1, sdb1 := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithURL(dburl)) + defer sdb1.Close() + store2, ps2, sdb2 := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithURL(dburl)) + defer sdb2.Close() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + t.Cleanup(cancel) + + // We do this to avoid failing due errors related to the + // database connection being close. + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + + // Create two coordinators, 1 for each peer. + c1, err := tailnet.NewPGCoord(ctx, logger, ps1, store1) + require.NoError(t, err) + c2, err := tailnet.NewPGCoord(ctx, logger, ps2, store2) + require.NoError(t, err) + + p1 := agpltest.NewPeer(ctx, t, c1, "peer1") + p2 := agpltest.NewPeer(ctx, t, c2, "peer2") + + // Create a binding between the two. + p1.AddTunnel(p2.ID) + + // Ensure that messages pass through. + p1.UpdateDERP(1) + p2.UpdateDERP(2) + p1.AssertEventuallyHasDERP(p2.ID, 2) + p2.AssertEventuallyHasDERP(p1.ID, 1) + + // Close the underlying database connection to induce + // a heartbeat failure scenario and assert that + // we eventually disconnect from the coordinator. + err = sdb1.Close() + require.NoError(t, err) + p1.AssertEventuallyResponsesClosed() + p2.AssertEventuallyLost(p1.ID) + + // Connect peer1 to coordinator2. + p1.ConnectToCoordinator(ctx, c2) + // Reestablish binding. + p1.AddTunnel(p2.ID) + // Ensure messages still flow back and forth. + p1.AssertEventuallyHasDERP(p2.ID, 2) + p1.UpdateDERP(3) + p2.UpdateDERP(4) + p2.AssertEventuallyHasDERP(p1.ID, 3) + p1.AssertEventuallyHasDERP(p2.ID, 4) + // Make sure peer2 never got an update about peer1 disconnecting. + p2.AssertNeverUpdateKind(p1.ID, proto.CoordinateResponse_PeerUpdate_DISCONNECTED) + +} + +func TestPGCoordinatorDual_PeerReconnect(t *testing.T) { t.Parallel() if !dbtestutil.WillUsePostgres() { @@ -940,8 +1008,8 @@ func TestPGCoordinatorPeerReconnect(t *testing.T) { c2, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) - p1 := test.NewPeer(ctx, t, c1, "peer1") - p2 := test.NewPeer(ctx, t, c2, "peer2") + p1 := agpltest.NewPeer(ctx, t, c1, "peer1") + p2 := agpltest.NewPeer(ctx, t, c2, "peer2") // Create a binding between the two. p1.AddTunnel(p2.ID) @@ -957,6 +1025,7 @@ func TestPGCoordinatorPeerReconnect(t *testing.T) { err = c1.Close() require.NoError(t, err) p1.AssertEventuallyResponsesClosed() + p2.AssertEventuallyLost(p1.ID) // Connect peer1 to coordinator2. p1.ConnectToCoordinator(ctx, c2) From 1c905c716527819cce2cdfe4dd6113836fd03255 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Sun, 11 Aug 2024 23:41:20 +0000 Subject: [PATCH 05/13] delete stale comments --- enterprise/tailnet/pgcoord.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index cbccd05366edd..3b63d12735b25 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -144,9 +144,6 @@ func newPGCoordInternal( // signals when first heartbeat has been sent, so it's safe to start binding. fHB := make(chan struct{}) - // we need to arrange for the querier to stop _after_ the tunneler and binder, since we delete - // the coordinator when the querier stops (via the heartbeats). If the tunneler and binder are - // still running, they could run afoul of foreign key constraints. querierCtx, querierCancel := context.WithCancel(dbauthz.As(context.Background(), pgCoordSubject)) c := &pgCoord{ ctx: ctx, @@ -168,8 +165,9 @@ func newPGCoordInternal( } go func() { // when the main context is canceled, or the coordinator closed, the binder, tunneler, and - // handshaker always eventually stop. Once they stop it's safe to cancel the querier context, which - // has the effect of deleting the coordinator from the database and ceasing heartbeats. + // handshaker always eventually stop. When the + // binder stops it updates all the peers handled + // by this coordinator to LOST. c.binder.workerWG.Wait() c.tunneler.workerWG.Wait() c.handshaker.workerWG.Wait() From fadde3403934c24a578b5ac99329f175d5d94ffa Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Sun, 11 Aug 2024 23:43:04 +0000 Subject: [PATCH 06/13] lint --- coderd/database/dbtestutil/db.go | 4 ++-- enterprise/tailnet/pgcoord_test.go | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/coderd/database/dbtestutil/db.go b/coderd/database/dbtestutil/db.go index 98b4654760af4..327d880f69648 100644 --- a/coderd/database/dbtestutil/db.go +++ b/coderd/database/dbtestutil/db.go @@ -60,9 +60,9 @@ func WithLogger(logger slog.Logger) Option { } } -func WithURL(url string) Option { +func WithURL(u string) Option { return func(o *options) { - o.url = url + o.url = u } } diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index 54d8ae77b21d7..e22d003eb5e59 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -987,7 +987,6 @@ func TestPGCoordinatorDual_FailedHeartbeat(t *testing.T) { p1.AssertEventuallyHasDERP(p2.ID, 4) // Make sure peer2 never got an update about peer1 disconnecting. p2.AssertNeverUpdateKind(p1.ID, proto.CoordinateResponse_PeerUpdate_DISCONNECTED) - } func TestPGCoordinatorDual_PeerReconnect(t *testing.T) { From 159b509b0890886d29c9480a0894f900a84dbffa Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Mon, 12 Aug 2024 00:03:18 +0000 Subject: [PATCH 07/13] extra checks --- enterprise/tailnet/pgcoord_test.go | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index e22d003eb5e59..7ed0b7a7d18cd 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -908,6 +908,8 @@ func TestPGCoordinator_NoDeleteOnClose(t *testing.T) { err = coordinator.Close() require.NoError(t, err) + assertEventuallyLost(ctx, t, store, agent.id) + assertEventuallyLost(ctx, t, store, client.id) coordinator2, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) @@ -974,6 +976,10 @@ func TestPGCoordinatorDual_FailedHeartbeat(t *testing.T) { require.NoError(t, err) p1.AssertEventuallyResponsesClosed() p2.AssertEventuallyLost(p1.ID) + // This basically checks that peer2 had no update + // performed on their status since we are connected + // to coordinator2. + assertEventuallyStatus(ctx, t, store2, p2.ID, database.TailnetStatusOk) // Connect peer1 to coordinator2. p1.ConnectToCoordinator(ctx, c2) @@ -1025,6 +1031,10 @@ func TestPGCoordinatorDual_PeerReconnect(t *testing.T) { require.NoError(t, err) p1.AssertEventuallyResponsesClosed() p2.AssertEventuallyLost(p1.ID) + // This basically checks that peer2 had no update + // performed on their status since we are connected + // to coordinator2. + assertEventuallyStatus(ctx, t, store, p2.ID, database.TailnetStatusOk) // Connect peer1 to coordinator2. p1.ConnectToCoordinator(ctx, c2) @@ -1178,7 +1188,7 @@ func assertNeverHasDERPs(ctx context.Context, t *testing.T, c *testConn, expecte } } -func assertEventuallyLost(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { +func assertEventuallyStatus(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID, status database.TailnetStatus) { t.Helper() assert.Eventually(t, func() bool { peers, err := store.GetTailnetPeers(ctx, agentID) @@ -1189,7 +1199,7 @@ func assertEventuallyLost(ctx context.Context, t *testing.T, store database.Stor t.Fatal(err) } for _, peer := range peers { - if peer.Status == database.TailnetStatusOk { + if peer.Status != status { return false } } @@ -1197,6 +1207,11 @@ func assertEventuallyLost(ctx context.Context, t *testing.T, store database.Stor }, testutil.WaitShort, testutil.IntervalFast) } +func assertEventuallyLost(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { + t.Helper() + assertEventuallyStatus(ctx, t, store, agentID, database.TailnetStatusLost) +} + func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { t.Helper() assert.Eventually(t, func() bool { From 04472e324adf0c53f44845d62f4db02ab89810ec Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Mon, 12 Aug 2024 18:35:13 +0000 Subject: [PATCH 08/13] pr comments --- enterprise/tailnet/pgcoord_internal_test.go | 6 +----- tailnet/test/peer.go | 6 +++--- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/enterprise/tailnet/pgcoord_internal_test.go b/enterprise/tailnet/pgcoord_internal_test.go index 745479d3160a0..dec6d95e26178 100644 --- a/enterprise/tailnet/pgcoord_internal_test.go +++ b/enterprise/tailnet/pgcoord_internal_test.go @@ -403,15 +403,11 @@ func TestPGCoordinatorUnhealthy(t *testing.T) { mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil) mStore.EXPECT().CleanTailnetLostPeers(gomock.Any()).AnyTimes().Return(nil) mStore.EXPECT().CleanTailnetTunnels(gomock.Any()).AnyTimes().Return(nil) + mStore.EXPECT().UpdateTailnetPeerStatusByCoordinator(gomock.Any(), gomock.Any()) coordinator, err := newPGCoordInternal(ctx, logger, ps, mStore, mClock) require.NoError(t, err) - mStore.EXPECT().UpdateTailnetPeerStatusByCoordinator(gomock.Any(), database.UpdateTailnetPeerStatusByCoordinatorParams{ - CoordinatorID: coordinator.id, - Status: database.TailnetStatusLost, - }) - expectedPeriod := HeartbeatPeriod tfCall, err := tfTrap.Wait(ctx) require.NoError(t, err) diff --git a/tailnet/test/peer.go b/tailnet/test/peer.go index df9f9c5d58118..1b08d6886ae98 100644 --- a/tailnet/test/peer.go +++ b/tailnet/test/peer.go @@ -6,7 +6,6 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "golang.org/x/xerrors" "github.com/coder/coder/v2/tailnet" @@ -199,7 +198,7 @@ func (p *Peer) AssertNeverUpdateKind(peer uuid.UUID, kind proto.CoordinateRespon p.t.Helper() updates, ok := p.peerUpdates[peer] - require.True(p.t, ok, "expected updates for peer %s", peer) + assert.True(p.t, ok, "expected updates for peer %s", peer) for _, update := range updates { assert.NotEqual(p.t, kind, update.Kind, update) @@ -224,6 +223,8 @@ func (p *Peer) handleOneResp() error { if err != nil { return err } + p.peerUpdates[id] = append(p.peerUpdates[id], update) + switch update.Kind { case proto.CoordinateResponse_PeerUpdate_NODE, proto.CoordinateResponse_PeerUpdate_LOST: peer := p.peers[id] @@ -239,7 +240,6 @@ func (p *Peer) handleOneResp() error { default: return xerrors.Errorf("unhandled update kind %s", update.Kind) } - p.peerUpdates[id] = append(p.peerUpdates[id], update) } } return nil From eb52aa21423a2b75ddcfb5a1cee6f1ccb5f3c6ce Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Tue, 13 Aug 2024 14:36:24 +0000 Subject: [PATCH 09/13] wait for workers to exit --- enterprise/tailnet/pgcoord.go | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 3b63d12735b25..2d0b8a0a11f21 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -164,11 +164,7 @@ func newPGCoordInternal( closed: make(chan struct{}), } go func() { - // when the main context is canceled, or the coordinator closed, the binder, tunneler, and - // handshaker always eventually stop. When the - // binder stops it updates all the peers handled - // by this coordinator to LOST. - c.binder.workerWG.Wait() + c.binder.wait() c.tunneler.workerWG.Wait() c.handshaker.workerWG.Wait() querierCancel() @@ -518,14 +514,7 @@ func (b *binder) handleBindings() { for { select { case <-b.ctx.Done(): - b.logger.Debug(b.ctx, "binder exiting, updating peers to lost", slog.Error(b.ctx.Err())) - err := b.store.UpdateTailnetPeerStatusByCoordinator(context.Background(), database.UpdateTailnetPeerStatusByCoordinatorParams{ - CoordinatorID: b.coordinatorID, - Status: database.TailnetStatusLost, - }) - if err != nil { - b.logger.Error(b.ctx, "update peer status to lost", slog.Error(err)) - } + b.logger.Debug(b.ctx, "binder exiting") return case bnd := <-b.bindings: b.storeBinding(bnd) @@ -637,6 +626,20 @@ func (b *binder) retrieveBinding(bk bKey) binding { return bnd } +func (b *binder) wait() { + b.workerWG.Wait() + + b.logger.Debug(b.ctx, "binder exiting, updating peers to lost", slog.Error(b.ctx.Err())) + + err := b.store.UpdateTailnetPeerStatusByCoordinator(context.Background(), database.UpdateTailnetPeerStatusByCoordinatorParams{ + CoordinatorID: b.coordinatorID, + Status: database.TailnetStatusLost, + }) + if err != nil { + b.logger.Error(b.ctx, "update peer status to lost", slog.Error(err)) + } +} + // mapper tracks data sent to a peer, and sends updates based on changes read from the database. type mapper struct { ctx context.Context From 3173e4492a21d481cf8ebb5efca9984f78218247 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Tue, 13 Aug 2024 14:44:27 +0000 Subject: [PATCH 10/13] simplify ctx management --- enterprise/tailnet/pgcoord.go | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 2d0b8a0a11f21..6ebe1ad8e04c6 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -144,7 +144,6 @@ func newPGCoordInternal( // signals when first heartbeat has been sent, so it's safe to start binding. fHB := make(chan struct{}) - querierCtx, querierCancel := context.WithCancel(dbauthz.As(context.Background(), pgCoordSubject)) c := &pgCoord{ ctx: ctx, cancel: cancel, @@ -160,15 +159,9 @@ func newPGCoordInternal( handshaker: newHandshaker(ctx, logger, id, ps, rfhCh, fHB), handshakerCh: rfhCh, id: id, - querier: newQuerier(querierCtx, logger, id, ps, store, id, cCh, ccCh, numQuerierWorkers, fHB, clk), + querier: newQuerier(ctx, logger, id, ps, store, id, cCh, ccCh, numQuerierWorkers, fHB, clk), closed: make(chan struct{}), } - go func() { - c.binder.wait() - c.tunneler.workerWG.Wait() - c.handshaker.workerWG.Wait() - querierCancel() - }() logger.Info(ctx, "starting coordinator") return c, nil } @@ -233,6 +226,9 @@ func (c *pgCoord) Close() error { c.cancel() c.closeOnce.Do(func() { close(c.closed) }) c.querier.wait() + c.binder.wait() + c.tunneler.workerWG.Wait() + c.handshaker.workerWG.Wait() return nil } From 4754554056d414ddc43cbcbf24e3b690a8cf0005 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 14 Aug 2024 15:54:04 +0000 Subject: [PATCH 11/13] delete peers once --- enterprise/tailnet/pgcoord.go | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 6ebe1ad8e04c6..a8c68341976ce 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -475,6 +475,7 @@ type binder struct { workQ *workQ[bKey] workerWG sync.WaitGroup + close chan struct{} } func newBinder(ctx context.Context, @@ -503,6 +504,26 @@ func newBinder(ctx context.Context, go b.worker() } }() + + go func() { + defer close(b.close) + <-ctx.Done() + b.logger.Debug(b.ctx, "binder exiting, waiting for workers") + + b.workerWG.Wait() + + b.logger.Debug(b.ctx, "updating peers to lost") + + ctx, cancel := context.WithTimeout(ctx, time.Second*15) + defer cancel() + err := b.store.UpdateTailnetPeerStatusByCoordinator(ctx, database.UpdateTailnetPeerStatusByCoordinatorParams{ + CoordinatorID: b.coordinatorID, + Status: database.TailnetStatusLost, + }) + if err != nil { + b.logger.Error(b.ctx, "update peer status to lost", slog.Error(err)) + } + }() return b } @@ -623,17 +644,7 @@ func (b *binder) retrieveBinding(bk bKey) binding { } func (b *binder) wait() { - b.workerWG.Wait() - - b.logger.Debug(b.ctx, "binder exiting, updating peers to lost", slog.Error(b.ctx.Err())) - - err := b.store.UpdateTailnetPeerStatusByCoordinator(context.Background(), database.UpdateTailnetPeerStatusByCoordinatorParams{ - CoordinatorID: b.coordinatorID, - Status: database.TailnetStatusLost, - }) - if err != nil { - b.logger.Error(b.ctx, "update peer status to lost", slog.Error(err)) - } + <-b.close } // mapper tracks data sent to a peer, and sends updates based on changes read from the database. From 09133928b143a5c09c5c630ca04aca9e89869307 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 14 Aug 2024 16:30:30 +0000 Subject: [PATCH 12/13] instantiate channel --- enterprise/tailnet/pgcoord.go | 1 + 1 file changed, 1 insertion(+) diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index a8c68341976ce..f30e7700db3f6 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -493,6 +493,7 @@ func newBinder(ctx context.Context, bindings: bindings, latest: make(map[bKey]binding), workQ: newWorkQ[bKey](ctx), + close: make(chan struct{}), } go b.handleBindings() // add to the waitgroup immediately to avoid any races waiting for it before From 8a65a2f3465314d9b44ec98242f0c7d59a7fd9f8 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 14 Aug 2024 17:02:50 +0000 Subject: [PATCH 13/13] wrong ctx --- enterprise/tailnet/pgcoord.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index f30e7700db3f6..be4722a02f317 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -508,14 +508,14 @@ func newBinder(ctx context.Context, go func() { defer close(b.close) - <-ctx.Done() + <-b.ctx.Done() b.logger.Debug(b.ctx, "binder exiting, waiting for workers") b.workerWG.Wait() b.logger.Debug(b.ctx, "updating peers to lost") - ctx, cancel := context.WithTimeout(ctx, time.Second*15) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) defer cancel() err := b.store.UpdateTailnetPeerStatusByCoordinator(ctx, database.UpdateTailnetPeerStatusByCoordinatorParams{ CoordinatorID: b.coordinatorID,