diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 2f3567455fed8..92d495e040d86 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -3160,6 +3160,13 @@ func (q *querier) UpdateReplica(ctx context.Context, arg database.UpdateReplicaP return q.db.UpdateReplica(ctx, arg) } +func (q *querier) UpdateTailnetPeerStatusByCoordinator(ctx context.Context, arg database.UpdateTailnetPeerStatusByCoordinatorParams) error { + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { + return err + } + return q.db.UpdateTailnetPeerStatusByCoordinator(ctx, arg) +} + func (q *querier) UpdateTemplateACLByID(ctx context.Context, arg database.UpdateTemplateACLByIDParams) error { fetch := func(ctx context.Context, arg database.UpdateTemplateACLByIDParams) (database.Template, error) { return q.db.GetTemplateByID(ctx, arg.ID) diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 95d1bbcdb7f18..8161280c913e7 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -2053,6 +2053,11 @@ func (s *MethodTestSuite) TestTailnetFunctions() { Asserts(rbac.ResourceTailnetCoordinator, policy.ActionCreate). Errors(dbmem.ErrUnimplemented) })) + s.Run("UpdateTailnetPeerStatusByCoordinator", s.Subtest(func(_ database.Store, check *expects) { + check.Args(database.UpdateTailnetPeerStatusByCoordinatorParams{}). + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate). + Errors(dbmem.ErrUnimplemented) + })) } func (s *MethodTestSuite) TestDBCrypt() { diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 5768379535668..b7bff20bf014c 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -7759,6 +7759,10 @@ func (q *FakeQuerier) UpdateReplica(_ context.Context, arg database.UpdateReplic return database.Replica{}, sql.ErrNoRows } +func (*FakeQuerier) UpdateTailnetPeerStatusByCoordinator(context.Context, database.UpdateTailnetPeerStatusByCoordinatorParams) error { + return ErrUnimplemented +} + func (q *FakeQuerier) UpdateTemplateACLByID(_ context.Context, arg database.UpdateTemplateACLByIDParams) error { if err := validateDatabaseType(arg); err != nil { return err diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 7b6cdb147dcf9..b579ae70d10e1 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -2041,6 +2041,13 @@ func (m metricsStore) UpdateReplica(ctx context.Context, arg database.UpdateRepl return replica, err } +func (m metricsStore) UpdateTailnetPeerStatusByCoordinator(ctx context.Context, arg database.UpdateTailnetPeerStatusByCoordinatorParams) error { + start := time.Now() + r0 := m.s.UpdateTailnetPeerStatusByCoordinator(ctx, arg) + m.queryLatencies.WithLabelValues("UpdateTailnetPeerStatusByCoordinator").Observe(time.Since(start).Seconds()) + return r0 +} + func (m metricsStore) UpdateTemplateACLByID(ctx context.Context, arg database.UpdateTemplateACLByIDParams) error { start := time.Now() err := m.s.UpdateTemplateACLByID(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index bda8186a26a4f..a5c44d11c714f 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -4307,6 +4307,20 @@ func (mr *MockStoreMockRecorder) UpdateReplica(arg0, arg1 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateReplica", reflect.TypeOf((*MockStore)(nil).UpdateReplica), arg0, arg1) } +// UpdateTailnetPeerStatusByCoordinator mocks base method. +func (m *MockStore) UpdateTailnetPeerStatusByCoordinator(arg0 context.Context, arg1 database.UpdateTailnetPeerStatusByCoordinatorParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateTailnetPeerStatusByCoordinator", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateTailnetPeerStatusByCoordinator indicates an expected call of UpdateTailnetPeerStatusByCoordinator. +func (mr *MockStoreMockRecorder) UpdateTailnetPeerStatusByCoordinator(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTailnetPeerStatusByCoordinator", reflect.TypeOf((*MockStore)(nil).UpdateTailnetPeerStatusByCoordinator), arg0, arg1) +} + // UpdateTemplateACLByID mocks base method. func (m *MockStore) UpdateTemplateACLByID(arg0 context.Context, arg1 database.UpdateTemplateACLByIDParams) error { m.ctrl.T.Helper() diff --git a/coderd/database/dbtestutil/db.go b/coderd/database/dbtestutil/db.go index 16eb3393ca346..327d880f69648 100644 --- a/coderd/database/dbtestutil/db.go +++ b/coderd/database/dbtestutil/db.go @@ -35,6 +35,7 @@ type options struct { dumpOnFailure bool returnSQLDB func(*sql.DB) logger slog.Logger + url string } type Option func(*options) @@ -59,6 +60,12 @@ func WithLogger(logger slog.Logger) Option { } } +func WithURL(u string) Option { + return func(o *options) { + o.url = u + } +} + func withReturnSQLDB(f func(*sql.DB)) Option { return func(o *options) { o.returnSQLDB = f @@ -92,6 +99,9 @@ func NewDB(t testing.TB, opts ...Option) (database.Store, pubsub.Pubsub) { ps := pubsub.NewInMemory() if WillUsePostgres() { connectionURL := os.Getenv("CODER_PG_CONNECTION_URL") + if connectionURL == "" && o.url != "" { + connectionURL = o.url + } if connectionURL == "" { var ( err error diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 2d45e154b532d..77b64eda3a585 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -413,6 +413,7 @@ type sqlcQuerier interface { UpdateProvisionerJobWithCancelByID(ctx context.Context, arg UpdateProvisionerJobWithCancelByIDParams) error UpdateProvisionerJobWithCompleteByID(ctx context.Context, arg UpdateProvisionerJobWithCompleteByIDParams) error UpdateReplica(ctx context.Context, arg UpdateReplicaParams) (Replica, error) + UpdateTailnetPeerStatusByCoordinator(ctx context.Context, arg UpdateTailnetPeerStatusByCoordinatorParams) error UpdateTemplateACLByID(ctx context.Context, arg UpdateTemplateACLByIDParams) error UpdateTemplateAccessControlByID(ctx context.Context, arg UpdateTemplateAccessControlByIDParams) error UpdateTemplateActiveVersionByID(ctx context.Context, arg UpdateTemplateActiveVersionByIDParams) error diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index d8a6e3a1abb03..ae7fdd1f69948 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -7354,6 +7354,25 @@ func (q *sqlQuerier) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUI return items, nil } +const updateTailnetPeerStatusByCoordinator = `-- name: UpdateTailnetPeerStatusByCoordinator :exec +UPDATE + tailnet_peers +SET + status = $2 +WHERE + coordinator_id = $1 +` + +type UpdateTailnetPeerStatusByCoordinatorParams struct { + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` + Status TailnetStatus `db:"status" json:"status"` +} + +func (q *sqlQuerier) UpdateTailnetPeerStatusByCoordinator(ctx context.Context, arg UpdateTailnetPeerStatusByCoordinatorParams) error { + _, err := q.db.ExecContext(ctx, updateTailnetPeerStatusByCoordinator, arg.CoordinatorID, arg.Status) + return err +} + const upsertTailnetAgent = `-- name: UpsertTailnetAgent :one INSERT INTO tailnet_agents ( diff --git a/coderd/database/queries/tailnet.sql b/coderd/database/queries/tailnet.sql index 767b966cbbce3..07936e277bc52 100644 --- a/coderd/database/queries/tailnet.sql +++ b/coderd/database/queries/tailnet.sql @@ -149,6 +149,14 @@ DO UPDATE SET updated_at = now() at time zone 'utc' RETURNING *; +-- name: UpdateTailnetPeerStatusByCoordinator :exec +UPDATE + tailnet_peers +SET + status = $2 +WHERE + coordinator_id = $1; + -- name: DeleteTailnetPeer :one DELETE FROM tailnet_peers diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 1546f0ac3087b..be4722a02f317 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -144,10 +144,6 @@ func newPGCoordInternal( // signals when first heartbeat has been sent, so it's safe to start binding. fHB := make(chan struct{}) - // we need to arrange for the querier to stop _after_ the tunneler and binder, since we delete - // the coordinator when the querier stops (via the heartbeats). If the tunneler and binder are - // still running, they could run afoul of foreign key constraints. - querierCtx, querierCancel := context.WithCancel(dbauthz.As(context.Background(), pgCoordSubject)) c := &pgCoord{ ctx: ctx, cancel: cancel, @@ -163,18 +159,9 @@ func newPGCoordInternal( handshaker: newHandshaker(ctx, logger, id, ps, rfhCh, fHB), handshakerCh: rfhCh, id: id, - querier: newQuerier(querierCtx, logger, id, ps, store, id, cCh, ccCh, numQuerierWorkers, fHB, clk), + querier: newQuerier(ctx, logger, id, ps, store, id, cCh, ccCh, numQuerierWorkers, fHB, clk), closed: make(chan struct{}), } - go func() { - // when the main context is canceled, or the coordinator closed, the binder, tunneler, and - // handshaker always eventually stop. Once they stop it's safe to cancel the querier context, which - // has the effect of deleting the coordinator from the database and ceasing heartbeats. - c.binder.workerWG.Wait() - c.tunneler.workerWG.Wait() - c.handshaker.workerWG.Wait() - querierCancel() - }() logger.Info(ctx, "starting coordinator") return c, nil } @@ -239,6 +226,9 @@ func (c *pgCoord) Close() error { c.cancel() c.closeOnce.Do(func() { close(c.closed) }) c.querier.wait() + c.binder.wait() + c.tunneler.workerWG.Wait() + c.handshaker.workerWG.Wait() return nil } @@ -485,6 +475,7 @@ type binder struct { workQ *workQ[bKey] workerWG sync.WaitGroup + close chan struct{} } func newBinder(ctx context.Context, @@ -502,6 +493,7 @@ func newBinder(ctx context.Context, bindings: bindings, latest: make(map[bKey]binding), workQ: newWorkQ[bKey](ctx), + close: make(chan struct{}), } go b.handleBindings() // add to the waitgroup immediately to avoid any races waiting for it before @@ -513,6 +505,26 @@ func newBinder(ctx context.Context, go b.worker() } }() + + go func() { + defer close(b.close) + <-b.ctx.Done() + b.logger.Debug(b.ctx, "binder exiting, waiting for workers") + + b.workerWG.Wait() + + b.logger.Debug(b.ctx, "updating peers to lost") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) + defer cancel() + err := b.store.UpdateTailnetPeerStatusByCoordinator(ctx, database.UpdateTailnetPeerStatusByCoordinatorParams{ + CoordinatorID: b.coordinatorID, + Status: database.TailnetStatusLost, + }) + if err != nil { + b.logger.Error(b.ctx, "update peer status to lost", slog.Error(err)) + } + }() return b } @@ -520,7 +532,7 @@ func (b *binder) handleBindings() { for { select { case <-b.ctx.Done(): - b.logger.Debug(b.ctx, "binder exiting", slog.Error(b.ctx.Err())) + b.logger.Debug(b.ctx, "binder exiting") return case bnd := <-b.bindings: b.storeBinding(bnd) @@ -632,6 +644,10 @@ func (b *binder) retrieveBinding(bk bKey) binding { return bnd } +func (b *binder) wait() { + <-b.close +} + // mapper tracks data sent to a peer, and sends updates based on changes read from the database. type mapper struct { ctx context.Context @@ -1646,7 +1662,6 @@ func (h *heartbeats) sendBeats() { // send an initial heartbeat so that other coordinators can start using our bindings right away. h.sendBeat() close(h.firstHeartbeat) // signal binder it can start writing - defer h.sendDelete() tkr := h.clock.TickerFunc(h.ctx, HeartbeatPeriod, func() error { h.sendBeat() return nil @@ -1677,17 +1692,6 @@ func (h *heartbeats) sendBeat() { h.failedHeartbeats = 0 } -func (h *heartbeats) sendDelete() { - // here we don't want to use the main context, since it will have been canceled - ctx := dbauthz.As(context.Background(), pgCoordSubject) - err := h.store.DeleteCoordinator(ctx, h.self) - if err != nil { - h.logger.Error(h.ctx, "failed to send coordinator delete", slog.Error(err)) - return - } - h.logger.Debug(h.ctx, "deleted coordinator") -} - func (h *heartbeats) cleanupLoop() { defer h.wg.Done() h.cleanup() diff --git a/enterprise/tailnet/pgcoord_internal_test.go b/enterprise/tailnet/pgcoord_internal_test.go index 253487d28d196..dec6d95e26178 100644 --- a/enterprise/tailnet/pgcoord_internal_test.go +++ b/enterprise/tailnet/pgcoord_internal_test.go @@ -396,10 +396,6 @@ func TestPGCoordinatorUnhealthy(t *testing.T) { UpsertTailnetCoordinator(gomock.Any(), gomock.Any()). Times(3). Return(database.TailnetCoordinator{}, xerrors.New("badness")) - mStore.EXPECT(). - DeleteCoordinator(gomock.Any(), gomock.Any()). - Times(1). - Return(nil) // But, in particular we DO NOT want the coordinator to call DeleteTailnetPeer, as this is // unnecessary and can spam the database. c.f. https://github.com/coder/coder/issues/12923 @@ -407,6 +403,7 @@ func TestPGCoordinatorUnhealthy(t *testing.T) { mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil) mStore.EXPECT().CleanTailnetLostPeers(gomock.Any()).AnyTimes().Return(nil) mStore.EXPECT().CleanTailnetTunnels(gomock.Any()).AnyTimes().Return(nil) + mStore.EXPECT().UpdateTailnetPeerStatusByCoordinator(gomock.Any(), gomock.Any()) coordinator, err := newPGCoordInternal(ctx, logger, ps, mStore, mClock) require.NoError(t, err) diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index 2232e3941eb0c..7ed0b7a7d18cd 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -480,7 +480,7 @@ func TestPGCoordinatorSingle_SendsHeartbeats(t *testing.T) { mu := sync.Mutex{} heartbeats := []time.Time{} - unsub, err := ps.SubscribeWithErr(tailnet.EventHeartbeats, func(_ context.Context, msg []byte, err error) { + unsub, err := ps.SubscribeWithErr(tailnet.EventHeartbeats, func(_ context.Context, _ []byte, err error) { assert.NoError(t, err) mu.Lock() defer mu.Unlock() @@ -591,10 +591,10 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) { require.ErrorIs(t, err, io.EOF) err = client21.recvErr(ctx, t) require.ErrorIs(t, err, io.EOF) + assertEventuallyLost(ctx, t, store, agent2.id) + assertEventuallyLost(ctx, t, store, client21.id) + assertEventuallyLost(ctx, t, store, client22.id) - assertEventuallyNoAgents(ctx, t, store, agent2.id) - - t.Logf("close coord1") err = coord1.Close() require.NoError(t, err) // this closes agent1, client12, client11 @@ -604,6 +604,9 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) { require.ErrorIs(t, err, io.EOF) err = client11.recvErr(ctx, t) require.ErrorIs(t, err, io.EOF) + assertEventuallyLost(ctx, t, store, agent1.id) + assertEventuallyLost(ctx, t, store, client11.id) + assertEventuallyLost(ctx, t, store, client12.id) // wait for all connections to close err = agent1.close() @@ -629,10 +632,6 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) { err = client22.close() require.NoError(t, err) client22.waitForClose(ctx, t) - - assertEventuallyNoAgents(ctx, t, store, agent1.id) - assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) - assertEventuallyNoClientsForAgent(ctx, t, store, agent2.id) } // TestPGCoordinator_MultiCoordinatorAgent tests when a single agent connects to multiple coordinators. @@ -746,7 +745,7 @@ func TestPGCoordinator_Unhealthy(t *testing.T) { mStore.EXPECT().DeleteTailnetPeer(gomock.Any(), gomock.Any()). AnyTimes().Return(database.DeleteTailnetPeerRow{}, nil) mStore.EXPECT().DeleteAllTailnetTunnels(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - mStore.EXPECT().DeleteCoordinator(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + mStore.EXPECT().UpdateTailnetPeerStatusByCoordinator(gomock.Any(), gomock.Any()) uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore) require.NoError(t, err) @@ -811,7 +810,7 @@ func TestPGCoordinator_Node_Empty(t *testing.T) { mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil) mStore.EXPECT().CleanTailnetLostPeers(gomock.Any()).AnyTimes().Return(nil) mStore.EXPECT().CleanTailnetTunnels(gomock.Any()).AnyTimes().Return(nil) - mStore.EXPECT().DeleteCoordinator(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + mStore.EXPECT().UpdateTailnetPeerStatusByCoordinator(gomock.Any(), gomock.Any()).Times(1) uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore) require.NoError(t, err) @@ -871,51 +870,184 @@ func TestPGCoordinator_Lost(t *testing.T) { agpltest.LostTest(ctx, t, coordinator) } -func TestPGCoordinator_DeleteOnClose(t *testing.T) { +func TestPGCoordinator_NoDeleteOnClose(t *testing.T) { t.Parallel() - + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() - ctrl := gomock.NewController(t) - mStore := dbmock.NewMockStore(ctrl) - ps := pubsub.NewInMemory() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store) + require.NoError(t, err) + defer coordinator.Close() + + agent := newTestAgent(t, coordinator, "original") + defer agent.close() + agent.sendNode(&agpl.Node{PreferredDERP: 10}) + + client := newTestClient(t, coordinator, agent.id) + defer client.close() + + // Simulate some traffic to generate + // a peer. + agentNodes := client.recvNodes(ctx, t) + require.Len(t, agentNodes, 1) + assert.Equal(t, 10, agentNodes[0].PreferredDERP) + client.sendNode(&agpl.Node{PreferredDERP: 11}) + + clientNodes := agent.recvNodes(ctx, t) + require.Len(t, clientNodes, 1) + assert.Equal(t, 11, clientNodes[0].PreferredDERP) + + anode := coordinator.Node(agent.id) + require.NotNil(t, anode) + cnode := coordinator.Node(client.id) + require.NotNil(t, cnode) + + err = coordinator.Close() + require.NoError(t, err) + assertEventuallyLost(ctx, t, store, agent.id) + assertEventuallyLost(ctx, t, store, client.id) + + coordinator2, err := tailnet.NewPGCoord(ctx, logger, ps, store) + require.NoError(t, err) + defer coordinator2.Close() + + anode = coordinator2.Node(agent.id) + require.NotNil(t, anode) + assert.Equal(t, 10, anode.PreferredDERP) + + cnode = coordinator2.Node(client.id) + require.NotNil(t, cnode) + assert.Equal(t, 11, cnode.PreferredDERP) +} + +// TestPGCoordinatorDual_FailedHeartbeat tests that peers +// disconnect from a coordinator when they are unhealthy, +// are marked as LOST (not DISCONNECTED), and can reconnect to +// a new coordinator and reestablish their tunnels. +func TestPGCoordinatorDual_FailedHeartbeat(t *testing.T) { + t.Parallel() + + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + + dburl, closeFn, err := dbtestutil.Open() + require.NoError(t, err) + t.Cleanup(closeFn) + + store1, ps1, sdb1 := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithURL(dburl)) + defer sdb1.Close() + store2, ps2, sdb2 := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithURL(dburl)) + defer sdb2.Close() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + t.Cleanup(cancel) + + // We do this to avoid failing due errors related to the + // database connection being close. logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - upsertDone := make(chan struct{}) - deleteCalled := make(chan struct{}) - finishDelete := make(chan struct{}) - mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()). - MinTimes(1). - Do(func(_ context.Context, _ uuid.UUID) { close(upsertDone) }). - Return(database.TailnetCoordinator{}, nil) - mStore.EXPECT().DeleteCoordinator(gomock.Any(), gomock.Any()). - Times(1). - Do(func(_ context.Context, _ uuid.UUID) { - close(deleteCalled) - <-finishDelete - }). - Return(nil) + // Create two coordinators, 1 for each peer. + c1, err := tailnet.NewPGCoord(ctx, logger, ps1, store1) + require.NoError(t, err) + c2, err := tailnet.NewPGCoord(ctx, logger, ps2, store2) + require.NoError(t, err) - // extra calls we don't particularly care about for this test - mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil) - mStore.EXPECT().CleanTailnetLostPeers(gomock.Any()).AnyTimes().Return(nil) - mStore.EXPECT().CleanTailnetTunnels(gomock.Any()).AnyTimes().Return(nil) + p1 := agpltest.NewPeer(ctx, t, c1, "peer1") + p2 := agpltest.NewPeer(ctx, t, c2, "peer2") - uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore) + // Create a binding between the two. + p1.AddTunnel(p2.ID) + + // Ensure that messages pass through. + p1.UpdateDERP(1) + p2.UpdateDERP(2) + p1.AssertEventuallyHasDERP(p2.ID, 2) + p2.AssertEventuallyHasDERP(p1.ID, 1) + + // Close the underlying database connection to induce + // a heartbeat failure scenario and assert that + // we eventually disconnect from the coordinator. + err = sdb1.Close() require.NoError(t, err) - testutil.RequireRecvCtx(ctx, t, upsertDone) - closeErr := make(chan error, 1) - go func() { - closeErr <- uut.Close() - }() - select { - case <-closeErr: - t.Fatal("close returned before DeleteCoordinator called") - case <-deleteCalled: - close(finishDelete) - err := testutil.RequireRecvCtx(ctx, t, closeErr) - require.NoError(t, err) + p1.AssertEventuallyResponsesClosed() + p2.AssertEventuallyLost(p1.ID) + // This basically checks that peer2 had no update + // performed on their status since we are connected + // to coordinator2. + assertEventuallyStatus(ctx, t, store2, p2.ID, database.TailnetStatusOk) + + // Connect peer1 to coordinator2. + p1.ConnectToCoordinator(ctx, c2) + // Reestablish binding. + p1.AddTunnel(p2.ID) + // Ensure messages still flow back and forth. + p1.AssertEventuallyHasDERP(p2.ID, 2) + p1.UpdateDERP(3) + p2.UpdateDERP(4) + p2.AssertEventuallyHasDERP(p1.ID, 3) + p1.AssertEventuallyHasDERP(p2.ID, 4) + // Make sure peer2 never got an update about peer1 disconnecting. + p2.AssertNeverUpdateKind(p1.ID, proto.CoordinateResponse_PeerUpdate_DISCONNECTED) +} + +func TestPGCoordinatorDual_PeerReconnect(t *testing.T) { + t.Parallel() + + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") } + + store, ps := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + // Create two coordinators, 1 for each peer. + c1, err := tailnet.NewPGCoord(ctx, logger, ps, store) + require.NoError(t, err) + c2, err := tailnet.NewPGCoord(ctx, logger, ps, store) + require.NoError(t, err) + + p1 := agpltest.NewPeer(ctx, t, c1, "peer1") + p2 := agpltest.NewPeer(ctx, t, c2, "peer2") + + // Create a binding between the two. + p1.AddTunnel(p2.ID) + + // Ensure that messages pass through. + p1.UpdateDERP(1) + p2.UpdateDERP(2) + p1.AssertEventuallyHasDERP(p2.ID, 2) + p2.AssertEventuallyHasDERP(p1.ID, 1) + + // Close coordinator1. Now we will check that we + // never send a DISCONNECTED update. + err = c1.Close() + require.NoError(t, err) + p1.AssertEventuallyResponsesClosed() + p2.AssertEventuallyLost(p1.ID) + // This basically checks that peer2 had no update + // performed on their status since we are connected + // to coordinator2. + assertEventuallyStatus(ctx, t, store, p2.ID, database.TailnetStatusOk) + + // Connect peer1 to coordinator2. + p1.ConnectToCoordinator(ctx, c2) + // Reestablish binding. + p1.AddTunnel(p2.ID) + // Ensure messages still flow back and forth. + p1.AssertEventuallyHasDERP(p2.ID, 2) + p1.UpdateDERP(3) + p2.UpdateDERP(4) + p2.AssertEventuallyHasDERP(p1.ID, 3) + p1.AssertEventuallyHasDERP(p2.ID, 4) + // Make sure peer2 never got an update about peer1 disconnecting. + p2.AssertNeverUpdateKind(p1.ID, proto.CoordinateResponse_PeerUpdate_DISCONNECTED) } type testConn struct { @@ -1056,21 +1188,7 @@ func assertNeverHasDERPs(ctx context.Context, t *testing.T, c *testConn, expecte } } -func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { - t.Helper() - assert.Eventually(t, func() bool { - agents, err := store.GetTailnetPeers(ctx, agentID) - if xerrors.Is(err, sql.ErrNoRows) { - return true - } - if err != nil { - t.Fatal(err) - } - return len(agents) == 0 - }, testutil.WaitShort, testutil.IntervalFast) -} - -func assertEventuallyLost(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { +func assertEventuallyStatus(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID, status database.TailnetStatus) { t.Helper() assert.Eventually(t, func() bool { peers, err := store.GetTailnetPeers(ctx, agentID) @@ -1081,7 +1199,7 @@ func assertEventuallyLost(ctx context.Context, t *testing.T, store database.Stor t.Fatal(err) } for _, peer := range peers { - if peer.Status == database.TailnetStatusOk { + if peer.Status != status { return false } } @@ -1089,6 +1207,11 @@ func assertEventuallyLost(ctx context.Context, t *testing.T, store database.Stor }, testutil.WaitShort, testutil.IntervalFast) } +func assertEventuallyLost(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { + t.Helper() + assertEventuallyStatus(ctx, t, store, agentID, database.TailnetStatusLost) +} + func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { t.Helper() assert.Eventually(t, func() bool { diff --git a/tailnet/test/peer.go b/tailnet/test/peer.go index 791c3b0e9176d..1b08d6886ae98 100644 --- a/tailnet/test/peer.go +++ b/tailnet/test/peer.go @@ -19,18 +19,24 @@ type PeerStatus struct { } type Peer struct { - ctx context.Context - cancel context.CancelFunc - t testing.TB - ID uuid.UUID - name string - resps <-chan *proto.CoordinateResponse - reqs chan<- *proto.CoordinateRequest - peers map[uuid.UUID]PeerStatus + ctx context.Context + cancel context.CancelFunc + t testing.TB + ID uuid.UUID + name string + resps <-chan *proto.CoordinateResponse + reqs chan<- *proto.CoordinateRequest + peers map[uuid.UUID]PeerStatus + peerUpdates map[uuid.UUID][]*proto.CoordinateResponse_PeerUpdate } func NewPeer(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, name string, id ...uuid.UUID) *Peer { - p := &Peer{t: t, name: name, peers: make(map[uuid.UUID]PeerStatus)} + p := &Peer{ + t: t, + name: name, + peers: make(map[uuid.UUID]PeerStatus), + peerUpdates: make(map[uuid.UUID][]*proto.CoordinateResponse_PeerUpdate), + } p.ctx, p.cancel = context.WithCancel(ctx) if len(id) > 1 { t.Fatal("too many") @@ -45,6 +51,12 @@ func NewPeer(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, nam return p } +func (p *Peer) ConnectToCoordinator(ctx context.Context, c tailnet.CoordinatorV2) { + p.t.Helper() + + p.reqs, p.resps = c.Coordinate(ctx, p.ID, p.name, tailnet.SingleTailnetCoordinateeAuth{}) +} + func (p *Peer) AddTunnel(other uuid.UUID) { p.t.Helper() req := &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: tailnet.UUIDToByteSlice(other)}} @@ -180,6 +192,19 @@ func (p *Peer) AssertEventuallyGetsError(match string) { } } +// AssertNeverUpdateKind asserts that we have not received +// any updates on the provided peer for the provided kind. +func (p *Peer) AssertNeverUpdateKind(peer uuid.UUID, kind proto.CoordinateResponse_PeerUpdate_Kind) { + p.t.Helper() + + updates, ok := p.peerUpdates[peer] + assert.True(p.t, ok, "expected updates for peer %s", peer) + + for _, update := range updates { + assert.NotEqual(p.t, kind, update.Kind, update) + } +} + var responsesClosed = xerrors.New("responses closed") func (p *Peer) handleOneResp() error { @@ -198,6 +223,8 @@ func (p *Peer) handleOneResp() error { if err != nil { return err } + p.peerUpdates[id] = append(p.peerUpdates[id], update) + switch update.Kind { case proto.CoordinateResponse_PeerUpdate_NODE, proto.CoordinateResponse_PeerUpdate_LOST: peer := p.peers[id]