From a20d318c7358936d18698e255dab317e4583c3cb Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 25 Aug 2023 22:46:46 +0000 Subject: [PATCH 01/21] feat: add single tailnet support to pgcoord --- coderd/coderd.go | 2 +- coderd/database/dbtestutil/db.go | 4 + coderd/database/dump.sql | 15 +- ...151_pg_coordinator_single_tailnet.down.sql | 19 + ...00151_pg_coordinator_single_tailnet.up.sql | 42 ++ coderd/database/models.go | 2 +- coderd/database/querier.go | 2 +- coderd/database/queries.sql.go | 27 +- coderd/database/queries/tailnet.sql | 9 +- coderd/tailnet_test.go | 2 +- enterprise/coderd/workspaceproxycoordinate.go | 2 +- .../coderd/workspaceproxycoordinator_test.go | 6 +- enterprise/tailnet/connio.go | 141 +++++ enterprise/tailnet/coordinator.go | 10 +- enterprise/tailnet/pgcoord.go | 509 +++++++++++------- enterprise/tailnet/pgcoord_test.go | 76 ++- enterprise/wsproxy/wsproxysdk/wsproxysdk.go | 3 +- tailnet/coordinator.go | 22 +- tailnet/multiagent.go | 17 +- tailnet/trackedconn.go | 19 +- 20 files changed, 671 insertions(+), 258 deletions(-) create mode 100644 coderd/database/migrations/000151_pg_coordinator_single_tailnet.down.sql create mode 100644 coderd/database/migrations/000151_pg_coordinator_single_tailnet.up.sql create mode 100644 enterprise/tailnet/connio.go diff --git a/coderd/coderd.go b/coderd/coderd.go index f71e68195509a..175394b64f63d 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -406,7 +406,7 @@ func New(options *Options) *API { api.DERPMap, options.DeploymentValues.DERP.Config.ForceWebSockets.Value(), func(context.Context) (tailnet.MultiAgentConn, error) { - return (*api.TailnetCoordinator.Load()).ServeMultiAgent(uuid.New()), nil + return (*api.TailnetCoordinator.Load()).ServeMultiAgent(uuid.New()) }, wsconncache.New(api._dialWorkspaceAgentTailnet, 0), api.TracerProvider, diff --git a/coderd/database/dbtestutil/db.go b/coderd/database/dbtestutil/db.go index 00eae9dd11218..36ecbf10d6b73 100644 --- a/coderd/database/dbtestutil/db.go +++ b/coderd/database/dbtestutil/db.go @@ -10,6 +10,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbfake" + "github.com/coder/coder/v2/coderd/database/migrations" "github.com/coder/coder/v2/coderd/database/postgres" "github.com/coder/coder/v2/coderd/database/pubsub" ) @@ -42,6 +43,9 @@ func NewDB(t testing.TB) (database.Store, pubsub.Pubsub) { }) db = database.New(sqlDB) + err = migrations.Up(sqlDB) + require.NoError(t, err) + ps, err = pubsub.New(context.Background(), sqlDB, connectionURL) require.NoError(t, err) t.Cleanup(func() { diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 3ee0ac7e19894..31a6d98af211c 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -220,12 +220,13 @@ CREATE FUNCTION tailnet_notify_client_change() RETURNS trigger LANGUAGE plpgsql AS $$ BEGIN - IF (OLD IS NOT NULL) THEN - PERFORM pg_notify('tailnet_client_update', OLD.id || ',' || OLD.agent_id); + -- check new first to get the updated agent ids. + IF (NEW IS NOT NULL) THEN + PERFORM pg_notify('tailnet_client_update', NEW.id || ',' || array_to_string(NEW.agent_ids, ',')); RETURN NULL; END IF; - IF (NEW IS NOT NULL) THEN - PERFORM pg_notify('tailnet_client_update', NEW.id || ',' || NEW.agent_id); + IF (OLD IS NOT NULL) THEN + PERFORM pg_notify('tailnet_client_update', OLD.id || ',' || array_to_string(OLD.agent_ids, ',')); RETURN NULL; END IF; END; @@ -498,9 +499,9 @@ CREATE TABLE tailnet_agents ( CREATE TABLE tailnet_clients ( id uuid NOT NULL, coordinator_id uuid NOT NULL, - agent_id uuid NOT NULL, updated_at timestamp with time zone NOT NULL, - node jsonb NOT NULL + node jsonb NOT NULL, + agent_ids uuid[] NOT NULL ); CREATE TABLE tailnet_coordinators ( @@ -1248,7 +1249,7 @@ CREATE UNIQUE INDEX idx_organization_name_lower ON organizations USING btree (lo CREATE INDEX idx_tailnet_agents_coordinator ON tailnet_agents USING btree (coordinator_id); -CREATE INDEX idx_tailnet_clients_agent ON tailnet_clients USING btree (agent_id); +CREATE INDEX idx_tailnet_clients_agent_ids ON tailnet_clients USING gin (agent_ids); CREATE INDEX idx_tailnet_clients_coordinator ON tailnet_clients USING btree (coordinator_id); diff --git a/coderd/database/migrations/000151_pg_coordinator_single_tailnet.down.sql b/coderd/database/migrations/000151_pg_coordinator_single_tailnet.down.sql new file mode 100644 index 0000000000000..9ce2798205ef8 --- /dev/null +++ b/coderd/database/migrations/000151_pg_coordinator_single_tailnet.down.sql @@ -0,0 +1,19 @@ +BEGIN; + +-- ALTER TABLE +-- tailnet_clients +-- ADD COLUMN +-- agent_id uuid; + +UPDATE + tailnet_clients +SET + -- grab just the first agent_id, or default to an empty UUID. + agent_id = COALESCE(agent_ids[0], '00000000-0000-0000-0000-000000000000'::uuid); + +ALTER TABLE + tailnet_clients +DROP COLUMN + agent_ids; + +COMMIT; diff --git a/coderd/database/migrations/000151_pg_coordinator_single_tailnet.up.sql b/coderd/database/migrations/000151_pg_coordinator_single_tailnet.up.sql new file mode 100644 index 0000000000000..f59e893942b27 --- /dev/null +++ b/coderd/database/migrations/000151_pg_coordinator_single_tailnet.up.sql @@ -0,0 +1,42 @@ +BEGIN; + +ALTER TABLE + tailnet_clients +ADD COLUMN + agent_ids uuid[]; + +UPDATE + tailnet_clients +SET + agent_ids = ARRAY[agent_id]::uuid[]; + +ALTER TABLE + tailnet_clients +ALTER COLUMN + agent_ids SET NOT NULL; + + +CREATE INDEX idx_tailnet_clients_agent_ids ON tailnet_clients USING GIN (agent_ids); + +CREATE OR REPLACE FUNCTION tailnet_notify_client_change() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + -- check new first to get the updated agent ids. + IF (NEW IS NOT NULL) THEN + PERFORM pg_notify('tailnet_client_update', NEW.id || ',' || array_to_string(NEW.agent_ids, ',')); + RETURN NULL; + END IF; + IF (OLD IS NOT NULL) THEN + PERFORM pg_notify('tailnet_client_update', OLD.id || ',' || array_to_string(OLD.agent_ids, ',')); + RETURN NULL; + END IF; +END; +$$; + +ALTER TABLE + tailnet_clients +DROP COLUMN + agent_id; + +COMMIT; diff --git a/coderd/database/models.go b/coderd/database/models.go index 4d1852a54114e..1a95dd6e5cd70 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -1783,9 +1783,9 @@ type TailnetAgent struct { type TailnetClient struct { ID uuid.UUID `db:"id" json:"id"` CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` - AgentID uuid.UUID `db:"agent_id" json:"agent_id"` UpdatedAt time.Time `db:"updated_at" json:"updated_at"` Node json.RawMessage `db:"node" json:"node"` + AgentIds []uuid.UUID `db:"agent_ids" json:"agent_ids"` } // We keep this separate from replicas in case we need to break the coordinator out into its own service diff --git a/coderd/database/querier.go b/coderd/database/querier.go index cdf4d184544bb..15caa096e3fb0 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -108,7 +108,7 @@ type sqlcQuerier interface { GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error) GetServiceBanner(ctx context.Context) (string, error) GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]TailnetAgent, error) - GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]TailnetClient, error) + GetTailnetClientsForAgent(ctx context.Context, dollar_1 uuid.UUID) ([]TailnetClient, error) // GetTemplateAppInsights returns the aggregate usage of each app in a given // timeframe. The result can be filtered on template_ids, meaning only user data // from workspaces based on those templates will be included. diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 4d9bc72a37157..2a6fd22e60fe9 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -4172,9 +4172,8 @@ func (q *sqlQuerier) GetAllTailnetAgents(ctx context.Context) ([]TailnetAgent, e } const getAllTailnetClients = `-- name: GetAllTailnetClients :many -SELECT id, coordinator_id, agent_id, updated_at, node +SELECT id, coordinator_id, updated_at, node, agent_ids FROM tailnet_clients -ORDER BY agent_id ` func (q *sqlQuerier) GetAllTailnetClients(ctx context.Context) ([]TailnetClient, error) { @@ -4189,9 +4188,9 @@ func (q *sqlQuerier) GetAllTailnetClients(ctx context.Context) ([]TailnetClient, if err := rows.Scan( &i.ID, &i.CoordinatorID, - &i.AgentID, &i.UpdatedAt, &i.Node, + pq.Array(&i.AgentIds), ); err != nil { return nil, err } @@ -4241,13 +4240,13 @@ func (q *sqlQuerier) GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]Tail } const getTailnetClientsForAgent = `-- name: GetTailnetClientsForAgent :many -SELECT id, coordinator_id, agent_id, updated_at, node +SELECT id, coordinator_id, updated_at, node, agent_ids FROM tailnet_clients -WHERE agent_id = $1 +WHERE $1::uuid = ANY(agent_ids) ` -func (q *sqlQuerier) GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]TailnetClient, error) { - rows, err := q.db.QueryContext(ctx, getTailnetClientsForAgent, agentID) +func (q *sqlQuerier) GetTailnetClientsForAgent(ctx context.Context, dollar_1 uuid.UUID) ([]TailnetClient, error) { + rows, err := q.db.QueryContext(ctx, getTailnetClientsForAgent, dollar_1) if err != nil { return nil, err } @@ -4258,9 +4257,9 @@ func (q *sqlQuerier) GetTailnetClientsForAgent(ctx context.Context, agentID uuid if err := rows.Scan( &i.ID, &i.CoordinatorID, - &i.AgentID, &i.UpdatedAt, &i.Node, + pq.Array(&i.AgentIds), ); err != nil { return nil, err } @@ -4317,7 +4316,7 @@ INSERT INTO tailnet_clients ( id, coordinator_id, - agent_id, + agent_ids, node, updated_at ) @@ -4327,16 +4326,16 @@ ON CONFLICT (id, coordinator_id) DO UPDATE SET id = $1, coordinator_id = $2, - agent_id = $3, + agent_ids = $3, node = $4, updated_at = now() at time zone 'utc' -RETURNING id, coordinator_id, agent_id, updated_at, node +RETURNING id, coordinator_id, updated_at, node, agent_ids ` type UpsertTailnetClientParams struct { ID uuid.UUID `db:"id" json:"id"` CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` - AgentID uuid.UUID `db:"agent_id" json:"agent_id"` + AgentIds []uuid.UUID `db:"agent_ids" json:"agent_ids"` Node json.RawMessage `db:"node" json:"node"` } @@ -4344,16 +4343,16 @@ func (q *sqlQuerier) UpsertTailnetClient(ctx context.Context, arg UpsertTailnetC row := q.db.QueryRowContext(ctx, upsertTailnetClient, arg.ID, arg.CoordinatorID, - arg.AgentID, + pq.Array(arg.AgentIds), arg.Node, ) var i TailnetClient err := row.Scan( &i.ID, &i.CoordinatorID, - &i.AgentID, &i.UpdatedAt, &i.Node, + pq.Array(&i.AgentIds), ) return i, err } diff --git a/coderd/database/queries/tailnet.sql b/coderd/database/queries/tailnet.sql index fd2db296dfa54..1b63261b90d86 100644 --- a/coderd/database/queries/tailnet.sql +++ b/coderd/database/queries/tailnet.sql @@ -3,7 +3,7 @@ INSERT INTO tailnet_clients ( id, coordinator_id, - agent_id, + agent_ids, node, updated_at ) @@ -13,7 +13,7 @@ ON CONFLICT (id, coordinator_id) DO UPDATE SET id = $1, coordinator_id = $2, - agent_id = $3, + agent_ids = $3, node = $4, updated_at = now() at time zone 'utc' RETURNING *; @@ -66,12 +66,11 @@ FROM tailnet_agents; -- name: GetTailnetClientsForAgent :many SELECT * FROM tailnet_clients -WHERE agent_id = $1; +WHERE $1::uuid = ANY(agent_ids); -- name: GetAllTailnetClients :many SELECT * -FROM tailnet_clients -ORDER BY agent_id; +FROM tailnet_clients; -- name: UpsertTailnetCoordinator :one INSERT INTO diff --git a/coderd/tailnet_test.go b/coderd/tailnet_test.go index 2a0b0dfdbae70..8b1f55d9e994d 100644 --- a/coderd/tailnet_test.go +++ b/coderd/tailnet_test.go @@ -233,7 +233,7 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A derpServer, func() *tailcfg.DERPMap { return manifest.DERPMap }, false, - func(context.Context) (tailnet.MultiAgentConn, error) { return coord.ServeMultiAgent(uuid.New()), nil }, + func(context.Context) (tailnet.MultiAgentConn, error) { return coord.ServeMultiAgent(uuid.New()) }, cache, trace.NewNoopTracerProvider(), ) diff --git a/enterprise/coderd/workspaceproxycoordinate.go b/enterprise/coderd/workspaceproxycoordinate.go index ec454d73a870a..6fc0bcd6b18f7 100644 --- a/enterprise/coderd/workspaceproxycoordinate.go +++ b/enterprise/coderd/workspaceproxycoordinate.go @@ -67,7 +67,7 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request } id := uuid.New() - sub := (*api.AGPL.TailnetCoordinator.Load()).ServeMultiAgent(id) + sub, err := (*api.AGPL.TailnetCoordinator.Load()).ServeMultiAgent(id) ctx, nc := websocketNetConn(ctx, conn, websocket.MessageText) defer nc.Close() diff --git a/enterprise/coderd/workspaceproxycoordinator_test.go b/enterprise/coderd/workspaceproxycoordinator_test.go index de72c288b2eee..fb991180b3adc 100644 --- a/enterprise/coderd/workspaceproxycoordinator_test.go +++ b/enterprise/coderd/workspaceproxycoordinator_test.go @@ -59,7 +59,8 @@ func Test_agentIsLegacy(t *testing.T) { defer cancel() nodeID := uuid.New() - ma := coordinator.ServeMultiAgent(nodeID) + ma, err := coordinator.ServeMultiAgent(nodeID) + require.NoError(t, err) defer ma.Close() require.NoError(t, ma.UpdateSelf(&agpl.Node{ ID: 55, @@ -123,7 +124,8 @@ func Test_agentIsLegacy(t *testing.T) { defer cancel() nodeID := uuid.New() - ma := coordinator.ServeMultiAgent(nodeID) + ma, err := coordinator.ServeMultiAgent(nodeID) + require.NoError(t, err) defer ma.Close() require.NoError(t, ma.UpdateSelf(&agpl.Node{ ID: 55, diff --git a/enterprise/tailnet/connio.go b/enterprise/tailnet/connio.go new file mode 100644 index 0000000000000..72e378f045b1f --- /dev/null +++ b/enterprise/tailnet/connio.go @@ -0,0 +1,141 @@ +package tailnet + +import ( + "context" + "encoding/json" + "io" + "net" + + "github.com/google/uuid" + "golang.org/x/xerrors" + "nhooyr.io/websocket" + + "cdr.dev/slog" + agpl "github.com/coder/coder/v2/tailnet" +) + +// connIO manages the reading and writing to a connected client or agent. Agent connIOs have their client field set to +// uuid.Nil. It reads node updates via its decoder, then pushes them onto the bindings channel. It receives mappings +// via its updates TrackedConn, which then writes them. +type connIO struct { + pCtx context.Context + ctx context.Context + cancel context.CancelFunc + logger slog.Logger + subscriptions []uuid.UUID + decoder *json.Decoder + updates *agpl.TrackedConn + bindings chan<- binding +} + +func newConnIO(pCtx context.Context, + logger slog.Logger, + bindings chan<- binding, + conn net.Conn, + id uuid.UUID, + subs []uuid.UUID, + name string, + kind agpl.QueueKind, +) *connIO { + ctx, cancel := context.WithCancel(pCtx) + c := &connIO{ + pCtx: pCtx, + ctx: ctx, + cancel: cancel, + logger: logger, + subscriptions: subs, + decoder: json.NewDecoder(conn), + updates: agpl.NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, kind), + bindings: bindings, + } + go c.recvLoop() + go c.updates.SendUpdates() + logger.Info(ctx, "serving connection") + return c +} + +func (c *connIO) recvLoop() { + defer func() { + // withdraw bindings when we exit. We need to use the parent context here, since our own context might be + // canceled, but we still need to withdraw bindings. + b := binding{ + bKey: bKey{ + id: c.UniqueID(), + kind: c.Kind(), + }, + } + if err := sendCtx(c.pCtx, c.bindings, b); err != nil { + c.logger.Debug(c.ctx, "parent context expired while withdrawing bindings", slog.Error(err)) + } + }() + defer c.cancel() + for { + var node agpl.Node + err := c.decoder.Decode(&node) + if err != nil { + if xerrors.Is(err, io.EOF) || + xerrors.Is(err, io.ErrClosedPipe) || + xerrors.Is(err, context.Canceled) || + xerrors.Is(err, context.DeadlineExceeded) || + websocket.CloseStatus(err) > 0 { + c.logger.Debug(c.ctx, "exiting recvLoop", slog.Error(err)) + } else { + c.logger.Error(c.ctx, "failed to decode Node update", slog.Error(err)) + } + return + } + c.logger.Debug(c.ctx, "got node update", slog.F("node", node)) + b := binding{ + bKey: bKey{ + id: c.UniqueID(), + kind: c.Kind(), + }, + subscriptions: c.subscriptions, + node: &node, + } + if err := sendCtx(c.ctx, c.bindings, b); err != nil { + c.logger.Debug(c.ctx, "recvLoop ctx expired", slog.Error(err)) + return + } + } +} + +func (c *connIO) UniqueID() uuid.UUID { + return c.updates.UniqueID() +} + +func (c *connIO) Kind() agpl.QueueKind { + return c.updates.Kind() +} + +func (c *connIO) Enqueue(n []*agpl.Node) error { + return c.updates.Enqueue(n) +} + +func (c *connIO) Name() string { + return c.updates.Name() +} + +func (c *connIO) Stats() (start int64, lastWrite int64) { + return c.updates.Stats() +} + +func (c *connIO) Overwrites() int64 { + return c.updates.Overwrites() +} + +// CoordinatorClose is used by the coordinator when closing a Queue. It +// should skip removing itself from the coordinator. +func (c *connIO) CoordinatorClose() error { + c.cancel() + return c.updates.CoordinatorClose() +} + +func (c *connIO) Done() <-chan struct{} { + return c.ctx.Done() +} + +func (c *connIO) Close() error { + c.cancel() + return c.updates.Close() +} diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go index d97bf2cce7a6c..9a04670d78b02 100644 --- a/enterprise/tailnet/coordinator.go +++ b/enterprise/tailnet/coordinator.go @@ -52,16 +52,16 @@ func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, err return coord, nil } -func (c *haCoordinator) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn { +func (c *haCoordinator) ServeMultiAgent(id uuid.UUID) (agpl.MultiAgentConn, error) { m := (&agpl.MultiAgent{ ID: id, AgentIsLegacyFunc: c.agentIsLegacy, OnSubscribe: c.clientSubscribeToAgent, OnNodeUpdate: c.clientNodeUpdate, - OnRemove: c.clientDisconnected, + OnRemove: func(enq agpl.Queue) { c.clientDisconnected(enq.UniqueID()) }, }).Init() c.addClient(id, m) - return m + return m, nil } func (c *haCoordinator) addClient(id uuid.UUID, q agpl.Queue) { @@ -157,7 +157,7 @@ func (c *haCoordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error defer cancel() logger := c.clientLogger(id, agentID) - tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0) + tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0, agpl.QueueKindClient) defer tc.Close() c.addClient(id, tc) @@ -300,7 +300,7 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err } // This uniquely identifies a connection that belongs to this goroutine. unique := uuid.New() - tc := agpl.NewTrackedConn(ctx, cancel, conn, unique, logger, name, overwrites) + tc := agpl.NewTrackedConn(ctx, cancel, conn, unique, logger, name, overwrites, agpl.QueueKindAgent) // Publish all nodes on this instance that want to connect to this agent. nodes := c.nodesSubscribedToAgent(id) diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 62cc5a240cd98..4a55650303abb 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -4,7 +4,7 @@ import ( "context" "database/sql" "encoding/json" - "io" + "fmt" "net" "net/http" "strings" @@ -15,7 +15,6 @@ import ( "github.com/google/uuid" "golang.org/x/exp/slices" "golang.org/x/xerrors" - "nhooyr.io/websocket" "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" @@ -72,7 +71,7 @@ type pgCoord struct { store database.Store bindings chan binding - newConnections chan *connIO + newConnections chan agpl.Queue id uuid.UUID cancel context.CancelFunc @@ -106,7 +105,7 @@ func NewPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store id := uuid.New() logger = logger.Named("pgcoord").With(slog.F("coordinator_id", id)) bCh := make(chan binding) - cCh := make(chan *connIO) + cCh := make(chan agpl.Queue) // signals when first heartbeat has been sent, so it's safe to start binding. fHB := make(chan struct{}) @@ -127,9 +126,44 @@ func NewPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store return c, nil } -func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn { - _, _ = c, id - panic("not implemented") // TODO: Implement +func (c *pgCoord) ServeMultiAgent(id uuid.UUID) (agpl.MultiAgentConn, error) { + ma := (&agpl.MultiAgent{ + ID: id, + AgentIsLegacyFunc: func(agentID uuid.UUID) bool { return true }, + OnSubscribe: func(enq agpl.Queue, agent uuid.UUID) (*agpl.Node, error) { + c.querier.newClientSubscription(enq, agent) + return c.Node(agent), nil + }, + OnUnsubscribe: func(enq agpl.Queue, agent uuid.UUID) error { + c.querier.removeClientSubscription(enq, agent) + return nil + }, + OnNodeUpdate: func(id uuid.UUID, node *agpl.Node) error { + return sendCtx(c.ctx, c.bindings, binding{ + bKey: bKey{id, agpl.QueueKindClient}, + node: node, + subscriptions: c.querier.getClientSubscriptions(id), + }) + }, + OnRemove: func(enq agpl.Queue) { + b := binding{ + bKey: bKey{ + id: enq.UniqueID(), + kind: enq.Kind(), + }, + } + if err := sendCtx(c.ctx, c.bindings, b); err != nil { + c.logger.Debug(c.ctx, "parent context expired while withdrawing bindings", slog.Error(err)) + } + c.querier.cleanupConn(enq) + }, + }).Init() + + if err := sendCtx(c.ctx, c.newConnections, agpl.Queue(ma)); err != nil { + return nil, err + } + + return ma, nil } func (c *pgCoord) Node(id uuid.UUID) *agpl.Node { @@ -162,11 +196,12 @@ func (c *pgCoord) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) erro slog.Error(err)) } }() - cIO := newConnIO(c.ctx, c.logger, c.bindings, conn, id, agent, id.String()) - if err := sendCtx(c.ctx, c.newConnections, cIO); err != nil { + cIO := newConnIO(c.ctx, c.logger, c.bindings, conn, id, []uuid.UUID{agent}, id.String(), agpl.QueueKindClient) + if err := sendCtx(c.ctx, c.newConnections, agpl.Queue(cIO)); err != nil { // can only be a context error, no need to log here. return err } + c.querier.newClientSubscription(cIO, agent) <-cIO.ctx.Done() return nil } @@ -181,8 +216,8 @@ func (c *pgCoord) ServeAgent(conn net.Conn, id uuid.UUID, name string) error { } }() logger := c.logger.With(slog.F("name", name)) - cIO := newConnIO(c.ctx, logger, c.bindings, conn, uuid.Nil, id, name) - if err := sendCtx(c.ctx, c.newConnections, cIO); err != nil { + cIO := newConnIO(c.ctx, logger, c.bindings, conn, id, nil, name, agpl.QueueKindAgent) + if err := sendCtx(c.ctx, c.newConnections, agpl.Queue(cIO)); err != nil { // can only be a context error, no need to log here. return err } @@ -197,97 +232,6 @@ func (c *pgCoord) Close() error { return nil } -// connIO manages the reading and writing to a connected client or agent. Agent connIOs have their client field set to -// uuid.Nil. It reads node updates via its decoder, then pushes them onto the bindings channel. It receives mappings -// via its updates TrackedConn, which then writes them. -type connIO struct { - pCtx context.Context - ctx context.Context - cancel context.CancelFunc - logger slog.Logger - client uuid.UUID - agent uuid.UUID - decoder *json.Decoder - updates *agpl.TrackedConn - bindings chan<- binding -} - -func newConnIO(pCtx context.Context, - logger slog.Logger, - bindings chan<- binding, - conn net.Conn, - client, agent uuid.UUID, - name string, -) *connIO { - ctx, cancel := context.WithCancel(pCtx) - id := agent - logger = logger.With(slog.F("agent_id", agent)) - if client != uuid.Nil { - logger = logger.With(slog.F("client_id", client)) - id = client - } - c := &connIO{ - pCtx: pCtx, - ctx: ctx, - cancel: cancel, - logger: logger, - client: client, - agent: agent, - decoder: json.NewDecoder(conn), - updates: agpl.NewTrackedConn(ctx, cancel, conn, id, logger, name, 0), - bindings: bindings, - } - go c.recvLoop() - go c.updates.SendUpdates() - logger.Info(ctx, "serving connection") - return c -} - -func (c *connIO) recvLoop() { - defer func() { - // withdraw bindings when we exit. We need to use the parent context here, since our own context might be - // canceled, but we still need to withdraw bindings. - b := binding{ - bKey: bKey{ - client: c.client, - agent: c.agent, - }, - } - if err := sendCtx(c.pCtx, c.bindings, b); err != nil { - c.logger.Debug(c.ctx, "parent context expired while withdrawing bindings", slog.Error(err)) - } - }() - defer c.cancel() - for { - var node agpl.Node - err := c.decoder.Decode(&node) - if err != nil { - if xerrors.Is(err, io.EOF) || - xerrors.Is(err, io.ErrClosedPipe) || - xerrors.Is(err, context.Canceled) || - xerrors.Is(err, context.DeadlineExceeded) || - websocket.CloseStatus(err) > 0 { - c.logger.Debug(c.ctx, "exiting recvLoop", slog.Error(err)) - } else { - c.logger.Error(c.ctx, "failed to decode Node update", slog.Error(err)) - } - return - } - c.logger.Debug(c.ctx, "got node update", slog.F("node", node)) - b := binding{ - bKey: bKey{ - client: c.client, - agent: c.agent, - }, - node: &node, - } - if err := sendCtx(c.ctx, c.bindings, b); err != nil { - c.logger.Debug(c.ctx, "recvLoop ctx expired", slog.Error(err)) - return - } - } -} - func sendCtx[A any](ctx context.Context, c chan<- A, a A) (err error) { select { case <-ctx.Done(): @@ -297,20 +241,23 @@ func sendCtx[A any](ctx context.Context, c chan<- A, a A) (err error) { } } -// bKey, or "binding key" identifies a client or agent in a binding. Agents have their client field set to uuid.Nil. +// bKey, or "binding key" identifies a client or agent in a binding. Agents have their client field set to uuid.Nil, +// while clients have their agent field set to uuid.Nil. type bKey struct { - client uuid.UUID - agent uuid.UUID + id uuid.UUID + kind agpl.QueueKind } // binding represents an association between a client or agent and a Node. type binding struct { bKey - node *agpl.Node + // subscriptions is a list of agents a client is subscribed to. + subscriptions []uuid.UUID + node *agpl.Node } -func (b *binding) isAgent() bool { return b.client == uuid.Nil } -func (b *binding) isClient() bool { return b.client != uuid.Nil } +func (b *binding) isAgent() bool { return b.kind == agpl.QueueKindAgent } +func (b *binding) isClient() bool { return b.kind == agpl.QueueKindClient } // binder reads node bindings from the channel and writes them to the database. It handles retries with a backoff. type binder struct { @@ -325,9 +272,12 @@ type binder struct { workQ *workQ[bKey] } -func newBinder(ctx context.Context, logger slog.Logger, - id uuid.UUID, store database.Store, - bindings <-chan binding, startWorkers <-chan struct{}, +func newBinder(ctx context.Context, + logger slog.Logger, + id uuid.UUID, + store database.Store, + bindings <-chan binding, + startWorkers <-chan struct{}, ) *binder { b := &binder{ ctx: ctx, @@ -399,40 +349,40 @@ func (b *binder) writeOne(bnd binding) error { switch { case bnd.isAgent() && len(nodeRaw) > 0: _, err = b.store.UpsertTailnetAgent(b.ctx, database.UpsertTailnetAgentParams{ - ID: bnd.agent, + ID: bnd.id, CoordinatorID: b.coordinatorID, Node: nodeRaw, }) b.logger.Debug(b.ctx, "upserted agent binding", - slog.F("agent_id", bnd.agent), slog.F("node", nodeRaw), slog.Error(err)) + slog.F("agent_id", bnd.id), slog.F("node", nodeRaw), slog.Error(err)) case bnd.isAgent() && len(nodeRaw) == 0: _, err = b.store.DeleteTailnetAgent(b.ctx, database.DeleteTailnetAgentParams{ - ID: bnd.agent, + ID: bnd.id, CoordinatorID: b.coordinatorID, }) b.logger.Debug(b.ctx, "deleted agent binding", - slog.F("agent_id", bnd.agent), slog.Error(err)) + slog.F("agent_id", bnd.id), slog.Error(err)) if xerrors.Is(err, sql.ErrNoRows) { // treat deletes as idempotent err = nil } case bnd.isClient() && len(nodeRaw) > 0: _, err = b.store.UpsertTailnetClient(b.ctx, database.UpsertTailnetClientParams{ - ID: bnd.client, + ID: bnd.id, CoordinatorID: b.coordinatorID, - AgentID: bnd.agent, + AgentIds: bnd.subscriptions, Node: nodeRaw, }) b.logger.Debug(b.ctx, "upserted client binding", - slog.F("agent_id", bnd.agent), slog.F("client_id", bnd.client), + slog.F("subscriptions", bnd.subscriptions), slog.F("client_id", bnd.id), slog.F("node", nodeRaw), slog.Error(err)) case bnd.isClient() && len(nodeRaw) == 0: _, err = b.store.DeleteTailnetClient(b.ctx, database.DeleteTailnetClientParams{ - ID: bnd.client, + ID: bnd.id, CoordinatorID: b.coordinatorID, }) b.logger.Debug(b.ctx, "deleted client binding", - slog.F("agent_id", bnd.agent), slog.F("client_id", bnd.client), slog.Error(err)) + slog.F("subscriptions", bnd.subscriptions), slog.F("client_id", bnd.id)) if xerrors.Is(err, sql.ErrNoRows) { // treat deletes as idempotent err = nil @@ -442,8 +392,8 @@ func (b *binder) writeOne(bnd binding) error { } if err != nil && !database.IsQueryCanceledError(err) { b.logger.Error(b.ctx, "failed to write binding to database", - slog.F("client_id", bnd.client), - slog.F("agent_id", bnd.agent), + slog.F("binding_id", bnd.id), + slog.F("kind", bnd.kind), slog.F("node", string(nodeRaw)), slog.Error(err)) } @@ -483,8 +433,8 @@ type mapper struct { ctx context.Context logger slog.Logger - add chan *connIO - del chan *connIO + add chan agpl.Queue + del chan agpl.Queue // reads from this channel trigger sending latest nodes to // all connections. It is used when coordinators are added @@ -493,7 +443,7 @@ type mapper struct { mappings chan []mapping - conns map[bKey]*connIO + conns map[bKey]agpl.Queue latest []mapping heartbeats *heartbeats @@ -502,15 +452,15 @@ type mapper struct { func newMapper(ctx context.Context, logger slog.Logger, mk mKey, h *heartbeats) *mapper { logger = logger.With( slog.F("agent_id", mk.agent), - slog.F("clients_of_agent", mk.clientsOfAgent), + slog.F("kind", mk.kind), ) m := &mapper{ ctx: ctx, logger: logger, - add: make(chan *connIO), - del: make(chan *connIO), + add: make(chan agpl.Queue), + del: make(chan agpl.Queue), update: make(chan struct{}), - conns: make(map[bKey]*connIO), + conns: make(map[bKey]agpl.Queue), mappings: make(chan []mapping), heartbeats: h, } @@ -524,17 +474,17 @@ func (m *mapper) run() { case <-m.ctx.Done(): return case c := <-m.add: - m.conns[bKey{c.client, c.agent}] = c + m.conns[bKey{id: c.UniqueID(), kind: c.Kind()}] = c nodes := m.mappingsToNodes(m.latest) if len(nodes) == 0 { m.logger.Debug(m.ctx, "skipping 0 length node update") continue } - if err := c.updates.Enqueue(nodes); err != nil { + if err := c.Enqueue(nodes); err != nil { m.logger.Error(m.ctx, "failed to enqueue node update", slog.Error(err)) } case c := <-m.del: - delete(m.conns, bKey{c.client, c.agent}) + delete(m.conns, bKey{id: c.UniqueID(), kind: c.Kind()}) case mappings := <-m.mappings: m.latest = mappings nodes := m.mappingsToNodes(mappings) @@ -543,7 +493,7 @@ func (m *mapper) run() { continue } for _, conn := range m.conns { - if err := conn.updates.Enqueue(nodes); err != nil { + if err := conn.Enqueue(nodes); err != nil { m.logger.Error(m.ctx, "failed to enqueue node update", slog.Error(err)) } } @@ -554,7 +504,7 @@ func (m *mapper) run() { continue } for _, conn := range m.conns { - if err := conn.updates.Enqueue(nodes); err != nil { + if err := conn.Enqueue(nodes); err != nil { m.logger.Error(m.ctx, "failed to enqueue triggered node update", slog.Error(err)) } } @@ -570,7 +520,13 @@ func (m *mapper) mappingsToNodes(mappings []mapping) []*agpl.Node { mappings = m.heartbeats.filter(mappings) best := make(map[bKey]mapping, len(mappings)) for _, m := range mappings { - bk := bKey{client: m.client, agent: m.agent} + var bk bKey + if m.client == uuid.Nil { + bk = bKey{id: m.agent, kind: agpl.QueueKindAgent} + } else { + bk = bKey{id: m.client, kind: agpl.QueueKindClient} + } + bestM, ok := best[bk] if !ok || m.updatedAt.After(bestM.updatedAt) { best[bk] = m @@ -591,16 +547,20 @@ type querier struct { logger slog.Logger pubsub pubsub.Pubsub store database.Store - newConnections chan *connIO + newConnections chan agpl.Queue + + workQ *workQ[mKey] - workQ *workQ[mKey] heartbeats *heartbeats updates <-chan hbUpdate mu sync.Mutex mappers map[mKey]*countedMapper - conns map[*connIO]struct{} - healthy bool + conns map[uuid.UUID]agpl.Queue + // clientSubscriptions maps client ids to the agent ids they're subscribed to. + // map[client_id]map[agent_id] + clientSubscriptions map[uuid.UUID]map[uuid.UUID]struct{} + healthy bool } type countedMapper struct { @@ -609,28 +569,32 @@ type countedMapper struct { cancel context.CancelFunc } -func newQuerier( - ctx context.Context, logger slog.Logger, - ps pubsub.Pubsub, store database.Store, - self uuid.UUID, newConnections chan *connIO, numWorkers int, +func newQuerier(ctx context.Context, + logger slog.Logger, + ps pubsub.Pubsub, + store database.Store, + self uuid.UUID, + newConnections chan agpl.Queue, + numWorkers int, firstHeartbeat chan<- struct{}, ) *querier { updates := make(chan hbUpdate) q := &querier{ - ctx: ctx, - logger: logger.Named("querier"), - pubsub: ps, - store: store, - newConnections: newConnections, - workQ: newWorkQ[mKey](ctx), - heartbeats: newHeartbeats(ctx, logger, ps, store, self, updates, firstHeartbeat), - mappers: make(map[mKey]*countedMapper), - conns: make(map[*connIO]struct{}), - updates: updates, - healthy: true, // assume we start healthy + ctx: ctx, + logger: logger.Named("querier"), + pubsub: ps, + store: store, + newConnections: newConnections, + workQ: newWorkQ[mKey](ctx), + heartbeats: newHeartbeats(ctx, logger, ps, store, self, updates, firstHeartbeat), + mappers: make(map[mKey]*countedMapper), + conns: make(map[uuid.UUID]agpl.Queue), + updates: updates, + clientSubscriptions: make(map[uuid.UUID]map[uuid.UUID]struct{}), + healthy: true, // assume we start healthy } q.subscribe() - go q.handleConnIO() + go q.handleNewConnections() for i := 0; i < numWorkers; i++ { go q.worker() } @@ -638,33 +602,71 @@ func newQuerier( return q } -func (q *querier) handleConnIO() { +func (q *querier) handleNewConnections() { for { select { case <-q.ctx.Done(): return case c := <-q.newConnections: - q.newConn(c) + switch c.Kind() { + case agpl.QueueKindAgent: + q.newAgentConn(c) + case agpl.QueueKindClient: + q.newClientConn(c) + default: + panic(fmt.Sprint("unreachable: invalid queue kind ", c.Kind())) + } } } } -func (q *querier) newConn(c *connIO) { +func (q *querier) newAgentConn(c agpl.Queue) { q.mu.Lock() defer q.mu.Unlock() if !q.healthy { - err := c.updates.Close() + err := c.Close() q.logger.Info(q.ctx, "closed incoming connection while unhealthy", slog.Error(err), - slog.F("agent_id", c.agent), - slog.F("client_id", c.client), + slog.F("agent_id", c.UniqueID()), ) return } mk := mKey{ - agent: c.agent, - // if client is Nil, this is an agent connection, and it wants the mappings for all the clients of itself - clientsOfAgent: c.client == uuid.Nil, + agent: c.UniqueID(), + kind: c.Kind(), + } + cm, ok := q.mappers[mk] + if !ok { + ctx, cancel := context.WithCancel(q.ctx) + mpr := newMapper(ctx, q.logger, mk, q.heartbeats) + cm = &countedMapper{ + mapper: mpr, + count: 0, + cancel: cancel, + } + q.mappers[mk] = cm + // we don't have any mapping state for this key yet + q.workQ.enqueue(mk) + } + if err := sendCtx(cm.ctx, cm.add, c); err != nil { + return + } + cm.count++ + q.conns[c.UniqueID()] = c + go q.waitCleanupConn(c) +} + +func (q *querier) newClientSubscription(c agpl.Queue, agentID uuid.UUID) { + q.mu.Lock() + defer q.mu.Unlock() + + if _, ok := q.clientSubscriptions[c.UniqueID()]; !ok { + q.clientSubscriptions[c.UniqueID()] = map[uuid.UUID]struct{}{} + } + + mk := mKey{ + agent: agentID, + kind: c.Kind(), } cm, ok := q.mappers[mk] if !ok { @@ -682,22 +684,94 @@ func (q *querier) newConn(c *connIO) { if err := sendCtx(cm.ctx, cm.add, c); err != nil { return } + q.clientSubscriptions[c.UniqueID()][agentID] = struct{}{} cm.count++ - q.conns[c] = struct{}{} - go q.cleanupConn(c) } -func (q *querier) cleanupConn(c *connIO) { - <-c.ctx.Done() +func (q *querier) removeClientSubscription(c agpl.Queue, agentID uuid.UUID) { q.mu.Lock() defer q.mu.Unlock() - delete(q.conns, c) + mk := mKey{ - agent: c.agent, - // if client is Nil, this is an agent connection, and it wants the mappings for all the clients of itself - clientsOfAgent: c.client == uuid.Nil, + agent: agentID, + kind: c.Kind(), } cm := q.mappers[mk] + if err := sendCtx(cm.ctx, cm.del, c); err != nil { + return + } + delete(q.clientSubscriptions[c.UniqueID()], agentID) + cm.count-- + if cm.count == 0 { + cm.cancel() + delete(q.mappers, mk) + } +} + +func (q *querier) newClientConn(c agpl.Queue) { + q.mu.Lock() + defer q.mu.Unlock() + if !q.healthy { + err := c.Close() + q.logger.Info(q.ctx, "closed incoming connection while unhealthy", + slog.Error(err), + slog.F("client_id", c.UniqueID()), + ) + return + } + + q.conns[c.UniqueID()] = c + go q.waitCleanupConn(c) +} + +func (q *querier) getClientSubscriptions(id uuid.UUID) []uuid.UUID { + q.mu.Lock() + defer q.mu.Unlock() + subs := []uuid.UUID{} + for sub := range q.clientSubscriptions[id] { + subs = append(subs, sub) + } + return subs +} + +func (q *querier) waitCleanupConn(c agpl.Queue) { + <-c.Done() + q.cleanupConn(c) +} + +func (q *querier) cleanupConn(c agpl.Queue) { + q.mu.Lock() + defer q.mu.Unlock() + delete(q.conns, c.UniqueID()) + + // Iterate over all subscriptions and remove them from the mappers. + for agentID := range q.clientSubscriptions[c.UniqueID()] { + mk := mKey{ + agent: agentID, + kind: c.Kind(), + } + cm, ok := q.mappers[mk] + if ok { + if err := sendCtx(cm.ctx, cm.del, c); err != nil { + return + } + cm.count-- + if cm.count == 0 { + cm.cancel() + delete(q.mappers, mk) + } + } + } + + mk := mKey{ + agent: c.UniqueID(), + kind: c.Kind(), + } + cm, ok := q.mappers[mk] + if !ok { + return + } + if err := sendCtx(cm.ctx, cm.del, c); err != nil { return } @@ -732,12 +806,15 @@ func (q *querier) worker() { func (q *querier) query(mk mKey) error { var mappings []mapping var err error - if mk.clientsOfAgent { + // If the mapping is an agent, query all of its clients. + if mk.kind == agpl.QueueKindAgent { mappings, err = q.queryClientsOfAgent(mk.agent) if err != nil { return err } } else { + // The mapping is for clients subscribed to the agent. Query the agent + // itself. mappings, err = q.queryAgent(mk.agent) if err != nil { return err @@ -748,9 +825,10 @@ func (q *querier) query(mk mKey) error { q.mu.Unlock() if !ok { q.logger.Debug(q.ctx, "query for missing mapper", - slog.F("agent_id", mk.agent), slog.F("clients_of_agent", mk.clientsOfAgent)) + slog.F("agent_id", mk.agent), slog.F("kind", mk.kind)) return nil } + q.logger.Debug(q.ctx, "sending mappings", slog.F("mapping_len", len(mappings))) mpr.mappings <- mappings return nil } @@ -772,7 +850,7 @@ func (q *querier) queryClientsOfAgent(agent uuid.UUID) ([]mapping, error) { } mappings = append(mappings, mapping{ client: client.ID, - agent: client.AgentID, + agent: agent, coordinator: client.CoordinatorID, updatedAt: client.UpdatedAt, node: node, @@ -788,6 +866,11 @@ func (q *querier) queryAgent(agentID uuid.UUID) ([]mapping, error) { if err != nil { return nil, err } + return q.agentsToMappings(agents) +} + +func (q *querier) agentsToMappings(agents []database.TailnetAgent) ([]mapping, error) { + slog.Helper() mappings := make([]mapping, 0, len(agents)) for _, agent := range agents { node := new(agpl.Node) @@ -874,25 +957,28 @@ func (q *querier) listenClient(_ context.Context, msg []byte, err error) { if err != nil { q.logger.Warn(q.ctx, "unhandled pubsub error", slog.Error(err)) } - client, agent, err := parseClientUpdate(string(msg)) + client, agents, err := parseClientUpdate(string(msg)) if err != nil { q.logger.Error(q.ctx, "failed to parse client update", slog.F("msg", string(msg)), slog.Error(err)) return } - logger := q.logger.With(slog.F("client_id", client), slog.F("agent_id", agent)) + logger := q.logger.With(slog.F("client_id", client)) logger.Debug(q.ctx, "got client update") - mk := mKey{ - agent: agent, - clientsOfAgent: true, - } - q.mu.Lock() - _, ok := q.mappers[mk] - q.mu.Unlock() - if !ok { - logger.Debug(q.ctx, "ignoring update because we have no mapper") - return + for _, agentID := range agents { + logger := q.logger.With(slog.F("agent_id", agentID)) + mk := mKey{ + agent: agentID, + kind: agpl.QueueKindAgent, + } + q.mu.Lock() + _, ok := q.mappers[mk] + q.mu.Unlock() + if !ok { + logger.Debug(q.ctx, "ignoring update because we have no mapper") + return + } + q.workQ.enqueue(mk) } - q.workQ.enqueue(mk) } func (q *querier) listenAgent(_ context.Context, msg []byte, err error) { @@ -905,7 +991,7 @@ func (q *querier) listenAgent(_ context.Context, msg []byte, err error) { if err != nil { q.logger.Warn(q.ctx, "unhandled pubsub error", slog.Error(err)) } - agent, err := parseAgentUpdate(string(msg)) + agent, err := parseUpdateMessage(string(msg)) if err != nil { q.logger.Error(q.ctx, "failed to parse agent update", slog.F("msg", string(msg)), slog.Error(err)) return @@ -913,8 +999,8 @@ func (q *querier) listenAgent(_ context.Context, msg []byte, err error) { logger := q.logger.With(slog.F("agent_id", agent)) logger.Debug(q.ctx, "got agent update") mk := mKey{ - agent: agent, - clientsOfAgent: false, + agent: agent, + kind: agpl.QueueKindClient, } q.mu.Lock() _, ok := q.mappers[mk] @@ -930,7 +1016,7 @@ func (q *querier) resyncClientMappings() { q.mu.Lock() defer q.mu.Unlock() for mk := range q.mappers { - if mk.clientsOfAgent { + if mk.kind == agpl.QueueKindClient { q.workQ.enqueue(mk) } } @@ -940,7 +1026,7 @@ func (q *querier) resyncAgentMappings() { q.mu.Lock() defer q.mu.Unlock() for mk := range q.mappers { - if !mk.clientsOfAgent { + if mk.kind == agpl.QueueKindAgent { q.workQ.enqueue(mk) } } @@ -988,10 +1074,10 @@ func (q *querier) unhealthyCloseAll() { q.mu.Lock() defer q.mu.Unlock() q.healthy = false - for c := range q.conns { + for _, c := range q.conns { // close connections async so that we don't block the querier routine that responds to updates - go func(c *connIO) { - err := c.updates.Close() + go func(c agpl.Queue) { + err := c.Close() if err != nil { q.logger.Debug(q.ctx, "error closing conn while unhealthy", slog.Error(err)) } @@ -1021,32 +1107,41 @@ func (q *querier) getAll(ctx context.Context) (map[uuid.UUID]database.TailnetAge } clientsMap := map[uuid.UUID][]database.TailnetClient{} for _, client := range clients { - clientsMap[client.AgentID] = append(clientsMap[client.AgentID], client) + for _, agentID := range client.AgentIds { + clientsMap[agentID] = append(clientsMap[agentID], client) + } } return agentsMap, clientsMap, nil } -func parseClientUpdate(msg string) (client, agent uuid.UUID, err error) { +func parseClientUpdate(msg string) (client uuid.UUID, agents []uuid.UUID, err error) { parts := strings.Split(msg, ",") if len(parts) != 2 { - return uuid.Nil, uuid.Nil, xerrors.Errorf("expected 2 parts separated by comma") + return uuid.Nil, nil, xerrors.Errorf("expected 2 parts separated by comma") } client, err = uuid.Parse(parts[0]) if err != nil { - return uuid.Nil, uuid.Nil, xerrors.Errorf("failed to parse client UUID: %w", err) + return uuid.Nil, nil, xerrors.Errorf("failed to parse client UUID: %w", err) } - agent, err = uuid.Parse(parts[1]) - if err != nil { - return uuid.Nil, uuid.Nil, xerrors.Errorf("failed to parse agent UUID: %w", err) + + agents = []uuid.UUID{} + for _, agentStr := range parts[1:] { + agent, err := uuid.Parse(agentStr) + if err != nil { + return uuid.Nil, nil, xerrors.Errorf("failed to parse agent UUID: %w", err) + } + + agents = append(agents, agent) } - return client, agent, nil + + return client, agents, nil } -func parseAgentUpdate(msg string) (agent uuid.UUID, err error) { +func parseUpdateMessage(msg string) (agent uuid.UUID, err error) { agent, err = uuid.Parse(msg) if err != nil { - return uuid.Nil, xerrors.Errorf("failed to parse agent UUID: %w", err) + return uuid.Nil, xerrors.Errorf("failed to parse update message UUID: %w", err) } return agent, nil } @@ -1056,7 +1151,7 @@ type mKey struct { agent uuid.UUID // we always query based on the agent ID, but if we have client connection(s), we query the agent itself. If we // have an agent connection, we need the node mappings for all clients of the agent. - clientsOfAgent bool + kind agpl.QueueKind } // mapping associates a particular client or agent, and its respective coordinator with a node. It is generalized to diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index 9112cd95a0791..51fd1074dcd8b 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -438,7 +438,7 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) { assertEventuallyNoClientsForAgent(ctx, t, store, agent2.id) } -// TestPGCoordinator_MultiAgent tests when a single agent connects to multiple coordinators. +// TestPGCoordinator_MultiCoordinatorAgent tests when a single agent connects to multiple coordinators. // We use two agent connections, but they share the same AgentID. This could happen due to a reconnection, // or an infrastructure problem where an old workspace is not fully cleaned up before a new one started. // @@ -451,7 +451,7 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) { // +---------+ // | coord3 | <--- client // +---------+ -func TestPGCoordinator_MultiAgent(t *testing.T) { +func TestPGCoordinator_MultiCoordinatorAgent(t *testing.T) { t.Parallel() if !dbtestutil.WillUsePostgres() { t.Skip("test only with postgres") @@ -589,6 +589,51 @@ func TestPGCoordinator_Unhealthy(t *testing.T) { } } +func TestPGCoordinator_MultiAgent(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + store, ps := dbtestutil.NewDB(t) + coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) + require.NoError(t, err) + defer coord1.Close() + + agent1 := newTestAgent(t, coord1, "agent1") + defer agent1.close() + agent1.sendNode(&agpl.Node{PreferredDERP: 5}) + + id := uuid.New() + ma1, err := coord1.ServeMultiAgent(id) + require.NoError(t, err) + defer ma1.Close() + + err = ma1.SubscribeAgent(agent1.id) + require.NoError(t, err) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + + agent1.sendNode(&agpl.Node{PreferredDERP: 1}) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + + err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) + require.NoError(t, err) + assertEventuallyHasDERPs(ctx, t, agent1, 3) + + err = ma1.Close() + require.NoError(t, err) + + err = agent1.close() + require.NoError(t, err) + + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) + assertEventuallyNoAgents(ctx, t, store, agent1.id) +} + type testConn struct { ws, serverWS net.Conn nodeChan chan []*agpl.Node @@ -601,7 +646,7 @@ type testConn struct { func newTestConn(ids []uuid.UUID) *testConn { a := &testConn{} a.ws, a.serverWS = net.Pipe() - a.nodeChan = make(chan []*agpl.Node) + a.nodeChan = make(chan []*agpl.Node, 5) a.sendNode, a.errChan = agpl.ServeCoordinator(a.ws, func(nodes []*agpl.Node) error { a.nodeChan <- nodes return nil @@ -698,6 +743,30 @@ func assertEventuallyHasDERPs(ctx context.Context, t *testing.T, c *testConn, ex } } +func assertMultiAgentEventuallyHasDERPs(ctx context.Context, t *testing.T, ma agpl.MultiAgentConn, expected ...int) { + t.Helper() + for { + nodes, ok := ma.NextUpdate(ctx) + require.True(t, ok) + if len(nodes) != len(expected) { + t.Logf("expected %d, got %d nodes", len(expected), len(nodes)) + continue + } + + derps := make([]int, 0, len(nodes)) + for _, n := range nodes { + derps = append(derps, n.PreferredDERP) + } + for _, e := range expected { + if !slices.Contains(derps, e) { + t.Logf("expected DERP %d to be in %v", e, derps) + continue + } + } + return + } +} + func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { assert.Eventually(t, func() bool { agents, err := store.GetTailnetAgents(ctx, agentID) @@ -712,6 +781,7 @@ func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database. } func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { + t.Helper() assert.Eventually(t, func() bool { clients, err := store.GetTailnetClientsForAgent(ctx, agentID) if xerrors.Is(err, sql.ErrNoRows) { diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go index 74c381c2d8b4a..68a42d23646ca 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go @@ -23,6 +23,7 @@ import ( "github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/tailnet" + agpl "github.com/coder/coder/v2/tailnet" ) // Client is a HTTP client for a subset of Coder API routes that external @@ -469,7 +470,7 @@ func (c *Client) DialCoordinator(ctx context.Context) (tailnet.MultiAgentConn, e OnSubscribe: rma.OnSubscribe, OnUnsubscribe: rma.OnUnsubscribe, OnNodeUpdate: rma.OnNodeUpdate, - OnRemove: func(uuid.UUID) { conn.Close(websocket.StatusGoingAway, "closed") }, + OnRemove: func(agpl.Queue) { conn.Close(websocket.StatusGoingAway, "closed") }, }).Init() go func() { diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 866633ed54a1e..073195187a9cf 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -45,7 +45,7 @@ type Coordinator interface { // Close closes the coordinator. Close() error - ServeMultiAgent(id uuid.UUID) MultiAgentConn + ServeMultiAgent(id uuid.UUID) (MultiAgentConn, error) } // Node represents a node in the network. @@ -139,17 +139,17 @@ type coordinator struct { core *core } -func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn { +func (c *coordinator) ServeMultiAgent(id uuid.UUID) (MultiAgentConn, error) { m := (&MultiAgent{ ID: id, AgentIsLegacyFunc: c.core.agentIsLegacy, OnSubscribe: c.core.clientSubscribeToAgent, OnUnsubscribe: c.core.clientUnsubscribeFromAgent, OnNodeUpdate: c.core.clientNodeUpdate, - OnRemove: c.core.clientDisconnected, + OnRemove: func(enq Queue) { c.core.clientDisconnected(enq.UniqueID()) }, }).Init() c.core.addClient(id, m) - return m + return m, nil } func (c *core) addClient(id uuid.UUID, ma Queue) { @@ -191,8 +191,17 @@ type core struct { legacyAgents map[uuid.UUID]struct{} } +type QueueKind int + +const ( + _ QueueKind = iota + QueueKindClient + QueueKindAgent +) + type Queue interface { UniqueID() uuid.UUID + Kind() QueueKind Enqueue(n []*Node) error Name() string Stats() (start, lastWrite int64) @@ -200,6 +209,7 @@ type Queue interface { // CoordinatorClose is used by the coordinator when closing a Queue. It // should skip removing itself from the coordinator. CoordinatorClose() error + Done() <-chan struct{} Close() error } @@ -264,7 +274,7 @@ func (c *coordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error { logger := c.core.clientLogger(id, agentID) logger.Debug(ctx, "coordinating client") - tc := NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0) + tc := NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0, QueueKindClient) defer tc.Close() c.core.addClient(id, tc) @@ -509,7 +519,7 @@ func (c *core) initAndTrackAgent(ctx context.Context, cancel func(), conn net.Co overwrites = oldAgentSocket.Overwrites() + 1 _ = oldAgentSocket.Close() } - tc := NewTrackedConn(ctx, cancel, conn, unique, logger, name, overwrites) + tc := NewTrackedConn(ctx, cancel, conn, unique, logger, name, overwrites, QueueKindAgent) c.agentNameCache.Add(id, name) sockets, ok := c.agentToConnectionSockets[id] diff --git a/tailnet/multiagent.go b/tailnet/multiagent.go index ee76e4b88d8aa..5c3412a595152 100644 --- a/tailnet/multiagent.go +++ b/tailnet/multiagent.go @@ -29,9 +29,12 @@ type MultiAgent struct { OnSubscribe func(enq Queue, agent uuid.UUID) (*Node, error) OnUnsubscribe func(enq Queue, agent uuid.UUID) error OnNodeUpdate func(id uuid.UUID, node *Node) error - OnRemove func(id uuid.UUID) + OnRemove func(enq Queue) + ctx context.Context + ctxCancel func() closed bool + updates chan []*Node closeOnce sync.Once start int64 @@ -44,9 +47,14 @@ type MultiAgent struct { func (m *MultiAgent) Init() *MultiAgent { m.updates = make(chan []*Node, 128) m.start = time.Now().Unix() + m.ctx, m.ctxCancel = context.WithCancel(context.Background()) return m } +func (*MultiAgent) Kind() QueueKind { + return QueueKindClient +} + func (m *MultiAgent) UniqueID() uuid.UUID { return m.ID } @@ -156,8 +164,13 @@ func (m *MultiAgent) CoordinatorClose() error { return nil } +func (m *MultiAgent) Done() <-chan struct{} { + return m.ctx.Done() +} + func (m *MultiAgent) Close() error { _ = m.CoordinatorClose() - m.closeOnce.Do(func() { m.OnRemove(m.ID) }) + m.ctxCancel() + m.closeOnce.Do(func() { m.OnRemove(m) }) return nil } diff --git a/tailnet/trackedconn.go b/tailnet/trackedconn.go index 0ec19695ba29f..be464b2327921 100644 --- a/tailnet/trackedconn.go +++ b/tailnet/trackedconn.go @@ -20,6 +20,7 @@ const WriteTimeout = time.Second * 5 type TrackedConn struct { ctx context.Context cancel func() + kind QueueKind conn net.Conn updates chan []*Node logger slog.Logger @@ -35,7 +36,14 @@ type TrackedConn struct { overwrites int64 } -func NewTrackedConn(ctx context.Context, cancel func(), conn net.Conn, id uuid.UUID, logger slog.Logger, name string, overwrites int64) *TrackedConn { +func NewTrackedConn(ctx context.Context, cancel func(), + conn net.Conn, + id uuid.UUID, + logger slog.Logger, + name string, + overwrites int64, + kind QueueKind, +) *TrackedConn { // buffer updates so they don't block, since we hold the // coordinator mutex while queuing. Node updates don't // come quickly, so 512 should be plenty for all but @@ -53,6 +61,7 @@ func NewTrackedConn(ctx context.Context, cancel func(), conn net.Conn, id uuid.U lastWrite: now, name: name, overwrites: overwrites, + kind: kind, } } @@ -70,6 +79,10 @@ func (t *TrackedConn) UniqueID() uuid.UUID { return t.id } +func (t *TrackedConn) Kind() QueueKind { + return t.kind +} + func (t *TrackedConn) Name() string { return t.name } @@ -86,6 +99,10 @@ func (t *TrackedConn) CoordinatorClose() error { return t.Close() } +func (t *TrackedConn) Done() <-chan struct{} { + return t.ctx.Done() +} + // Close the connection and cancel the context for reading node updates from the queue func (t *TrackedConn) Close() error { t.cancel() From efa20718f497257b9eb73588ddb6863cae329fe4 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 25 Aug 2023 23:09:59 +0000 Subject: [PATCH 02/21] fixup! feat: add single tailnet support to pgcoord --- ...0152_pg_coordinator_single_tailnet.down.sql} | 0 ...000152_pg_coordinator_single_tailnet.up.sql} | 0 enterprise/coderd/workspaceproxycoordinate.go | 8 ++++++++ enterprise/wsproxy/wsproxysdk/wsproxysdk.go | 17 ++++++++--------- 4 files changed, 16 insertions(+), 9 deletions(-) rename coderd/database/migrations/{000151_pg_coordinator_single_tailnet.down.sql => 000152_pg_coordinator_single_tailnet.down.sql} (100%) rename coderd/database/migrations/{000151_pg_coordinator_single_tailnet.up.sql => 000152_pg_coordinator_single_tailnet.up.sql} (100%) diff --git a/coderd/database/migrations/000151_pg_coordinator_single_tailnet.down.sql b/coderd/database/migrations/000152_pg_coordinator_single_tailnet.down.sql similarity index 100% rename from coderd/database/migrations/000151_pg_coordinator_single_tailnet.down.sql rename to coderd/database/migrations/000152_pg_coordinator_single_tailnet.down.sql diff --git a/coderd/database/migrations/000151_pg_coordinator_single_tailnet.up.sql b/coderd/database/migrations/000152_pg_coordinator_single_tailnet.up.sql similarity index 100% rename from coderd/database/migrations/000151_pg_coordinator_single_tailnet.up.sql rename to coderd/database/migrations/000152_pg_coordinator_single_tailnet.up.sql diff --git a/enterprise/coderd/workspaceproxycoordinate.go b/enterprise/coderd/workspaceproxycoordinate.go index 6fc0bcd6b18f7..bb4b3fa7b69eb 100644 --- a/enterprise/coderd/workspaceproxycoordinate.go +++ b/enterprise/coderd/workspaceproxycoordinate.go @@ -68,6 +68,14 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request id := uuid.New() sub, err := (*api.AGPL.TailnetCoordinator.Load()).ServeMultiAgent(id) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to serve multi agent.", + Detail: err.Error(), + }) + return + } + ctx, nc := websocketNetConn(ctx, conn, websocket.MessageText) defer nc.Close() diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go index 68a42d23646ca..c00ab834b7c25 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go @@ -22,7 +22,6 @@ import ( "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/tailnet" agpl "github.com/coder/coder/v2/tailnet" ) @@ -423,14 +422,14 @@ const ( type CoordinateMessage struct { Type CoordinateMessageType `json:"type"` AgentID uuid.UUID `json:"agent_id"` - Node *tailnet.Node `json:"node"` + Node *agpl.Node `json:"node"` } type CoordinateNodes struct { - Nodes []*tailnet.Node + Nodes []*agpl.Node } -func (c *Client) DialCoordinator(ctx context.Context) (tailnet.MultiAgentConn, error) { +func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, error) { ctx, cancel := context.WithCancel(ctx) coordinateURL, err := c.SDKClient.URL.Parse("/api/v2/workspaceproxies/me/coordinate") @@ -464,7 +463,7 @@ func (c *Client) DialCoordinator(ctx context.Context) (tailnet.MultiAgentConn, e legacyAgentCache: map[uuid.UUID]bool{}, } - ma := (&tailnet.MultiAgent{ + ma := (&agpl.MultiAgent{ ID: uuid.New(), AgentIsLegacyFunc: rma.AgentIsLegacy, OnSubscribe: rma.OnSubscribe, @@ -516,7 +515,7 @@ func (a *remoteMultiAgentHandler) writeJSON(v interface{}) error { // Set a deadline so that hung connections don't put back pressure on the system. // Node updates are tiny, so even the dinkiest connection can handle them if it's not hung. - err = a.nc.SetWriteDeadline(time.Now().Add(tailnet.WriteTimeout)) + err = a.nc.SetWriteDeadline(time.Now().Add(agpl.WriteTimeout)) if err != nil { return xerrors.Errorf("set write deadline: %w", err) } @@ -538,21 +537,21 @@ func (a *remoteMultiAgentHandler) writeJSON(v interface{}) error { return nil } -func (a *remoteMultiAgentHandler) OnNodeUpdate(_ uuid.UUID, node *tailnet.Node) error { +func (a *remoteMultiAgentHandler) OnNodeUpdate(_ uuid.UUID, node *agpl.Node) error { return a.writeJSON(CoordinateMessage{ Type: CoordinateMessageTypeNodeUpdate, Node: node, }) } -func (a *remoteMultiAgentHandler) OnSubscribe(_ tailnet.Queue, agentID uuid.UUID) (*tailnet.Node, error) { +func (a *remoteMultiAgentHandler) OnSubscribe(_ agpl.Queue, agentID uuid.UUID) (*agpl.Node, error) { return nil, a.writeJSON(CoordinateMessage{ Type: CoordinateMessageTypeSubscribe, AgentID: agentID, }) } -func (a *remoteMultiAgentHandler) OnUnsubscribe(_ tailnet.Queue, agentID uuid.UUID) error { +func (a *remoteMultiAgentHandler) OnUnsubscribe(_ agpl.Queue, agentID uuid.UUID) error { return a.writeJSON(CoordinateMessage{ Type: CoordinateMessageTypeUnsubscribe, AgentID: agentID, From 2976c00271beac06cf0567df441bb09a4601b79e Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 25 Aug 2023 23:19:13 +0000 Subject: [PATCH 03/21] fixup! feat: add single tailnet support to pgcoord --- .../000152_pg_coordinator_single_tailnet.down.sql | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/coderd/database/migrations/000152_pg_coordinator_single_tailnet.down.sql b/coderd/database/migrations/000152_pg_coordinator_single_tailnet.down.sql index 9ce2798205ef8..ddc05facaf677 100644 --- a/coderd/database/migrations/000152_pg_coordinator_single_tailnet.down.sql +++ b/coderd/database/migrations/000152_pg_coordinator_single_tailnet.down.sql @@ -1,9 +1,9 @@ BEGIN; --- ALTER TABLE --- tailnet_clients --- ADD COLUMN --- agent_id uuid; +ALTER TABLE + tailnet_clients +ADD COLUMN + agent_id uuid; UPDATE tailnet_clients @@ -11,6 +11,11 @@ SET -- grab just the first agent_id, or default to an empty UUID. agent_id = COALESCE(agent_ids[0], '00000000-0000-0000-0000-000000000000'::uuid); +ALTER TABLE + tailnet_clients +ALTER COLUMN + agent_id SET NOT NULL; + ALTER TABLE tailnet_clients DROP COLUMN From a7b39dfe95758951a34f526d29ef05842076b56a Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Thu, 31 Aug 2023 22:15:51 +0000 Subject: [PATCH 04/21] separate table + use channels --- coderd/database/dbauthz/dbauthz.go | 12 +- coderd/database/dbfake/dbfake.go | 15 +- coderd/database/dbmetrics/dbmetrics.go | 16 +- coderd/database/dbmock/dbmock.go | 33 ++- coderd/database/dump.sql | 78 ++++++- ...00152_pg_coordinator_single_tailnet.up.sql | 42 ---- ...54_pg_coordinator_single_tailnet.down.sql} | 0 ...00154_pg_coordinator_single_tailnet.up.sql | 97 +++++++++ coderd/database/models.go | 8 +- coderd/database/querier.go | 6 +- coderd/database/queries.sql.go | 108 +++++++--- coderd/database/queries/tailnet.sql | 41 +++- enterprise/tailnet/connio.go | 38 ++-- enterprise/tailnet/pgcoord.go | 196 ++++++++++++------ 14 files changed, 513 insertions(+), 177 deletions(-) delete mode 100644 coderd/database/migrations/000152_pg_coordinator_single_tailnet.up.sql rename coderd/database/migrations/{000152_pg_coordinator_single_tailnet.down.sql => 000154_pg_coordinator_single_tailnet.down.sql} (100%) create mode 100644 coderd/database/migrations/000154_pg_coordinator_single_tailnet.up.sql diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 8ddd779d795e9..0fa1c616e209e 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -767,6 +767,10 @@ func (q *querier) DeleteTailnetClient(ctx context.Context, arg database.DeleteTa return q.db.DeleteTailnetClient(ctx, arg) } +func (q *querier) DeleteTailnetClientSubscription(ctx context.Context, arg database.DeleteTailnetClientSubscriptionParams) (database.DeleteTailnetClientSubscriptionRow, error) { + panic("not implemented") +} + func (q *querier) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) { return fetch(q.log, q.auth, q.db.GetAPIKeyByID)(ctx, id) } @@ -809,9 +813,9 @@ func (q *querier) GetAllTailnetAgents(ctx context.Context) ([]database.TailnetAg return q.db.GetAllTailnetAgents(ctx) } -func (q *querier) GetAllTailnetClients(ctx context.Context) ([]database.TailnetClient, error) { +func (q *querier) GetAllTailnetClients(ctx context.Context) ([]database.GetAllTailnetClientsRow, error) { if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTailnetCoordinator); err != nil { - return []database.TailnetClient{}, err + return []database.GetAllTailnetClientsRow{}, err } return q.db.GetAllTailnetClients(ctx) } @@ -2778,6 +2782,10 @@ func (q *querier) UpsertTailnetClient(ctx context.Context, arg database.UpsertTa return q.db.UpsertTailnetClient(ctx, arg) } +func (q *querier) UpsertTailnetClientSubscription(ctx context.Context, arg database.UpsertTailnetClientSubscriptionParams) error { + panic("not implemented") +} + func (q *querier) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (database.TailnetCoordinator, error) { if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { return database.TailnetCoordinator{}, err diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index e73578a61a7df..6d102098c0b18 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -909,6 +909,15 @@ func (*FakeQuerier) DeleteTailnetClient(context.Context, database.DeleteTailnetC return database.DeleteTailnetClientRow{}, ErrUnimplemented } +func (q *FakeQuerier) DeleteTailnetClientSubscription(ctx context.Context, arg database.DeleteTailnetClientSubscriptionParams) (database.DeleteTailnetClientSubscriptionRow, error) { + err := validateDatabaseType(arg) + if err != nil { + return database.DeleteTailnetClientSubscriptionRow{}, err + } + + panic("not implemented") +} + func (q *FakeQuerier) GetAPIKeyByID(_ context.Context, id string) (database.APIKey, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -1024,7 +1033,7 @@ func (*FakeQuerier) GetAllTailnetAgents(_ context.Context) ([]database.TailnetAg return nil, ErrUnimplemented } -func (*FakeQuerier) GetAllTailnetClients(_ context.Context) ([]database.TailnetClient, error) { +func (*FakeQuerier) GetAllTailnetClients(_ context.Context) ([]database.GetAllTailnetClientsRow, error) { return nil, ErrUnimplemented } @@ -6032,6 +6041,10 @@ func (*FakeQuerier) UpsertTailnetClient(context.Context, database.UpsertTailnetC return database.TailnetClient{}, ErrUnimplemented } +func (q *FakeQuerier) UpsertTailnetClientSubscription(ctx context.Context, arg database.UpsertTailnetClientSubscriptionParams) error { + return ErrUnimplemented +} + func (*FakeQuerier) UpsertTailnetCoordinator(context.Context, uuid.UUID) (database.TailnetCoordinator, error) { return database.TailnetCoordinator{}, ErrUnimplemented } diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 0a02896200f60..24187b2b78b22 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -195,6 +195,13 @@ func (m metricsStore) DeleteTailnetClient(ctx context.Context, arg database.Dele return m.s.DeleteTailnetClient(ctx, arg) } +func (m metricsStore) DeleteTailnetClientSubscription(ctx context.Context, arg database.DeleteTailnetClientSubscriptionParams) (database.DeleteTailnetClientSubscriptionRow, error) { + start := time.Now() + r0, r1 := m.s.DeleteTailnetClientSubscription(ctx, arg) + m.queryLatencies.WithLabelValues("DeleteTailnetClientSubscription").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) { start := time.Now() apiKey, err := m.s.GetAPIKeyByID(ctx, id) @@ -251,7 +258,7 @@ func (m metricsStore) GetAllTailnetAgents(ctx context.Context) ([]database.Tailn return r0, r1 } -func (m metricsStore) GetAllTailnetClients(ctx context.Context) ([]database.TailnetClient, error) { +func (m metricsStore) GetAllTailnetClients(ctx context.Context) ([]database.GetAllTailnetClientsRow, error) { start := time.Now() r0, r1 := m.s.GetAllTailnetClients(ctx) m.queryLatencies.WithLabelValues("GetAllTailnetClients").Observe(time.Since(start).Seconds()) @@ -1738,6 +1745,13 @@ func (m metricsStore) UpsertTailnetClient(ctx context.Context, arg database.Upse return m.s.UpsertTailnetClient(ctx, arg) } +func (m metricsStore) UpsertTailnetClientSubscription(ctx context.Context, arg database.UpsertTailnetClientSubscriptionParams) error { + start := time.Now() + r0 := m.s.UpsertTailnetClientSubscription(ctx, arg) + m.queryLatencies.WithLabelValues("UpsertTailnetClientSubscription").Observe(time.Since(start).Seconds()) + return r0 +} + func (m metricsStore) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (database.TailnetCoordinator, error) { start := time.Now() defer m.queryLatencies.WithLabelValues("UpsertTailnetCoordinator").Observe(time.Since(start).Seconds()) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index be1f994d81161..d6e88675dbed2 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -281,6 +281,21 @@ func (mr *MockStoreMockRecorder) DeleteTailnetClient(arg0, arg1 interface{}) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTailnetClient", reflect.TypeOf((*MockStore)(nil).DeleteTailnetClient), arg0, arg1) } +// DeleteTailnetClientSubscription mocks base method. +func (m *MockStore) DeleteTailnetClientSubscription(arg0 context.Context, arg1 database.DeleteTailnetClientSubscriptionParams) (database.DeleteTailnetClientSubscriptionRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteTailnetClientSubscription", arg0, arg1) + ret0, _ := ret[0].(database.DeleteTailnetClientSubscriptionRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteTailnetClientSubscription indicates an expected call of DeleteTailnetClientSubscription. +func (mr *MockStoreMockRecorder) DeleteTailnetClientSubscription(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTailnetClientSubscription", reflect.TypeOf((*MockStore)(nil).DeleteTailnetClientSubscription), arg0, arg1) +} + // GetAPIKeyByID mocks base method. func (m *MockStore) GetAPIKeyByID(arg0 context.Context, arg1 string) (database.APIKey, error) { m.ctrl.T.Helper() @@ -402,10 +417,10 @@ func (mr *MockStoreMockRecorder) GetAllTailnetAgents(arg0 interface{}) *gomock.C } // GetAllTailnetClients mocks base method. -func (m *MockStore) GetAllTailnetClients(arg0 context.Context) ([]database.TailnetClient, error) { +func (m *MockStore) GetAllTailnetClients(arg0 context.Context) ([]database.GetAllTailnetClientsRow, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetAllTailnetClients", arg0) - ret0, _ := ret[0].([]database.TailnetClient) + ret0, _ := ret[0].([]database.GetAllTailnetClientsRow) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -3652,6 +3667,20 @@ func (mr *MockStoreMockRecorder) UpsertTailnetClient(arg0, arg1 interface{}) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTailnetClient", reflect.TypeOf((*MockStore)(nil).UpsertTailnetClient), arg0, arg1) } +// UpsertTailnetClientSubscription mocks base method. +func (m *MockStore) UpsertTailnetClientSubscription(arg0 context.Context, arg1 database.UpsertTailnetClientSubscriptionParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertTailnetClientSubscription", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpsertTailnetClientSubscription indicates an expected call of UpsertTailnetClientSubscription. +func (mr *MockStoreMockRecorder) UpsertTailnetClientSubscription(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTailnetClientSubscription", reflect.TypeOf((*MockStore)(nil).UpsertTailnetClientSubscription), arg0, arg1) +} + // UpsertTailnetCoordinator mocks base method. func (m *MockStore) UpsertTailnetCoordinator(arg0 context.Context, arg1 uuid.UUID) (database.TailnetCoordinator, error) { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 31a6d98af211c..62372b63be3fc 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -219,14 +219,66 @@ $$; CREATE FUNCTION tailnet_notify_client_change() RETURNS trigger LANGUAGE plpgsql AS $$ +DECLARE + var_client_id uuid; + var_agent_ids uuid[]; +BEGIN + IF (NEW.id IS NOT NULL) THEN + IF (NEW.node IS NULL) THEN + return NULL; + END IF; + + var_client_id = NEW.id; + SELECT + array_agg(agent_id) + INTO + var_agent_ids + FROM + tailnet_client_subscriptions subs + WHERE + subs.client_id = NEW.id AND + subs.coordinator_id = NEW.coordinator_id; + ELSIF (OLD.id IS NOT NULL) THEN + -- if new is null and old is not null, that means the row was deleted. + var_client_id = OLD.id; + WITH agent_ids AS ( + DELETE FROM + tailnet_client_subscriptions subs + WHERE + subs.client_id = OLD.id AND + subs.coordinator_id = OLD.coordinator_id + RETURNING + subs.agent_id + ) + SELECT + array_agg(agent_id) + INTO + var_agent_ids + FROM + agent_ids; + END IF; + + -- Read all agents the client is subscribed to, so we can notify them. + -- No agents to notify + if (var_agent_ids IS NULL) THEN + return NULL; + END IF; + + PERFORM pg_notify('tailnet_client_update', var_client_id || ',' || array_to_string(var_agent_ids, ',')); + return NULL; +END; +$$; + +CREATE FUNCTION tailnet_notify_client_subscription_change() RETURNS trigger + LANGUAGE plpgsql + AS $$ BEGIN - -- check new first to get the updated agent ids. IF (NEW IS NOT NULL) THEN - PERFORM pg_notify('tailnet_client_update', NEW.id || ',' || array_to_string(NEW.agent_ids, ',')); + PERFORM pg_notify('tailnet_client_update', NEW.client_id || ',' || NEW.agent_id); RETURN NULL; END IF; IF (OLD IS NOT NULL) THEN - PERFORM pg_notify('tailnet_client_update', OLD.id || ',' || array_to_string(OLD.agent_ids, ',')); + PERFORM pg_notify('tailnet_client_update', OLD.client_id || ',' || OLD.agent_id); RETURN NULL; END IF; END; @@ -496,12 +548,18 @@ CREATE TABLE tailnet_agents ( node jsonb NOT NULL ); +CREATE TABLE tailnet_client_subscriptions ( + client_id uuid NOT NULL, + coordinator_id uuid NOT NULL, + agent_id uuid NOT NULL, + updated_at timestamp with time zone NOT NULL +); + CREATE TABLE tailnet_clients ( id uuid NOT NULL, coordinator_id uuid NOT NULL, updated_at timestamp with time zone NOT NULL, - node jsonb NOT NULL, - agent_ids uuid[] NOT NULL + node jsonb NOT NULL ); CREATE TABLE tailnet_coordinators ( @@ -1145,6 +1203,9 @@ ALTER TABLE ONLY site_configs ALTER TABLE ONLY tailnet_agents ADD CONSTRAINT tailnet_agents_pkey PRIMARY KEY (id, coordinator_id); +ALTER TABLE ONLY tailnet_client_subscriptions + ADD CONSTRAINT tailnet_client_subscriptions_pkey PRIMARY KEY (client_id, coordinator_id, agent_id); + ALTER TABLE ONLY tailnet_clients ADD CONSTRAINT tailnet_clients_pkey PRIMARY KEY (id, coordinator_id); @@ -1249,8 +1310,6 @@ CREATE UNIQUE INDEX idx_organization_name_lower ON organizations USING btree (lo CREATE INDEX idx_tailnet_agents_coordinator ON tailnet_agents USING btree (coordinator_id); -CREATE INDEX idx_tailnet_clients_agent_ids ON tailnet_clients USING gin (agent_ids); - CREATE INDEX idx_tailnet_clients_coordinator ON tailnet_clients USING btree (coordinator_id); CREATE UNIQUE INDEX idx_users_email ON users USING btree (email) WHERE (deleted = false); @@ -1285,6 +1344,8 @@ CREATE TRIGGER tailnet_notify_agent_change AFTER INSERT OR DELETE OR UPDATE ON t CREATE TRIGGER tailnet_notify_client_change AFTER INSERT OR DELETE OR UPDATE ON tailnet_clients FOR EACH ROW EXECUTE FUNCTION tailnet_notify_client_change(); +CREATE TRIGGER tailnet_notify_client_subscription_change AFTER INSERT OR DELETE OR UPDATE ON tailnet_client_subscriptions FOR EACH ROW EXECUTE FUNCTION tailnet_notify_client_subscription_change(); + CREATE TRIGGER tailnet_notify_coordinator_heartbeat AFTER INSERT OR UPDATE ON tailnet_coordinators FOR EACH ROW EXECUTE FUNCTION tailnet_notify_coordinator_heartbeat(); CREATE TRIGGER trigger_insert_apikeys BEFORE INSERT ON api_keys FOR EACH ROW EXECUTE FUNCTION insert_apikey_fail_if_user_deleted(); @@ -1330,6 +1391,9 @@ ALTER TABLE ONLY provisioner_jobs ALTER TABLE ONLY tailnet_agents ADD CONSTRAINT tailnet_agents_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE; +ALTER TABLE ONLY tailnet_client_subscriptions + ADD CONSTRAINT tailnet_client_subscriptions_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE; + ALTER TABLE ONLY tailnet_clients ADD CONSTRAINT tailnet_clients_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE; diff --git a/coderd/database/migrations/000152_pg_coordinator_single_tailnet.up.sql b/coderd/database/migrations/000152_pg_coordinator_single_tailnet.up.sql deleted file mode 100644 index f59e893942b27..0000000000000 --- a/coderd/database/migrations/000152_pg_coordinator_single_tailnet.up.sql +++ /dev/null @@ -1,42 +0,0 @@ -BEGIN; - -ALTER TABLE - tailnet_clients -ADD COLUMN - agent_ids uuid[]; - -UPDATE - tailnet_clients -SET - agent_ids = ARRAY[agent_id]::uuid[]; - -ALTER TABLE - tailnet_clients -ALTER COLUMN - agent_ids SET NOT NULL; - - -CREATE INDEX idx_tailnet_clients_agent_ids ON tailnet_clients USING GIN (agent_ids); - -CREATE OR REPLACE FUNCTION tailnet_notify_client_change() RETURNS trigger - LANGUAGE plpgsql - AS $$ -BEGIN - -- check new first to get the updated agent ids. - IF (NEW IS NOT NULL) THEN - PERFORM pg_notify('tailnet_client_update', NEW.id || ',' || array_to_string(NEW.agent_ids, ',')); - RETURN NULL; - END IF; - IF (OLD IS NOT NULL) THEN - PERFORM pg_notify('tailnet_client_update', OLD.id || ',' || array_to_string(OLD.agent_ids, ',')); - RETURN NULL; - END IF; -END; -$$; - -ALTER TABLE - tailnet_clients -DROP COLUMN - agent_id; - -COMMIT; diff --git a/coderd/database/migrations/000152_pg_coordinator_single_tailnet.down.sql b/coderd/database/migrations/000154_pg_coordinator_single_tailnet.down.sql similarity index 100% rename from coderd/database/migrations/000152_pg_coordinator_single_tailnet.down.sql rename to coderd/database/migrations/000154_pg_coordinator_single_tailnet.down.sql diff --git a/coderd/database/migrations/000154_pg_coordinator_single_tailnet.up.sql b/coderd/database/migrations/000154_pg_coordinator_single_tailnet.up.sql new file mode 100644 index 0000000000000..800594a5b00c2 --- /dev/null +++ b/coderd/database/migrations/000154_pg_coordinator_single_tailnet.up.sql @@ -0,0 +1,97 @@ +BEGIN; + +CREATE TABLE tailnet_client_subscriptions ( + client_id uuid NOT NULL, + coordinator_id uuid NOT NULL, + -- this isn't a foreign key since it's more of a list of agents the client + -- *wants* to connect to, and they don't necessarily have to currently + -- exist in the db. + agent_id uuid NOT NULL, + updated_at timestamp with time zone NOT NULL, + PRIMARY KEY (client_id, coordinator_id, agent_id), + FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators (id) ON DELETE CASCADE + -- we don't keep a foreign key to the tailnet_clients table since there's + -- not a great way to guarantee that a subscription is always added after + -- the client is inserted. clients are only created after the client sends + -- its first node update, which can take an undetermined amount of time. +); + +CREATE FUNCTION tailnet_notify_client_subscription_change() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + IF (NEW IS NOT NULL) THEN + PERFORM pg_notify('tailnet_client_update', NEW.client_id || ',' || NEW.agent_id); + RETURN NULL; + END IF; + IF (OLD IS NOT NULL) THEN + PERFORM pg_notify('tailnet_client_update', OLD.client_id || ',' || OLD.agent_id); + RETURN NULL; + END IF; +END; +$$; + +CREATE TRIGGER tailnet_notify_client_subscription_change + AFTER INSERT OR UPDATE OR DELETE ON tailnet_client_subscriptions + FOR EACH ROW +EXECUTE PROCEDURE tailnet_notify_client_subscription_change(); + +CREATE OR REPLACE FUNCTION tailnet_notify_client_change() RETURNS trigger + LANGUAGE plpgsql + AS $$ +DECLARE + var_client_id uuid; + var_agent_ids uuid[]; +BEGIN + IF (NEW.id IS NOT NULL) THEN + IF (NEW.node IS NULL) THEN + return NULL; + END IF; + + var_client_id = NEW.id; + SELECT + array_agg(agent_id) + INTO + var_agent_ids + FROM + tailnet_client_subscriptions subs + WHERE + subs.client_id = NEW.id AND + subs.coordinator_id = NEW.coordinator_id; + ELSIF (OLD.id IS NOT NULL) THEN + -- if new is null and old is not null, that means the row was deleted. + var_client_id = OLD.id; + WITH agent_ids AS ( + DELETE FROM + tailnet_client_subscriptions subs + WHERE + subs.client_id = OLD.id AND + subs.coordinator_id = OLD.coordinator_id + RETURNING + subs.agent_id + ) + SELECT + array_agg(agent_id) + INTO + var_agent_ids + FROM + agent_ids; + END IF; + + -- Read all agents the client is subscribed to, so we can notify them. + -- No agents to notify + if (var_agent_ids IS NULL) THEN + return NULL; + END IF; + + PERFORM pg_notify('tailnet_client_update', var_client_id || ',' || array_to_string(var_agent_ids, ',')); + return NULL; +END; +$$; + +ALTER TABLE + tailnet_clients +DROP COLUMN + agent_id; + +COMMIT; diff --git a/coderd/database/models.go b/coderd/database/models.go index 1a95dd6e5cd70..aaa141bca68ff 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -1785,7 +1785,13 @@ type TailnetClient struct { CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` UpdatedAt time.Time `db:"updated_at" json:"updated_at"` Node json.RawMessage `db:"node" json:"node"` - AgentIds []uuid.UUID `db:"agent_ids" json:"agent_ids"` +} + +type TailnetClientSubscription struct { + ClientID uuid.UUID `db:"client_id" json:"client_id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` + AgentID uuid.UUID `db:"agent_id" json:"agent_id"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` } // We keep this separate from replicas in case we need to break the coordinator out into its own service diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 15caa096e3fb0..53c064b9d2910 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -41,6 +41,7 @@ type sqlcQuerier interface { DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error DeleteTailnetAgent(ctx context.Context, arg DeleteTailnetAgentParams) (DeleteTailnetAgentRow, error) DeleteTailnetClient(ctx context.Context, arg DeleteTailnetClientParams) (DeleteTailnetClientRow, error) + DeleteTailnetClientSubscription(ctx context.Context, arg DeleteTailnetClientSubscriptionParams) (DeleteTailnetClientSubscriptionRow, error) GetAPIKeyByID(ctx context.Context, id string) (APIKey, error) // there is no unique constraint on empty token names GetAPIKeyByName(ctx context.Context, arg GetAPIKeyByNameParams) (APIKey, error) @@ -50,7 +51,7 @@ type sqlcQuerier interface { GetActiveUserCount(ctx context.Context) (int64, error) GetActiveWorkspaceBuildsByTemplateID(ctx context.Context, templateID uuid.UUID) ([]WorkspaceBuild, error) GetAllTailnetAgents(ctx context.Context) ([]TailnetAgent, error) - GetAllTailnetClients(ctx context.Context) ([]TailnetClient, error) + GetAllTailnetClients(ctx context.Context) ([]GetAllTailnetClientsRow, error) GetAppSecurityKey(ctx context.Context) (string, error) // GetAuditLogsBefore retrieves `row_limit` number of audit logs before the provided // ID. @@ -108,7 +109,7 @@ type sqlcQuerier interface { GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error) GetServiceBanner(ctx context.Context) (string, error) GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]TailnetAgent, error) - GetTailnetClientsForAgent(ctx context.Context, dollar_1 uuid.UUID) ([]TailnetClient, error) + GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]TailnetClient, error) // GetTemplateAppInsights returns the aggregate usage of each app in a given // timeframe. The result can be filtered on template_ids, meaning only user data // from workspaces based on those templates will be included. @@ -315,6 +316,7 @@ type sqlcQuerier interface { UpsertServiceBanner(ctx context.Context, value string) error UpsertTailnetAgent(ctx context.Context, arg UpsertTailnetAgentParams) (TailnetAgent, error) UpsertTailnetClient(ctx context.Context, arg UpsertTailnetClientParams) (TailnetClient, error) + UpsertTailnetClientSubscription(ctx context.Context, arg UpsertTailnetClientSubscriptionParams) error UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (TailnetCoordinator, error) } diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 2a6fd22e60fe9..3258f4b444a68 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -4138,6 +4138,32 @@ func (q *sqlQuerier) DeleteTailnetClient(ctx context.Context, arg DeleteTailnetC return i, err } +const deleteTailnetClientSubscription = `-- name: DeleteTailnetClientSubscription :one +DELETE +FROM tailnet_client_subscriptions +WHERE client_id = $1 and agent_id = $2 and coordinator_id = $3 +RETURNING client_id, agent_id, coordinator_id +` + +type DeleteTailnetClientSubscriptionParams struct { + ClientID uuid.UUID `db:"client_id" json:"client_id"` + AgentID uuid.UUID `db:"agent_id" json:"agent_id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` +} + +type DeleteTailnetClientSubscriptionRow struct { + ClientID uuid.UUID `db:"client_id" json:"client_id"` + AgentID uuid.UUID `db:"agent_id" json:"agent_id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` +} + +func (q *sqlQuerier) DeleteTailnetClientSubscription(ctx context.Context, arg DeleteTailnetClientSubscriptionParams) (DeleteTailnetClientSubscriptionRow, error) { + row := q.db.QueryRowContext(ctx, deleteTailnetClientSubscription, arg.ClientID, arg.AgentID, arg.CoordinatorID) + var i DeleteTailnetClientSubscriptionRow + err := row.Scan(&i.ClientID, &i.AgentID, &i.CoordinatorID) + return i, err +} + const getAllTailnetAgents = `-- name: GetAllTailnetAgents :many SELECT id, coordinator_id, updated_at, node FROM tailnet_agents @@ -4172,24 +4198,31 @@ func (q *sqlQuerier) GetAllTailnetAgents(ctx context.Context) ([]TailnetAgent, e } const getAllTailnetClients = `-- name: GetAllTailnetClients :many -SELECT id, coordinator_id, updated_at, node, agent_ids +SELECT tailnet_clients.id, tailnet_clients.coordinator_id, tailnet_clients.updated_at, tailnet_clients.node, array_agg(tailnet_client_subscriptions.agent_id)::uuid[] as agent_ids FROM tailnet_clients +LEFT JOIN tailnet_client_subscriptions +ON tailnet_clients.id = tailnet_client_subscriptions.client_id ` -func (q *sqlQuerier) GetAllTailnetClients(ctx context.Context) ([]TailnetClient, error) { +type GetAllTailnetClientsRow struct { + TailnetClient TailnetClient `db:"tailnet_client" json:"tailnet_client"` + AgentIds []uuid.UUID `db:"agent_ids" json:"agent_ids"` +} + +func (q *sqlQuerier) GetAllTailnetClients(ctx context.Context) ([]GetAllTailnetClientsRow, error) { rows, err := q.db.QueryContext(ctx, getAllTailnetClients) if err != nil { return nil, err } defer rows.Close() - var items []TailnetClient + var items []GetAllTailnetClientsRow for rows.Next() { - var i TailnetClient + var i GetAllTailnetClientsRow if err := rows.Scan( - &i.ID, - &i.CoordinatorID, - &i.UpdatedAt, - &i.Node, + &i.TailnetClient.ID, + &i.TailnetClient.CoordinatorID, + &i.TailnetClient.UpdatedAt, + &i.TailnetClient.Node, pq.Array(&i.AgentIds), ); err != nil { return nil, err @@ -4240,13 +4273,17 @@ func (q *sqlQuerier) GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]Tail } const getTailnetClientsForAgent = `-- name: GetTailnetClientsForAgent :many -SELECT id, coordinator_id, updated_at, node, agent_ids +SELECT id, coordinator_id, updated_at, node FROM tailnet_clients -WHERE $1::uuid = ANY(agent_ids) +WHERE id IN ( + SELECT tailnet_client_subscriptions.client_id + FROM tailnet_client_subscriptions + WHERE tailnet_client_subscriptions.agent_id = $1 +) ` -func (q *sqlQuerier) GetTailnetClientsForAgent(ctx context.Context, dollar_1 uuid.UUID) ([]TailnetClient, error) { - rows, err := q.db.QueryContext(ctx, getTailnetClientsForAgent, dollar_1) +func (q *sqlQuerier) GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]TailnetClient, error) { + rows, err := q.db.QueryContext(ctx, getTailnetClientsForAgent, agentID) if err != nil { return nil, err } @@ -4259,7 +4296,6 @@ func (q *sqlQuerier) GetTailnetClientsForAgent(ctx context.Context, dollar_1 uui &i.CoordinatorID, &i.UpdatedAt, &i.Node, - pq.Array(&i.AgentIds), ); err != nil { return nil, err } @@ -4316,47 +4352,67 @@ INSERT INTO tailnet_clients ( id, coordinator_id, - agent_ids, node, updated_at ) VALUES - ($1, $2, $3, $4, now() at time zone 'utc') + ($1, $2, $3, now() at time zone 'utc') ON CONFLICT (id, coordinator_id) DO UPDATE SET id = $1, coordinator_id = $2, - agent_ids = $3, - node = $4, + node = $3, updated_at = now() at time zone 'utc' -RETURNING id, coordinator_id, updated_at, node, agent_ids +RETURNING id, coordinator_id, updated_at, node ` type UpsertTailnetClientParams struct { ID uuid.UUID `db:"id" json:"id"` CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` - AgentIds []uuid.UUID `db:"agent_ids" json:"agent_ids"` Node json.RawMessage `db:"node" json:"node"` } func (q *sqlQuerier) UpsertTailnetClient(ctx context.Context, arg UpsertTailnetClientParams) (TailnetClient, error) { - row := q.db.QueryRowContext(ctx, upsertTailnetClient, - arg.ID, - arg.CoordinatorID, - pq.Array(arg.AgentIds), - arg.Node, - ) + row := q.db.QueryRowContext(ctx, upsertTailnetClient, arg.ID, arg.CoordinatorID, arg.Node) var i TailnetClient err := row.Scan( &i.ID, &i.CoordinatorID, &i.UpdatedAt, &i.Node, - pq.Array(&i.AgentIds), ) return i, err } +const upsertTailnetClientSubscription = `-- name: UpsertTailnetClientSubscription :exec +INSERT INTO + tailnet_client_subscriptions ( + client_id, + coordinator_id, + agent_id, + updated_at +) +VALUES + ($1, $2, $3, now() at time zone 'utc') +ON CONFLICT (client_id, coordinator_id, agent_id) +DO UPDATE SET + client_id = $1, + coordinator_id = $2, + agent_id = $3, + updated_at = now() at time zone 'utc' +` + +type UpsertTailnetClientSubscriptionParams struct { + ClientID uuid.UUID `db:"client_id" json:"client_id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` + AgentID uuid.UUID `db:"agent_id" json:"agent_id"` +} + +func (q *sqlQuerier) UpsertTailnetClientSubscription(ctx context.Context, arg UpsertTailnetClientSubscriptionParams) error { + _, err := q.db.ExecContext(ctx, upsertTailnetClientSubscription, arg.ClientID, arg.CoordinatorID, arg.AgentID) + return err +} + const upsertTailnetCoordinator = `-- name: UpsertTailnetCoordinator :one INSERT INTO tailnet_coordinators ( diff --git a/coderd/database/queries/tailnet.sql b/coderd/database/queries/tailnet.sql index 1b63261b90d86..86a0a1b9abd3e 100644 --- a/coderd/database/queries/tailnet.sql +++ b/coderd/database/queries/tailnet.sql @@ -3,21 +3,36 @@ INSERT INTO tailnet_clients ( id, coordinator_id, - agent_ids, node, updated_at ) VALUES - ($1, $2, $3, $4, now() at time zone 'utc') + ($1, $2, $3, now() at time zone 'utc') ON CONFLICT (id, coordinator_id) DO UPDATE SET id = $1, coordinator_id = $2, - agent_ids = $3, - node = $4, + node = $3, updated_at = now() at time zone 'utc' RETURNING *; +-- name: UpsertTailnetClientSubscription :exec +INSERT INTO + tailnet_client_subscriptions ( + client_id, + coordinator_id, + agent_id, + updated_at +) +VALUES + ($1, $2, $3, now() at time zone 'utc') +ON CONFLICT (client_id, coordinator_id, agent_id) +DO UPDATE SET + client_id = $1, + coordinator_id = $2, + agent_id = $3, + updated_at = now() at time zone 'utc'; + -- name: UpsertTailnetAgent :one INSERT INTO tailnet_agents ( @@ -43,6 +58,12 @@ FROM tailnet_clients WHERE id = $1 and coordinator_id = $2 RETURNING id, coordinator_id; +-- name: DeleteTailnetClientSubscription :one +DELETE +FROM tailnet_client_subscriptions +WHERE client_id = $1 and agent_id = $2 and coordinator_id = $3 +RETURNING client_id, agent_id, coordinator_id; + -- name: DeleteTailnetAgent :one DELETE FROM tailnet_agents @@ -66,11 +87,17 @@ FROM tailnet_agents; -- name: GetTailnetClientsForAgent :many SELECT * FROM tailnet_clients -WHERE $1::uuid = ANY(agent_ids); +WHERE id IN ( + SELECT tailnet_client_subscriptions.client_id + FROM tailnet_client_subscriptions + WHERE tailnet_client_subscriptions.agent_id = $1 +); -- name: GetAllTailnetClients :many -SELECT * -FROM tailnet_clients; +SELECT sqlc.embed(tailnet_clients), array_agg(tailnet_client_subscriptions.agent_id)::uuid[] as agent_ids +FROM tailnet_clients +LEFT JOIN tailnet_client_subscriptions +ON tailnet_clients.id = tailnet_client_subscriptions.client_id; -- name: UpsertTailnetCoordinator :one INSERT INTO diff --git a/enterprise/tailnet/connio.go b/enterprise/tailnet/connio.go index 72e378f045b1f..5d429ec3d5398 100644 --- a/enterprise/tailnet/connio.go +++ b/enterprise/tailnet/connio.go @@ -18,14 +18,13 @@ import ( // uuid.Nil. It reads node updates via its decoder, then pushes them onto the bindings channel. It receives mappings // via its updates TrackedConn, which then writes them. type connIO struct { - pCtx context.Context - ctx context.Context - cancel context.CancelFunc - logger slog.Logger - subscriptions []uuid.UUID - decoder *json.Decoder - updates *agpl.TrackedConn - bindings chan<- binding + pCtx context.Context + ctx context.Context + cancel context.CancelFunc + logger slog.Logger + decoder *json.Decoder + updates *agpl.TrackedConn + bindings chan<- binding } func newConnIO(pCtx context.Context, @@ -33,25 +32,23 @@ func newConnIO(pCtx context.Context, bindings chan<- binding, conn net.Conn, id uuid.UUID, - subs []uuid.UUID, name string, kind agpl.QueueKind, -) *connIO { +) (*connIO, error) { ctx, cancel := context.WithCancel(pCtx) c := &connIO{ - pCtx: pCtx, - ctx: ctx, - cancel: cancel, - logger: logger, - subscriptions: subs, - decoder: json.NewDecoder(conn), - updates: agpl.NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, kind), - bindings: bindings, + pCtx: pCtx, + ctx: ctx, + cancel: cancel, + logger: logger, + decoder: json.NewDecoder(conn), + updates: agpl.NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, kind), + bindings: bindings, } go c.recvLoop() go c.updates.SendUpdates() logger.Info(ctx, "serving connection") - return c + return c, nil } func (c *connIO) recvLoop() { @@ -90,8 +87,7 @@ func (c *connIO) recvLoop() { id: c.UniqueID(), kind: c.Kind(), }, - subscriptions: c.subscriptions, - node: &node, + node: &node, } if err := sendCtx(c.ctx, c.bindings, b); err != nil { c.logger.Debug(c.ctx, "recvLoop ctx expired", slog.Error(err)) diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 4a55650303abb..a3ad0d0ccf34d 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -26,15 +26,16 @@ import ( ) const ( - EventHeartbeats = "tailnet_coordinator_heartbeat" - eventClientUpdate = "tailnet_client_update" - eventAgentUpdate = "tailnet_agent_update" - HeartbeatPeriod = time.Second * 2 - MissedHeartbeats = 3 - numQuerierWorkers = 10 - numBinderWorkers = 10 - dbMaxBackoff = 10 * time.Second - cleanupPeriod = time.Hour + EventHeartbeats = "tailnet_coordinator_heartbeat" + eventClientUpdate = "tailnet_client_update" + eventClientSubscription = "tailnet_client_subscription_update" + eventAgentUpdate = "tailnet_agent_update" + HeartbeatPeriod = time.Second * 2 + MissedHeartbeats = 3 + numQuerierWorkers = 10 + numBinderWorkers = 10 + dbMaxBackoff = 10 * time.Second + cleanupPeriod = time.Hour ) // pgCoord is a postgres-backed coordinator @@ -70,9 +71,10 @@ type pgCoord struct { pubsub pubsub.Pubsub store database.Store - bindings chan binding - newConnections chan agpl.Queue - id uuid.UUID + bindings chan binding + newConnections chan agpl.Queue + newSubscriptions chan subscribe + id uuid.UUID cancel context.CancelFunc closeOnce sync.Once @@ -106,21 +108,23 @@ func NewPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store logger = logger.Named("pgcoord").With(slog.F("coordinator_id", id)) bCh := make(chan binding) cCh := make(chan agpl.Queue) + sCh := make(chan subscribe) // signals when first heartbeat has been sent, so it's safe to start binding. fHB := make(chan struct{}) c := &pgCoord{ - ctx: ctx, - cancel: cancel, - logger: logger, - pubsub: ps, - store: store, - binder: newBinder(ctx, logger, id, store, bCh, fHB), - bindings: bCh, - newConnections: cCh, - id: id, - querier: newQuerier(ctx, logger, ps, store, id, cCh, numQuerierWorkers, fHB), - closed: make(chan struct{}), + ctx: ctx, + cancel: cancel, + logger: logger, + pubsub: ps, + store: store, + binder: newBinder(ctx, logger, id, store, bCh, fHB), + bindings: bCh, + newConnections: cCh, + newSubscriptions: sCh, + id: id, + querier: newQuerier(ctx, logger, id, ps, store, id, cCh, sCh, numQuerierWorkers, fHB), + closed: make(chan struct{}), } logger.Info(ctx, "starting coordinator") return c, nil @@ -131,18 +135,25 @@ func (c *pgCoord) ServeMultiAgent(id uuid.UUID) (agpl.MultiAgentConn, error) { ID: id, AgentIsLegacyFunc: func(agentID uuid.UUID) bool { return true }, OnSubscribe: func(enq agpl.Queue, agent uuid.UUID) (*agpl.Node, error) { - c.querier.newClientSubscription(enq, agent) - return c.Node(agent), nil + err := sendCtx(c.ctx, c.newSubscriptions, subscribe{ + q: enq, + agentID: agent, + active: true, + }) + + return c.Node(agent), err }, OnUnsubscribe: func(enq agpl.Queue, agent uuid.UUID) error { - c.querier.removeClientSubscription(enq, agent) - return nil + return sendCtx(c.ctx, c.newSubscriptions, subscribe{ + q: enq, + agentID: agent, + active: false, + }) }, OnNodeUpdate: func(id uuid.UUID, node *agpl.Node) error { return sendCtx(c.ctx, c.bindings, binding{ - bKey: bKey{id, agpl.QueueKindClient}, - node: node, - subscriptions: c.querier.getClientSubscriptions(id), + bKey: bKey{id, agpl.QueueKindClient}, + node: node, }) }, OnRemove: func(enq agpl.Queue) { @@ -196,12 +207,21 @@ func (c *pgCoord) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) erro slog.Error(err)) } }() - cIO := newConnIO(c.ctx, c.logger, c.bindings, conn, id, []uuid.UUID{agent}, id.String(), agpl.QueueKindClient) + cIO, err := newConnIO(c.ctx, c.logger, c.bindings, conn, id, id.String(), agpl.QueueKindClient) + if err != nil { + return err + } if err := sendCtx(c.ctx, c.newConnections, agpl.Queue(cIO)); err != nil { // can only be a context error, no need to log here. return err } - c.querier.newClientSubscription(cIO, agent) + if err := sendCtx(c.ctx, c.newSubscriptions, subscribe{ + q: agpl.Queue(cIO), + agentID: agent, + active: true, + }); err != nil { + return err + } <-cIO.ctx.Done() return nil } @@ -216,7 +236,10 @@ func (c *pgCoord) ServeAgent(conn net.Conn, id uuid.UUID, name string) error { } }() logger := c.logger.With(slog.F("name", name)) - cIO := newConnIO(c.ctx, logger, c.bindings, conn, id, nil, name, agpl.QueueKindAgent) + cIO, err := newConnIO(c.ctx, logger, c.bindings, conn, id, name, agpl.QueueKindAgent) + if err != nil { + return err + } if err := sendCtx(c.ctx, c.newConnections, agpl.Queue(cIO)); err != nil { // can only be a context error, no need to log here. return err @@ -251,9 +274,7 @@ type bKey struct { // binding represents an association between a client or agent and a Node. type binding struct { bKey - // subscriptions is a list of agents a client is subscribed to. - subscriptions []uuid.UUID - node *agpl.Node + node *agpl.Node } func (b *binding) isAgent() bool { return b.kind == agpl.QueueKindAgent } @@ -370,11 +391,10 @@ func (b *binder) writeOne(bnd binding) error { _, err = b.store.UpsertTailnetClient(b.ctx, database.UpsertTailnetClientParams{ ID: bnd.id, CoordinatorID: b.coordinatorID, - AgentIds: bnd.subscriptions, Node: nodeRaw, }) b.logger.Debug(b.ctx, "upserted client binding", - slog.F("subscriptions", bnd.subscriptions), slog.F("client_id", bnd.id), + slog.F("client_id", bnd.id), slog.F("node", nodeRaw), slog.Error(err)) case bnd.isClient() && len(nodeRaw) == 0: _, err = b.store.DeleteTailnetClient(b.ctx, database.DeleteTailnetClientParams{ @@ -382,7 +402,7 @@ func (b *binder) writeOne(bnd binding) error { CoordinatorID: b.coordinatorID, }) b.logger.Debug(b.ctx, "deleted client binding", - slog.F("subscriptions", bnd.subscriptions), slog.F("client_id", bnd.id)) + slog.F("client_id", bnd.id)) if xerrors.Is(err, sql.ErrNoRows) { // treat deletes as idempotent err = nil @@ -539,15 +559,26 @@ func (m *mapper) mappingsToNodes(mappings []mapping) []*agpl.Node { return nodes } +type subscribe struct { + q agpl.Queue + agentID uuid.UUID + // whether the subscription should be active. if true, the subscription is + // added. if false, the subscription is removed. + active bool +} + // querier is responsible for monitoring pubsub notifications and querying the database for the mappings that all // connected clients and agents need. It also checks heartbeats and withdraws mappings from coordinators that have // failed heartbeats. type querier struct { - ctx context.Context - logger slog.Logger - pubsub pubsub.Pubsub - store database.Store - newConnections chan agpl.Queue + ctx context.Context + logger slog.Logger + coordinatorID uuid.UUID + pubsub pubsub.Pubsub + store database.Store + + newConnections chan agpl.Queue + newSubscriptions chan subscribe workQ *workQ[mKey] @@ -571,20 +602,24 @@ type countedMapper struct { func newQuerier(ctx context.Context, logger slog.Logger, + coordinatorID uuid.UUID, ps pubsub.Pubsub, store database.Store, self uuid.UUID, newConnections chan agpl.Queue, + newSubscriptions chan subscribe, numWorkers int, - firstHeartbeat chan<- struct{}, + firstHeartbeat chan struct{}, ) *querier { updates := make(chan hbUpdate) q := &querier{ ctx: ctx, logger: logger.Named("querier"), + coordinatorID: coordinatorID, pubsub: ps, store: store, newConnections: newConnections, + newSubscriptions: newSubscriptions, workQ: newWorkQ[mKey](ctx), heartbeats: newHeartbeats(ctx, logger, ps, store, self, updates, firstHeartbeat), mappers: make(map[mKey]*countedMapper), @@ -594,11 +629,16 @@ func newQuerier(ctx context.Context, healthy: true, // assume we start healthy } q.subscribe() - go q.handleNewConnections() - for i := 0; i < numWorkers; i++ { - go q.worker() - } - go q.handleUpdates() + + go func() { + <-firstHeartbeat + go q.handleNewConnections() + go q.handleNewSubscriptions() + for i := 0; i < numWorkers; i++ { + go q.worker() + } + go q.handleUpdates() + }() return q } @@ -620,6 +660,21 @@ func (q *querier) handleNewConnections() { } } +func (q *querier) handleNewSubscriptions() { + for { + select { + case <-q.ctx.Done(): + return + case c := <-q.newSubscriptions: + if c.active { + q.newClientSubscription(c.q, c.agentID) + } else { + q.removeClientSubscription(c.q, c.agentID) + } + } + } +} + func (q *querier) newAgentConn(c agpl.Queue) { q.mu.Lock() defer q.mu.Unlock() @@ -656,7 +711,7 @@ func (q *querier) newAgentConn(c agpl.Queue) { go q.waitCleanupConn(c) } -func (q *querier) newClientSubscription(c agpl.Queue, agentID uuid.UUID) { +func (q *querier) newClientSubscription(c agpl.Queue, agentID uuid.UUID) error { q.mu.Lock() defer q.mu.Unlock() @@ -664,6 +719,15 @@ func (q *querier) newClientSubscription(c agpl.Queue, agentID uuid.UUID) { q.clientSubscriptions[c.UniqueID()] = map[uuid.UUID]struct{}{} } + err := q.store.UpsertTailnetClientSubscription(q.ctx, database.UpsertTailnetClientSubscriptionParams{ + ClientID: c.UniqueID(), + CoordinatorID: q.coordinatorID, + AgentID: agentID, + }) + if err != nil { + return xerrors.Errorf("upsert subscription: %w", err) + } + mk := mKey{ agent: agentID, kind: c.Kind(), @@ -682,23 +746,33 @@ func (q *querier) newClientSubscription(c agpl.Queue, agentID uuid.UUID) { q.workQ.enqueue(mk) } if err := sendCtx(cm.ctx, cm.add, c); err != nil { - return + return xerrors.Errorf("send subscription to mapper: %w", err) } q.clientSubscriptions[c.UniqueID()][agentID] = struct{}{} cm.count++ + return nil } -func (q *querier) removeClientSubscription(c agpl.Queue, agentID uuid.UUID) { +func (q *querier) removeClientSubscription(c agpl.Queue, agentID uuid.UUID) error { q.mu.Lock() defer q.mu.Unlock() + _, err := q.store.DeleteTailnetClientSubscription(q.ctx, database.DeleteTailnetClientSubscriptionParams{ + ClientID: c.UniqueID(), + CoordinatorID: q.coordinatorID, + AgentID: agentID, + }) + if err != nil { + return xerrors.Errorf("delete subscription: %w", err) + } + mk := mKey{ agent: agentID, kind: c.Kind(), } cm := q.mappers[mk] if err := sendCtx(cm.ctx, cm.del, c); err != nil { - return + return xerrors.Errorf("send deletion to mapper: %w", err) } delete(q.clientSubscriptions[c.UniqueID()], agentID) cm.count-- @@ -706,6 +780,7 @@ func (q *querier) removeClientSubscription(c agpl.Queue, agentID uuid.UUID) { cm.cancel() delete(q.mappers, mk) } + return nil } func (q *querier) newClientConn(c agpl.Queue) { @@ -724,16 +799,6 @@ func (q *querier) newClientConn(c agpl.Queue) { go q.waitCleanupConn(c) } -func (q *querier) getClientSubscriptions(id uuid.UUID) []uuid.UUID { - q.mu.Lock() - defer q.mu.Unlock() - subs := []uuid.UUID{} - for sub := range q.clientSubscriptions[id] { - subs = append(subs, sub) - } - return subs -} - func (q *querier) waitCleanupConn(c agpl.Queue) { <-c.Done() q.cleanupConn(c) @@ -956,6 +1021,7 @@ func (q *querier) listenClient(_ context.Context, msg []byte, err error) { } if err != nil { q.logger.Warn(q.ctx, "unhandled pubsub error", slog.Error(err)) + return } client, agents, err := parseClientUpdate(string(msg)) if err != nil { @@ -1108,7 +1174,7 @@ func (q *querier) getAll(ctx context.Context) (map[uuid.UUID]database.TailnetAge clientsMap := map[uuid.UUID][]database.TailnetClient{} for _, client := range clients { for _, agentID := range client.AgentIds { - clientsMap[agentID] = append(clientsMap[agentID], client) + clientsMap[agentID] = append(clientsMap[agentID], client.TailnetClient) } } From 618b0e0fdd20e22e5937d0223d484239622e1ab1 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Thu, 7 Sep 2023 21:44:19 +0000 Subject: [PATCH 05/21] fix migrations --- coderd/database/dbauthz/dbauthz.go | 10 ++++- coderd/database/dbfake/dbfake.go | 9 +---- ...154_pg_coordinator_single_tailnet.down.sql | 24 ------------ ...155_pg_coordinator_single_tailnet.down.sql | 39 +++++++++++++++++++ ...0155_pg_coordinator_single_tailnet.up.sql} | 14 +++---- .../fixtures/000130_ha_coordinator.up.sql | 2 +- ...00155_pg_coordinator_single_tailnet.up.sql | 9 +++++ enterprise/tailnet/pgcoord_test.go | 2 +- 8 files changed, 65 insertions(+), 44 deletions(-) delete mode 100644 coderd/database/migrations/000154_pg_coordinator_single_tailnet.down.sql create mode 100644 coderd/database/migrations/000155_pg_coordinator_single_tailnet.down.sql rename coderd/database/migrations/{000154_pg_coordinator_single_tailnet.up.sql => 000155_pg_coordinator_single_tailnet.up.sql} (92%) create mode 100644 coderd/database/migrations/testdata/fixtures/000155_pg_coordinator_single_tailnet.up.sql diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 0fa1c616e209e..7b2d8dd6e0f0b 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -768,7 +768,10 @@ func (q *querier) DeleteTailnetClient(ctx context.Context, arg database.DeleteTa } func (q *querier) DeleteTailnetClientSubscription(ctx context.Context, arg database.DeleteTailnetClientSubscriptionParams) (database.DeleteTailnetClientSubscriptionRow, error) { - panic("not implemented") + if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil { + return database.DeleteTailnetClientSubscriptionRow{}, err + } + return q.db.DeleteTailnetClientSubscription(ctx, arg) } func (q *querier) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) { @@ -2783,7 +2786,10 @@ func (q *querier) UpsertTailnetClient(ctx context.Context, arg database.UpsertTa } func (q *querier) UpsertTailnetClientSubscription(ctx context.Context, arg database.UpsertTailnetClientSubscriptionParams) error { - panic("not implemented") + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { + return err + } + return q.db.UpsertTailnetClientSubscription(ctx, arg) } func (q *querier) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (database.TailnetCoordinator, error) { diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index 6d102098c0b18..c0b049063c8cf 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -909,13 +909,8 @@ func (*FakeQuerier) DeleteTailnetClient(context.Context, database.DeleteTailnetC return database.DeleteTailnetClientRow{}, ErrUnimplemented } -func (q *FakeQuerier) DeleteTailnetClientSubscription(ctx context.Context, arg database.DeleteTailnetClientSubscriptionParams) (database.DeleteTailnetClientSubscriptionRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.DeleteTailnetClientSubscriptionRow{}, err - } - - panic("not implemented") +func (*FakeQuerier) DeleteTailnetClientSubscription(context.Context, database.DeleteTailnetClientSubscriptionParams) (database.DeleteTailnetClientSubscriptionRow, error) { + return database.DeleteTailnetClientSubscriptionRow{}, ErrUnimplemented } func (q *FakeQuerier) GetAPIKeyByID(_ context.Context, id string) (database.APIKey, error) { diff --git a/coderd/database/migrations/000154_pg_coordinator_single_tailnet.down.sql b/coderd/database/migrations/000154_pg_coordinator_single_tailnet.down.sql deleted file mode 100644 index ddc05facaf677..0000000000000 --- a/coderd/database/migrations/000154_pg_coordinator_single_tailnet.down.sql +++ /dev/null @@ -1,24 +0,0 @@ -BEGIN; - -ALTER TABLE - tailnet_clients -ADD COLUMN - agent_id uuid; - -UPDATE - tailnet_clients -SET - -- grab just the first agent_id, or default to an empty UUID. - agent_id = COALESCE(agent_ids[0], '00000000-0000-0000-0000-000000000000'::uuid); - -ALTER TABLE - tailnet_clients -ALTER COLUMN - agent_id SET NOT NULL; - -ALTER TABLE - tailnet_clients -DROP COLUMN - agent_ids; - -COMMIT; diff --git a/coderd/database/migrations/000155_pg_coordinator_single_tailnet.down.sql b/coderd/database/migrations/000155_pg_coordinator_single_tailnet.down.sql new file mode 100644 index 0000000000000..7cc418489f59a --- /dev/null +++ b/coderd/database/migrations/000155_pg_coordinator_single_tailnet.down.sql @@ -0,0 +1,39 @@ +BEGIN; + +ALTER TABLE + tailnet_clients +ADD COLUMN + agent_id uuid; + +UPDATE + tailnet_clients +SET + -- there's no reason for us to try and preserve data since coordinators will + -- have to restart anyways, which will create all of the client mappings. + agent_id = '00000000-0000-0000-0000-000000000000'::uuid; + +ALTER TABLE + tailnet_clients +ALTER COLUMN + agent_id SET NOT NULL; + +DROP TABLE tailnet_client_subscriptions; +DROP FUNCTION tailnet_notify_client_subscription_change; + +-- update the tailnet_clients trigger to the old version. +CREATE OR REPLACE FUNCTION tailnet_notify_client_change() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + IF (OLD IS NOT NULL) THEN + PERFORM pg_notify('tailnet_client_update', OLD.id || ',' || OLD.agent_id); + RETURN NULL; + END IF; + IF (NEW IS NOT NULL) THEN + PERFORM pg_notify('tailnet_client_update', NEW.id || ',' || NEW.agent_id); + RETURN NULL; + END IF; +END; +$$; + +COMMIT; diff --git a/coderd/database/migrations/000154_pg_coordinator_single_tailnet.up.sql b/coderd/database/migrations/000155_pg_coordinator_single_tailnet.up.sql similarity index 92% rename from coderd/database/migrations/000154_pg_coordinator_single_tailnet.up.sql rename to coderd/database/migrations/000155_pg_coordinator_single_tailnet.up.sql index 800594a5b00c2..663e23ab4f53f 100644 --- a/coderd/database/migrations/000154_pg_coordinator_single_tailnet.up.sql +++ b/coderd/database/migrations/000155_pg_coordinator_single_tailnet.up.sql @@ -1,13 +1,13 @@ BEGIN; CREATE TABLE tailnet_client_subscriptions ( - client_id uuid NOT NULL, + client_id uuid NOT NULL, coordinator_id uuid NOT NULL, -- this isn't a foreign key since it's more of a list of agents the client -- *wants* to connect to, and they don't necessarily have to currently -- exist in the db. - agent_id uuid NOT NULL, - updated_at timestamp with time zone NOT NULL, + agent_id uuid NOT NULL, + updated_at timestamp with time zone NOT NULL, PRIMARY KEY (client_id, coordinator_id, agent_id), FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators (id) ON DELETE CASCADE -- we don't keep a foreign key to the tailnet_clients table since there's @@ -23,8 +23,7 @@ BEGIN IF (NEW IS NOT NULL) THEN PERFORM pg_notify('tailnet_client_update', NEW.client_id || ',' || NEW.agent_id); RETURN NULL; - END IF; - IF (OLD IS NOT NULL) THEN + ELSIF (OLD IS NOT NULL) THEN PERFORM pg_notify('tailnet_client_update', OLD.client_id || ',' || OLD.agent_id); RETURN NULL; END IF; @@ -44,10 +43,6 @@ DECLARE var_agent_ids uuid[]; BEGIN IF (NEW.id IS NOT NULL) THEN - IF (NEW.node IS NULL) THEN - return NULL; - END IF; - var_client_id = NEW.id; SELECT array_agg(agent_id) @@ -60,6 +55,7 @@ BEGIN subs.coordinator_id = NEW.coordinator_id; ELSIF (OLD.id IS NOT NULL) THEN -- if new is null and old is not null, that means the row was deleted. + -- simulate a foreign key by deleting all of the subscriptions. var_client_id = OLD.id; WITH agent_ids AS ( DELETE FROM diff --git a/coderd/database/migrations/testdata/fixtures/000130_ha_coordinator.up.sql b/coderd/database/migrations/testdata/fixtures/000130_ha_coordinator.up.sql index 8af4fa4827997..dbebd6d5dd384 100644 --- a/coderd/database/migrations/testdata/fixtures/000130_ha_coordinator.up.sql +++ b/coderd/database/migrations/testdata/fixtures/000130_ha_coordinator.up.sql @@ -18,7 +18,7 @@ VALUES ); INSERT INTO tailnet_agents -(id, coordinator_id, updated_at, node) + (id, coordinator_id, updated_at, node) VALUES ( 'c0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', diff --git a/coderd/database/migrations/testdata/fixtures/000155_pg_coordinator_single_tailnet.up.sql b/coderd/database/migrations/testdata/fixtures/000155_pg_coordinator_single_tailnet.up.sql new file mode 100644 index 0000000000000..b5b744d6d1dc8 --- /dev/null +++ b/coderd/database/migrations/testdata/fixtures/000155_pg_coordinator_single_tailnet.up.sql @@ -0,0 +1,9 @@ +INSERT INTO tailnet_client_subscriptions + (client_id, agent_id, coordinator_id, updated_at) +VALUES + ( + 'b0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', + 'c0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', + 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', + '2023-06-15 10:23:54+00' + ); diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index 51fd1074dcd8b..aa6de92b8def8 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -646,7 +646,7 @@ type testConn struct { func newTestConn(ids []uuid.UUID) *testConn { a := &testConn{} a.ws, a.serverWS = net.Pipe() - a.nodeChan = make(chan []*agpl.Node, 5) + a.nodeChan = make(chan []*agpl.Node) a.sendNode, a.errChan = agpl.ServeCoordinator(a.ws, func(nodes []*agpl.Node) error { a.nodeChan <- nodes return nil From f50f929cc4bb053d38cd80708985e9b18607ffec Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Thu, 7 Sep 2023 23:49:39 +0000 Subject: [PATCH 06/21] fixup! fix migrations --- coderd/database/dump.sql | 8 ++------ enterprise/tailnet/pgcoord.go | 12 ++++++++++-- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 62372b63be3fc..8b946aa4e6299 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -224,10 +224,6 @@ DECLARE var_agent_ids uuid[]; BEGIN IF (NEW.id IS NOT NULL) THEN - IF (NEW.node IS NULL) THEN - return NULL; - END IF; - var_client_id = NEW.id; SELECT array_agg(agent_id) @@ -240,6 +236,7 @@ BEGIN subs.coordinator_id = NEW.coordinator_id; ELSIF (OLD.id IS NOT NULL) THEN -- if new is null and old is not null, that means the row was deleted. + -- simulate a foreign key by deleting all of the subscriptions. var_client_id = OLD.id; WITH agent_ids AS ( DELETE FROM @@ -276,8 +273,7 @@ BEGIN IF (NEW IS NOT NULL) THEN PERFORM pg_notify('tailnet_client_update', NEW.client_id || ',' || NEW.agent_id); RETURN NULL; - END IF; - IF (OLD IS NOT NULL) THEN + ELSIF (OLD IS NOT NULL) THEN PERFORM pg_notify('tailnet_client_update', OLD.client_id || ',' || OLD.agent_id); RETURN NULL; END IF; diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index a3ad0d0ccf34d..937c1ebc7c48e 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -667,9 +667,17 @@ func (q *querier) handleNewSubscriptions() { return case c := <-q.newSubscriptions: if c.active { - q.newClientSubscription(c.q, c.agentID) + err := q.newClientSubscription(c.q, c.agentID) + if err != nil { + q.logger.Error(q.ctx, "create client subscription", slog.Error(err), + slog.F("client_id", c.q.UniqueID()), slog.F("agent_id", c.agentID)) + } } else { - q.removeClientSubscription(c.q, c.agentID) + err := q.removeClientSubscription(c.q, c.agentID) + if err != nil { + q.logger.Error(q.ctx, "remove client subscription", slog.Error(err), + slog.F("client_id", c.q.UniqueID()), slog.F("agent_id", c.agentID)) + } } } } From 3a5bb7608a906a1ef63879aa9e6c81e4383cb3a8 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 8 Sep 2023 00:04:53 +0000 Subject: [PATCH 07/21] fixup! fix migrations --- enterprise/tailnet/pgcoord.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 937c1ebc7c48e..a40632dbda6ab 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -264,8 +264,8 @@ func sendCtx[A any](ctx context.Context, c chan<- A, a A) (err error) { } } -// bKey, or "binding key" identifies a client or agent in a binding. Agents have their client field set to uuid.Nil, -// while clients have their agent field set to uuid.Nil. +// bKey, or "binding key" identifies a client or agent in a binding. Agents and +// clients are differentiated by the kind field. type bKey struct { id uuid.UUID kind agpl.QueueKind @@ -826,7 +826,7 @@ func (q *querier) cleanupConn(c agpl.Queue) { cm, ok := q.mappers[mk] if ok { if err := sendCtx(cm.ctx, cm.del, c); err != nil { - return + continue } cm.count-- if cm.count == 0 { From 3ddc7837b3606ac89ac9f875f10b1785d7329e97 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 8 Sep 2023 00:29:57 +0000 Subject: [PATCH 08/21] fixup! fix migrations --- coderd/database/dbfake/dbfake.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index c0b049063c8cf..bbe91e2b396c6 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -6036,7 +6036,7 @@ func (*FakeQuerier) UpsertTailnetClient(context.Context, database.UpsertTailnetC return database.TailnetClient{}, ErrUnimplemented } -func (q *FakeQuerier) UpsertTailnetClientSubscription(ctx context.Context, arg database.UpsertTailnetClientSubscriptionParams) error { +func (*FakeQuerier) UpsertTailnetClientSubscription(context.Context, database.UpsertTailnetClientSubscriptionParams) error { return ErrUnimplemented } From dcb007d87064449e95134dbebb9c7644077ee4cb Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 15 Sep 2023 19:46:06 +0000 Subject: [PATCH 09/21] add subscriber subsystem --- coderd/coderd.go | 2 +- coderd/database/dbauthz/dbauthz.go | 11 +- coderd/database/dbfake/dbfake.go | 13 +- coderd/database/dbmetrics/dbmetrics.go | 13 +- coderd/database/dbmock/dbmock.go | 21 +- coderd/database/foreign_key_constraint.go | 1 + ...00155_pg_coordinator_single_tailnet.up.sql | 49 ++- coderd/database/querier.go | 3 +- coderd/database/queries.sql.go | 33 +- coderd/database/queries/tailnet.sql | 10 +- coderd/tailnet_test.go | 2 +- enterprise/coderd/workspaceproxycoordinate.go | 9 +- .../coderd/workspaceproxycoordinator_test.go | 6 +- enterprise/tailnet/coordinator.go | 4 +- enterprise/tailnet/pgcoord.go | 337 +++++++++++++----- enterprise/tailnet/pgcoord_test.go | 3 +- tailnet/coordinator.go | 6 +- 17 files changed, 358 insertions(+), 165 deletions(-) diff --git a/coderd/coderd.go b/coderd/coderd.go index 38e15fc044e6f..52c7740de481d 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -405,7 +405,7 @@ func New(options *Options) *API { api.DERPMap, options.DeploymentValues.DERP.Config.ForceWebSockets.Value(), func(context.Context) (tailnet.MultiAgentConn, error) { - return (*api.TailnetCoordinator.Load()).ServeMultiAgent(uuid.New()) + return (*api.TailnetCoordinator.Load()).ServeMultiAgent(uuid.New()), nil }, wsconncache.New(api._dialWorkspaceAgentTailnet, 0), api.TracerProvider, diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index fef0e4b0546db..bbf538e3be536 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -774,13 +774,20 @@ func (q *querier) DeleteTailnetClient(ctx context.Context, arg database.DeleteTa return q.db.DeleteTailnetClient(ctx, arg) } -func (q *querier) DeleteTailnetClientSubscription(ctx context.Context, arg database.DeleteTailnetClientSubscriptionParams) (database.DeleteTailnetClientSubscriptionRow, error) { +func (q *querier) DeleteTailnetClientSubscription(ctx context.Context, arg database.DeleteTailnetClientSubscriptionParams) error { if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil { - return database.DeleteTailnetClientSubscriptionRow{}, err + return err } return q.db.DeleteTailnetClientSubscription(ctx, arg) } +func (q *querier) DeleteAllTailnetClientSubscriptions(ctx context.Context, arg database.DeleteAllTailnetClientSubscriptionsParams) error { + if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil { + return err + } + return q.db.DeleteAllTailnetClientSubscriptions(ctx, arg) +} + func (q *querier) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) { return fetch(q.log, q.auth, q.db.GetAPIKeyByID)(ctx, id) } diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index 3b35eaa458d64..f3b092b8801a5 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -844,6 +844,15 @@ func (q *FakeQuerier) DeleteAPIKeysByUserID(_ context.Context, userID uuid.UUID) return nil } +func (q *FakeQuerier) DeleteAllTailnetClientSubscriptions(ctx context.Context, arg database.DeleteAllTailnetClientSubscriptionsParams) error { + err := validateDatabaseType(arg) + if err != nil { + return err + } + + panic("not implemented") +} + func (q *FakeQuerier) DeleteApplicationConnectAPIKeysByUserID(_ context.Context, userID uuid.UUID) error { q.mutex.Lock() defer q.mutex.Unlock() @@ -977,8 +986,8 @@ func (*FakeQuerier) DeleteTailnetClient(context.Context, database.DeleteTailnetC return database.DeleteTailnetClientRow{}, ErrUnimplemented } -func (*FakeQuerier) DeleteTailnetClientSubscription(context.Context, database.DeleteTailnetClientSubscriptionParams) (database.DeleteTailnetClientSubscriptionRow, error) { - return database.DeleteTailnetClientSubscriptionRow{}, ErrUnimplemented +func (*FakeQuerier) DeleteTailnetClientSubscription(context.Context, database.DeleteTailnetClientSubscriptionParams) error { + return ErrUnimplemented } func (q *FakeQuerier) GetAPIKeyByID(_ context.Context, id string) (database.APIKey, error) { diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 2f97510b8f482..0892166d8868c 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -121,6 +121,13 @@ func (m metricsStore) DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUI return err } +func (m metricsStore) DeleteAllTailnetClientSubscriptions(ctx context.Context, arg database.DeleteAllTailnetClientSubscriptionsParams) error { + start := time.Now() + r0 := m.s.DeleteAllTailnetClientSubscriptions(ctx, arg) + m.queryLatencies.WithLabelValues("DeleteAllTailnetClientSubscriptions").Observe(time.Since(start).Seconds()) + return r0 +} + func (m metricsStore) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error { start := time.Now() err := m.s.DeleteApplicationConnectAPIKeysByUserID(ctx, userID) @@ -202,11 +209,11 @@ func (m metricsStore) DeleteTailnetClient(ctx context.Context, arg database.Dele return m.s.DeleteTailnetClient(ctx, arg) } -func (m metricsStore) DeleteTailnetClientSubscription(ctx context.Context, arg database.DeleteTailnetClientSubscriptionParams) (database.DeleteTailnetClientSubscriptionRow, error) { +func (m metricsStore) DeleteTailnetClientSubscription(ctx context.Context, arg database.DeleteTailnetClientSubscriptionParams) error { start := time.Now() - r0, r1 := m.s.DeleteTailnetClientSubscription(ctx, arg) + r0 := m.s.DeleteTailnetClientSubscription(ctx, arg) m.queryLatencies.WithLabelValues("DeleteTailnetClientSubscription").Observe(time.Since(start).Seconds()) - return r0, r1 + return r0 } func (m metricsStore) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) { diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 4969acdcbba5e..393311abb76c0 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -124,6 +124,20 @@ func (mr *MockStoreMockRecorder) DeleteAPIKeysByUserID(arg0, arg1 interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAPIKeysByUserID", reflect.TypeOf((*MockStore)(nil).DeleteAPIKeysByUserID), arg0, arg1) } +// DeleteAllTailnetClientSubscriptions mocks base method. +func (m *MockStore) DeleteAllTailnetClientSubscriptions(arg0 context.Context, arg1 database.DeleteAllTailnetClientSubscriptionsParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAllTailnetClientSubscriptions", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAllTailnetClientSubscriptions indicates an expected call of DeleteAllTailnetClientSubscriptions. +func (mr *MockStoreMockRecorder) DeleteAllTailnetClientSubscriptions(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllTailnetClientSubscriptions", reflect.TypeOf((*MockStore)(nil).DeleteAllTailnetClientSubscriptions), arg0, arg1) +} + // DeleteApplicationConnectAPIKeysByUserID mocks base method. func (m *MockStore) DeleteApplicationConnectAPIKeysByUserID(arg0 context.Context, arg1 uuid.UUID) error { m.ctrl.T.Helper() @@ -296,12 +310,11 @@ func (mr *MockStoreMockRecorder) DeleteTailnetClient(arg0, arg1 interface{}) *go } // DeleteTailnetClientSubscription mocks base method. -func (m *MockStore) DeleteTailnetClientSubscription(arg0 context.Context, arg1 database.DeleteTailnetClientSubscriptionParams) (database.DeleteTailnetClientSubscriptionRow, error) { +func (m *MockStore) DeleteTailnetClientSubscription(arg0 context.Context, arg1 database.DeleteTailnetClientSubscriptionParams) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "DeleteTailnetClientSubscription", arg0, arg1) - ret0, _ := ret[0].(database.DeleteTailnetClientSubscriptionRow) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret0, _ := ret[0].(error) + return ret0 } // DeleteTailnetClientSubscription indicates an expected call of DeleteTailnetClientSubscription. diff --git a/coderd/database/foreign_key_constraint.go b/coderd/database/foreign_key_constraint.go index db2021166f621..5f3b9aa5c32b3 100644 --- a/coderd/database/foreign_key_constraint.go +++ b/coderd/database/foreign_key_constraint.go @@ -19,6 +19,7 @@ const ( ForeignKeyProvisionerJobLogsJobID ForeignKeyConstraint = "provisioner_job_logs_job_id_fkey" // ALTER TABLE ONLY provisioner_job_logs ADD CONSTRAINT provisioner_job_logs_job_id_fkey FOREIGN KEY (job_id) REFERENCES provisioner_jobs(id) ON DELETE CASCADE; ForeignKeyProvisionerJobsOrganizationID ForeignKeyConstraint = "provisioner_jobs_organization_id_fkey" // ALTER TABLE ONLY provisioner_jobs ADD CONSTRAINT provisioner_jobs_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; ForeignKeyTailnetAgentsCoordinatorID ForeignKeyConstraint = "tailnet_agents_coordinator_id_fkey" // ALTER TABLE ONLY tailnet_agents ADD CONSTRAINT tailnet_agents_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE; + ForeignKeyTailnetClientSubscriptionsCoordinatorID ForeignKeyConstraint = "tailnet_client_subscriptions_coordinator_id_fkey" // ALTER TABLE ONLY tailnet_client_subscriptions ADD CONSTRAINT tailnet_client_subscriptions_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE; ForeignKeyTailnetClientsCoordinatorID ForeignKeyConstraint = "tailnet_clients_coordinator_id_fkey" // ALTER TABLE ONLY tailnet_clients ADD CONSTRAINT tailnet_clients_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE; ForeignKeyTemplateVersionParametersTemplateVersionID ForeignKeyConstraint = "template_version_parameters_template_version_id_fkey" // ALTER TABLE ONLY template_version_parameters ADD CONSTRAINT template_version_parameters_template_version_id_fkey FOREIGN KEY (template_version_id) REFERENCES template_versions(id) ON DELETE CASCADE; ForeignKeyTemplateVersionVariablesTemplateVersionID ForeignKeyConstraint = "template_version_variables_template_version_id_fkey" // ALTER TABLE ONLY template_version_variables ADD CONSTRAINT template_version_variables_template_version_id_fkey FOREIGN KEY (template_version_id) REFERENCES template_versions(id) ON DELETE CASCADE; diff --git a/coderd/database/migrations/000155_pg_coordinator_single_tailnet.up.sql b/coderd/database/migrations/000155_pg_coordinator_single_tailnet.up.sql index 663e23ab4f53f..4ca218248ef4a 100644 --- a/coderd/database/migrations/000155_pg_coordinator_single_tailnet.up.sql +++ b/coderd/database/migrations/000155_pg_coordinator_single_tailnet.up.sql @@ -40,47 +40,42 @@ CREATE OR REPLACE FUNCTION tailnet_notify_client_change() RETURNS trigger AS $$ DECLARE var_client_id uuid; + var_coordinator_id uuid; var_agent_ids uuid[]; + var_agent_id uuid; BEGIN IF (NEW.id IS NOT NULL) THEN var_client_id = NEW.id; - SELECT - array_agg(agent_id) - INTO - var_agent_ids - FROM - tailnet_client_subscriptions subs - WHERE - subs.client_id = NEW.id AND - subs.coordinator_id = NEW.coordinator_id; + var_coordinator_id = NEW.coordinator_id; ELSIF (OLD.id IS NOT NULL) THEN - -- if new is null and old is not null, that means the row was deleted. - -- simulate a foreign key by deleting all of the subscriptions. var_client_id = OLD.id; - WITH agent_ids AS ( - DELETE FROM - tailnet_client_subscriptions subs - WHERE - subs.client_id = OLD.id AND - subs.coordinator_id = OLD.coordinator_id - RETURNING - subs.agent_id - ) - SELECT - array_agg(agent_id) - INTO - var_agent_ids - FROM - agent_ids; + var_coordinator_id = OLD.coordinator_id; END IF; -- Read all agents the client is subscribed to, so we can notify them. + SELECT + array_agg(agent_id) + INTO + var_agent_ids + FROM + tailnet_client_subscriptions subs + WHERE + subs.client_id = NEW.id AND + subs.coordinator_id = NEW.coordinator_id; + -- No agents to notify if (var_agent_ids IS NULL) THEN return NULL; END IF; - PERFORM pg_notify('tailnet_client_update', var_client_id || ',' || array_to_string(var_agent_ids, ',')); + -- pg_notify is limited to 8k bytes, which is approximately 221 UUIDs. + -- Instead of sending all agent ids in a single update, send one for each + -- agent id to prevent overflow. + FOREACH var_agent_id IN ARRAY var_agent_ids + LOOP + PERFORM pg_notify('tailnet_client_update', var_client_id || ',' || var_agent_id); + END LOOP; + return NULL; END; $$; diff --git a/coderd/database/querier.go b/coderd/database/querier.go index ee015bbb325fb..df79008288b55 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -34,6 +34,7 @@ type sqlcQuerier interface { CleanTailnetCoordinators(ctx context.Context) error DeleteAPIKeyByID(ctx context.Context, id string) error DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error + DeleteAllTailnetClientSubscriptions(ctx context.Context, arg DeleteAllTailnetClientSubscriptionsParams) error DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error DeleteCoordinator(ctx context.Context, id uuid.UUID) error DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error @@ -48,7 +49,7 @@ type sqlcQuerier interface { DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error DeleteTailnetAgent(ctx context.Context, arg DeleteTailnetAgentParams) (DeleteTailnetAgentRow, error) DeleteTailnetClient(ctx context.Context, arg DeleteTailnetClientParams) (DeleteTailnetClientRow, error) - DeleteTailnetClientSubscription(ctx context.Context, arg DeleteTailnetClientSubscriptionParams) (DeleteTailnetClientSubscriptionRow, error) + DeleteTailnetClientSubscription(ctx context.Context, arg DeleteTailnetClientSubscriptionParams) error GetAPIKeyByID(ctx context.Context, id string) (APIKey, error) // there is no unique constraint on empty token names GetAPIKeyByName(ctx context.Context, arg GetAPIKeyByNameParams) (APIKey, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index e3a4214100ae6..88b4e74f033f0 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -4125,6 +4125,22 @@ func (q *sqlQuerier) CleanTailnetCoordinators(ctx context.Context) error { return err } +const deleteAllTailnetClientSubscriptions = `-- name: DeleteAllTailnetClientSubscriptions :exec +DELETE +FROM tailnet_client_subscriptions +WHERE client_id = $1 and coordinator_id = $2 +` + +type DeleteAllTailnetClientSubscriptionsParams struct { + ClientID uuid.UUID `db:"client_id" json:"client_id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` +} + +func (q *sqlQuerier) DeleteAllTailnetClientSubscriptions(ctx context.Context, arg DeleteAllTailnetClientSubscriptionsParams) error { + _, err := q.db.ExecContext(ctx, deleteAllTailnetClientSubscriptions, arg.ClientID, arg.CoordinatorID) + return err +} + const deleteCoordinator = `-- name: DeleteCoordinator :exec DELETE FROM tailnet_coordinators @@ -4184,11 +4200,10 @@ func (q *sqlQuerier) DeleteTailnetClient(ctx context.Context, arg DeleteTailnetC return i, err } -const deleteTailnetClientSubscription = `-- name: DeleteTailnetClientSubscription :one +const deleteTailnetClientSubscription = `-- name: DeleteTailnetClientSubscription :exec DELETE FROM tailnet_client_subscriptions WHERE client_id = $1 and agent_id = $2 and coordinator_id = $3 -RETURNING client_id, agent_id, coordinator_id ` type DeleteTailnetClientSubscriptionParams struct { @@ -4197,17 +4212,9 @@ type DeleteTailnetClientSubscriptionParams struct { CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` } -type DeleteTailnetClientSubscriptionRow struct { - ClientID uuid.UUID `db:"client_id" json:"client_id"` - AgentID uuid.UUID `db:"agent_id" json:"agent_id"` - CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` -} - -func (q *sqlQuerier) DeleteTailnetClientSubscription(ctx context.Context, arg DeleteTailnetClientSubscriptionParams) (DeleteTailnetClientSubscriptionRow, error) { - row := q.db.QueryRowContext(ctx, deleteTailnetClientSubscription, arg.ClientID, arg.AgentID, arg.CoordinatorID) - var i DeleteTailnetClientSubscriptionRow - err := row.Scan(&i.ClientID, &i.AgentID, &i.CoordinatorID) - return i, err +func (q *sqlQuerier) DeleteTailnetClientSubscription(ctx context.Context, arg DeleteTailnetClientSubscriptionParams) error { + _, err := q.db.ExecContext(ctx, deleteTailnetClientSubscription, arg.ClientID, arg.AgentID, arg.CoordinatorID) + return err } const getAllTailnetAgents = `-- name: GetAllTailnetAgents :many diff --git a/coderd/database/queries/tailnet.sql b/coderd/database/queries/tailnet.sql index 86a0a1b9abd3e..16f8708f3210a 100644 --- a/coderd/database/queries/tailnet.sql +++ b/coderd/database/queries/tailnet.sql @@ -58,11 +58,15 @@ FROM tailnet_clients WHERE id = $1 and coordinator_id = $2 RETURNING id, coordinator_id; --- name: DeleteTailnetClientSubscription :one +-- name: DeleteTailnetClientSubscription :exec DELETE FROM tailnet_client_subscriptions -WHERE client_id = $1 and agent_id = $2 and coordinator_id = $3 -RETURNING client_id, agent_id, coordinator_id; +WHERE client_id = $1 and agent_id = $2 and coordinator_id = $3; + +-- name: DeleteAllTailnetClientSubscriptions :exec +DELETE +FROM tailnet_client_subscriptions +WHERE client_id = $1 and coordinator_id = $2; -- name: DeleteTailnetAgent :one DELETE diff --git a/coderd/tailnet_test.go b/coderd/tailnet_test.go index 8b1f55d9e994d..2a0b0dfdbae70 100644 --- a/coderd/tailnet_test.go +++ b/coderd/tailnet_test.go @@ -233,7 +233,7 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A derpServer, func() *tailcfg.DERPMap { return manifest.DERPMap }, false, - func(context.Context) (tailnet.MultiAgentConn, error) { return coord.ServeMultiAgent(uuid.New()) }, + func(context.Context) (tailnet.MultiAgentConn, error) { return coord.ServeMultiAgent(uuid.New()), nil }, cache, trace.NewNoopTracerProvider(), ) diff --git a/enterprise/coderd/workspaceproxycoordinate.go b/enterprise/coderd/workspaceproxycoordinate.go index bb4b3fa7b69eb..501095d44477e 100644 --- a/enterprise/coderd/workspaceproxycoordinate.go +++ b/enterprise/coderd/workspaceproxycoordinate.go @@ -67,14 +67,7 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request } id := uuid.New() - sub, err := (*api.AGPL.TailnetCoordinator.Load()).ServeMultiAgent(id) - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to serve multi agent.", - Detail: err.Error(), - }) - return - } + sub := (*api.AGPL.TailnetCoordinator.Load()).ServeMultiAgent(id) ctx, nc := websocketNetConn(ctx, conn, websocket.MessageText) defer nc.Close() diff --git a/enterprise/coderd/workspaceproxycoordinator_test.go b/enterprise/coderd/workspaceproxycoordinator_test.go index fb991180b3adc..de72c288b2eee 100644 --- a/enterprise/coderd/workspaceproxycoordinator_test.go +++ b/enterprise/coderd/workspaceproxycoordinator_test.go @@ -59,8 +59,7 @@ func Test_agentIsLegacy(t *testing.T) { defer cancel() nodeID := uuid.New() - ma, err := coordinator.ServeMultiAgent(nodeID) - require.NoError(t, err) + ma := coordinator.ServeMultiAgent(nodeID) defer ma.Close() require.NoError(t, ma.UpdateSelf(&agpl.Node{ ID: 55, @@ -124,8 +123,7 @@ func Test_agentIsLegacy(t *testing.T) { defer cancel() nodeID := uuid.New() - ma, err := coordinator.ServeMultiAgent(nodeID) - require.NoError(t, err) + ma := coordinator.ServeMultiAgent(nodeID) defer ma.Close() require.NoError(t, ma.UpdateSelf(&agpl.Node{ ID: 55, diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go index 9a04670d78b02..70ad50687b1f3 100644 --- a/enterprise/tailnet/coordinator.go +++ b/enterprise/tailnet/coordinator.go @@ -52,7 +52,7 @@ func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, err return coord, nil } -func (c *haCoordinator) ServeMultiAgent(id uuid.UUID) (agpl.MultiAgentConn, error) { +func (c *haCoordinator) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn { m := (&agpl.MultiAgent{ ID: id, AgentIsLegacyFunc: c.agentIsLegacy, @@ -61,7 +61,7 @@ func (c *haCoordinator) ServeMultiAgent(id uuid.UUID) (agpl.MultiAgentConn, erro OnRemove: func(enq agpl.Queue) { c.clientDisconnected(enq.UniqueID()) }, }).Init() c.addClient(id, m) - return m, nil + return m } func (c *haCoordinator) addClient(id uuid.UUID, q agpl.Queue) { diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index a40632dbda6ab..0f8f4713d68fe 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -38,6 +38,7 @@ const ( cleanupPeriod = time.Hour ) +// TODO: add subscriber to this graphic // pgCoord is a postgres-backed coordinator // // ┌────────┐ ┌────────┐ ┌───────┐ @@ -80,8 +81,9 @@ type pgCoord struct { closeOnce sync.Once closed chan struct{} - binder *binder - querier *querier + binder *binder + subscriber *subscriber + querier *querier } var pgCoordSubject = rbac.Subject{ @@ -121,35 +123,25 @@ func NewPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store binder: newBinder(ctx, logger, id, store, bCh, fHB), bindings: bCh, newConnections: cCh, + subscriber: newSubscriber(ctx, logger, id, store, sCh, fHB), newSubscriptions: sCh, id: id, - querier: newQuerier(ctx, logger, id, ps, store, id, cCh, sCh, numQuerierWorkers, fHB), + querier: newQuerier(ctx, logger, id, ps, store, id, cCh, numQuerierWorkers, fHB), closed: make(chan struct{}), } logger.Info(ctx, "starting coordinator") return c, nil } -func (c *pgCoord) ServeMultiAgent(id uuid.UUID) (agpl.MultiAgentConn, error) { +func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn { ma := (&agpl.MultiAgent{ ID: id, AgentIsLegacyFunc: func(agentID uuid.UUID) bool { return true }, OnSubscribe: func(enq agpl.Queue, agent uuid.UUID) (*agpl.Node, error) { - err := sendCtx(c.ctx, c.newSubscriptions, subscribe{ - q: enq, - agentID: agent, - active: true, - }) - + err := c.addSubscription(enq, agent) return c.Node(agent), err }, - OnUnsubscribe: func(enq agpl.Queue, agent uuid.UUID) error { - return sendCtx(c.ctx, c.newSubscriptions, subscribe{ - q: enq, - agentID: agent, - active: false, - }) - }, + OnUnsubscribe: c.removeSubscription, OnNodeUpdate: func(id uuid.UUID, node *agpl.Node) error { return sendCtx(c.ctx, c.bindings, binding{ bKey: bKey{id, agpl.QueueKindClient}, @@ -164,17 +156,58 @@ func (c *pgCoord) ServeMultiAgent(id uuid.UUID) (agpl.MultiAgentConn, error) { }, } if err := sendCtx(c.ctx, c.bindings, b); err != nil { - c.logger.Debug(c.ctx, "parent context expired while withdrawing bindings", slog.Error(err)) + c.logger.Debug(c.ctx, "parent context expired while withdrawing binding", slog.Error(err)) + } + if err := sendCtx(c.ctx, c.newSubscriptions, subscribe{ + sKey: sKey{clientID: id}, + active: false, + }); err != nil { + c.logger.Debug(c.ctx, "parent context expired while withdrawing subscriptions", slog.Error(err)) } c.querier.cleanupConn(enq) }, }).Init() if err := sendCtx(c.ctx, c.newConnections, agpl.Queue(ma)); err != nil { - return nil, err + // If we can't successfully send the multiagent, that means the + // coordinator is shutting down. In this case, just return a closed + // multiagent. + ma.CoordinatorClose() + } + + return ma +} + +func (c *pgCoord) addSubscription(q agpl.Queue, agentID uuid.UUID) error { + err := sendCtx(c.ctx, c.newSubscriptions, subscribe{ + sKey: sKey{ + clientID: q.UniqueID(), + agentID: agentID, + }, + active: true, + }) + if err != nil { + return err } - return ma, nil + c.querier.newClientSubscription(q, agentID) + return nil +} + +func (c *pgCoord) removeSubscription(q agpl.Queue, agentID uuid.UUID) error { + err := sendCtx(c.ctx, c.newSubscriptions, subscribe{ + sKey: sKey{ + clientID: q.UniqueID(), + agentID: agentID, + }, + active: false, + }) + if err != nil { + return err + } + + c.querier.removeClientSubscription(q, agentID) + return nil } func (c *pgCoord) Node(id uuid.UUID) *agpl.Node { @@ -215,13 +248,12 @@ func (c *pgCoord) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) erro // can only be a context error, no need to log here. return err } - if err := sendCtx(c.ctx, c.newSubscriptions, subscribe{ - q: agpl.Queue(cIO), - agentID: agent, - active: true, - }); err != nil { + + if err := c.addSubscription(cIO, agent); err != nil { return err } + defer c.removeSubscription(cIO, agent) + <-cIO.ctx.Done() return nil } @@ -264,6 +296,184 @@ func sendCtx[A any](ctx context.Context, c chan<- A, a A) (err error) { } } +type sKey struct { + clientID uuid.UUID + agentID uuid.UUID +} + +type subscribe struct { + sKey + // whether the subscription should be active. if true, the subscription is + // added. if false, the subscription is removed. + active bool +} + +type subscriber struct { + ctx context.Context + logger slog.Logger + coordinatorID uuid.UUID + store database.Store + subscriptions <-chan subscribe + + mu sync.Mutex + // map[clientID]map[agentID]subscribe + latest map[uuid.UUID]map[uuid.UUID]subscribe + workQ *workQ[sKey] +} + +func newSubscriber(ctx context.Context, + logger slog.Logger, + id uuid.UUID, + store database.Store, + subscriptions <-chan subscribe, + startWorkers <-chan struct{}, +) *subscriber { + s := &subscriber{ + ctx: ctx, + logger: logger, + coordinatorID: id, + store: store, + subscriptions: subscriptions, + latest: make(map[uuid.UUID]map[uuid.UUID]subscribe), + workQ: newWorkQ[sKey](ctx), + } + go s.handleSubscriptions() + go func() { + <-startWorkers + for i := 0; i < numBinderWorkers; i++ { + go s.worker() + } + }() + return s +} + +func (s *subscriber) handleSubscriptions() { + for { + select { + case <-s.ctx.Done(): + s.logger.Debug(s.ctx, "subscriber exiting", slog.Error(s.ctx.Err())) + return + case sub := <-s.subscriptions: + s.storeSubscription(sub) + s.workQ.enqueue(sub.sKey) + } + } +} + +func (s *subscriber) worker() { + eb := backoff.NewExponentialBackOff() + eb.MaxElapsedTime = 0 // retry indefinitely + eb.MaxInterval = dbMaxBackoff + bkoff := backoff.WithContext(eb, s.ctx) + for { + bk, err := s.workQ.acquire() + if err != nil { + // context expired + return + } + err = backoff.Retry(func() error { + bnd := s.retrieveSubscription(bk) + return s.writeOne(bnd) + }, bkoff) + if err != nil { + bkoff.Reset() + } + s.workQ.done(bk) + } +} + +func (s *subscriber) storeSubscription(sub subscribe) { + s.mu.Lock() + defer s.mu.Unlock() + if sub.active { + if _, ok := s.latest[sub.clientID]; !ok { + s.latest[sub.clientID] = map[uuid.UUID]subscribe{} + } + s.latest[sub.clientID][sub.agentID] = sub + } else { + // If the agentID is nil, clean up all of the clients subscriptions. + if sub.agentID == uuid.Nil { + delete(s.latest, sub.clientID) + } else { + delete(s.latest[sub.clientID], sub.agentID) + // clean up the subscription map if all the subscriptions are gone. + if len(s.latest[sub.clientID]) == 0 { + delete(s.latest, sub.clientID) + } + } + } +} + +// retrieveBinding gets the latest binding for a key. +func (s *subscriber) retrieveSubscription(sk sKey) subscribe { + s.mu.Lock() + defer s.mu.Unlock() + agents, ok := s.latest[sk.clientID] + if !ok { + return subscribe{ + sKey: sk, + active: false, + } + } + + sub, ok := agents[sk.agentID] + if !ok { + return subscribe{ + sKey: sk, + active: false, + } + } + + return sub +} + +func (s *subscriber) writeOne(sub subscribe) error { + var err error + switch { + case sub.agentID == uuid.Nil: + err = s.store.DeleteAllTailnetClientSubscriptions(s.ctx, database.DeleteAllTailnetClientSubscriptionsParams{ + ClientID: sub.clientID, + CoordinatorID: s.coordinatorID, + }) + s.logger.Debug(s.ctx, "deleted all client subscriptions", + slog.F("client_id", sub.clientID), + slog.Error(err), + ) + case sub.active: + err = s.store.UpsertTailnetClientSubscription(s.ctx, database.UpsertTailnetClientSubscriptionParams{ + ClientID: sub.clientID, + CoordinatorID: s.coordinatorID, + AgentID: sub.agentID, + }) + s.logger.Debug(s.ctx, "upserted client subscription", + slog.F("client_id", sub.clientID), + slog.F("agent_id", sub.agentID), + slog.Error(err), + ) + case !sub.active: + err = s.store.DeleteTailnetClientSubscription(s.ctx, database.DeleteTailnetClientSubscriptionParams{ + ClientID: sub.clientID, + CoordinatorID: s.coordinatorID, + AgentID: sub.agentID, + }) + s.logger.Debug(s.ctx, "deleted client subscription", + slog.F("client_id", sub.clientID), + slog.F("agent_id", sub.agentID), + slog.Error(err), + ) + default: + panic("unreachable") + } + if err != nil && !database.IsQueryCanceledError(err) { + s.logger.Error(s.ctx, "write subscription to database", + slog.F("client_id", sub.clientID), + slog.F("agent_id", sub.agentID), + slog.F("active", sub.active), + slog.Error(err)) + } + return err +} + // bKey, or "binding key" identifies a client or agent in a binding. Agents and // clients are differentiated by the kind field. type bKey struct { @@ -559,14 +769,6 @@ func (m *mapper) mappingsToNodes(mappings []mapping) []*agpl.Node { return nodes } -type subscribe struct { - q agpl.Queue - agentID uuid.UUID - // whether the subscription should be active. if true, the subscription is - // added. if false, the subscription is removed. - active bool -} - // querier is responsible for monitoring pubsub notifications and querying the database for the mappings that all // connected clients and agents need. It also checks heartbeats and withdraws mappings from coordinators that have // failed heartbeats. @@ -577,8 +779,7 @@ type querier struct { pubsub pubsub.Pubsub store database.Store - newConnections chan agpl.Queue - newSubscriptions chan subscribe + newConnections chan agpl.Queue workQ *workQ[mKey] @@ -607,7 +808,6 @@ func newQuerier(ctx context.Context, store database.Store, self uuid.UUID, newConnections chan agpl.Queue, - newSubscriptions chan subscribe, numWorkers int, firstHeartbeat chan struct{}, ) *querier { @@ -619,7 +819,6 @@ func newQuerier(ctx context.Context, pubsub: ps, store: store, newConnections: newConnections, - newSubscriptions: newSubscriptions, workQ: newWorkQ[mKey](ctx), heartbeats: newHeartbeats(ctx, logger, ps, store, self, updates, firstHeartbeat), mappers: make(map[mKey]*countedMapper), @@ -633,7 +832,6 @@ func newQuerier(ctx context.Context, go func() { <-firstHeartbeat go q.handleNewConnections() - go q.handleNewSubscriptions() for i := 0; i < numWorkers; i++ { go q.worker() } @@ -660,29 +858,6 @@ func (q *querier) handleNewConnections() { } } -func (q *querier) handleNewSubscriptions() { - for { - select { - case <-q.ctx.Done(): - return - case c := <-q.newSubscriptions: - if c.active { - err := q.newClientSubscription(c.q, c.agentID) - if err != nil { - q.logger.Error(q.ctx, "create client subscription", slog.Error(err), - slog.F("client_id", c.q.UniqueID()), slog.F("agent_id", c.agentID)) - } - } else { - err := q.removeClientSubscription(c.q, c.agentID) - if err != nil { - q.logger.Error(q.ctx, "remove client subscription", slog.Error(err), - slog.F("client_id", c.q.UniqueID()), slog.F("agent_id", c.agentID)) - } - } - } - } -} - func (q *querier) newAgentConn(c agpl.Queue) { q.mu.Lock() defer q.mu.Unlock() @@ -719,7 +894,7 @@ func (q *querier) newAgentConn(c agpl.Queue) { go q.waitCleanupConn(c) } -func (q *querier) newClientSubscription(c agpl.Queue, agentID uuid.UUID) error { +func (q *querier) newClientSubscription(c agpl.Queue, agentID uuid.UUID) { q.mu.Lock() defer q.mu.Unlock() @@ -727,18 +902,9 @@ func (q *querier) newClientSubscription(c agpl.Queue, agentID uuid.UUID) error { q.clientSubscriptions[c.UniqueID()] = map[uuid.UUID]struct{}{} } - err := q.store.UpsertTailnetClientSubscription(q.ctx, database.UpsertTailnetClientSubscriptionParams{ - ClientID: c.UniqueID(), - CoordinatorID: q.coordinatorID, - AgentID: agentID, - }) - if err != nil { - return xerrors.Errorf("upsert subscription: %w", err) - } - mk := mKey{ agent: agentID, - kind: c.Kind(), + kind: agpl.QueueKindClient, } cm, ok := q.mappers[mk] if !ok { @@ -754,33 +920,23 @@ func (q *querier) newClientSubscription(c agpl.Queue, agentID uuid.UUID) error { q.workQ.enqueue(mk) } if err := sendCtx(cm.ctx, cm.add, c); err != nil { - return xerrors.Errorf("send subscription to mapper: %w", err) + return } q.clientSubscriptions[c.UniqueID()][agentID] = struct{}{} cm.count++ - return nil } -func (q *querier) removeClientSubscription(c agpl.Queue, agentID uuid.UUID) error { +func (q *querier) removeClientSubscription(c agpl.Queue, agentID uuid.UUID) { q.mu.Lock() defer q.mu.Unlock() - _, err := q.store.DeleteTailnetClientSubscription(q.ctx, database.DeleteTailnetClientSubscriptionParams{ - ClientID: c.UniqueID(), - CoordinatorID: q.coordinatorID, - AgentID: agentID, - }) - if err != nil { - return xerrors.Errorf("delete subscription: %w", err) - } - mk := mKey{ agent: agentID, - kind: c.Kind(), + kind: agpl.QueueKindClient, } cm := q.mappers[mk] if err := sendCtx(cm.ctx, cm.del, c); err != nil { - return xerrors.Errorf("send deletion to mapper: %w", err) + return } delete(q.clientSubscriptions[c.UniqueID()], agentID) cm.count-- @@ -788,7 +944,6 @@ func (q *querier) removeClientSubscription(c agpl.Queue, agentID uuid.UUID) erro cm.cancel() delete(q.mappers, mk) } - return nil } func (q *querier) newClientConn(c agpl.Queue) { @@ -1238,9 +1393,13 @@ type mapping struct { node *agpl.Node } +type queueKey interface { + mKey | bKey | sKey +} + // workQ allows scheduling work based on a key. Multiple enqueue requests for the same key are coalesced, and // only one in-progress job per key is scheduled. -type workQ[K mKey | bKey] struct { +type workQ[K queueKey] struct { ctx context.Context cond *sync.Cond @@ -1248,7 +1407,7 @@ type workQ[K mKey | bKey] struct { inProgress map[K]bool } -func newWorkQ[K mKey | bKey](ctx context.Context) *workQ[K] { +func newWorkQ[K queueKey](ctx context.Context) *workQ[K] { q := &workQ[K]{ ctx: ctx, cond: sync.NewCond(&sync.Mutex{}), diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index aa6de92b8def8..b06fc005211bb 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -609,8 +609,7 @@ func TestPGCoordinator_MultiAgent(t *testing.T) { agent1.sendNode(&agpl.Node{PreferredDERP: 5}) id := uuid.New() - ma1, err := coord1.ServeMultiAgent(id) - require.NoError(t, err) + ma1 := coord1.ServeMultiAgent(id) defer ma1.Close() err = ma1.SubscribeAgent(agent1.id) diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 073195187a9cf..effe9173b4f81 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -45,7 +45,7 @@ type Coordinator interface { // Close closes the coordinator. Close() error - ServeMultiAgent(id uuid.UUID) (MultiAgentConn, error) + ServeMultiAgent(id uuid.UUID) MultiAgentConn } // Node represents a node in the network. @@ -139,7 +139,7 @@ type coordinator struct { core *core } -func (c *coordinator) ServeMultiAgent(id uuid.UUID) (MultiAgentConn, error) { +func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn { m := (&MultiAgent{ ID: id, AgentIsLegacyFunc: c.core.agentIsLegacy, @@ -149,7 +149,7 @@ func (c *coordinator) ServeMultiAgent(id uuid.UUID) (MultiAgentConn, error) { OnRemove: func(enq Queue) { c.core.clientDisconnected(enq.UniqueID()) }, }).Init() c.core.addClient(id, m) - return m, nil + return m } func (c *core) addClient(id uuid.UUID, ma Queue) { From bdd7ef1526992bf7f02e226990dda891b433c935 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 15 Sep 2023 19:48:52 +0000 Subject: [PATCH 10/21] fixup! add subscriber subsystem --- coderd/database/dbfake/dbfake.go | 4 ++-- enterprise/tailnet/pgcoord.go | 11 ++++++++++- provisionersdk/session.go | 2 +- scaletest/workspacetraffic/run_test.go | 3 ++- scripts/apitypings/main.go | 2 +- 5 files changed, 16 insertions(+), 6 deletions(-) diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index f3b092b8801a5..18411447153df 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -844,13 +844,13 @@ func (q *FakeQuerier) DeleteAPIKeysByUserID(_ context.Context, userID uuid.UUID) return nil } -func (q *FakeQuerier) DeleteAllTailnetClientSubscriptions(ctx context.Context, arg database.DeleteAllTailnetClientSubscriptionsParams) error { +func (q *FakeQuerier) DeleteAllTailnetClientSubscriptions(_ context.Context, arg database.DeleteAllTailnetClientSubscriptionsParams) error { err := validateDatabaseType(arg) if err != nil { return err } - panic("not implemented") + return ErrUnimplemented } func (q *FakeQuerier) DeleteApplicationConnectAPIKeysByUserID(_ context.Context, userID uuid.UUID) error { diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 0f8f4713d68fe..1edcf27b0e32c 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -252,7 +252,16 @@ func (c *pgCoord) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) erro if err := c.addSubscription(cIO, agent); err != nil { return err } - defer c.removeSubscription(cIO, agent) + defer func() { + err := c.removeSubscription(cIO, agent) + if err != nil { + c.logger.Debug(c.ctx, "remove client subscription", + slog.F("client_id", id), + slog.F("agent_id", agent), + slog.Error(err), + ) + } + }() <-cIO.ctx.Done() return nil diff --git a/provisionersdk/session.go b/provisionersdk/session.go index 840e02d30e337..d4b2935b5d95a 100644 --- a/provisionersdk/session.go +++ b/provisionersdk/session.go @@ -234,7 +234,7 @@ func (s *Session) extractArchive() error { // Always check for context cancellation before reading the next header. // This is mainly important for unit tests, since a canceled context means // the underlying directory is going to be deleted. There still exists - // the small race condition that the context is cancelled after this, and + // the small race condition that the context is canceled after this, and // before the disk write. if ctx.Err() != nil { return xerrors.Errorf("context canceled: %w", ctx.Err()) diff --git a/scaletest/workspacetraffic/run_test.go b/scaletest/workspacetraffic/run_test.go index 6e759cf46ebaf..961263e972e1c 100644 --- a/scaletest/workspacetraffic/run_test.go +++ b/scaletest/workspacetraffic/run_test.go @@ -8,6 +8,8 @@ import ( "testing" "time" + "golang.org/x/exp/slices" + "github.com/coder/coder/v2/agent" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/codersdk" @@ -16,7 +18,6 @@ import ( "github.com/coder/coder/v2/provisionersdk/proto" "github.com/coder/coder/v2/scaletest/workspacetraffic" "github.com/coder/coder/v2/testutil" - "golang.org/x/exp/slices" "github.com/google/uuid" "github.com/stretchr/testify/assert" diff --git a/scripts/apitypings/main.go b/scripts/apitypings/main.go index 45c97399610e3..3b5709cc01646 100644 --- a/scripts/apitypings/main.go +++ b/scripts/apitypings/main.go @@ -990,7 +990,7 @@ func (g *Generator) typescriptType(ty types.Type) (TypescriptType, error) { } // Do support "Stringer" interfaces, they likely can get string - // marshalled. + // marshaled. for i := 0; i < intf.NumMethods(); i++ { meth := intf.Method(i) if meth.Name() == "String" { From 751b22ffb754509a7e533c7c4104d7dbefb41806 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 15 Sep 2023 20:00:09 +0000 Subject: [PATCH 11/21] fixup! add subscriber subsystem --- coderd/database/dbauthz/dbauthz.go | 14 ++++----- coderd/database/dbfake/dbfake.go | 2 +- coderd/database/dump.sql | 49 ++++++++++++++---------------- 3 files changed, 30 insertions(+), 35 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index bbf538e3be536..d264c98458db8 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -685,6 +685,13 @@ func (q *querier) DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) e return q.db.DeleteAPIKeysByUserID(ctx, userID) } +func (q *querier) DeleteAllTailnetClientSubscriptions(ctx context.Context, arg database.DeleteAllTailnetClientSubscriptionsParams) error { + if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil { + return err + } + return q.db.DeleteAllTailnetClientSubscriptions(ctx, arg) +} + func (q *querier) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error { // TODO: This is not 100% correct because it omits apikey IDs. err := q.authorizeContext(ctx, rbac.ActionDelete, @@ -781,13 +788,6 @@ func (q *querier) DeleteTailnetClientSubscription(ctx context.Context, arg datab return q.db.DeleteTailnetClientSubscription(ctx, arg) } -func (q *querier) DeleteAllTailnetClientSubscriptions(ctx context.Context, arg database.DeleteAllTailnetClientSubscriptionsParams) error { - if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil { - return err - } - return q.db.DeleteAllTailnetClientSubscriptions(ctx, arg) -} - func (q *querier) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) { return fetch(q.log, q.auth, q.db.GetAPIKeyByID)(ctx, id) } diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index 18411447153df..002932407a181 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -844,7 +844,7 @@ func (q *FakeQuerier) DeleteAPIKeysByUserID(_ context.Context, userID uuid.UUID) return nil } -func (q *FakeQuerier) DeleteAllTailnetClientSubscriptions(_ context.Context, arg database.DeleteAllTailnetClientSubscriptionsParams) error { +func (*FakeQuerier) DeleteAllTailnetClientSubscriptions(_ context.Context, arg database.DeleteAllTailnetClientSubscriptionsParams) error { err := validateDatabaseType(arg) if err != nil { return err diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 8b946aa4e6299..09a50730fa819 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -221,47 +221,42 @@ CREATE FUNCTION tailnet_notify_client_change() RETURNS trigger AS $$ DECLARE var_client_id uuid; + var_coordinator_id uuid; var_agent_ids uuid[]; + var_agent_id uuid; BEGIN IF (NEW.id IS NOT NULL) THEN var_client_id = NEW.id; - SELECT - array_agg(agent_id) - INTO - var_agent_ids - FROM - tailnet_client_subscriptions subs - WHERE - subs.client_id = NEW.id AND - subs.coordinator_id = NEW.coordinator_id; + var_coordinator_id = NEW.coordinator_id; ELSIF (OLD.id IS NOT NULL) THEN - -- if new is null and old is not null, that means the row was deleted. - -- simulate a foreign key by deleting all of the subscriptions. var_client_id = OLD.id; - WITH agent_ids AS ( - DELETE FROM - tailnet_client_subscriptions subs - WHERE - subs.client_id = OLD.id AND - subs.coordinator_id = OLD.coordinator_id - RETURNING - subs.agent_id - ) - SELECT - array_agg(agent_id) - INTO - var_agent_ids - FROM - agent_ids; + var_coordinator_id = OLD.coordinator_id; END IF; -- Read all agents the client is subscribed to, so we can notify them. + SELECT + array_agg(agent_id) + INTO + var_agent_ids + FROM + tailnet_client_subscriptions subs + WHERE + subs.client_id = NEW.id AND + subs.coordinator_id = NEW.coordinator_id; + -- No agents to notify if (var_agent_ids IS NULL) THEN return NULL; END IF; - PERFORM pg_notify('tailnet_client_update', var_client_id || ',' || array_to_string(var_agent_ids, ',')); + -- pg_notify is limited to 8k bytes, which is approximately 221 UUIDs. + -- Instead of sending all agent ids in a single update, send one for each + -- agent id to prevent overflow. + FOREACH var_agent_id IN ARRAY var_agent_ids + LOOP + PERFORM pg_notify('tailnet_client_update', var_client_id || ',' || var_agent_id); + END LOOP; + return NULL; END; $$; From 3af1af1f3560b4be1eac66a136626bebe7f56fba Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 15 Sep 2023 22:29:53 +0000 Subject: [PATCH 12/21] fixup! add subscriber subsystem --- enterprise/tailnet/connio.go | 4 ++-- enterprise/tailnet/pgcoord.go | 33 ++++++++++++++------------------- tailnet/coordinator.go | 3 +-- 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/enterprise/tailnet/connio.go b/enterprise/tailnet/connio.go index 5d429ec3d5398..fed307758603e 100644 --- a/enterprise/tailnet/connio.go +++ b/enterprise/tailnet/connio.go @@ -34,7 +34,7 @@ func newConnIO(pCtx context.Context, id uuid.UUID, name string, kind agpl.QueueKind, -) (*connIO, error) { +) *connIO { ctx, cancel := context.WithCancel(pCtx) c := &connIO{ pCtx: pCtx, @@ -48,7 +48,7 @@ func newConnIO(pCtx context.Context, go c.recvLoop() go c.updates.SendUpdates() logger.Info(ctx, "serving connection") - return c, nil + return c } func (c *connIO) recvLoop() { diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 1edcf27b0e32c..61f495ac8a3a2 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -26,16 +26,16 @@ import ( ) const ( - EventHeartbeats = "tailnet_coordinator_heartbeat" - eventClientUpdate = "tailnet_client_update" - eventClientSubscription = "tailnet_client_subscription_update" - eventAgentUpdate = "tailnet_agent_update" - HeartbeatPeriod = time.Second * 2 - MissedHeartbeats = 3 - numQuerierWorkers = 10 - numBinderWorkers = 10 - dbMaxBackoff = 10 * time.Second - cleanupPeriod = time.Hour + EventHeartbeats = "tailnet_coordinator_heartbeat" + eventClientUpdate = "tailnet_client_update" + eventAgentUpdate = "tailnet_agent_update" + HeartbeatPeriod = time.Second * 2 + MissedHeartbeats = 3 + numQuerierWorkers = 10 + numBinderWorkers = 10 + numSubscriberWorkers = 10 + dbMaxBackoff = 10 * time.Second + cleanupPeriod = time.Hour ) // TODO: add subscriber to this graphic @@ -240,10 +240,8 @@ func (c *pgCoord) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) erro slog.Error(err)) } }() - cIO, err := newConnIO(c.ctx, c.logger, c.bindings, conn, id, id.String(), agpl.QueueKindClient) - if err != nil { - return err - } + + cIO := newConnIO(c.ctx, c.logger, c.bindings, conn, id, id.String(), agpl.QueueKindClient) if err := sendCtx(c.ctx, c.newConnections, agpl.Queue(cIO)); err != nil { // can only be a context error, no need to log here. return err @@ -277,10 +275,7 @@ func (c *pgCoord) ServeAgent(conn net.Conn, id uuid.UUID, name string) error { } }() logger := c.logger.With(slog.F("name", name)) - cIO, err := newConnIO(c.ctx, logger, c.bindings, conn, id, name, agpl.QueueKindAgent) - if err != nil { - return err - } + cIO := newConnIO(c.ctx, logger, c.bindings, conn, id, name, agpl.QueueKindAgent) if err := sendCtx(c.ctx, c.newConnections, agpl.Queue(cIO)); err != nil { // can only be a context error, no need to log here. return err @@ -349,7 +344,7 @@ func newSubscriber(ctx context.Context, go s.handleSubscriptions() go func() { <-startWorkers - for i := 0; i < numBinderWorkers; i++ { + for i := 0; i < numSubscriberWorkers; i++ { go s.worker() } }() diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index effe9173b4f81..41a75f1fc5e78 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -194,8 +194,7 @@ type core struct { type QueueKind int const ( - _ QueueKind = iota - QueueKindClient + QueueKindClient QueueKind = 1 + iota QueueKindAgent ) From 7762a7393eb80d5e60dcae639dcb736adc18bc02 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Tue, 19 Sep 2023 18:14:09 +0000 Subject: [PATCH 13/21] querier <- subscriber --- enterprise/tailnet/pgcoord.go | 118 ++++++++++++++++++++-------------- 1 file changed, 71 insertions(+), 47 deletions(-) diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 61f495ac8a3a2..63d223b631203 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -110,7 +110,10 @@ func NewPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store logger = logger.Named("pgcoord").With(slog.F("coordinator_id", id)) bCh := make(chan binding) cCh := make(chan agpl.Queue) + // for communicating subscriptions with the subscriber sCh := make(chan subscribe) + // for communicating subscriptions with the querier + qsCh := make(chan subscribe) // signals when first heartbeat has been sent, so it's safe to start binding. fHB := make(chan struct{}) @@ -123,10 +126,10 @@ func NewPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store binder: newBinder(ctx, logger, id, store, bCh, fHB), bindings: bCh, newConnections: cCh, - subscriber: newSubscriber(ctx, logger, id, store, sCh, fHB), + subscriber: newSubscriber(ctx, logger, id, store, sCh, qsCh, fHB), newSubscriptions: sCh, id: id, - querier: newQuerier(ctx, logger, id, ps, store, id, cCh, numQuerierWorkers, fHB), + querier: newQuerier(ctx, logger, id, ps, store, id, cCh, qsCh, numQuerierWorkers, fHB), closed: make(chan struct{}), } logger.Info(ctx, "starting coordinator") @@ -160,11 +163,11 @@ func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn { } if err := sendCtx(c.ctx, c.newSubscriptions, subscribe{ sKey: sKey{clientID: id}, + q: enq, active: false, }); err != nil { c.logger.Debug(c.ctx, "parent context expired while withdrawing subscriptions", slog.Error(err)) } - c.querier.cleanupConn(enq) }, }).Init() @@ -184,13 +187,12 @@ func (c *pgCoord) addSubscription(q agpl.Queue, agentID uuid.UUID) error { clientID: q.UniqueID(), agentID: agentID, }, + q: q, active: true, }) if err != nil { return err } - - c.querier.newClientSubscription(q, agentID) return nil } @@ -200,13 +202,12 @@ func (c *pgCoord) removeSubscription(q agpl.Queue, agentID uuid.UUID) error { clientID: q.UniqueID(), agentID: agentID, }, + q: q, active: false, }) if err != nil { return err } - - c.querier.removeClientSubscription(q, agentID) return nil } @@ -307,6 +308,8 @@ type sKey struct { type subscribe struct { sKey + + q agpl.Queue // whether the subscription should be active. if true, the subscription is // added. if false, the subscription is removed. active bool @@ -318,6 +321,7 @@ type subscriber struct { coordinatorID uuid.UUID store database.Store subscriptions <-chan subscribe + querierCh chan<- subscribe mu sync.Mutex // map[clientID]map[agentID]subscribe @@ -330,6 +334,7 @@ func newSubscriber(ctx context.Context, id uuid.UUID, store database.Store, subscriptions <-chan subscribe, + querierCh chan<- subscribe, startWorkers <-chan struct{}, ) *subscriber { s := &subscriber{ @@ -338,6 +343,7 @@ func newSubscriber(ctx context.Context, coordinatorID: id, store: store, subscriptions: subscriptions, + querierCh: querierCh, latest: make(map[uuid.UUID]map[uuid.UUID]subscribe), workQ: newWorkQ[sKey](ctx), } @@ -360,6 +366,7 @@ func (s *subscriber) handleSubscriptions() { case sub := <-s.subscriptions: s.storeSubscription(sub) s.workQ.enqueue(sub.sKey) + s.querierCh <- sub } } } @@ -784,6 +791,7 @@ type querier struct { store database.Store newConnections chan agpl.Queue + subscriptions chan subscribe workQ *workQ[mKey] @@ -812,6 +820,7 @@ func newQuerier(ctx context.Context, store database.Store, self uuid.UUID, newConnections chan agpl.Queue, + subscriptions chan subscribe, numWorkers int, firstHeartbeat chan struct{}, ) *querier { @@ -823,6 +832,7 @@ func newQuerier(ctx context.Context, pubsub: ps, store: store, newConnections: newConnections, + subscriptions: subscriptions, workQ: newWorkQ[mKey](ctx), heartbeats: newHeartbeats(ctx, logger, ps, store, self, updates, firstHeartbeat), mappers: make(map[mKey]*countedMapper), @@ -835,7 +845,7 @@ func newQuerier(ctx context.Context, go func() { <-firstHeartbeat - go q.handleNewConnections() + go q.handleIncoming() for i := 0; i < numWorkers; i++ { go q.worker() } @@ -844,11 +854,12 @@ func newQuerier(ctx context.Context, return q } -func (q *querier) handleNewConnections() { +func (q *querier) handleIncoming() { for { select { case <-q.ctx.Done(): return + case c := <-q.newConnections: switch c.Kind() { case agpl.QueueKindAgent: @@ -858,6 +869,13 @@ func (q *querier) handleNewConnections() { default: panic(fmt.Sprint("unreachable: invalid queue kind ", c.Kind())) } + + case sub := <-q.subscriptions: + if sub.active { + q.newClientSubscription(sub.q, sub.agentID) + } else { + q.removeClientSubscription(sub.q, sub.agentID) + } } } } @@ -905,6 +923,11 @@ func (q *querier) newClientSubscription(c agpl.Queue, agentID uuid.UUID) { if _, ok := q.clientSubscriptions[c.UniqueID()]; !ok { q.clientSubscriptions[c.UniqueID()] = map[uuid.UUID]struct{}{} } + fmt.Println("CREATEDC SUBSCRIPTION", c.UniqueID(), agentID) + fmt.Println("CREATEDC SUBSCRIPTION", c.UniqueID(), agentID) + fmt.Println("CREATEDC SUBSCRIPTION", c.UniqueID(), agentID) + fmt.Println("CREATEDC SUBSCRIPTION", c.UniqueID(), agentID) + fmt.Println("CREATEDC SUBSCRIPTION", c.UniqueID(), agentID) mk := mKey{ agent: agentID, @@ -934,6 +957,12 @@ func (q *querier) removeClientSubscription(c agpl.Queue, agentID uuid.UUID) { q.mu.Lock() defer q.mu.Unlock() + // agentID: uuid.Nil indicates that a client is going away. The querier + // handles that in cleanupConn below instead. + if agentID == uuid.Nil { + return + } + mk := mKey{ agent: agentID, kind: agpl.QueueKindClient, @@ -948,6 +977,9 @@ func (q *querier) removeClientSubscription(c agpl.Queue, agentID uuid.UUID) { cm.cancel() delete(q.mappers, mk) } + if len(q.clientSubscriptions[c.UniqueID()]) == 0 { + delete(q.clientSubscriptions, c.UniqueID()) + } } func (q *querier) newClientConn(c agpl.Queue) { @@ -982,18 +1014,17 @@ func (q *querier) cleanupConn(c agpl.Queue) { agent: agentID, kind: c.Kind(), } - cm, ok := q.mappers[mk] - if ok { - if err := sendCtx(cm.ctx, cm.del, c); err != nil { - continue - } - cm.count-- - if cm.count == 0 { - cm.cancel() - delete(q.mappers, mk) - } + cm := q.mappers[mk] + if err := sendCtx(cm.ctx, cm.del, c); err != nil { + continue + } + cm.count-- + if cm.count == 0 { + cm.cancel() + delete(q.mappers, mk) } } + delete(q.clientSubscriptions, c.UniqueID()) mk := mKey{ agent: c.UniqueID(), @@ -1190,28 +1221,26 @@ func (q *querier) listenClient(_ context.Context, msg []byte, err error) { q.logger.Warn(q.ctx, "unhandled pubsub error", slog.Error(err)) return } - client, agents, err := parseClientUpdate(string(msg)) + client, agent, err := parseClientUpdate(string(msg)) if err != nil { q.logger.Error(q.ctx, "failed to parse client update", slog.F("msg", string(msg)), slog.Error(err)) return } - logger := q.logger.With(slog.F("client_id", client)) + logger := q.logger.With(slog.F("client_id", client), slog.F("agent_id", agent)) logger.Debug(q.ctx, "got client update") - for _, agentID := range agents { - logger := q.logger.With(slog.F("agent_id", agentID)) - mk := mKey{ - agent: agentID, - kind: agpl.QueueKindAgent, - } - q.mu.Lock() - _, ok := q.mappers[mk] - q.mu.Unlock() - if !ok { - logger.Debug(q.ctx, "ignoring update because we have no mapper") - return - } - q.workQ.enqueue(mk) + + mk := mKey{ + agent: agent, + kind: agpl.QueueKindAgent, } + q.mu.Lock() + _, ok := q.mappers[mk] + q.mu.Unlock() + if !ok { + logger.Debug(q.ctx, "ignoring update because we have no mapper") + return + } + q.workQ.enqueue(mk) } func (q *querier) listenAgent(_ context.Context, msg []byte, err error) { @@ -1348,27 +1377,22 @@ func (q *querier) getAll(ctx context.Context) (map[uuid.UUID]database.TailnetAge return agentsMap, clientsMap, nil } -func parseClientUpdate(msg string) (client uuid.UUID, agents []uuid.UUID, err error) { +func parseClientUpdate(msg string) (client, agent uuid.UUID, err error) { parts := strings.Split(msg, ",") if len(parts) != 2 { - return uuid.Nil, nil, xerrors.Errorf("expected 2 parts separated by comma") + return uuid.Nil, uuid.Nil, xerrors.Errorf("expected 2 parts separated by comma") } client, err = uuid.Parse(parts[0]) if err != nil { - return uuid.Nil, nil, xerrors.Errorf("failed to parse client UUID: %w", err) + return uuid.Nil, uuid.Nil, xerrors.Errorf("failed to parse client UUID: %w", err) } - agents = []uuid.UUID{} - for _, agentStr := range parts[1:] { - agent, err := uuid.Parse(agentStr) - if err != nil { - return uuid.Nil, nil, xerrors.Errorf("failed to parse agent UUID: %w", err) - } - - agents = append(agents, agent) + agent, err = uuid.Parse(parts[1]) + if err != nil { + return uuid.Nil, uuid.Nil, xerrors.Errorf("failed to parse agent UUID: %w", err) } - return client, agents, nil + return client, agent, nil } func parseUpdateMessage(msg string) (agent uuid.UUID, err error) { From eb681ffea2eb4b3f3005b2942697399da178ae3c Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Tue, 19 Sep 2023 18:17:28 +0000 Subject: [PATCH 14/21] fixup! Merge branch 'main' into colin/single-pgcoord --- ...net.down.sql => 000156_pg_coordinator_single_tailnet.down.sql} | 0 ...tailnet.up.sql => 000156_pg_coordinator_single_tailnet.up.sql} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename coderd/database/migrations/{000155_pg_coordinator_single_tailnet.down.sql => 000156_pg_coordinator_single_tailnet.down.sql} (100%) rename coderd/database/migrations/{000155_pg_coordinator_single_tailnet.up.sql => 000156_pg_coordinator_single_tailnet.up.sql} (100%) diff --git a/coderd/database/migrations/000155_pg_coordinator_single_tailnet.down.sql b/coderd/database/migrations/000156_pg_coordinator_single_tailnet.down.sql similarity index 100% rename from coderd/database/migrations/000155_pg_coordinator_single_tailnet.down.sql rename to coderd/database/migrations/000156_pg_coordinator_single_tailnet.down.sql diff --git a/coderd/database/migrations/000155_pg_coordinator_single_tailnet.up.sql b/coderd/database/migrations/000156_pg_coordinator_single_tailnet.up.sql similarity index 100% rename from coderd/database/migrations/000155_pg_coordinator_single_tailnet.up.sql rename to coderd/database/migrations/000156_pg_coordinator_single_tailnet.up.sql From 390e837e50ed09aaccfeea716e490348e82de01c Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Tue, 19 Sep 2023 18:27:12 +0000 Subject: [PATCH 15/21] fixup! Merge branch 'main' into colin/single-pgcoord --- enterprise/tailnet/pgcoord.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 63d223b631203..22f72f5ff5d2b 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -923,11 +923,6 @@ func (q *querier) newClientSubscription(c agpl.Queue, agentID uuid.UUID) { if _, ok := q.clientSubscriptions[c.UniqueID()]; !ok { q.clientSubscriptions[c.UniqueID()] = map[uuid.UUID]struct{}{} } - fmt.Println("CREATEDC SUBSCRIPTION", c.UniqueID(), agentID) - fmt.Println("CREATEDC SUBSCRIPTION", c.UniqueID(), agentID) - fmt.Println("CREATEDC SUBSCRIPTION", c.UniqueID(), agentID) - fmt.Println("CREATEDC SUBSCRIPTION", c.UniqueID(), agentID) - fmt.Println("CREATEDC SUBSCRIPTION", c.UniqueID(), agentID) mk := mKey{ agent: agentID, From a1c3acf9425d010b137dbcc3fa9756c32cf47af1 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Tue, 19 Sep 2023 19:03:54 +0000 Subject: [PATCH 16/21] fixup! Merge branch 'main' into colin/single-pgcoord --- ...> 000156_pg_coordinator_single_tailnet.up.sql} | 0 enterprise/tailnet/pgcoord.go | 15 +++++---------- 2 files changed, 5 insertions(+), 10 deletions(-) rename coderd/database/migrations/testdata/fixtures/{000155_pg_coordinator_single_tailnet.up.sql => 000156_pg_coordinator_single_tailnet.up.sql} (100%) diff --git a/coderd/database/migrations/testdata/fixtures/000155_pg_coordinator_single_tailnet.up.sql b/coderd/database/migrations/testdata/fixtures/000156_pg_coordinator_single_tailnet.up.sql similarity index 100% rename from coderd/database/migrations/testdata/fixtures/000155_pg_coordinator_single_tailnet.up.sql rename to coderd/database/migrations/testdata/fixtures/000156_pg_coordinator_single_tailnet.up.sql diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 22f72f5ff5d2b..18c189259f151 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -251,16 +251,6 @@ func (c *pgCoord) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) erro if err := c.addSubscription(cIO, agent); err != nil { return err } - defer func() { - err := c.removeSubscription(cIO, agent) - if err != nil { - c.logger.Debug(c.ctx, "remove client subscription", - slog.F("client_id", id), - slog.F("agent_id", agent), - slog.Error(err), - ) - } - }() <-cIO.ctx.Done() return nil @@ -924,12 +914,15 @@ func (q *querier) newClientSubscription(c agpl.Queue, agentID uuid.UUID) { q.clientSubscriptions[c.UniqueID()] = map[uuid.UUID]struct{}{} } + fmt.Println("add sub", c.UniqueID(), agentID) + mk := mKey{ agent: agentID, kind: agpl.QueueKindClient, } cm, ok := q.mappers[mk] if !ok { + fmt.Println("new mapper") ctx, cancel := context.WithCancel(q.ctx) mpr := newMapper(ctx, q.logger, mk, q.heartbeats) cm = &countedMapper{ @@ -952,6 +945,8 @@ func (q *querier) removeClientSubscription(c agpl.Queue, agentID uuid.UUID) { q.mu.Lock() defer q.mu.Unlock() + fmt.Println("remove sub", c.UniqueID(), agentID) + // agentID: uuid.Nil indicates that a client is going away. The querier // handles that in cleanupConn below instead. if agentID == uuid.Nil { From 8256670d45ed78a7d9371e81db201fbbf8ebcfba Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Tue, 19 Sep 2023 19:07:50 +0000 Subject: [PATCH 17/21] fixup! Merge branch 'main' into colin/single-pgcoord --- enterprise/tailnet/pgcoord.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 18c189259f151..0f3c20699b4d0 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -914,15 +914,12 @@ func (q *querier) newClientSubscription(c agpl.Queue, agentID uuid.UUID) { q.clientSubscriptions[c.UniqueID()] = map[uuid.UUID]struct{}{} } - fmt.Println("add sub", c.UniqueID(), agentID) - mk := mKey{ agent: agentID, kind: agpl.QueueKindClient, } cm, ok := q.mappers[mk] if !ok { - fmt.Println("new mapper") ctx, cancel := context.WithCancel(q.ctx) mpr := newMapper(ctx, q.logger, mk, q.heartbeats) cm = &countedMapper{ @@ -945,8 +942,6 @@ func (q *querier) removeClientSubscription(c agpl.Queue, agentID uuid.UUID) { q.mu.Lock() defer q.mu.Unlock() - fmt.Println("remove sub", c.UniqueID(), agentID) - // agentID: uuid.Nil indicates that a client is going away. The querier // handles that in cleanupConn below instead. if agentID == uuid.Nil { From a75f6f1cf35a6357cabf0e7676d2a4fffaf01b0c Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Wed, 20 Sep 2023 05:39:41 +0000 Subject: [PATCH 18/21] add extensive multiagent tests --- enterprise/tailnet/multiagent_test.go | 354 ++++++++++++++++++++++++++ enterprise/tailnet/pgcoord.go | 7 +- enterprise/tailnet/pgcoord_test.go | 95 +++---- 3 files changed, 407 insertions(+), 49 deletions(-) create mode 100644 enterprise/tailnet/multiagent_test.go diff --git a/enterprise/tailnet/multiagent_test.go b/enterprise/tailnet/multiagent_test.go new file mode 100644 index 0000000000000..ee2e835cecb64 --- /dev/null +++ b/enterprise/tailnet/multiagent_test.go @@ -0,0 +1,354 @@ +package tailnet_test + +import ( + "context" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/enterprise/tailnet" + agpl "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/testutil" +) + +// TestPGCoordinator_MultiAgent tests a single coordinator with a MultiAgent +// connecting to one agent. +// +// +--------+ +// agent1 ---> | coord1 | <--- client +// +--------+ +func TestPGCoordinator_MultiAgent(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + store, ps := dbtestutil.NewDB(t) + coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) + require.NoError(t, err) + defer coord1.Close() + + agent1 := newTestAgent(t, coord1, "agent1") + defer agent1.close() + agent1.sendNode(&agpl.Node{PreferredDERP: 5}) + + id := uuid.New() + ma1 := coord1.ServeMultiAgent(id) + defer ma1.Close() + + err = ma1.SubscribeAgent(agent1.id) + require.NoError(t, err) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + + agent1.sendNode(&agpl.Node{PreferredDERP: 1}) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + + err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) + require.NoError(t, err) + assertEventuallyHasDERPs(ctx, t, agent1, 3) + + require.NoError(t, ma1.Close()) + require.NoError(t, agent1.close()) + + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) + assertEventuallyNoAgents(ctx, t, store, agent1.id) +} + +// TestPGCoordinator_MultiAgent_UnsubscribeRace tests a single coordinator with +// a MultiAgent connecting to one agent. It tries to race a call to Unsubscribe +// with the MultiAgent closing. +// +// +--------+ +// agent1 ---> | coord1 | <--- client +// +--------+ +func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + store, ps := dbtestutil.NewDB(t) + coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) + require.NoError(t, err) + defer coord1.Close() + + agent1 := newTestAgent(t, coord1, "agent1") + defer agent1.close() + agent1.sendNode(&agpl.Node{PreferredDERP: 5}) + + id := uuid.New() + ma1 := coord1.ServeMultiAgent(id) + defer ma1.Close() + + err = ma1.SubscribeAgent(agent1.id) + require.NoError(t, err) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + + agent1.sendNode(&agpl.Node{PreferredDERP: 1}) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + + err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) + require.NoError(t, err) + assertEventuallyHasDERPs(ctx, t, agent1, 3) + + require.NoError(t, ma1.UnsubscribeAgent(agent1.id)) + require.NoError(t, ma1.Close()) + require.NoError(t, agent1.close()) + + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) + assertEventuallyNoAgents(ctx, t, store, agent1.id) +} + +// TestPGCoordinator_MultiAgent_Unsubscribe tests a single coordinator with a +// MultiAgent connecting to one agent. It unsubscribes before closing, and +// ensures node updates are no longer propagated. +// +// +--------+ +// agent1 ---> | coord1 | <--- client +// +--------+ +func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + store, ps := dbtestutil.NewDB(t) + coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) + require.NoError(t, err) + defer coord1.Close() + + agent1 := newTestAgent(t, coord1, "agent1") + defer agent1.close() + agent1.sendNode(&agpl.Node{PreferredDERP: 5}) + + id := uuid.New() + ma1 := coord1.ServeMultiAgent(id) + defer ma1.Close() + + err = ma1.SubscribeAgent(agent1.id) + require.NoError(t, err) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + + agent1.sendNode(&agpl.Node{PreferredDERP: 1}) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + + require.NoError(t, ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})) + assertEventuallyHasDERPs(ctx, t, agent1, 3) + + require.NoError(t, ma1.UnsubscribeAgent(agent1.id)) + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) + + func() { + ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3) + defer cancel() + require.NoError(t, ma1.UpdateSelf(&agpl.Node{PreferredDERP: 9})) + assertNeverHasDERPs(ctx, t, agent1, 9) + }() + func() { + ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3) + defer cancel() + agent1.sendNode(&agpl.Node{PreferredDERP: 8}) + assertMultiAgentNeverHasDERPs(ctx, t, ma1, 8) + }() + + require.NoError(t, ma1.Close()) + require.NoError(t, agent1.close()) + + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) + assertEventuallyNoAgents(ctx, t, store, agent1.id) +} + +// TestPGCoordinator_MultiAgent_MultiCoordinator tests two coordinators with a +// MultiAgent connecting to an agent on a separate coordinator. +// +// +--------+ +// agent1 ---> | coord1 | +// +--------+ +// +--------+ +// | coord2 | <--- client +// +--------+ +func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + store, ps := dbtestutil.NewDB(t) + coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) + require.NoError(t, err) + defer coord1.Close() + coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store) + require.NoError(t, err) + defer coord2.Close() + + agent1 := newTestAgent(t, coord1, "agent1") + defer agent1.close() + agent1.sendNode(&agpl.Node{PreferredDERP: 5}) + + id := uuid.New() + ma1 := coord2.ServeMultiAgent(id) + defer ma1.Close() + + err = ma1.SubscribeAgent(agent1.id) + require.NoError(t, err) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + + agent1.sendNode(&agpl.Node{PreferredDERP: 1}) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + + err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) + require.NoError(t, err) + assertEventuallyHasDERPs(ctx, t, agent1, 3) + + require.NoError(t, ma1.Close()) + require.NoError(t, agent1.close()) + + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) + assertEventuallyNoAgents(ctx, t, store, agent1.id) +} + +// TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe tests two +// coordinators with a MultiAgent connecting to an agent on a separate +// coordinator. The MultiAgent updates its own node before subscribing. +// +// +--------+ +// agent1 ---> | coord1 | +// +--------+ +// +--------+ +// | coord2 | <--- client +// +--------+ +func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + store, ps := dbtestutil.NewDB(t) + coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) + require.NoError(t, err) + defer coord1.Close() + coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store) + require.NoError(t, err) + defer coord2.Close() + + agent1 := newTestAgent(t, coord1, "agent1") + defer agent1.close() + agent1.sendNode(&agpl.Node{PreferredDERP: 5}) + + id := uuid.New() + ma1 := coord2.ServeMultiAgent(id) + defer ma1.Close() + + err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) + require.NoError(t, err) + + err = ma1.SubscribeAgent(agent1.id) + require.NoError(t, err) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + assertEventuallyHasDERPs(ctx, t, agent1, 3) + + agent1.sendNode(&agpl.Node{PreferredDERP: 1}) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + + require.NoError(t, ma1.Close()) + require.NoError(t, agent1.close()) + + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) + assertEventuallyNoAgents(ctx, t, store, agent1.id) +} + +// TestPGCoordinator_MultiAgent_TwoAgents tests three coordinators with a +// MultiAgent connecting to two agents on separate coordinators. +// +// +--------+ +// agent1 ---> | coord1 | +// +--------+ +// +--------+ +// agent2 ---> | coord2 | +// +--------+ +// +--------+ +// | coord3 | <--- client +// +--------+ +func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + store, ps := dbtestutil.NewDB(t) + coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) + require.NoError(t, err) + defer coord1.Close() + coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store) + require.NoError(t, err) + defer coord2.Close() + coord3, err := tailnet.NewPGCoord(ctx, logger.Named("coord3"), ps, store) + require.NoError(t, err) + defer coord3.Close() + + agent1 := newTestAgent(t, coord1, "agent1") + defer agent1.close() + agent1.sendNode(&agpl.Node{PreferredDERP: 5}) + + agent2 := newTestAgent(t, coord2, "agent2") + defer agent1.close() + agent2.sendNode(&agpl.Node{PreferredDERP: 6}) + + id := uuid.New() + ma1 := coord2.ServeMultiAgent(id) + defer ma1.Close() + + err = ma1.SubscribeAgent(agent1.id) + require.NoError(t, err) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + + agent1.sendNode(&agpl.Node{PreferredDERP: 1}) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + + err = ma1.SubscribeAgent(agent2.id) + require.NoError(t, err) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 6) + + agent2.sendNode(&agpl.Node{PreferredDERP: 2}) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 2) + + err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) + require.NoError(t, err) + assertEventuallyHasDERPs(ctx, t, agent1, 3) + assertEventuallyHasDERPs(ctx, t, agent2, 3) + + require.NoError(t, ma1.Close()) + require.NoError(t, agent1.close()) + require.NoError(t, agent2.close()) + + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) + assertEventuallyNoAgents(ctx, t, store, agent1.id) +} diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 0f3c20699b4d0..1dd61dd9a608d 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -942,9 +942,10 @@ func (q *querier) removeClientSubscription(c agpl.Queue, agentID uuid.UUID) { q.mu.Lock() defer q.mu.Unlock() - // agentID: uuid.Nil indicates that a client is going away. The querier - // handles that in cleanupConn below instead. - if agentID == uuid.Nil { + // Allow duplicate unsubscribes. It's possible for cleanupConn to race with + // an external call to removeClientSubscription, so we just ensure the + // client subscription exists before attempting to remove it. + if _, ok := q.clientSubscriptions[c.UniqueID()][agentID]; !ok { return } diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index b06fc005211bb..031b863144e92 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -589,50 +589,6 @@ func TestPGCoordinator_Unhealthy(t *testing.T) { } } -func TestPGCoordinator_MultiAgent(t *testing.T) { - t.Parallel() - if !dbtestutil.WillUsePostgres() { - t.Skip("test only with postgres") - } - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) - defer cancel() - - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - store, ps := dbtestutil.NewDB(t) - coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) - require.NoError(t, err) - defer coord1.Close() - - agent1 := newTestAgent(t, coord1, "agent1") - defer agent1.close() - agent1.sendNode(&agpl.Node{PreferredDERP: 5}) - - id := uuid.New() - ma1 := coord1.ServeMultiAgent(id) - defer ma1.Close() - - err = ma1.SubscribeAgent(agent1.id) - require.NoError(t, err) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) - - agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) - - err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) - require.NoError(t, err) - assertEventuallyHasDERPs(ctx, t, agent1, 3) - - err = ma1.Close() - require.NoError(t, err) - - err = agent1.close() - require.NoError(t, err) - - assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) - assertEventuallyNoAgents(ctx, t, store, agent1.id) -} - type testConn struct { ws, serverWS net.Conn nodeChan chan []*agpl.Node @@ -737,8 +693,29 @@ func assertEventuallyHasDERPs(ctx context.Context, t *testing.T, c *testConn, ex t.Logf("expected DERP %d to be in %v", e, derps) continue } + return + } + } +} + +func assertNeverHasDERPs(ctx context.Context, t *testing.T, c *testConn, expected ...int) { + t.Helper() + for { + select { + case <-ctx.Done(): + return + case nodes := <-c.nodeChan: + derps := make([]int, 0, len(nodes)) + for _, n := range nodes { + derps = append(derps, n.PreferredDERP) + } + for _, e := range expected { + if slices.Contains(derps, e) { + t.Fatalf("expected not to get DERP %d, but received it", e) + return + } + } } - return } } @@ -761,8 +738,34 @@ func assertMultiAgentEventuallyHasDERPs(ctx context.Context, t *testing.T, ma ag t.Logf("expected DERP %d to be in %v", e, derps) continue } + return + } + } +} + +func assertMultiAgentNeverHasDERPs(ctx context.Context, t *testing.T, ma agpl.MultiAgentConn, expected ...int) { + t.Helper() + for { + nodes, ok := ma.NextUpdate(ctx) + if !ok { + return + } + if len(nodes) != len(expected) { + t.Logf("expected %d, got %d nodes", len(expected), len(nodes)) + continue + } + + derps := make([]int, 0, len(nodes)) + for _, n := range nodes { + derps = append(derps, n.PreferredDERP) + } + for _, e := range expected { + if !slices.Contains(derps, e) { + t.Logf("expected DERP %d to be in %v", e, derps) + continue + } + return } - return } } From d143d2d9930de26d9fd6aa72d36d0f2418d12bf2 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Wed, 20 Sep 2023 15:27:54 +0000 Subject: [PATCH 19/21] use dedicated channels for querier subscribe and closing conns --- enterprise/tailnet/multiagent_test.go | 2 +- enterprise/tailnet/pgcoord.go | 77 ++++++++++++++++----------- 2 files changed, 46 insertions(+), 33 deletions(-) diff --git a/enterprise/tailnet/multiagent_test.go b/enterprise/tailnet/multiagent_test.go index ee2e835cecb64..7546bec350504 100644 --- a/enterprise/tailnet/multiagent_test.go +++ b/enterprise/tailnet/multiagent_test.go @@ -323,7 +323,7 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) { agent2.sendNode(&agpl.Node{PreferredDERP: 6}) id := uuid.New() - ma1 := coord2.ServeMultiAgent(id) + ma1 := coord3.ServeMultiAgent(id) defer ma1.Close() err = ma1.SubscribeAgent(agent1.id) diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 1dd61dd9a608d..32595005320b2 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -74,7 +74,9 @@ type pgCoord struct { bindings chan binding newConnections chan agpl.Queue - newSubscriptions chan subscribe + closeConnections chan agpl.Queue + subscriberCh chan subscribe + querierSubCh chan subscribe id uuid.UUID cancel context.CancelFunc @@ -109,7 +111,10 @@ func NewPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store id := uuid.New() logger = logger.Named("pgcoord").With(slog.F("coordinator_id", id)) bCh := make(chan binding) + // used for opening connections cCh := make(chan agpl.Queue) + // used for closing connections + ccCh := make(chan agpl.Queue) // for communicating subscriptions with the subscriber sCh := make(chan subscribe) // for communicating subscriptions with the querier @@ -126,10 +131,12 @@ func NewPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store binder: newBinder(ctx, logger, id, store, bCh, fHB), bindings: bCh, newConnections: cCh, - subscriber: newSubscriber(ctx, logger, id, store, sCh, qsCh, fHB), - newSubscriptions: sCh, + closeConnections: ccCh, + subscriber: newSubscriber(ctx, logger, id, store, sCh, fHB), + subscriberCh: sCh, + querierSubCh: qsCh, id: id, - querier: newQuerier(ctx, logger, id, ps, store, id, cCh, qsCh, numQuerierWorkers, fHB), + querier: newQuerier(ctx, logger, id, ps, store, id, cCh, ccCh, qsCh, numQuerierWorkers, fHB), closed: make(chan struct{}), } logger.Info(ctx, "starting coordinator") @@ -152,22 +159,18 @@ func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn { }) }, OnRemove: func(enq agpl.Queue) { - b := binding{ + _ = sendCtx(c.ctx, c.bindings, binding{ bKey: bKey{ id: enq.UniqueID(), kind: enq.Kind(), }, - } - if err := sendCtx(c.ctx, c.bindings, b); err != nil { - c.logger.Debug(c.ctx, "parent context expired while withdrawing binding", slog.Error(err)) - } - if err := sendCtx(c.ctx, c.newSubscriptions, subscribe{ + }) + _ = sendCtx(c.ctx, c.subscriberCh, subscribe{ sKey: sKey{clientID: id}, q: enq, active: false, - }); err != nil { - c.logger.Debug(c.ctx, "parent context expired while withdrawing subscriptions", slog.Error(err)) - } + }) + _ = sendCtx(c.ctx, c.closeConnections, enq) }, }).Init() @@ -182,32 +185,44 @@ func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn { } func (c *pgCoord) addSubscription(q agpl.Queue, agentID uuid.UUID) error { - err := sendCtx(c.ctx, c.newSubscriptions, subscribe{ + sub := subscribe{ sKey: sKey{ clientID: q.UniqueID(), agentID: agentID, }, q: q, active: true, - }) - if err != nil { + } + if err := sendCtx(c.ctx, c.subscriberCh, sub); err != nil { return err } + if err := sendCtx(c.ctx, c.querierSubCh, sub); err != nil { + // There's no need to clean up the sub sent to the subscriber if this + // fails, since it means the entire coordinator is being torn down. + return err + } + return nil } func (c *pgCoord) removeSubscription(q agpl.Queue, agentID uuid.UUID) error { - err := sendCtx(c.ctx, c.newSubscriptions, subscribe{ + sub := subscribe{ sKey: sKey{ clientID: q.UniqueID(), agentID: agentID, }, q: q, active: false, - }) - if err != nil { + } + if err := sendCtx(c.ctx, c.subscriberCh, sub); err != nil { return err } + if err := sendCtx(c.ctx, c.querierSubCh, sub); err != nil { + // There's no need to clean up the sub sent to the subscriber if this + // fails, since it means the entire coordinator is being torn down. + return err + } + return nil } @@ -247,6 +262,7 @@ func (c *pgCoord) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) erro // can only be a context error, no need to log here. return err } + defer func() { _ = sendCtx(c.ctx, c.closeConnections, agpl.Queue(cIO)) }() if err := c.addSubscription(cIO, agent); err != nil { return err @@ -271,6 +287,8 @@ func (c *pgCoord) ServeAgent(conn net.Conn, id uuid.UUID, name string) error { // can only be a context error, no need to log here. return err } + defer func() { _ = sendCtx(c.ctx, c.closeConnections, agpl.Queue(cIO)) }() + <-cIO.ctx.Done() return nil } @@ -311,7 +329,6 @@ type subscriber struct { coordinatorID uuid.UUID store database.Store subscriptions <-chan subscribe - querierCh chan<- subscribe mu sync.Mutex // map[clientID]map[agentID]subscribe @@ -324,7 +341,6 @@ func newSubscriber(ctx context.Context, id uuid.UUID, store database.Store, subscriptions <-chan subscribe, - querierCh chan<- subscribe, startWorkers <-chan struct{}, ) *subscriber { s := &subscriber{ @@ -333,7 +349,6 @@ func newSubscriber(ctx context.Context, coordinatorID: id, store: store, subscriptions: subscriptions, - querierCh: querierCh, latest: make(map[uuid.UUID]map[uuid.UUID]subscribe), workQ: newWorkQ[sKey](ctx), } @@ -356,7 +371,6 @@ func (s *subscriber) handleSubscriptions() { case sub := <-s.subscriptions: s.storeSubscription(sub) s.workQ.enqueue(sub.sKey) - s.querierCh <- sub } } } @@ -780,8 +794,9 @@ type querier struct { pubsub pubsub.Pubsub store database.Store - newConnections chan agpl.Queue - subscriptions chan subscribe + newConnections chan agpl.Queue + closeConnections chan agpl.Queue + subscriptions chan subscribe workQ *workQ[mKey] @@ -810,6 +825,7 @@ func newQuerier(ctx context.Context, store database.Store, self uuid.UUID, newConnections chan agpl.Queue, + closeConnections chan agpl.Queue, subscriptions chan subscribe, numWorkers int, firstHeartbeat chan struct{}, @@ -822,6 +838,7 @@ func newQuerier(ctx context.Context, pubsub: ps, store: store, newConnections: newConnections, + closeConnections: closeConnections, subscriptions: subscriptions, workQ: newWorkQ[mKey](ctx), heartbeats: newHeartbeats(ctx, logger, ps, store, self, updates, firstHeartbeat), @@ -860,6 +877,9 @@ func (q *querier) handleIncoming() { panic(fmt.Sprint("unreachable: invalid queue kind ", c.Kind())) } + case c := <-q.closeConnections: + q.cleanupConn(c) + case sub := <-q.subscriptions: if sub.active { q.newClientSubscription(sub.q, sub.agentID) @@ -903,7 +923,6 @@ func (q *querier) newAgentConn(c agpl.Queue) { } cm.count++ q.conns[c.UniqueID()] = c - go q.waitCleanupConn(c) } func (q *querier) newClientSubscription(c agpl.Queue, agentID uuid.UUID) { @@ -981,12 +1000,6 @@ func (q *querier) newClientConn(c agpl.Queue) { } q.conns[c.UniqueID()] = c - go q.waitCleanupConn(c) -} - -func (q *querier) waitCleanupConn(c agpl.Queue) { - <-c.Done() - q.cleanupConn(c) } func (q *querier) cleanupConn(c agpl.Queue) { From 036094fbf27a1aa171468b480b00653abb50dfd6 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Wed, 20 Sep 2023 15:38:08 +0000 Subject: [PATCH 20/21] fixup! use dedicated channels for querier subscribe and closing conns --- enterprise/tailnet/pgcoord.go | 1 + 1 file changed, 1 insertion(+) diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 32595005320b2..5e3f6b2f12205 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -267,6 +267,7 @@ func (c *pgCoord) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) erro if err := c.addSubscription(cIO, agent); err != nil { return err } + defer func() { _ = c.removeSubscription(cIO, agent) }() <-cIO.ctx.Done() return nil From e8a2b0104a1943ca2bd59ddde56f23d2b059b008 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Thu, 21 Sep 2023 18:56:26 +0000 Subject: [PATCH 21/21] fixup! use dedicated channels for querier subscribe and closing conns --- coderd/database/dbtestutil/db.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/coderd/database/dbtestutil/db.go b/coderd/database/dbtestutil/db.go index 84c9cee3ee224..c6ebbcee35b1a 100644 --- a/coderd/database/dbtestutil/db.go +++ b/coderd/database/dbtestutil/db.go @@ -19,7 +19,6 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbfake" - "github.com/coder/coder/v2/coderd/database/migrations" "github.com/coder/coder/v2/coderd/database/postgres" "github.com/coder/coder/v2/coderd/database/pubsub" ) @@ -94,9 +93,6 @@ func NewDB(t testing.TB, opts ...Option) (database.Store, pubsub.Pubsub) { } db = database.New(sqlDB) - err = migrations.Up(sqlDB) - require.NoError(t, err) - ps, err = pubsub.New(context.Background(), sqlDB, connectionURL) require.NoError(t, err) t.Cleanup(func() {