From d00c556103a5455bd9f646b88a410d565d70aaea Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Tue, 6 Jun 2023 04:55:19 +0000 Subject: [PATCH 01/11] postgres tailnet coordinator Signed-off-by: Spike Curtis --- coderd/database/dbauthz/tailnetcoordinator.go | 66 + coderd/database/dbfake/dbfake.go | 39 + coderd/database/dbmetrics/dbmetrics.go | 48 + coderd/database/dbmock/dbmock.go | 119 ++ coderd/database/dump.sql | 86 ++ .../migrations/000125_ha_coordinator.down.sql | 18 + .../migrations/000125_ha_coordinator.up.sql | 94 ++ coderd/database/models.go | 20 + coderd/database/querier.go | 8 + coderd/database/queries.sql.go | 233 ++++ coderd/database/queries/tailnet.sql | 79 ++ coderd/rbac/object.go | 5 + coderd/rbac/object_gen.go | 1 + enterprise/tailnet/pgcoord.go | 1208 +++++++++++++++++ enterprise/tailnet/pgcoord_test.go | 656 +++++++++ tailnet/coordinator.go | 17 +- 16 files changed, 2692 insertions(+), 5 deletions(-) create mode 100644 coderd/database/dbauthz/tailnetcoordinator.go create mode 100644 coderd/database/migrations/000125_ha_coordinator.down.sql create mode 100644 coderd/database/migrations/000125_ha_coordinator.up.sql create mode 100644 coderd/database/queries/tailnet.sql create mode 100644 enterprise/tailnet/pgcoord.go create mode 100644 enterprise/tailnet/pgcoord_test.go diff --git a/coderd/database/dbauthz/tailnetcoordinator.go b/coderd/database/dbauthz/tailnetcoordinator.go new file mode 100644 index 0000000000000..ddf924a498364 --- /dev/null +++ b/coderd/database/dbauthz/tailnetcoordinator.go @@ -0,0 +1,66 @@ +package dbauthz + +import ( + "context" + + "github.com/google/uuid" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/rbac" +) + +func (q *querier) UpsertTailnetClient(ctx context.Context, arg database.UpsertTailnetClientParams) (database.TailnetClient, error) { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { + return database.TailnetClient{}, err + } + return q.db.UpsertTailnetClient(ctx, arg) +} + +func (q *querier) UpsertTailnetAgent(ctx context.Context, arg database.UpsertTailnetAgentParams) (database.TailnetAgent, error) { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { + return database.TailnetAgent{}, err + } + return q.db.UpsertTailnetAgent(ctx, arg) +} + +func (q *querier) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (database.TailnetCoordinator, error) { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { + return database.TailnetCoordinator{}, err + } + return q.db.UpsertTailnetCoordinator(ctx, id) +} + +func (q *querier) DeleteTailnetClient(ctx context.Context, arg database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) { + if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil { + return database.DeleteTailnetClientRow{}, err + } + return q.db.DeleteTailnetClient(ctx, arg) +} + +func (q *querier) DeleteTailnetAgent(ctx context.Context, arg database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { + return database.DeleteTailnetAgentRow{}, err + } + return q.db.DeleteTailnetAgent(ctx, arg) +} + +func (q *querier) DeleteCoordinator(ctx context.Context, id uuid.UUID) error { + if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil { + return err + } + return q.db.DeleteCoordinator(ctx, id) +} + +func (q *querier) GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]database.TailnetAgent, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTailnetCoordinator); err != nil { + return nil, err + } + return q.db.GetTailnetAgents(ctx, id) +} + +func (q *querier) GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]database.TailnetClient, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTailnetCoordinator); err != nil { + return nil, err + } + return q.db.GetTailnetClientsForAgent(ctx, agentID) +} diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index dad08f081a4a9..f0bf58830f427 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -5189,3 +5189,42 @@ func (q *fakeQuerier) UpsertServiceBanner(_ context.Context, data string) error q.serviceBanner = []byte(data) return nil } + +// The remaining methods are only used by the enterprise/tailnet.pgCoord. This coordinator explicitly depends on +// postgres triggers that announce changes on the pubsub. Implementing support for this in the fake database would +// strongly couple the fakeQuerier to the pubsub, which is undesirable. Furthermore, it makes little sense to directly +// test the pgCoord against anything other than postgres. The fakeQuerier is designed to allow us to test the Coderd +// API, and for that kind of test, the in-memory, AGPL tailnet coordinator is sufficient. Therefore, these methods +// remain unimplemented in the fakeQuerier. + +func (*fakeQuerier) UpsertTailnetClient(context.Context, database.UpsertTailnetClientParams) (database.TailnetClient, error) { + panic("unimplemented") +} + +func (*fakeQuerier) UpsertTailnetAgent(context.Context, database.UpsertTailnetAgentParams) (database.TailnetAgent, error) { + panic("unimplemented") +} + +func (*fakeQuerier) UpsertTailnetCoordinator(context.Context, uuid.UUID) (database.TailnetCoordinator, error) { + panic("unimplemented") +} + +func (*fakeQuerier) DeleteTailnetClient(context.Context, database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) { + panic("unimplemented") +} + +func (*fakeQuerier) DeleteTailnetAgent(context.Context, database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) { + panic("unimplemented") +} + +func (*fakeQuerier) DeleteCoordinator(context.Context, uuid.UUID) error { + panic("unimplemented") +} + +func (*fakeQuerier) GetTailnetAgents(context.Context, uuid.UUID) ([]database.TailnetAgent, error) { + panic("unimplemented") +} + +func (*fakeQuerier) GetTailnetClientsForAgent(context.Context, uuid.UUID) ([]database.TailnetClient, error) { + panic("unimplemented") +} diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 49cf5fc402c5e..dd281c6831914 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -1535,3 +1535,51 @@ func (m metricsStore) UpsertServiceBanner(ctx context.Context, value string) err m.queryLatencies.WithLabelValues("UpsertServiceBanner").Observe(time.Since(start).Seconds()) return r0 } + +func (m metricsStore) UpsertTailnetClient(ctx context.Context, arg database.UpsertTailnetClientParams) (database.TailnetClient, error) { + start := time.Now() + defer m.queryLatencies.WithLabelValues("UpsertTailnetClient").Observe(time.Since(start).Seconds()) + return m.s.UpsertTailnetClient(ctx, arg) +} + +func (m metricsStore) UpsertTailnetAgent(ctx context.Context, arg database.UpsertTailnetAgentParams) (database.TailnetAgent, error) { + start := time.Now() + defer m.queryLatencies.WithLabelValues("UpsertTailnetAgent").Observe(time.Since(start).Seconds()) + return m.s.UpsertTailnetAgent(ctx, arg) +} + +func (m metricsStore) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (database.TailnetCoordinator, error) { + start := time.Now() + defer m.queryLatencies.WithLabelValues("UpsertTailnetCoordinator").Observe(time.Since(start).Seconds()) + return m.s.UpsertTailnetCoordinator(ctx, id) +} + +func (m metricsStore) DeleteTailnetClient(ctx context.Context, arg database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) { + start := time.Now() + defer m.queryLatencies.WithLabelValues("DeleteTailnetClient").Observe(time.Since(start).Seconds()) + return m.s.DeleteTailnetClient(ctx, arg) +} + +func (m metricsStore) DeleteTailnetAgent(ctx context.Context, arg database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) { + start := time.Now() + defer m.queryLatencies.WithLabelValues("DeleteTailnetAgent").Observe(time.Since(start).Seconds()) + return m.s.DeleteTailnetAgent(ctx, arg) +} + +func (m metricsStore) DeleteCoordinator(ctx context.Context, id uuid.UUID) error { + start := time.Now() + defer m.queryLatencies.WithLabelValues("DeleteCoordinator").Observe(time.Since(start).Seconds()) + return m.s.DeleteCoordinator(ctx, id) +} + +func (m metricsStore) GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]database.TailnetAgent, error) { + start := time.Now() + defer m.queryLatencies.WithLabelValues("GetTailnetAgents").Observe(time.Since(start).Seconds()) + return m.s.GetTailnetAgents(ctx, id) +} + +func (m metricsStore) GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]database.TailnetClient, error) { + start := time.Now() + defer m.queryLatencies.WithLabelValues("GetTailnetClientsForAgent").Observe(time.Since(start).Seconds()) + return m.s.GetTailnetClientsForAgent(ctx, agentID) +} diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index a1c75353b7a96..d6a3a47e93e14 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -110,6 +110,20 @@ func (mr *MockStoreMockRecorder) DeleteApplicationConnectAPIKeysByUserID(arg0, a return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteApplicationConnectAPIKeysByUserID", reflect.TypeOf((*MockStore)(nil).DeleteApplicationConnectAPIKeysByUserID), arg0, arg1) } +// DeleteCoordinator mocks base method. +func (m *MockStore) DeleteCoordinator(arg0 context.Context, arg1 uuid.UUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteCoordinator", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteCoordinator indicates an expected call of DeleteCoordinator. +func (mr *MockStoreMockRecorder) DeleteCoordinator(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCoordinator", reflect.TypeOf((*MockStore)(nil).DeleteCoordinator), arg0, arg1) +} + // DeleteGitSSHKey mocks base method. func (m *MockStore) DeleteGitSSHKey(arg0 context.Context, arg1 uuid.UUID) error { m.ctrl.T.Helper() @@ -223,6 +237,36 @@ func (mr *MockStoreMockRecorder) DeleteReplicasUpdatedBefore(arg0, arg1 interfac return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteReplicasUpdatedBefore", reflect.TypeOf((*MockStore)(nil).DeleteReplicasUpdatedBefore), arg0, arg1) } +// DeleteTailnetAgent mocks base method. +func (m *MockStore) DeleteTailnetAgent(arg0 context.Context, arg1 database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteTailnetAgent", arg0, arg1) + ret0, _ := ret[0].(database.DeleteTailnetAgentRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteTailnetAgent indicates an expected call of DeleteTailnetAgent. +func (mr *MockStoreMockRecorder) DeleteTailnetAgent(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTailnetAgent", reflect.TypeOf((*MockStore)(nil).DeleteTailnetAgent), arg0, arg1) +} + +// DeleteTailnetClient mocks base method. +func (m *MockStore) DeleteTailnetClient(arg0 context.Context, arg1 database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteTailnetClient", arg0, arg1) + ret0, _ := ret[0].(database.DeleteTailnetClientRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteTailnetClient indicates an expected call of DeleteTailnetClient. +func (mr *MockStoreMockRecorder) DeleteTailnetClient(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTailnetClient", reflect.TypeOf((*MockStore)(nil).DeleteTailnetClient), arg0, arg1) +} + // GetAPIKeyByID mocks base method. func (m *MockStore) GetAPIKeyByID(arg0 context.Context, arg1 string) (database.APIKey, error) { m.ctrl.T.Helper() @@ -1018,6 +1062,36 @@ func (mr *MockStoreMockRecorder) GetServiceBanner(arg0 interface{}) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceBanner", reflect.TypeOf((*MockStore)(nil).GetServiceBanner), arg0) } +// GetTailnetAgents mocks base method. +func (m *MockStore) GetTailnetAgents(arg0 context.Context, arg1 uuid.UUID) ([]database.TailnetAgent, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTailnetAgents", arg0, arg1) + ret0, _ := ret[0].([]database.TailnetAgent) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTailnetAgents indicates an expected call of GetTailnetAgents. +func (mr *MockStoreMockRecorder) GetTailnetAgents(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetAgents", reflect.TypeOf((*MockStore)(nil).GetTailnetAgents), arg0, arg1) +} + +// GetTailnetClientsForAgent mocks base method. +func (m *MockStore) GetTailnetClientsForAgent(arg0 context.Context, arg1 uuid.UUID) ([]database.TailnetClient, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTailnetClientsForAgent", arg0, arg1) + ret0, _ := ret[0].([]database.TailnetClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTailnetClientsForAgent indicates an expected call of GetTailnetClientsForAgent. +func (mr *MockStoreMockRecorder) GetTailnetClientsForAgent(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetClientsForAgent", reflect.TypeOf((*MockStore)(nil).GetTailnetClientsForAgent), arg0, arg1) +} + // GetTemplateAverageBuildTime mocks base method. func (m *MockStore) GetTemplateAverageBuildTime(arg0 context.Context, arg1 database.GetTemplateAverageBuildTimeParams) (database.GetTemplateAverageBuildTimeRow, error) { m.ctrl.T.Helper() @@ -3159,6 +3233,51 @@ func (mr *MockStoreMockRecorder) UpsertServiceBanner(arg0, arg1 interface{}) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertServiceBanner", reflect.TypeOf((*MockStore)(nil).UpsertServiceBanner), arg0, arg1) } +// UpsertTailnetAgent mocks base method. +func (m *MockStore) UpsertTailnetAgent(arg0 context.Context, arg1 database.UpsertTailnetAgentParams) (database.TailnetAgent, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertTailnetAgent", arg0, arg1) + ret0, _ := ret[0].(database.TailnetAgent) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpsertTailnetAgent indicates an expected call of UpsertTailnetAgent. +func (mr *MockStoreMockRecorder) UpsertTailnetAgent(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTailnetAgent", reflect.TypeOf((*MockStore)(nil).UpsertTailnetAgent), arg0, arg1) +} + +// UpsertTailnetClient mocks base method. +func (m *MockStore) UpsertTailnetClient(arg0 context.Context, arg1 database.UpsertTailnetClientParams) (database.TailnetClient, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertTailnetClient", arg0, arg1) + ret0, _ := ret[0].(database.TailnetClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpsertTailnetClient indicates an expected call of UpsertTailnetClient. +func (mr *MockStoreMockRecorder) UpsertTailnetClient(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTailnetClient", reflect.TypeOf((*MockStore)(nil).UpsertTailnetClient), arg0, arg1) +} + +// UpsertTailnetCoordinator mocks base method. +func (m *MockStore) UpsertTailnetCoordinator(arg0 context.Context, arg1 uuid.UUID) (database.TailnetCoordinator, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertTailnetCoordinator", arg0, arg1) + ret0, _ := ret[0].(database.TailnetCoordinator) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpsertTailnetCoordinator indicates an expected call of UpsertTailnetCoordinator. +func (mr *MockStoreMockRecorder) UpsertTailnetCoordinator(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTailnetCoordinator", reflect.TypeOf((*MockStore)(nil).UpsertTailnetCoordinator), arg0, arg1) +} + // Wrappers mocks base method. func (m *MockStore) Wrappers() []string { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index a4e796e1f60e2..e9f91c1a67d99 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -171,6 +171,45 @@ BEGIN END; $$; +CREATE FUNCTION notify_agent_change() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + IF (OLD IS NOT NULL) THEN + PERFORM pg_notify('tailnet_agent_update', OLD.id::text); + RETURN NULL; + END IF; + IF (NEW IS NOT NULL) THEN + PERFORM pg_notify('tailnet_agent_update', NEW.id::text); + RETURN NULL; + END IF; +END; +$$; + +CREATE FUNCTION notify_client_change() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + IF (OLD IS NOT NULL) THEN + PERFORM pg_notify('tailnet_client_update', OLD.id || ',' || OLD.agent_id); + RETURN NULL; + END IF; + IF (NEW IS NOT NULL) THEN + PERFORM pg_notify('tailnet_client_update', NEW.id || ',' || NEW.agent_id); + RETURN NULL; + END IF; +END; +$$; + +CREATE FUNCTION notify_coordinator_heartbeat() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + PERFORM pg_notify('tailnet_coordinator_heartbeat', NEW.id::text); + RETURN NULL; +END; +$$; + CREATE TABLE api_keys ( id text NOT NULL, hashed_secret bytea NOT NULL, @@ -383,6 +422,26 @@ CREATE TABLE site_configs ( value character varying(8192) NOT NULL ); +CREATE TABLE tailnet_agents ( + id uuid NOT NULL, + coordinator_id uuid NOT NULL, + updated_at timestamp with time zone NOT NULL, + node jsonb NOT NULL +); + +CREATE TABLE tailnet_clients ( + id uuid NOT NULL, + coordinator_id uuid NOT NULL, + agent_id uuid NOT NULL, + updated_at timestamp with time zone NOT NULL, + node jsonb NOT NULL +); + +CREATE TABLE tailnet_coordinators ( + id uuid NOT NULL, + heartbeat_at timestamp with time zone NOT NULL +); + CREATE TABLE template_version_parameters ( template_version_id uuid NOT NULL, name text NOT NULL, @@ -828,6 +887,15 @@ ALTER TABLE ONLY provisioner_jobs ALTER TABLE ONLY site_configs ADD CONSTRAINT site_configs_key_key UNIQUE (key); +ALTER TABLE ONLY tailnet_agents + ADD CONSTRAINT tailnet_agents_pkey PRIMARY KEY (id, coordinator_id); + +ALTER TABLE ONLY tailnet_clients + ADD CONSTRAINT tailnet_clients_pkey PRIMARY KEY (id, coordinator_id); + +ALTER TABLE ONLY tailnet_coordinators + ADD CONSTRAINT tailnet_coordinators_pkey PRIMARY KEY (id); + ALTER TABLE ONLY template_version_parameters ADD CONSTRAINT template_version_parameters_template_version_id_name_key UNIQUE (template_version_id, name); @@ -915,6 +983,12 @@ CREATE UNIQUE INDEX idx_organization_name ON organizations USING btree (name); CREATE UNIQUE INDEX idx_organization_name_lower ON organizations USING btree (lower(name)); +CREATE INDEX idx_tailnet_agents_coordinator ON tailnet_agents USING btree (coordinator_id); + +CREATE INDEX idx_tailnet_clients_agent ON tailnet_clients USING btree (agent_id); + +CREATE INDEX idx_tailnet_clients_coordinator ON tailnet_clients USING btree (coordinator_id); + CREATE UNIQUE INDEX idx_users_email ON users USING btree (email) WHERE (deleted = false); CREATE UNIQUE INDEX idx_users_username ON users USING btree (username) WHERE (deleted = false); @@ -941,6 +1015,12 @@ CREATE INDEX workspace_resources_job_id_idx ON workspace_resources USING btree ( CREATE UNIQUE INDEX workspaces_owner_id_lower_idx ON workspaces USING btree (owner_id, lower((name)::text)) WHERE (deleted = false); +CREATE TRIGGER notify_agent_change AFTER INSERT OR DELETE OR UPDATE ON tailnet_agents FOR EACH ROW EXECUTE FUNCTION notify_agent_change(); + +CREATE TRIGGER notify_client_change AFTER INSERT OR DELETE OR UPDATE ON tailnet_clients FOR EACH ROW EXECUTE FUNCTION notify_client_change(); + +CREATE TRIGGER notify_coordinator_heartbeat AFTER INSERT OR UPDATE ON tailnet_coordinators FOR EACH ROW EXECUTE FUNCTION notify_coordinator_heartbeat(); + CREATE TRIGGER trigger_insert_apikeys BEFORE INSERT ON api_keys FOR EACH ROW EXECUTE FUNCTION insert_apikey_fail_if_user_deleted(); CREATE TRIGGER trigger_update_users AFTER INSERT OR UPDATE ON users FOR EACH ROW WHEN ((new.deleted = true)) EXECUTE FUNCTION delete_deleted_user_api_keys(); @@ -975,6 +1055,12 @@ ALTER TABLE ONLY provisioner_job_logs ALTER TABLE ONLY provisioner_jobs ADD CONSTRAINT provisioner_jobs_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; +ALTER TABLE ONLY tailnet_agents + ADD CONSTRAINT tailnet_agents_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE; + +ALTER TABLE ONLY tailnet_clients + ADD CONSTRAINT tailnet_clients_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE; + ALTER TABLE ONLY template_version_parameters ADD CONSTRAINT template_version_parameters_template_version_id_fkey FOREIGN KEY (template_version_id) REFERENCES template_versions(id) ON DELETE CASCADE; diff --git a/coderd/database/migrations/000125_ha_coordinator.down.sql b/coderd/database/migrations/000125_ha_coordinator.down.sql new file mode 100644 index 0000000000000..5bf7b888b347b --- /dev/null +++ b/coderd/database/migrations/000125_ha_coordinator.down.sql @@ -0,0 +1,18 @@ +BEGIN; + +DROP TRIGGER IF EXISTS notify_client_change ON tailnet_clients; +DROP FUNCTION IF EXISTS notify_client_change; +DROP INDEX IF EXISTS idx_tailnet_clients_agent; +DROP INDEX IF EXISTS idx_tailnet_clients_coordinator; +DROP TABLE tailnet_clients; + +DROP TRIGGER IF EXISTS notify_agent_change ON tailnet_agents; +DROP FUNCTION IF EXISTS notify_agent_change; +DROP INDEX IF EXISTS idx_tailnet_agents_coordinator; +DROP TABLE IF EXISTS tailnet_agents; + +DROP TRIGGER IF EXISTS notify_coordinator_heartbeat ON tailnet_coordinators; +DROP FUNCTION IF EXISTS notify_coordinator_heartbeat; +DROP TABLE IF EXISTS tailnet_coordinators; + +COMMIT; diff --git a/coderd/database/migrations/000125_ha_coordinator.up.sql b/coderd/database/migrations/000125_ha_coordinator.up.sql new file mode 100644 index 0000000000000..3b1431e173840 --- /dev/null +++ b/coderd/database/migrations/000125_ha_coordinator.up.sql @@ -0,0 +1,94 @@ +BEGIN; + +CREATE TABLE tailnet_coordinators ( + id uuid NOT NULL PRIMARY KEY, + heartbeat_at timestamp with time zone NOT NULL +); + +CREATE TABLE tailnet_clients ( + id uuid NOT NULL, + coordinator_id uuid NOT NULL, + agent_id uuid NOT NULL, + updated_at timestamp with time zone NOT NULL, + node jsonb NOT NULL, + PRIMARY KEY (id, coordinator_id), + FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE +); + + +-- For querying/deleting mappings +CREATE INDEX idx_tailnet_clients_agent ON tailnet_clients (agent_id); + +-- For shutting down / GC a coordinator +CREATE INDEX idx_tailnet_clients_coordinator ON tailnet_clients (coordinator_id); + +CREATE TABLE tailnet_agents ( + id uuid NOT NULL, + coordinator_id uuid NOT NULL, + updated_at timestamp with time zone NOT NULL, + node jsonb NOT NULL, + PRIMARY KEY (id, coordinator_id), + FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE +); + +-- For shutting down / GC a coordinator +CREATE INDEX idx_tailnet_agents_coordinator ON tailnet_agents (coordinator_id); + +-- Any time the tailnet_clients table changes, send an update with the affected client and agent IDs +CREATE FUNCTION notify_client_change() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + IF (OLD IS NOT NULL) THEN + PERFORM pg_notify('tailnet_client_update', OLD.id || ',' || OLD.agent_id); + RETURN NULL; + END IF; + IF (NEW IS NOT NULL) THEN + PERFORM pg_notify('tailnet_client_update', NEW.id || ',' || NEW.agent_id); + RETURN NULL; + END IF; +END; +$$; + +CREATE TRIGGER notify_client_change + AFTER INSERT OR UPDATE OR DELETE ON tailnet_clients + FOR EACH ROW +EXECUTE PROCEDURE notify_client_change(); + +-- Any time tailnet_agents table changes, send an update with the affected agent ID. +CREATE FUNCTION notify_agent_change() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + IF (OLD IS NOT NULL) THEN + PERFORM pg_notify('tailnet_agent_update', OLD.id::text); + RETURN NULL; + END IF; + IF (NEW IS NOT NULL) THEN + PERFORM pg_notify('tailnet_agent_update', NEW.id::text); + RETURN NULL; + END IF; +END; +$$; + +CREATE TRIGGER notify_agent_change + AFTER INSERT OR UPDATE OR DELETE ON tailnet_agents + FOR EACH ROW +EXECUTE PROCEDURE notify_agent_change(); + +-- Send coordinator heartbeats +CREATE FUNCTION notify_coordinator_heartbeat() RETURNS trigger + LANGUAGE plpgsql +AS $$ +BEGIN + PERFORM pg_notify('tailnet_coordinator_heartbeat', NEW.id::text); + RETURN NULL; +END; +$$; + +CREATE TRIGGER notify_coordinator_heartbeat + AFTER INSERT OR UPDATE ON tailnet_coordinators + FOR EACH ROW +EXECUTE PROCEDURE notify_coordinator_heartbeat(); + +COMMIT; diff --git a/coderd/database/models.go b/coderd/database/models.go index e7a0e3b1ee38d..d23ff149eb873 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -1534,6 +1534,26 @@ type SiteConfig struct { Value string `db:"value" json:"value"` } +type TailnetAgent struct { + ID uuid.UUID `db:"id" json:"id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + Node json.RawMessage `db:"node" json:"node"` +} + +type TailnetClient struct { + ID uuid.UUID `db:"id" json:"id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` + AgentID uuid.UUID `db:"agent_id" json:"agent_id"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + Node json.RawMessage `db:"node" json:"node"` +} + +type TailnetCoordinator struct { + ID uuid.UUID `db:"id" json:"id"` + HeartbeatAt time.Time `db:"heartbeat_at" json:"heartbeat_at"` +} + type Template struct { ID uuid.UUID `db:"id" json:"id"` CreatedAt time.Time `db:"created_at" json:"created_at"` diff --git a/coderd/database/querier.go b/coderd/database/querier.go index c427ac768c79f..890d8d1789105 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -29,6 +29,7 @@ type sqlcQuerier interface { DeleteAPIKeyByID(ctx context.Context, id string) error DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error + DeleteCoordinator(ctx context.Context, id uuid.UUID) error DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error DeleteGroupByID(ctx context.Context, id uuid.UUID) error DeleteGroupMemberFromGroup(ctx context.Context, arg DeleteGroupMemberFromGroupParams) error @@ -39,6 +40,8 @@ type sqlcQuerier interface { DeleteOldWorkspaceAgentStartupLogs(ctx context.Context) error DeleteOldWorkspaceAgentStats(ctx context.Context) error DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error + DeleteTailnetAgent(ctx context.Context, arg DeleteTailnetAgentParams) (DeleteTailnetAgentRow, error) + DeleteTailnetClient(ctx context.Context, arg DeleteTailnetClientParams) (DeleteTailnetClientRow, error) GetAPIKeyByID(ctx context.Context, id string) (APIKey, error) // there is no unique constraint on empty token names GetAPIKeyByName(ctx context.Context, arg GetAPIKeyByNameParams) (APIKey, error) @@ -96,6 +99,8 @@ type sqlcQuerier interface { GetQuotaConsumedForUser(ctx context.Context, ownerID uuid.UUID) (int64, error) GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error) GetServiceBanner(ctx context.Context) (string, error) + GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]TailnetAgent, error) + GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]TailnetClient, error) GetTemplateAverageBuildTime(ctx context.Context, arg GetTemplateAverageBuildTimeParams) (GetTemplateAverageBuildTimeRow, error) GetTemplateByID(ctx context.Context, id uuid.UUID) (Template, error) GetTemplateByOrganizationAndName(ctx context.Context, arg GetTemplateByOrganizationAndNameParams) (Template, error) @@ -262,6 +267,9 @@ type sqlcQuerier interface { UpsertLastUpdateCheck(ctx context.Context, value string) error UpsertLogoURL(ctx context.Context, value string) error UpsertServiceBanner(ctx context.Context, value string) error + UpsertTailnetAgent(ctx context.Context, arg UpsertTailnetAgentParams) (TailnetAgent, error) + UpsertTailnetClient(ctx context.Context, arg UpsertTailnetClientParams) (TailnetClient, error) + UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (TailnetCoordinator, error) } var _ sqlcQuerier = (*sqlQuerier)(nil) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 2c71a96ccc4e0..6fbe1bf5d74ef 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3178,6 +3178,239 @@ func (q *sqlQuerier) UpsertServiceBanner(ctx context.Context, value string) erro return err } +const deleteCoordinator = `-- name: DeleteCoordinator :exec +DELETE +FROM tailnet_coordinators +WHERE id = $1 +` + +func (q *sqlQuerier) DeleteCoordinator(ctx context.Context, id uuid.UUID) error { + _, err := q.db.ExecContext(ctx, deleteCoordinator, id) + return err +} + +const deleteTailnetAgent = `-- name: DeleteTailnetAgent :one +DELETE +FROM tailnet_agents +WHERE id = $1 and coordinator_id = $2 +RETURNING id, coordinator_id +` + +type DeleteTailnetAgentParams struct { + ID uuid.UUID `db:"id" json:"id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` +} + +type DeleteTailnetAgentRow struct { + ID uuid.UUID `db:"id" json:"id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` +} + +func (q *sqlQuerier) DeleteTailnetAgent(ctx context.Context, arg DeleteTailnetAgentParams) (DeleteTailnetAgentRow, error) { + row := q.db.QueryRowContext(ctx, deleteTailnetAgent, arg.ID, arg.CoordinatorID) + var i DeleteTailnetAgentRow + err := row.Scan(&i.ID, &i.CoordinatorID) + return i, err +} + +const deleteTailnetClient = `-- name: DeleteTailnetClient :one +DELETE +FROM tailnet_clients +WHERE id = $1 and coordinator_id = $2 +RETURNING id, coordinator_id +` + +type DeleteTailnetClientParams struct { + ID uuid.UUID `db:"id" json:"id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` +} + +type DeleteTailnetClientRow struct { + ID uuid.UUID `db:"id" json:"id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` +} + +func (q *sqlQuerier) DeleteTailnetClient(ctx context.Context, arg DeleteTailnetClientParams) (DeleteTailnetClientRow, error) { + row := q.db.QueryRowContext(ctx, deleteTailnetClient, arg.ID, arg.CoordinatorID) + var i DeleteTailnetClientRow + err := row.Scan(&i.ID, &i.CoordinatorID) + return i, err +} + +const getTailnetAgents = `-- name: GetTailnetAgents :many +SELECT id, coordinator_id, updated_at, node +FROM tailnet_agents +WHERE id = $1 +` + +func (q *sqlQuerier) GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]TailnetAgent, error) { + rows, err := q.db.QueryContext(ctx, getTailnetAgents, id) + if err != nil { + return nil, err + } + defer rows.Close() + var items []TailnetAgent + for rows.Next() { + var i TailnetAgent + if err := rows.Scan( + &i.ID, + &i.CoordinatorID, + &i.UpdatedAt, + &i.Node, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getTailnetClientsForAgent = `-- name: GetTailnetClientsForAgent :many +SELECT id, coordinator_id, agent_id, updated_at, node +FROM tailnet_clients +WHERE agent_id = $1 +` + +func (q *sqlQuerier) GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]TailnetClient, error) { + rows, err := q.db.QueryContext(ctx, getTailnetClientsForAgent, agentID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []TailnetClient + for rows.Next() { + var i TailnetClient + if err := rows.Scan( + &i.ID, + &i.CoordinatorID, + &i.AgentID, + &i.UpdatedAt, + &i.Node, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const upsertTailnetAgent = `-- name: UpsertTailnetAgent :one +INSERT INTO + tailnet_agents ( + id, + coordinator_id, + node, + updated_at +) +VALUES + ($1, $2, $3, now() at time zone 'utc') +ON CONFLICT (id, coordinator_id) +DO UPDATE SET + id = $1, + coordinator_id = $2, + node = $3, + updated_at = now() at time zone 'utc' +RETURNING id, coordinator_id, updated_at, node +` + +type UpsertTailnetAgentParams struct { + ID uuid.UUID `db:"id" json:"id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` + Node json.RawMessage `db:"node" json:"node"` +} + +func (q *sqlQuerier) UpsertTailnetAgent(ctx context.Context, arg UpsertTailnetAgentParams) (TailnetAgent, error) { + row := q.db.QueryRowContext(ctx, upsertTailnetAgent, arg.ID, arg.CoordinatorID, arg.Node) + var i TailnetAgent + err := row.Scan( + &i.ID, + &i.CoordinatorID, + &i.UpdatedAt, + &i.Node, + ) + return i, err +} + +const upsertTailnetClient = `-- name: UpsertTailnetClient :one +INSERT INTO + tailnet_clients ( + id, + coordinator_id, + agent_id, + node, + updated_at +) +VALUES + ($1, $2, $3, $4, now() at time zone 'utc') +ON CONFLICT (id, coordinator_id) +DO UPDATE SET + id = $1, + coordinator_id = $2, + agent_id = $3, + node = $4, + updated_at = now() at time zone 'utc' +RETURNING id, coordinator_id, agent_id, updated_at, node +` + +type UpsertTailnetClientParams struct { + ID uuid.UUID `db:"id" json:"id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` + AgentID uuid.UUID `db:"agent_id" json:"agent_id"` + Node json.RawMessage `db:"node" json:"node"` +} + +func (q *sqlQuerier) UpsertTailnetClient(ctx context.Context, arg UpsertTailnetClientParams) (TailnetClient, error) { + row := q.db.QueryRowContext(ctx, upsertTailnetClient, + arg.ID, + arg.CoordinatorID, + arg.AgentID, + arg.Node, + ) + var i TailnetClient + err := row.Scan( + &i.ID, + &i.CoordinatorID, + &i.AgentID, + &i.UpdatedAt, + &i.Node, + ) + return i, err +} + +const upsertTailnetCoordinator = `-- name: UpsertTailnetCoordinator :one +INSERT INTO + tailnet_coordinators ( + id, + heartbeat_at +) +VALUES + ($1, now() at time zone 'utc') +ON CONFLICT (id) +DO UPDATE SET + id = $1, + heartbeat_at = now() at time zone 'utc' +RETURNING id, heartbeat_at +` + +func (q *sqlQuerier) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (TailnetCoordinator, error) { + row := q.db.QueryRowContext(ctx, upsertTailnetCoordinator, id) + var i TailnetCoordinator + err := row.Scan(&i.ID, &i.HeartbeatAt) + return i, err +} + const getTemplateAverageBuildTime = `-- name: GetTemplateAverageBuildTime :one WITH build_times AS ( SELECT diff --git a/coderd/database/queries/tailnet.sql b/coderd/database/queries/tailnet.sql new file mode 100644 index 0000000000000..e45cb480b1d01 --- /dev/null +++ b/coderd/database/queries/tailnet.sql @@ -0,0 +1,79 @@ +-- name: UpsertTailnetClient :one +INSERT INTO + tailnet_clients ( + id, + coordinator_id, + agent_id, + node, + updated_at +) +VALUES + ($1, $2, $3, $4, now() at time zone 'utc') +ON CONFLICT (id, coordinator_id) +DO UPDATE SET + id = $1, + coordinator_id = $2, + agent_id = $3, + node = $4, + updated_at = now() at time zone 'utc' +RETURNING *; + +-- name: UpsertTailnetAgent :one +INSERT INTO + tailnet_agents ( + id, + coordinator_id, + node, + updated_at +) +VALUES + ($1, $2, $3, now() at time zone 'utc') +ON CONFLICT (id, coordinator_id) +DO UPDATE SET + id = $1, + coordinator_id = $2, + node = $3, + updated_at = now() at time zone 'utc' +RETURNING *; + + +-- name: DeleteTailnetClient :one +DELETE +FROM tailnet_clients +WHERE id = $1 and coordinator_id = $2 +RETURNING id, coordinator_id; + +-- name: DeleteTailnetAgent :one +DELETE +FROM tailnet_agents +WHERE id = $1 and coordinator_id = $2 +RETURNING id, coordinator_id; + +-- name: DeleteCoordinator :exec +DELETE +FROM tailnet_coordinators +WHERE id = $1; + +-- name: GetTailnetAgents :many +SELECT * +FROM tailnet_agents +WHERE id = $1; + +-- name: GetTailnetClientsForAgent :many +SELECT * +FROM tailnet_clients +WHERE agent_id = $1; + +-- name: UpsertTailnetCoordinator :one +INSERT INTO + tailnet_coordinators ( + id, + heartbeat_at +) +VALUES + ($1, now() at time zone 'utc') +ON CONFLICT (id) +DO UPDATE SET + id = $1, + heartbeat_at = now() at time zone 'utc' +RETURNING *; diff --git a/coderd/rbac/object.go b/coderd/rbac/object.go index e867abfb69685..060f325c607a1 100644 --- a/coderd/rbac/object.go +++ b/coderd/rbac/object.go @@ -173,6 +173,11 @@ var ( ResourceSystem = Object{ Type: "system", } + + // ResourceTailnetCoordinator is a pseudo-resource for use by the tailnet coordinator + ResourceTailnetCoordinator = Object{ + Type: "tailnet_coordinator", + } ) // Object is used to create objects for authz checks when you have none in diff --git a/coderd/rbac/object_gen.go b/coderd/rbac/object_gen.go index 9af80010cf753..d0a7bb5e68193 100644 --- a/coderd/rbac/object_gen.go +++ b/coderd/rbac/object_gen.go @@ -18,6 +18,7 @@ func AllResources() []Object { ResourceReplicas, ResourceRoleAssignment, ResourceSystem, + ResourceTailnetCoordinator, ResourceTemplate, ResourceUser, ResourceUserData, diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go new file mode 100644 index 0000000000000..a1a14bd993bc8 --- /dev/null +++ b/enterprise/tailnet/pgcoord.go @@ -0,0 +1,1208 @@ +package tailnet + +import ( + "context" + "database/sql" + "encoding/json" + "io" + "net" + "net/http" + "strings" + "sync" + "time" + + "github.com/cenkalti/backoff/v4" + "github.com/google/uuid" + "golang.org/x/xerrors" + + "cdr.dev/slog" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/pubsub" + agpl "github.com/coder/coder/tailnet" +) + +const ( + EventHeartbeats = "tailnet_coordinator_heartbeat" + eventClientUpdate = "tailnet_client_update" + eventAgentUpdate = "tailnet_agent_update" + HeartbeatPeriod = time.Second * 2 + MissedHeartbeats = 3 + numQuerierWorkers = 10 + numBinderWorkers = 10 + dbMaxBackoff = 10 * time.Second +) + +// pgCoord is a postgres-backed coordinator +// +// ┌────────┐ ┌────────┐ ┌───────┐ +// │ connIO ├───────► binder ├────────► store │ +// └───▲────┘ │ │ │ │ +// │ └────────┘ ┌──────┤ │ +// │ │ └───────┘ +// │ │ +// │ ┌──────────▼┐ ┌────────┐ +// │ │ │ │ │ +// └────────────┤ querier ◄─────┤ pubsub │ +// │ │ │ │ +// └───────────┘ └────────┘ +// +// each incoming connection (websocket) from a client or agent is wrapped in a connIO which handles reading & writing +// from it. Node updates from a connIO are sent to the binder, which writes them to the database.Store. The querier +// is responsible for querying the store for the nodes the connection needs (e.g. for a client, the corresponding +// agent). The querier receives pubsub notifications about changes, which trigger queries for the latest state. +// +// The querier also sends the coordinator's heartbeat, and monitors the heartbeats of other coordinators. When +// heartbeats cease for a coordinator, it stops using any nodes discovered from that coordinator and pushes an update +// to affected connIOs. +// +// This package uses the term "binding" to mean the act of registering an association between some connection (client +// or agent) and an agpl.Node. It uses the term "mapping" to mean the act of determining the nodes that the connection +// needs to receive (e.g. for a client, the node bound to the corresponding agent, or for an agent, the nodes bound to +// all clients of the agent). +type pgCoord struct { + ctx context.Context + logger slog.Logger + pubsub pubsub.Pubsub + store database.Store + + bindings chan binding + newConnections chan *connIO + id uuid.UUID + + cancel context.CancelFunc + closeOnce sync.Once + closed chan struct{} + + binder *binder + querier *querier +} + +// NewPGCoord creates a high-availability coordinator that stores state in the PostgreSQL database and +// receives notifications of updates via the pubsub. +func NewPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store database.Store) (agpl.Coordinator, error) { + ctx, cancel := context.WithCancel(ctx) + id := uuid.New() + logger = logger.Named("pgcoord").With(slog.F("coordinator_id", id)) + bCh := make(chan binding) + cCh := make(chan *connIO) + // signals when first heartbeat has been sent, so it's safe to start binding. + fHB := make(chan struct{}) + + c := &pgCoord{ + ctx: ctx, + cancel: cancel, + logger: logger, + pubsub: ps, + store: store, + binder: newBinder(ctx, logger, id, store, bCh, fHB), + bindings: bCh, + newConnections: cCh, + id: id, + querier: newQuerier(ctx, logger, ps, store, id, cCh, numQuerierWorkers, fHB), + closed: make(chan struct{}), + } + return c, nil +} + +func (*pgCoord) ServeHTTPDebug(w http.ResponseWriter, _ *http.Request) { + // TODO(spikecurtis) I'd like to hold off implementing this until after the rest of this is code reviewed. + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("Coder Enterprise PostgreSQL distributed tailnet coordinator")) +} + +func (c *pgCoord) Node(id uuid.UUID) *agpl.Node { + // In production, we only ever get this request for an agent. + // We're going to directly query the database, since we would only have the agent mapping stored locally if we had + // a client of that agent connected, which isn't always the case. + mappings, err := c.querier.queryAgent(id) + if err != nil { + c.logger.Error(c.ctx, "failed to query agents", slog.Error(err)) + } + mappings = c.querier.heartbeats.filter(mappings) + var bestT time.Time + var bestN *agpl.Node + for _, m := range mappings { + if m.updatedAt.After(bestT) { + bestN = m.node + bestT = m.updatedAt + } + } + return bestN +} + +func (c *pgCoord) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) (retErr error) { + defer func() { + err := conn.Close() + if err != nil { + c.logger.Debug(c.ctx, "closing client connection", + slog.F("client_id", id), + slog.F("agent_id", agent), + slog.Error(err)) + } + }() + cIO := newConnIO(c.ctx, c.logger, c.bindings, conn, id, agent) + if err := sendCtx(c.ctx, c.newConnections, cIO); err != nil { + // can only be a context error, no need to log here. + return err + } + <-cIO.ctx.Done() + return nil +} + +func (c *pgCoord) ServeAgent(conn net.Conn, id uuid.UUID, name string) (retErr error) { + defer func() { + err := conn.Close() + if err != nil { + c.logger.Debug(c.ctx, "closing agent connection", + slog.F("agent_id", id), + slog.Error(err)) + } + }() + logger := c.logger.With(slog.F("name", name)) + cIO := newConnIO(c.ctx, logger, c.bindings, conn, uuid.Nil, id) + if err := sendCtx(c.ctx, c.newConnections, cIO); err != nil { + // can only be a context error, no need to log here. + return err + } + <-cIO.ctx.Done() + return nil +} + +func (c *pgCoord) Close() error { + c.cancel() + // do we need to wait for the binder to complete? + c.closeOnce.Do(func() { close(c.closed) }) + return nil +} + +// connIO manages the reading and writing to a connected client or agent. Agent connIOs have their client field set to +// uuid.Nil. It reads node updates via its decoder, then pushes them onto the bindings channel. It receives mappings +// via its updates TrackedConn, which then writes them. +type connIO struct { + pCtx context.Context + ctx context.Context + cancel context.CancelFunc + logger slog.Logger + client uuid.UUID + agent uuid.UUID + decoder *json.Decoder + updates *agpl.TrackedConn + bindings chan<- binding +} + +func newConnIO( + pCtx context.Context, logger slog.Logger, bindings chan<- binding, conn net.Conn, client, agent uuid.UUID, +) *connIO { + ctx, cancel := context.WithCancel(pCtx) + id := agent + logger = logger.With(slog.F("agent_id", agent)) + if client != uuid.Nil { + logger = logger.With(slog.F("client_id", client)) + id = client + } + c := &connIO{ + pCtx: pCtx, + ctx: ctx, + cancel: cancel, + logger: logger, + client: client, + agent: agent, + decoder: json.NewDecoder(conn), + updates: agpl.NewTrackedConn(ctx, cancel, conn, id, logger, 0), + bindings: bindings, + } + go c.recvLoop() + go c.updates.SendUpdates() + logger.Info(ctx, "serving connection") + return c +} + +func (c *connIO) recvLoop() { + defer func() { + // withdraw bindings when we exit. We need to use the parent context here, since our own context might be + // canceled, but we still need to withdraw bindings. + b := binding{ + bKey: bKey{ + client: c.client, + agent: c.agent, + }, + } + if err := sendCtx(c.pCtx, c.bindings, b); err != nil { + c.logger.Debug(c.ctx, "parent context expired while withdrawing bindings", slog.Error(err)) + } + }() + defer c.cancel() + for { + var node agpl.Node + err := c.decoder.Decode(&node) + if err != nil { + if xerrors.Is(err, io.EOF) || xerrors.Is(err, io.ErrClosedPipe) || xerrors.Is(err, context.Canceled) { + c.logger.Debug(c.ctx, "exiting recvLoop", slog.Error(err)) + } else { + c.logger.Error(c.ctx, "failed to decode Node update", slog.Error(err)) + } + return + } + c.logger.Debug(c.ctx, "got node update", slog.F("node", node)) + b := binding{ + bKey: bKey{ + client: c.client, + agent: c.agent, + }, + node: &node, + } + if err := sendCtx(c.ctx, c.bindings, b); err != nil { + c.logger.Debug(c.ctx, "recvLoop ctx expired", slog.Error(err)) + return + } + } +} + +func sendCtx[A any](ctx context.Context, c chan<- A, a A) (err error) { + select { + case <-ctx.Done(): + return ctx.Err() + case c <- a: + return nil + } +} + +// bKey, or "binding key" identifies a client or agent in a binding. Agents have their client field set to uuid.Nil. +type bKey struct { + client uuid.UUID + agent uuid.UUID +} + +// binding represents an association between a client or agent and a Node. +type binding struct { + bKey + node *agpl.Node +} + +// binder reads node bindings from the channel and writes them to the database. It handles retries with a backoff. +type binder struct { + ctx context.Context + logger slog.Logger + coordinatorID uuid.UUID + store database.Store + bindings <-chan binding + + mu sync.Mutex + latest map[bKey]binding + workQ *workQ[bKey] +} + +func newBinder(ctx context.Context, logger slog.Logger, + id uuid.UUID, store database.Store, + bindings <-chan binding, startWorkers <-chan struct{}, +) *binder { + b := &binder{ + ctx: ctx, + logger: logger, + coordinatorID: id, + store: store, + bindings: bindings, + latest: make(map[bKey]binding), + workQ: newWorkQ[bKey](ctx), + } + go b.handleBindings() + go func() { + <-startWorkers + for i := 0; i < numBinderWorkers; i++ { + go b.worker() + } + }() + return b +} + +func (b *binder) handleBindings() { + for { + select { + case <-b.ctx.Done(): + b.logger.Debug(b.ctx, "binder exiting", slog.Error(b.ctx.Err())) + return + case bnd := <-b.bindings: + b.storeBinding(bnd) + b.workQ.enqueue(bnd.bKey) + } + } +} + +func (b *binder) worker() { + eb := backoff.NewExponentialBackOff() + eb.MaxElapsedTime = 0 // retry indefinitely + eb.MaxInterval = dbMaxBackoff + bkoff := backoff.WithContext(eb, b.ctx) + for { + bk, err := b.workQ.acquire() + if err != nil { + // context expired + return + } + err = backoff.Retry(func() error { + bnd := b.retrieveBinding(bk) + return b.writeOne(bnd) + }, bkoff) + if err != nil { + bkoff.Reset() + } + b.workQ.done(bk) + } +} + +func (b *binder) writeOne(bnd binding) error { + var nodeRaw json.RawMessage + var err error + if bnd.node != nil { + nodeRaw, err = json.Marshal(*bnd.node) + if err != nil { + // this is very bad news, but it should never happen because the node was Unmarshalled by this process + // earlier. + b.logger.Error(b.ctx, "failed to marshall node", slog.Error(err)) + return err + } + } + + switch { + case bnd.client == uuid.Nil && len(nodeRaw) > 0: + _, err = b.store.UpsertTailnetAgent(b.ctx, database.UpsertTailnetAgentParams{ + ID: bnd.agent, + CoordinatorID: b.coordinatorID, + Node: nodeRaw, + }) + case bnd.client == uuid.Nil && len(nodeRaw) == 0: + _, err = b.store.DeleteTailnetAgent(b.ctx, database.DeleteTailnetAgentParams{ + ID: bnd.agent, + CoordinatorID: b.coordinatorID, + }) + if xerrors.Is(err, sql.ErrNoRows) { + // treat deletes as idempotent + err = nil + } + case bnd.client != uuid.Nil && len(nodeRaw) > 0: + _, err = b.store.UpsertTailnetClient(b.ctx, database.UpsertTailnetClientParams{ + ID: bnd.client, + CoordinatorID: b.coordinatorID, + AgentID: bnd.agent, + Node: nodeRaw, + }) + case bnd.client != uuid.Nil && len(nodeRaw) == 0: + _, err = b.store.DeleteTailnetClient(b.ctx, database.DeleteTailnetClientParams{ + ID: bnd.client, + CoordinatorID: b.coordinatorID, + }) + if xerrors.Is(err, sql.ErrNoRows) { + // treat deletes as idempotent + err = nil + } + default: + panic("unhittable") + } + if err != nil { + b.logger.Error(b.ctx, "failed to write binding to database", + slog.F("client_id", bnd.client), + slog.F("agent_id", bnd.agent), + slog.F("node", string(nodeRaw)), + slog.Error(err)) + } + return err +} + +// storeBinding stores the latest binding, where we interpret node == nil as removing the binding. This keeps the map +// from growing without bound. +func (b *binder) storeBinding(bnd binding) { + b.mu.Lock() + defer b.mu.Unlock() + if bnd.node != nil { + b.latest[bnd.bKey] = bnd + } else { + // nil node is interpreted as removing binding + delete(b.latest, bnd.bKey) + } +} + +// retrieveBinding gets the latest binding for a key. +func (b *binder) retrieveBinding(bk bKey) binding { + b.mu.Lock() + defer b.mu.Unlock() + bnd, ok := b.latest[bk] + if !ok { + bnd = binding{ + bKey: bk, + node: nil, + } + } + return bnd +} + +// mapper tracks a single client or agent ID, and fans out updates to that ID->node mapping to every local connection +// that needs it. +type mapper struct { + ctx context.Context + logger slog.Logger + + add chan *connIO + del chan *connIO + + // reads from this channel trigger sending latest nodes to + // all connections. It is used when coordinators are added + // or removed + update chan struct{} + + mappings chan []mapping + + conns map[bKey]*connIO + latest []mapping + + heartbeats *heartbeats +} + +func newMapper(ctx context.Context, logger slog.Logger, mk mKey, h *heartbeats) *mapper { + logger = logger.With( + slog.F("agent_id", mk.agent), + slog.F("clients_of_agent", mk.clientsOfAgent), + ) + m := &mapper{ + ctx: ctx, + logger: logger, + add: make(chan *connIO), + del: make(chan *connIO), + update: make(chan struct{}), + conns: make(map[bKey]*connIO), + mappings: make(chan []mapping), + heartbeats: h, + } + go m.run() + return m +} + +func (m *mapper) run() { + for { + select { + case <-m.ctx.Done(): + return + case c := <-m.add: + m.conns[bKey{c.client, c.agent}] = c + nodes := m.mappingsToNodes(m.latest) + if len(nodes) == 0 { + m.logger.Debug(m.ctx, "skipping 0 length node update") + continue + } + if err := c.updates.Enqueue(nodes); err != nil { + m.logger.Error(m.ctx, "failed to enqueue node update", slog.Error(err)) + } + case c := <-m.del: + delete(m.conns, bKey{c.client, c.agent}) + case mappings := <-m.mappings: + m.latest = mappings + nodes := m.mappingsToNodes(mappings) + if len(nodes) == 0 { + m.logger.Debug(m.ctx, "skipping 0 length node update") + continue + } + for _, conn := range m.conns { + if err := conn.updates.Enqueue(nodes); err != nil { + m.logger.Error(m.ctx, "failed to enqueue node update", slog.Error(err)) + } + } + case <-m.update: + nodes := m.mappingsToNodes(m.latest) + if len(nodes) == 0 { + m.logger.Debug(m.ctx, "skipping 0 length node update") + continue + } + for _, conn := range m.conns { + if err := conn.updates.Enqueue(nodes); err != nil { + m.logger.Error(m.ctx, "failed to enqueue triggered node update", slog.Error(err)) + } + } + } + } +} + +// mappingsToNodes takes a set of mappings and resolves the best set of nodes. We may get several mappings for a +// particular connection, from different coordinators in the distributed system. Furthermore, some coordinators +// might be considered invalid on account of missing heartbeats. We take the most recent mapping from a valid +// coordinator as the "best" mapping. +func (m *mapper) mappingsToNodes(mappings []mapping) []*agpl.Node { + mappings = m.heartbeats.filter(mappings) + best := make(map[bKey]mapping, len(mappings)) + for _, m := range mappings { + bk := bKey{client: m.client, agent: m.agent} + bestM, ok := best[bk] + if !ok || m.updatedAt.After(bestM.updatedAt) { + best[bk] = m + } + } + nodes := make([]*agpl.Node, 0, len(best)) + for _, m := range best { + nodes = append(nodes, m.node) + } + return nodes +} + +// querier is responsible for monitoring pubsub notifications and querying the database for the mappings that all +// connnected clients and agents need. It also checks heartbeats and withdraws mappings from coordinators that have +// failed heartbeats. +type querier struct { + ctx context.Context + logger slog.Logger + pubsub pubsub.Pubsub + store database.Store + newConnections chan *connIO + + workQ *workQ[mKey] + heartbeats *heartbeats + updates <-chan struct{} + + mu sync.Mutex + mappers map[mKey]*countedMapper +} + +type countedMapper struct { + *mapper + count int + cancel context.CancelFunc +} + +func newQuerier( + ctx context.Context, logger slog.Logger, + ps pubsub.Pubsub, store database.Store, + self uuid.UUID, newConnections chan *connIO, numWorkers int, + firstHeartbeat chan<- struct{}, +) *querier { + updates := make(chan struct{}) + q := &querier{ + ctx: ctx, + logger: logger.Named("querier"), + pubsub: ps, + store: store, + newConnections: newConnections, + workQ: newWorkQ[mKey](ctx), + heartbeats: newHeartbeats(ctx, logger, ps, store, self, updates, firstHeartbeat), + mappers: make(map[mKey]*countedMapper), + updates: updates, + } + go q.subscribe() + go q.handleConnIO() + for i := 0; i < numWorkers; i++ { + go q.worker() + } + go q.handleUpdates() + return q +} + +func (q *querier) handleConnIO() { + for { + select { + case <-q.ctx.Done(): + return + case c := <-q.newConnections: + q.newConn(c) + } + } +} + +func (q *querier) newConn(c *connIO) { + q.mu.Lock() + defer q.mu.Unlock() + mk := mKey{ + agent: c.agent, + // if client is Nil, this is an agent connection, and it wants the mappings for all the clients of itself + clientsOfAgent: c.client == uuid.Nil, + } + cm, ok := q.mappers[mk] + if !ok { + ctx, cancel := context.WithCancel(q.ctx) + mpr := newMapper(ctx, q.logger, mk, q.heartbeats) + cm = &countedMapper{ + mapper: mpr, + count: 0, + cancel: cancel, + } + q.mappers[mk] = cm + // we don't have any mapping state for this key yet + q.workQ.enqueue(mk) + } + if err := sendCtx(cm.ctx, cm.add, c); err != nil { + return + } + cm.count++ + go q.waitForConn(c) +} + +func (q *querier) waitForConn(c *connIO) { + <-c.ctx.Done() + q.mu.Lock() + defer q.mu.Unlock() + mk := mKey{ + agent: c.agent, + // if client is Nil, this is an agent connection, and it wants the mappings for all the clients of itself + clientsOfAgent: c.client == uuid.Nil, + } + cm := q.mappers[mk] + if err := sendCtx(cm.ctx, cm.del, c); err != nil { + return + } + cm.count-- + if cm.count == 0 { + cm.cancel() + delete(q.mappers, mk) + } +} + +func (q *querier) worker() { + eb := backoff.NewExponentialBackOff() + eb.MaxElapsedTime = 0 // retry indefinitely + eb.MaxInterval = dbMaxBackoff + bkoff := backoff.WithContext(eb, q.ctx) + for { + mk, err := q.workQ.acquire() + if err != nil { + // context expired + return + } + err = backoff.Retry(func() error { + return q.query(mk) + }, bkoff) + if err != nil { + bkoff.Reset() + } + q.workQ.done(mk) + } +} + +func (q *querier) query(mk mKey) error { + var mappings []mapping + var err error + if mk.clientsOfAgent { + mappings, err = q.queryClientsOfAgent(mk.agent) + if err != nil { + return err + } + } else { + mappings, err = q.queryAgent(mk.agent) + if err != nil { + return err + } + } + q.mu.Lock() + mpr, ok := q.mappers[mk] + q.mu.Unlock() + if !ok { + q.logger.Debug(q.ctx, "query for missing mapper", + slog.F("agent_id", mk.agent), slog.F("clients_of_agent", mk.clientsOfAgent)) + return nil + } + mpr.mappings <- mappings + return nil +} + +func (q *querier) queryClientsOfAgent(agent uuid.UUID) ([]mapping, error) { + clients, err := q.store.GetTailnetClientsForAgent(q.ctx, agent) + if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + return nil, err + } + mappings := make([]mapping, 0, len(clients)) + for _, client := range clients { + node := new(agpl.Node) + err := json.Unmarshal(client.Node, node) + if err != nil { + q.logger.Error(q.ctx, "failed to unmarshal node", slog.Error(err)) + return nil, backoff.Permanent(err) + } + mappings = append(mappings, mapping{ + client: client.ID, + agent: client.AgentID, + coordinator: client.CoordinatorID, + updatedAt: client.UpdatedAt, + node: node, + }) + } + return mappings, nil +} + +func (q *querier) queryAgent(agentID uuid.UUID) ([]mapping, error) { + agents, err := q.store.GetTailnetAgents(q.ctx, agentID) + if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + return nil, err + } + mappings := make([]mapping, 0, len(agents)) + for _, agent := range agents { + node := new(agpl.Node) + err := json.Unmarshal(agent.Node, node) + if err != nil { + q.logger.Error(q.ctx, "failed to unmarshal node", slog.Error(err)) + return nil, backoff.Permanent(err) + } + mappings = append(mappings, mapping{ + agent: agent.ID, + coordinator: agent.CoordinatorID, + updatedAt: agent.UpdatedAt, + node: node, + }) + } + return mappings, nil +} + +func (q *querier) subscribe() { + eb := backoff.NewExponentialBackOff() + eb.MaxElapsedTime = 0 // retry indefinitely + eb.MaxInterval = dbMaxBackoff + bkoff := backoff.WithContext(eb, q.ctx) + var cancelClient context.CancelFunc + err := backoff.Retry(func() error { + cancelFn, err := q.pubsub.SubscribeWithErr(eventClientUpdate, q.listenClient) + if err != nil { + q.logger.Warn(q.ctx, "failed to subscribe to client updates", slog.Error(err)) + return err + } + cancelClient = cancelFn + return nil + }, bkoff) + if err != nil { + // this should only happen if context is canceled + return + } + defer cancelClient() + bkoff.Reset() + + var cancelAgent context.CancelFunc + err = backoff.Retry(func() error { + cancelFn, err := q.pubsub.SubscribeWithErr(eventAgentUpdate, q.listenAgent) + if err != nil { + q.logger.Warn(q.ctx, "failed to subscribe to agent updates", slog.Error(err)) + return err + } + cancelAgent = cancelFn + return nil + }, bkoff) + if err != nil { + // this should only happen if context is canceled + return + } + defer cancelAgent() + + // hold subscriptions open until context is canceled + <-q.ctx.Done() +} + +func (q *querier) listenClient(_ context.Context, msg []byte, err error) { + if xerrors.Is(err, pubsub.ErrDroppedMessages) { + q.logger.Warn(q.ctx, "pubsub may have dropped client updates") + // we need to schedule a full resync of client mappings + q.resyncClientMappings() + return + } + if err != nil { + q.logger.Warn(q.ctx, "unhandled pubsub error", slog.Error(err)) + } + client, agent, err := parseClientUpdate(string(msg)) + if err != nil { + q.logger.Error(q.ctx, "failed to parse client update", slog.F("msg", string(msg)), slog.Error(err)) + return + } + logger := q.logger.With(slog.F("client_id", client), slog.F("agent_id", agent)) + logger.Debug(q.ctx, "got client update") + mk := mKey{ + agent: agent, + clientsOfAgent: true, + } + q.mu.Lock() + _, ok := q.mappers[mk] + q.mu.Unlock() + if !ok { + logger.Debug(q.ctx, "ignoring update because we have no mapper") + return + } + q.workQ.enqueue(mk) +} + +func (q *querier) listenAgent(_ context.Context, msg []byte, err error) { + if xerrors.Is(err, pubsub.ErrDroppedMessages) { + q.logger.Warn(q.ctx, "pubsub may have dropped agent updates") + // we need to schedule a full resync of agent mappings + q.resyncAgentMappings() + return + } + if err != nil { + q.logger.Warn(q.ctx, "unhandled pubsub error", slog.Error(err)) + } + agent, err := parseAgentUpdate(string(msg)) + if err != nil { + q.logger.Error(q.ctx, "failed to parse agent update", slog.F("msg", string(msg)), slog.Error(err)) + return + } + logger := q.logger.With(slog.F("agent_id", agent)) + logger.Debug(q.ctx, "got agent update") + mk := mKey{ + agent: agent, + clientsOfAgent: false, + } + q.mu.Lock() + _, ok := q.mappers[mk] + q.mu.Unlock() + if !ok { + logger.Debug(q.ctx, "ignoring update because we have no mapper") + return + } + q.workQ.enqueue(mk) +} + +func (q *querier) resyncClientMappings() { + q.mu.Lock() + defer q.mu.Unlock() + for mk := range q.mappers { + if mk.clientsOfAgent { + q.workQ.enqueue(mk) + } + } +} + +func (q *querier) resyncAgentMappings() { + q.mu.Lock() + defer q.mu.Unlock() + for mk := range q.mappers { + if !mk.clientsOfAgent { + q.workQ.enqueue(mk) + } + } +} + +func (q *querier) handleUpdates() { + for { + select { + case <-q.ctx.Done(): + return + case <-q.updates: + q.updateAll() + } + } +} + +func (q *querier) updateAll() { + q.mu.Lock() + defer q.mu.Unlock() + + for _, cm := range q.mappers { + // send on goroutine to avoid holding the q.mu. Heartbeat failures come asynchronously with respect to + // other kinds of work, so it's fine to deliver the command to refresh async. + go func(m *mapper) { + // make sure we send on the _mapper_ context, not our own in case the mapper is + // shutting down or shut down. + _ = sendCtx(m.ctx, m.update, struct{}{}) + }(cm.mapper) + } +} + +func parseClientUpdate(msg string) (client, agent uuid.UUID, err error) { + parts := strings.Split(msg, ",") + if len(parts) != 2 { + return uuid.Nil, uuid.Nil, xerrors.Errorf("expected 2 parts separated by comma") + } + client, err = uuid.Parse(parts[0]) + if err != nil { + return uuid.Nil, uuid.Nil, xerrors.Errorf("failed to parse client UUID: %w", err) + } + agent, err = uuid.Parse(parts[1]) + if err != nil { + return uuid.Nil, uuid.Nil, xerrors.Errorf("failed to parse agent UUID: %w", err) + } + return client, agent, nil +} + +func parseAgentUpdate(msg string) (agent uuid.UUID, err error) { + agent, err = uuid.Parse(msg) + if err != nil { + return uuid.Nil, xerrors.Errorf("failed to parse agent UUID: %w", err) + } + return agent, nil +} + +// mKey identifies a set of node mappings we want to query. +type mKey struct { + agent uuid.UUID + // we always query based on the agent ID, but if we have client connection(s), we query the agent itself. If we + // have an agent connection, we need the node mappings for all clients of the agent. + clientsOfAgent bool +} + +// mapping associates a particular client or agent, and its respective coordinator with a node. It is generalized to +// include clients or agents: agent mappings will have client set to uuid.Nil. +type mapping struct { + client uuid.UUID + agent uuid.UUID + coordinator uuid.UUID + updatedAt time.Time + node *agpl.Node +} + +// workQ allows scheduling work based on a key. Multiple enqueue requests for the same key are coalesced, and +// only one in-progress job per key is scheduled. +type workQ[K mKey | bKey] struct { + ctx context.Context + + cond *sync.Cond + pending []K + inProgress map[K]bool +} + +func newWorkQ[K mKey | bKey](ctx context.Context) *workQ[K] { + q := &workQ[K]{ + ctx: ctx, + cond: sync.NewCond(&sync.Mutex{}), + inProgress: make(map[K]bool), + } + // wake up all waiting workers when context is done + go func() { + <-ctx.Done() + q.cond.L.Lock() + defer q.cond.L.Unlock() + q.cond.Broadcast() + }() + return q +} + +// enqueue adds the key to the workQ if it is not already pending. +func (q *workQ[K]) enqueue(key K) { + q.cond.L.Lock() + defer q.cond.L.Unlock() + for _, mk := range q.pending { + if mk == key { + // already pending, no-op + return + } + } + q.pending = append(q.pending, key) + q.cond.Signal() +} + +// acquire gets a new key to begin working on. This call blocks until work is available. After acquiring a key, the +// worker MUST call done() with the same key to mark it complete and allow new pending work to be acquired for the key. +// An error is returned if the workQ context is canceled to unblock waiting workers. +func (q *workQ[K]) acquire() (key K, err error) { + q.cond.L.Lock() + defer q.cond.L.Unlock() + for !q.workAvailable() && q.ctx.Err() == nil { + q.cond.Wait() + } + if q.ctx.Err() != nil { + return key, q.ctx.Err() + } + for i, mk := range q.pending { + _, ok := q.inProgress[mk] + if !ok { + q.pending = append(q.pending[:i], q.pending[i+1:]...) + q.inProgress[mk] = true + return mk, nil + } + } + // this should not be possible because we are holding the lock when we exit the loop that waits + panic("woke with no work available") +} + +// workAvailable returns true if there is work we can do. Must be called while holding q.cond.L +func (q workQ[K]) workAvailable() bool { + for _, mk := range q.pending { + _, ok := q.inProgress[mk] + if !ok { + return true + } + } + return false +} + +// done marks the key completed; MUST be called after acquire() for each key. +func (q *workQ[K]) done(key K) { + q.cond.L.Lock() + defer q.cond.L.Unlock() + delete(q.inProgress, key) + q.cond.Signal() +} + +// heartbeats sends heartbeats for this coordinator on a timer, and monitors heartbeats from other coordinators. If a +// coordinator misses their heartbeat, we remove it from our map of "valid" coordinators, such that we will filter out +// any mappings for it when filter() is called, and we send a signal on the update channel, which triggers all mappers +// to recompute their mappings and push them out to their connections. +type heartbeats struct { + ctx context.Context + logger slog.Logger + pubsub pubsub.Pubsub + store database.Store + self uuid.UUID + + update chan<- struct{} + firstHeartbeat chan<- struct{} + + lock sync.RWMutex + coordinators map[uuid.UUID]time.Time + timer *time.Timer +} + +func newHeartbeats( + ctx context.Context, logger slog.Logger, + ps pubsub.Pubsub, store database.Store, + self uuid.UUID, update chan<- struct{}, + firstHeartbeat chan<- struct{}, +) *heartbeats { + h := &heartbeats{ + ctx: ctx, + logger: logger, + pubsub: ps, + store: store, + self: self, + update: update, + firstHeartbeat: firstHeartbeat, + coordinators: make(map[uuid.UUID]time.Time), + } + go h.subscribe() + go h.sendBeats() + return h +} + +func (h *heartbeats) filter(mappings []mapping) []mapping { + out := make([]mapping, 0, len(mappings)) + h.lock.RLock() + defer h.lock.RUnlock() + for _, m := range mappings { + ok := m.coordinator == h.self + if !ok { + _, ok = h.coordinators[m.coordinator] + } + if ok { + out = append(out, m) + } + } + return out +} + +func (h *heartbeats) subscribe() { + eb := backoff.NewExponentialBackOff() + eb.MaxElapsedTime = 0 // retry indefinitely + eb.MaxInterval = dbMaxBackoff + bkoff := backoff.WithContext(eb, h.ctx) + var cancel context.CancelFunc + err := backoff.Retry(func() error { + cancelFn, err := h.pubsub.SubscribeWithErr(EventHeartbeats, h.listen) + if err != nil { + h.logger.Warn(h.ctx, "failed to subscribe to heartbeats", slog.Error(err)) + return err + } + cancel = cancelFn + return nil + }, bkoff) + if err != nil { + // this should only happen if context is canceled + return + } + // cancel subscription when context finishes + defer cancel() + <-h.ctx.Done() +} + +func (h *heartbeats) listen(_ context.Context, msg []byte, err error) { + if err != nil { + // in the context of heartbeats, if we miss some messages it will be OK as long + // as we aren't disconnected for multiple beats. Still, even if we are disconnected + // for longer, there isn't much to do except log. Once we reconnect we will reinstate + // any expired coordinators that are still alive and continue on. + h.logger.Warn(h.ctx, "heartbeat notification error", slog.Error(err)) + return + } + id, err := uuid.Parse(string(msg)) + if err != nil { + h.logger.Error(h.ctx, "unable to parse heartbeat", slog.F("msg", string(msg)), slog.Error(err)) + return + } + if id == h.self { + h.logger.Debug(h.ctx, "ignoring our own heartbeat") + return + } + h.recvBeat(id) +} + +func (h *heartbeats) recvBeat(id uuid.UUID) { + h.logger.Debug(h.ctx, "got heartbeat", slog.F("heartbeat_from", id)) + h.lock.Lock() + defer h.lock.Unlock() + var oldestTime time.Time + h.coordinators[id] = time.Now() + + if h.timer == nil { + // this can only happen for the very first beat + h.timer = time.AfterFunc(MissedHeartbeats*HeartbeatPeriod, h.checkExpiry) + h.logger.Debug(h.ctx, "set initial heartbeat timeout") + return + } + + for _, t := range h.coordinators { + if oldestTime.IsZero() || t.Before(oldestTime) { + oldestTime = t + } + } + d := time.Until(oldestTime.Add(MissedHeartbeats * HeartbeatPeriod)) + h.logger.Debug(h.ctx, "computed oldest heartbeat", slog.F("oldest", oldestTime), slog.F("time_to_expiry", d)) + // only reschedule if it's in the future. + if d > 0 { + h.timer.Reset(d) + } +} + +func (h *heartbeats) checkExpiry() { + h.logger.Debug(h.ctx, "checking heartbeat expiry") + h.lock.Lock() + now := time.Now() + expired := false + for id, t := range h.coordinators { + lastHB := now.Sub(t) + h.logger.Debug(h.ctx, "last heartbeat from coordinator", slog.F("other_coordinator", id), slog.F("last heartbeat", lastHB)) + if lastHB > MissedHeartbeats*HeartbeatPeriod { + expired = true + delete(h.coordinators, id) + h.logger.Info(h.ctx, "coordinator failed heartbeat check", slog.F("other_coordinator", id), slog.F("last heartbeat", lastHB)) + } + } + h.lock.Unlock() + if expired { + _ = sendCtx(h.ctx, h.update, struct{}{}) + } +} + +func (h *heartbeats) sendBeats() { + // send an initial heartbeat so that other coordinators can start using our bindings right away. + h.sendBeat() + close(h.firstHeartbeat) // signal binder it can start writing + defer h.sendDelete() + tkr := time.NewTicker(HeartbeatPeriod) + defer tkr.Stop() + for { + select { + case <-h.ctx.Done(): + h.logger.Debug(h.ctx, "ending heartbeats", slog.Error(h.ctx.Err())) + return + case <-tkr.C: + h.sendBeat() + } + } +} + +func (h *heartbeats) sendBeat() { + _, err := h.store.UpsertTailnetCoordinator(h.ctx, h.self) + if err != nil { + // just log errors, heartbeats are rescheduled on a timer + h.logger.Error(h.ctx, "failed to send heartbeat", slog.Error(err)) + return + } + h.logger.Debug(h.ctx, "sent heartbeat") +} + +func (h *heartbeats) sendDelete() { + // here we don't want to use the main context, since it will have been c + err := h.store.DeleteCoordinator(context.Background(), h.self) + if err != nil { + h.logger.Error(h.ctx, "failed to send coordinator delete", slog.Error(err)) + return + } + h.logger.Debug(h.ctx, "deleted coordinator") +} diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go new file mode 100644 index 0000000000000..e55560d7f3fb5 --- /dev/null +++ b/enterprise/tailnet/pgcoord_test.go @@ -0,0 +1,656 @@ +package tailnet_test + +import ( + "context" + "database/sql" + "encoding/json" + "io" + "net" + "os" + "sync" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + "golang.org/x/xerrors" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbtestutil" + "github.com/coder/coder/enterprise/tailnet" + agpl "github.com/coder/coder/tailnet" + "github.com/coder/coder/testutil" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestPGCoordinatorSingle_ClientWithoutAgent(t *testing.T) { + t.Parallel() + if os.Getenv("DB") == "" { + t.Skip("test only with postgres") + } + store, pubsub := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + require.NoError(t, err) + defer coordinator.Close() + + agentID := uuid.New() + client := newTestClient(t, coordinator, agentID) + defer client.close() + client.sendNode(&agpl.Node{PreferredDERP: 10}) + require.Eventually(t, func() bool { + clients, err := store.GetTailnetClientsForAgent(ctx, agentID) + if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + t.Fatalf("database error: %v", err) + } + if len(clients) == 0 { + return false + } + var node agpl.Node + err = json.Unmarshal(clients[0].Node, &node) + assert.NoError(t, err) + assert.Equal(t, 10, node.PreferredDERP) + return true + }, testutil.WaitShort, testutil.IntervalFast) + + err = client.close() + require.NoError(t, err) + <-client.errChan + <-client.closeChan + assertEventuallyNoClientsForAgent(ctx, t, store, agentID) +} + +func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) { + t.Parallel() + if os.Getenv("DB") == "" { + t.Skip("test only with postgres") + } + store, pubsub := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + require.NoError(t, err) + defer coordinator.Close() + + agent := newTestAgent(t, coordinator) + defer agent.close() + agent.sendNode(&agpl.Node{PreferredDERP: 10}) + require.Eventually(t, func() bool { + agents, err := store.GetTailnetAgents(ctx, agent.id) + if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + t.Fatalf("database error: %v", err) + } + if len(agents) == 0 { + return false + } + var node agpl.Node + err = json.Unmarshal(agents[0].Node, &node) + assert.NoError(t, err) + assert.Equal(t, 10, node.PreferredDERP) + return true + }, testutil.WaitShort, testutil.IntervalFast) + err = agent.close() + require.NoError(t, err) + <-agent.errChan + <-agent.closeChan + assertEventuallyNoAgents(ctx, t, store, agent.id) +} + +func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) { + t.Parallel() + if os.Getenv("DB") == "" { + t.Skip("test only with postgres") + } + store, pubsub := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + require.NoError(t, err) + defer coordinator.Close() + + agent := newTestAgent(t, coordinator) + defer agent.close() + agent.sendNode(&agpl.Node{PreferredDERP: 10}) + + client := newTestClient(t, coordinator, agent.id) + defer client.close() + + agentNodes := client.recvNodes(ctx, t) + require.Len(t, agentNodes, 1) + assert.Equal(t, 10, agentNodes[0].PreferredDERP) + client.sendNode(&agpl.Node{PreferredDERP: 11}) + clientNodes := agent.recvNodes(ctx, t) + require.Len(t, clientNodes, 1) + assert.Equal(t, 11, clientNodes[0].PreferredDERP) + + // Ensure an update to the agent node reaches the connIO! + agent.sendNode(&agpl.Node{PreferredDERP: 12}) + agentNodes = client.recvNodes(ctx, t) + require.Len(t, agentNodes, 1) + assert.Equal(t, 12, agentNodes[0].PreferredDERP) + + // Close the agent WebSocket so a new one can connect. + err = agent.close() + require.NoError(t, err) + _ = agent.recvErr(ctx, t) + agent.waitForClose(ctx, t) + + // Create a new agent connection. This is to simulate a reconnect! + agent = newTestAgent(t, coordinator, agent.id) + // Ensure the existing listening connIO sends its node immediately! + clientNodes = agent.recvNodes(ctx, t) + require.Len(t, clientNodes, 1) + assert.Equal(t, 11, clientNodes[0].PreferredDERP) + + // Send a bunch of updates in rapid succession, and test that we eventually get the latest. We don't want the + // coordinator accidentally reordering things. + for d := 13; d < 36; d++ { + agent.sendNode(&agpl.Node{PreferredDERP: d}) + } + for { + nodes := client.recvNodes(ctx, t) + if !assert.Len(t, nodes, 1) { + break + } + if nodes[0].PreferredDERP == 35 { + // got latest! + break + } + } + + err = agent.close() + require.NoError(t, err) + _ = agent.recvErr(ctx, t) + agent.waitForClose(ctx, t) + + err = client.close() + require.NoError(t, err) + _ = client.recvErr(ctx, t) + client.waitForClose(ctx, t) + + assertEventuallyNoAgents(ctx, t, store, agent.id) + assertEventuallyNoClientsForAgent(ctx, t, store, agent.id) +} + +func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) { + t.Parallel() + if os.Getenv("DB") == "" { + t.Skip("test only with postgres") + } + store, pubsub := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + require.NoError(t, err) + defer coordinator.Close() + + agent := newTestAgent(t, coordinator) + defer agent.close() + agent.sendNode(&agpl.Node{PreferredDERP: 10}) + + client := newTestClient(t, coordinator, agent.id) + defer client.close() + + nodes := client.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 10) + client.sendNode(&agpl.Node{PreferredDERP: 11}) + nodes = agent.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 11) + + // simulate a second coordinator via DB calls only --- our goal is to test broken heart-beating, so we can't use a + // real coordinator + fCoord := &fakeCoordinator{ + ctx: ctx, + t: t, + store: store, + id: uuid.New(), + } + start := time.Now() + fCoord.heartbeat() + fCoord.agentNode(agent.id, &agpl.Node{PreferredDERP: 12}) + nodes = client.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 12) + + // when the fake coordinator misses enough heartbeats, the real coordinator should send an update with the old + // node for the agent. + nodes = client.recvNodes(ctx, t) + assert.Greater(t, time.Since(start), tailnet.HeartbeatPeriod*tailnet.MissedHeartbeats) + assertHasDERPs(t, nodes, 10) + + err = agent.close() + require.NoError(t, err) + _ = agent.recvErr(ctx, t) + agent.waitForClose(ctx, t) + + err = client.close() + require.NoError(t, err) + _ = client.recvErr(ctx, t) + client.waitForClose(ctx, t) + + assertEventuallyNoClientsForAgent(ctx, t, store, agent.id) +} + +func TestPGCoordinatorSingle_SendsHeartbeats(t *testing.T) { + t.Parallel() + if os.Getenv("DB") == "" { + t.Skip("test only with postgres") + } + store, pubsub := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + mu := sync.Mutex{} + heartbeats := []time.Time{} + unsub, err := pubsub.SubscribeWithErr(tailnet.EventHeartbeats, func(_ context.Context, msg []byte, err error) { + assert.NoError(t, err) + mu.Lock() + defer mu.Unlock() + heartbeats = append(heartbeats, time.Now()) + }) + require.NoError(t, err) + defer unsub() + + start := time.Now() + coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + require.NoError(t, err) + defer coordinator.Close() + + require.Eventually(t, func() bool { + mu.Lock() + defer mu.Unlock() + if len(heartbeats) < 2 { + return false + } + require.Greater(t, heartbeats[0].Sub(start), time.Duration(0)) + require.Greater(t, heartbeats[1].Sub(start), time.Duration(0)) + return assert.Greater(t, heartbeats[1].Sub(heartbeats[0]), tailnet.HeartbeatPeriod*9/10) + }, testutil.WaitMedium, testutil.IntervalMedium) +} + +// TestPGCoordinatorDual_Mainline tests with 2 coordinators, one agent connected to each, and 2 clients per agent. +// +// +---------+ +// agent1 ---> | coord1 | <--- client11 (coord 1, agent 1) +// | | +// | | <--- client12 (coord 1, agent 2) +// +---------+ +// +---------+ +// agent2 ---> | coord2 | <--- client21 (coord 2, agent 1) +// | | +// | | <--- client22 (coord2, agent 2) +// +---------+ +func TestPGCoordinatorDual_Mainline(t *testing.T) { + t.Parallel() + if os.Getenv("DB") == "" { + t.Skip("test only with postgres") + } + store, pubsub := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + coord1, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + require.NoError(t, err) + defer coord1.Close() + coord2, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + require.NoError(t, err) + defer coord2.Close() + + agent1 := newTestAgent(t, coord1) + defer agent1.close() + agent2 := newTestAgent(t, coord2) + defer agent2.close() + + client11 := newTestClient(t, coord1, agent1.id) + defer client11.close() + client12 := newTestClient(t, coord1, agent2.id) + defer client12.close() + client21 := newTestClient(t, coord2, agent1.id) + defer client21.close() + client22 := newTestClient(t, coord2, agent2.id) + defer client22.close() + + client11.sendNode(&agpl.Node{PreferredDERP: 11}) + nodes := agent1.recvNodes(ctx, t) + assert.Len(t, nodes, 1) + assertHasDERPs(t, nodes, 11) + + client21.sendNode(&agpl.Node{PreferredDERP: 21}) + nodes = agent1.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 21, 11) + + client22.sendNode(&agpl.Node{PreferredDERP: 22}) + nodes = agent2.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 22) + + agent2.sendNode(&agpl.Node{PreferredDERP: 2}) + nodes = client22.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 2) + nodes = client12.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 2) + + client12.sendNode(&agpl.Node{PreferredDERP: 12}) + nodes = agent2.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 12, 22) + + agent1.sendNode(&agpl.Node{PreferredDERP: 1}) + nodes = client21.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 1) + nodes = client11.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 1) + + // let's close coord2 + err = coord2.Close() + require.NoError(t, err) + + // this closes agent2, client22, client21 + err = agent2.recvErr(ctx, t) + require.ErrorIs(t, err, io.EOF) + err = client22.recvErr(ctx, t) + require.ErrorIs(t, err, io.EOF) + err = client21.recvErr(ctx, t) + require.ErrorIs(t, err, io.EOF) + + // agent1 will see an update that drops client21. + // In this case the update is superfluous because client11's node hasn't changed, and agents don't deprogram clients + // from the dataplane even if they are missing. Suppressing this kind of update would require the coordinator to + // store all the data its sent to each connection, so we don't bother. + nodes = agent1.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 11) + + // note that although agent2 is disconnected, client12 does NOT get an update because we suppress empty updates. + // (Its easy to tell these are superfluous.) + + assertEventuallyNoAgents(ctx, t, store, agent2.id) + + // Close coord1 + err = coord1.Close() + require.NoError(t, err) + // this closes agent1, client12, client11 + err = agent1.recvErr(ctx, t) + require.ErrorIs(t, err, io.EOF) + err = client12.recvErr(ctx, t) + require.ErrorIs(t, err, io.EOF) + err = client11.recvErr(ctx, t) + require.ErrorIs(t, err, io.EOF) + + // wait for all connections to close + err = agent1.close() + require.NoError(t, err) + agent1.waitForClose(ctx, t) + + err = agent2.close() + require.NoError(t, err) + agent2.waitForClose(ctx, t) + + err = client11.close() + require.NoError(t, err) + client11.waitForClose(ctx, t) + + err = client12.close() + require.NoError(t, err) + client12.waitForClose(ctx, t) + + err = client21.close() + require.NoError(t, err) + client21.waitForClose(ctx, t) + + err = client22.close() + require.NoError(t, err) + client22.waitForClose(ctx, t) + + assertEventuallyNoAgents(ctx, t, store, agent1.id) + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) + assertEventuallyNoClientsForAgent(ctx, t, store, agent2.id) +} + +// TestPGCoordinator_MultiAgent tests when a single agent connects to multiple coordinators. +// We use two agent connections, but they share the same AgentID. This could happen due to a reconnection, +// or an infrastructure problem where an old workspace is not fully cleaned up before a new one started. +// +// +---------+ +// agent1 ---> | coord1 | +// +---------+ +// +---------+ +// agent2 ---> | coord2 | +// +---------+ +// +---------+ +// | coord3 | <--- client +// +---------+ +func TestPGCoordinator_MultiAgent(t *testing.T) { + t.Parallel() + if os.Getenv("DB") == "" { + t.Skip("test only with postgres") + } + store, pubsub := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + coord1, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + require.NoError(t, err) + defer coord1.Close() + coord2, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + require.NoError(t, err) + defer coord2.Close() + coord3, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + require.NoError(t, err) + defer coord3.Close() + + agent1 := newTestAgent(t, coord1) + defer agent1.close() + agent2 := newTestAgent(t, coord2, agent1.id) + defer agent2.close() + + client := newTestClient(t, coord3, agent1.id) + defer client.close() + + client.sendNode(&agpl.Node{PreferredDERP: 3}) + nodes := agent1.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 3) + nodes = agent2.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 3) + + agent1.sendNode(&agpl.Node{PreferredDERP: 1}) + nodes = client.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 1) + + // agent2's update overrides agent1 because it is newer + agent2.sendNode(&agpl.Node{PreferredDERP: 2}) + nodes = client.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 2) + + // agent2 disconnects, and we should revert back to agent1 + err = agent2.close() + require.NoError(t, err) + err = agent2.recvErr(ctx, t) + require.ErrorIs(t, err, io.ErrClosedPipe) + agent2.waitForClose(ctx, t) + nodes = client.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 1) + + agent1.sendNode(&agpl.Node{PreferredDERP: 11}) + nodes = client.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 11) + + client.sendNode(&agpl.Node{PreferredDERP: 31}) + nodes = agent1.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 31) + + err = agent1.close() + require.NoError(t, err) + err = agent1.recvErr(ctx, t) + require.ErrorIs(t, err, io.ErrClosedPipe) + agent1.waitForClose(ctx, t) + + err = client.close() + require.NoError(t, err) + err = client.recvErr(ctx, t) + require.ErrorIs(t, err, io.ErrClosedPipe) + client.waitForClose(ctx, t) + + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) + assertEventuallyNoAgents(ctx, t, store, agent1.id) +} + +type testConn struct { + ws, serverWS net.Conn + nodeChan chan []*agpl.Node + sendNode func(node *agpl.Node) + errChan <-chan error + id uuid.UUID + closeChan chan struct{} +} + +func newTestConn(ids []uuid.UUID) *testConn { + a := &testConn{} + a.ws, a.serverWS = net.Pipe() + a.nodeChan = make(chan []*agpl.Node) + a.sendNode, a.errChan = agpl.ServeCoordinator(a.ws, func(nodes []*agpl.Node) error { + a.nodeChan <- nodes + return nil + }) + if len(ids) > 1 { + panic("too many") + } + if len(ids) == 1 { + a.id = ids[0] + } else { + a.id = uuid.New() + } + a.closeChan = make(chan struct{}) + return a +} + +func newTestAgent(t *testing.T, coord agpl.Coordinator, id ...uuid.UUID) *testConn { + a := newTestConn(id) + go func() { + err := coord.ServeAgent(a.serverWS, a.id, "") + assert.NoError(t, err) + close(a.closeChan) + }() + return a +} + +func (c *testConn) close() error { + return c.ws.Close() +} + +func (c *testConn) recvNodes(ctx context.Context, t *testing.T) []*agpl.Node { + t.Helper() + select { + case <-ctx.Done(): + t.Fatal("timeout receiving nodes") + return nil + case nodes := <-c.nodeChan: + return nodes + } +} + +func (c *testConn) recvErr(ctx context.Context, t *testing.T) error { + t.Helper() + select { + case <-ctx.Done(): + t.Fatal("timeout receiving error") + return ctx.Err() + case err := <-c.errChan: + return err + } +} + +func (c *testConn) waitForClose(ctx context.Context, t *testing.T) { + t.Helper() + select { + case <-ctx.Done(): + t.Fatal("timeout waiting for connection to close") + return + case <-c.closeChan: + return + } +} + +func newTestClient(t *testing.T, coord agpl.Coordinator, agentID uuid.UUID, id ...uuid.UUID) *testConn { + c := newTestConn(id) + go func() { + err := coord.ServeClient(c.serverWS, c.id, agentID) + assert.NoError(t, err) + close(c.closeChan) + }() + return c +} + +func assertHasDERPs(t *testing.T, nodes []*agpl.Node, expected ...int) { + if !assert.Len(t, nodes, len(expected), "expected %d node(s), got %d", len(expected), len(nodes)) { + return + } + derps := make([]int, 0, len(nodes)) + for _, n := range nodes { + derps = append(derps, n.PreferredDERP) + } + for _, e := range expected { + assert.Contains(t, derps, e, "expected DERP %v, got %v", e, derps) + } +} + +func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { + assert.Eventually(t, func() bool { + agents, err := store.GetTailnetAgents(ctx, agentID) + if xerrors.Is(err, sql.ErrNoRows) { + return true + } + if err != nil { + t.Fatal(err) + } + return len(agents) == 0 + }, testutil.WaitShort, testutil.IntervalFast) +} + +func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { + assert.Eventually(t, func() bool { + clients, err := store.GetTailnetClientsForAgent(ctx, agentID) + if xerrors.Is(err, sql.ErrNoRows) { + return true + } + if err != nil { + t.Fatal(err) + } + return len(clients) == 0 + }, testutil.WaitShort, testutil.IntervalFast) +} + +type fakeCoordinator struct { + ctx context.Context + t *testing.T + store database.Store + id uuid.UUID +} + +func (c *fakeCoordinator) heartbeat() { + c.t.Helper() + _, err := c.store.UpsertTailnetCoordinator(c.ctx, c.id) + require.NoError(c.t, err) +} + +func (c *fakeCoordinator) agentNode(agentID uuid.UUID, node *agpl.Node) { + c.t.Helper() + nodeRaw, err := json.Marshal(node) + require.NoError(c.t, err) + _, err = c.store.UpsertTailnetAgent(c.ctx, database.UpsertTailnetAgentParams{ + ID: agentID, + CoordinatorID: c.id, + Node: nodeRaw, + }) + require.NoError(c.t, err) +} diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 5ee49cd194f16..f8e5476a9f7a2 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -1,6 +1,7 @@ package tailnet import ( + "bytes" "context" "encoding/json" "errors" @@ -174,11 +175,12 @@ func newCore(logger slog.Logger) *core { var ErrWouldBlock = xerrors.New("would block") type TrackedConn struct { - ctx context.Context - cancel func() - conn net.Conn - updates chan []*Node - logger slog.Logger + ctx context.Context + cancel func() + conn net.Conn + updates chan []*Node + logger slog.Logger + lastData []byte // ID is an ephemeral UUID used to uniquely identify the owner of the // connection. @@ -224,6 +226,10 @@ func (t *TrackedConn) SendUpdates() { t.logger.Error(t.ctx, "unable to marshal nodes update", slog.Error(err), slog.F("nodes", nodes)) return } + if bytes.Equal(t.lastData, data) { + t.logger.Debug(t.ctx, "skipping duplicate update", slog.F("nodes", nodes)) + continue + } // Set a deadline so that hung connections don't put back pressure on the system. // Node updates are tiny, so even the dinkiest connection can handle them if it's not hung. @@ -255,6 +261,7 @@ func (t *TrackedConn) SendUpdates() { _ = t.Close() return } + t.lastData = data } } } From d37787f370fc0f1ac8cdff338a40b0dffb4f75cf Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 15 Jun 2023 10:38:26 +0000 Subject: [PATCH 02/11] Fix db migration; tests Signed-off-by: Spike Curtis --- coderd/database/dbfake/dbfake.go | 18 ++++++++++-------- ...down.sql => 000127_ha_coordinator.down.sql} | 0 ...tor.up.sql => 000127_ha_coordinator.up.sql} | 0 enterprise/tailnet/coordinator_test.go | 12 ++++++------ enterprise/tailnet/pgcoord.go | 2 +- tailnet/coordinator_test.go | 12 ++++++------ 6 files changed, 23 insertions(+), 21 deletions(-) rename coderd/database/migrations/{000125_ha_coordinator.down.sql => 000127_ha_coordinator.down.sql} (100%) rename coderd/database/migrations/{000125_ha_coordinator.up.sql => 000127_ha_coordinator.up.sql} (100%) diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index f0bf58830f427..33d9157ff05fd 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -5197,34 +5197,36 @@ func (q *fakeQuerier) UpsertServiceBanner(_ context.Context, data string) error // API, and for that kind of test, the in-memory, AGPL tailnet coordinator is sufficient. Therefore, these methods // remain unimplemented in the fakeQuerier. +var ErrUnimplemented = xerrors.New("unimplemented") + func (*fakeQuerier) UpsertTailnetClient(context.Context, database.UpsertTailnetClientParams) (database.TailnetClient, error) { - panic("unimplemented") + return database.TailnetClient{}, ErrUnimplemented } func (*fakeQuerier) UpsertTailnetAgent(context.Context, database.UpsertTailnetAgentParams) (database.TailnetAgent, error) { - panic("unimplemented") + return database.TailnetAgent{}, ErrUnimplemented } func (*fakeQuerier) UpsertTailnetCoordinator(context.Context, uuid.UUID) (database.TailnetCoordinator, error) { - panic("unimplemented") + return database.TailnetCoordinator{}, ErrUnimplemented } func (*fakeQuerier) DeleteTailnetClient(context.Context, database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) { - panic("unimplemented") + return database.DeleteTailnetClientRow{}, ErrUnimplemented } func (*fakeQuerier) DeleteTailnetAgent(context.Context, database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) { - panic("unimplemented") + return database.DeleteTailnetAgentRow{}, ErrUnimplemented } func (*fakeQuerier) DeleteCoordinator(context.Context, uuid.UUID) error { - panic("unimplemented") + return ErrUnimplemented } func (*fakeQuerier) GetTailnetAgents(context.Context, uuid.UUID) ([]database.TailnetAgent, error) { - panic("unimplemented") + return nil, ErrUnimplemented } func (*fakeQuerier) GetTailnetClientsForAgent(context.Context, uuid.UUID) ([]database.TailnetClient, error) { - panic("unimplemented") + return nil, ErrUnimplemented } diff --git a/coderd/database/migrations/000125_ha_coordinator.down.sql b/coderd/database/migrations/000127_ha_coordinator.down.sql similarity index 100% rename from coderd/database/migrations/000125_ha_coordinator.down.sql rename to coderd/database/migrations/000127_ha_coordinator.down.sql diff --git a/coderd/database/migrations/000125_ha_coordinator.up.sql b/coderd/database/migrations/000127_ha_coordinator.up.sql similarity index 100% rename from coderd/database/migrations/000125_ha_coordinator.up.sql rename to coderd/database/migrations/000127_ha_coordinator.up.sql diff --git a/enterprise/tailnet/coordinator_test.go b/enterprise/tailnet/coordinator_test.go index bcc3ddca34d05..a29bf2ad273a9 100644 --- a/enterprise/tailnet/coordinator_test.go +++ b/enterprise/tailnet/coordinator_test.go @@ -95,7 +95,7 @@ func TestCoordinatorSingle(t *testing.T) { assert.NoError(t, err) close(closeAgentChan) }() - sendAgentNode(&agpl.Node{}) + sendAgentNode(&agpl.Node{PreferredDERP: 1}) require.Eventually(t, func() bool { return coordinator.Node(agentID) != nil }, testutil.WaitShort, testutil.IntervalFast) @@ -117,12 +117,12 @@ func TestCoordinatorSingle(t *testing.T) { }() agentNodes := <-clientNodeChan require.Len(t, agentNodes, 1) - sendClientNode(&agpl.Node{}) + sendClientNode(&agpl.Node{PreferredDERP: 2}) clientNodes := <-agentNodeChan require.Len(t, clientNodes, 1) // Ensure an update to the agent node reaches the client! - sendAgentNode(&agpl.Node{}) + sendAgentNode(&agpl.Node{PreferredDERP: 3}) agentNodes = <-clientNodeChan require.Len(t, agentNodes, 1) @@ -188,7 +188,7 @@ func TestCoordinatorHA(t *testing.T) { assert.NoError(t, err) close(closeAgentChan) }() - sendAgentNode(&agpl.Node{}) + sendAgentNode(&agpl.Node{PreferredDERP: 1}) require.Eventually(t, func() bool { return coordinator1.Node(agentID) != nil }, testutil.WaitShort, testutil.IntervalFast) @@ -214,13 +214,13 @@ func TestCoordinatorHA(t *testing.T) { }() agentNodes := <-clientNodeChan require.Len(t, agentNodes, 1) - sendClientNode(&agpl.Node{}) + sendClientNode(&agpl.Node{PreferredDERP: 2}) _ = sendClientNode clientNodes := <-agentNodeChan require.Len(t, clientNodes, 1) // Ensure an update to the agent node reaches the client! - sendAgentNode(&agpl.Node{}) + sendAgentNode(&agpl.Node{PreferredDERP: 3}) agentNodes = <-clientNodeChan require.Len(t, agentNodes, 1) diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index a1a14bd993bc8..8f4c43269e602 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -543,7 +543,7 @@ func (m *mapper) mappingsToNodes(mappings []mapping) []*agpl.Node { } // querier is responsible for monitoring pubsub notifications and querying the database for the mappings that all -// connnected clients and agents need. It also checks heartbeats and withdraws mappings from coordinators that have +// connected clients and agents need. It also checks heartbeats and withdraws mappings from coordinators that have // failed heartbeats. type querier struct { ctx context.Context diff --git a/tailnet/coordinator_test.go b/tailnet/coordinator_test.go index 94c6f6da58341..300a89ad5f9e0 100644 --- a/tailnet/coordinator_test.go +++ b/tailnet/coordinator_test.go @@ -96,7 +96,7 @@ func TestCoordinator(t *testing.T) { assert.NoError(t, err) close(closeAgentChan) }() - sendAgentNode(&tailnet.Node{}) + sendAgentNode(&tailnet.Node{PreferredDERP: 1}) require.Eventually(t, func() bool { return coordinator.Node(agentID) != nil }, testutil.WaitShort, testutil.IntervalFast) @@ -122,7 +122,7 @@ func TestCoordinator(t *testing.T) { case <-ctx.Done(): t.Fatal("timed out") } - sendClientNode(&tailnet.Node{}) + sendClientNode(&tailnet.Node{PreferredDERP: 2}) clientNodes := <-agentNodeChan require.Len(t, clientNodes, 1) @@ -131,7 +131,7 @@ func TestCoordinator(t *testing.T) { time.Sleep(tailnet.WriteTimeout * 3 / 2) // Ensure an update to the agent node reaches the client! - sendAgentNode(&tailnet.Node{}) + sendAgentNode(&tailnet.Node{PreferredDERP: 3}) select { case agentNodes := <-clientNodeChan: require.Len(t, agentNodes, 1) @@ -193,7 +193,7 @@ func TestCoordinator(t *testing.T) { assert.NoError(t, err) close(closeAgentChan1) }() - sendAgentNode1(&tailnet.Node{}) + sendAgentNode1(&tailnet.Node{PreferredDERP: 1}) require.Eventually(t, func() bool { return coordinator.Node(agentID) != nil }, testutil.WaitShort, testutil.IntervalFast) @@ -215,12 +215,12 @@ func TestCoordinator(t *testing.T) { }() agentNodes := <-clientNodeChan require.Len(t, agentNodes, 1) - sendClientNode(&tailnet.Node{}) + sendClientNode(&tailnet.Node{PreferredDERP: 2}) clientNodes := <-agentNodeChan1 require.Len(t, clientNodes, 1) // Ensure an update to the agent node reaches the client! - sendAgentNode1(&tailnet.Node{}) + sendAgentNode1(&tailnet.Node{PreferredDERP: 3}) agentNodes = <-clientNodeChan require.Len(t, agentNodes, 1) From 513c3b7167e32eec7c76f14ceba46042bb369132 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 15 Jun 2023 11:27:08 +0000 Subject: [PATCH 03/11] Add fixture, regenerate Signed-off-by: Spike Curtis --- coderd/database/dbauthz/dbauthz.go | 56 +++++++++++++++ coderd/database/dbauthz/tailnetcoordinator.go | 66 ----------------- coderd/database/dbfake/dbfake.go | 64 ++++++++--------- coderd/database/dbmetrics/dbmetrics.go | 72 +++++++++---------- .../migrations/000127_ha_coordinator.up.sql | 1 + .../fixtures/000127_ha_coordinator.up.sql | 26 +++++++ 6 files changed, 151 insertions(+), 134 deletions(-) delete mode 100644 coderd/database/dbauthz/tailnetcoordinator.go create mode 100644 coderd/database/migrations/testdata/fixtures/000127_ha_coordinator.up.sql diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 30b8c5d33569c..1a9345c549653 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2507,3 +2507,59 @@ func (q *querier) UpsertServiceBanner(ctx context.Context, value string) error { } return q.db.UpsertServiceBanner(ctx, value) } + +func (q *querier) UpsertTailnetClient(ctx context.Context, arg database.UpsertTailnetClientParams) (database.TailnetClient, error) { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { + return database.TailnetClient{}, err + } + return q.db.UpsertTailnetClient(ctx, arg) +} + +func (q *querier) UpsertTailnetAgent(ctx context.Context, arg database.UpsertTailnetAgentParams) (database.TailnetAgent, error) { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { + return database.TailnetAgent{}, err + } + return q.db.UpsertTailnetAgent(ctx, arg) +} + +func (q *querier) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (database.TailnetCoordinator, error) { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { + return database.TailnetCoordinator{}, err + } + return q.db.UpsertTailnetCoordinator(ctx, id) +} + +func (q *querier) DeleteTailnetClient(ctx context.Context, arg database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) { + if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil { + return database.DeleteTailnetClientRow{}, err + } + return q.db.DeleteTailnetClient(ctx, arg) +} + +func (q *querier) DeleteTailnetAgent(ctx context.Context, arg database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { + return database.DeleteTailnetAgentRow{}, err + } + return q.db.DeleteTailnetAgent(ctx, arg) +} + +func (q *querier) DeleteCoordinator(ctx context.Context, id uuid.UUID) error { + if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil { + return err + } + return q.db.DeleteCoordinator(ctx, id) +} + +func (q *querier) GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]database.TailnetAgent, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTailnetCoordinator); err != nil { + return nil, err + } + return q.db.GetTailnetAgents(ctx, id) +} + +func (q *querier) GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]database.TailnetClient, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTailnetCoordinator); err != nil { + return nil, err + } + return q.db.GetTailnetClientsForAgent(ctx, agentID) +} diff --git a/coderd/database/dbauthz/tailnetcoordinator.go b/coderd/database/dbauthz/tailnetcoordinator.go deleted file mode 100644 index ddf924a498364..0000000000000 --- a/coderd/database/dbauthz/tailnetcoordinator.go +++ /dev/null @@ -1,66 +0,0 @@ -package dbauthz - -import ( - "context" - - "github.com/google/uuid" - - "github.com/coder/coder/coderd/database" - "github.com/coder/coder/coderd/rbac" -) - -func (q *querier) UpsertTailnetClient(ctx context.Context, arg database.UpsertTailnetClientParams) (database.TailnetClient, error) { - if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { - return database.TailnetClient{}, err - } - return q.db.UpsertTailnetClient(ctx, arg) -} - -func (q *querier) UpsertTailnetAgent(ctx context.Context, arg database.UpsertTailnetAgentParams) (database.TailnetAgent, error) { - if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { - return database.TailnetAgent{}, err - } - return q.db.UpsertTailnetAgent(ctx, arg) -} - -func (q *querier) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (database.TailnetCoordinator, error) { - if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { - return database.TailnetCoordinator{}, err - } - return q.db.UpsertTailnetCoordinator(ctx, id) -} - -func (q *querier) DeleteTailnetClient(ctx context.Context, arg database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) { - if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil { - return database.DeleteTailnetClientRow{}, err - } - return q.db.DeleteTailnetClient(ctx, arg) -} - -func (q *querier) DeleteTailnetAgent(ctx context.Context, arg database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) { - if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { - return database.DeleteTailnetAgentRow{}, err - } - return q.db.DeleteTailnetAgent(ctx, arg) -} - -func (q *querier) DeleteCoordinator(ctx context.Context, id uuid.UUID) error { - if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil { - return err - } - return q.db.DeleteCoordinator(ctx, id) -} - -func (q *querier) GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]database.TailnetAgent, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTailnetCoordinator); err != nil { - return nil, err - } - return q.db.GetTailnetAgents(ctx, id) -} - -func (q *querier) GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]database.TailnetClient, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTailnetCoordinator); err != nil { - return nil, err - } - return q.db.GetTailnetClientsForAgent(ctx, agentID) -} diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index 33d9157ff05fd..a4cd7542986e9 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -966,6 +966,15 @@ func isNotNull(v interface{}) bool { return reflect.ValueOf(v).FieldByName("Valid").Bool() } +// The remaining methods are only used by the enterprise/tailnet.pgCoord. This coordinator explicitly depends on +// postgres triggers that announce changes on the pubsub. Implementing support for this in the fake database would +// strongly couple the fakeQuerier to the pubsub, which is undesirable. Furthermore, it makes little sense to directly +// test the pgCoord against anything other than postgres. The fakeQuerier is designed to allow us to test the Coderd +// API, and for that kind of test, the in-memory, AGPL tailnet coordinator is sufficient. Therefore, these methods +// remain unimplemented in the fakeQuerier. + +var ErrUnimplemented = xerrors.New("unimplemented") + func (*fakeQuerier) AcquireLock(_ context.Context, _ int64) error { return xerrors.New("AcquireLock must only be called within a transaction") } @@ -1066,6 +1075,10 @@ func (q *fakeQuerier) DeleteApplicationConnectAPIKeysByUserID(_ context.Context, return nil } +func (*fakeQuerier) DeleteCoordinator(context.Context, uuid.UUID) error { + return ErrUnimplemented +} + func (q *fakeQuerier) DeleteGitSSHKey(_ context.Context, userID uuid.UUID) error { q.mutex.Lock() defer q.mutex.Unlock() @@ -1174,6 +1187,14 @@ func (q *fakeQuerier) DeleteReplicasUpdatedBefore(_ context.Context, before time return nil } +func (*fakeQuerier) DeleteTailnetAgent(context.Context, database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) { + return database.DeleteTailnetAgentRow{}, ErrUnimplemented +} + +func (*fakeQuerier) DeleteTailnetClient(context.Context, database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) { + return database.DeleteTailnetClientRow{}, ErrUnimplemented +} + func (q *fakeQuerier) GetAPIKeyByID(_ context.Context, id string) (database.APIKey, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -2153,6 +2174,14 @@ func (q *fakeQuerier) GetServiceBanner(_ context.Context) (string, error) { return string(q.serviceBanner), nil } +func (*fakeQuerier) GetTailnetAgents(context.Context, uuid.UUID) ([]database.TailnetAgent, error) { + return nil, ErrUnimplemented +} + +func (*fakeQuerier) GetTailnetClientsForAgent(context.Context, uuid.UUID) ([]database.TailnetClient, error) { + return nil, ErrUnimplemented +} + func (q *fakeQuerier) GetTemplateAverageBuildTime(ctx context.Context, arg database.GetTemplateAverageBuildTimeParams) (database.GetTemplateAverageBuildTimeRow, error) { if err := validateDatabaseType(arg); err != nil { return database.GetTemplateAverageBuildTimeRow{}, err @@ -5190,43 +5219,14 @@ func (q *fakeQuerier) UpsertServiceBanner(_ context.Context, data string) error return nil } -// The remaining methods are only used by the enterprise/tailnet.pgCoord. This coordinator explicitly depends on -// postgres triggers that announce changes on the pubsub. Implementing support for this in the fake database would -// strongly couple the fakeQuerier to the pubsub, which is undesirable. Furthermore, it makes little sense to directly -// test the pgCoord against anything other than postgres. The fakeQuerier is designed to allow us to test the Coderd -// API, and for that kind of test, the in-memory, AGPL tailnet coordinator is sufficient. Therefore, these methods -// remain unimplemented in the fakeQuerier. - -var ErrUnimplemented = xerrors.New("unimplemented") +func (*fakeQuerier) UpsertTailnetAgent(context.Context, database.UpsertTailnetAgentParams) (database.TailnetAgent, error) { + return database.TailnetAgent{}, ErrUnimplemented +} func (*fakeQuerier) UpsertTailnetClient(context.Context, database.UpsertTailnetClientParams) (database.TailnetClient, error) { return database.TailnetClient{}, ErrUnimplemented } -func (*fakeQuerier) UpsertTailnetAgent(context.Context, database.UpsertTailnetAgentParams) (database.TailnetAgent, error) { - return database.TailnetAgent{}, ErrUnimplemented -} - func (*fakeQuerier) UpsertTailnetCoordinator(context.Context, uuid.UUID) (database.TailnetCoordinator, error) { return database.TailnetCoordinator{}, ErrUnimplemented } - -func (*fakeQuerier) DeleteTailnetClient(context.Context, database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) { - return database.DeleteTailnetClientRow{}, ErrUnimplemented -} - -func (*fakeQuerier) DeleteTailnetAgent(context.Context, database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) { - return database.DeleteTailnetAgentRow{}, ErrUnimplemented -} - -func (*fakeQuerier) DeleteCoordinator(context.Context, uuid.UUID) error { - return ErrUnimplemented -} - -func (*fakeQuerier) GetTailnetAgents(context.Context, uuid.UUID) ([]database.TailnetAgent, error) { - return nil, ErrUnimplemented -} - -func (*fakeQuerier) GetTailnetClientsForAgent(context.Context, uuid.UUID) ([]database.TailnetClient, error) { - return nil, ErrUnimplemented -} diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index dd281c6831914..5dc185b2f237a 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -143,6 +143,12 @@ func (m metricsStore) DeleteApplicationConnectAPIKeysByUserID(ctx context.Contex return err } +func (m metricsStore) DeleteCoordinator(ctx context.Context, id uuid.UUID) error { + start := time.Now() + defer m.queryLatencies.WithLabelValues("DeleteCoordinator").Observe(time.Since(start).Seconds()) + return m.s.DeleteCoordinator(ctx, id) +} + func (m metricsStore) DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error { start := time.Now() err := m.s.DeleteGitSSHKey(ctx, userID) @@ -199,6 +205,18 @@ func (m metricsStore) DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt return err } +func (m metricsStore) DeleteTailnetAgent(ctx context.Context, arg database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) { + start := time.Now() + defer m.queryLatencies.WithLabelValues("DeleteTailnetAgent").Observe(time.Since(start).Seconds()) + return m.s.DeleteTailnetAgent(ctx, arg) +} + +func (m metricsStore) DeleteTailnetClient(ctx context.Context, arg database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) { + start := time.Now() + defer m.queryLatencies.WithLabelValues("DeleteTailnetClient").Observe(time.Since(start).Seconds()) + return m.s.DeleteTailnetClient(ctx, arg) +} + func (m metricsStore) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) { start := time.Now() apiKey, err := m.s.GetAPIKeyByID(ctx, id) @@ -549,6 +567,18 @@ func (m metricsStore) GetServiceBanner(ctx context.Context) (string, error) { return banner, err } +func (m metricsStore) GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]database.TailnetAgent, error) { + start := time.Now() + defer m.queryLatencies.WithLabelValues("GetTailnetAgents").Observe(time.Since(start).Seconds()) + return m.s.GetTailnetAgents(ctx, id) +} + +func (m metricsStore) GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]database.TailnetClient, error) { + start := time.Now() + defer m.queryLatencies.WithLabelValues("GetTailnetClientsForAgent").Observe(time.Since(start).Seconds()) + return m.s.GetTailnetClientsForAgent(ctx, agentID) +} + func (m metricsStore) GetTemplateAverageBuildTime(ctx context.Context, arg database.GetTemplateAverageBuildTimeParams) (database.GetTemplateAverageBuildTimeRow, error) { start := time.Now() buildTime, err := m.s.GetTemplateAverageBuildTime(ctx, arg) @@ -1536,50 +1566,20 @@ func (m metricsStore) UpsertServiceBanner(ctx context.Context, value string) err return r0 } -func (m metricsStore) UpsertTailnetClient(ctx context.Context, arg database.UpsertTailnetClientParams) (database.TailnetClient, error) { - start := time.Now() - defer m.queryLatencies.WithLabelValues("UpsertTailnetClient").Observe(time.Since(start).Seconds()) - return m.s.UpsertTailnetClient(ctx, arg) -} - func (m metricsStore) UpsertTailnetAgent(ctx context.Context, arg database.UpsertTailnetAgentParams) (database.TailnetAgent, error) { start := time.Now() defer m.queryLatencies.WithLabelValues("UpsertTailnetAgent").Observe(time.Since(start).Seconds()) return m.s.UpsertTailnetAgent(ctx, arg) } -func (m metricsStore) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (database.TailnetCoordinator, error) { - start := time.Now() - defer m.queryLatencies.WithLabelValues("UpsertTailnetCoordinator").Observe(time.Since(start).Seconds()) - return m.s.UpsertTailnetCoordinator(ctx, id) -} - -func (m metricsStore) DeleteTailnetClient(ctx context.Context, arg database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) { - start := time.Now() - defer m.queryLatencies.WithLabelValues("DeleteTailnetClient").Observe(time.Since(start).Seconds()) - return m.s.DeleteTailnetClient(ctx, arg) -} - -func (m metricsStore) DeleteTailnetAgent(ctx context.Context, arg database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) { - start := time.Now() - defer m.queryLatencies.WithLabelValues("DeleteTailnetAgent").Observe(time.Since(start).Seconds()) - return m.s.DeleteTailnetAgent(ctx, arg) -} - -func (m metricsStore) DeleteCoordinator(ctx context.Context, id uuid.UUID) error { - start := time.Now() - defer m.queryLatencies.WithLabelValues("DeleteCoordinator").Observe(time.Since(start).Seconds()) - return m.s.DeleteCoordinator(ctx, id) -} - -func (m metricsStore) GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]database.TailnetAgent, error) { +func (m metricsStore) UpsertTailnetClient(ctx context.Context, arg database.UpsertTailnetClientParams) (database.TailnetClient, error) { start := time.Now() - defer m.queryLatencies.WithLabelValues("GetTailnetAgents").Observe(time.Since(start).Seconds()) - return m.s.GetTailnetAgents(ctx, id) + defer m.queryLatencies.WithLabelValues("UpsertTailnetClient").Observe(time.Since(start).Seconds()) + return m.s.UpsertTailnetClient(ctx, arg) } -func (m metricsStore) GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]database.TailnetClient, error) { +func (m metricsStore) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (database.TailnetCoordinator, error) { start := time.Now() - defer m.queryLatencies.WithLabelValues("GetTailnetClientsForAgent").Observe(time.Since(start).Seconds()) - return m.s.GetTailnetClientsForAgent(ctx, agentID) + defer m.queryLatencies.WithLabelValues("UpsertTailnetCoordinator").Observe(time.Since(start).Seconds()) + return m.s.UpsertTailnetCoordinator(ctx, id) } diff --git a/coderd/database/migrations/000127_ha_coordinator.up.sql b/coderd/database/migrations/000127_ha_coordinator.up.sql index 3b1431e173840..8200c9401506f 100644 --- a/coderd/database/migrations/000127_ha_coordinator.up.sql +++ b/coderd/database/migrations/000127_ha_coordinator.up.sql @@ -31,6 +31,7 @@ CREATE TABLE tailnet_agents ( FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE ); + -- For shutting down / GC a coordinator CREATE INDEX idx_tailnet_agents_coordinator ON tailnet_agents (coordinator_id); diff --git a/coderd/database/migrations/testdata/fixtures/000127_ha_coordinator.up.sql b/coderd/database/migrations/testdata/fixtures/000127_ha_coordinator.up.sql new file mode 100644 index 0000000000000..488859dda291f --- /dev/null +++ b/coderd/database/migrations/testdata/fixtures/000127_ha_coordinator.up.sql @@ -0,0 +1,26 @@ +INSERT INTO tailnet_coordinators + (id, heartbeat_at) +VALUES + ( + 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', + '2023-06-15 10:23:54+00' + ); + +INSERT INTO tailnet_clients + (id, agent_id, coordinator_id, node) +VALUES + ( + 'b0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', + 'c0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', + 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', + '{"preferred_derp": 12}'::json + ); + +INSERT INTO tailnet_agents +(id, coordinator_id, node) +VALUES + ( + 'c0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', + 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', + '{"preferred_derp": 13}'::json + ); From 7d751b39b97525432f912ce4d1c8d12346841a28 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 15 Jun 2023 11:37:07 +0000 Subject: [PATCH 04/11] Fix fixtures Signed-off-by: Spike Curtis --- coderd/database/dbfake/dbfake.go | 13 ++++++------- .../testdata/fixtures/000127_ha_coordinator.up.sql | 6 ++++-- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index a4cd7542986e9..a751883498967 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -966,13 +966,12 @@ func isNotNull(v interface{}) bool { return reflect.ValueOf(v).FieldByName("Valid").Bool() } -// The remaining methods are only used by the enterprise/tailnet.pgCoord. This coordinator explicitly depends on -// postgres triggers that announce changes on the pubsub. Implementing support for this in the fake database would -// strongly couple the fakeQuerier to the pubsub, which is undesirable. Furthermore, it makes little sense to directly -// test the pgCoord against anything other than postgres. The fakeQuerier is designed to allow us to test the Coderd -// API, and for that kind of test, the in-memory, AGPL tailnet coordinator is sufficient. Therefore, these methods -// remain unimplemented in the fakeQuerier. - +// ErrUnimplemented is returned by methods only used by the enterprise/tailnet.pgCoord. This coordinator explicitly +// depends on postgres triggers that announce changes on the pubsub. Implementing support for this in the fake +// database would strongly couple the fakeQuerier to the pubsub, which is undesirable. Furthermore, it makes little +// sense to directly test the pgCoord against anything other than postgres. The fakeQuerier is designed to allow us to +// test the Coderd API, and for that kind of test, the in-memory, AGPL tailnet coordinator is sufficient. Therefore, +// these methods remain unimplemented in the fakeQuerier. var ErrUnimplemented = xerrors.New("unimplemented") func (*fakeQuerier) AcquireLock(_ context.Context, _ int64) error { diff --git a/coderd/database/migrations/testdata/fixtures/000127_ha_coordinator.up.sql b/coderd/database/migrations/testdata/fixtures/000127_ha_coordinator.up.sql index 488859dda291f..8af4fa4827997 100644 --- a/coderd/database/migrations/testdata/fixtures/000127_ha_coordinator.up.sql +++ b/coderd/database/migrations/testdata/fixtures/000127_ha_coordinator.up.sql @@ -7,20 +7,22 @@ VALUES ); INSERT INTO tailnet_clients - (id, agent_id, coordinator_id, node) + (id, agent_id, coordinator_id, updated_at, node) VALUES ( 'b0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', 'c0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', + '2023-06-15 10:23:54+00', '{"preferred_derp": 12}'::json ); INSERT INTO tailnet_agents -(id, coordinator_id, node) +(id, coordinator_id, updated_at, node) VALUES ( 'c0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', + '2023-06-15 10:23:54+00', '{"preferred_derp": 13}'::json ); From dc4b30cbb7743fcf3194933af4755539bd205488 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 15 Jun 2023 11:44:25 +0000 Subject: [PATCH 05/11] review comments, run clean gen Signed-off-by: Spike Curtis --- coderd/database/dbauthz/dbauthz.go | 82 +++++++++++++++--------------- enterprise/tailnet/pgcoord.go | 11 ++-- 2 files changed, 46 insertions(+), 47 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 1a9345c549653..75c78c39ef44d 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -707,6 +707,13 @@ func (q *querier) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, u return q.db.DeleteApplicationConnectAPIKeysByUserID(ctx, userID) } +func (q *querier) DeleteCoordinator(ctx context.Context, id uuid.UUID) error { + if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil { + return err + } + return q.db.DeleteCoordinator(ctx, id) +} + func (q *querier) DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error { return deleteQ(q.log, q.auth, q.db.GetGitSSHKey, q.db.DeleteGitSSHKey)(ctx, userID) } @@ -765,6 +772,20 @@ func (q *querier) DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt tim return q.db.DeleteReplicasUpdatedBefore(ctx, updatedAt) } +func (q *querier) DeleteTailnetAgent(ctx context.Context, arg database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { + return database.DeleteTailnetAgentRow{}, err + } + return q.db.DeleteTailnetAgent(ctx, arg) +} + +func (q *querier) DeleteTailnetClient(ctx context.Context, arg database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) { + if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil { + return database.DeleteTailnetClientRow{}, err + } + return q.db.DeleteTailnetClient(ctx, arg) +} + func (q *querier) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) { return fetch(q.log, q.auth, q.db.GetAPIKeyByID)(ctx, id) } @@ -1132,6 +1153,20 @@ func (q *querier) GetServiceBanner(ctx context.Context) (string, error) { return q.db.GetServiceBanner(ctx) } +func (q *querier) GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]database.TailnetAgent, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTailnetCoordinator); err != nil { + return nil, err + } + return q.db.GetTailnetAgents(ctx, id) +} + +func (q *querier) GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]database.TailnetClient, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTailnetCoordinator); err != nil { + return nil, err + } + return q.db.GetTailnetClientsForAgent(ctx, agentID) +} + // Only used by metrics cache. func (q *querier) GetTemplateAverageBuildTime(ctx context.Context, arg database.GetTemplateAverageBuildTimeParams) (database.GetTemplateAverageBuildTimeRow, error) { if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { @@ -2508,13 +2543,6 @@ func (q *querier) UpsertServiceBanner(ctx context.Context, value string) error { return q.db.UpsertServiceBanner(ctx, value) } -func (q *querier) UpsertTailnetClient(ctx context.Context, arg database.UpsertTailnetClientParams) (database.TailnetClient, error) { - if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { - return database.TailnetClient{}, err - } - return q.db.UpsertTailnetClient(ctx, arg) -} - func (q *querier) UpsertTailnetAgent(ctx context.Context, arg database.UpsertTailnetAgentParams) (database.TailnetAgent, error) { if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { return database.TailnetAgent{}, err @@ -2522,44 +2550,16 @@ func (q *querier) UpsertTailnetAgent(ctx context.Context, arg database.UpsertTai return q.db.UpsertTailnetAgent(ctx, arg) } -func (q *querier) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (database.TailnetCoordinator, error) { +func (q *querier) UpsertTailnetClient(ctx context.Context, arg database.UpsertTailnetClientParams) (database.TailnetClient, error) { if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { - return database.TailnetCoordinator{}, err - } - return q.db.UpsertTailnetCoordinator(ctx, id) -} - -func (q *querier) DeleteTailnetClient(ctx context.Context, arg database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) { - if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil { - return database.DeleteTailnetClientRow{}, err + return database.TailnetClient{}, err } - return q.db.DeleteTailnetClient(ctx, arg) + return q.db.UpsertTailnetClient(ctx, arg) } -func (q *querier) DeleteTailnetAgent(ctx context.Context, arg database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) { +func (q *querier) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (database.TailnetCoordinator, error) { if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { - return database.DeleteTailnetAgentRow{}, err - } - return q.db.DeleteTailnetAgent(ctx, arg) -} - -func (q *querier) DeleteCoordinator(ctx context.Context, id uuid.UUID) error { - if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil { - return err - } - return q.db.DeleteCoordinator(ctx, id) -} - -func (q *querier) GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]database.TailnetAgent, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTailnetCoordinator); err != nil { - return nil, err - } - return q.db.GetTailnetAgents(ctx, id) -} - -func (q *querier) GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]database.TailnetClient, error) { - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTailnetCoordinator); err != nil { - return nil, err + return database.TailnetCoordinator{}, err } - return q.db.GetTailnetClientsForAgent(ctx, agentID) + return q.db.UpsertTailnetCoordinator(ctx, id) } diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 8f4c43269e602..5e789c3539ed6 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -131,7 +131,7 @@ func (c *pgCoord) Node(id uuid.UUID) *agpl.Node { return bestN } -func (c *pgCoord) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) (retErr error) { +func (c *pgCoord) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { defer func() { err := conn.Close() if err != nil { @@ -150,7 +150,7 @@ func (c *pgCoord) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) (ret return nil } -func (c *pgCoord) ServeAgent(conn net.Conn, id uuid.UUID, name string) (retErr error) { +func (c *pgCoord) ServeAgent(conn net.Conn, id uuid.UUID, name string) error { defer func() { err := conn.Close() if err != nil { @@ -171,7 +171,6 @@ func (c *pgCoord) ServeAgent(conn net.Conn, id uuid.UUID, name string) (retErr e func (c *pgCoord) Close() error { c.cancel() - // do we need to wait for the binder to complete? c.closeOnce.Do(func() { close(c.closed) }) return nil } @@ -1083,7 +1082,7 @@ func (h *heartbeats) subscribe() { eb.MaxInterval = dbMaxBackoff bkoff := backoff.WithContext(eb, h.ctx) var cancel context.CancelFunc - err := backoff.Retry(func() error { + bErr := backoff.Retry(func() error { cancelFn, err := h.pubsub.SubscribeWithErr(EventHeartbeats, h.listen) if err != nil { h.logger.Warn(h.ctx, "failed to subscribe to heartbeats", slog.Error(err)) @@ -1092,7 +1091,7 @@ func (h *heartbeats) subscribe() { cancel = cancelFn return nil }, bkoff) - if err != nil { + if bErr != nil { // this should only happen if context is canceled return } @@ -1198,7 +1197,7 @@ func (h *heartbeats) sendBeat() { } func (h *heartbeats) sendDelete() { - // here we don't want to use the main context, since it will have been c + // here we don't want to use the main context, since it will have been canceled err := h.store.DeleteCoordinator(context.Background(), h.self) if err != nil { h.logger.Error(h.ctx, "failed to send coordinator delete", slog.Error(err)) From fb87ea485308033476e380c481bfb9de68e26362 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Tue, 20 Jun 2023 10:47:05 +0000 Subject: [PATCH 06/11] Rename waitForConn -> cleanupConn Signed-off-by: Spike Curtis --- enterprise/tailnet/pgcoord.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 5e789c3539ed6..bbf2526eaac55 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -628,10 +628,10 @@ func (q *querier) newConn(c *connIO) { return } cm.count++ - go q.waitForConn(c) + go q.cleanupConn(c) } -func (q *querier) waitForConn(c *connIO) { +func (q *querier) cleanupConn(c *connIO) { <-c.ctx.Done() q.mu.Lock() defer q.mu.Unlock() From 95d01cf1b0c7a132e3122b61e0592bede7c5ba70 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 21 Jun 2023 11:05:35 +0000 Subject: [PATCH 07/11] code review updates Signed-off-by: Spike Curtis --- coderd/database/dbtestutil/db.go | 7 ++++++- coderd/database/dump.sql | 14 +++++++------ .../migrations/000127_ha_coordinator.down.sql | 12 +++++------ .../migrations/000127_ha_coordinator.up.sql | 20 ++++++++++--------- coderd/database/models.go | 1 + enterprise/tailnet/pgcoord.go | 16 ++++++++++----- enterprise/tailnet/pgcoord_test.go | 15 +++++++------- 7 files changed, 50 insertions(+), 35 deletions(-) diff --git a/coderd/database/dbtestutil/db.go b/coderd/database/dbtestutil/db.go index 932e4aaf4739a..ad8cecf143240 100644 --- a/coderd/database/dbtestutil/db.go +++ b/coderd/database/dbtestutil/db.go @@ -14,12 +14,17 @@ import ( "github.com/coder/coder/coderd/database/pubsub" ) +// WillUsePostgres returns true if a call to NewDB() will return a real, postgres-backed Store and Pubsub. +func WillUsePostgres() bool { + return os.Getenv("DB") != "" +} + func NewDB(t testing.TB) (database.Store, pubsub.Pubsub) { t.Helper() db := dbfake.New() ps := pubsub.NewInMemory() - if os.Getenv("DB") != "" { + if WillUsePostgres() { connectionURL := os.Getenv("CODER_PG_CONNECTION_URL") if connectionURL == "" { var ( diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index e9f91c1a67d99..8a275171a9c66 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -171,7 +171,7 @@ BEGIN END; $$; -CREATE FUNCTION notify_agent_change() RETURNS trigger +CREATE FUNCTION tailnet_notify_agent_change() RETURNS trigger LANGUAGE plpgsql AS $$ BEGIN @@ -186,7 +186,7 @@ BEGIN END; $$; -CREATE FUNCTION notify_client_change() RETURNS trigger +CREATE FUNCTION tailnet_notify_client_change() RETURNS trigger LANGUAGE plpgsql AS $$ BEGIN @@ -201,7 +201,7 @@ BEGIN END; $$; -CREATE FUNCTION notify_coordinator_heartbeat() RETURNS trigger +CREATE FUNCTION tailnet_notify_coordinator_heartbeat() RETURNS trigger LANGUAGE plpgsql AS $$ BEGIN @@ -442,6 +442,8 @@ CREATE TABLE tailnet_coordinators ( heartbeat_at timestamp with time zone NOT NULL ); +COMMENT ON TABLE tailnet_coordinators IS 'We keep this separate from replicas in case we need to break the coordinator out into its own service'; + CREATE TABLE template_version_parameters ( template_version_id uuid NOT NULL, name text NOT NULL, @@ -1015,11 +1017,11 @@ CREATE INDEX workspace_resources_job_id_idx ON workspace_resources USING btree ( CREATE UNIQUE INDEX workspaces_owner_id_lower_idx ON workspaces USING btree (owner_id, lower((name)::text)) WHERE (deleted = false); -CREATE TRIGGER notify_agent_change AFTER INSERT OR DELETE OR UPDATE ON tailnet_agents FOR EACH ROW EXECUTE FUNCTION notify_agent_change(); +CREATE TRIGGER tailnet_notify_agent_change AFTER INSERT OR DELETE OR UPDATE ON tailnet_agents FOR EACH ROW EXECUTE FUNCTION tailnet_notify_agent_change(); -CREATE TRIGGER notify_client_change AFTER INSERT OR DELETE OR UPDATE ON tailnet_clients FOR EACH ROW EXECUTE FUNCTION notify_client_change(); +CREATE TRIGGER tailnet_notify_client_change AFTER INSERT OR DELETE OR UPDATE ON tailnet_clients FOR EACH ROW EXECUTE FUNCTION tailnet_notify_client_change(); -CREATE TRIGGER notify_coordinator_heartbeat AFTER INSERT OR UPDATE ON tailnet_coordinators FOR EACH ROW EXECUTE FUNCTION notify_coordinator_heartbeat(); +CREATE TRIGGER tailnet_notify_coordinator_heartbeat AFTER INSERT OR UPDATE ON tailnet_coordinators FOR EACH ROW EXECUTE FUNCTION tailnet_notify_coordinator_heartbeat(); CREATE TRIGGER trigger_insert_apikeys BEFORE INSERT ON api_keys FOR EACH ROW EXECUTE FUNCTION insert_apikey_fail_if_user_deleted(); diff --git a/coderd/database/migrations/000127_ha_coordinator.down.sql b/coderd/database/migrations/000127_ha_coordinator.down.sql index 5bf7b888b347b..54c8b0253902b 100644 --- a/coderd/database/migrations/000127_ha_coordinator.down.sql +++ b/coderd/database/migrations/000127_ha_coordinator.down.sql @@ -1,18 +1,18 @@ BEGIN; -DROP TRIGGER IF EXISTS notify_client_change ON tailnet_clients; -DROP FUNCTION IF EXISTS notify_client_change; +DROP TRIGGER IF EXISTS tailnet_notify_client_change ON tailnet_clients; +DROP FUNCTION IF EXISTS tailnet_notify_client_change; DROP INDEX IF EXISTS idx_tailnet_clients_agent; DROP INDEX IF EXISTS idx_tailnet_clients_coordinator; DROP TABLE tailnet_clients; -DROP TRIGGER IF EXISTS notify_agent_change ON tailnet_agents; -DROP FUNCTION IF EXISTS notify_agent_change; +DROP TRIGGER IF EXISTS tailnet_notify_agent_change ON tailnet_agents; +DROP FUNCTION IF EXISTS tailnet_notify_agent_change; DROP INDEX IF EXISTS idx_tailnet_agents_coordinator; DROP TABLE IF EXISTS tailnet_agents; -DROP TRIGGER IF EXISTS notify_coordinator_heartbeat ON tailnet_coordinators; -DROP FUNCTION IF EXISTS notify_coordinator_heartbeat; +DROP TRIGGER IF EXISTS tailnet_notify_coordinator_heartbeat ON tailnet_coordinators; +DROP FUNCTION IF EXISTS tailnet_notify_coordinator_heartbeat; DROP TABLE IF EXISTS tailnet_coordinators; COMMIT; diff --git a/coderd/database/migrations/000127_ha_coordinator.up.sql b/coderd/database/migrations/000127_ha_coordinator.up.sql index 8200c9401506f..f30bd077c798b 100644 --- a/coderd/database/migrations/000127_ha_coordinator.up.sql +++ b/coderd/database/migrations/000127_ha_coordinator.up.sql @@ -5,6 +5,8 @@ CREATE TABLE tailnet_coordinators ( heartbeat_at timestamp with time zone NOT NULL ); +COMMENT ON TABLE tailnet_coordinators IS 'We keep this separate from replicas in case we need to break the coordinator out into its own service'; + CREATE TABLE tailnet_clients ( id uuid NOT NULL, coordinator_id uuid NOT NULL, @@ -36,7 +38,7 @@ CREATE TABLE tailnet_agents ( CREATE INDEX idx_tailnet_agents_coordinator ON tailnet_agents (coordinator_id); -- Any time the tailnet_clients table changes, send an update with the affected client and agent IDs -CREATE FUNCTION notify_client_change() RETURNS trigger +CREATE FUNCTION tailnet_notify_client_change() RETURNS trigger LANGUAGE plpgsql AS $$ BEGIN @@ -51,13 +53,13 @@ BEGIN END; $$; -CREATE TRIGGER notify_client_change +CREATE TRIGGER tailnet_notify_client_change AFTER INSERT OR UPDATE OR DELETE ON tailnet_clients FOR EACH ROW -EXECUTE PROCEDURE notify_client_change(); +EXECUTE PROCEDURE tailnet_notify_client_change(); -- Any time tailnet_agents table changes, send an update with the affected agent ID. -CREATE FUNCTION notify_agent_change() RETURNS trigger +CREATE FUNCTION tailnet_notify_agent_change() RETURNS trigger LANGUAGE plpgsql AS $$ BEGIN @@ -72,13 +74,13 @@ BEGIN END; $$; -CREATE TRIGGER notify_agent_change +CREATE TRIGGER tailnet_notify_agent_change AFTER INSERT OR UPDATE OR DELETE ON tailnet_agents FOR EACH ROW -EXECUTE PROCEDURE notify_agent_change(); +EXECUTE PROCEDURE tailnet_notify_agent_change(); -- Send coordinator heartbeats -CREATE FUNCTION notify_coordinator_heartbeat() RETURNS trigger +CREATE FUNCTION tailnet_notify_coordinator_heartbeat() RETURNS trigger LANGUAGE plpgsql AS $$ BEGIN @@ -87,9 +89,9 @@ BEGIN END; $$; -CREATE TRIGGER notify_coordinator_heartbeat +CREATE TRIGGER tailnet_notify_coordinator_heartbeat AFTER INSERT OR UPDATE ON tailnet_coordinators FOR EACH ROW -EXECUTE PROCEDURE notify_coordinator_heartbeat(); +EXECUTE PROCEDURE tailnet_notify_coordinator_heartbeat(); COMMIT; diff --git a/coderd/database/models.go b/coderd/database/models.go index d23ff149eb873..310b5d387496b 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -1549,6 +1549,7 @@ type TailnetClient struct { Node json.RawMessage `db:"node" json:"node"` } +// We keep this separate from replicas in case we need to break the coordinator out into its own service type TailnetCoordinator struct { ID uuid.UUID `db:"id" json:"id"` HeartbeatAt time.Time `db:"heartbeat_at" json:"heartbeat_at"` diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index bbf2526eaac55..a8d62ea87e1ee 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -700,7 +700,7 @@ func (q *querier) query(mk mKey) error { func (q *querier) queryClientsOfAgent(agent uuid.UUID) ([]mapping, error) { clients, err := q.store.GetTailnetClientsForAgent(q.ctx, agent) - if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + if err != nil { return nil, err } mappings := make([]mapping, 0, len(clients)) @@ -724,7 +724,7 @@ func (q *querier) queryClientsOfAgent(agent uuid.UUID) ([]mapping, error) { func (q *querier) queryAgent(agentID uuid.UUID) ([]mapping, error) { agents, err := q.store.GetTailnetAgents(q.ctx, agentID) - if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + if err != nil { return nil, err } mappings := make([]mapping, 0, len(agents)) @@ -761,7 +761,9 @@ func (q *querier) subscribe() { return nil }, bkoff) if err != nil { - // this should only happen if context is canceled + if q.ctx.Err() == nil { + q.logger.Error(q.ctx, "code bug: retry failed before context canceled", slog.Error(err)) + } return } defer cancelClient() @@ -778,7 +780,9 @@ func (q *querier) subscribe() { return nil }, bkoff) if err != nil { - // this should only happen if context is canceled + if q.ctx.Err() == nil { + q.logger.Error(q.ctx, "code bug: retry failed before context canceled", slog.Error(err)) + } return } defer cancelAgent() @@ -1092,7 +1096,9 @@ func (h *heartbeats) subscribe() { return nil }, bkoff) if bErr != nil { - // this should only happen if context is canceled + if h.ctx.Err() == nil { + h.logger.Error(h.ctx, "code bug: retry failed before context canceled", slog.Error(bErr)) + } return } // cancel subscription when context finishes diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index e55560d7f3fb5..b2fbb8d0f9845 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -6,7 +6,6 @@ import ( "encoding/json" "io" "net" - "os" "sync" "testing" "time" @@ -33,7 +32,7 @@ func TestMain(m *testing.M) { func TestPGCoordinatorSingle_ClientWithoutAgent(t *testing.T) { t.Parallel() - if os.Getenv("DB") == "" { + if !dbtestutil.WillUsePostgres() { t.Skip("test only with postgres") } store, pubsub := dbtestutil.NewDB(t) @@ -72,7 +71,7 @@ func TestPGCoordinatorSingle_ClientWithoutAgent(t *testing.T) { func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) { t.Parallel() - if os.Getenv("DB") == "" { + if !dbtestutil.WillUsePostgres() { t.Skip("test only with postgres") } store, pubsub := dbtestutil.NewDB(t) @@ -109,7 +108,7 @@ func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) { func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) { t.Parallel() - if os.Getenv("DB") == "" { + if !dbtestutil.WillUsePostgres() { t.Skip("test only with postgres") } store, pubsub := dbtestutil.NewDB(t) @@ -186,7 +185,7 @@ func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) { func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) { t.Parallel() - if os.Getenv("DB") == "" { + if !dbtestutil.WillUsePostgres() { t.Skip("test only with postgres") } store, pubsub := dbtestutil.NewDB(t) @@ -245,7 +244,7 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) { func TestPGCoordinatorSingle_SendsHeartbeats(t *testing.T) { t.Parallel() - if os.Getenv("DB") == "" { + if !dbtestutil.WillUsePostgres() { t.Skip("test only with postgres") } store, pubsub := dbtestutil.NewDB(t) @@ -295,7 +294,7 @@ func TestPGCoordinatorSingle_SendsHeartbeats(t *testing.T) { // +---------+ func TestPGCoordinatorDual_Mainline(t *testing.T) { t.Parallel() - if os.Getenv("DB") == "" { + if !dbtestutil.WillUsePostgres() { t.Skip("test only with postgres") } store, pubsub := dbtestutil.NewDB(t) @@ -432,7 +431,7 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) { // +---------+ func TestPGCoordinator_MultiAgent(t *testing.T) { t.Parallel() - if os.Getenv("DB") == "" { + if !dbtestutil.WillUsePostgres() { t.Skip("test only with postgres") } store, pubsub := dbtestutil.NewDB(t) From 689d01459c3749cf49c4e365ccc085cb70d0085b Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 21 Jun 2023 11:10:53 +0000 Subject: [PATCH 08/11] db migration order Signed-off-by: Spike Curtis --- ...127_ha_coordinator.down.sql => 000130_ha_coordinator.down.sql} | 0 ...{000127_ha_coordinator.up.sql => 000130_ha_coordinator.up.sql} | 0 ...{000127_ha_coordinator.up.sql => 000130_ha_coordinator.up.sql} | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename coderd/database/migrations/{000127_ha_coordinator.down.sql => 000130_ha_coordinator.down.sql} (100%) rename coderd/database/migrations/{000127_ha_coordinator.up.sql => 000130_ha_coordinator.up.sql} (100%) rename coderd/database/migrations/testdata/fixtures/{000127_ha_coordinator.up.sql => 000130_ha_coordinator.up.sql} (100%) diff --git a/coderd/database/migrations/000127_ha_coordinator.down.sql b/coderd/database/migrations/000130_ha_coordinator.down.sql similarity index 100% rename from coderd/database/migrations/000127_ha_coordinator.down.sql rename to coderd/database/migrations/000130_ha_coordinator.down.sql diff --git a/coderd/database/migrations/000127_ha_coordinator.up.sql b/coderd/database/migrations/000130_ha_coordinator.up.sql similarity index 100% rename from coderd/database/migrations/000127_ha_coordinator.up.sql rename to coderd/database/migrations/000130_ha_coordinator.up.sql diff --git a/coderd/database/migrations/testdata/fixtures/000127_ha_coordinator.up.sql b/coderd/database/migrations/testdata/fixtures/000130_ha_coordinator.up.sql similarity index 100% rename from coderd/database/migrations/testdata/fixtures/000127_ha_coordinator.up.sql rename to coderd/database/migrations/testdata/fixtures/000130_ha_coordinator.up.sql From 197b5a6635db62c3e567398f1b545a681eea04e3 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 21 Jun 2023 11:26:27 +0000 Subject: [PATCH 09/11] fix log field name last_heartbeat Signed-off-by: Spike Curtis --- enterprise/tailnet/pgcoord.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index a8d62ea87e1ee..1168472b91ee4 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -1161,11 +1161,11 @@ func (h *heartbeats) checkExpiry() { expired := false for id, t := range h.coordinators { lastHB := now.Sub(t) - h.logger.Debug(h.ctx, "last heartbeat from coordinator", slog.F("other_coordinator", id), slog.F("last heartbeat", lastHB)) + h.logger.Debug(h.ctx, "last heartbeat from coordinator", slog.F("other_coordinator", id), slog.F("last_heartbeat", lastHB)) if lastHB > MissedHeartbeats*HeartbeatPeriod { expired = true delete(h.coordinators, id) - h.logger.Info(h.ctx, "coordinator failed heartbeat check", slog.F("other_coordinator", id), slog.F("last heartbeat", lastHB)) + h.logger.Info(h.ctx, "coordinator failed heartbeat check", slog.F("other_coordinator", id), slog.F("last_heartbeat", lastHB)) } } h.lock.Unlock() From 7edda4de2b32710485e57d5c1908a1732a681075 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 21 Jun 2023 11:27:47 +0000 Subject: [PATCH 10/11] fix heartbeat_from log field Signed-off-by: Spike Curtis --- enterprise/tailnet/pgcoord.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 1168472b91ee4..58bb297cd3091 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -1128,7 +1128,7 @@ func (h *heartbeats) listen(_ context.Context, msg []byte, err error) { } func (h *heartbeats) recvBeat(id uuid.UUID) { - h.logger.Debug(h.ctx, "got heartbeat", slog.F("heartbeat_from", id)) + h.logger.Debug(h.ctx, "got heartbeat", slog.F("heartbeat_from_id", id)) h.lock.Lock() defer h.lock.Unlock() var oldestTime time.Time From 2c757c21de178ccc86b1e1d5a72cb3c0c332259c Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 21 Jun 2023 11:33:34 +0000 Subject: [PATCH 11/11] fix slog fields for linting Signed-off-by: Spike Curtis --- enterprise/tailnet/pgcoord.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 58bb297cd3091..3fca584c28f97 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -1128,7 +1128,7 @@ func (h *heartbeats) listen(_ context.Context, msg []byte, err error) { } func (h *heartbeats) recvBeat(id uuid.UUID) { - h.logger.Debug(h.ctx, "got heartbeat", slog.F("heartbeat_from_id", id)) + h.logger.Debug(h.ctx, "got heartbeat", slog.F("other_coordinator_id", id)) h.lock.Lock() defer h.lock.Unlock() var oldestTime time.Time @@ -1161,11 +1161,11 @@ func (h *heartbeats) checkExpiry() { expired := false for id, t := range h.coordinators { lastHB := now.Sub(t) - h.logger.Debug(h.ctx, "last heartbeat from coordinator", slog.F("other_coordinator", id), slog.F("last_heartbeat", lastHB)) + h.logger.Debug(h.ctx, "last heartbeat from coordinator", slog.F("other_coordinator_id", id), slog.F("last_heartbeat", lastHB)) if lastHB > MissedHeartbeats*HeartbeatPeriod { expired = true delete(h.coordinators, id) - h.logger.Info(h.ctx, "coordinator failed heartbeat check", slog.F("other_coordinator", id), slog.F("last_heartbeat", lastHB)) + h.logger.Info(h.ctx, "coordinator failed heartbeat check", slog.F("other_coordinator_id", id), slog.F("last_heartbeat", lastHB)) } } h.lock.Unlock()