diff --git a/enterprise/tailnet/connio.go b/enterprise/tailnet/connio.go index fed307758603e..49b04ac8cb816 100644 --- a/enterprise/tailnet/connio.go +++ b/enterprise/tailnet/connio.go @@ -2,136 +2,230 @@ package tailnet import ( "context" - "encoding/json" "io" - "net" + "sync" + "sync/atomic" + "time" "github.com/google/uuid" "golang.org/x/xerrors" - "nhooyr.io/websocket" "cdr.dev/slog" + agpl "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" ) -// connIO manages the reading and writing to a connected client or agent. Agent connIOs have their client field set to -// uuid.Nil. It reads node updates via its decoder, then pushes them onto the bindings channel. It receives mappings -// via its updates TrackedConn, which then writes them. +// connIO manages the reading and writing to a connected peer. It reads requests via its requests +// channel, then pushes them onto the bindings or tunnels channel. It receives responses via calls +// to Enqueue and pushes them onto the responses channel. type connIO struct { - pCtx context.Context - ctx context.Context - cancel context.CancelFunc - logger slog.Logger - decoder *json.Decoder - updates *agpl.TrackedConn - bindings chan<- binding + id uuid.UUID + // coordCtx is the parent context, that is, the context of the Coordinator + coordCtx context.Context + // peerCtx is the context of the connection to our peer + peerCtx context.Context + cancel context.CancelFunc + logger slog.Logger + requests <-chan *proto.CoordinateRequest + responses chan<- *proto.CoordinateResponse + bindings chan<- binding + tunnels chan<- tunnel + auth agpl.TunnelAuth + mu sync.Mutex + closed bool + + name string + start int64 + lastWrite int64 + overwrites int64 } -func newConnIO(pCtx context.Context, +func newConnIO(coordContext context.Context, + peerCtx context.Context, logger slog.Logger, bindings chan<- binding, - conn net.Conn, + tunnels chan<- tunnel, + requests <-chan *proto.CoordinateRequest, + responses chan<- *proto.CoordinateResponse, id uuid.UUID, name string, - kind agpl.QueueKind, + auth agpl.TunnelAuth, ) *connIO { - ctx, cancel := context.WithCancel(pCtx) + peerCtx, cancel := context.WithCancel(peerCtx) + now := time.Now().Unix() c := &connIO{ - pCtx: pCtx, - ctx: ctx, - cancel: cancel, - logger: logger, - decoder: json.NewDecoder(conn), - updates: agpl.NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, kind), - bindings: bindings, + id: id, + coordCtx: coordContext, + peerCtx: peerCtx, + cancel: cancel, + logger: logger.With(slog.F("name", name)), + requests: requests, + responses: responses, + bindings: bindings, + tunnels: tunnels, + auth: auth, + name: name, + start: now, + lastWrite: now, } go c.recvLoop() - go c.updates.SendUpdates() - logger.Info(ctx, "serving connection") + c.logger.Info(coordContext, "serving connection") return c } func (c *connIO) recvLoop() { defer func() { - // withdraw bindings when we exit. We need to use the parent context here, since our own context might be - // canceled, but we still need to withdraw bindings. + // withdraw bindings & tunnels when we exit. We need to use the parent context here, since + // our own context might be canceled, but we still need to withdraw. b := binding{ - bKey: bKey{ - id: c.UniqueID(), - kind: c.Kind(), - }, + bKey: bKey(c.UniqueID()), + } + if err := sendCtx(c.coordCtx, c.bindings, b); err != nil { + c.logger.Debug(c.coordCtx, "parent context expired while withdrawing bindings", slog.Error(err)) } - if err := sendCtx(c.pCtx, c.bindings, b); err != nil { - c.logger.Debug(c.ctx, "parent context expired while withdrawing bindings", slog.Error(err)) + t := tunnel{ + tKey: tKey{src: c.UniqueID()}, + active: false, + } + if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil { + c.logger.Debug(c.coordCtx, "parent context expired while withdrawing tunnels", slog.Error(err)) } }() - defer c.cancel() + defer c.Close() for { - var node agpl.Node - err := c.decoder.Decode(&node) + req, err := recvCtx(c.peerCtx, c.requests) if err != nil { - if xerrors.Is(err, io.EOF) || - xerrors.Is(err, io.ErrClosedPipe) || - xerrors.Is(err, context.Canceled) || + if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) || - websocket.CloseStatus(err) > 0 { - c.logger.Debug(c.ctx, "exiting recvLoop", slog.Error(err)) + xerrors.Is(err, io.EOF) { + c.logger.Debug(c.coordCtx, "exiting io recvLoop", slog.Error(err)) } else { - c.logger.Error(c.ctx, "failed to decode Node update", slog.Error(err)) + c.logger.Error(c.coordCtx, "failed to receive request", slog.Error(err)) } return } - c.logger.Debug(c.ctx, "got node update", slog.F("node", node)) + if err := c.handleRequest(req); err != nil { + return + } + } +} + +func (c *connIO) handleRequest(req *proto.CoordinateRequest) error { + c.logger.Debug(c.peerCtx, "got request") + if req.UpdateSelf != nil { + c.logger.Debug(c.peerCtx, "got node update", slog.F("node", req.UpdateSelf)) b := binding{ - bKey: bKey{ - id: c.UniqueID(), - kind: c.Kind(), + bKey: bKey(c.UniqueID()), + node: req.UpdateSelf.Node, + } + if err := sendCtx(c.coordCtx, c.bindings, b); err != nil { + c.logger.Debug(c.peerCtx, "failed to send binding", slog.Error(err)) + return err + } + } + if req.AddTunnel != nil { + c.logger.Debug(c.peerCtx, "got add tunnel", slog.F("tunnel", req.AddTunnel)) + dst, err := uuid.FromBytes(req.AddTunnel.Uuid) + if err != nil { + c.logger.Error(c.peerCtx, "unable to convert bytes to UUID", slog.Error(err)) + // this shouldn't happen unless there is a client error. Close the connection so the client + // doesn't just happily continue thinking everything is fine. + return err + } + if !c.auth.Authorize(dst) { + return xerrors.New("unauthorized tunnel") + } + t := tunnel{ + tKey: tKey{ + src: c.UniqueID(), + dst: dst, }, - node: &node, + active: true, } - if err := sendCtx(c.ctx, c.bindings, b); err != nil { - c.logger.Debug(c.ctx, "recvLoop ctx expired", slog.Error(err)) - return + if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil { + c.logger.Debug(c.peerCtx, "failed to send add tunnel", slog.Error(err)) + return err + } + } + if req.RemoveTunnel != nil { + c.logger.Debug(c.peerCtx, "got remove tunnel", slog.F("tunnel", req.RemoveTunnel)) + dst, err := uuid.FromBytes(req.RemoveTunnel.Uuid) + if err != nil { + c.logger.Error(c.peerCtx, "unable to convert bytes to UUID", slog.Error(err)) + // this shouldn't happen unless there is a client error. Close the connection so the client + // doesn't just happily continue thinking everything is fine. + return err + } + t := tunnel{ + tKey: tKey{ + src: c.UniqueID(), + dst: dst, + }, + active: false, + } + if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil { + c.logger.Debug(c.peerCtx, "failed to send remove tunnel", slog.Error(err)) + return err } } + // TODO: (spikecurtis) support Disconnect + return nil } func (c *connIO) UniqueID() uuid.UUID { - return c.updates.UniqueID() -} - -func (c *connIO) Kind() agpl.QueueKind { - return c.updates.Kind() + return c.id } -func (c *connIO) Enqueue(n []*agpl.Node) error { - return c.updates.Enqueue(n) +func (c *connIO) Enqueue(resp *proto.CoordinateResponse) error { + atomic.StoreInt64(&c.lastWrite, time.Now().Unix()) + c.mu.Lock() + closed := c.closed + c.mu.Unlock() + if closed { + return xerrors.New("connIO closed") + } + select { + case <-c.peerCtx.Done(): + return c.peerCtx.Err() + case c.responses <- resp: + c.logger.Debug(c.peerCtx, "wrote response") + return nil + default: + return agpl.ErrWouldBlock + } } func (c *connIO) Name() string { - return c.updates.Name() + return c.name } func (c *connIO) Stats() (start int64, lastWrite int64) { - return c.updates.Stats() + return c.start, atomic.LoadInt64(&c.lastWrite) } func (c *connIO) Overwrites() int64 { - return c.updates.Overwrites() + return atomic.LoadInt64(&c.overwrites) } // CoordinatorClose is used by the coordinator when closing a Queue. It // should skip removing itself from the coordinator. func (c *connIO) CoordinatorClose() error { - c.cancel() - return c.updates.CoordinatorClose() + return c.Close() } func (c *connIO) Done() <-chan struct{} { - return c.ctx.Done() + return c.peerCtx.Done() } func (c *connIO) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return nil + } c.cancel() - return c.updates.Close() + c.closed = true + close(c.responses) + return nil } diff --git a/enterprise/tailnet/multiagent_test.go b/enterprise/tailnet/multiagent_test.go index 7546bec350504..8978c59418e95 100644 --- a/enterprise/tailnet/multiagent_test.go +++ b/enterprise/tailnet/multiagent_test.go @@ -27,11 +27,10 @@ func TestPGCoordinator_MultiAgent(t *testing.T) { t.Skip("test only with postgres") } - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) - defer cancel() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) store, ps := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) require.NoError(t, err) defer coord1.Close() @@ -75,11 +74,10 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) { t.Skip("test only with postgres") } - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) - defer cancel() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) store, ps := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) require.NoError(t, err) defer coord1.Close() @@ -124,11 +122,10 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) { t.Skip("test only with postgres") } - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) store, ps := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) require.NoError(t, err) defer coord1.Close() @@ -189,11 +186,10 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) { t.Skip("test only with postgres") } - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) - defer cancel() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) store, ps := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) require.NoError(t, err) defer coord1.Close() @@ -243,11 +239,10 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test t.Skip("test only with postgres") } - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) - defer cancel() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) store, ps := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) require.NoError(t, err) defer coord1.Close() @@ -299,11 +294,10 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) { t.Skip("test only with postgres") } - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) - defer cancel() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) store, ps := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) require.NoError(t, err) defer coord1.Close() diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index bb2a2ac7eac0e..3803b8cb20b6c 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -4,18 +4,24 @@ import ( "context" "database/sql" "encoding/json" - "fmt" + "io" "net" "net/http" "net/netip" "strings" "sync" + "sync/atomic" "time" + "nhooyr.io/websocket" + + "github.com/coder/coder/v2/tailnet/proto" + "github.com/cenkalti/backoff/v4" "github.com/google/uuid" "golang.org/x/exp/slices" "golang.org/x/xerrors" + gProto "google.golang.org/protobuf/proto" "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" @@ -27,22 +33,26 @@ import ( ) const ( - EventHeartbeats = "tailnet_coordinator_heartbeat" - eventClientUpdate = "tailnet_client_update" - eventAgentUpdate = "tailnet_agent_update" - HeartbeatPeriod = time.Second * 2 - MissedHeartbeats = 3 - numQuerierWorkers = 10 - numBinderWorkers = 10 - numSubscriberWorkers = 10 - dbMaxBackoff = 10 * time.Second - cleanupPeriod = time.Hour + EventHeartbeats = "tailnet_coordinator_heartbeat" + eventPeerUpdate = "tailnet_peer_update" + eventTunnelUpdate = "tailnet_tunnel_update" + HeartbeatPeriod = time.Second * 2 + MissedHeartbeats = 3 + numQuerierWorkers = 10 + numBinderWorkers = 10 + numTunnelerWorkers = 10 + dbMaxBackoff = 10 * time.Second + cleanupPeriod = time.Hour + requestResponseBuffSize = 32 ) -// TODO: add subscriber to this graphic // pgCoord is a postgres-backed coordinator // -// ┌────────┐ ┌────────┐ ┌───────┐ +// ┌──────────┐ +// ┌────────────► tunneler ├──────────┐ +// │ └──────────┘ │ +// │ │ +// ┌────────┐ ┌────────┐ ┌───▼───┐ // │ connIO ├───────► binder ├────────► store │ // └───▲────┘ │ │ │ │ // │ └────────┘ ┌──────┤ │ @@ -54,19 +64,19 @@ const ( // │ │ │ │ // └───────────┘ └────────┘ // -// each incoming connection (websocket) from a client or agent is wrapped in a connIO which handles reading & writing -// from it. Node updates from a connIO are sent to the binder, which writes them to the database.Store. The querier -// is responsible for querying the store for the nodes the connection needs (e.g. for a client, the corresponding -// agent). The querier receives pubsub notifications about changes, which trigger queries for the latest state. +// each incoming connection (websocket) from a peer is wrapped in a connIO which handles reading & writing +// from it. Node updates from a connIO are sent to the binder, which writes them to the database.Store. Tunnel +// updates from a connIO are sent to the tunneler, which writes them to the database.Store. The querier is responsible +// for querying the store for the nodes the connection needs. The querier receives pubsub notifications about changes, +// which trigger queries for the latest state. // // The querier also sends the coordinator's heartbeat, and monitors the heartbeats of other coordinators. When // heartbeats cease for a coordinator, it stops using any nodes discovered from that coordinator and pushes an update // to affected connIOs. // -// This package uses the term "binding" to mean the act of registering an association between some connection (client -// or agent) and an agpl.Node. It uses the term "mapping" to mean the act of determining the nodes that the connection -// needs to receive (e.g. for a client, the node bound to the corresponding agent, or for an agent, the nodes bound to -// all clients of the agent). +// This package uses the term "binding" to mean the act of registering an association between some connection +// and a *proto.Node. It uses the term "mapping" to mean the act of determining the nodes that the connection +// needs to receive (i.e. the nodes of all peers it shares a tunnel with). type pgCoord struct { ctx context.Context logger slog.Logger @@ -74,19 +84,18 @@ type pgCoord struct { store database.Store bindings chan binding - newConnections chan agpl.Queue - closeConnections chan agpl.Queue - subscriberCh chan subscribe - querierSubCh chan subscribe + newConnections chan *connIO + closeConnections chan *connIO + tunnelerCh chan tunnel id uuid.UUID cancel context.CancelFunc closeOnce sync.Once closed chan struct{} - binder *binder - subscriber *subscriber - querier *querier + binder *binder + tunneler *tunneler + querier *querier } var pgCoordSubject = rbac.Subject{ @@ -108,18 +117,24 @@ var pgCoordSubject = rbac.Subject{ // NewPGCoord creates a high-availability coordinator that stores state in the PostgreSQL database and // receives notifications of updates via the pubsub. func NewPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store database.Store) (agpl.Coordinator, error) { + return newPGCoordInternal(ctx, logger, ps, store) +} + +func newPGCoordInternal( + ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store database.Store, +) ( + *pgCoord, error, +) { ctx, cancel := context.WithCancel(dbauthz.As(ctx, pgCoordSubject)) id := uuid.New() logger = logger.Named("pgcoord").With(slog.F("coordinator_id", id)) bCh := make(chan binding) // used for opening connections - cCh := make(chan agpl.Queue) + cCh := make(chan *connIO) // used for closing connections - ccCh := make(chan agpl.Queue) - // for communicating subscriptions with the subscriber - sCh := make(chan subscribe) - // for communicating subscriptions with the querier - qsCh := make(chan subscribe) + ccCh := make(chan *connIO) + // for communicating subscriptions with the tunneler + sCh := make(chan tunnel) // signals when first heartbeat has been sent, so it's safe to start binding. fHB := make(chan struct{}) @@ -133,22 +148,30 @@ func NewPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store bindings: bCh, newConnections: cCh, closeConnections: ccCh, - subscriber: newSubscriber(ctx, logger, id, store, sCh, fHB), - subscriberCh: sCh, - querierSubCh: qsCh, + tunneler: newTunneler(ctx, logger, id, store, sCh, fHB), + tunnelerCh: sCh, id: id, - querier: newQuerier(ctx, logger, id, ps, store, id, cCh, ccCh, qsCh, numQuerierWorkers, fHB), + querier: newQuerier(ctx, logger, id, ps, store, id, cCh, ccCh, numQuerierWorkers, fHB), closed: make(chan struct{}), } logger.Info(ctx, "starting coordinator") return c, nil } +// NewPGCoordV2 creates a high-availability coordinator that stores state in the PostgreSQL database and +// receives notifications of updates via the pubsub. +func NewPGCoordV2(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store database.Store) (agpl.CoordinatorV2, error) { + return newPGCoordInternal(ctx, logger, ps, store) +} + // This is copied from codersdk because importing it here would cause an import // cycle. This is just temporary until wsconncache is phased out. var legacyAgentIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4") func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn { + logger := c.logger.With(slog.F("client_id", id)).Named("multiagent") + ctx, cancel := context.WithCancel(c.ctx) + reqs, resps := c.Coordinate(ctx, id, id.String(), agpl.SingleTailnetTunnelAuth{}) ma := (&agpl.MultiAgent{ ID: id, AgentIsLegacyFunc: func(agentID uuid.UUID) bool { @@ -166,150 +189,177 @@ func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn { } }, OnSubscribe: func(enq agpl.Queue, agent uuid.UUID) (*agpl.Node, error) { - err := c.addSubscription(enq, agent) + err := sendCtx(ctx, reqs, &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(agent)}}) return c.Node(agent), err }, - OnUnsubscribe: c.removeSubscription, + OnUnsubscribe: func(enq agpl.Queue, agent uuid.UUID) error { + err := sendCtx(ctx, reqs, &proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(agent)}}) + return err + }, OnNodeUpdate: func(id uuid.UUID, node *agpl.Node) error { - return sendCtx(c.ctx, c.bindings, binding{ - bKey: bKey{id, agpl.QueueKindClient}, - node: node, - }) + pn, err := agpl.NodeToProto(node) + if err != nil { + return err + } + return sendCtx(c.ctx, reqs, &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{ + Node: pn, + }}) }, - OnRemove: func(enq agpl.Queue) { - _ = sendCtx(c.ctx, c.bindings, binding{ - bKey: bKey{ - id: enq.UniqueID(), - kind: enq.Kind(), - }, - }) - _ = sendCtx(c.ctx, c.subscriberCh, subscribe{ - sKey: sKey{clientID: id}, - q: enq, - active: false, - }) - _ = sendCtx(c.ctx, c.closeConnections, enq) + OnRemove: func(_ agpl.Queue) { + cancel() }, }).Init() - if err := sendCtx(c.ctx, c.newConnections, agpl.Queue(ma)); err != nil { - // If we can't successfully send the multiagent, that means the - // coordinator is shutting down. In this case, just return a closed - // multiagent. - ma.CoordinatorClose() - } - + go v1SendLoop(ctx, cancel, logger, ma, resps) return ma } -func (c *pgCoord) addSubscription(q agpl.Queue, agentID uuid.UUID) error { - sub := subscribe{ - sKey: sKey{ - clientID: q.UniqueID(), - agentID: agentID, - }, - q: q, - active: true, - } - if err := sendCtx(c.ctx, c.subscriberCh, sub); err != nil { - return err - } - if err := sendCtx(c.ctx, c.querierSubCh, sub); err != nil { - // There's no need to clean up the sub sent to the subscriber if this - // fails, since it means the entire coordinator is being torn down. - return err - } - - return nil -} - -func (c *pgCoord) removeSubscription(q agpl.Queue, agentID uuid.UUID) error { - sub := subscribe{ - sKey: sKey{ - clientID: q.UniqueID(), - agentID: agentID, - }, - q: q, - active: false, - } - if err := sendCtx(c.ctx, c.subscriberCh, sub); err != nil { - return err - } - if err := sendCtx(c.ctx, c.querierSubCh, sub); err != nil { - // There's no need to clean up the sub sent to the subscriber if this - // fails, since it means the entire coordinator is being torn down. - return err - } - - return nil -} - func (c *pgCoord) Node(id uuid.UUID) *agpl.Node { - // In production, we only ever get this request for an agent. - // We're going to directly query the database, since we would only have the agent mapping stored locally if we had - // a client of that agent connected, which isn't always the case. - mappings, err := c.querier.queryAgent(id) + // We're going to directly query the database, since we would only have the mapping stored locally if we had + // a tunnel peer connected, which is not always the case. + peers, err := c.store.GetTailnetPeers(c.ctx, id) if err != nil { - c.logger.Error(c.ctx, "failed to query agents", slog.Error(err)) + c.logger.Error(c.ctx, "failed to query peers", slog.Error(err)) + return nil + } + mappings := make([]mapping, 0, len(peers)) + for _, peer := range peers { + pNode := new(proto.Node) + err := gProto.Unmarshal(peer.Node, pNode) + if err != nil { + c.logger.Critical(c.ctx, "failed to unmarshal node", slog.F("bytes", peer.Node), slog.Error(err)) + return nil + } + mappings = append(mappings, mapping{ + peer: peer.ID, + coordinator: peer.CoordinatorID, + updatedAt: peer.UpdatedAt, + node: pNode, + }) } mappings = c.querier.heartbeats.filter(mappings) var bestT time.Time - var bestN *agpl.Node + var bestN *proto.Node for _, m := range mappings { if m.updatedAt.After(bestT) { bestN = m.node bestT = m.updatedAt } } - return bestN + node, err := agpl.ProtoToNode(bestN) + if err != nil { + c.logger.Critical(c.ctx, "failed to convert node", slog.F("node", bestN), slog.Error(err)) + return nil + } + return node } func (c *pgCoord) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { + logger := c.logger.With(slog.F("client_id", id), slog.F("agent_id", agent)) defer func() { err := conn.Close() if err != nil { - c.logger.Debug(c.ctx, "closing client connection", - slog.F("client_id", id), - slog.F("agent_id", agent), - slog.Error(err)) + logger.Debug(c.ctx, "closing client connection", slog.Error(err)) } }() - - cIO := newConnIO(c.ctx, c.logger, c.bindings, conn, id, id.String(), agpl.QueueKindClient) - if err := sendCtx(c.ctx, c.newConnections, agpl.Queue(cIO)); err != nil { + ctx, cancel := context.WithCancel(c.ctx) + defer cancel() + reqs, resps := c.Coordinate(ctx, id, id.String(), agpl.ClientTunnelAuth{AgentID: agent}) + err := sendCtx(ctx, reqs, &proto.CoordinateRequest{ + AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(agent)}, + }) + if err != nil { // can only be a context error, no need to log here. return err } - defer func() { _ = sendCtx(c.ctx, c.closeConnections, agpl.Queue(cIO)) }() - - if err := c.addSubscription(cIO, agent); err != nil { - return err - } - defer func() { _ = c.removeSubscription(cIO, agent) }() + defer func() { + _ = sendCtx(ctx, reqs, &proto.CoordinateRequest{ + RemoveTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(agent)}, + }) + }() - <-cIO.ctx.Done() + tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0, agpl.QueueKindClient) + go tc.SendUpdates() + go v1SendLoop(ctx, cancel, logger, tc, resps) + go v1RecvLoop(ctx, cancel, logger, conn, reqs) + <-ctx.Done() return nil } func (c *pgCoord) ServeAgent(conn net.Conn, id uuid.UUID, name string) error { + logger := c.logger.With(slog.F("agent_id", id), slog.F("name", name)) defer func() { + logger.Debug(c.ctx, "closing agent connection") err := conn.Close() + logger.Debug(c.ctx, "closed agent connection", slog.Error(err)) + }() + ctx, cancel := context.WithCancel(c.ctx) + defer cancel() + reqs, resps := c.Coordinate(ctx, id, name, agpl.AgentTunnelAuth{}) + tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, agpl.QueueKindAgent) + go tc.SendUpdates() + go v1SendLoop(ctx, cancel, logger, tc, resps) + go v1RecvLoop(ctx, cancel, logger, conn, reqs) + <-ctx.Done() + return nil +} + +func v1RecvLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger, + conn net.Conn, reqs chan<- *proto.CoordinateRequest, +) { + defer cancel() + decoder := json.NewDecoder(conn) + for { + var node agpl.Node + err := decoder.Decode(&node) + if err != nil { + if xerrors.Is(err, io.EOF) || + xerrors.Is(err, io.ErrClosedPipe) || + xerrors.Is(err, context.Canceled) || + xerrors.Is(err, context.DeadlineExceeded) || + websocket.CloseStatus(err) > 0 { + logger.Debug(ctx, "exiting recvLoop", slog.Error(err)) + } else { + logger.Error(ctx, "failed to decode Node update", slog.Error(err)) + } + return + } + logger.Debug(ctx, "got node update", slog.F("node", node)) + pn, err := agpl.NodeToProto(&node) if err != nil { - c.logger.Debug(c.ctx, "closing agent connection", - slog.F("agent_id", id), - slog.Error(err)) + logger.Critical(ctx, "failed to convert v1 node", slog.F("node", node), slog.Error(err)) + return + } + req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{ + Node: pn, + }} + if err := sendCtx(ctx, reqs, req); err != nil { + logger.Debug(ctx, "recvLoop ctx expired", slog.Error(err)) + return } - }() - logger := c.logger.With(slog.F("name", name)) - cIO := newConnIO(c.ctx, logger, c.bindings, conn, id, name, agpl.QueueKindAgent) - if err := sendCtx(c.ctx, c.newConnections, agpl.Queue(cIO)); err != nil { - // can only be a context error, no need to log here. - return err } - defer func() { _ = sendCtx(c.ctx, c.closeConnections, agpl.Queue(cIO)) }() +} - <-cIO.ctx.Done() - return nil +func v1SendLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger, q agpl.Queue, resps <-chan *proto.CoordinateResponse) { + defer cancel() + for { + resp, err := recvCtx(ctx, resps) + if err != nil { + logger.Debug(ctx, "done reading responses", slog.Error(err)) + return + } + logger.Debug(ctx, "v1: got response", slog.F("resp", resp)) + nodes, err := agpl.OnlyNodeUpdates(resp) + if err != nil { + logger.Critical(ctx, "failed to decode resp", slog.F("resp", resp), slog.Error(err)) + _ = q.CoordinatorClose() + return + } + err = q.Enqueue(nodes) + if err != nil { + logger.Error(ctx, "failed to enqueue multi-agent update", slog.Error(err)) + } + } } func (c *pgCoord) Close() error { @@ -319,6 +369,28 @@ func (c *pgCoord) Close() error { return nil } +func (c *pgCoord) Coordinate( + ctx context.Context, id uuid.UUID, name string, a agpl.TunnelAuth, +) ( + chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse, +) { + logger := c.logger.With(slog.F("peer_id", id)) + reqs := make(chan *proto.CoordinateRequest, requestResponseBuffSize) + resps := make(chan *proto.CoordinateResponse, agpl.ResponseBufferSize) + cIO := newConnIO(c.ctx, ctx, logger, c.bindings, c.tunnelerCh, reqs, resps, id, name, a) + err := sendCtx(c.ctx, c.newConnections, cIO) + if err != nil { + // this can only happen if the context is canceled, no need to log + return reqs, resps + } + go func() { + <-cIO.Done() + _ = sendCtx(c.ctx, c.closeConnections, cIO) + }() + + return reqs, resps +} + func sendCtx[A any](ctx context.Context, c chan<- A, a A) (err error) { select { case <-ctx.Done(): @@ -328,202 +400,208 @@ func sendCtx[A any](ctx context.Context, c chan<- A, a A) (err error) { } } -type sKey struct { - clientID uuid.UUID - agentID uuid.UUID +func recvCtx[A any](ctx context.Context, c <-chan A) (a A, err error) { + select { + case <-ctx.Done(): + return a, ctx.Err() + case a, ok := <-c: + if ok { + return a, nil + } + return a, io.EOF + } } -type subscribe struct { - sKey +type tKey struct { + src uuid.UUID + dst uuid.UUID +} - q agpl.Queue +type tunnel struct { + tKey // whether the subscription should be active. if true, the subscription is // added. if false, the subscription is removed. active bool } -type subscriber struct { +type tunneler struct { ctx context.Context logger slog.Logger coordinatorID uuid.UUID store database.Store - subscriptions <-chan subscribe + updates <-chan tunnel - mu sync.Mutex - // map[clientID]map[agentID]subscribe - latest map[uuid.UUID]map[uuid.UUID]subscribe - workQ *workQ[sKey] + mu sync.Mutex + latest map[uuid.UUID]map[uuid.UUID]tunnel + workQ *workQ[tKey] } -func newSubscriber(ctx context.Context, +func newTunneler(ctx context.Context, logger slog.Logger, id uuid.UUID, store database.Store, - subscriptions <-chan subscribe, + updates <-chan tunnel, startWorkers <-chan struct{}, -) *subscriber { - s := &subscriber{ +) *tunneler { + s := &tunneler{ ctx: ctx, logger: logger, coordinatorID: id, store: store, - subscriptions: subscriptions, - latest: make(map[uuid.UUID]map[uuid.UUID]subscribe), - workQ: newWorkQ[sKey](ctx), + updates: updates, + latest: make(map[uuid.UUID]map[uuid.UUID]tunnel), + workQ: newWorkQ[tKey](ctx), } - go s.handleSubscriptions() + go s.handle() go func() { <-startWorkers - for i := 0; i < numSubscriberWorkers; i++ { + for i := 0; i < numTunnelerWorkers; i++ { go s.worker() } }() return s } -func (s *subscriber) handleSubscriptions() { +func (t *tunneler) handle() { for { select { - case <-s.ctx.Done(): - s.logger.Debug(s.ctx, "subscriber exiting", slog.Error(s.ctx.Err())) + case <-t.ctx.Done(): + t.logger.Debug(t.ctx, "tunneler exiting", slog.Error(t.ctx.Err())) return - case sub := <-s.subscriptions: - s.storeSubscription(sub) - s.workQ.enqueue(sub.sKey) + case tun := <-t.updates: + t.cache(tun) + t.workQ.enqueue(tun.tKey) } } } -func (s *subscriber) worker() { +func (t *tunneler) worker() { eb := backoff.NewExponentialBackOff() eb.MaxElapsedTime = 0 // retry indefinitely eb.MaxInterval = dbMaxBackoff - bkoff := backoff.WithContext(eb, s.ctx) + bkoff := backoff.WithContext(eb, t.ctx) for { - bk, err := s.workQ.acquire() + tk, err := t.workQ.acquire() if err != nil { // context expired return } err = backoff.Retry(func() error { - bnd := s.retrieveSubscription(bk) - return s.writeOne(bnd) + tun := t.retrieve(tk) + return t.writeOne(tun) }, bkoff) if err != nil { bkoff.Reset() } - s.workQ.done(bk) + t.workQ.done(tk) } } -func (s *subscriber) storeSubscription(sub subscribe) { - s.mu.Lock() - defer s.mu.Unlock() - if sub.active { - if _, ok := s.latest[sub.clientID]; !ok { - s.latest[sub.clientID] = map[uuid.UUID]subscribe{} +func (t *tunneler) cache(update tunnel) { + t.mu.Lock() + defer t.mu.Unlock() + if update.active { + if _, ok := t.latest[update.src]; !ok { + t.latest[update.src] = map[uuid.UUID]tunnel{} } - s.latest[sub.clientID][sub.agentID] = sub + t.latest[update.src][update.dst] = update } else { - // If the agentID is nil, clean up all of the clients subscriptions. - if sub.agentID == uuid.Nil { - delete(s.latest, sub.clientID) + // If inactive and dst is nil, it means clean up all tunnels. + if update.dst == uuid.Nil { + delete(t.latest, update.src) } else { - delete(s.latest[sub.clientID], sub.agentID) - // clean up the subscription map if all the subscriptions are gone. - if len(s.latest[sub.clientID]) == 0 { - delete(s.latest, sub.clientID) + delete(t.latest[update.src], update.dst) + // clean up the tunnel map if all the tunnels are gone. + if len(t.latest[update.src]) == 0 { + delete(t.latest, update.src) } } } } -// retrieveBinding gets the latest binding for a key. -func (s *subscriber) retrieveSubscription(sk sKey) subscribe { - s.mu.Lock() - defer s.mu.Unlock() - agents, ok := s.latest[sk.clientID] +// retrieveBinding gets the latest tunnel for a key. +func (t *tunneler) retrieve(k tKey) tunnel { + t.mu.Lock() + defer t.mu.Unlock() + dstMap, ok := t.latest[k.src] if !ok { - return subscribe{ - sKey: sk, + return tunnel{ + tKey: k, active: false, } } - sub, ok := agents[sk.agentID] + tun, ok := dstMap[k.dst] if !ok { - return subscribe{ - sKey: sk, + return tunnel{ + tKey: k, active: false, } } - return sub + return tun } -func (s *subscriber) writeOne(sub subscribe) error { +func (t *tunneler) writeOne(tun tunnel) error { var err error switch { - case sub.agentID == uuid.Nil: - err = s.store.DeleteAllTailnetClientSubscriptions(s.ctx, database.DeleteAllTailnetClientSubscriptionsParams{ - ClientID: sub.clientID, - CoordinatorID: s.coordinatorID, + case tun.dst == uuid.Nil: + err = t.store.DeleteAllTailnetTunnels(t.ctx, database.DeleteAllTailnetTunnelsParams{ + SrcID: tun.src, + CoordinatorID: t.coordinatorID, }) - s.logger.Debug(s.ctx, "deleted all client subscriptions", - slog.F("client_id", sub.clientID), + t.logger.Debug(t.ctx, "deleted all tunnels", + slog.F("src_id", tun.src), slog.Error(err), ) - case sub.active: - err = s.store.UpsertTailnetClientSubscription(s.ctx, database.UpsertTailnetClientSubscriptionParams{ - ClientID: sub.clientID, - CoordinatorID: s.coordinatorID, - AgentID: sub.agentID, + case tun.active: + _, err = t.store.UpsertTailnetTunnel(t.ctx, database.UpsertTailnetTunnelParams{ + CoordinatorID: t.coordinatorID, + SrcID: tun.src, + DstID: tun.dst, }) - s.logger.Debug(s.ctx, "upserted client subscription", - slog.F("client_id", sub.clientID), - slog.F("agent_id", sub.agentID), + t.logger.Debug(t.ctx, "upserted tunnel", + slog.F("src_id", tun.src), + slog.F("dst_id", tun.dst), slog.Error(err), ) - case !sub.active: - err = s.store.DeleteTailnetClientSubscription(s.ctx, database.DeleteTailnetClientSubscriptionParams{ - ClientID: sub.clientID, - CoordinatorID: s.coordinatorID, - AgentID: sub.agentID, + case !tun.active: + _, err = t.store.DeleteTailnetTunnel(t.ctx, database.DeleteTailnetTunnelParams{ + CoordinatorID: t.coordinatorID, + SrcID: tun.src, + DstID: tun.dst, }) - s.logger.Debug(s.ctx, "deleted client subscription", - slog.F("client_id", sub.clientID), - slog.F("agent_id", sub.agentID), + t.logger.Debug(t.ctx, "deleted tunnel", + slog.F("src_id", tun.src), + slog.F("dst_id", tun.dst), slog.Error(err), ) + // writeOne should be idempotent + if xerrors.Is(err, sql.ErrNoRows) { + err = nil + } default: panic("unreachable") } if err != nil && !database.IsQueryCanceledError(err) { - s.logger.Error(s.ctx, "write subscription to database", - slog.F("client_id", sub.clientID), - slog.F("agent_id", sub.agentID), - slog.F("active", sub.active), + t.logger.Error(t.ctx, "write tunnel to database", + slog.F("src_id", tun.src), + slog.F("dst_id", tun.dst), + slog.F("active", tun.active), slog.Error(err)) } return err } -// bKey, or "binding key" identifies a client or agent in a binding. Agents and -// clients are differentiated by the kind field. -type bKey struct { - id uuid.UUID - kind agpl.QueueKind -} +// bKey, or "binding key" identifies a peer in a binding +type bKey uuid.UUID -// binding represents an association between a client or agent and a Node. +// binding represents an association between a peer and a Node. type binding struct { bKey - node *agpl.Node + node *proto.Node } -func (b *binding) isAgent() bool { return b.kind == agpl.QueueKindAgent } -func (b *binding) isClient() bool { return b.kind == agpl.QueueKindClient } - // binder reads node bindings from the channel and writes them to the database. It handles retries with a backoff. type binder struct { ctx context.Context @@ -599,66 +677,37 @@ func (b *binder) worker() { } func (b *binder) writeOne(bnd binding) error { - var nodeRaw json.RawMessage var err error if bnd.node != nil { - nodeRaw, err = json.Marshal(*bnd.node) + var nodeRaw []byte + nodeRaw, err = gProto.Marshal(bnd.node) if err != nil { - // this is very bad news, but it should never happen because the node was Unmarshalled by this process - // earlier. - b.logger.Error(b.ctx, "failed to marshal node", slog.Error(err)) + // this is very bad news, but it should never happen because the node was Unmarshalled or converted by this + // process earlier. + b.logger.Critical(b.ctx, "failed to marshal node", slog.Error(err)) return err } - } - - switch { - case bnd.isAgent() && len(nodeRaw) > 0: - _, err = b.store.UpsertTailnetAgent(b.ctx, database.UpsertTailnetAgentParams{ - ID: bnd.id, - CoordinatorID: b.coordinatorID, - Node: nodeRaw, - }) - b.logger.Debug(b.ctx, "upserted agent binding", - slog.F("agent_id", bnd.id), slog.F("node", nodeRaw), slog.Error(err)) - case bnd.isAgent() && len(nodeRaw) == 0: - _, err = b.store.DeleteTailnetAgent(b.ctx, database.DeleteTailnetAgentParams{ - ID: bnd.id, - CoordinatorID: b.coordinatorID, - }) - b.logger.Debug(b.ctx, "deleted agent binding", - slog.F("agent_id", bnd.id), slog.Error(err)) - if xerrors.Is(err, sql.ErrNoRows) { - // treat deletes as idempotent - err = nil - } - case bnd.isClient() && len(nodeRaw) > 0: - _, err = b.store.UpsertTailnetClient(b.ctx, database.UpsertTailnetClientParams{ - ID: bnd.id, + _, err = b.store.UpsertTailnetPeer(b.ctx, database.UpsertTailnetPeerParams{ + ID: uuid.UUID(bnd.bKey), CoordinatorID: b.coordinatorID, Node: nodeRaw, + Status: database.TailnetStatusOk, }) - b.logger.Debug(b.ctx, "upserted client binding", - slog.F("client_id", bnd.id), - slog.F("node", nodeRaw), slog.Error(err)) - case bnd.isClient() && len(nodeRaw) == 0: - _, err = b.store.DeleteTailnetClient(b.ctx, database.DeleteTailnetClientParams{ - ID: bnd.id, + } else { + _, err = b.store.DeleteTailnetPeer(b.ctx, database.DeleteTailnetPeerParams{ + ID: uuid.UUID(bnd.bKey), CoordinatorID: b.coordinatorID, }) - b.logger.Debug(b.ctx, "deleted client binding", - slog.F("client_id", bnd.id)) + // writeOne is idempotent if xerrors.Is(err, sql.ErrNoRows) { - // treat deletes as idempotent err = nil } - default: - panic("unhittable") } + if err != nil && !database.IsQueryCanceledError(err) { b.logger.Error(b.ctx, "failed to write binding to database", - slog.F("binding_id", bnd.id), - slog.F("kind", bnd.kind), - slog.F("node", string(nodeRaw)), + slog.F("binding_id", bnd.bKey), + slog.F("node", bnd.node), slog.Error(err)) } return err @@ -691,42 +740,41 @@ func (b *binder) retrieveBinding(bk bKey) binding { return bnd } -// mapper tracks a single client or agent ID, and fans out updates to that ID->node mapping to every local connection -// that needs it. +// mapper tracks data sent to a peer, and sends updates based on changes read from the database. type mapper struct { ctx context.Context logger slog.Logger - add chan agpl.Queue - del chan agpl.Queue - - // reads from this channel trigger sending latest nodes to - // all connections. It is used when coordinators are added - // or removed + // reads from this channel trigger recomputing the set of mappings to send, and sending any updates. It is used when + // coordinators are added or removed update chan struct{} mappings chan []mapping - conns map[bKey]agpl.Queue + c *connIO + + // latest is the most recent, unfiltered snapshot of the mappings we know about latest []mapping + // sent is the state of mappings we have actually enqueued; used to compute diffs for updates. It is a map from peer + // ID to node. + sent map[uuid.UUID]*proto.Node + // called to filter mappings to healthy coordinators heartbeats *heartbeats } -func newMapper(ctx context.Context, logger slog.Logger, mk mKey, h *heartbeats) *mapper { +func newMapper(c *connIO, logger slog.Logger, h *heartbeats) *mapper { logger = logger.With( - slog.F("agent_id", mk.agent), - slog.F("kind", mk.kind), + slog.F("peer_id", c.UniqueID()), ) m := &mapper{ - ctx: ctx, + ctx: c.peerCtx, // mapper has same lifetime as the underlying connection it serves logger: logger, - add: make(chan agpl.Queue), - del: make(chan agpl.Queue), + c: c, update: make(chan struct{}), - conns: make(map[bKey]agpl.Queue), mappings: make(chan []mapping), heartbeats: h, + sent: make(map[uuid.UUID]*proto.Node), } go m.run() return m @@ -734,44 +782,25 @@ func newMapper(ctx context.Context, logger slog.Logger, mk mKey, h *heartbeats) func (m *mapper) run() { for { + var nodes map[uuid.UUID]*proto.Node select { case <-m.ctx.Done(): return - case c := <-m.add: - m.conns[bKey{id: c.UniqueID(), kind: c.Kind()}] = c - nodes := m.mappingsToNodes(m.latest) - if len(nodes) == 0 { - m.logger.Debug(m.ctx, "skipping 0 length node update") - continue - } - if err := c.Enqueue(nodes); err != nil { - m.logger.Error(m.ctx, "failed to enqueue node update", slog.Error(err)) - } - case c := <-m.del: - delete(m.conns, bKey{id: c.UniqueID(), kind: c.Kind()}) case mappings := <-m.mappings: + m.logger.Debug(m.ctx, "got new mappings") m.latest = mappings - nodes := m.mappingsToNodes(mappings) - if len(nodes) == 0 { - m.logger.Debug(m.ctx, "skipping 0 length node update") - continue - } - for _, conn := range m.conns { - if err := conn.Enqueue(nodes); err != nil { - m.logger.Error(m.ctx, "failed to enqueue node update", slog.Error(err)) - } - } + nodes = m.mappingsToNodes(mappings) case <-m.update: - nodes := m.mappingsToNodes(m.latest) - if len(nodes) == 0 { - m.logger.Debug(m.ctx, "skipping 0 length node update") - continue - } - for _, conn := range m.conns { - if err := conn.Enqueue(nodes); err != nil { - m.logger.Error(m.ctx, "failed to enqueue triggered node update", slog.Error(err)) - } - } + m.logger.Debug(m.ctx, "triggered update") + nodes = m.mappingsToNodes(m.latest) + } + update := m.nodesToUpdate(nodes) + if update == nil { + m.logger.Debug(m.ctx, "skipping nil node update") + continue + } + if err := m.c.Enqueue(update); err != nil { + m.logger.Error(m.ctx, "failed to enqueue node update", slog.Error(err)) } } } @@ -780,32 +809,83 @@ func (m *mapper) run() { // particular connection, from different coordinators in the distributed system. Furthermore, some coordinators // might be considered invalid on account of missing heartbeats. We take the most recent mapping from a valid // coordinator as the "best" mapping. -func (m *mapper) mappingsToNodes(mappings []mapping) []*agpl.Node { +func (m *mapper) mappingsToNodes(mappings []mapping) map[uuid.UUID]*proto.Node { mappings = m.heartbeats.filter(mappings) - best := make(map[bKey]mapping, len(mappings)) + best := make(map[uuid.UUID]mapping, len(mappings)) for _, m := range mappings { - var bk bKey - if m.client == uuid.Nil { - bk = bKey{id: m.agent, kind: agpl.QueueKindAgent} - } else { - bk = bKey{id: m.client, kind: agpl.QueueKindClient} - } - - bestM, ok := best[bk] + bestM, ok := best[m.peer] if !ok || m.updatedAt.After(bestM.updatedAt) { - best[bk] = m + best[m.peer] = m } } - nodes := make([]*agpl.Node, 0, len(best)) - for _, m := range best { - nodes = append(nodes, m.node) + nodes := make(map[uuid.UUID]*proto.Node, len(best)) + for k, m := range best { + nodes[k] = m.node } return nodes } -// querier is responsible for monitoring pubsub notifications and querying the database for the mappings that all -// connected clients and agents need. It also checks heartbeats and withdraws mappings from coordinators that have -// failed heartbeats. +func (m *mapper) nodesToUpdate(nodes map[uuid.UUID]*proto.Node) *proto.CoordinateResponse { + resp := new(proto.CoordinateResponse) + + for k, n := range nodes { + sn, ok := m.sent[k] + if !ok { + resp.PeerUpdates = append(resp.PeerUpdates, &proto.CoordinateResponse_PeerUpdate{ + Uuid: agpl.UUIDToByteSlice(k), + Node: n, + Kind: proto.CoordinateResponse_PeerUpdate_NODE, + Reason: "new", + }) + continue + } + eq, err := sn.Equal(n) + if err != nil { + m.logger.Critical(m.ctx, "failed to compare nodes", slog.F("old", sn), slog.F("new", n)) + } + if !eq { + resp.PeerUpdates = append(resp.PeerUpdates, &proto.CoordinateResponse_PeerUpdate{ + Uuid: agpl.UUIDToByteSlice(k), + Node: n, + Kind: proto.CoordinateResponse_PeerUpdate_NODE, + Reason: "update", + }) + continue + } + } + + for k := range m.sent { + if _, ok := nodes[k]; !ok { + resp.PeerUpdates = append(resp.PeerUpdates, &proto.CoordinateResponse_PeerUpdate{ + Uuid: agpl.UUIDToByteSlice(k), + Kind: proto.CoordinateResponse_PeerUpdate_DISCONNECTED, + Reason: "disconnected", + }) + } + } + + m.sent = nodes + + if len(resp.PeerUpdates) == 0 { + return nil + } + return resp +} + +// querier is responsible for monitoring pubsub notifications and querying the database for the +// mappings that all connected peers need. It also checks heartbeats and withdraws mappings from +// coordinators that have failed heartbeats. +// +// There are two kinds of pubsub notifications it listens for and responds to. +// +// 1. Tunnel updates --- a tunnel was added or removed. In this case we need +// to recompute the mappings for peers on both sides of the tunnel. +// 2. Peer updates --- a peer got a new binding. When a peer gets a new +// binding, we need to update all the _other_ peers it shares a tunnel with. +// However, we don't keep tunnels in memory (to avoid the +// complexity of synchronizing with the database), so we first have to query +// the database to learn the tunnel peers, then schedule an update on each +// one. type querier struct { ctx context.Context logger slog.Logger @@ -813,28 +893,17 @@ type querier struct { pubsub pubsub.Pubsub store database.Store - newConnections chan agpl.Queue - closeConnections chan agpl.Queue - subscriptions chan subscribe + newConnections chan *connIO + closeConnections chan *connIO - workQ *workQ[mKey] + workQ *workQ[querierWorkKey] heartbeats *heartbeats updates <-chan hbUpdate mu sync.Mutex - mappers map[mKey]*countedMapper - conns map[uuid.UUID]agpl.Queue - // clientSubscriptions maps client ids to the agent ids they're subscribed to. - // map[client_id]map[agent_id] - clientSubscriptions map[uuid.UUID]map[uuid.UUID]struct{} - healthy bool -} - -type countedMapper struct { - *mapper - count int - cancel context.CancelFunc + mappers map[mKey]*mapper + healthy bool } func newQuerier(ctx context.Context, @@ -843,29 +912,25 @@ func newQuerier(ctx context.Context, ps pubsub.Pubsub, store database.Store, self uuid.UUID, - newConnections chan agpl.Queue, - closeConnections chan agpl.Queue, - subscriptions chan subscribe, + newConnections chan *connIO, + closeConnections chan *connIO, numWorkers int, firstHeartbeat chan struct{}, ) *querier { updates := make(chan hbUpdate) q := &querier{ - ctx: ctx, - logger: logger.Named("querier"), - coordinatorID: coordinatorID, - pubsub: ps, - store: store, - newConnections: newConnections, - closeConnections: closeConnections, - subscriptions: subscriptions, - workQ: newWorkQ[mKey](ctx), - heartbeats: newHeartbeats(ctx, logger, ps, store, self, updates, firstHeartbeat), - mappers: make(map[mKey]*countedMapper), - conns: make(map[uuid.UUID]agpl.Queue), - updates: updates, - clientSubscriptions: make(map[uuid.UUID]map[uuid.UUID]struct{}), - healthy: true, // assume we start healthy + ctx: ctx, + logger: logger.Named("querier"), + coordinatorID: coordinatorID, + pubsub: ps, + store: store, + newConnections: newConnections, + closeConnections: closeConnections, + workQ: newWorkQ[querierWorkKey](ctx), + heartbeats: newHeartbeats(ctx, logger, ps, store, self, updates, firstHeartbeat), + mappers: make(map[mKey]*mapper), + updates: updates, + healthy: true, // assume we start healthy } q.subscribe() @@ -887,180 +952,62 @@ func (q *querier) handleIncoming() { return case c := <-q.newConnections: - switch c.Kind() { - case agpl.QueueKindAgent: - q.newAgentConn(c) - case agpl.QueueKindClient: - q.newClientConn(c) - default: - panic(fmt.Sprint("unreachable: invalid queue kind ", c.Kind())) - } + q.newConn(c) case c := <-q.closeConnections: q.cleanupConn(c) - - case sub := <-q.subscriptions: - if sub.active { - q.newClientSubscription(sub.q, sub.agentID) - } else { - q.removeClientSubscription(sub.q, sub.agentID) - } } } } -func (q *querier) newAgentConn(c agpl.Queue) { +func (q *querier) newConn(c *connIO) { q.mu.Lock() defer q.mu.Unlock() if !q.healthy { err := c.Close() q.logger.Info(q.ctx, "closed incoming connection while unhealthy", slog.Error(err), - slog.F("agent_id", c.UniqueID()), + slog.F("peer_id", c.UniqueID()), ) return } - mk := mKey{ - agent: c.UniqueID(), - kind: c.Kind(), - } - cm, ok := q.mappers[mk] - if !ok { - ctx, cancel := context.WithCancel(q.ctx) - mpr := newMapper(ctx, q.logger, mk, q.heartbeats) - cm = &countedMapper{ - mapper: mpr, - count: 0, - cancel: cancel, - } - q.mappers[mk] = cm - // we don't have any mapping state for this key yet - q.workQ.enqueue(mk) - } - if err := sendCtx(cm.ctx, cm.add, c); err != nil { - return - } - cm.count++ - q.conns[c.UniqueID()] = c -} - -func (q *querier) newClientSubscription(c agpl.Queue, agentID uuid.UUID) { - q.mu.Lock() - defer q.mu.Unlock() - - if _, ok := q.clientSubscriptions[c.UniqueID()]; !ok { - q.clientSubscriptions[c.UniqueID()] = map[uuid.UUID]struct{}{} - } - - mk := mKey{ - agent: agentID, - kind: agpl.QueueKindClient, - } - cm, ok := q.mappers[mk] - if !ok { - ctx, cancel := context.WithCancel(q.ctx) - mpr := newMapper(ctx, q.logger, mk, q.heartbeats) - cm = &countedMapper{ - mapper: mpr, - count: 0, - cancel: cancel, - } - q.mappers[mk] = cm - // we don't have any mapping state for this key yet - q.workQ.enqueue(mk) - } - if err := sendCtx(cm.ctx, cm.add, c); err != nil { - return - } - q.clientSubscriptions[c.UniqueID()][agentID] = struct{}{} - cm.count++ -} - -func (q *querier) removeClientSubscription(c agpl.Queue, agentID uuid.UUID) { - q.mu.Lock() - defer q.mu.Unlock() - - // Allow duplicate unsubscribes. It's possible for cleanupConn to race with - // an external call to removeClientSubscription, so we just ensure the - // client subscription exists before attempting to remove it. - if _, ok := q.clientSubscriptions[c.UniqueID()][agentID]; !ok { - return - } - - mk := mKey{ - agent: agentID, - kind: agpl.QueueKindClient, - } - cm := q.mappers[mk] - if err := sendCtx(cm.ctx, cm.del, c); err != nil { - return - } - delete(q.clientSubscriptions[c.UniqueID()], agentID) - cm.count-- - if cm.count == 0 { - cm.cancel() - delete(q.mappers, mk) - } - if len(q.clientSubscriptions[c.UniqueID()]) == 0 { - delete(q.clientSubscriptions, c.UniqueID()) - } -} - -func (q *querier) newClientConn(c agpl.Queue) { - q.mu.Lock() - defer q.mu.Unlock() - if !q.healthy { - err := c.Close() - q.logger.Info(q.ctx, "closed incoming connection while unhealthy", - slog.Error(err), - slog.F("client_id", c.UniqueID()), - ) - return + mpr := newMapper(c, q.logger, q.heartbeats) + mk := mKey(c.UniqueID()) + dup, ok := q.mappers[mk] + if ok { + // duplicate, overwrite and close the old one + atomic.StoreInt64(&c.overwrites, dup.c.Overwrites()+1) + err := dup.c.CoordinatorClose() + if err != nil { + q.logger.Error(q.ctx, "failed to close duplicate mapper", slog.F("peer_id", dup.c.UniqueID()), slog.Error(err)) + } } - - q.conns[c.UniqueID()] = c + q.mappers[mk] = mpr + q.workQ.enqueue(querierWorkKey{ + mappingQuery: mk, + }) } -func (q *querier) cleanupConn(c agpl.Queue) { +func (q *querier) cleanupConn(c *connIO) { + logger := q.logger.With(slog.F("peer_id", c.UniqueID())) q.mu.Lock() defer q.mu.Unlock() - delete(q.conns, c.UniqueID()) - // Iterate over all subscriptions and remove them from the mappers. - for agentID := range q.clientSubscriptions[c.UniqueID()] { - mk := mKey{ - agent: agentID, - kind: c.Kind(), - } - cm := q.mappers[mk] - if err := sendCtx(cm.ctx, cm.del, c); err != nil { - continue - } - cm.count-- - if cm.count == 0 { - cm.cancel() - delete(q.mappers, mk) - } - } - delete(q.clientSubscriptions, c.UniqueID()) - - mk := mKey{ - agent: c.UniqueID(), - kind: c.Kind(), - } - cm, ok := q.mappers[mk] + mk := mKey(c.UniqueID()) + mpr, ok := q.mappers[mk] if !ok { return } - - if err := sendCtx(cm.ctx, cm.del, c); err != nil { + if mpr.c != c { + logger.Debug(q.ctx, "attempt to cleanup for duplicate connection, ignoring") return } - cm.count-- - if cm.count == 0 { - cm.cancel() - delete(q.mappers, mk) + err := c.CoordinatorClose() + if err != nil { + logger.Error(q.ctx, "failed to close connIO", slog.Error(err)) } + delete(q.mappers, mk) + q.logger.Debug(q.ctx, "removed mapper") } func (q *querier) worker() { @@ -1069,108 +1016,104 @@ func (q *querier) worker() { eb.MaxInterval = dbMaxBackoff bkoff := backoff.WithContext(eb, q.ctx) for { - mk, err := q.workQ.acquire() + qk, err := q.workQ.acquire() if err != nil { // context expired return } err = backoff.Retry(func() error { - return q.query(mk) + return q.query(qk) }, bkoff) if err != nil { bkoff.Reset() } - q.workQ.done(mk) + q.workQ.done(qk) } } -func (q *querier) query(mk mKey) error { - var mappings []mapping - var err error - // If the mapping is an agent, query all of its clients. - if mk.kind == agpl.QueueKindAgent { - mappings, err = q.queryClientsOfAgent(mk.agent) - if err != nil { - return err - } - } else { - // The mapping is for clients subscribed to the agent. Query the agent - // itself. - mappings, err = q.queryAgent(mk.agent) - if err != nil { - return err - } +func (q *querier) query(qk querierWorkKey) error { + if uuid.UUID(qk.mappingQuery) != uuid.Nil { + return q.mappingQuery(qk.mappingQuery) } - q.mu.Lock() - mpr, ok := q.mappers[mk] - q.mu.Unlock() - if !ok { - q.logger.Debug(q.ctx, "query for missing mapper", - slog.F("agent_id", mk.agent), slog.F("kind", mk.kind)) - return nil + if qk.peerUpdate != uuid.Nil { + return q.peerUpdate(qk.peerUpdate) } - q.logger.Debug(q.ctx, "sending mappings", slog.F("mapping_len", len(mappings))) - mpr.mappings <- mappings - return nil + q.logger.Critical(q.ctx, "bad querierWorkKey", slog.F("work_key", qk)) + return backoff.Permanent(xerrors.Errorf("bad querierWorkKey %v", qk)) } -func (q *querier) queryClientsOfAgent(agent uuid.UUID) ([]mapping, error) { - clients, err := q.store.GetTailnetClientsForAgent(q.ctx, agent) - q.logger.Debug(q.ctx, "queried clients of agent", - slog.F("agent_id", agent), slog.F("num_clients", len(clients)), slog.Error(err)) - if err != nil { - return nil, err +// peerUpdate is work scheduled in response to a new peer->binding. We need to find out all the +// other peers that share a tunnel with the indicated peer, and then schedule a mapping update on +// each, so that they can find out about the new binding. +func (q *querier) peerUpdate(peer uuid.UUID) error { + logger := q.logger.With(slog.F("peer_id", peer)) + logger.Debug(q.ctx, "querying peers that share a tunnel") + others, err := q.store.GetTailnetTunnelPeerIDs(q.ctx, peer) + if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + return err } - mappings := make([]mapping, 0, len(clients)) - for _, client := range clients { - node := new(agpl.Node) - err := json.Unmarshal(client.Node, node) - if err != nil { - q.logger.Error(q.ctx, "failed to unmarshal node", slog.Error(err)) - return nil, backoff.Permanent(err) - } - mappings = append(mappings, mapping{ - client: client.ID, - agent: agent, - coordinator: client.CoordinatorID, - updatedAt: client.UpdatedAt, - node: node, - }) + logger.Debug(q.ctx, "queried peers that share a tunnel", slog.F("num_peers", len(others))) + for _, other := range others { + logger.Debug(q.ctx, "got tunnel peer", slog.F("other_id", other.PeerID)) + q.workQ.enqueue(querierWorkKey{mappingQuery: mKey(other.PeerID)}) } - return mappings, nil + return nil } -func (q *querier) queryAgent(agentID uuid.UUID) ([]mapping, error) { - agents, err := q.store.GetTailnetAgents(q.ctx, agentID) - q.logger.Debug(q.ctx, "queried agents", - slog.F("agent_id", agentID), slog.F("num_agents", len(agents)), slog.Error(err)) +// mappingQuery queries the database for all the mappings that the given peer should know about, +// that is, all the peers that it shares a tunnel with and their current node mappings (if they +// exist). It then sends the mapping snapshot to the corresponding mapper, where it will get +// transmitted to the peer. +func (q *querier) mappingQuery(peer mKey) error { + logger := q.logger.With(slog.F("peer_id", uuid.UUID(peer))) + logger.Debug(q.ctx, "querying mappings") + bindings, err := q.store.GetTailnetTunnelPeerBindings(q.ctx, uuid.UUID(peer)) + logger.Debug(q.ctx, "queried mappings", slog.F("num_mappings", len(bindings))) + if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + return err + } + if len(bindings) == 0 { + logger.Debug(q.ctx, "no mappings, nothing to do") + return nil + } + mappings, err := q.bindingsToMappings(bindings) if err != nil { - return nil, err + logger.Debug(q.ctx, "failed to convert mappings", slog.Error(err)) + return err + } + q.mu.Lock() + mpr, ok := q.mappers[peer] + q.mu.Unlock() + if !ok { + logger.Debug(q.ctx, "query for missing mapper") + return nil } - return q.agentsToMappings(agents) + logger.Debug(q.ctx, "sending mappings", slog.F("mapping_len", len(mappings))) + mpr.mappings <- mappings + return nil } -func (q *querier) agentsToMappings(agents []database.TailnetAgent) ([]mapping, error) { +func (q *querier) bindingsToMappings(bindings []database.GetTailnetTunnelPeerBindingsRow) ([]mapping, error) { slog.Helper() - mappings := make([]mapping, 0, len(agents)) - for _, agent := range agents { - node := new(agpl.Node) - err := json.Unmarshal(agent.Node, node) + mappings := make([]mapping, 0, len(bindings)) + for _, binding := range bindings { + node := new(proto.Node) + err := gProto.Unmarshal(binding.Node, node) if err != nil { q.logger.Error(q.ctx, "failed to unmarshal node", slog.Error(err)) return nil, backoff.Permanent(err) } mappings = append(mappings, mapping{ - agent: agent.ID, - coordinator: agent.CoordinatorID, - updatedAt: agent.UpdatedAt, + peer: binding.PeerID, + coordinator: binding.CoordinatorID, + updatedAt: binding.UpdatedAt, node: node, }) } return mappings, nil } -// subscribe starts our subscriptions to client and agent updates in a new goroutine, and returns once we are subscribed +// subscribe starts our subscriptions to peer and tunnnel updates in a new goroutine, and returns once we are subscribed // or the querier context is canceled. func (q *querier) subscribe() { subscribed := make(chan struct{}) @@ -1180,14 +1123,14 @@ func (q *querier) subscribe() { eb.MaxElapsedTime = 0 // retry indefinitely eb.MaxInterval = dbMaxBackoff bkoff := backoff.WithContext(eb, q.ctx) - var cancelClient context.CancelFunc + var cancelPeer context.CancelFunc err := backoff.Retry(func() error { - cancelFn, err := q.pubsub.SubscribeWithErr(eventClientUpdate, q.listenClient) + cancelFn, err := q.pubsub.SubscribeWithErr(eventPeerUpdate, q.listenPeer) if err != nil { - q.logger.Warn(q.ctx, "failed to subscribe to client updates", slog.Error(err)) + q.logger.Warn(q.ctx, "failed to subscribe to peer updates", slog.Error(err)) return err } - cancelClient = cancelFn + cancelPeer = cancelFn return nil }, bkoff) if err != nil { @@ -1196,18 +1139,18 @@ func (q *querier) subscribe() { } return } - defer cancelClient() + defer cancelPeer() bkoff.Reset() - q.logger.Debug(q.ctx, "subscribed to client updates") + q.logger.Debug(q.ctx, "subscribed to peer updates") - var cancelAgent context.CancelFunc + var cancelTunnel context.CancelFunc err = backoff.Retry(func() error { - cancelFn, err := q.pubsub.SubscribeWithErr(eventAgentUpdate, q.listenAgent) + cancelFn, err := q.pubsub.SubscribeWithErr(eventTunnelUpdate, q.listenTunnel) if err != nil { - q.logger.Warn(q.ctx, "failed to subscribe to agent updates", slog.Error(err)) + q.logger.Warn(q.ctx, "failed to subscribe to tunnel updates", slog.Error(err)) return err } - cancelAgent = cancelFn + cancelTunnel = cancelFn return nil }, bkoff) if err != nil { @@ -1216,8 +1159,8 @@ func (q *querier) subscribe() { } return } - defer cancelAgent() - q.logger.Debug(q.ctx, "subscribed to agent updates") + defer cancelTunnel() + q.logger.Debug(q.ctx, "subscribed to tunnel updates") // unblock the outer function from returning subscribed <- struct{}{} @@ -1228,87 +1171,68 @@ func (q *querier) subscribe() { <-subscribed } -func (q *querier) listenClient(_ context.Context, msg []byte, err error) { +func (q *querier) listenPeer(_ context.Context, msg []byte, err error) { if xerrors.Is(err, pubsub.ErrDroppedMessages) { - q.logger.Warn(q.ctx, "pubsub may have dropped client updates") - // we need to schedule a full resync of client mappings - q.resyncClientMappings() + q.logger.Warn(q.ctx, "pubsub may have dropped peer updates") + // we need to schedule a full resync of peer mappings + q.resyncPeerMappings() return } if err != nil { q.logger.Warn(q.ctx, "unhandled pubsub error", slog.Error(err)) return } - client, agent, err := parseClientUpdate(string(msg)) + peer, err := parsePeerUpdate(string(msg)) if err != nil { - q.logger.Error(q.ctx, "failed to parse client update", slog.F("msg", string(msg)), slog.Error(err)) + q.logger.Error(q.ctx, "failed to parse peer update", + slog.F("msg", string(msg)), slog.Error(err)) return } - logger := q.logger.With(slog.F("client_id", client), slog.F("agent_id", agent)) - logger.Debug(q.ctx, "got client update") - mk := mKey{ - agent: agent, - kind: agpl.QueueKindAgent, - } - q.mu.Lock() - _, ok := q.mappers[mk] - q.mu.Unlock() - if !ok { - logger.Debug(q.ctx, "ignoring update because we have no mapper") - return - } - q.workQ.enqueue(mk) + logger := q.logger.With(slog.F("peer_id", peer)) + logger.Debug(q.ctx, "got peer update") + + // we know that this peer has an updated node mapping, but we don't yet know who to send that + // update to. We need to query the database to find all the other peers that share a tunnel with + // this one, and then run mapping queries against all of them. + q.workQ.enqueue(querierWorkKey{peerUpdate: peer}) } -func (q *querier) listenAgent(_ context.Context, msg []byte, err error) { +func (q *querier) listenTunnel(_ context.Context, msg []byte, err error) { if xerrors.Is(err, pubsub.ErrDroppedMessages) { - q.logger.Warn(q.ctx, "pubsub may have dropped agent updates") - // we need to schedule a full resync of agent mappings - q.resyncAgentMappings() + q.logger.Warn(q.ctx, "pubsub may have dropped tunnel updates") + // we need to schedule a full resync of peer mappings + q.resyncPeerMappings() return } if err != nil { q.logger.Warn(q.ctx, "unhandled pubsub error", slog.Error(err)) } - agent, err := parseUpdateMessage(string(msg)) + peers, err := parseTunnelUpdate(string(msg)) if err != nil { - q.logger.Error(q.ctx, "failed to parse agent update", slog.F("msg", string(msg)), slog.Error(err)) + q.logger.Error(q.ctx, "failed to parse tunnel update", slog.F("msg", string(msg)), slog.Error(err)) return } - logger := q.logger.With(slog.F("agent_id", agent)) - logger.Debug(q.ctx, "got agent update") - mk := mKey{ - agent: agent, - kind: agpl.QueueKindClient, - } - q.mu.Lock() - _, ok := q.mappers[mk] - q.mu.Unlock() - if !ok { - logger.Debug(q.ctx, "ignoring update because we have no mapper") - return - } - q.workQ.enqueue(mk) -} - -func (q *querier) resyncClientMappings() { - q.mu.Lock() - defer q.mu.Unlock() - for mk := range q.mappers { - if mk.kind == agpl.QueueKindClient { - q.workQ.enqueue(mk) + q.logger.Debug(q.ctx, "got tunnel update", slog.F("peers", peers)) + for _, peer := range peers { + mk := mKey(peer) + q.mu.Lock() + _, ok := q.mappers[mk] + q.mu.Unlock() + if !ok { + q.logger.Debug(q.ctx, "ignoring tunnel update because we have no mapper", + slog.F("peer_id", peer)) + continue } + q.workQ.enqueue(querierWorkKey{mappingQuery: mk}) } } -func (q *querier) resyncAgentMappings() { +func (q *querier) resyncPeerMappings() { q.mu.Lock() defer q.mu.Unlock() for mk := range q.mappers { - if mk.kind == agpl.QueueKindAgent { - q.workQ.enqueue(mk) - } + q.workQ.enqueue(querierWorkKey{mappingQuery: mk}) } } @@ -1337,31 +1261,31 @@ func (q *querier) updateAll() { q.mu.Lock() defer q.mu.Unlock() - for _, cm := range q.mappers { + for _, mpr := range q.mappers { // send on goroutine to avoid holding the q.mu. Heartbeat failures come asynchronously with respect to // other kinds of work, so it's fine to deliver the command to refresh async. go func(m *mapper) { // make sure we send on the _mapper_ context, not our own in case the mapper is // shutting down or shut down. _ = sendCtx(m.ctx, m.update, struct{}{}) - }(cm.mapper) + }(mpr) } } -// unhealthyCloseAll marks the coordinator unhealthy and closes all connections. We do this so that clients and agents +// unhealthyCloseAll marks the coordinator unhealthy and closes all connections. We do this so that peers // are forced to reconnect to the coordinator, and will hopefully land on a healthy coordinator. func (q *querier) unhealthyCloseAll() { q.mu.Lock() defer q.mu.Unlock() q.healthy = false - for _, c := range q.conns { + for _, mpr := range q.mappers { // close connections async so that we don't block the querier routine that responds to updates - go func(c agpl.Queue) { + go func(c *connIO) { err := c.Close() if err != nil { q.logger.Debug(q.ctx, "error closing conn while unhealthy", slog.Error(err)) } - }(c) + }(mpr.c) // NOTE: we don't need to remove the connection from the map, as that will happen async in q.cleanupConn() } } @@ -1395,52 +1319,52 @@ func (q *querier) getAll(ctx context.Context) (map[uuid.UUID]database.TailnetAge return agentsMap, clientsMap, nil } -func parseClientUpdate(msg string) (client, agent uuid.UUID, err error) { +func parseTunnelUpdate(msg string) ([]uuid.UUID, error) { parts := strings.Split(msg, ",") if len(parts) != 2 { - return uuid.Nil, uuid.Nil, xerrors.Errorf("expected 2 parts separated by comma") - } - client, err = uuid.Parse(parts[0]) - if err != nil { - return uuid.Nil, uuid.Nil, xerrors.Errorf("failed to parse client UUID: %w", err) + return nil, xerrors.Errorf("expected 2 parts separated by comma") } - - agent, err = uuid.Parse(parts[1]) - if err != nil { - return uuid.Nil, uuid.Nil, xerrors.Errorf("failed to parse agent UUID: %w", err) + peers := make([]uuid.UUID, 2) + var err error + for i, part := range parts { + peers[i], err = uuid.Parse(part) + if err != nil { + return nil, xerrors.Errorf("failed to parse UUID: %w", err) + } } - - return client, agent, nil + return peers, nil } -func parseUpdateMessage(msg string) (agent uuid.UUID, err error) { - agent, err = uuid.Parse(msg) +func parsePeerUpdate(msg string) (peer uuid.UUID, err error) { + peer, err = uuid.Parse(msg) if err != nil { - return uuid.Nil, xerrors.Errorf("failed to parse update message UUID: %w", err) + return uuid.Nil, xerrors.Errorf("failed to parse peer update message UUID: %w", err) } - return agent, nil + return peer, nil } // mKey identifies a set of node mappings we want to query. -type mKey struct { - agent uuid.UUID - // we always query based on the agent ID, but if we have client connection(s), we query the agent itself. If we - // have an agent connection, we need the node mappings for all clients of the agent. - kind agpl.QueueKind -} +type mKey uuid.UUID -// mapping associates a particular client or agent, and its respective coordinator with a node. It is generalized to -// include clients or agents: agent mappings will have client set to uuid.Nil. +// mapping associates a particular peer, and its respective coordinator with a node. type mapping struct { - client uuid.UUID - agent uuid.UUID + peer uuid.UUID coordinator uuid.UUID updatedAt time.Time - node *agpl.Node + node *proto.Node +} + +// querierWorkKey describes two kinds of work the querier needs to do. If peerUpdate +// is not uuid.Nil, then the querier needs to find all tunnel peers of the given peer and +// mark them for a mapping query. If mappingQuery is not uuid.Nil, then the querier has to +// query the mappings of the tunnel peers of the given peer. +type querierWorkKey struct { + peerUpdate uuid.UUID + mappingQuery mKey } type queueKey interface { - mKey | bKey | sKey + bKey | tKey | querierWorkKey } // workQ allows scheduling work based on a key. Multiple enqueue requests for the same key are coalesced, and @@ -1619,7 +1543,7 @@ func (h *heartbeats) subscribe() { bErr := backoff.Retry(func() error { cancelFn, err := h.pubsub.SubscribeWithErr(EventHeartbeats, h.listen) if err != nil { - h.logger.Warn(h.ctx, "failed to subscribe to heartbeats", slog.Error(err)) + h.logger.Warn(h.ctx, "failed to tunnel to heartbeats", slog.Error(err)) return err } cancel = cancelFn diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index 7bc7c89767054..d59d437f3228c 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -3,7 +3,6 @@ package tailnet_test import ( "context" "database/sql" - "encoding/json" "io" "net" "sync" @@ -17,6 +16,7 @@ import ( "go.uber.org/goleak" "golang.org/x/exp/slices" "golang.org/x/xerrors" + gProto "google.golang.org/protobuf/proto" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" @@ -27,6 +27,7 @@ import ( "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/enterprise/tailnet" agpl "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/testutil" ) @@ -52,17 +53,17 @@ func TestPGCoordinatorSingle_ClientWithoutAgent(t *testing.T) { defer client.close() client.sendNode(&agpl.Node{PreferredDERP: 10}) require.Eventually(t, func() bool { - clients, err := store.GetTailnetClientsForAgent(ctx, agentID) + clients, err := store.GetTailnetTunnelPeerBindings(ctx, agentID) if err != nil && !xerrors.Is(err, sql.ErrNoRows) { t.Fatalf("database error: %v", err) } if len(clients) == 0 { return false } - var node agpl.Node - err = json.Unmarshal(clients[0].Node, &node) + node := new(proto.Node) + err = gProto.Unmarshal(clients[0].Node, node) assert.NoError(t, err) - assert.Equal(t, 10, node.PreferredDERP) + assert.EqualValues(t, 10, node.PreferredDerp) return true }, testutil.WaitShort, testutil.IntervalFast) @@ -90,17 +91,17 @@ func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) { defer agent.close() agent.sendNode(&agpl.Node{PreferredDERP: 10}) require.Eventually(t, func() bool { - agents, err := store.GetTailnetAgents(ctx, agent.id) + agents, err := store.GetTailnetPeers(ctx, agent.id) if err != nil && !xerrors.Is(err, sql.ErrNoRows) { t.Fatalf("database error: %v", err) } if len(agents) == 0 { return false } - var node agpl.Node - err = json.Unmarshal(agents[0].Node, &node) + node := new(proto.Node) + err = gProto.Unmarshal(agents[0].Node, node) assert.NoError(t, err) - assert.Equal(t, 10, node.PreferredDERP) + assert.EqualValues(t, 10, node.PreferredDerp) return true }, testutil.WaitShort, testutil.IntervalFast) err = agent.close() @@ -342,39 +343,51 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) { agent1 := newTestAgent(t, coord1, "agent1") defer agent1.close() + t.Logf("agent1=%s", agent1.id) agent2 := newTestAgent(t, coord2, "agent2") defer agent2.close() + t.Logf("agent2=%s", agent2.id) client11 := newTestClient(t, coord1, agent1.id) defer client11.close() + t.Logf("client11=%s", client11.id) client12 := newTestClient(t, coord1, agent2.id) defer client12.close() + t.Logf("client12=%s", client12.id) client21 := newTestClient(t, coord2, agent1.id) defer client21.close() + t.Logf("client21=%s", client21.id) client22 := newTestClient(t, coord2, agent2.id) defer client22.close() + t.Logf("client22=%s", client22.id) + t.Logf("client11 -> Node 11") client11.sendNode(&agpl.Node{PreferredDERP: 11}) assertEventuallyHasDERPs(ctx, t, agent1, 11) + t.Logf("client21 -> Node 21") client21.sendNode(&agpl.Node{PreferredDERP: 21}) - assertEventuallyHasDERPs(ctx, t, agent1, 21, 11) + assertEventuallyHasDERPs(ctx, t, agent1, 21) + t.Logf("client22 -> Node 22") client22.sendNode(&agpl.Node{PreferredDERP: 22}) assertEventuallyHasDERPs(ctx, t, agent2, 22) + t.Logf("agent2 -> Node 2") agent2.sendNode(&agpl.Node{PreferredDERP: 2}) assertEventuallyHasDERPs(ctx, t, client22, 2) assertEventuallyHasDERPs(ctx, t, client12, 2) + t.Logf("client12 -> Node 12") client12.sendNode(&agpl.Node{PreferredDERP: 12}) - assertEventuallyHasDERPs(ctx, t, agent2, 12, 22) + assertEventuallyHasDERPs(ctx, t, agent2, 12) + t.Logf("agent1 -> Node 1") agent1.sendNode(&agpl.Node{PreferredDERP: 1}) assertEventuallyHasDERPs(ctx, t, client21, 1) assertEventuallyHasDERPs(ctx, t, client11, 1) - // let's close coord2 + t.Logf("close coord2") err = coord2.Close() require.NoError(t, err) @@ -386,18 +399,9 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) { err = client21.recvErr(ctx, t) require.ErrorIs(t, err, io.EOF) - // agent1 will see an update that drops client21. - // In this case the update is superfluous because client11's node hasn't changed, and agents don't deprogram clients - // from the dataplane even if they are missing. Suppressing this kind of update would require the coordinator to - // store all the data its sent to each connection, so we don't bother. - assertEventuallyHasDERPs(ctx, t, agent1, 11) - - // note that although agent2 is disconnected, client12 does NOT get an update because we suppress empty updates. - // (Its easy to tell these are superfluous.) - assertEventuallyNoAgents(ctx, t, store, agent2.id) - // Close coord1 + t.Logf("close coord1") err = coord1.Close() require.NoError(t, err) // this closes agent1, client12, client11 @@ -541,9 +545,12 @@ func TestPGCoordinator_Unhealthy(t *testing.T) { Return(database.TailnetCoordinator{}, nil) // extra calls we don't particularly care about for this test mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil) - mStore.EXPECT().GetTailnetClientsForAgent(gomock.Any(), gomock.Any()).AnyTimes().Return(nil, nil) - mStore.EXPECT().DeleteTailnetAgent(gomock.Any(), gomock.Any()). - AnyTimes().Return(database.DeleteTailnetAgentRow{}, nil) + mStore.EXPECT().GetTailnetTunnelPeerIDs(gomock.Any(), gomock.Any()).AnyTimes().Return(nil, nil) + mStore.EXPECT().GetTailnetTunnelPeerBindings(gomock.Any(), gomock.Any()). + AnyTimes().Return(nil, nil) + mStore.EXPECT().DeleteTailnetPeer(gomock.Any(), gomock.Any()). + AnyTimes().Return(database.DeleteTailnetPeerRow{}, nil) + mStore.EXPECT().DeleteAllTailnetTunnels(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) mStore.EXPECT().DeleteCoordinator(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore) @@ -589,6 +596,34 @@ func TestPGCoordinator_Unhealthy(t *testing.T) { } } +// TestPGCoordinator_BidirectionalTunnels tests when peers create tunnels to each other. We don't +// do this now, but it's schematically possible, so we should make sure it doesn't break anything. +func TestPGCoordinator_BidirectionalTunnels(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + store, ps := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store) + require.NoError(t, err) + defer coordinator.Close() + + p1 := newTestPeer(ctx, t, coordinator, "p1") + defer p1.close(ctx) + p2 := newTestPeer(ctx, t, coordinator, "p2") + defer p2.close(ctx) + p1.addTunnel(p2.id) + p2.addTunnel(p1.id) + p1.updateDERP(1) + p2.updateDERP(2) + + p1.assertEventuallyHasDERP(p2.id, 2) + p2.assertEventuallyHasDERP(p1.id, 1) +} + type testConn struct { ws, serverWS net.Conn nodeChan chan []*agpl.Node @@ -779,7 +814,7 @@ func assertMultiAgentNeverHasDERPs(ctx context.Context, t *testing.T, ma agpl.Mu func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { assert.Eventually(t, func() bool { - agents, err := store.GetTailnetAgents(ctx, agentID) + agents, err := store.GetTailnetPeers(ctx, agentID) if xerrors.Is(err, sql.ErrNoRows) { return true } @@ -793,7 +828,7 @@ func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database. func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { t.Helper() assert.Eventually(t, func() bool { - clients, err := store.GetTailnetClientsForAgent(ctx, agentID) + clients, err := store.GetTailnetTunnelPeerIDs(ctx, agentID) if xerrors.Is(err, sql.ErrNoRows) { return true } @@ -804,6 +839,108 @@ func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store }, testutil.WaitShort, testutil.IntervalFast) } +type testPeer struct { + ctx context.Context + cancel context.CancelFunc + t testing.TB + id uuid.UUID + name string + resps <-chan *proto.CoordinateResponse + reqs chan<- *proto.CoordinateRequest + derps map[uuid.UUID]int32 +} + +func newTestPeer(ctx context.Context, t testing.TB, coord agpl.CoordinatorV2, name string, id ...uuid.UUID) *testPeer { + p := &testPeer{t: t, name: name, derps: make(map[uuid.UUID]int32)} + p.ctx, p.cancel = context.WithCancel(ctx) + if len(id) > 1 { + t.Fatal("too many") + } + if len(id) == 1 { + p.id = id[0] + } else { + p.id = uuid.New() + } + // SingleTailnetTunnelAuth allows connections to arbitrary peers + p.reqs, p.resps = coord.Coordinate(p.ctx, p.id, name, agpl.SingleTailnetTunnelAuth{}) + return p +} + +func (p *testPeer) addTunnel(other uuid.UUID) { + p.t.Helper() + req := &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(other)}} + select { + case <-p.ctx.Done(): + p.t.Errorf("timeout adding tunnel for %s", p.name) + return + case p.reqs <- req: + return + } +} + +func (p *testPeer) updateDERP(derp int32) { + p.t.Helper() + req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: &proto.Node{PreferredDerp: derp}}} + select { + case <-p.ctx.Done(): + p.t.Errorf("timeout updating node for %s", p.name) + return + case p.reqs <- req: + return + } +} + +func (p *testPeer) assertEventuallyHasDERP(other uuid.UUID, derp int32) { + p.t.Helper() + for { + d, ok := p.derps[other] + if ok && d == derp { + return + } + select { + case <-p.ctx.Done(): + p.t.Errorf("timeout waiting for response for %s", p.name) + return + case resp, ok := <-p.resps: + if !ok { + p.t.Errorf("responses closed for %s", p.name) + return + } + for _, update := range resp.PeerUpdates { + id, err := uuid.FromBytes(update.Uuid) + if !assert.NoError(p.t, err) { + return + } + switch update.Kind { + case proto.CoordinateResponse_PeerUpdate_NODE: + p.derps[id] = update.Node.PreferredDerp + case proto.CoordinateResponse_PeerUpdate_DISCONNECTED: + delete(p.derps, id) + default: + p.t.Errorf("unhandled update kind %s", update.Kind) + } + } + } + } +} + +func (p *testPeer) close(ctx context.Context) { + p.t.Helper() + p.cancel() + for { + select { + case <-ctx.Done(): + p.t.Errorf("timeout waiting for responses to close for %s", p.name) + return + case _, ok := <-p.resps: + if ok { + continue + } + return + } + } +} + type fakeCoordinator struct { ctx context.Context t *testing.T @@ -819,12 +956,15 @@ func (c *fakeCoordinator) heartbeat() { func (c *fakeCoordinator) agentNode(agentID uuid.UUID, node *agpl.Node) { c.t.Helper() - nodeRaw, err := json.Marshal(node) + pNode, err := agpl.NodeToProto(node) + require.NoError(c.t, err) + nodeRaw, err := gProto.Marshal(pNode) require.NoError(c.t, err) - _, err = c.store.UpsertTailnetAgent(c.ctx, database.UpsertTailnetAgentParams{ + _, err = c.store.UpsertTailnetPeer(c.ctx, database.UpsertTailnetPeerParams{ ID: agentID, CoordinatorID: c.id, Node: nodeRaw, + Status: database.TailnetStatusOk, }) require.NoError(c.t, err) } diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 41a75f1fc5e78..2da96bc444275 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -22,6 +22,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/util/slice" + "github.com/coder/coder/v2/tailnet/proto" ) // Coordinator exchanges nodes with agents to establish connections. @@ -48,6 +49,17 @@ type Coordinator interface { ServeMultiAgent(id uuid.UUID) MultiAgentConn } +// CoordinatorV2 is the interface for interacting with the coordinator via the 2.0 tailnet API. +type CoordinatorV2 interface { + // ServeHTTPDebug serves a debug webpage that shows the internal state of + // the coordinator. + ServeHTTPDebug(w http.ResponseWriter, r *http.Request) + // Node returns a node by peer ID, if known to the coordinator. Returns nil if unknown. + Node(id uuid.UUID) *Node + Close() error + Coordinate(ctx context.Context, id uuid.UUID, name string, a TunnelAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) +} + // Node represents a node in the network. type Node struct { // ID is used to identify the connection. diff --git a/tailnet/proto/compare.go b/tailnet/proto/compare.go new file mode 100644 index 0000000000000..012ac293a07c3 --- /dev/null +++ b/tailnet/proto/compare.go @@ -0,0 +1,20 @@ +package proto + +import ( + "bytes" + + gProto "google.golang.org/protobuf/proto" +) + +// Equal returns true if the nodes have the same contents +func (s *Node) Equal(o *Node) (bool, error) { + sBytes, err := gProto.Marshal(s) + if err != nil { + return false, err + } + oBytes, err := gProto.Marshal(o) + if err != nil { + return false, err + } + return bytes.Equal(sBytes, oBytes), nil +} diff --git a/tailnet/trackedconn.go b/tailnet/trackedconn.go index be464b2327921..d083c838b238b 100644 --- a/tailnet/trackedconn.go +++ b/tailnet/trackedconn.go @@ -13,9 +13,14 @@ import ( "cdr.dev/slog" ) -// WriteTimeout is the amount of time we wait to write a node update to a connection before we declare it hung. -// It is exported so that tests can use it. -const WriteTimeout = time.Second * 5 +const ( + // WriteTimeout is the amount of time we wait to write a node update to a connection before we + // declare it hung. It is exported so that tests can use it. + WriteTimeout = time.Second * 5 + // ResponseBufferSize is the max number of responses to buffer per connection before we start + // dropping updates + ResponseBufferSize = 512 +) type TrackedConn struct { ctx context.Context @@ -48,7 +53,7 @@ func NewTrackedConn(ctx context.Context, cancel func(), // coordinator mutex while queuing. Node updates don't // come quickly, so 512 should be plenty for all but // the most pathological cases. - updates := make(chan []*Node, 512) + updates := make(chan []*Node, ResponseBufferSize) now := time.Now().Unix() return &TrackedConn{ ctx: ctx, diff --git a/tailnet/tunnel.go b/tailnet/tunnel.go new file mode 100644 index 0000000000000..19f4a485dc817 --- /dev/null +++ b/tailnet/tunnel.go @@ -0,0 +1,30 @@ +package tailnet + +import "github.com/google/uuid" + +type TunnelAuth interface { + Authorize(dst uuid.UUID) bool +} + +// SingleTailnetTunnelAuth allows all tunnels, since Coderd and wsproxy are allowed to initiate a tunnel to any agent +type SingleTailnetTunnelAuth struct{} + +func (SingleTailnetTunnelAuth) Authorize(uuid.UUID) bool { + return true +} + +// ClientTunnelAuth allows connecting to a single, given agent +type ClientTunnelAuth struct { + AgentID uuid.UUID +} + +func (c ClientTunnelAuth) Authorize(dst uuid.UUID) bool { + return c.AgentID == dst +} + +// AgentTunnelAuth disallows all tunnels, since agents are not allowed to initiate their own tunnels +type AgentTunnelAuth struct{} + +func (AgentTunnelAuth) Authorize(uuid.UUID) bool { + return false +}