diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 8eb3a44560a6a..eb944edb6ee54 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -694,6 +694,13 @@ func (q *querier) DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) e return q.db.DeleteAPIKeysByUserID(ctx, userID) } +func (q *querier) DeleteAllTailnetClientSubscriptions(ctx context.Context, arg database.DeleteAllTailnetClientSubscriptionsParams) error { + if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil { + return err + } + return q.db.DeleteAllTailnetClientSubscriptions(ctx, arg) +} + func (q *querier) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error { // TODO: This is not 100% correct because it omits apikey IDs. err := q.authorizeContext(ctx, rbac.ActionDelete, @@ -783,6 +790,13 @@ func (q *querier) DeleteTailnetClient(ctx context.Context, arg database.DeleteTa return q.db.DeleteTailnetClient(ctx, arg) } +func (q *querier) DeleteTailnetClientSubscription(ctx context.Context, arg database.DeleteTailnetClientSubscriptionParams) error { + if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil { + return err + } + return q.db.DeleteTailnetClientSubscription(ctx, arg) +} + func (q *querier) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) { return fetch(q.log, q.auth, q.db.GetAPIKeyByID)(ctx, id) } @@ -825,9 +839,9 @@ func (q *querier) GetAllTailnetAgents(ctx context.Context) ([]database.TailnetAg return q.db.GetAllTailnetAgents(ctx) } -func (q *querier) GetAllTailnetClients(ctx context.Context) ([]database.TailnetClient, error) { +func (q *querier) GetAllTailnetClients(ctx context.Context) ([]database.GetAllTailnetClientsRow, error) { if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTailnetCoordinator); err != nil { - return []database.TailnetClient{}, err + return []database.GetAllTailnetClientsRow{}, err } return q.db.GetAllTailnetClients(ctx) } @@ -2794,6 +2808,13 @@ func (q *querier) UpsertTailnetClient(ctx context.Context, arg database.UpsertTa return q.db.UpsertTailnetClient(ctx, arg) } +func (q *querier) UpsertTailnetClientSubscription(ctx context.Context, arg database.UpsertTailnetClientSubscriptionParams) error { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { + return err + } + return q.db.UpsertTailnetClientSubscription(ctx, arg) +} + func (q *querier) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (database.TailnetCoordinator, error) { if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { return database.TailnetCoordinator{}, err diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index 7d44c634ee49c..9e9d4ab61e8e8 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -854,6 +854,15 @@ func (q *FakeQuerier) DeleteAPIKeysByUserID(_ context.Context, userID uuid.UUID) return nil } +func (*FakeQuerier) DeleteAllTailnetClientSubscriptions(_ context.Context, arg database.DeleteAllTailnetClientSubscriptionsParams) error { + err := validateDatabaseType(arg) + if err != nil { + return err + } + + return ErrUnimplemented +} + func (q *FakeQuerier) DeleteApplicationConnectAPIKeysByUserID(_ context.Context, userID uuid.UUID) error { q.mutex.Lock() defer q.mutex.Unlock() @@ -987,6 +996,10 @@ func (*FakeQuerier) DeleteTailnetClient(context.Context, database.DeleteTailnetC return database.DeleteTailnetClientRow{}, ErrUnimplemented } +func (*FakeQuerier) DeleteTailnetClientSubscription(context.Context, database.DeleteTailnetClientSubscriptionParams) error { + return ErrUnimplemented +} + func (q *FakeQuerier) GetAPIKeyByID(_ context.Context, id string) (database.APIKey, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -1102,7 +1115,7 @@ func (*FakeQuerier) GetAllTailnetAgents(_ context.Context) ([]database.TailnetAg return nil, ErrUnimplemented } -func (*FakeQuerier) GetAllTailnetClients(_ context.Context) ([]database.TailnetClient, error) { +func (*FakeQuerier) GetAllTailnetClients(_ context.Context) ([]database.GetAllTailnetClientsRow, error) { return nil, ErrUnimplemented } @@ -6112,6 +6125,10 @@ func (*FakeQuerier) UpsertTailnetClient(context.Context, database.UpsertTailnetC return database.TailnetClient{}, ErrUnimplemented } +func (*FakeQuerier) UpsertTailnetClientSubscription(context.Context, database.UpsertTailnetClientSubscriptionParams) error { + return ErrUnimplemented +} + func (*FakeQuerier) UpsertTailnetCoordinator(context.Context, uuid.UUID) (database.TailnetCoordinator, error) { return database.TailnetCoordinator{}, ErrUnimplemented } diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index a1210ad7435eb..2e0210d707b61 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -128,6 +128,13 @@ func (m metricsStore) DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUI return err } +func (m metricsStore) DeleteAllTailnetClientSubscriptions(ctx context.Context, arg database.DeleteAllTailnetClientSubscriptionsParams) error { + start := time.Now() + r0 := m.s.DeleteAllTailnetClientSubscriptions(ctx, arg) + m.queryLatencies.WithLabelValues("DeleteAllTailnetClientSubscriptions").Observe(time.Since(start).Seconds()) + return r0 +} + func (m metricsStore) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error { start := time.Now() err := m.s.DeleteApplicationConnectAPIKeysByUserID(ctx, userID) @@ -209,6 +216,13 @@ func (m metricsStore) DeleteTailnetClient(ctx context.Context, arg database.Dele return m.s.DeleteTailnetClient(ctx, arg) } +func (m metricsStore) DeleteTailnetClientSubscription(ctx context.Context, arg database.DeleteTailnetClientSubscriptionParams) error { + start := time.Now() + r0 := m.s.DeleteTailnetClientSubscription(ctx, arg) + m.queryLatencies.WithLabelValues("DeleteTailnetClientSubscription").Observe(time.Since(start).Seconds()) + return r0 +} + func (m metricsStore) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) { start := time.Now() apiKey, err := m.s.GetAPIKeyByID(ctx, id) @@ -265,7 +279,7 @@ func (m metricsStore) GetAllTailnetAgents(ctx context.Context) ([]database.Tailn return r0, r1 } -func (m metricsStore) GetAllTailnetClients(ctx context.Context) ([]database.TailnetClient, error) { +func (m metricsStore) GetAllTailnetClients(ctx context.Context) ([]database.GetAllTailnetClientsRow, error) { start := time.Now() r0, r1 := m.s.GetAllTailnetClients(ctx) m.queryLatencies.WithLabelValues("GetAllTailnetClients").Observe(time.Since(start).Seconds()) @@ -1752,6 +1766,13 @@ func (m metricsStore) UpsertTailnetClient(ctx context.Context, arg database.Upse return m.s.UpsertTailnetClient(ctx, arg) } +func (m metricsStore) UpsertTailnetClientSubscription(ctx context.Context, arg database.UpsertTailnetClientSubscriptionParams) error { + start := time.Now() + r0 := m.s.UpsertTailnetClientSubscription(ctx, arg) + m.queryLatencies.WithLabelValues("UpsertTailnetClientSubscription").Observe(time.Since(start).Seconds()) + return r0 +} + func (m metricsStore) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (database.TailnetCoordinator, error) { start := time.Now() defer m.queryLatencies.WithLabelValues("UpsertTailnetCoordinator").Observe(time.Since(start).Seconds()) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 3a791a601383a..9a7113321cc46 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -139,6 +139,20 @@ func (mr *MockStoreMockRecorder) DeleteAPIKeysByUserID(arg0, arg1 interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAPIKeysByUserID", reflect.TypeOf((*MockStore)(nil).DeleteAPIKeysByUserID), arg0, arg1) } +// DeleteAllTailnetClientSubscriptions mocks base method. +func (m *MockStore) DeleteAllTailnetClientSubscriptions(arg0 context.Context, arg1 database.DeleteAllTailnetClientSubscriptionsParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAllTailnetClientSubscriptions", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAllTailnetClientSubscriptions indicates an expected call of DeleteAllTailnetClientSubscriptions. +func (mr *MockStoreMockRecorder) DeleteAllTailnetClientSubscriptions(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllTailnetClientSubscriptions", reflect.TypeOf((*MockStore)(nil).DeleteAllTailnetClientSubscriptions), arg0, arg1) +} + // DeleteApplicationConnectAPIKeysByUserID mocks base method. func (m *MockStore) DeleteApplicationConnectAPIKeysByUserID(arg0 context.Context, arg1 uuid.UUID) error { m.ctrl.T.Helper() @@ -310,6 +324,20 @@ func (mr *MockStoreMockRecorder) DeleteTailnetClient(arg0, arg1 interface{}) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTailnetClient", reflect.TypeOf((*MockStore)(nil).DeleteTailnetClient), arg0, arg1) } +// DeleteTailnetClientSubscription mocks base method. +func (m *MockStore) DeleteTailnetClientSubscription(arg0 context.Context, arg1 database.DeleteTailnetClientSubscriptionParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteTailnetClientSubscription", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteTailnetClientSubscription indicates an expected call of DeleteTailnetClientSubscription. +func (mr *MockStoreMockRecorder) DeleteTailnetClientSubscription(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTailnetClientSubscription", reflect.TypeOf((*MockStore)(nil).DeleteTailnetClientSubscription), arg0, arg1) +} + // GetAPIKeyByID mocks base method. func (m *MockStore) GetAPIKeyByID(arg0 context.Context, arg1 string) (database.APIKey, error) { m.ctrl.T.Helper() @@ -431,10 +459,10 @@ func (mr *MockStoreMockRecorder) GetAllTailnetAgents(arg0 interface{}) *gomock.C } // GetAllTailnetClients mocks base method. -func (m *MockStore) GetAllTailnetClients(arg0 context.Context) ([]database.TailnetClient, error) { +func (m *MockStore) GetAllTailnetClients(arg0 context.Context) ([]database.GetAllTailnetClientsRow, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetAllTailnetClients", arg0) - ret0, _ := ret[0].([]database.TailnetClient) + ret0, _ := ret[0].([]database.GetAllTailnetClientsRow) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -3681,6 +3709,20 @@ func (mr *MockStoreMockRecorder) UpsertTailnetClient(arg0, arg1 interface{}) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTailnetClient", reflect.TypeOf((*MockStore)(nil).UpsertTailnetClient), arg0, arg1) } +// UpsertTailnetClientSubscription mocks base method. +func (m *MockStore) UpsertTailnetClientSubscription(arg0 context.Context, arg1 database.UpsertTailnetClientSubscriptionParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertTailnetClientSubscription", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpsertTailnetClientSubscription indicates an expected call of UpsertTailnetClientSubscription. +func (mr *MockStoreMockRecorder) UpsertTailnetClientSubscription(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTailnetClientSubscription", reflect.TypeOf((*MockStore)(nil).UpsertTailnetClientSubscription), arg0, arg1) +} + // UpsertTailnetCoordinator mocks base method. func (m *MockStore) UpsertTailnetCoordinator(arg0 context.Context, arg1 uuid.UUID) (database.TailnetCoordinator, error) { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 995fd93739d74..644349bb58cae 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -219,13 +219,57 @@ $$; CREATE FUNCTION tailnet_notify_client_change() RETURNS trigger LANGUAGE plpgsql AS $$ +DECLARE + var_client_id uuid; + var_coordinator_id uuid; + var_agent_ids uuid[]; + var_agent_id uuid; BEGIN - IF (OLD IS NOT NULL) THEN - PERFORM pg_notify('tailnet_client_update', OLD.id || ',' || OLD.agent_id); - RETURN NULL; + IF (NEW.id IS NOT NULL) THEN + var_client_id = NEW.id; + var_coordinator_id = NEW.coordinator_id; + ELSIF (OLD.id IS NOT NULL) THEN + var_client_id = OLD.id; + var_coordinator_id = OLD.coordinator_id; END IF; + + -- Read all agents the client is subscribed to, so we can notify them. + SELECT + array_agg(agent_id) + INTO + var_agent_ids + FROM + tailnet_client_subscriptions subs + WHERE + subs.client_id = NEW.id AND + subs.coordinator_id = NEW.coordinator_id; + + -- No agents to notify + if (var_agent_ids IS NULL) THEN + return NULL; + END IF; + + -- pg_notify is limited to 8k bytes, which is approximately 221 UUIDs. + -- Instead of sending all agent ids in a single update, send one for each + -- agent id to prevent overflow. + FOREACH var_agent_id IN ARRAY var_agent_ids + LOOP + PERFORM pg_notify('tailnet_client_update', var_client_id || ',' || var_agent_id); + END LOOP; + + return NULL; +END; +$$; + +CREATE FUNCTION tailnet_notify_client_subscription_change() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN IF (NEW IS NOT NULL) THEN - PERFORM pg_notify('tailnet_client_update', NEW.id || ',' || NEW.agent_id); + PERFORM pg_notify('tailnet_client_update', NEW.client_id || ',' || NEW.agent_id); + RETURN NULL; + ELSIF (OLD IS NOT NULL) THEN + PERFORM pg_notify('tailnet_client_update', OLD.client_id || ',' || OLD.agent_id); RETURN NULL; END IF; END; @@ -495,10 +539,16 @@ CREATE TABLE tailnet_agents ( node jsonb NOT NULL ); +CREATE TABLE tailnet_client_subscriptions ( + client_id uuid NOT NULL, + coordinator_id uuid NOT NULL, + agent_id uuid NOT NULL, + updated_at timestamp with time zone NOT NULL +); + CREATE TABLE tailnet_clients ( id uuid NOT NULL, coordinator_id uuid NOT NULL, - agent_id uuid NOT NULL, updated_at timestamp with time zone NOT NULL, node jsonb NOT NULL ); @@ -1144,6 +1194,9 @@ ALTER TABLE ONLY site_configs ALTER TABLE ONLY tailnet_agents ADD CONSTRAINT tailnet_agents_pkey PRIMARY KEY (id, coordinator_id); +ALTER TABLE ONLY tailnet_client_subscriptions + ADD CONSTRAINT tailnet_client_subscriptions_pkey PRIMARY KEY (client_id, coordinator_id, agent_id); + ALTER TABLE ONLY tailnet_clients ADD CONSTRAINT tailnet_clients_pkey PRIMARY KEY (id, coordinator_id); @@ -1248,8 +1301,6 @@ CREATE UNIQUE INDEX idx_organization_name_lower ON organizations USING btree (lo CREATE INDEX idx_tailnet_agents_coordinator ON tailnet_agents USING btree (coordinator_id); -CREATE INDEX idx_tailnet_clients_agent ON tailnet_clients USING btree (agent_id); - CREATE INDEX idx_tailnet_clients_coordinator ON tailnet_clients USING btree (coordinator_id); CREATE UNIQUE INDEX idx_users_email ON users USING btree (email) WHERE (deleted = false); @@ -1284,6 +1335,8 @@ CREATE TRIGGER tailnet_notify_agent_change AFTER INSERT OR DELETE OR UPDATE ON t CREATE TRIGGER tailnet_notify_client_change AFTER INSERT OR DELETE OR UPDATE ON tailnet_clients FOR EACH ROW EXECUTE FUNCTION tailnet_notify_client_change(); +CREATE TRIGGER tailnet_notify_client_subscription_change AFTER INSERT OR DELETE OR UPDATE ON tailnet_client_subscriptions FOR EACH ROW EXECUTE FUNCTION tailnet_notify_client_subscription_change(); + CREATE TRIGGER tailnet_notify_coordinator_heartbeat AFTER INSERT OR UPDATE ON tailnet_coordinators FOR EACH ROW EXECUTE FUNCTION tailnet_notify_coordinator_heartbeat(); CREATE TRIGGER trigger_insert_apikeys BEFORE INSERT ON api_keys FOR EACH ROW EXECUTE FUNCTION insert_apikey_fail_if_user_deleted(); @@ -1329,6 +1382,9 @@ ALTER TABLE ONLY provisioner_jobs ALTER TABLE ONLY tailnet_agents ADD CONSTRAINT tailnet_agents_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE; +ALTER TABLE ONLY tailnet_client_subscriptions + ADD CONSTRAINT tailnet_client_subscriptions_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE; + ALTER TABLE ONLY tailnet_clients ADD CONSTRAINT tailnet_clients_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE; diff --git a/coderd/database/foreign_key_constraint.go b/coderd/database/foreign_key_constraint.go index db2021166f621..5f3b9aa5c32b3 100644 --- a/coderd/database/foreign_key_constraint.go +++ b/coderd/database/foreign_key_constraint.go @@ -19,6 +19,7 @@ const ( ForeignKeyProvisionerJobLogsJobID ForeignKeyConstraint = "provisioner_job_logs_job_id_fkey" // ALTER TABLE ONLY provisioner_job_logs ADD CONSTRAINT provisioner_job_logs_job_id_fkey FOREIGN KEY (job_id) REFERENCES provisioner_jobs(id) ON DELETE CASCADE; ForeignKeyProvisionerJobsOrganizationID ForeignKeyConstraint = "provisioner_jobs_organization_id_fkey" // ALTER TABLE ONLY provisioner_jobs ADD CONSTRAINT provisioner_jobs_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; ForeignKeyTailnetAgentsCoordinatorID ForeignKeyConstraint = "tailnet_agents_coordinator_id_fkey" // ALTER TABLE ONLY tailnet_agents ADD CONSTRAINT tailnet_agents_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE; + ForeignKeyTailnetClientSubscriptionsCoordinatorID ForeignKeyConstraint = "tailnet_client_subscriptions_coordinator_id_fkey" // ALTER TABLE ONLY tailnet_client_subscriptions ADD CONSTRAINT tailnet_client_subscriptions_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE; ForeignKeyTailnetClientsCoordinatorID ForeignKeyConstraint = "tailnet_clients_coordinator_id_fkey" // ALTER TABLE ONLY tailnet_clients ADD CONSTRAINT tailnet_clients_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE; ForeignKeyTemplateVersionParametersTemplateVersionID ForeignKeyConstraint = "template_version_parameters_template_version_id_fkey" // ALTER TABLE ONLY template_version_parameters ADD CONSTRAINT template_version_parameters_template_version_id_fkey FOREIGN KEY (template_version_id) REFERENCES template_versions(id) ON DELETE CASCADE; ForeignKeyTemplateVersionVariablesTemplateVersionID ForeignKeyConstraint = "template_version_variables_template_version_id_fkey" // ALTER TABLE ONLY template_version_variables ADD CONSTRAINT template_version_variables_template_version_id_fkey FOREIGN KEY (template_version_id) REFERENCES template_versions(id) ON DELETE CASCADE; diff --git a/coderd/database/migrations/000156_pg_coordinator_single_tailnet.down.sql b/coderd/database/migrations/000156_pg_coordinator_single_tailnet.down.sql new file mode 100644 index 0000000000000..7cc418489f59a --- /dev/null +++ b/coderd/database/migrations/000156_pg_coordinator_single_tailnet.down.sql @@ -0,0 +1,39 @@ +BEGIN; + +ALTER TABLE + tailnet_clients +ADD COLUMN + agent_id uuid; + +UPDATE + tailnet_clients +SET + -- there's no reason for us to try and preserve data since coordinators will + -- have to restart anyways, which will create all of the client mappings. + agent_id = '00000000-0000-0000-0000-000000000000'::uuid; + +ALTER TABLE + tailnet_clients +ALTER COLUMN + agent_id SET NOT NULL; + +DROP TABLE tailnet_client_subscriptions; +DROP FUNCTION tailnet_notify_client_subscription_change; + +-- update the tailnet_clients trigger to the old version. +CREATE OR REPLACE FUNCTION tailnet_notify_client_change() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + IF (OLD IS NOT NULL) THEN + PERFORM pg_notify('tailnet_client_update', OLD.id || ',' || OLD.agent_id); + RETURN NULL; + END IF; + IF (NEW IS NOT NULL) THEN + PERFORM pg_notify('tailnet_client_update', NEW.id || ',' || NEW.agent_id); + RETURN NULL; + END IF; +END; +$$; + +COMMIT; diff --git a/coderd/database/migrations/000156_pg_coordinator_single_tailnet.up.sql b/coderd/database/migrations/000156_pg_coordinator_single_tailnet.up.sql new file mode 100644 index 0000000000000..4ca218248ef4a --- /dev/null +++ b/coderd/database/migrations/000156_pg_coordinator_single_tailnet.up.sql @@ -0,0 +1,88 @@ +BEGIN; + +CREATE TABLE tailnet_client_subscriptions ( + client_id uuid NOT NULL, + coordinator_id uuid NOT NULL, + -- this isn't a foreign key since it's more of a list of agents the client + -- *wants* to connect to, and they don't necessarily have to currently + -- exist in the db. + agent_id uuid NOT NULL, + updated_at timestamp with time zone NOT NULL, + PRIMARY KEY (client_id, coordinator_id, agent_id), + FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators (id) ON DELETE CASCADE + -- we don't keep a foreign key to the tailnet_clients table since there's + -- not a great way to guarantee that a subscription is always added after + -- the client is inserted. clients are only created after the client sends + -- its first node update, which can take an undetermined amount of time. +); + +CREATE FUNCTION tailnet_notify_client_subscription_change() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + IF (NEW IS NOT NULL) THEN + PERFORM pg_notify('tailnet_client_update', NEW.client_id || ',' || NEW.agent_id); + RETURN NULL; + ELSIF (OLD IS NOT NULL) THEN + PERFORM pg_notify('tailnet_client_update', OLD.client_id || ',' || OLD.agent_id); + RETURN NULL; + END IF; +END; +$$; + +CREATE TRIGGER tailnet_notify_client_subscription_change + AFTER INSERT OR UPDATE OR DELETE ON tailnet_client_subscriptions + FOR EACH ROW +EXECUTE PROCEDURE tailnet_notify_client_subscription_change(); + +CREATE OR REPLACE FUNCTION tailnet_notify_client_change() RETURNS trigger + LANGUAGE plpgsql + AS $$ +DECLARE + var_client_id uuid; + var_coordinator_id uuid; + var_agent_ids uuid[]; + var_agent_id uuid; +BEGIN + IF (NEW.id IS NOT NULL) THEN + var_client_id = NEW.id; + var_coordinator_id = NEW.coordinator_id; + ELSIF (OLD.id IS NOT NULL) THEN + var_client_id = OLD.id; + var_coordinator_id = OLD.coordinator_id; + END IF; + + -- Read all agents the client is subscribed to, so we can notify them. + SELECT + array_agg(agent_id) + INTO + var_agent_ids + FROM + tailnet_client_subscriptions subs + WHERE + subs.client_id = NEW.id AND + subs.coordinator_id = NEW.coordinator_id; + + -- No agents to notify + if (var_agent_ids IS NULL) THEN + return NULL; + END IF; + + -- pg_notify is limited to 8k bytes, which is approximately 221 UUIDs. + -- Instead of sending all agent ids in a single update, send one for each + -- agent id to prevent overflow. + FOREACH var_agent_id IN ARRAY var_agent_ids + LOOP + PERFORM pg_notify('tailnet_client_update', var_client_id || ',' || var_agent_id); + END LOOP; + + return NULL; +END; +$$; + +ALTER TABLE + tailnet_clients +DROP COLUMN + agent_id; + +COMMIT; diff --git a/coderd/database/migrations/testdata/fixtures/000130_ha_coordinator.up.sql b/coderd/database/migrations/testdata/fixtures/000130_ha_coordinator.up.sql index 8af4fa4827997..dbebd6d5dd384 100644 --- a/coderd/database/migrations/testdata/fixtures/000130_ha_coordinator.up.sql +++ b/coderd/database/migrations/testdata/fixtures/000130_ha_coordinator.up.sql @@ -18,7 +18,7 @@ VALUES ); INSERT INTO tailnet_agents -(id, coordinator_id, updated_at, node) + (id, coordinator_id, updated_at, node) VALUES ( 'c0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', diff --git a/coderd/database/migrations/testdata/fixtures/000156_pg_coordinator_single_tailnet.up.sql b/coderd/database/migrations/testdata/fixtures/000156_pg_coordinator_single_tailnet.up.sql new file mode 100644 index 0000000000000..b5b744d6d1dc8 --- /dev/null +++ b/coderd/database/migrations/testdata/fixtures/000156_pg_coordinator_single_tailnet.up.sql @@ -0,0 +1,9 @@ +INSERT INTO tailnet_client_subscriptions + (client_id, agent_id, coordinator_id, updated_at) +VALUES + ( + 'b0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', + 'c0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', + 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', + '2023-06-15 10:23:54+00' + ); diff --git a/coderd/database/models.go b/coderd/database/models.go index 4d1852a54114e..aaa141bca68ff 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -1783,11 +1783,17 @@ type TailnetAgent struct { type TailnetClient struct { ID uuid.UUID `db:"id" json:"id"` CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` - AgentID uuid.UUID `db:"agent_id" json:"agent_id"` UpdatedAt time.Time `db:"updated_at" json:"updated_at"` Node json.RawMessage `db:"node" json:"node"` } +type TailnetClientSubscription struct { + ClientID uuid.UUID `db:"client_id" json:"client_id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` + AgentID uuid.UUID `db:"agent_id" json:"agent_id"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} + // We keep this separate from replicas in case we need to break the coordinator out into its own service type TailnetCoordinator struct { ID uuid.UUID `db:"id" json:"id"` diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 483976c0355d7..07d3ce26aa6c9 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -36,6 +36,7 @@ type sqlcQuerier interface { CleanTailnetCoordinators(ctx context.Context) error DeleteAPIKeyByID(ctx context.Context, id string) error DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error + DeleteAllTailnetClientSubscriptions(ctx context.Context, arg DeleteAllTailnetClientSubscriptionsParams) error DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error DeleteCoordinator(ctx context.Context, id uuid.UUID) error DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error @@ -50,6 +51,7 @@ type sqlcQuerier interface { DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error DeleteTailnetAgent(ctx context.Context, arg DeleteTailnetAgentParams) (DeleteTailnetAgentRow, error) DeleteTailnetClient(ctx context.Context, arg DeleteTailnetClientParams) (DeleteTailnetClientRow, error) + DeleteTailnetClientSubscription(ctx context.Context, arg DeleteTailnetClientSubscriptionParams) error GetAPIKeyByID(ctx context.Context, id string) (APIKey, error) // there is no unique constraint on empty token names GetAPIKeyByName(ctx context.Context, arg GetAPIKeyByNameParams) (APIKey, error) @@ -59,7 +61,7 @@ type sqlcQuerier interface { GetActiveUserCount(ctx context.Context) (int64, error) GetActiveWorkspaceBuildsByTemplateID(ctx context.Context, templateID uuid.UUID) ([]WorkspaceBuild, error) GetAllTailnetAgents(ctx context.Context) ([]TailnetAgent, error) - GetAllTailnetClients(ctx context.Context) ([]TailnetClient, error) + GetAllTailnetClients(ctx context.Context) ([]GetAllTailnetClientsRow, error) GetAppSecurityKey(ctx context.Context) (string, error) // GetAuditLogsBefore retrieves `row_limit` number of audit logs before the provided // ID. @@ -324,6 +326,7 @@ type sqlcQuerier interface { UpsertServiceBanner(ctx context.Context, value string) error UpsertTailnetAgent(ctx context.Context, arg UpsertTailnetAgentParams) (TailnetAgent, error) UpsertTailnetClient(ctx context.Context, arg UpsertTailnetClientParams) (TailnetClient, error) + UpsertTailnetClientSubscription(ctx context.Context, arg UpsertTailnetClientSubscriptionParams) error UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (TailnetCoordinator, error) } diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 17b12293e0c08..b63f86ce288b3 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -4131,6 +4131,22 @@ func (q *sqlQuerier) CleanTailnetCoordinators(ctx context.Context) error { return err } +const deleteAllTailnetClientSubscriptions = `-- name: DeleteAllTailnetClientSubscriptions :exec +DELETE +FROM tailnet_client_subscriptions +WHERE client_id = $1 and coordinator_id = $2 +` + +type DeleteAllTailnetClientSubscriptionsParams struct { + ClientID uuid.UUID `db:"client_id" json:"client_id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` +} + +func (q *sqlQuerier) DeleteAllTailnetClientSubscriptions(ctx context.Context, arg DeleteAllTailnetClientSubscriptionsParams) error { + _, err := q.db.ExecContext(ctx, deleteAllTailnetClientSubscriptions, arg.ClientID, arg.CoordinatorID) + return err +} + const deleteCoordinator = `-- name: DeleteCoordinator :exec DELETE FROM tailnet_coordinators @@ -4190,6 +4206,23 @@ func (q *sqlQuerier) DeleteTailnetClient(ctx context.Context, arg DeleteTailnetC return i, err } +const deleteTailnetClientSubscription = `-- name: DeleteTailnetClientSubscription :exec +DELETE +FROM tailnet_client_subscriptions +WHERE client_id = $1 and agent_id = $2 and coordinator_id = $3 +` + +type DeleteTailnetClientSubscriptionParams struct { + ClientID uuid.UUID `db:"client_id" json:"client_id"` + AgentID uuid.UUID `db:"agent_id" json:"agent_id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` +} + +func (q *sqlQuerier) DeleteTailnetClientSubscription(ctx context.Context, arg DeleteTailnetClientSubscriptionParams) error { + _, err := q.db.ExecContext(ctx, deleteTailnetClientSubscription, arg.ClientID, arg.AgentID, arg.CoordinatorID) + return err +} + const getAllTailnetAgents = `-- name: GetAllTailnetAgents :many SELECT id, coordinator_id, updated_at, node FROM tailnet_agents @@ -4224,26 +4257,32 @@ func (q *sqlQuerier) GetAllTailnetAgents(ctx context.Context) ([]TailnetAgent, e } const getAllTailnetClients = `-- name: GetAllTailnetClients :many -SELECT id, coordinator_id, agent_id, updated_at, node +SELECT tailnet_clients.id, tailnet_clients.coordinator_id, tailnet_clients.updated_at, tailnet_clients.node, array_agg(tailnet_client_subscriptions.agent_id)::uuid[] as agent_ids FROM tailnet_clients -ORDER BY agent_id +LEFT JOIN tailnet_client_subscriptions +ON tailnet_clients.id = tailnet_client_subscriptions.client_id ` -func (q *sqlQuerier) GetAllTailnetClients(ctx context.Context) ([]TailnetClient, error) { +type GetAllTailnetClientsRow struct { + TailnetClient TailnetClient `db:"tailnet_client" json:"tailnet_client"` + AgentIds []uuid.UUID `db:"agent_ids" json:"agent_ids"` +} + +func (q *sqlQuerier) GetAllTailnetClients(ctx context.Context) ([]GetAllTailnetClientsRow, error) { rows, err := q.db.QueryContext(ctx, getAllTailnetClients) if err != nil { return nil, err } defer rows.Close() - var items []TailnetClient + var items []GetAllTailnetClientsRow for rows.Next() { - var i TailnetClient + var i GetAllTailnetClientsRow if err := rows.Scan( - &i.ID, - &i.CoordinatorID, - &i.AgentID, - &i.UpdatedAt, - &i.Node, + &i.TailnetClient.ID, + &i.TailnetClient.CoordinatorID, + &i.TailnetClient.UpdatedAt, + &i.TailnetClient.Node, + pq.Array(&i.AgentIds), ); err != nil { return nil, err } @@ -4293,9 +4332,13 @@ func (q *sqlQuerier) GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]Tail } const getTailnetClientsForAgent = `-- name: GetTailnetClientsForAgent :many -SELECT id, coordinator_id, agent_id, updated_at, node +SELECT id, coordinator_id, updated_at, node FROM tailnet_clients -WHERE agent_id = $1 +WHERE id IN ( + SELECT tailnet_client_subscriptions.client_id + FROM tailnet_client_subscriptions + WHERE tailnet_client_subscriptions.agent_id = $1 +) ` func (q *sqlQuerier) GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]TailnetClient, error) { @@ -4310,7 +4353,6 @@ func (q *sqlQuerier) GetTailnetClientsForAgent(ctx context.Context, agentID uuid if err := rows.Scan( &i.ID, &i.CoordinatorID, - &i.AgentID, &i.UpdatedAt, &i.Node, ); err != nil { @@ -4369,47 +4411,67 @@ INSERT INTO tailnet_clients ( id, coordinator_id, - agent_id, node, updated_at ) VALUES - ($1, $2, $3, $4, now() at time zone 'utc') + ($1, $2, $3, now() at time zone 'utc') ON CONFLICT (id, coordinator_id) DO UPDATE SET id = $1, coordinator_id = $2, - agent_id = $3, - node = $4, + node = $3, updated_at = now() at time zone 'utc' -RETURNING id, coordinator_id, agent_id, updated_at, node +RETURNING id, coordinator_id, updated_at, node ` type UpsertTailnetClientParams struct { ID uuid.UUID `db:"id" json:"id"` CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` - AgentID uuid.UUID `db:"agent_id" json:"agent_id"` Node json.RawMessage `db:"node" json:"node"` } func (q *sqlQuerier) UpsertTailnetClient(ctx context.Context, arg UpsertTailnetClientParams) (TailnetClient, error) { - row := q.db.QueryRowContext(ctx, upsertTailnetClient, - arg.ID, - arg.CoordinatorID, - arg.AgentID, - arg.Node, - ) + row := q.db.QueryRowContext(ctx, upsertTailnetClient, arg.ID, arg.CoordinatorID, arg.Node) var i TailnetClient err := row.Scan( &i.ID, &i.CoordinatorID, - &i.AgentID, &i.UpdatedAt, &i.Node, ) return i, err } +const upsertTailnetClientSubscription = `-- name: UpsertTailnetClientSubscription :exec +INSERT INTO + tailnet_client_subscriptions ( + client_id, + coordinator_id, + agent_id, + updated_at +) +VALUES + ($1, $2, $3, now() at time zone 'utc') +ON CONFLICT (client_id, coordinator_id, agent_id) +DO UPDATE SET + client_id = $1, + coordinator_id = $2, + agent_id = $3, + updated_at = now() at time zone 'utc' +` + +type UpsertTailnetClientSubscriptionParams struct { + ClientID uuid.UUID `db:"client_id" json:"client_id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` + AgentID uuid.UUID `db:"agent_id" json:"agent_id"` +} + +func (q *sqlQuerier) UpsertTailnetClientSubscription(ctx context.Context, arg UpsertTailnetClientSubscriptionParams) error { + _, err := q.db.ExecContext(ctx, upsertTailnetClientSubscription, arg.ClientID, arg.CoordinatorID, arg.AgentID) + return err +} + const upsertTailnetCoordinator = `-- name: UpsertTailnetCoordinator :one INSERT INTO tailnet_coordinators ( diff --git a/coderd/database/queries/tailnet.sql b/coderd/database/queries/tailnet.sql index fd2db296dfa54..16f8708f3210a 100644 --- a/coderd/database/queries/tailnet.sql +++ b/coderd/database/queries/tailnet.sql @@ -3,21 +3,36 @@ INSERT INTO tailnet_clients ( id, coordinator_id, - agent_id, node, updated_at ) VALUES - ($1, $2, $3, $4, now() at time zone 'utc') + ($1, $2, $3, now() at time zone 'utc') ON CONFLICT (id, coordinator_id) DO UPDATE SET id = $1, coordinator_id = $2, - agent_id = $3, - node = $4, + node = $3, updated_at = now() at time zone 'utc' RETURNING *; +-- name: UpsertTailnetClientSubscription :exec +INSERT INTO + tailnet_client_subscriptions ( + client_id, + coordinator_id, + agent_id, + updated_at +) +VALUES + ($1, $2, $3, now() at time zone 'utc') +ON CONFLICT (client_id, coordinator_id, agent_id) +DO UPDATE SET + client_id = $1, + coordinator_id = $2, + agent_id = $3, + updated_at = now() at time zone 'utc'; + -- name: UpsertTailnetAgent :one INSERT INTO tailnet_agents ( @@ -43,6 +58,16 @@ FROM tailnet_clients WHERE id = $1 and coordinator_id = $2 RETURNING id, coordinator_id; +-- name: DeleteTailnetClientSubscription :exec +DELETE +FROM tailnet_client_subscriptions +WHERE client_id = $1 and agent_id = $2 and coordinator_id = $3; + +-- name: DeleteAllTailnetClientSubscriptions :exec +DELETE +FROM tailnet_client_subscriptions +WHERE client_id = $1 and coordinator_id = $2; + -- name: DeleteTailnetAgent :one DELETE FROM tailnet_agents @@ -66,12 +91,17 @@ FROM tailnet_agents; -- name: GetTailnetClientsForAgent :many SELECT * FROM tailnet_clients -WHERE agent_id = $1; +WHERE id IN ( + SELECT tailnet_client_subscriptions.client_id + FROM tailnet_client_subscriptions + WHERE tailnet_client_subscriptions.agent_id = $1 +); -- name: GetAllTailnetClients :many -SELECT * +SELECT sqlc.embed(tailnet_clients), array_agg(tailnet_client_subscriptions.agent_id)::uuid[] as agent_ids FROM tailnet_clients -ORDER BY agent_id; +LEFT JOIN tailnet_client_subscriptions +ON tailnet_clients.id = tailnet_client_subscriptions.client_id; -- name: UpsertTailnetCoordinator :one INSERT INTO diff --git a/enterprise/coderd/workspaceproxycoordinate.go b/enterprise/coderd/workspaceproxycoordinate.go index ec454d73a870a..501095d44477e 100644 --- a/enterprise/coderd/workspaceproxycoordinate.go +++ b/enterprise/coderd/workspaceproxycoordinate.go @@ -68,6 +68,7 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request id := uuid.New() sub := (*api.AGPL.TailnetCoordinator.Load()).ServeMultiAgent(id) + ctx, nc := websocketNetConn(ctx, conn, websocket.MessageText) defer nc.Close() diff --git a/enterprise/tailnet/connio.go b/enterprise/tailnet/connio.go new file mode 100644 index 0000000000000..fed307758603e --- /dev/null +++ b/enterprise/tailnet/connio.go @@ -0,0 +1,137 @@ +package tailnet + +import ( + "context" + "encoding/json" + "io" + "net" + + "github.com/google/uuid" + "golang.org/x/xerrors" + "nhooyr.io/websocket" + + "cdr.dev/slog" + agpl "github.com/coder/coder/v2/tailnet" +) + +// connIO manages the reading and writing to a connected client or agent. Agent connIOs have their client field set to +// uuid.Nil. It reads node updates via its decoder, then pushes them onto the bindings channel. It receives mappings +// via its updates TrackedConn, which then writes them. +type connIO struct { + pCtx context.Context + ctx context.Context + cancel context.CancelFunc + logger slog.Logger + decoder *json.Decoder + updates *agpl.TrackedConn + bindings chan<- binding +} + +func newConnIO(pCtx context.Context, + logger slog.Logger, + bindings chan<- binding, + conn net.Conn, + id uuid.UUID, + name string, + kind agpl.QueueKind, +) *connIO { + ctx, cancel := context.WithCancel(pCtx) + c := &connIO{ + pCtx: pCtx, + ctx: ctx, + cancel: cancel, + logger: logger, + decoder: json.NewDecoder(conn), + updates: agpl.NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, kind), + bindings: bindings, + } + go c.recvLoop() + go c.updates.SendUpdates() + logger.Info(ctx, "serving connection") + return c +} + +func (c *connIO) recvLoop() { + defer func() { + // withdraw bindings when we exit. We need to use the parent context here, since our own context might be + // canceled, but we still need to withdraw bindings. + b := binding{ + bKey: bKey{ + id: c.UniqueID(), + kind: c.Kind(), + }, + } + if err := sendCtx(c.pCtx, c.bindings, b); err != nil { + c.logger.Debug(c.ctx, "parent context expired while withdrawing bindings", slog.Error(err)) + } + }() + defer c.cancel() + for { + var node agpl.Node + err := c.decoder.Decode(&node) + if err != nil { + if xerrors.Is(err, io.EOF) || + xerrors.Is(err, io.ErrClosedPipe) || + xerrors.Is(err, context.Canceled) || + xerrors.Is(err, context.DeadlineExceeded) || + websocket.CloseStatus(err) > 0 { + c.logger.Debug(c.ctx, "exiting recvLoop", slog.Error(err)) + } else { + c.logger.Error(c.ctx, "failed to decode Node update", slog.Error(err)) + } + return + } + c.logger.Debug(c.ctx, "got node update", slog.F("node", node)) + b := binding{ + bKey: bKey{ + id: c.UniqueID(), + kind: c.Kind(), + }, + node: &node, + } + if err := sendCtx(c.ctx, c.bindings, b); err != nil { + c.logger.Debug(c.ctx, "recvLoop ctx expired", slog.Error(err)) + return + } + } +} + +func (c *connIO) UniqueID() uuid.UUID { + return c.updates.UniqueID() +} + +func (c *connIO) Kind() agpl.QueueKind { + return c.updates.Kind() +} + +func (c *connIO) Enqueue(n []*agpl.Node) error { + return c.updates.Enqueue(n) +} + +func (c *connIO) Name() string { + return c.updates.Name() +} + +func (c *connIO) Stats() (start int64, lastWrite int64) { + return c.updates.Stats() +} + +func (c *connIO) Overwrites() int64 { + return c.updates.Overwrites() +} + +// CoordinatorClose is used by the coordinator when closing a Queue. It +// should skip removing itself from the coordinator. +func (c *connIO) CoordinatorClose() error { + c.cancel() + return c.updates.CoordinatorClose() +} + +func (c *connIO) Done() <-chan struct{} { + return c.ctx.Done() +} + +func (c *connIO) Close() error { + c.cancel() + return c.updates.Close() +} diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go index d97bf2cce7a6c..70ad50687b1f3 100644 --- a/enterprise/tailnet/coordinator.go +++ b/enterprise/tailnet/coordinator.go @@ -58,7 +58,7 @@ func (c *haCoordinator) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn { AgentIsLegacyFunc: c.agentIsLegacy, OnSubscribe: c.clientSubscribeToAgent, OnNodeUpdate: c.clientNodeUpdate, - OnRemove: c.clientDisconnected, + OnRemove: func(enq agpl.Queue) { c.clientDisconnected(enq.UniqueID()) }, }).Init() c.addClient(id, m) return m @@ -157,7 +157,7 @@ func (c *haCoordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error defer cancel() logger := c.clientLogger(id, agentID) - tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0) + tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0, agpl.QueueKindClient) defer tc.Close() c.addClient(id, tc) @@ -300,7 +300,7 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err } // This uniquely identifies a connection that belongs to this goroutine. unique := uuid.New() - tc := agpl.NewTrackedConn(ctx, cancel, conn, unique, logger, name, overwrites) + tc := agpl.NewTrackedConn(ctx, cancel, conn, unique, logger, name, overwrites, agpl.QueueKindAgent) // Publish all nodes on this instance that want to connect to this agent. nodes := c.nodesSubscribedToAgent(id) diff --git a/enterprise/tailnet/multiagent_test.go b/enterprise/tailnet/multiagent_test.go new file mode 100644 index 0000000000000..7546bec350504 --- /dev/null +++ b/enterprise/tailnet/multiagent_test.go @@ -0,0 +1,354 @@ +package tailnet_test + +import ( + "context" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/enterprise/tailnet" + agpl "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/testutil" +) + +// TestPGCoordinator_MultiAgent tests a single coordinator with a MultiAgent +// connecting to one agent. +// +// +--------+ +// agent1 ---> | coord1 | <--- client +// +--------+ +func TestPGCoordinator_MultiAgent(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + store, ps := dbtestutil.NewDB(t) + coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) + require.NoError(t, err) + defer coord1.Close() + + agent1 := newTestAgent(t, coord1, "agent1") + defer agent1.close() + agent1.sendNode(&agpl.Node{PreferredDERP: 5}) + + id := uuid.New() + ma1 := coord1.ServeMultiAgent(id) + defer ma1.Close() + + err = ma1.SubscribeAgent(agent1.id) + require.NoError(t, err) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + + agent1.sendNode(&agpl.Node{PreferredDERP: 1}) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + + err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) + require.NoError(t, err) + assertEventuallyHasDERPs(ctx, t, agent1, 3) + + require.NoError(t, ma1.Close()) + require.NoError(t, agent1.close()) + + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) + assertEventuallyNoAgents(ctx, t, store, agent1.id) +} + +// TestPGCoordinator_MultiAgent_UnsubscribeRace tests a single coordinator with +// a MultiAgent connecting to one agent. It tries to race a call to Unsubscribe +// with the MultiAgent closing. +// +// +--------+ +// agent1 ---> | coord1 | <--- client +// +--------+ +func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + store, ps := dbtestutil.NewDB(t) + coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) + require.NoError(t, err) + defer coord1.Close() + + agent1 := newTestAgent(t, coord1, "agent1") + defer agent1.close() + agent1.sendNode(&agpl.Node{PreferredDERP: 5}) + + id := uuid.New() + ma1 := coord1.ServeMultiAgent(id) + defer ma1.Close() + + err = ma1.SubscribeAgent(agent1.id) + require.NoError(t, err) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + + agent1.sendNode(&agpl.Node{PreferredDERP: 1}) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + + err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) + require.NoError(t, err) + assertEventuallyHasDERPs(ctx, t, agent1, 3) + + require.NoError(t, ma1.UnsubscribeAgent(agent1.id)) + require.NoError(t, ma1.Close()) + require.NoError(t, agent1.close()) + + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) + assertEventuallyNoAgents(ctx, t, store, agent1.id) +} + +// TestPGCoordinator_MultiAgent_Unsubscribe tests a single coordinator with a +// MultiAgent connecting to one agent. It unsubscribes before closing, and +// ensures node updates are no longer propagated. +// +// +--------+ +// agent1 ---> | coord1 | <--- client +// +--------+ +func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + store, ps := dbtestutil.NewDB(t) + coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) + require.NoError(t, err) + defer coord1.Close() + + agent1 := newTestAgent(t, coord1, "agent1") + defer agent1.close() + agent1.sendNode(&agpl.Node{PreferredDERP: 5}) + + id := uuid.New() + ma1 := coord1.ServeMultiAgent(id) + defer ma1.Close() + + err = ma1.SubscribeAgent(agent1.id) + require.NoError(t, err) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + + agent1.sendNode(&agpl.Node{PreferredDERP: 1}) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + + require.NoError(t, ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})) + assertEventuallyHasDERPs(ctx, t, agent1, 3) + + require.NoError(t, ma1.UnsubscribeAgent(agent1.id)) + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) + + func() { + ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3) + defer cancel() + require.NoError(t, ma1.UpdateSelf(&agpl.Node{PreferredDERP: 9})) + assertNeverHasDERPs(ctx, t, agent1, 9) + }() + func() { + ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3) + defer cancel() + agent1.sendNode(&agpl.Node{PreferredDERP: 8}) + assertMultiAgentNeverHasDERPs(ctx, t, ma1, 8) + }() + + require.NoError(t, ma1.Close()) + require.NoError(t, agent1.close()) + + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) + assertEventuallyNoAgents(ctx, t, store, agent1.id) +} + +// TestPGCoordinator_MultiAgent_MultiCoordinator tests two coordinators with a +// MultiAgent connecting to an agent on a separate coordinator. +// +// +--------+ +// agent1 ---> | coord1 | +// +--------+ +// +--------+ +// | coord2 | <--- client +// +--------+ +func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + store, ps := dbtestutil.NewDB(t) + coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) + require.NoError(t, err) + defer coord1.Close() + coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store) + require.NoError(t, err) + defer coord2.Close() + + agent1 := newTestAgent(t, coord1, "agent1") + defer agent1.close() + agent1.sendNode(&agpl.Node{PreferredDERP: 5}) + + id := uuid.New() + ma1 := coord2.ServeMultiAgent(id) + defer ma1.Close() + + err = ma1.SubscribeAgent(agent1.id) + require.NoError(t, err) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + + agent1.sendNode(&agpl.Node{PreferredDERP: 1}) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + + err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) + require.NoError(t, err) + assertEventuallyHasDERPs(ctx, t, agent1, 3) + + require.NoError(t, ma1.Close()) + require.NoError(t, agent1.close()) + + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) + assertEventuallyNoAgents(ctx, t, store, agent1.id) +} + +// TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe tests two +// coordinators with a MultiAgent connecting to an agent on a separate +// coordinator. The MultiAgent updates its own node before subscribing. +// +// +--------+ +// agent1 ---> | coord1 | +// +--------+ +// +--------+ +// | coord2 | <--- client +// +--------+ +func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + store, ps := dbtestutil.NewDB(t) + coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) + require.NoError(t, err) + defer coord1.Close() + coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store) + require.NoError(t, err) + defer coord2.Close() + + agent1 := newTestAgent(t, coord1, "agent1") + defer agent1.close() + agent1.sendNode(&agpl.Node{PreferredDERP: 5}) + + id := uuid.New() + ma1 := coord2.ServeMultiAgent(id) + defer ma1.Close() + + err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) + require.NoError(t, err) + + err = ma1.SubscribeAgent(agent1.id) + require.NoError(t, err) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + assertEventuallyHasDERPs(ctx, t, agent1, 3) + + agent1.sendNode(&agpl.Node{PreferredDERP: 1}) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + + require.NoError(t, ma1.Close()) + require.NoError(t, agent1.close()) + + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) + assertEventuallyNoAgents(ctx, t, store, agent1.id) +} + +// TestPGCoordinator_MultiAgent_TwoAgents tests three coordinators with a +// MultiAgent connecting to two agents on separate coordinators. +// +// +--------+ +// agent1 ---> | coord1 | +// +--------+ +// +--------+ +// agent2 ---> | coord2 | +// +--------+ +// +--------+ +// | coord3 | <--- client +// +--------+ +func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + store, ps := dbtestutil.NewDB(t) + coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) + require.NoError(t, err) + defer coord1.Close() + coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store) + require.NoError(t, err) + defer coord2.Close() + coord3, err := tailnet.NewPGCoord(ctx, logger.Named("coord3"), ps, store) + require.NoError(t, err) + defer coord3.Close() + + agent1 := newTestAgent(t, coord1, "agent1") + defer agent1.close() + agent1.sendNode(&agpl.Node{PreferredDERP: 5}) + + agent2 := newTestAgent(t, coord2, "agent2") + defer agent1.close() + agent2.sendNode(&agpl.Node{PreferredDERP: 6}) + + id := uuid.New() + ma1 := coord3.ServeMultiAgent(id) + defer ma1.Close() + + err = ma1.SubscribeAgent(agent1.id) + require.NoError(t, err) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + + agent1.sendNode(&agpl.Node{PreferredDERP: 1}) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + + err = ma1.SubscribeAgent(agent2.id) + require.NoError(t, err) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 6) + + agent2.sendNode(&agpl.Node{PreferredDERP: 2}) + assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 2) + + err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) + require.NoError(t, err) + assertEventuallyHasDERPs(ctx, t, agent1, 3) + assertEventuallyHasDERPs(ctx, t, agent2, 3) + + require.NoError(t, ma1.Close()) + require.NoError(t, agent1.close()) + require.NoError(t, agent2.close()) + + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) + assertEventuallyNoAgents(ctx, t, store, agent1.id) +} diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 62cc5a240cd98..5e3f6b2f12205 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -4,7 +4,7 @@ import ( "context" "database/sql" "encoding/json" - "io" + "fmt" "net" "net/http" "strings" @@ -15,7 +15,6 @@ import ( "github.com/google/uuid" "golang.org/x/exp/slices" "golang.org/x/xerrors" - "nhooyr.io/websocket" "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" @@ -27,17 +26,19 @@ import ( ) const ( - EventHeartbeats = "tailnet_coordinator_heartbeat" - eventClientUpdate = "tailnet_client_update" - eventAgentUpdate = "tailnet_agent_update" - HeartbeatPeriod = time.Second * 2 - MissedHeartbeats = 3 - numQuerierWorkers = 10 - numBinderWorkers = 10 - dbMaxBackoff = 10 * time.Second - cleanupPeriod = time.Hour + EventHeartbeats = "tailnet_coordinator_heartbeat" + eventClientUpdate = "tailnet_client_update" + eventAgentUpdate = "tailnet_agent_update" + HeartbeatPeriod = time.Second * 2 + MissedHeartbeats = 3 + numQuerierWorkers = 10 + numBinderWorkers = 10 + numSubscriberWorkers = 10 + dbMaxBackoff = 10 * time.Second + cleanupPeriod = time.Hour ) +// TODO: add subscriber to this graphic // pgCoord is a postgres-backed coordinator // // ┌────────┐ ┌────────┐ ┌───────┐ @@ -71,16 +72,20 @@ type pgCoord struct { pubsub pubsub.Pubsub store database.Store - bindings chan binding - newConnections chan *connIO - id uuid.UUID + bindings chan binding + newConnections chan agpl.Queue + closeConnections chan agpl.Queue + subscriberCh chan subscribe + querierSubCh chan subscribe + id uuid.UUID cancel context.CancelFunc closeOnce sync.Once closed chan struct{} - binder *binder - querier *querier + binder *binder + subscriber *subscriber + querier *querier } var pgCoordSubject = rbac.Subject{ @@ -106,30 +111,119 @@ func NewPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store id := uuid.New() logger = logger.Named("pgcoord").With(slog.F("coordinator_id", id)) bCh := make(chan binding) - cCh := make(chan *connIO) + // used for opening connections + cCh := make(chan agpl.Queue) + // used for closing connections + ccCh := make(chan agpl.Queue) + // for communicating subscriptions with the subscriber + sCh := make(chan subscribe) + // for communicating subscriptions with the querier + qsCh := make(chan subscribe) // signals when first heartbeat has been sent, so it's safe to start binding. fHB := make(chan struct{}) c := &pgCoord{ - ctx: ctx, - cancel: cancel, - logger: logger, - pubsub: ps, - store: store, - binder: newBinder(ctx, logger, id, store, bCh, fHB), - bindings: bCh, - newConnections: cCh, - id: id, - querier: newQuerier(ctx, logger, ps, store, id, cCh, numQuerierWorkers, fHB), - closed: make(chan struct{}), + ctx: ctx, + cancel: cancel, + logger: logger, + pubsub: ps, + store: store, + binder: newBinder(ctx, logger, id, store, bCh, fHB), + bindings: bCh, + newConnections: cCh, + closeConnections: ccCh, + subscriber: newSubscriber(ctx, logger, id, store, sCh, fHB), + subscriberCh: sCh, + querierSubCh: qsCh, + id: id, + querier: newQuerier(ctx, logger, id, ps, store, id, cCh, ccCh, qsCh, numQuerierWorkers, fHB), + closed: make(chan struct{}), } logger.Info(ctx, "starting coordinator") return c, nil } func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn { - _, _ = c, id - panic("not implemented") // TODO: Implement + ma := (&agpl.MultiAgent{ + ID: id, + AgentIsLegacyFunc: func(agentID uuid.UUID) bool { return true }, + OnSubscribe: func(enq agpl.Queue, agent uuid.UUID) (*agpl.Node, error) { + err := c.addSubscription(enq, agent) + return c.Node(agent), err + }, + OnUnsubscribe: c.removeSubscription, + OnNodeUpdate: func(id uuid.UUID, node *agpl.Node) error { + return sendCtx(c.ctx, c.bindings, binding{ + bKey: bKey{id, agpl.QueueKindClient}, + node: node, + }) + }, + OnRemove: func(enq agpl.Queue) { + _ = sendCtx(c.ctx, c.bindings, binding{ + bKey: bKey{ + id: enq.UniqueID(), + kind: enq.Kind(), + }, + }) + _ = sendCtx(c.ctx, c.subscriberCh, subscribe{ + sKey: sKey{clientID: id}, + q: enq, + active: false, + }) + _ = sendCtx(c.ctx, c.closeConnections, enq) + }, + }).Init() + + if err := sendCtx(c.ctx, c.newConnections, agpl.Queue(ma)); err != nil { + // If we can't successfully send the multiagent, that means the + // coordinator is shutting down. In this case, just return a closed + // multiagent. + ma.CoordinatorClose() + } + + return ma +} + +func (c *pgCoord) addSubscription(q agpl.Queue, agentID uuid.UUID) error { + sub := subscribe{ + sKey: sKey{ + clientID: q.UniqueID(), + agentID: agentID, + }, + q: q, + active: true, + } + if err := sendCtx(c.ctx, c.subscriberCh, sub); err != nil { + return err + } + if err := sendCtx(c.ctx, c.querierSubCh, sub); err != nil { + // There's no need to clean up the sub sent to the subscriber if this + // fails, since it means the entire coordinator is being torn down. + return err + } + + return nil +} + +func (c *pgCoord) removeSubscription(q agpl.Queue, agentID uuid.UUID) error { + sub := subscribe{ + sKey: sKey{ + clientID: q.UniqueID(), + agentID: agentID, + }, + q: q, + active: false, + } + if err := sendCtx(c.ctx, c.subscriberCh, sub); err != nil { + return err + } + if err := sendCtx(c.ctx, c.querierSubCh, sub); err != nil { + // There's no need to clean up the sub sent to the subscriber if this + // fails, since it means the entire coordinator is being torn down. + return err + } + + return nil } func (c *pgCoord) Node(id uuid.UUID) *agpl.Node { @@ -162,11 +256,19 @@ func (c *pgCoord) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) erro slog.Error(err)) } }() - cIO := newConnIO(c.ctx, c.logger, c.bindings, conn, id, agent, id.String()) - if err := sendCtx(c.ctx, c.newConnections, cIO); err != nil { + + cIO := newConnIO(c.ctx, c.logger, c.bindings, conn, id, id.String(), agpl.QueueKindClient) + if err := sendCtx(c.ctx, c.newConnections, agpl.Queue(cIO)); err != nil { // can only be a context error, no need to log here. return err } + defer func() { _ = sendCtx(c.ctx, c.closeConnections, agpl.Queue(cIO)) }() + + if err := c.addSubscription(cIO, agent); err != nil { + return err + } + defer func() { _ = c.removeSubscription(cIO, agent) }() + <-cIO.ctx.Done() return nil } @@ -181,11 +283,13 @@ func (c *pgCoord) ServeAgent(conn net.Conn, id uuid.UUID, name string) error { } }() logger := c.logger.With(slog.F("name", name)) - cIO := newConnIO(c.ctx, logger, c.bindings, conn, uuid.Nil, id, name) - if err := sendCtx(c.ctx, c.newConnections, cIO); err != nil { + cIO := newConnIO(c.ctx, logger, c.bindings, conn, id, name, agpl.QueueKindAgent) + if err := sendCtx(c.ctx, c.newConnections, agpl.Queue(cIO)); err != nil { // can only be a context error, no need to log here. return err } + defer func() { _ = sendCtx(c.ctx, c.closeConnections, agpl.Queue(cIO)) }() + <-cIO.ctx.Done() return nil } @@ -197,110 +301,200 @@ func (c *pgCoord) Close() error { return nil } -// connIO manages the reading and writing to a connected client or agent. Agent connIOs have their client field set to -// uuid.Nil. It reads node updates via its decoder, then pushes them onto the bindings channel. It receives mappings -// via its updates TrackedConn, which then writes them. -type connIO struct { - pCtx context.Context - ctx context.Context - cancel context.CancelFunc - logger slog.Logger - client uuid.UUID - agent uuid.UUID - decoder *json.Decoder - updates *agpl.TrackedConn - bindings chan<- binding +func sendCtx[A any](ctx context.Context, c chan<- A, a A) (err error) { + select { + case <-ctx.Done(): + return ctx.Err() + case c <- a: + return nil + } +} + +type sKey struct { + clientID uuid.UUID + agentID uuid.UUID +} + +type subscribe struct { + sKey + + q agpl.Queue + // whether the subscription should be active. if true, the subscription is + // added. if false, the subscription is removed. + active bool +} + +type subscriber struct { + ctx context.Context + logger slog.Logger + coordinatorID uuid.UUID + store database.Store + subscriptions <-chan subscribe + + mu sync.Mutex + // map[clientID]map[agentID]subscribe + latest map[uuid.UUID]map[uuid.UUID]subscribe + workQ *workQ[sKey] } -func newConnIO(pCtx context.Context, +func newSubscriber(ctx context.Context, logger slog.Logger, - bindings chan<- binding, - conn net.Conn, - client, agent uuid.UUID, - name string, -) *connIO { - ctx, cancel := context.WithCancel(pCtx) - id := agent - logger = logger.With(slog.F("agent_id", agent)) - if client != uuid.Nil { - logger = logger.With(slog.F("client_id", client)) - id = client - } - c := &connIO{ - pCtx: pCtx, - ctx: ctx, - cancel: cancel, - logger: logger, - client: client, - agent: agent, - decoder: json.NewDecoder(conn), - updates: agpl.NewTrackedConn(ctx, cancel, conn, id, logger, name, 0), - bindings: bindings, - } - go c.recvLoop() - go c.updates.SendUpdates() - logger.Info(ctx, "serving connection") - return c -} - -func (c *connIO) recvLoop() { - defer func() { - // withdraw bindings when we exit. We need to use the parent context here, since our own context might be - // canceled, but we still need to withdraw bindings. - b := binding{ - bKey: bKey{ - client: c.client, - agent: c.agent, - }, - } - if err := sendCtx(c.pCtx, c.bindings, b); err != nil { - c.logger.Debug(c.ctx, "parent context expired while withdrawing bindings", slog.Error(err)) + id uuid.UUID, + store database.Store, + subscriptions <-chan subscribe, + startWorkers <-chan struct{}, +) *subscriber { + s := &subscriber{ + ctx: ctx, + logger: logger, + coordinatorID: id, + store: store, + subscriptions: subscriptions, + latest: make(map[uuid.UUID]map[uuid.UUID]subscribe), + workQ: newWorkQ[sKey](ctx), + } + go s.handleSubscriptions() + go func() { + <-startWorkers + for i := 0; i < numSubscriberWorkers; i++ { + go s.worker() } }() - defer c.cancel() + return s +} + +func (s *subscriber) handleSubscriptions() { for { - var node agpl.Node - err := c.decoder.Decode(&node) + select { + case <-s.ctx.Done(): + s.logger.Debug(s.ctx, "subscriber exiting", slog.Error(s.ctx.Err())) + return + case sub := <-s.subscriptions: + s.storeSubscription(sub) + s.workQ.enqueue(sub.sKey) + } + } +} + +func (s *subscriber) worker() { + eb := backoff.NewExponentialBackOff() + eb.MaxElapsedTime = 0 // retry indefinitely + eb.MaxInterval = dbMaxBackoff + bkoff := backoff.WithContext(eb, s.ctx) + for { + bk, err := s.workQ.acquire() if err != nil { - if xerrors.Is(err, io.EOF) || - xerrors.Is(err, io.ErrClosedPipe) || - xerrors.Is(err, context.Canceled) || - xerrors.Is(err, context.DeadlineExceeded) || - websocket.CloseStatus(err) > 0 { - c.logger.Debug(c.ctx, "exiting recvLoop", slog.Error(err)) - } else { - c.logger.Error(c.ctx, "failed to decode Node update", slog.Error(err)) - } + // context expired return } - c.logger.Debug(c.ctx, "got node update", slog.F("node", node)) - b := binding{ - bKey: bKey{ - client: c.client, - agent: c.agent, - }, - node: &node, + err = backoff.Retry(func() error { + bnd := s.retrieveSubscription(bk) + return s.writeOne(bnd) + }, bkoff) + if err != nil { + bkoff.Reset() } - if err := sendCtx(c.ctx, c.bindings, b); err != nil { - c.logger.Debug(c.ctx, "recvLoop ctx expired", slog.Error(err)) - return + s.workQ.done(bk) + } +} + +func (s *subscriber) storeSubscription(sub subscribe) { + s.mu.Lock() + defer s.mu.Unlock() + if sub.active { + if _, ok := s.latest[sub.clientID]; !ok { + s.latest[sub.clientID] = map[uuid.UUID]subscribe{} + } + s.latest[sub.clientID][sub.agentID] = sub + } else { + // If the agentID is nil, clean up all of the clients subscriptions. + if sub.agentID == uuid.Nil { + delete(s.latest, sub.clientID) + } else { + delete(s.latest[sub.clientID], sub.agentID) + // clean up the subscription map if all the subscriptions are gone. + if len(s.latest[sub.clientID]) == 0 { + delete(s.latest, sub.clientID) + } } } } -func sendCtx[A any](ctx context.Context, c chan<- A, a A) (err error) { - select { - case <-ctx.Done(): - return ctx.Err() - case c <- a: - return nil +// retrieveBinding gets the latest binding for a key. +func (s *subscriber) retrieveSubscription(sk sKey) subscribe { + s.mu.Lock() + defer s.mu.Unlock() + agents, ok := s.latest[sk.clientID] + if !ok { + return subscribe{ + sKey: sk, + active: false, + } + } + + sub, ok := agents[sk.agentID] + if !ok { + return subscribe{ + sKey: sk, + active: false, + } + } + + return sub +} + +func (s *subscriber) writeOne(sub subscribe) error { + var err error + switch { + case sub.agentID == uuid.Nil: + err = s.store.DeleteAllTailnetClientSubscriptions(s.ctx, database.DeleteAllTailnetClientSubscriptionsParams{ + ClientID: sub.clientID, + CoordinatorID: s.coordinatorID, + }) + s.logger.Debug(s.ctx, "deleted all client subscriptions", + slog.F("client_id", sub.clientID), + slog.Error(err), + ) + case sub.active: + err = s.store.UpsertTailnetClientSubscription(s.ctx, database.UpsertTailnetClientSubscriptionParams{ + ClientID: sub.clientID, + CoordinatorID: s.coordinatorID, + AgentID: sub.agentID, + }) + s.logger.Debug(s.ctx, "upserted client subscription", + slog.F("client_id", sub.clientID), + slog.F("agent_id", sub.agentID), + slog.Error(err), + ) + case !sub.active: + err = s.store.DeleteTailnetClientSubscription(s.ctx, database.DeleteTailnetClientSubscriptionParams{ + ClientID: sub.clientID, + CoordinatorID: s.coordinatorID, + AgentID: sub.agentID, + }) + s.logger.Debug(s.ctx, "deleted client subscription", + slog.F("client_id", sub.clientID), + slog.F("agent_id", sub.agentID), + slog.Error(err), + ) + default: + panic("unreachable") } + if err != nil && !database.IsQueryCanceledError(err) { + s.logger.Error(s.ctx, "write subscription to database", + slog.F("client_id", sub.clientID), + slog.F("agent_id", sub.agentID), + slog.F("active", sub.active), + slog.Error(err)) + } + return err } -// bKey, or "binding key" identifies a client or agent in a binding. Agents have their client field set to uuid.Nil. +// bKey, or "binding key" identifies a client or agent in a binding. Agents and +// clients are differentiated by the kind field. type bKey struct { - client uuid.UUID - agent uuid.UUID + id uuid.UUID + kind agpl.QueueKind } // binding represents an association between a client or agent and a Node. @@ -309,8 +503,8 @@ type binding struct { node *agpl.Node } -func (b *binding) isAgent() bool { return b.client == uuid.Nil } -func (b *binding) isClient() bool { return b.client != uuid.Nil } +func (b *binding) isAgent() bool { return b.kind == agpl.QueueKindAgent } +func (b *binding) isClient() bool { return b.kind == agpl.QueueKindClient } // binder reads node bindings from the channel and writes them to the database. It handles retries with a backoff. type binder struct { @@ -325,9 +519,12 @@ type binder struct { workQ *workQ[bKey] } -func newBinder(ctx context.Context, logger slog.Logger, - id uuid.UUID, store database.Store, - bindings <-chan binding, startWorkers <-chan struct{}, +func newBinder(ctx context.Context, + logger slog.Logger, + id uuid.UUID, + store database.Store, + bindings <-chan binding, + startWorkers <-chan struct{}, ) *binder { b := &binder{ ctx: ctx, @@ -399,40 +596,39 @@ func (b *binder) writeOne(bnd binding) error { switch { case bnd.isAgent() && len(nodeRaw) > 0: _, err = b.store.UpsertTailnetAgent(b.ctx, database.UpsertTailnetAgentParams{ - ID: bnd.agent, + ID: bnd.id, CoordinatorID: b.coordinatorID, Node: nodeRaw, }) b.logger.Debug(b.ctx, "upserted agent binding", - slog.F("agent_id", bnd.agent), slog.F("node", nodeRaw), slog.Error(err)) + slog.F("agent_id", bnd.id), slog.F("node", nodeRaw), slog.Error(err)) case bnd.isAgent() && len(nodeRaw) == 0: _, err = b.store.DeleteTailnetAgent(b.ctx, database.DeleteTailnetAgentParams{ - ID: bnd.agent, + ID: bnd.id, CoordinatorID: b.coordinatorID, }) b.logger.Debug(b.ctx, "deleted agent binding", - slog.F("agent_id", bnd.agent), slog.Error(err)) + slog.F("agent_id", bnd.id), slog.Error(err)) if xerrors.Is(err, sql.ErrNoRows) { // treat deletes as idempotent err = nil } case bnd.isClient() && len(nodeRaw) > 0: _, err = b.store.UpsertTailnetClient(b.ctx, database.UpsertTailnetClientParams{ - ID: bnd.client, + ID: bnd.id, CoordinatorID: b.coordinatorID, - AgentID: bnd.agent, Node: nodeRaw, }) b.logger.Debug(b.ctx, "upserted client binding", - slog.F("agent_id", bnd.agent), slog.F("client_id", bnd.client), + slog.F("client_id", bnd.id), slog.F("node", nodeRaw), slog.Error(err)) case bnd.isClient() && len(nodeRaw) == 0: _, err = b.store.DeleteTailnetClient(b.ctx, database.DeleteTailnetClientParams{ - ID: bnd.client, + ID: bnd.id, CoordinatorID: b.coordinatorID, }) b.logger.Debug(b.ctx, "deleted client binding", - slog.F("agent_id", bnd.agent), slog.F("client_id", bnd.client), slog.Error(err)) + slog.F("client_id", bnd.id)) if xerrors.Is(err, sql.ErrNoRows) { // treat deletes as idempotent err = nil @@ -442,8 +638,8 @@ func (b *binder) writeOne(bnd binding) error { } if err != nil && !database.IsQueryCanceledError(err) { b.logger.Error(b.ctx, "failed to write binding to database", - slog.F("client_id", bnd.client), - slog.F("agent_id", bnd.agent), + slog.F("binding_id", bnd.id), + slog.F("kind", bnd.kind), slog.F("node", string(nodeRaw)), slog.Error(err)) } @@ -483,8 +679,8 @@ type mapper struct { ctx context.Context logger slog.Logger - add chan *connIO - del chan *connIO + add chan agpl.Queue + del chan agpl.Queue // reads from this channel trigger sending latest nodes to // all connections. It is used when coordinators are added @@ -493,7 +689,7 @@ type mapper struct { mappings chan []mapping - conns map[bKey]*connIO + conns map[bKey]agpl.Queue latest []mapping heartbeats *heartbeats @@ -502,15 +698,15 @@ type mapper struct { func newMapper(ctx context.Context, logger slog.Logger, mk mKey, h *heartbeats) *mapper { logger = logger.With( slog.F("agent_id", mk.agent), - slog.F("clients_of_agent", mk.clientsOfAgent), + slog.F("kind", mk.kind), ) m := &mapper{ ctx: ctx, logger: logger, - add: make(chan *connIO), - del: make(chan *connIO), + add: make(chan agpl.Queue), + del: make(chan agpl.Queue), update: make(chan struct{}), - conns: make(map[bKey]*connIO), + conns: make(map[bKey]agpl.Queue), mappings: make(chan []mapping), heartbeats: h, } @@ -524,17 +720,17 @@ func (m *mapper) run() { case <-m.ctx.Done(): return case c := <-m.add: - m.conns[bKey{c.client, c.agent}] = c + m.conns[bKey{id: c.UniqueID(), kind: c.Kind()}] = c nodes := m.mappingsToNodes(m.latest) if len(nodes) == 0 { m.logger.Debug(m.ctx, "skipping 0 length node update") continue } - if err := c.updates.Enqueue(nodes); err != nil { + if err := c.Enqueue(nodes); err != nil { m.logger.Error(m.ctx, "failed to enqueue node update", slog.Error(err)) } case c := <-m.del: - delete(m.conns, bKey{c.client, c.agent}) + delete(m.conns, bKey{id: c.UniqueID(), kind: c.Kind()}) case mappings := <-m.mappings: m.latest = mappings nodes := m.mappingsToNodes(mappings) @@ -543,7 +739,7 @@ func (m *mapper) run() { continue } for _, conn := range m.conns { - if err := conn.updates.Enqueue(nodes); err != nil { + if err := conn.Enqueue(nodes); err != nil { m.logger.Error(m.ctx, "failed to enqueue node update", slog.Error(err)) } } @@ -554,7 +750,7 @@ func (m *mapper) run() { continue } for _, conn := range m.conns { - if err := conn.updates.Enqueue(nodes); err != nil { + if err := conn.Enqueue(nodes); err != nil { m.logger.Error(m.ctx, "failed to enqueue triggered node update", slog.Error(err)) } } @@ -570,7 +766,13 @@ func (m *mapper) mappingsToNodes(mappings []mapping) []*agpl.Node { mappings = m.heartbeats.filter(mappings) best := make(map[bKey]mapping, len(mappings)) for _, m := range mappings { - bk := bKey{client: m.client, agent: m.agent} + var bk bKey + if m.client == uuid.Nil { + bk = bKey{id: m.agent, kind: agpl.QueueKindAgent} + } else { + bk = bKey{id: m.client, kind: agpl.QueueKindClient} + } + bestM, ok := best[bk] if !ok || m.updatedAt.After(bestM.updatedAt) { best[bk] = m @@ -587,20 +789,28 @@ func (m *mapper) mappingsToNodes(mappings []mapping) []*agpl.Node { // connected clients and agents need. It also checks heartbeats and withdraws mappings from coordinators that have // failed heartbeats. type querier struct { - ctx context.Context - logger slog.Logger - pubsub pubsub.Pubsub - store database.Store - newConnections chan *connIO + ctx context.Context + logger slog.Logger + coordinatorID uuid.UUID + pubsub pubsub.Pubsub + store database.Store + + newConnections chan agpl.Queue + closeConnections chan agpl.Queue + subscriptions chan subscribe + + workQ *workQ[mKey] - workQ *workQ[mKey] heartbeats *heartbeats updates <-chan hbUpdate mu sync.Mutex mappers map[mKey]*countedMapper - conns map[*connIO]struct{} - healthy bool + conns map[uuid.UUID]agpl.Queue + // clientSubscriptions maps client ids to the agent ids they're subscribed to. + // map[client_id]map[agent_id] + clientSubscriptions map[uuid.UUID]map[uuid.UUID]struct{} + healthy bool } type countedMapper struct { @@ -609,62 +819,92 @@ type countedMapper struct { cancel context.CancelFunc } -func newQuerier( - ctx context.Context, logger slog.Logger, - ps pubsub.Pubsub, store database.Store, - self uuid.UUID, newConnections chan *connIO, numWorkers int, - firstHeartbeat chan<- struct{}, +func newQuerier(ctx context.Context, + logger slog.Logger, + coordinatorID uuid.UUID, + ps pubsub.Pubsub, + store database.Store, + self uuid.UUID, + newConnections chan agpl.Queue, + closeConnections chan agpl.Queue, + subscriptions chan subscribe, + numWorkers int, + firstHeartbeat chan struct{}, ) *querier { updates := make(chan hbUpdate) q := &querier{ - ctx: ctx, - logger: logger.Named("querier"), - pubsub: ps, - store: store, - newConnections: newConnections, - workQ: newWorkQ[mKey](ctx), - heartbeats: newHeartbeats(ctx, logger, ps, store, self, updates, firstHeartbeat), - mappers: make(map[mKey]*countedMapper), - conns: make(map[*connIO]struct{}), - updates: updates, - healthy: true, // assume we start healthy + ctx: ctx, + logger: logger.Named("querier"), + coordinatorID: coordinatorID, + pubsub: ps, + store: store, + newConnections: newConnections, + closeConnections: closeConnections, + subscriptions: subscriptions, + workQ: newWorkQ[mKey](ctx), + heartbeats: newHeartbeats(ctx, logger, ps, store, self, updates, firstHeartbeat), + mappers: make(map[mKey]*countedMapper), + conns: make(map[uuid.UUID]agpl.Queue), + updates: updates, + clientSubscriptions: make(map[uuid.UUID]map[uuid.UUID]struct{}), + healthy: true, // assume we start healthy } q.subscribe() - go q.handleConnIO() - for i := 0; i < numWorkers; i++ { - go q.worker() - } - go q.handleUpdates() + + go func() { + <-firstHeartbeat + go q.handleIncoming() + for i := 0; i < numWorkers; i++ { + go q.worker() + } + go q.handleUpdates() + }() return q } -func (q *querier) handleConnIO() { +func (q *querier) handleIncoming() { for { select { case <-q.ctx.Done(): return + case c := <-q.newConnections: - q.newConn(c) + switch c.Kind() { + case agpl.QueueKindAgent: + q.newAgentConn(c) + case agpl.QueueKindClient: + q.newClientConn(c) + default: + panic(fmt.Sprint("unreachable: invalid queue kind ", c.Kind())) + } + + case c := <-q.closeConnections: + q.cleanupConn(c) + + case sub := <-q.subscriptions: + if sub.active { + q.newClientSubscription(sub.q, sub.agentID) + } else { + q.removeClientSubscription(sub.q, sub.agentID) + } } } } -func (q *querier) newConn(c *connIO) { +func (q *querier) newAgentConn(c agpl.Queue) { q.mu.Lock() defer q.mu.Unlock() if !q.healthy { - err := c.updates.Close() + err := c.Close() q.logger.Info(q.ctx, "closed incoming connection while unhealthy", slog.Error(err), - slog.F("agent_id", c.agent), - slog.F("client_id", c.client), + slog.F("agent_id", c.UniqueID()), ) return } mk := mKey{ - agent: c.agent, - // if client is Nil, this is an agent connection, and it wants the mappings for all the clients of itself - clientsOfAgent: c.client == uuid.Nil, + agent: c.UniqueID(), + kind: c.Kind(), } cm, ok := q.mappers[mk] if !ok { @@ -683,21 +923,118 @@ func (q *querier) newConn(c *connIO) { return } cm.count++ - q.conns[c] = struct{}{} - go q.cleanupConn(c) + q.conns[c.UniqueID()] = c } -func (q *querier) cleanupConn(c *connIO) { - <-c.ctx.Done() +func (q *querier) newClientSubscription(c agpl.Queue, agentID uuid.UUID) { q.mu.Lock() defer q.mu.Unlock() - delete(q.conns, c) + + if _, ok := q.clientSubscriptions[c.UniqueID()]; !ok { + q.clientSubscriptions[c.UniqueID()] = map[uuid.UUID]struct{}{} + } + mk := mKey{ - agent: c.agent, - // if client is Nil, this is an agent connection, and it wants the mappings for all the clients of itself - clientsOfAgent: c.client == uuid.Nil, + agent: agentID, + kind: agpl.QueueKindClient, + } + cm, ok := q.mappers[mk] + if !ok { + ctx, cancel := context.WithCancel(q.ctx) + mpr := newMapper(ctx, q.logger, mk, q.heartbeats) + cm = &countedMapper{ + mapper: mpr, + count: 0, + cancel: cancel, + } + q.mappers[mk] = cm + // we don't have any mapping state for this key yet + q.workQ.enqueue(mk) + } + if err := sendCtx(cm.ctx, cm.add, c); err != nil { + return + } + q.clientSubscriptions[c.UniqueID()][agentID] = struct{}{} + cm.count++ +} + +func (q *querier) removeClientSubscription(c agpl.Queue, agentID uuid.UUID) { + q.mu.Lock() + defer q.mu.Unlock() + + // Allow duplicate unsubscribes. It's possible for cleanupConn to race with + // an external call to removeClientSubscription, so we just ensure the + // client subscription exists before attempting to remove it. + if _, ok := q.clientSubscriptions[c.UniqueID()][agentID]; !ok { + return + } + + mk := mKey{ + agent: agentID, + kind: agpl.QueueKindClient, } cm := q.mappers[mk] + if err := sendCtx(cm.ctx, cm.del, c); err != nil { + return + } + delete(q.clientSubscriptions[c.UniqueID()], agentID) + cm.count-- + if cm.count == 0 { + cm.cancel() + delete(q.mappers, mk) + } + if len(q.clientSubscriptions[c.UniqueID()]) == 0 { + delete(q.clientSubscriptions, c.UniqueID()) + } +} + +func (q *querier) newClientConn(c agpl.Queue) { + q.mu.Lock() + defer q.mu.Unlock() + if !q.healthy { + err := c.Close() + q.logger.Info(q.ctx, "closed incoming connection while unhealthy", + slog.Error(err), + slog.F("client_id", c.UniqueID()), + ) + return + } + + q.conns[c.UniqueID()] = c +} + +func (q *querier) cleanupConn(c agpl.Queue) { + q.mu.Lock() + defer q.mu.Unlock() + delete(q.conns, c.UniqueID()) + + // Iterate over all subscriptions and remove them from the mappers. + for agentID := range q.clientSubscriptions[c.UniqueID()] { + mk := mKey{ + agent: agentID, + kind: c.Kind(), + } + cm := q.mappers[mk] + if err := sendCtx(cm.ctx, cm.del, c); err != nil { + continue + } + cm.count-- + if cm.count == 0 { + cm.cancel() + delete(q.mappers, mk) + } + } + delete(q.clientSubscriptions, c.UniqueID()) + + mk := mKey{ + agent: c.UniqueID(), + kind: c.Kind(), + } + cm, ok := q.mappers[mk] + if !ok { + return + } + if err := sendCtx(cm.ctx, cm.del, c); err != nil { return } @@ -732,12 +1069,15 @@ func (q *querier) worker() { func (q *querier) query(mk mKey) error { var mappings []mapping var err error - if mk.clientsOfAgent { + // If the mapping is an agent, query all of its clients. + if mk.kind == agpl.QueueKindAgent { mappings, err = q.queryClientsOfAgent(mk.agent) if err != nil { return err } } else { + // The mapping is for clients subscribed to the agent. Query the agent + // itself. mappings, err = q.queryAgent(mk.agent) if err != nil { return err @@ -748,9 +1088,10 @@ func (q *querier) query(mk mKey) error { q.mu.Unlock() if !ok { q.logger.Debug(q.ctx, "query for missing mapper", - slog.F("agent_id", mk.agent), slog.F("clients_of_agent", mk.clientsOfAgent)) + slog.F("agent_id", mk.agent), slog.F("kind", mk.kind)) return nil } + q.logger.Debug(q.ctx, "sending mappings", slog.F("mapping_len", len(mappings))) mpr.mappings <- mappings return nil } @@ -772,7 +1113,7 @@ func (q *querier) queryClientsOfAgent(agent uuid.UUID) ([]mapping, error) { } mappings = append(mappings, mapping{ client: client.ID, - agent: client.AgentID, + agent: agent, coordinator: client.CoordinatorID, updatedAt: client.UpdatedAt, node: node, @@ -788,6 +1129,11 @@ func (q *querier) queryAgent(agentID uuid.UUID) ([]mapping, error) { if err != nil { return nil, err } + return q.agentsToMappings(agents) +} + +func (q *querier) agentsToMappings(agents []database.TailnetAgent) ([]mapping, error) { + slog.Helper() mappings := make([]mapping, 0, len(agents)) for _, agent := range agents { node := new(agpl.Node) @@ -873,6 +1219,7 @@ func (q *querier) listenClient(_ context.Context, msg []byte, err error) { } if err != nil { q.logger.Warn(q.ctx, "unhandled pubsub error", slog.Error(err)) + return } client, agent, err := parseClientUpdate(string(msg)) if err != nil { @@ -881,9 +1228,10 @@ func (q *querier) listenClient(_ context.Context, msg []byte, err error) { } logger := q.logger.With(slog.F("client_id", client), slog.F("agent_id", agent)) logger.Debug(q.ctx, "got client update") + mk := mKey{ - agent: agent, - clientsOfAgent: true, + agent: agent, + kind: agpl.QueueKindAgent, } q.mu.Lock() _, ok := q.mappers[mk] @@ -905,7 +1253,7 @@ func (q *querier) listenAgent(_ context.Context, msg []byte, err error) { if err != nil { q.logger.Warn(q.ctx, "unhandled pubsub error", slog.Error(err)) } - agent, err := parseAgentUpdate(string(msg)) + agent, err := parseUpdateMessage(string(msg)) if err != nil { q.logger.Error(q.ctx, "failed to parse agent update", slog.F("msg", string(msg)), slog.Error(err)) return @@ -913,8 +1261,8 @@ func (q *querier) listenAgent(_ context.Context, msg []byte, err error) { logger := q.logger.With(slog.F("agent_id", agent)) logger.Debug(q.ctx, "got agent update") mk := mKey{ - agent: agent, - clientsOfAgent: false, + agent: agent, + kind: agpl.QueueKindClient, } q.mu.Lock() _, ok := q.mappers[mk] @@ -930,7 +1278,7 @@ func (q *querier) resyncClientMappings() { q.mu.Lock() defer q.mu.Unlock() for mk := range q.mappers { - if mk.clientsOfAgent { + if mk.kind == agpl.QueueKindClient { q.workQ.enqueue(mk) } } @@ -940,7 +1288,7 @@ func (q *querier) resyncAgentMappings() { q.mu.Lock() defer q.mu.Unlock() for mk := range q.mappers { - if !mk.clientsOfAgent { + if mk.kind == agpl.QueueKindAgent { q.workQ.enqueue(mk) } } @@ -988,10 +1336,10 @@ func (q *querier) unhealthyCloseAll() { q.mu.Lock() defer q.mu.Unlock() q.healthy = false - for c := range q.conns { + for _, c := range q.conns { // close connections async so that we don't block the querier routine that responds to updates - go func(c *connIO) { - err := c.updates.Close() + go func(c agpl.Queue) { + err := c.Close() if err != nil { q.logger.Debug(q.ctx, "error closing conn while unhealthy", slog.Error(err)) } @@ -1021,7 +1369,9 @@ func (q *querier) getAll(ctx context.Context) (map[uuid.UUID]database.TailnetAge } clientsMap := map[uuid.UUID][]database.TailnetClient{} for _, client := range clients { - clientsMap[client.AgentID] = append(clientsMap[client.AgentID], client) + for _, agentID := range client.AgentIds { + clientsMap[agentID] = append(clientsMap[agentID], client.TailnetClient) + } } return agentsMap, clientsMap, nil @@ -1036,17 +1386,19 @@ func parseClientUpdate(msg string) (client, agent uuid.UUID, err error) { if err != nil { return uuid.Nil, uuid.Nil, xerrors.Errorf("failed to parse client UUID: %w", err) } + agent, err = uuid.Parse(parts[1]) if err != nil { return uuid.Nil, uuid.Nil, xerrors.Errorf("failed to parse agent UUID: %w", err) } + return client, agent, nil } -func parseAgentUpdate(msg string) (agent uuid.UUID, err error) { +func parseUpdateMessage(msg string) (agent uuid.UUID, err error) { agent, err = uuid.Parse(msg) if err != nil { - return uuid.Nil, xerrors.Errorf("failed to parse agent UUID: %w", err) + return uuid.Nil, xerrors.Errorf("failed to parse update message UUID: %w", err) } return agent, nil } @@ -1056,7 +1408,7 @@ type mKey struct { agent uuid.UUID // we always query based on the agent ID, but if we have client connection(s), we query the agent itself. If we // have an agent connection, we need the node mappings for all clients of the agent. - clientsOfAgent bool + kind agpl.QueueKind } // mapping associates a particular client or agent, and its respective coordinator with a node. It is generalized to @@ -1069,9 +1421,13 @@ type mapping struct { node *agpl.Node } +type queueKey interface { + mKey | bKey | sKey +} + // workQ allows scheduling work based on a key. Multiple enqueue requests for the same key are coalesced, and // only one in-progress job per key is scheduled. -type workQ[K mKey | bKey] struct { +type workQ[K queueKey] struct { ctx context.Context cond *sync.Cond @@ -1079,7 +1435,7 @@ type workQ[K mKey | bKey] struct { inProgress map[K]bool } -func newWorkQ[K mKey | bKey](ctx context.Context) *workQ[K] { +func newWorkQ[K queueKey](ctx context.Context) *workQ[K] { q := &workQ[K]{ ctx: ctx, cond: sync.NewCond(&sync.Mutex{}), diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index 9112cd95a0791..031b863144e92 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -438,7 +438,7 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) { assertEventuallyNoClientsForAgent(ctx, t, store, agent2.id) } -// TestPGCoordinator_MultiAgent tests when a single agent connects to multiple coordinators. +// TestPGCoordinator_MultiCoordinatorAgent tests when a single agent connects to multiple coordinators. // We use two agent connections, but they share the same AgentID. This could happen due to a reconnection, // or an infrastructure problem where an old workspace is not fully cleaned up before a new one started. // @@ -451,7 +451,7 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) { // +---------+ // | coord3 | <--- client // +---------+ -func TestPGCoordinator_MultiAgent(t *testing.T) { +func TestPGCoordinator_MultiCoordinatorAgent(t *testing.T) { t.Parallel() if !dbtestutil.WillUsePostgres() { t.Skip("test only with postgres") @@ -693,8 +693,79 @@ func assertEventuallyHasDERPs(ctx context.Context, t *testing.T, c *testConn, ex t.Logf("expected DERP %d to be in %v", e, derps) continue } + return + } + } +} + +func assertNeverHasDERPs(ctx context.Context, t *testing.T, c *testConn, expected ...int) { + t.Helper() + for { + select { + case <-ctx.Done(): + return + case nodes := <-c.nodeChan: + derps := make([]int, 0, len(nodes)) + for _, n := range nodes { + derps = append(derps, n.PreferredDERP) + } + for _, e := range expected { + if slices.Contains(derps, e) { + t.Fatalf("expected not to get DERP %d, but received it", e) + return + } + } + } + } +} + +func assertMultiAgentEventuallyHasDERPs(ctx context.Context, t *testing.T, ma agpl.MultiAgentConn, expected ...int) { + t.Helper() + for { + nodes, ok := ma.NextUpdate(ctx) + require.True(t, ok) + if len(nodes) != len(expected) { + t.Logf("expected %d, got %d nodes", len(expected), len(nodes)) + continue + } + + derps := make([]int, 0, len(nodes)) + for _, n := range nodes { + derps = append(derps, n.PreferredDERP) + } + for _, e := range expected { + if !slices.Contains(derps, e) { + t.Logf("expected DERP %d to be in %v", e, derps) + continue + } + return + } + } +} + +func assertMultiAgentNeverHasDERPs(ctx context.Context, t *testing.T, ma agpl.MultiAgentConn, expected ...int) { + t.Helper() + for { + nodes, ok := ma.NextUpdate(ctx) + if !ok { + return + } + if len(nodes) != len(expected) { + t.Logf("expected %d, got %d nodes", len(expected), len(nodes)) + continue + } + + derps := make([]int, 0, len(nodes)) + for _, n := range nodes { + derps = append(derps, n.PreferredDERP) + } + for _, e := range expected { + if !slices.Contains(derps, e) { + t.Logf("expected DERP %d to be in %v", e, derps) + continue + } + return } - return } } @@ -712,6 +783,7 @@ func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database. } func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { + t.Helper() assert.Eventually(t, func() bool { clients, err := store.GetTailnetClientsForAgent(ctx, agentID) if xerrors.Is(err, sql.ErrNoRows) { diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go index 74c381c2d8b4a..c00ab834b7c25 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go @@ -22,7 +22,7 @@ import ( "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/tailnet" + agpl "github.com/coder/coder/v2/tailnet" ) // Client is a HTTP client for a subset of Coder API routes that external @@ -422,14 +422,14 @@ const ( type CoordinateMessage struct { Type CoordinateMessageType `json:"type"` AgentID uuid.UUID `json:"agent_id"` - Node *tailnet.Node `json:"node"` + Node *agpl.Node `json:"node"` } type CoordinateNodes struct { - Nodes []*tailnet.Node + Nodes []*agpl.Node } -func (c *Client) DialCoordinator(ctx context.Context) (tailnet.MultiAgentConn, error) { +func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, error) { ctx, cancel := context.WithCancel(ctx) coordinateURL, err := c.SDKClient.URL.Parse("/api/v2/workspaceproxies/me/coordinate") @@ -463,13 +463,13 @@ func (c *Client) DialCoordinator(ctx context.Context) (tailnet.MultiAgentConn, e legacyAgentCache: map[uuid.UUID]bool{}, } - ma := (&tailnet.MultiAgent{ + ma := (&agpl.MultiAgent{ ID: uuid.New(), AgentIsLegacyFunc: rma.AgentIsLegacy, OnSubscribe: rma.OnSubscribe, OnUnsubscribe: rma.OnUnsubscribe, OnNodeUpdate: rma.OnNodeUpdate, - OnRemove: func(uuid.UUID) { conn.Close(websocket.StatusGoingAway, "closed") }, + OnRemove: func(agpl.Queue) { conn.Close(websocket.StatusGoingAway, "closed") }, }).Init() go func() { @@ -515,7 +515,7 @@ func (a *remoteMultiAgentHandler) writeJSON(v interface{}) error { // Set a deadline so that hung connections don't put back pressure on the system. // Node updates are tiny, so even the dinkiest connection can handle them if it's not hung. - err = a.nc.SetWriteDeadline(time.Now().Add(tailnet.WriteTimeout)) + err = a.nc.SetWriteDeadline(time.Now().Add(agpl.WriteTimeout)) if err != nil { return xerrors.Errorf("set write deadline: %w", err) } @@ -537,21 +537,21 @@ func (a *remoteMultiAgentHandler) writeJSON(v interface{}) error { return nil } -func (a *remoteMultiAgentHandler) OnNodeUpdate(_ uuid.UUID, node *tailnet.Node) error { +func (a *remoteMultiAgentHandler) OnNodeUpdate(_ uuid.UUID, node *agpl.Node) error { return a.writeJSON(CoordinateMessage{ Type: CoordinateMessageTypeNodeUpdate, Node: node, }) } -func (a *remoteMultiAgentHandler) OnSubscribe(_ tailnet.Queue, agentID uuid.UUID) (*tailnet.Node, error) { +func (a *remoteMultiAgentHandler) OnSubscribe(_ agpl.Queue, agentID uuid.UUID) (*agpl.Node, error) { return nil, a.writeJSON(CoordinateMessage{ Type: CoordinateMessageTypeSubscribe, AgentID: agentID, }) } -func (a *remoteMultiAgentHandler) OnUnsubscribe(_ tailnet.Queue, agentID uuid.UUID) error { +func (a *remoteMultiAgentHandler) OnUnsubscribe(_ agpl.Queue, agentID uuid.UUID) error { return a.writeJSON(CoordinateMessage{ Type: CoordinateMessageTypeUnsubscribe, AgentID: agentID, diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 866633ed54a1e..41a75f1fc5e78 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -146,7 +146,7 @@ func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn { OnSubscribe: c.core.clientSubscribeToAgent, OnUnsubscribe: c.core.clientUnsubscribeFromAgent, OnNodeUpdate: c.core.clientNodeUpdate, - OnRemove: c.core.clientDisconnected, + OnRemove: func(enq Queue) { c.core.clientDisconnected(enq.UniqueID()) }, }).Init() c.core.addClient(id, m) return m @@ -191,8 +191,16 @@ type core struct { legacyAgents map[uuid.UUID]struct{} } +type QueueKind int + +const ( + QueueKindClient QueueKind = 1 + iota + QueueKindAgent +) + type Queue interface { UniqueID() uuid.UUID + Kind() QueueKind Enqueue(n []*Node) error Name() string Stats() (start, lastWrite int64) @@ -200,6 +208,7 @@ type Queue interface { // CoordinatorClose is used by the coordinator when closing a Queue. It // should skip removing itself from the coordinator. CoordinatorClose() error + Done() <-chan struct{} Close() error } @@ -264,7 +273,7 @@ func (c *coordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error { logger := c.core.clientLogger(id, agentID) logger.Debug(ctx, "coordinating client") - tc := NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0) + tc := NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0, QueueKindClient) defer tc.Close() c.core.addClient(id, tc) @@ -509,7 +518,7 @@ func (c *core) initAndTrackAgent(ctx context.Context, cancel func(), conn net.Co overwrites = oldAgentSocket.Overwrites() + 1 _ = oldAgentSocket.Close() } - tc := NewTrackedConn(ctx, cancel, conn, unique, logger, name, overwrites) + tc := NewTrackedConn(ctx, cancel, conn, unique, logger, name, overwrites, QueueKindAgent) c.agentNameCache.Add(id, name) sockets, ok := c.agentToConnectionSockets[id] diff --git a/tailnet/multiagent.go b/tailnet/multiagent.go index ee76e4b88d8aa..5c3412a595152 100644 --- a/tailnet/multiagent.go +++ b/tailnet/multiagent.go @@ -29,9 +29,12 @@ type MultiAgent struct { OnSubscribe func(enq Queue, agent uuid.UUID) (*Node, error) OnUnsubscribe func(enq Queue, agent uuid.UUID) error OnNodeUpdate func(id uuid.UUID, node *Node) error - OnRemove func(id uuid.UUID) + OnRemove func(enq Queue) + ctx context.Context + ctxCancel func() closed bool + updates chan []*Node closeOnce sync.Once start int64 @@ -44,9 +47,14 @@ type MultiAgent struct { func (m *MultiAgent) Init() *MultiAgent { m.updates = make(chan []*Node, 128) m.start = time.Now().Unix() + m.ctx, m.ctxCancel = context.WithCancel(context.Background()) return m } +func (*MultiAgent) Kind() QueueKind { + return QueueKindClient +} + func (m *MultiAgent) UniqueID() uuid.UUID { return m.ID } @@ -156,8 +164,13 @@ func (m *MultiAgent) CoordinatorClose() error { return nil } +func (m *MultiAgent) Done() <-chan struct{} { + return m.ctx.Done() +} + func (m *MultiAgent) Close() error { _ = m.CoordinatorClose() - m.closeOnce.Do(func() { m.OnRemove(m.ID) }) + m.ctxCancel() + m.closeOnce.Do(func() { m.OnRemove(m) }) return nil } diff --git a/tailnet/trackedconn.go b/tailnet/trackedconn.go index 0ec19695ba29f..be464b2327921 100644 --- a/tailnet/trackedconn.go +++ b/tailnet/trackedconn.go @@ -20,6 +20,7 @@ const WriteTimeout = time.Second * 5 type TrackedConn struct { ctx context.Context cancel func() + kind QueueKind conn net.Conn updates chan []*Node logger slog.Logger @@ -35,7 +36,14 @@ type TrackedConn struct { overwrites int64 } -func NewTrackedConn(ctx context.Context, cancel func(), conn net.Conn, id uuid.UUID, logger slog.Logger, name string, overwrites int64) *TrackedConn { +func NewTrackedConn(ctx context.Context, cancel func(), + conn net.Conn, + id uuid.UUID, + logger slog.Logger, + name string, + overwrites int64, + kind QueueKind, +) *TrackedConn { // buffer updates so they don't block, since we hold the // coordinator mutex while queuing. Node updates don't // come quickly, so 512 should be plenty for all but @@ -53,6 +61,7 @@ func NewTrackedConn(ctx context.Context, cancel func(), conn net.Conn, id uuid.U lastWrite: now, name: name, overwrites: overwrites, + kind: kind, } } @@ -70,6 +79,10 @@ func (t *TrackedConn) UniqueID() uuid.UUID { return t.id } +func (t *TrackedConn) Kind() QueueKind { + return t.kind +} + func (t *TrackedConn) Name() string { return t.name } @@ -86,6 +99,10 @@ func (t *TrackedConn) CoordinatorClose() error { return t.Close() } +func (t *TrackedConn) Done() <-chan struct{} { + return t.ctx.Done() +} + // Close the connection and cancel the context for reading node updates from the queue func (t *TrackedConn) Close() error { t.cancel()