diff --git a/enterprise/tailnet/connio.go b/enterprise/tailnet/connio.go index 49b04ac8cb816..9454279adcb4c 100644 --- a/enterprise/tailnet/connio.go +++ b/enterprise/tailnet/connio.go @@ -24,16 +24,17 @@ type connIO struct { // coordCtx is the parent context, that is, the context of the Coordinator coordCtx context.Context // peerCtx is the context of the connection to our peer - peerCtx context.Context - cancel context.CancelFunc - logger slog.Logger - requests <-chan *proto.CoordinateRequest - responses chan<- *proto.CoordinateResponse - bindings chan<- binding - tunnels chan<- tunnel - auth agpl.TunnelAuth - mu sync.Mutex - closed bool + peerCtx context.Context + cancel context.CancelFunc + logger slog.Logger + requests <-chan *proto.CoordinateRequest + responses chan<- *proto.CoordinateResponse + bindings chan<- binding + tunnels chan<- tunnel + auth agpl.TunnelAuth + mu sync.Mutex + closed bool + disconnected bool name string start int64 @@ -76,20 +77,29 @@ func newConnIO(coordContext context.Context, func (c *connIO) recvLoop() { defer func() { - // withdraw bindings & tunnels when we exit. We need to use the parent context here, since + // withdraw bindings & tunnels when we exit. We need to use the coordinator context here, since // our own context might be canceled, but we still need to withdraw. b := binding{ bKey: bKey(c.UniqueID()), + kind: proto.CoordinateResponse_PeerUpdate_LOST, + } + if c.disconnected { + b.kind = proto.CoordinateResponse_PeerUpdate_DISCONNECTED } if err := sendCtx(c.coordCtx, c.bindings, b); err != nil { c.logger.Debug(c.coordCtx, "parent context expired while withdrawing bindings", slog.Error(err)) } - t := tunnel{ - tKey: tKey{src: c.UniqueID()}, - active: false, - } - if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil { - c.logger.Debug(c.coordCtx, "parent context expired while withdrawing tunnels", slog.Error(err)) + // only remove tunnels on graceful disconnect. If we remove tunnels for lost peers, then + // this will look like a disconnect from the peer perspective, since we query for active peers + // by using the tunnel as a join in the database + if c.disconnected { + t := tunnel{ + tKey: tKey{src: c.UniqueID()}, + active: false, + } + if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil { + c.logger.Debug(c.coordCtx, "parent context expired while withdrawing tunnels", slog.Error(err)) + } } }() defer c.Close() @@ -111,6 +121,8 @@ func (c *connIO) recvLoop() { } } +var errDisconnect = xerrors.New("graceful disconnect") + func (c *connIO) handleRequest(req *proto.CoordinateRequest) error { c.logger.Debug(c.peerCtx, "got request") if req.UpdateSelf != nil { @@ -118,6 +130,7 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error { b := binding{ bKey: bKey(c.UniqueID()), node: req.UpdateSelf.Node, + kind: proto.CoordinateResponse_PeerUpdate_NODE, } if err := sendCtx(c.coordCtx, c.bindings, b); err != nil { c.logger.Debug(c.peerCtx, "failed to send binding", slog.Error(err)) @@ -169,7 +182,11 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error { return err } } - // TODO: (spikecurtis) support Disconnect + if req.Disconnect != nil { + c.logger.Debug(c.peerCtx, "graceful disconnect") + c.disconnected = true + return errDisconnect + } return nil } diff --git a/enterprise/tailnet/multiagent_test.go b/enterprise/tailnet/multiagent_test.go index 8978c59418e95..e51cab881482b 100644 --- a/enterprise/tailnet/multiagent_test.go +++ b/enterprise/tailnet/multiagent_test.go @@ -58,7 +58,7 @@ func TestPGCoordinator_MultiAgent(t *testing.T) { require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) - assertEventuallyNoAgents(ctx, t, store, agent1.id) + assertEventuallyLost(ctx, t, store, agent1.id) } // TestPGCoordinator_MultiAgent_UnsubscribeRace tests a single coordinator with @@ -106,7 +106,7 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) { require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) - assertEventuallyNoAgents(ctx, t, store, agent1.id) + assertEventuallyLost(ctx, t, store, agent1.id) } // TestPGCoordinator_MultiAgent_Unsubscribe tests a single coordinator with a @@ -168,7 +168,7 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) { require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) - assertEventuallyNoAgents(ctx, t, store, agent1.id) + assertEventuallyLost(ctx, t, store, agent1.id) } // TestPGCoordinator_MultiAgent_MultiCoordinator tests two coordinators with a @@ -220,7 +220,7 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) { require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) - assertEventuallyNoAgents(ctx, t, store, agent1.id) + assertEventuallyLost(ctx, t, store, agent1.id) } // TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe tests two @@ -273,7 +273,7 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) - assertEventuallyNoAgents(ctx, t, store, agent1.id) + assertEventuallyLost(ctx, t, store, agent1.id) } // TestPGCoordinator_MultiAgent_TwoAgents tests three coordinators with a @@ -344,5 +344,5 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) { require.NoError(t, agent2.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) - assertEventuallyNoAgents(ctx, t, store, agent1.id) + assertEventuallyLost(ctx, t, store, agent1.id) } diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index a999e5586b2dd..16c7123c9ddde 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -203,6 +203,7 @@ func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn { }}) }, OnRemove: func(_ agpl.Queue) { + _ = sendCtx(c.ctx, reqs, &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}) cancel() }, }).Init() @@ -352,9 +353,14 @@ func v1SendLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logg _ = q.CoordinatorClose() return } + // don't send empty updates + if len(nodes) == 0 { + logger.Debug(ctx, "skipping enqueueing 0-length v1 update") + continue + } err = q.Enqueue(nodes) if err != nil { - logger.Error(ctx, "failed to enqueue multi-agent update", slog.Error(err)) + logger.Error(ctx, "failed to enqueue v1 update", slog.Error(err)) } } } @@ -597,6 +603,7 @@ type bKey uuid.UUID type binding struct { bKey node *proto.Node + kind proto.CoordinateResponse_PeerUpdate_Kind } // binder reads node bindings from the channel and writes them to the database. It handles retries with a backoff. @@ -675,7 +682,16 @@ func (b *binder) worker() { func (b *binder) writeOne(bnd binding) error { var err error - if bnd.node != nil { + if bnd.kind == proto.CoordinateResponse_PeerUpdate_DISCONNECTED { + _, err = b.store.DeleteTailnetPeer(b.ctx, database.DeleteTailnetPeerParams{ + ID: uuid.UUID(bnd.bKey), + CoordinatorID: b.coordinatorID, + }) + // writeOne is idempotent + if xerrors.Is(err, sql.ErrNoRows) { + err = nil + } + } else { var nodeRaw []byte nodeRaw, err = gProto.Marshal(bnd.node) if err != nil { @@ -684,21 +700,16 @@ func (b *binder) writeOne(bnd binding) error { b.logger.Critical(b.ctx, "failed to marshal node", slog.Error(err)) return err } + status := database.TailnetStatusOk + if bnd.kind == proto.CoordinateResponse_PeerUpdate_LOST { + status = database.TailnetStatusLost + } _, err = b.store.UpsertTailnetPeer(b.ctx, database.UpsertTailnetPeerParams{ ID: uuid.UUID(bnd.bKey), CoordinatorID: b.coordinatorID, Node: nodeRaw, - Status: database.TailnetStatusOk, + Status: status, }) - } else { - _, err = b.store.DeleteTailnetPeer(b.ctx, database.DeleteTailnetPeerParams{ - ID: uuid.UUID(bnd.bKey), - CoordinatorID: b.coordinatorID, - }) - // writeOne is idempotent - if xerrors.Is(err, sql.ErrNoRows) { - err = nil - } } if err != nil && !database.IsQueryCanceledError(err) { @@ -710,16 +721,27 @@ func (b *binder) writeOne(bnd binding) error { return err } -// storeBinding stores the latest binding, where we interpret node == nil as removing the binding. This keeps the map +// storeBinding stores the latest binding, where we interpret kind == DISCONNECTED as removing the binding. This keeps the map // from growing without bound. func (b *binder) storeBinding(bnd binding) { b.mu.Lock() defer b.mu.Unlock() - if bnd.node != nil { + + switch bnd.kind { + case proto.CoordinateResponse_PeerUpdate_NODE: b.latest[bnd.bKey] = bnd - } else { - // nil node is interpreted as removing binding + case proto.CoordinateResponse_PeerUpdate_DISCONNECTED: delete(b.latest, bnd.bKey) + case proto.CoordinateResponse_PeerUpdate_LOST: + // we need to coalesce with the previously stored node, since it must + // be non-nil in the database + old, ok := b.latest[bnd.bKey] + if !ok { + // lost before we ever got a node update. No action + return + } + bnd.node = old.node + b.latest[bnd.bKey] = bnd } } @@ -732,6 +754,7 @@ func (b *binder) retrieveBinding(bk bKey) binding { bnd = binding{ bKey: bk, node: nil, + kind: proto.CoordinateResponse_PeerUpdate_DISCONNECTED, } } return bnd @@ -752,9 +775,8 @@ type mapper struct { // latest is the most recent, unfiltered snapshot of the mappings we know about latest []mapping - // sent is the state of mappings we have actually enqueued; used to compute diffs for updates. It is a map from peer - // ID to node. - sent map[uuid.UUID]*proto.Node + // sent is the state of mappings we have actually enqueued; used to compute diffs for updates. + sent map[uuid.UUID]mapping // called to filter mappings to healthy coordinators heartbeats *heartbeats @@ -771,7 +793,7 @@ func newMapper(c *connIO, logger slog.Logger, h *heartbeats) *mapper { update: make(chan struct{}), mappings: make(chan []mapping), heartbeats: h, - sent: make(map[uuid.UUID]*proto.Node), + sent: make(map[uuid.UUID]mapping), } go m.run() return m @@ -779,19 +801,19 @@ func newMapper(c *connIO, logger slog.Logger, h *heartbeats) *mapper { func (m *mapper) run() { for { - var nodes map[uuid.UUID]*proto.Node + var best map[uuid.UUID]mapping select { case <-m.ctx.Done(): return case mappings := <-m.mappings: m.logger.Debug(m.ctx, "got new mappings") m.latest = mappings - nodes = m.mappingsToNodes(mappings) + best = m.bestMappings(mappings) case <-m.update: m.logger.Debug(m.ctx, "triggered update") - nodes = m.mappingsToNodes(m.latest) + best = m.bestMappings(m.latest) } - update := m.nodesToUpdate(nodes) + update := m.bestToUpdate(best) if update == nil { m.logger.Debug(m.ctx, "skipping nil node update") continue @@ -802,67 +824,83 @@ func (m *mapper) run() { } } -// mappingsToNodes takes a set of mappings and resolves the best set of nodes. We may get several mappings for a +// bestMappings takes a set of mappings and resolves the best set of nodes. We may get several mappings for a // particular connection, from different coordinators in the distributed system. Furthermore, some coordinators // might be considered invalid on account of missing heartbeats. We take the most recent mapping from a valid // coordinator as the "best" mapping. -func (m *mapper) mappingsToNodes(mappings []mapping) map[uuid.UUID]*proto.Node { +func (m *mapper) bestMappings(mappings []mapping) map[uuid.UUID]mapping { mappings = m.heartbeats.filter(mappings) best := make(map[uuid.UUID]mapping, len(mappings)) - for _, m := range mappings { - bestM, ok := best[m.peer] - if !ok || m.updatedAt.After(bestM.updatedAt) { - best[m.peer] = m + for _, mpng := range mappings { + bestM, ok := best[mpng.peer] + switch { + case !ok: + // no current best + best[mpng.peer] = mpng + + // NODE always beats LOST mapping, since the LOST could be from a coordinator that's + // slow updating the DB, and the peer has reconnected to a different coordinator and + // given a NODE mapping. + case bestM.kind == proto.CoordinateResponse_PeerUpdate_LOST && mpng.kind == proto.CoordinateResponse_PeerUpdate_NODE: + best[mpng.peer] = mpng + case mpng.updatedAt.After(bestM.updatedAt) && mpng.kind == proto.CoordinateResponse_PeerUpdate_NODE: + // newer, and it's a NODE update. + best[mpng.peer] = mpng } } - nodes := make(map[uuid.UUID]*proto.Node, len(best)) - for k, m := range best { - nodes[k] = m.node - } - return nodes + return best } -func (m *mapper) nodesToUpdate(nodes map[uuid.UUID]*proto.Node) *proto.CoordinateResponse { +func (m *mapper) bestToUpdate(best map[uuid.UUID]mapping) *proto.CoordinateResponse { resp := new(proto.CoordinateResponse) - for k, n := range nodes { - sn, ok := m.sent[k] - if !ok { - resp.PeerUpdates = append(resp.PeerUpdates, &proto.CoordinateResponse_PeerUpdate{ - Uuid: agpl.UUIDToByteSlice(k), - Node: n, - Kind: proto.CoordinateResponse_PeerUpdate_NODE, - Reason: "new", - }) + for k, mpng := range best { + var reason string + sm, ok := m.sent[k] + switch { + case !ok && mpng.kind == proto.CoordinateResponse_PeerUpdate_LOST: + // we don't need to send a "lost" update if we've never sent an update about this peer continue - } - eq, err := sn.Equal(n) - if err != nil { - m.logger.Critical(m.ctx, "failed to compare nodes", slog.F("old", sn), slog.F("new", n)) - } - if !eq { - resp.PeerUpdates = append(resp.PeerUpdates, &proto.CoordinateResponse_PeerUpdate{ - Uuid: agpl.UUIDToByteSlice(k), - Node: n, - Kind: proto.CoordinateResponse_PeerUpdate_NODE, - Reason: "update", - }) + case !ok && mpng.kind == proto.CoordinateResponse_PeerUpdate_NODE: + reason = "new" + case ok && sm.kind == proto.CoordinateResponse_PeerUpdate_LOST && mpng.kind == proto.CoordinateResponse_PeerUpdate_LOST: + // was lost and remains lost, no update needed continue + case ok && sm.kind == proto.CoordinateResponse_PeerUpdate_LOST && mpng.kind == proto.CoordinateResponse_PeerUpdate_NODE: + reason = "found" + case ok && sm.kind == proto.CoordinateResponse_PeerUpdate_NODE && mpng.kind == proto.CoordinateResponse_PeerUpdate_LOST: + reason = "lost" + case ok && sm.kind == proto.CoordinateResponse_PeerUpdate_NODE && mpng.kind == proto.CoordinateResponse_PeerUpdate_NODE: + eq, err := sm.node.Equal(mpng.node) + if err != nil { + m.logger.Critical(m.ctx, "failed to compare nodes", slog.F("old", sm.node), slog.F("new", mpng.kind)) + continue + } + if eq { + continue + } + reason = "update" } + resp.PeerUpdates = append(resp.PeerUpdates, &proto.CoordinateResponse_PeerUpdate{ + Uuid: agpl.UUIDToByteSlice(k), + Node: mpng.node, + Kind: mpng.kind, + Reason: reason, + }) + m.sent[k] = mpng } for k := range m.sent { - if _, ok := nodes[k]; !ok { + if _, ok := best[k]; !ok { resp.PeerUpdates = append(resp.PeerUpdates, &proto.CoordinateResponse_PeerUpdate{ Uuid: agpl.UUIDToByteSlice(k), Kind: proto.CoordinateResponse_PeerUpdate_DISCONNECTED, Reason: "disconnected", }) + delete(m.sent, k) } } - m.sent = nodes - if len(resp.PeerUpdates) == 0 { return nil } @@ -1069,10 +1107,6 @@ func (q *querier) mappingQuery(peer mKey) error { if err != nil && !xerrors.Is(err, sql.ErrNoRows) { return err } - if len(bindings) == 0 { - logger.Debug(q.ctx, "no mappings, nothing to do") - return nil - } mappings, err := q.bindingsToMappings(bindings) if err != nil { logger.Debug(q.ctx, "failed to convert mappings", slog.Error(err)) @@ -1100,11 +1134,16 @@ func (q *querier) bindingsToMappings(bindings []database.GetTailnetTunnelPeerBin q.logger.Error(q.ctx, "failed to unmarshal node", slog.Error(err)) return nil, backoff.Permanent(err) } + kind := proto.CoordinateResponse_PeerUpdate_NODE + if binding.Status == database.TailnetStatusLost { + kind = proto.CoordinateResponse_PeerUpdate_LOST + } mappings = append(mappings, mapping{ peer: binding.PeerID, coordinator: binding.CoordinatorID, updatedAt: binding.UpdatedAt, node: node, + kind: kind, }) } return mappings, nil @@ -1326,6 +1365,7 @@ type mapping struct { coordinator uuid.UUID updatedAt time.Time node *proto.Node + kind proto.CoordinateResponse_PeerUpdate_Kind } // querierWorkKey describes two kinds of work the querier needs to do. If peerUpdate diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index d59d437f3228c..20f5f39621a00 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -71,7 +71,7 @@ func TestPGCoordinatorSingle_ClientWithoutAgent(t *testing.T) { require.NoError(t, err) <-client.errChan <-client.closeChan - assertEventuallyNoClientsForAgent(ctx, t, store, agentID) + assertEventuallyLost(ctx, t, store, client.id) } func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) { @@ -108,7 +108,7 @@ func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) { require.NoError(t, err) <-agent.errChan <-agent.closeChan - assertEventuallyNoAgents(ctx, t, store, agent.id) + assertEventuallyLost(ctx, t, store, agent.id) } func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) { @@ -184,8 +184,8 @@ func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) { _ = client.recvErr(ctx, t) client.waitForClose(ctx, t) - assertEventuallyNoAgents(ctx, t, store, agent.id) - assertEventuallyNoClientsForAgent(ctx, t, store, agent.id) + assertEventuallyLost(ctx, t, store, agent.id) + assertEventuallyLost(ctx, t, store, client.id) } func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) { @@ -272,7 +272,7 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) { _ = client.recvErr(ctx, t) client.waitForClose(ctx, t) - assertEventuallyNoClientsForAgent(ctx, t, store, agent.id) + assertEventuallyLost(ctx, t, store, client.id) } func TestPGCoordinatorSingle_SendsHeartbeats(t *testing.T) { @@ -519,8 +519,8 @@ func TestPGCoordinator_MultiCoordinatorAgent(t *testing.T) { require.ErrorIs(t, err, io.ErrClosedPipe) client.waitForClose(ctx, t) - assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) - assertEventuallyNoAgents(ctx, t, store, agent1.id) + assertEventuallyLost(ctx, t, store, client.id) + assertEventuallyLost(ctx, t, store, agent1.id) } func TestPGCoordinator_Unhealthy(t *testing.T) { @@ -624,6 +624,63 @@ func TestPGCoordinator_BidirectionalTunnels(t *testing.T) { p2.assertEventuallyHasDERP(p1.id, 1) } +func TestPGCoordinator_GracefulDisconnect(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + store, ps := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store) + require.NoError(t, err) + defer coordinator.Close() + + p1 := newTestPeer(ctx, t, coordinator, "p1") + defer p1.close(ctx) + p2 := newTestPeer(ctx, t, coordinator, "p2") + defer p2.close(ctx) + p1.addTunnel(p2.id) + p1.updateDERP(1) + p2.updateDERP(2) + + p1.assertEventuallyHasDERP(p2.id, 2) + p2.assertEventuallyHasDERP(p1.id, 1) + + p2.disconnect() + p1.assertEventuallyDisconnected(p2.id) + p2.assertEventuallyResponsesClosed() +} + +func TestPGCoordinator_Lost(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + store, ps := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store) + require.NoError(t, err) + defer coordinator.Close() + + p1 := newTestPeer(ctx, t, coordinator, "p1") + defer p1.close(ctx) + p2 := newTestPeer(ctx, t, coordinator, "p2") + defer p2.close(ctx) + p1.addTunnel(p2.id) + p1.updateDERP(1) + p2.updateDERP(2) + + p1.assertEventuallyHasDERP(p2.id, 2) + p2.assertEventuallyHasDERP(p1.id, 1) + + p2.close(ctx) + p1.assertEventuallyLost(p2.id) +} + type testConn struct { ws, serverWS net.Conn nodeChan chan []*agpl.Node @@ -813,6 +870,7 @@ func assertMultiAgentNeverHasDERPs(ctx context.Context, t *testing.T, ma agpl.Mu } func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { + t.Helper() assert.Eventually(t, func() bool { agents, err := store.GetTailnetPeers(ctx, agentID) if xerrors.Is(err, sql.ErrNoRows) { @@ -825,6 +883,25 @@ func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database. }, testutil.WaitShort, testutil.IntervalFast) } +func assertEventuallyLost(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { + t.Helper() + assert.Eventually(t, func() bool { + peers, err := store.GetTailnetPeers(ctx, agentID) + if xerrors.Is(err, sql.ErrNoRows) { + return false + } + if err != nil { + t.Fatal(err) + } + for _, peer := range peers { + if peer.Status == database.TailnetStatusOk { + return false + } + } + return true + }, testutil.WaitShort, testutil.IntervalFast) +} + func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { t.Helper() assert.Eventually(t, func() bool { @@ -839,6 +916,11 @@ func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store }, testutil.WaitShort, testutil.IntervalFast) } +type peerStatus struct { + preferredDERP int32 + status proto.CoordinateResponse_PeerUpdate_Kind +} + type testPeer struct { ctx context.Context cancel context.CancelFunc @@ -847,11 +929,11 @@ type testPeer struct { name string resps <-chan *proto.CoordinateResponse reqs chan<- *proto.CoordinateRequest - derps map[uuid.UUID]int32 + peers map[uuid.UUID]peerStatus } func newTestPeer(ctx context.Context, t testing.TB, coord agpl.CoordinatorV2, name string, id ...uuid.UUID) *testPeer { - p := &testPeer{t: t, name: name, derps: make(map[uuid.UUID]int32)} + p := &testPeer{t: t, name: name, peers: make(map[uuid.UUID]peerStatus)} p.ctx, p.cancel = context.WithCancel(ctx) if len(id) > 1 { t.Fatal("too many") @@ -890,38 +972,102 @@ func (p *testPeer) updateDERP(derp int32) { } } +func (p *testPeer) disconnect() { + p.t.Helper() + req := &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}} + select { + case <-p.ctx.Done(): + p.t.Errorf("timeout updating node for %s", p.name) + return + case p.reqs <- req: + return + } +} + func (p *testPeer) assertEventuallyHasDERP(other uuid.UUID, derp int32) { p.t.Helper() for { - d, ok := p.derps[other] - if ok && d == derp { + o, ok := p.peers[other] + if ok && o.preferredDERP == derp { return } - select { - case <-p.ctx.Done(): - p.t.Errorf("timeout waiting for response for %s", p.name) + if err := p.handleOneResp(); err != nil { + assert.NoError(p.t, err) return - case resp, ok := <-p.resps: - if !ok { - p.t.Errorf("responses closed for %s", p.name) - return + } + } +} + +func (p *testPeer) assertEventuallyDisconnected(other uuid.UUID) { + p.t.Helper() + for { + _, ok := p.peers[other] + if !ok { + return + } + if err := p.handleOneResp(); err != nil { + assert.NoError(p.t, err) + return + } + } +} + +func (p *testPeer) assertEventuallyLost(other uuid.UUID) { + p.t.Helper() + for { + o := p.peers[other] + if o.status == proto.CoordinateResponse_PeerUpdate_LOST { + return + } + if err := p.handleOneResp(); err != nil { + assert.NoError(p.t, err) + return + } + } +} + +func (p *testPeer) assertEventuallyResponsesClosed() { + p.t.Helper() + for { + err := p.handleOneResp() + if xerrors.Is(err, responsesClosed) { + return + } + if !assert.NoError(p.t, err) { + return + } + } +} + +var responsesClosed = xerrors.New("responses closed") + +func (p *testPeer) handleOneResp() error { + select { + case <-p.ctx.Done(): + return p.ctx.Err() + case resp, ok := <-p.resps: + if !ok { + return responsesClosed + } + for _, update := range resp.PeerUpdates { + id, err := uuid.FromBytes(update.Uuid) + if err != nil { + return err } - for _, update := range resp.PeerUpdates { - id, err := uuid.FromBytes(update.Uuid) - if !assert.NoError(p.t, err) { - return - } - switch update.Kind { - case proto.CoordinateResponse_PeerUpdate_NODE: - p.derps[id] = update.Node.PreferredDerp - case proto.CoordinateResponse_PeerUpdate_DISCONNECTED: - delete(p.derps, id) - default: - p.t.Errorf("unhandled update kind %s", update.Kind) + switch update.Kind { + case proto.CoordinateResponse_PeerUpdate_NODE, proto.CoordinateResponse_PeerUpdate_LOST: + p.peers[id] = peerStatus{ + preferredDERP: update.GetNode().GetPreferredDerp(), + status: update.Kind, } + case proto.CoordinateResponse_PeerUpdate_DISCONNECTED: + delete(p.peers, id) + default: + return xerrors.Errorf("unhandled update kind %s", update.Kind) } } } + return nil } func (p *testPeer) close(ctx context.Context) {