From 35b2fed6860da12707266979f14286f73fe1468a Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Thu, 22 Sep 2022 17:40:59 -0500 Subject: [PATCH 01/10] feat: HA tailnet coordinator --- agent/agent_test.go | 2 +- coderd/coderd.go | 4 +- coderd/database/pubsub_memory.go | 3 +- coderd/workspaceagents.go | 2 +- coderd/wsconncache/wsconncache_test.go | 2 +- codersdk/workspaceagents.go | 1 - enterprise/tailnet/coordinator.go | 426 +++++++++++++++++++++++++ enterprise/tailnet/coordinator_test.go | 267 ++++++++++++++++ tailnet/coordinator.go | 203 +++++++----- tailnet/coordinator_test.go | 10 +- 10 files changed, 834 insertions(+), 86 deletions(-) create mode 100644 enterprise/tailnet/coordinator.go create mode 100644 enterprise/tailnet/coordinator_test.go diff --git a/agent/agent_test.go b/agent/agent_test.go index afed644f78e5e..d6ff21cdcd33d 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -572,7 +572,7 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) if metadata.DERPMap == nil { metadata.DERPMap = tailnettest.RunDERPAndSTUN(t) } - coordinator := tailnet.NewCoordinator() + coordinator := tailnet.NewMemoryCoordinator() agentID := uuid.New() statsCh := make(chan *agent.Stats) closer := agent.New(agent.Options{ diff --git a/coderd/coderd.go b/coderd/coderd.go index 25ac1afec2f36..f183e4d9b9ab7 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -74,7 +74,7 @@ type Options struct { TracerProvider trace.TracerProvider AutoImportTemplates []AutoImportTemplate - TailnetCoordinator *tailnet.Coordinator + TailnetCoordinator tailnet.Coordinator DERPMap *tailcfg.DERPMap MetricsCacheRefreshInterval time.Duration @@ -121,7 +121,7 @@ func New(options *Options) *API { options.PrometheusRegistry = prometheus.NewRegistry() } if options.TailnetCoordinator == nil { - options.TailnetCoordinator = tailnet.NewCoordinator() + options.TailnetCoordinator = tailnet.NewMemoryCoordinator() } if options.Auditor == nil { options.Auditor = audit.NewNop() diff --git a/coderd/database/pubsub_memory.go b/coderd/database/pubsub_memory.go index 148d2f57b129f..de5a940414d6c 100644 --- a/coderd/database/pubsub_memory.go +++ b/coderd/database/pubsub_memory.go @@ -47,8 +47,9 @@ func (m *memoryPubsub) Publish(event string, message []byte) error { return nil } for _, listener := range listeners { - listener(context.Background(), message) + go listener(context.Background(), message) } + return nil } diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 6167790fb8bb7..dd777913c452d 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -447,7 +447,7 @@ func convertApps(dbApps []database.WorkspaceApp) []codersdk.WorkspaceApp { return apps } -func convertWorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator *tailnet.Coordinator, dbAgent database.WorkspaceAgent, apps []codersdk.WorkspaceApp, agentInactiveDisconnectTimeout time.Duration) (codersdk.WorkspaceAgent, error) { +func convertWorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator tailnet.Coordinator, dbAgent database.WorkspaceAgent, apps []codersdk.WorkspaceApp, agentInactiveDisconnectTimeout time.Duration) (codersdk.WorkspaceAgent, error) { var envs map[string]string if dbAgent.EnvironmentVariables.Valid { err := json.Unmarshal(dbAgent.EnvironmentVariables.RawMessage, &envs) diff --git a/coderd/wsconncache/wsconncache_test.go b/coderd/wsconncache/wsconncache_test.go index a9ea85a2492ac..e4c7d58413110 100644 --- a/coderd/wsconncache/wsconncache_test.go +++ b/coderd/wsconncache/wsconncache_test.go @@ -142,7 +142,7 @@ func TestCache(t *testing.T) { func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) *agent.Conn { metadata.DERPMap = tailnettest.RunDERPAndSTUN(t) - coordinator := tailnet.NewCoordinator() + coordinator := tailnet.NewMemoryCoordinator() agentID := uuid.New() closer := agent.New(agent.Options{ FetchMetadata: func(ctx context.Context) (agent.Metadata, error) { diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 46d8ead8d2d6d..72e9767713c7c 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -20,7 +20,6 @@ import ( "tailscale.com/tailcfg" "cdr.dev/slog" - "github.com/coder/coder/agent" "github.com/coder/coder/tailnet" "github.com/coder/retry" diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go new file mode 100644 index 0000000000000..8824f584d60da --- /dev/null +++ b/enterprise/tailnet/coordinator.go @@ -0,0 +1,426 @@ +package tailnet + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "sync" + "time" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "cdr.dev/slog" + + "github.com/coder/coder/coderd/database" + agpl "github.com/coder/coder/tailnet" +) + +func NewHACoordinator(logger slog.Logger, pubsub database.Pubsub) (agpl.Coordinator, error) { + coord := &haCoordinator{ + id: uuid.New(), + log: logger, + pubsub: pubsub, + close: make(chan struct{}), + nodes: map[uuid.UUID]*agpl.Node{}, + agentSockets: map[uuid.UUID]net.Conn{}, + agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]net.Conn{}, + } + + if err := coord.runPubsub(); err != nil { + return nil, xerrors.Errorf("run coordinator pubsub: %w", err) + } + + return coord, nil +} + +type haCoordinator struct { + id uuid.UUID + log slog.Logger + mutex sync.RWMutex + pubsub database.Pubsub + close chan struct{} + + // nodes maps agent and connection IDs their respective node. + nodes map[uuid.UUID]*agpl.Node + // agentSockets maps agent IDs to their open websocket. + agentSockets map[uuid.UUID]net.Conn + // agentToConnectionSockets maps agent IDs to connection IDs of conns that + // are subscribed to updates for that agent. + agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]net.Conn +} + +// Node returns an in-memory node by ID. +func (c *haCoordinator) Node(id uuid.UUID) *agpl.Node { + c.mutex.RLock() + defer c.mutex.RUnlock() + node := c.nodes[id] + return node +} + +// ServeClient accepts a WebSocket connection that wants to connect to an agent +// with the specified ID. +func (c *haCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { + c.mutex.Lock() + // When a new connection is requested, we update it with the latest + // node of the agent. This allows the connection to establish. + node, ok := c.nodes[agent] + if ok { + data, err := json.Marshal([]*agpl.Node{node}) + if err != nil { + c.mutex.Unlock() + return xerrors.Errorf("marshal node: %w", err) + } + _, err = conn.Write(data) + if err != nil { + c.mutex.Unlock() + return xerrors.Errorf("write nodes: %w", err) + } + } + + connectionSockets, ok := c.agentToConnectionSockets[agent] + if !ok { + connectionSockets = map[uuid.UUID]net.Conn{} + c.agentToConnectionSockets[agent] = connectionSockets + } + + // Insert this connection into a map so the agent can publish node updates. + connectionSockets[id] = conn + c.mutex.Unlock() + + defer func() { + c.mutex.Lock() + defer c.mutex.Unlock() + // Clean all traces of this connection from the map. + delete(c.nodes, id) + connectionSockets, ok := c.agentToConnectionSockets[agent] + if !ok { + return + } + delete(connectionSockets, id) + if len(connectionSockets) != 0 { + return + } + delete(c.agentToConnectionSockets, agent) + }() + + decoder := json.NewDecoder(conn) + // Indefinitely handle messages from the client websocket. + for { + err := c.handleNextClientMessage(id, agent, decoder) + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } + return xerrors.Errorf("handle next client message: %w", err) + } + } +} + +func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error { + var node agpl.Node + err := decoder.Decode(&node) + if err != nil { + return xerrors.Errorf("read json: %w", err) + } + + c.mutex.Lock() + defer c.mutex.Unlock() + + // Update the node of this client in our in-memory map. If an agent entirely + // shuts down and reconnects, it needs to be aware of all clients attempting + // to establish connections. + c.nodes[id] = &node + + // Write the new node from this client to the actively connected agent. + err = c.writeNodeToAgent(agent, &node) + if err != nil { + return xerrors.Errorf("write node to agent: %w", err) + } + + return nil +} + +func (c *haCoordinator) writeNodeToAgent(agent uuid.UUID, node *agpl.Node) error { + agentSocket, ok := c.agentSockets[agent] + if !ok { + // If we don't own the agent locally, send it over pubsub to a node that + // owns the agent. + err := c.publishNodeToAgent(agent, node) + if err != nil { + return xerrors.Errorf("publish node to agent") + } + return nil + } + + // Write the new node from this client to the actively + // connected agent. + data, err := json.Marshal([]*agpl.Node{node}) + if err != nil { + return xerrors.Errorf("marshal nodes: %w", err) + } + + _, err = agentSocket.Write(data) + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } + return xerrors.Errorf("write json: %w", err) + } + return nil +} + +// ServeAgent accepts a WebSocket connection to an agent that listens to +// incoming connections and publishes node updates. +func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID) error { + c.mutex.Lock() + sockets, ok := c.agentToConnectionSockets[id] + if ok { + // Publish all nodes that want to connect to the + // desired agent ID. + nodes := make([]*agpl.Node, 0, len(sockets)) + for targetID := range sockets { + node, ok := c.nodes[targetID] + if !ok { + continue + } + nodes = append(nodes, node) + } + data, err := json.Marshal(nodes) + if err != nil { + c.mutex.Unlock() + return xerrors.Errorf("marshal json: %w", err) + } + _, err = conn.Write(data) + if err != nil { + c.mutex.Unlock() + return xerrors.Errorf("write nodes: %w", err) + } + } + + // If an old agent socket is connected, we close it + // to avoid any leaks. This shouldn't ever occur because + // we expect one agent to be running. + oldAgentSocket, ok := c.agentSockets[id] + if ok { + _ = oldAgentSocket.Close() + } + c.agentSockets[id] = conn + c.mutex.Unlock() + defer func() { + c.mutex.Lock() + defer c.mutex.Unlock() + delete(c.agentSockets, id) + delete(c.nodes, id) + }() + + decoder := json.NewDecoder(conn) + for { + err := c.hangleAgentUpdate(id, decoder, false) + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } + return xerrors.Errorf("handle next agent message: %w", err) + } + } +} + +func (c *haCoordinator) hangleAgentUpdate(id uuid.UUID, decoder *json.Decoder, fromPubsub bool) error { + var node agpl.Node + err := decoder.Decode(&node) + if err != nil { + return xerrors.Errorf("read json: %w", err) + } + + c.mutex.Lock() + defer c.mutex.Unlock() + + c.nodes[id] = &node + + // Don't send the agent back over pubsub if that's where we received it from! + if !fromPubsub { + err = c.publishAgentToNodes(id, &node) + if err != nil { + return xerrors.Errorf("publish agent to nodes: %w", err) + } + } + + connectionSockets, ok := c.agentToConnectionSockets[id] + if !ok { + return nil + } + + data, err := json.Marshal([]*agpl.Node{&node}) + if err != nil { + return xerrors.Errorf("marshal nodes: %w", err) + } + + // Publish the new node to every listening socket. + var wg sync.WaitGroup + wg.Add(len(connectionSockets)) + for _, connectionSocket := range connectionSockets { + connectionSocket := connectionSocket + go func() { + _ = connectionSocket.SetWriteDeadline(time.Now().Add(5 * time.Second)) + _, _ = connectionSocket.Write(data) + wg.Done() + }() + } + + wg.Wait() + return nil +} + +func (c *haCoordinator) Close() error { + close(c.close) + return nil +} + +func (c *haCoordinator) publishNodeToAgent(recipient uuid.UUID, node *agpl.Node) error { + msg, err := c.formatCallMeMaybe(recipient, node) + if err != nil { + return xerrors.Errorf("format publish message: %w", err) + } + + fmt.Println("publishing callmemaybe", c.id.String()) + err = c.pubsub.Publish("wireguard_peers", msg) + if err != nil { + return xerrors.Errorf("publish message: %w", err) + } + + return nil +} + +func (c *haCoordinator) publishAgentToNodes(id uuid.UUID, node *agpl.Node) error { + msg, err := c.formatAgentUpdate(id, node) + if err != nil { + return xerrors.Errorf("format publish message: %w", err) + } + + fmt.Println("publishing agentupdate", c.id.String()) + err = c.pubsub.Publish("wireguard_peers", msg) + if err != nil { + return xerrors.Errorf("publish message: %w", err) + } + + return nil +} + +func (c *haCoordinator) runPubsub() error { + cancelSub, err := c.pubsub.Subscribe("wireguard_peers", func(ctx context.Context, message []byte) { + sp := bytes.Split(message, []byte("|")) + if len(sp) != 4 { + c.log.Error(ctx, "invalid wireguard peer message", slog.F("msg", string(message))) + return + } + + var ( + coordinatorID = sp[0] + eventType = sp[1] + agentID = sp[2] + nodeJSON = sp[3] + ) + + sender, err := uuid.ParseBytes(coordinatorID) + if err != nil { + c.log.Error(ctx, "invalid sender id", slog.F("id", string(coordinatorID)), slog.F("msg", string(message))) + return + } + + // We sent this message! + if sender == c.id { + return + } + + switch string(eventType) { + case "callmemaybe": + agentUUID, err := uuid.ParseBytes(agentID) + if err != nil { + c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID))) + return + } + + fmt.Println("got callmemaybe", agentUUID.String()) + c.mutex.Lock() + defer c.mutex.Unlock() + + fmt.Println("process callmemaybe", agentUUID.String()) + agentSocket, ok := c.agentSockets[agentUUID] + if !ok { + fmt.Println("no socket") + return + } + + // We get a single node over pubsub, so turn into an array. + _, err = agentSocket.Write(bytes.Join([][]byte{[]byte("["), nodeJSON, []byte("]")}, []byte{})) + if err != nil { + if errors.Is(err, io.EOF) { + return + } + c.log.Error(ctx, "send callmemaybe to agent", slog.Error(err)) + return + } + fmt.Println("success callmemaybe", agentUUID.String()) + + case "agentupdate": + agentUUID, err := uuid.ParseBytes(agentID) + if err != nil { + c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID))) + } + + decoder := json.NewDecoder(bytes.NewReader(nodeJSON)) + err = c.hangleAgentUpdate(agentUUID, decoder, true) + if err != nil { + c.log.Error(ctx, "handle agent update", slog.Error(err)) + } + + default: + c.log.Error(ctx, "unknown peer event", slog.F("name", string(eventType))) + } + }) + if err != nil { + return xerrors.Errorf("subscribe wireguard peers") + } + + go func() { + defer cancelSub() + <-c.close + }() + + return nil +} + +// format: |callmemaybe|| +func (c *haCoordinator) formatCallMeMaybe(recipient uuid.UUID, node *agpl.Node) ([]byte, error) { + buf := bytes.Buffer{} + + buf.WriteString(c.id.String() + "|") + buf.WriteString("callmemaybe|") + buf.WriteString(recipient.String() + "|") + err := json.NewEncoder(&buf).Encode(node) + if err != nil { + return nil, xerrors.Errorf("encode node: %w", err) + } + + return buf.Bytes(), nil +} + +// format: |agentupdate|| +func (c *haCoordinator) formatAgentUpdate(id uuid.UUID, node *agpl.Node) ([]byte, error) { + buf := bytes.Buffer{} + + buf.WriteString(c.id.String() + "|") + buf.WriteString("agentupdate|") + buf.WriteString(id.String() + "|") + err := json.NewEncoder(&buf).Encode(node) + if err != nil { + return nil, xerrors.Errorf("encode node: %w", err) + } + + return buf.Bytes(), nil +} diff --git a/enterprise/tailnet/coordinator_test.go b/enterprise/tailnet/coordinator_test.go new file mode 100644 index 0000000000000..48fce5bfd0f6f --- /dev/null +++ b/enterprise/tailnet/coordinator_test.go @@ -0,0 +1,267 @@ +package tailnet_test + +import ( + "fmt" + "net" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/enterprise/tailnet" + agpl "github.com/coder/coder/tailnet" + "github.com/coder/coder/testutil" +) + +func TestCoordinatorSingle(t *testing.T) { + t.Parallel() + t.Run("ClientWithoutAgent", func(t *testing.T) { + t.Parallel() + coordinator, err := tailnet.NewHACoordinator(slogtest.Make(t, nil), database.NewPubsubInMemory()) + require.NoError(t, err) + defer coordinator.Close() + + client, server := net.Pipe() + sendNode, errChan := agpl.ServeCoordinator(client, func(node []*agpl.Node) error { + return nil + }) + id := uuid.New() + closeChan := make(chan struct{}) + go func() { + err := coordinator.ServeClient(server, id, uuid.New()) + assert.NoError(t, err) + close(closeChan) + }() + sendNode(&agpl.Node{}) + require.Eventually(t, func() bool { + return coordinator.Node(id) != nil + }, testutil.WaitShort, testutil.IntervalFast) + + err = client.Close() + require.NoError(t, err) + <-errChan + <-closeChan + }) + + t.Run("AgentWithoutClients", func(t *testing.T) { + t.Parallel() + coordinator, err := tailnet.NewHACoordinator(slogtest.Make(t, nil), database.NewPubsubInMemory()) + require.NoError(t, err) + defer coordinator.Close() + + client, server := net.Pipe() + sendNode, errChan := agpl.ServeCoordinator(client, func(node []*agpl.Node) error { + return nil + }) + id := uuid.New() + closeChan := make(chan struct{}) + go func() { + err := coordinator.ServeAgent(server, id) + assert.NoError(t, err) + close(closeChan) + }() + sendNode(&agpl.Node{}) + require.Eventually(t, func() bool { + return coordinator.Node(id) != nil + }, testutil.WaitShort, testutil.IntervalFast) + err = client.Close() + require.NoError(t, err) + <-errChan + <-closeChan + }) + + t.Run("AgentWithClient", func(t *testing.T) { + t.Parallel() + + coordinator, err := tailnet.NewHACoordinator(slogtest.Make(t, nil), database.NewPubsubInMemory()) + require.NoError(t, err) + defer coordinator.Close() + + agentWS, agentServerWS := net.Pipe() + defer agentWS.Close() + agentNodeChan := make(chan []*agpl.Node) + sendAgentNode, agentErrChan := agpl.ServeCoordinator(agentWS, func(nodes []*agpl.Node) error { + agentNodeChan <- nodes + return nil + }) + agentID := uuid.New() + closeAgentChan := make(chan struct{}) + go func() { + err := coordinator.ServeAgent(agentServerWS, agentID) + assert.NoError(t, err) + close(closeAgentChan) + }() + sendAgentNode(&agpl.Node{}) + require.Eventually(t, func() bool { + return coordinator.Node(agentID) != nil + }, testutil.WaitShort, testutil.IntervalFast) + + clientWS, clientServerWS := net.Pipe() + defer clientWS.Close() + defer clientServerWS.Close() + clientNodeChan := make(chan []*agpl.Node) + sendClientNode, clientErrChan := agpl.ServeCoordinator(clientWS, func(nodes []*agpl.Node) error { + clientNodeChan <- nodes + return nil + }) + clientID := uuid.New() + closeClientChan := make(chan struct{}) + go func() { + err := coordinator.ServeClient(clientServerWS, clientID, agentID) + assert.NoError(t, err) + close(closeClientChan) + }() + agentNodes := <-clientNodeChan + require.Len(t, agentNodes, 1) + sendClientNode(&agpl.Node{}) + clientNodes := <-agentNodeChan + require.Len(t, clientNodes, 1) + + // Ensure an update to the agent node reaches the client! + sendAgentNode(&agpl.Node{}) + agentNodes = <-clientNodeChan + require.Len(t, agentNodes, 1) + + // Close the agent WebSocket so a new one can connect. + err = agentWS.Close() + require.NoError(t, err) + <-agentErrChan + <-closeAgentChan + + // Create a new agent connection. This is to simulate a reconnect! + agentWS, agentServerWS = net.Pipe() + defer agentWS.Close() + agentNodeChan = make(chan []*agpl.Node) + _, agentErrChan = agpl.ServeCoordinator(agentWS, func(nodes []*agpl.Node) error { + agentNodeChan <- nodes + return nil + }) + closeAgentChan = make(chan struct{}) + go func() { + err := coordinator.ServeAgent(agentServerWS, agentID) + assert.NoError(t, err) + close(closeAgentChan) + }() + // Ensure the existing listening client sends it's node immediately! + clientNodes = <-agentNodeChan + require.Len(t, clientNodes, 1) + + err = agentWS.Close() + require.NoError(t, err) + <-agentErrChan + <-closeAgentChan + + err = clientWS.Close() + require.NoError(t, err) + <-clientErrChan + <-closeClientChan + }) +} + +func TestCoordinatorHA(t *testing.T) { + t.Parallel() + + t.Run("AgentWithClient", func(t *testing.T) { + t.Parallel() + + pubsub := database.NewPubsubInMemory() + + coordinator1, err := tailnet.NewHACoordinator(slogtest.Make(t, nil), pubsub) + require.NoError(t, err) + defer coordinator1.Close() + + coordinator2, err := tailnet.NewHACoordinator(slogtest.Make(t, nil), pubsub) + require.NoError(t, err) + defer coordinator2.Close() + + agentWS, agentServerWS := net.Pipe() + defer agentWS.Close() + agentNodeChan := make(chan []*agpl.Node) + sendAgentNode, agentErrChan := agpl.ServeCoordinator(agentWS, func(nodes []*agpl.Node) error { + fmt.Println("got agent node") + agentNodeChan <- nodes + fmt.Println("sent agent node") + return nil + }) + agentID := uuid.New() + closeAgentChan := make(chan struct{}) + go func() { + err := coordinator1.ServeAgent(agentServerWS, agentID) + assert.NoError(t, err) + close(closeAgentChan) + }() + sendAgentNode(&agpl.Node{}) + require.Eventually(t, func() bool { + return coordinator1.Node(agentID) != nil + }, testutil.WaitShort, testutil.IntervalFast) + + clientWS, clientServerWS := net.Pipe() + defer clientWS.Close() + defer clientServerWS.Close() + clientNodeChan := make(chan []*agpl.Node) + sendClientNode, clientErrChan := agpl.ServeCoordinator(clientWS, func(nodes []*agpl.Node) error { + fmt.Println("got client node") + clientNodeChan <- nodes + fmt.Println("sent client node") + return nil + }) + clientID := uuid.New() + closeClientChan := make(chan struct{}) + go func() { + err := coordinator2.ServeClient(clientServerWS, clientID, agentID) + assert.NoError(t, err) + close(closeClientChan) + }() + agentNodes := <-clientNodeChan + require.Len(t, agentNodes, 1) + sendClientNode(&agpl.Node{}) + _ = sendClientNode + clientNodes := <-agentNodeChan + require.Len(t, clientNodes, 1) + + // Ensure an update to the agent node reaches the client! + sendAgentNode(&agpl.Node{}) + agentNodes = <-clientNodeChan + require.Len(t, agentNodes, 1) + + // Close the agent WebSocket so a new one can connect. + require.NoError(t, agentWS.Close()) + require.NoError(t, agentServerWS.Close()) + <-agentErrChan + <-closeAgentChan + + // Create a new agent connection. This is to simulate a reconnect! + agentWS, agentServerWS = net.Pipe() + defer agentWS.Close() + agentNodeChan = make(chan []*agpl.Node) + _, agentErrChan = agpl.ServeCoordinator(agentWS, func(nodes []*agpl.Node) error { + fmt.Println("got agent node") + agentNodeChan <- nodes + fmt.Println("sent agent node") + return nil + }) + closeAgentChan = make(chan struct{}) + go func() { + err := coordinator1.ServeAgent(agentServerWS, agentID) + assert.NoError(t, err) + close(closeAgentChan) + }() + // Ensure the existing listening client sends it's node immediately! + clientNodes = <-agentNodeChan + require.Len(t, clientNodes, 1) + + err = agentWS.Close() + require.NoError(t, err) + <-agentErrChan + <-closeAgentChan + + err = clientWS.Close() + require.NoError(t, err) + <-clientErrChan + <-closeClientChan + }) +} diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 95209d56559ff..af6a5fee58288 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -7,6 +7,7 @@ import ( "net" "net/netip" "sync" + "time" "github.com/google/uuid" "golang.org/x/xerrors" @@ -14,6 +15,24 @@ import ( "tailscale.com/types/key" ) +// Coordinator exchanges nodes with agents to establish connections. +// ┌──────────────────┐ ┌────────────────────┐ ┌───────────────────┐ ┌──────────────────┐ +// │tailnet.Coordinate├──►│tailnet.AcceptClient│◄─►│tailnet.AcceptAgent│◄──┤tailnet.Coordinate│ +// └──────────────────┘ └────────────────────┘ └───────────────────┘ └──────────────────┘ +// Coordinators have different guarantees for HA support. +type Coordinator interface { + // Node returns an in-memory node by ID. + Node(id uuid.UUID) *Node + // ServeClient accepts a WebSocket connection that wants to connect to an agent + // with the specified ID. + ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error + // ServeAgent accepts a WebSocket connection to an agent that listens to + // incoming connections and publishes node updates. + ServeAgent(conn net.Conn, id uuid.UUID) error + // Close closes the coordinator. + Close() error +} + // Node represents a node in the network. type Node struct { ID tailcfg.NodeID `json:"id"` @@ -64,44 +83,46 @@ func ServeCoordinator(conn net.Conn, updateNodes func(node []*Node) error) (func }, errChan } -// NewCoordinator constructs a new in-memory connection coordinator. -func NewCoordinator() *Coordinator { - return &Coordinator{ +// NewMemoryCoordinator constructs a new in-memory connection coordinator. This +// coordinator is incompatible with multiple Coder replicas as all node data is +// in-memory. +func NewMemoryCoordinator() Coordinator { + return &memoryCoordinator{ nodes: map[uuid.UUID]*Node{}, agentSockets: map[uuid.UUID]net.Conn{}, agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]net.Conn{}, } } -// Coordinator exchanges nodes with agents to establish connections. +// MemoryCoordinator exchanges nodes with agents to establish connections. // ┌──────────────────┐ ┌────────────────────┐ ┌───────────────────┐ ┌──────────────────┐ // │tailnet.Coordinate├──►│tailnet.AcceptClient│◄─►│tailnet.AcceptAgent│◄──┤tailnet.Coordinate│ // └──────────────────┘ └────────────────────┘ └───────────────────┘ └──────────────────┘ // This coordinator is incompatible with multiple Coder // replicas as all node data is in-memory. -type Coordinator struct { +type memoryCoordinator struct { mutex sync.Mutex - // Maps agent and connection IDs to a node. + // nodes maps agent and connection IDs their respective node. nodes map[uuid.UUID]*Node - // Maps agent ID to an open socket. + // agentSockets maps agent IDs to their open websocket. agentSockets map[uuid.UUID]net.Conn - // Maps agent ID to connection ID for sending - // new node data as it comes in! + // agentToConnectionSockets maps agent IDs to connection IDs of conns that + // are subscribed to updates for that agent. agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]net.Conn } // Node returns an in-memory node by ID. -func (c *Coordinator) Node(id uuid.UUID) *Node { +func (c *memoryCoordinator) Node(id uuid.UUID) *Node { c.mutex.Lock() defer c.mutex.Unlock() node := c.nodes[id] return node } -// ServeClient accepts a WebSocket connection that wants to -// connect to an agent with the specified ID. -func (c *Coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { +// ServeClient accepts a WebSocket connection that wants to connect to an agent +// with the specified ID. +func (c *memoryCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { c.mutex.Lock() // When a new connection is requested, we update it with the latest // node of the agent. This allows the connection to establish. @@ -145,48 +166,67 @@ func (c *Coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) decoder := json.NewDecoder(conn) for { - var node Node - err := decoder.Decode(&node) - if errors.Is(err, io.EOF) { - return nil - } - if err != nil { - return xerrors.Errorf("read json: %w", err) - } - c.mutex.Lock() - // Update the node of this client in our in-memory map. - // If an agent entirely shuts down and reconnects, it - // needs to be aware of all clients attempting to - // establish connections. - c.nodes[id] = &node - agentSocket, ok := c.agentSockets[agent] - if !ok { - c.mutex.Unlock() - continue - } - // Write the new node from this client to the actively - // connected agent. - data, err := json.Marshal([]*Node{&node}) + err := c.handleNextClientMessage(id, agent, decoder) if err != nil { - c.mutex.Unlock() - return xerrors.Errorf("marshal nodes: %w", err) + if errors.Is(err, io.EOF) { + return nil + } + return xerrors.Errorf("handle next client message: %w", err) } - _, err = agentSocket.Write(data) + } +} + +func (c *memoryCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error { + var node Node + err := decoder.Decode(&node) + if err != nil { + return xerrors.Errorf("read json: %w", err) + } + + c.mutex.Lock() + defer c.mutex.Unlock() + + // Update the node of this client in our in-memory map. If an agent + // entirely shuts down and reconnects, it needs to be aware of all clients + // attempting to establish connections. + c.nodes[id] = &node + + // Write the new node from this client to the actively + // connected agent. + err = c.writeNodeToAgent(agent, &node) + if err != nil { + return xerrors.Errorf("write node to agent: %w", err) + } + + return nil +} + +func (c *memoryCoordinator) writeNodeToAgent(agent uuid.UUID, node *Node) error { + agentSocket, ok := c.agentSockets[agent] + if !ok { + return nil + } + + // Write the new node from this client to the actively + // connected agent. + data, err := json.Marshal([]*Node{node}) + if err != nil { + return xerrors.Errorf("marshal nodes: %w", err) + } + + _, err = agentSocket.Write(data) + if err != nil { if errors.Is(err, io.EOF) { - c.mutex.Unlock() return nil } - if err != nil { - c.mutex.Unlock() - return xerrors.Errorf("write json: %w", err) - } - c.mutex.Unlock() + return xerrors.Errorf("write json: %w", err) } + return nil } // ServeAgent accepts a WebSocket connection to an agent that // listens to incoming connections and publishes node updates. -func (c *Coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error { +func (c *memoryCoordinator) ServeAgent(conn net.Conn, id uuid.UUID) error { c.mutex.Lock() sockets, ok := c.agentToConnectionSockets[id] if ok { @@ -230,36 +270,51 @@ func (c *Coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error { decoder := json.NewDecoder(conn) for { - var node Node - err := decoder.Decode(&node) - if errors.Is(err, io.EOF) { - return nil - } - if err != nil { - return xerrors.Errorf("read json: %w", err) - } - c.mutex.Lock() - c.nodes[id] = &node - connectionSockets, ok := c.agentToConnectionSockets[id] - if !ok { - c.mutex.Unlock() - continue - } - data, err := json.Marshal([]*Node{&node}) + err := c.handleNextAgentMessage(id, decoder) if err != nil { - return xerrors.Errorf("marshal nodes: %w", err) - } - // Publish the new node to every listening socket. - var wg sync.WaitGroup - wg.Add(len(connectionSockets)) - for _, connectionSocket := range connectionSockets { - connectionSocket := connectionSocket - go func() { - _, _ = connectionSocket.Write(data) - wg.Done() - }() + if errors.Is(err, io.EOF) { + return nil + } + return xerrors.Errorf("handle next agent message: %w", err) } - wg.Wait() - c.mutex.Unlock() } } + +func (c *memoryCoordinator) handleNextAgentMessage(id uuid.UUID, decoder *json.Decoder) error { + var node Node + err := decoder.Decode(&node) + if err != nil { + return xerrors.Errorf("read json: %w", err) + } + + c.mutex.Lock() + defer c.mutex.Unlock() + + c.nodes[id] = &node + connectionSockets, ok := c.agentToConnectionSockets[id] + if !ok { + return nil + } + + data, err := json.Marshal([]*Node{&node}) + if err != nil { + return xerrors.Errorf("marshal nodes: %w", err) + } + + // Publish the new node to every listening socket. + var wg sync.WaitGroup + wg.Add(len(connectionSockets)) + for _, connectionSocket := range connectionSockets { + connectionSocket := connectionSocket + go func() { + _ = connectionSocket.SetWriteDeadline(time.Now().Add(5 * time.Second)) + _, _ = connectionSocket.Write(data) + wg.Done() + }() + } + + wg.Wait() + return nil +} + +func (*memoryCoordinator) Close() error { return nil } diff --git a/tailnet/coordinator_test.go b/tailnet/coordinator_test.go index f3fdab88d5ef8..e0ed44420ede2 100644 --- a/tailnet/coordinator_test.go +++ b/tailnet/coordinator_test.go @@ -16,7 +16,7 @@ func TestCoordinator(t *testing.T) { t.Parallel() t.Run("ClientWithoutAgent", func(t *testing.T) { t.Parallel() - coordinator := tailnet.NewCoordinator() + coordinator := tailnet.NewMemoryCoordinator() client, server := net.Pipe() sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error { return nil @@ -32,15 +32,15 @@ func TestCoordinator(t *testing.T) { require.Eventually(t, func() bool { return coordinator.Node(id) != nil }, testutil.WaitShort, testutil.IntervalFast) - err := client.Close() - require.NoError(t, err) + require.NoError(t, client.Close()) + require.NoError(t, server.Close()) <-errChan <-closeChan }) t.Run("AgentWithoutClients", func(t *testing.T) { t.Parallel() - coordinator := tailnet.NewCoordinator() + coordinator := tailnet.NewMemoryCoordinator() client, server := net.Pipe() sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error { return nil @@ -64,7 +64,7 @@ func TestCoordinator(t *testing.T) { t.Run("AgentWithClient", func(t *testing.T) { t.Parallel() - coordinator := tailnet.NewCoordinator() + coordinator := tailnet.NewMemoryCoordinator() agentWS, agentServerWS := net.Pipe() defer agentWS.Close() From 68a812b134d43b3777d7173fdeadc503eea9ad4e Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 23 Sep 2022 13:26:25 -0500 Subject: [PATCH 02/10] fixup! feat: HA tailnet coordinator --- enterprise/tailnet/coordinator.go | 132 +++++++++++++++++++++--------- 1 file changed, 92 insertions(+), 40 deletions(-) diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go index 8824f584d60da..6999fa7157d48 100644 --- a/enterprise/tailnet/coordinator.go +++ b/enterprise/tailnet/coordinator.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "io" "net" "sync" @@ -150,7 +149,7 @@ func (c *haCoordinator) writeNodeToAgent(agent uuid.UUID, node *agpl.Node) error if !ok { // If we don't own the agent locally, send it over pubsub to a node that // owns the agent. - err := c.publishNodeToAgent(agent, node) + err := c.publishNodesToAgent(agent, []*agpl.Node{node}) if err != nil { return xerrors.Errorf("publish node to agent") } @@ -178,18 +177,15 @@ func (c *haCoordinator) writeNodeToAgent(agent uuid.UUID, node *agpl.Node) error // incoming connections and publishes node updates. func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID) error { c.mutex.Lock() - sockets, ok := c.agentToConnectionSockets[id] - if ok { - // Publish all nodes that want to connect to the - // desired agent ID. - nodes := make([]*agpl.Node, 0, len(sockets)) - for targetID := range sockets { - node, ok := c.nodes[targetID] - if !ok { - continue - } - nodes = append(nodes, node) - } + + // Tell clients on other instances to send a callmemaybe to us. + err := c.publishAgentHello(id) + if err != nil { + return xerrors.Errorf("publish agent hello: %w", err) + } + + nodes := c.nodesSubscribedToAgent(id) + if len(nodes) > 0 { data, err := json.Marshal(nodes) if err != nil { c.mutex.Unlock() @@ -220,21 +216,46 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID) error { decoder := json.NewDecoder(conn) for { - err := c.hangleAgentUpdate(id, decoder, false) + node, err := c.hangleAgentUpdate(id, decoder) if err != nil { if errors.Is(err, io.EOF) { return nil } return xerrors.Errorf("handle next agent message: %w", err) } + + err = c.publishAgentToNodes(id, node) + if err != nil { + return xerrors.Errorf("publish agent to nodes: %w", err) + } } } -func (c *haCoordinator) hangleAgentUpdate(id uuid.UUID, decoder *json.Decoder, fromPubsub bool) error { +func (c *haCoordinator) nodesSubscribedToAgent(agentID uuid.UUID) []*agpl.Node { + sockets, ok := c.agentToConnectionSockets[agentID] + if !ok { + return nil + } + + // Publish all nodes that want to connect to the + // desired agent ID. + nodes := make([]*agpl.Node, 0, len(sockets)) + for targetID := range sockets { + node, ok := c.nodes[targetID] + if !ok { + continue + } + nodes = append(nodes, node) + } + + return nodes +} + +func (c *haCoordinator) hangleAgentUpdate(id uuid.UUID, decoder *json.Decoder) (*agpl.Node, error) { var node agpl.Node err := decoder.Decode(&node) if err != nil { - return xerrors.Errorf("read json: %w", err) + return nil, xerrors.Errorf("read json: %w", err) } c.mutex.Lock() @@ -242,22 +263,14 @@ func (c *haCoordinator) hangleAgentUpdate(id uuid.UUID, decoder *json.Decoder, f c.nodes[id] = &node - // Don't send the agent back over pubsub if that's where we received it from! - if !fromPubsub { - err = c.publishAgentToNodes(id, &node) - if err != nil { - return xerrors.Errorf("publish agent to nodes: %w", err) - } - } - connectionSockets, ok := c.agentToConnectionSockets[id] if !ok { - return nil + return &node, nil } data, err := json.Marshal([]*agpl.Node{&node}) if err != nil { - return xerrors.Errorf("marshal nodes: %w", err) + return nil, xerrors.Errorf("marshal nodes: %w", err) } // Publish the new node to every listening socket. @@ -273,7 +286,7 @@ func (c *haCoordinator) hangleAgentUpdate(id uuid.UUID, decoder *json.Decoder, f } wg.Wait() - return nil + return &node, nil } func (c *haCoordinator) Close() error { @@ -281,13 +294,26 @@ func (c *haCoordinator) Close() error { return nil } -func (c *haCoordinator) publishNodeToAgent(recipient uuid.UUID, node *agpl.Node) error { - msg, err := c.formatCallMeMaybe(recipient, node) +func (c *haCoordinator) publishNodesToAgent(recipient uuid.UUID, nodes []*agpl.Node) error { + msg, err := c.formatCallMeMaybe(recipient, nodes) + if err != nil { + return xerrors.Errorf("format publish message: %w", err) + } + + err = c.pubsub.Publish("wireguard_peers", msg) + if err != nil { + return xerrors.Errorf("publish message: %w", err) + } + + return nil +} + +func (c *haCoordinator) publishAgentHello(id uuid.UUID) error { + msg, err := c.formatAgentHello(id) if err != nil { return xerrors.Errorf("format publish message: %w", err) } - fmt.Println("publishing callmemaybe", c.id.String()) err = c.pubsub.Publish("wireguard_peers", msg) if err != nil { return xerrors.Errorf("publish message: %w", err) @@ -302,7 +328,6 @@ func (c *haCoordinator) publishAgentToNodes(id uuid.UUID, node *agpl.Node) error return xerrors.Errorf("format publish message: %w", err) } - fmt.Println("publishing agentupdate", c.id.String()) err = c.pubsub.Publish("wireguard_peers", msg) if err != nil { return xerrors.Errorf("publish message: %w", err) @@ -345,19 +370,16 @@ func (c *haCoordinator) runPubsub() error { return } - fmt.Println("got callmemaybe", agentUUID.String()) c.mutex.Lock() defer c.mutex.Unlock() - fmt.Println("process callmemaybe", agentUUID.String()) agentSocket, ok := c.agentSockets[agentUUID] if !ok { - fmt.Println("no socket") return } // We get a single node over pubsub, so turn into an array. - _, err = agentSocket.Write(bytes.Join([][]byte{[]byte("["), nodeJSON, []byte("]")}, []byte{})) + _, err = agentSocket.Write(nodeJSON) if err != nil { if errors.Is(err, io.EOF) { return @@ -365,18 +387,37 @@ func (c *haCoordinator) runPubsub() error { c.log.Error(ctx, "send callmemaybe to agent", slog.Error(err)) return } - fmt.Println("success callmemaybe", agentUUID.String()) + + case "agenthello": + agentUUID, err := uuid.ParseBytes(agentID) + if err != nil { + c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID))) + return + } + + c.mutex.Lock() + nodes := c.nodesSubscribedToAgent(agentUUID) + c.mutex.Unlock() + if len(nodes) > 0 { + err := c.publishNodesToAgent(agentUUID, nodes) + if err != nil { + c.log.Error(ctx, "publish nodes to agent", slog.Error(err)) + return + } + } case "agentupdate": agentUUID, err := uuid.ParseBytes(agentID) if err != nil { c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID))) + return } decoder := json.NewDecoder(bytes.NewReader(nodeJSON)) - err = c.hangleAgentUpdate(agentUUID, decoder, true) + _, err = c.hangleAgentUpdate(agentUUID, decoder) if err != nil { c.log.Error(ctx, "handle agent update", slog.Error(err)) + return } default: @@ -396,13 +437,13 @@ func (c *haCoordinator) runPubsub() error { } // format: |callmemaybe|| -func (c *haCoordinator) formatCallMeMaybe(recipient uuid.UUID, node *agpl.Node) ([]byte, error) { +func (c *haCoordinator) formatCallMeMaybe(recipient uuid.UUID, nodes []*agpl.Node) ([]byte, error) { buf := bytes.Buffer{} buf.WriteString(c.id.String() + "|") buf.WriteString("callmemaybe|") buf.WriteString(recipient.String() + "|") - err := json.NewEncoder(&buf).Encode(node) + err := json.NewEncoder(&buf).Encode(nodes) if err != nil { return nil, xerrors.Errorf("encode node: %w", err) } @@ -410,6 +451,17 @@ func (c *haCoordinator) formatCallMeMaybe(recipient uuid.UUID, node *agpl.Node) return buf.Bytes(), nil } +// format: |agenthello|| +func (c *haCoordinator) formatAgentHello(id uuid.UUID) ([]byte, error) { + buf := bytes.Buffer{} + + buf.WriteString(c.id.String() + "|") + buf.WriteString("agenthello|") + buf.WriteString(id.String() + "|") + + return buf.Bytes(), nil +} + // format: |agentupdate|| func (c *haCoordinator) formatAgentUpdate(id uuid.UUID, node *agpl.Node) ([]byte, error) { buf := bytes.Buffer{} From 774c5dafe3cb41a9eca1819531f073fd8ff9c9b9 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 23 Sep 2022 13:29:40 -0500 Subject: [PATCH 03/10] fixup! feat: HA tailnet coordinator --- enterprise/tailnet/coordinator.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go index 6999fa7157d48..61b4bd5759ace 100644 --- a/enterprise/tailnet/coordinator.go +++ b/enterprise/tailnet/coordinator.go @@ -184,6 +184,7 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID) error { return xerrors.Errorf("publish agent hello: %w", err) } + // Publish all nodes on this instance that want to connect to this agent. nodes := c.nodesSubscribedToAgent(id) if len(nodes) > 0 { data, err := json.Marshal(nodes) @@ -237,8 +238,6 @@ func (c *haCoordinator) nodesSubscribedToAgent(agentID uuid.UUID) []*agpl.Node { return nil } - // Publish all nodes that want to connect to the - // desired agent ID. nodes := make([]*agpl.Node, 0, len(sockets)) for targetID := range sockets { node, ok := c.nodes[targetID] From bd82c5e36c79c080954b38255c8d198a0f0b925f Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 23 Sep 2022 13:30:58 -0500 Subject: [PATCH 04/10] remove printlns --- enterprise/tailnet/coordinator_test.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/enterprise/tailnet/coordinator_test.go b/enterprise/tailnet/coordinator_test.go index 48fce5bfd0f6f..4889cd1c8ba60 100644 --- a/enterprise/tailnet/coordinator_test.go +++ b/enterprise/tailnet/coordinator_test.go @@ -1,7 +1,6 @@ package tailnet_test import ( - "fmt" "net" "testing" @@ -182,9 +181,7 @@ func TestCoordinatorHA(t *testing.T) { defer agentWS.Close() agentNodeChan := make(chan []*agpl.Node) sendAgentNode, agentErrChan := agpl.ServeCoordinator(agentWS, func(nodes []*agpl.Node) error { - fmt.Println("got agent node") agentNodeChan <- nodes - fmt.Println("sent agent node") return nil }) agentID := uuid.New() @@ -204,9 +201,7 @@ func TestCoordinatorHA(t *testing.T) { defer clientServerWS.Close() clientNodeChan := make(chan []*agpl.Node) sendClientNode, clientErrChan := agpl.ServeCoordinator(clientWS, func(nodes []*agpl.Node) error { - fmt.Println("got client node") clientNodeChan <- nodes - fmt.Println("sent client node") return nil }) clientID := uuid.New() @@ -239,9 +234,7 @@ func TestCoordinatorHA(t *testing.T) { defer agentWS.Close() agentNodeChan = make(chan []*agpl.Node) _, agentErrChan = agpl.ServeCoordinator(agentWS, func(nodes []*agpl.Node) error { - fmt.Println("got agent node") agentNodeChan <- nodes - fmt.Println("sent agent node") return nil }) closeAgentChan = make(chan struct{}) From fbad8d075ddfb47d99c9cd7f2d1696ded78266ed Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 7 Oct 2022 11:49:19 -0500 Subject: [PATCH 05/10] close all connections on coordinator --- codersdk/features.go | 12 ++++--- enterprise/coderd/coderd.go | 8 +++++ enterprise/coderd/license/license.go | 25 +++++++++++---- enterprise/tailnet/coordinator.go | 30 ++++++++++++++++- tailnet/coordinator.go | 48 ++++++++++++++++++++++++++-- 5 files changed, 109 insertions(+), 14 deletions(-) diff --git a/codersdk/features.go b/codersdk/features.go index fe8673ef028fd..6884f44087629 100644 --- a/codersdk/features.go +++ b/codersdk/features.go @@ -15,11 +15,12 @@ const ( ) const ( - FeatureUserLimit = "user_limit" - FeatureAuditLog = "audit_log" - FeatureBrowserOnly = "browser_only" - FeatureSCIM = "scim" - FeatureWorkspaceQuota = "workspace_quota" + FeatureUserLimit = "user_limit" + FeatureAuditLog = "audit_log" + FeatureBrowserOnly = "browser_only" + FeatureSCIM = "scim" + FeatureWorkspaceQuota = "workspace_quota" + FeatureHighAvailability = "high_availability" ) var FeatureNames = []string{ @@ -28,6 +29,7 @@ var FeatureNames = []string{ FeatureBrowserOnly, FeatureSCIM, FeatureWorkspaceQuota, + FeatureHighAvailability, } type Feature struct { diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 11cceef98f0db..a6595e8bd6554 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -170,6 +170,14 @@ func (api *API) updateEntitlements(ctx context.Context) error { api.AGPL.WorkspaceQuotaEnforcer.Store(&enforcer) } + if changed, enabled := featureChanged(codersdk.FeatureHighAvailability); changed { + enforcer := workspacequota.NewNop() + if enabled { + enforcer = NewEnforcer(api.Options.UserWorkspaceQuota) + } + api.AGPL.WorkspaceQuotaEnforcer.Store(&enforcer) + } + api.entitlements = entitlements return nil diff --git a/enterprise/coderd/license/license.go b/enterprise/coderd/license/license.go index 55a62eee17eee..84d28dfcccb21 100644 --- a/enterprise/coderd/license/license.go +++ b/enterprise/coderd/license/license.go @@ -17,7 +17,13 @@ import ( ) // Entitlements processes licenses to return whether features are enabled or not. -func Entitlements(ctx context.Context, db database.Store, logger slog.Logger, keys map[string]ed25519.PublicKey, enablements map[string]bool) (codersdk.Entitlements, error) { +func Entitlements( + ctx context.Context, + db database.Store, + logger slog.Logger, + keys map[string]ed25519.PublicKey, + enablements map[string]bool, +) (codersdk.Entitlements, error) { now := time.Now() // Default all entitlements to be disabled. entitlements := codersdk.Entitlements{ @@ -96,6 +102,12 @@ func Entitlements(ctx context.Context, db database.Store, logger slog.Logger, ke Enabled: enablements[codersdk.FeatureWorkspaceQuota], } } + if claims.Features.HighAvailability > 0 { + entitlements.Features[codersdk.FeatureHighAvailability] = codersdk.Feature{ + Entitlement: entitlement, + Enabled: enablements[codersdk.FeatureHighAvailability], + } + } if claims.AllFeatures { allFeatures = true } @@ -165,11 +177,12 @@ var ( ) type Features struct { - UserLimit int64 `json:"user_limit"` - AuditLog int64 `json:"audit_log"` - BrowserOnly int64 `json:"browser_only"` - SCIM int64 `json:"scim"` - WorkspaceQuota int64 `json:"workspace_quota"` + UserLimit int64 `json:"user_limit"` + AuditLog int64 `json:"audit_log"` + BrowserOnly int64 `json:"browser_only"` + SCIM int64 `json:"scim"` + WorkspaceQuota int64 `json:"workspace_quota"` + HighAvailability int64 `json:"high_availability"` } type Claims struct { diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go index 61b4bd5759ace..6bf2327507165 100644 --- a/enterprise/tailnet/coordinator.go +++ b/enterprise/tailnet/coordinator.go @@ -14,7 +14,6 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog" - "github.com/coder/coder/coderd/database" agpl "github.com/coder/coder/tailnet" ) @@ -288,8 +287,37 @@ func (c *haCoordinator) hangleAgentUpdate(id uuid.UUID, decoder *json.Decoder) ( return &node, nil } +// Close closes all of the open connections in the coordinator and stops the +// coordinator from accepting new connections. func (c *haCoordinator) Close() error { + c.mutex.Lock() + defer c.mutex.Unlock() + close(c.close) + + wg := sync.WaitGroup{} + + wg.Add(len(c.agentSockets)) + for _, socket := range c.agentSockets { + socket := socket + go func() { + _ = socket.Close() + wg.Done() + }() + } + + for _, connMap := range c.agentToConnectionSockets { + wg.Add(len(connMap)) + for _, socket := range connMap { + socket := socket + go func() { + _ = socket.Close() + wg.Done() + }() + } + } + + wg.Wait() return nil } diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index d79ffa34a5a3b..150a323bcfe52 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -99,6 +99,7 @@ func ServeCoordinator(conn net.Conn, updateNodes func(node []*Node) error) (func // in-memory. func NewMemoryCoordinator() Coordinator { return &memoryCoordinator{ + closed: false, nodes: map[uuid.UUID]*Node{}, agentSockets: map[uuid.UUID]net.Conn{}, agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]net.Conn{}, @@ -112,7 +113,8 @@ func NewMemoryCoordinator() Coordinator { // This coordinator is incompatible with multiple Coder // replicas as all node data is in-memory. type memoryCoordinator struct { - mutex sync.Mutex + mutex sync.Mutex + closed bool // nodes maps agent and connection IDs their respective node. nodes map[uuid.UUID]*Node @@ -135,6 +137,11 @@ func (c *memoryCoordinator) Node(id uuid.UUID) *Node { // with the specified ID. func (c *memoryCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { c.mutex.Lock() + + if c.closed { + return xerrors.New("coordinator is closed") + } + // When a new connection is requested, we update it with the latest // node of the agent. This allows the connection to establish. node, ok := c.nodes[agent] @@ -229,6 +236,11 @@ func (c *memoryCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder // listens to incoming connections and publishes node updates. func (c *memoryCoordinator) ServeAgent(conn net.Conn, id uuid.UUID) error { c.mutex.Lock() + + if c.closed { + return xerrors.New("coordinator is closed") + } + sockets, ok := c.agentToConnectionSockets[id] if ok { // Publish all nodes that want to connect to the @@ -320,4 +332,36 @@ func (c *memoryCoordinator) handleNextAgentMessage(id uuid.UUID, decoder *json.D return nil } -func (*memoryCoordinator) Close() error { return nil } +// Close closes all of the open connections in the coordinator and stops the +// coordinator from accepting new connections. +func (c *memoryCoordinator) Close() error { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.closed = true + + wg := sync.WaitGroup{} + + wg.Add(len(c.agentSockets)) + for _, socket := range c.agentSockets { + socket := socket + go func() { + _ = socket.Close() + wg.Done() + }() + } + + for _, connMap := range c.agentToConnectionSockets { + wg.Add(len(connMap)) + for _, socket := range connMap { + socket := socket + go func() { + _ = socket.Close() + wg.Done() + }() + } + } + + wg.Wait() + return nil +} From 46803aa38ba2d4189f687bda248f01bf933bf18e Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 7 Oct 2022 12:22:44 -0500 Subject: [PATCH 06/10] impelement high availability feature --- coderd/coderd.go | 2 ++ coderd/provisionerjobs.go | 2 +- coderd/workspaceagents.go | 16 ++++++++-------- coderd/workspacebuilds.go | 2 +- enterprise/coderd/coderd.go | 24 +++++++++++++++++++++--- 5 files changed, 33 insertions(+), 13 deletions(-) diff --git a/coderd/coderd.go b/coderd/coderd.go index 58686ae66fbcd..f3cdab0caea04 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -158,6 +158,7 @@ func New(options *Options) *API { api.Auditor.Store(&options.Auditor) api.WorkspaceQuotaEnforcer.Store(&options.WorkspaceQuotaEnforcer) api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0) + api.TailnetCoordinator.Store(&options.TailnetCoordinator) api.derpServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger)) oauthConfigs := &httpmw.OAuth2Configs{ Github: options.GithubOAuth2Config, @@ -525,6 +526,7 @@ type API struct { Auditor atomic.Pointer[audit.Auditor] WorkspaceClientCoordinateOverride atomic.Pointer[func(rw http.ResponseWriter) bool] WorkspaceQuotaEnforcer atomic.Pointer[workspacequota.Enforcer] + TailnetCoordinator atomic.Pointer[tailnet.Coordinator] HTTPAuth *HTTPAuthorizer // APIHandler serves "/api/v2" diff --git a/coderd/provisionerjobs.go b/coderd/provisionerjobs.go index 56a825ea09a3a..68802df04e5ec 100644 --- a/coderd/provisionerjobs.go +++ b/coderd/provisionerjobs.go @@ -270,7 +270,7 @@ func (api *API) provisionerJobResources(rw http.ResponseWriter, r *http.Request, } } - apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, agent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), agent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading job agent.", diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 247915db99592..29943c8701ec8 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -48,7 +48,7 @@ func (api *API) workspaceAgent(rw http.ResponseWriter, r *http.Request) { }) return } - apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading workspace agent.", @@ -77,7 +77,7 @@ func (api *API) workspaceAgentApps(rw http.ResponseWriter, r *http.Request) { func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() workspaceAgent := httpmw.WorkspaceAgent(r) - apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading workspace agent.", @@ -97,7 +97,7 @@ func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) func (api *API) postWorkspaceAgentVersion(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() workspaceAgent := httpmw.WorkspaceAgent(r) - apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading workspace agent.", @@ -151,7 +151,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { httpapi.ResourceNotFound(rw) return } - apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading workspace agent.", @@ -228,7 +228,7 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req return } - apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading workspace agent.", @@ -322,7 +322,7 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (* }) conn.SetNodeCallback(sendNodes) go func() { - err := api.TailnetCoordinator.ServeClient(serverConn, uuid.New(), agentID) + err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID) if err != nil { _ = conn.Close() } @@ -460,7 +460,7 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request closeChan := make(chan struct{}) go func() { defer close(closeChan) - err := api.TailnetCoordinator.ServeAgent(wsNetConn, workspaceAgent.ID) + err := (*api.TailnetCoordinator.Load()).ServeAgent(wsNetConn, workspaceAgent.ID) if err != nil { _ = conn.Close(websocket.StatusInternalError, err.Error()) return @@ -529,7 +529,7 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R go httpapi.Heartbeat(ctx, conn) defer conn.Close(websocket.StatusNormalClosure, "") - err = api.TailnetCoordinator.ServeClient(websocket.NetConn(ctx, conn, websocket.MessageBinary), uuid.New(), workspaceAgent.ID) + err = (*api.TailnetCoordinator.Load()).ServeClient(websocket.NetConn(ctx, conn, websocket.MessageBinary), uuid.New(), workspaceAgent.ID) if err != nil { _ = conn.Close(websocket.StatusInternalError, err.Error()) return diff --git a/coderd/workspacebuilds.go b/coderd/workspacebuilds.go index 6ece8d379b153..88e162fa7db94 100644 --- a/coderd/workspacebuilds.go +++ b/coderd/workspacebuilds.go @@ -831,7 +831,7 @@ func (api *API) convertWorkspaceBuild( apiAgents := make([]codersdk.WorkspaceAgent, 0) for _, agent := range agents { apps := appsByAgentID[agent.ID] - apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, agent, convertApps(apps), api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), agent, convertApps(apps), api.AgentInactiveDisconnectTimeout) if err != nil { return codersdk.WorkspaceBuild{}, xerrors.Errorf("converting workspace agent: %w", err) } diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index a6595e8bd6554..8eddcf42e325b 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -22,6 +22,8 @@ import ( "github.com/coder/coder/enterprise/audit" "github.com/coder/coder/enterprise/audit/backends" "github.com/coder/coder/enterprise/coderd/license" + "github.com/coder/coder/enterprise/tailnet" + agpltailnet "github.com/coder/coder/tailnet" ) // New constructs an Enterprise coderd API instance. @@ -171,11 +173,27 @@ func (api *API) updateEntitlements(ctx context.Context) error { } if changed, enabled := featureChanged(codersdk.FeatureHighAvailability); changed { - enforcer := workspacequota.NewNop() + coordinator := agpltailnet.NewMemoryCoordinator() if enabled { - enforcer = NewEnforcer(api.Options.UserWorkspaceQuota) + haCoordinator, err := tailnet.NewHACoordinator(api.Logger, api.Pubsub) + if err != nil { + api.Logger.Error(ctx, "unable to setup HA tailnet coordinator", slog.Error(err)) + // If we try to setup the HA coordinator and it fails, nothing + // is actually changing. + changed = false + } else { + coordinator = haCoordinator + } + } + + // Recheck changed in case the HA coordinator failed to set up. + if changed { + oldCoordinator := *api.AGPL.TailnetCoordinator.Swap(&coordinator) + err := oldCoordinator.Close() + if err != nil { + api.Logger.Error(ctx, "unable to setup HA tailnet coordinator", slog.Error(err)) + } } - api.AGPL.WorkspaceQuotaEnforcer.Store(&enforcer) } api.entitlements = entitlements From d38391e9f6ff27351e33017540efcc21f3dcc7d8 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 7 Oct 2022 12:45:29 -0500 Subject: [PATCH 07/10] fixup! impelement high availability feature --- enterprise/coderd/coderd.go | 2 +- .../coderd/coderdenttest/coderdenttest.go | 23 ++++++++++--------- enterprise/coderd/license/license_test.go | 9 ++++---- enterprise/coderd/licenses_test.go | 22 ++++++++++-------- 4 files changed, 30 insertions(+), 26 deletions(-) diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 8eddcf42e325b..d52596c547027 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -191,7 +191,7 @@ func (api *API) updateEntitlements(ctx context.Context) error { oldCoordinator := *api.AGPL.TailnetCoordinator.Swap(&coordinator) err := oldCoordinator.Close() if err != nil { - api.Logger.Error(ctx, "unable to setup HA tailnet coordinator", slog.Error(err)) + api.Logger.Error(ctx, "close old tailnet coordinator", slog.Error(err)) } } } diff --git a/enterprise/coderd/coderdenttest/coderdenttest.go b/enterprise/coderd/coderdenttest/coderdenttest.go index 90d09fd5c9c85..a9e08b4aac088 100644 --- a/enterprise/coderd/coderdenttest/coderdenttest.go +++ b/enterprise/coderd/coderdenttest/coderdenttest.go @@ -85,17 +85,18 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c } type LicenseOptions struct { - AccountType string - AccountID string - Trial bool - AllFeatures bool - GraceAt time.Time - ExpiresAt time.Time - UserLimit int64 - AuditLog bool - BrowserOnly bool - SCIM bool - WorkspaceQuota bool + AccountType string + AccountID string + Trial bool + AllFeatures bool + GraceAt time.Time + ExpiresAt time.Time + UserLimit int64 + AuditLog bool + BrowserOnly bool + SCIM bool + WorkspaceQuota bool + HighAvailability bool } // AddLicense generates a new license with the options provided and inserts it. diff --git a/enterprise/coderd/license/license_test.go b/enterprise/coderd/license/license_test.go index 85958fbf4f60d..39d6e05fb50d3 100644 --- a/enterprise/coderd/license/license_test.go +++ b/enterprise/coderd/license/license_test.go @@ -20,10 +20,11 @@ import ( func TestEntitlements(t *testing.T) { t.Parallel() all := map[string]bool{ - codersdk.FeatureAuditLog: true, - codersdk.FeatureBrowserOnly: true, - codersdk.FeatureSCIM: true, - codersdk.FeatureWorkspaceQuota: true, + codersdk.FeatureAuditLog: true, + codersdk.FeatureBrowserOnly: true, + codersdk.FeatureSCIM: true, + codersdk.FeatureWorkspaceQuota: true, + codersdk.FeatureHighAvailability: true, } t.Run("Defaults", func(t *testing.T) { diff --git a/enterprise/coderd/licenses_test.go b/enterprise/coderd/licenses_test.go index 59d36cc9157a6..5b4c89212578d 100644 --- a/enterprise/coderd/licenses_test.go +++ b/enterprise/coderd/licenses_test.go @@ -99,21 +99,23 @@ func TestGetLicense(t *testing.T) { assert.Equal(t, int32(1), licenses[0].ID) assert.Equal(t, "testing", licenses[0].Claims["account_id"]) assert.Equal(t, map[string]interface{}{ - codersdk.FeatureUserLimit: json.Number("0"), - codersdk.FeatureAuditLog: json.Number("1"), - codersdk.FeatureSCIM: json.Number("1"), - codersdk.FeatureBrowserOnly: json.Number("1"), - codersdk.FeatureWorkspaceQuota: json.Number("0"), + codersdk.FeatureUserLimit: json.Number("0"), + codersdk.FeatureAuditLog: json.Number("1"), + codersdk.FeatureSCIM: json.Number("1"), + codersdk.FeatureBrowserOnly: json.Number("1"), + codersdk.FeatureWorkspaceQuota: json.Number("0"), + codersdk.FeatureHighAvailability: json.Number("0"), }, licenses[0].Claims["features"]) assert.Equal(t, int32(2), licenses[1].ID) assert.Equal(t, "testing2", licenses[1].Claims["account_id"]) assert.Equal(t, true, licenses[1].Claims["trial"]) assert.Equal(t, map[string]interface{}{ - codersdk.FeatureUserLimit: json.Number("200"), - codersdk.FeatureAuditLog: json.Number("1"), - codersdk.FeatureSCIM: json.Number("1"), - codersdk.FeatureBrowserOnly: json.Number("1"), - codersdk.FeatureWorkspaceQuota: json.Number("0"), + codersdk.FeatureUserLimit: json.Number("200"), + codersdk.FeatureAuditLog: json.Number("1"), + codersdk.FeatureSCIM: json.Number("1"), + codersdk.FeatureBrowserOnly: json.Number("1"), + codersdk.FeatureWorkspaceQuota: json.Number("0"), + codersdk.FeatureHighAvailability: json.Number("0"), }, licenses[1].Claims["features"]) }) } From a0bcd6464f16483c9524a69137de1dcc7d309095 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 7 Oct 2022 12:53:18 -0500 Subject: [PATCH 08/10] fixup! impelement high availability feature --- enterprise/coderd/license/license_test.go | 26 ++++++++++++----------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/enterprise/coderd/license/license_test.go b/enterprise/coderd/license/license_test.go index 39d6e05fb50d3..204c6e7c3f5a2 100644 --- a/enterprise/coderd/license/license_test.go +++ b/enterprise/coderd/license/license_test.go @@ -60,11 +60,12 @@ func TestEntitlements(t *testing.T) { db := databasefake.New() db.InsertLicense(context.Background(), database.InsertLicenseParams{ JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ - UserLimit: 100, - AuditLog: true, - BrowserOnly: true, - SCIM: true, - WorkspaceQuota: true, + UserLimit: 100, + AuditLog: true, + BrowserOnly: true, + SCIM: true, + WorkspaceQuota: true, + HighAvailability: true, }), Exp: time.Now().Add(time.Hour), }) @@ -81,13 +82,14 @@ func TestEntitlements(t *testing.T) { db := databasefake.New() db.InsertLicense(context.Background(), database.InsertLicenseParams{ JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ - UserLimit: 100, - AuditLog: true, - BrowserOnly: true, - SCIM: true, - WorkspaceQuota: true, - GraceAt: time.Now().Add(-time.Hour), - ExpiresAt: time.Now().Add(time.Hour), + UserLimit: 100, + AuditLog: true, + BrowserOnly: true, + SCIM: true, + WorkspaceQuota: true, + HighAvailability: true, + GraceAt: time.Now().Add(-time.Hour), + ExpiresAt: time.Now().Add(time.Hour), }), Exp: time.Now().Add(time.Hour), }) From 1f33018bd1c586956c748e65c08e2049fcfdee78 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 7 Oct 2022 13:02:40 -0500 Subject: [PATCH 09/10] fixup! impelement high availability feature --- enterprise/coderd/coderdenttest/coderdenttest.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/enterprise/coderd/coderdenttest/coderdenttest.go b/enterprise/coderd/coderdenttest/coderdenttest.go index a9e08b4aac088..2c4250325b567 100644 --- a/enterprise/coderd/coderdenttest/coderdenttest.go +++ b/enterprise/coderd/coderdenttest/coderdenttest.go @@ -132,6 +132,10 @@ func GenerateLicense(t *testing.T, options LicenseOptions) string { if options.WorkspaceQuota { workspaceQuota = 1 } + highAvailability := int64(0) + if options.HighAvailability { + highAvailability = 1 + } c := &license.Claims{ RegisteredClaims: jwt.RegisteredClaims{ @@ -147,11 +151,12 @@ func GenerateLicense(t *testing.T, options LicenseOptions) string { Version: license.CurrentVersion, AllFeatures: options.AllFeatures, Features: license.Features{ - UserLimit: options.UserLimit, - AuditLog: auditLog, - BrowserOnly: browserOnly, - SCIM: scim, - WorkspaceQuota: workspaceQuota, + UserLimit: options.UserLimit, + AuditLog: auditLog, + BrowserOnly: browserOnly, + SCIM: scim, + WorkspaceQuota: workspaceQuota, + HighAvailability: highAvailability, }, } tok := jwt.NewWithClaims(jwt.SigningMethodEdDSA, c) From b6a507020417a5704d7d1336336cb5b961fa42eb Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 7 Oct 2022 13:11:20 -0500 Subject: [PATCH 10/10] fixup! impelement high availability feature --- enterprise/cli/features_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/enterprise/cli/features_test.go b/enterprise/cli/features_test.go index f5e7b1ff3520a..f892182f164fe 100644 --- a/enterprise/cli/features_test.go +++ b/enterprise/cli/features_test.go @@ -57,7 +57,7 @@ func TestFeaturesList(t *testing.T) { var entitlements codersdk.Entitlements err := json.Unmarshal(buf.Bytes(), &entitlements) require.NoError(t, err, "unmarshal JSON output") - assert.Len(t, entitlements.Features, 5) + assert.Len(t, entitlements.Features, 6) assert.Empty(t, entitlements.Warnings) assert.Equal(t, codersdk.EntitlementNotEntitled, entitlements.Features[codersdk.FeatureUserLimit].Entitlement) @@ -69,6 +69,8 @@ func TestFeaturesList(t *testing.T) { entitlements.Features[codersdk.FeatureWorkspaceQuota].Entitlement) assert.Equal(t, codersdk.EntitlementNotEntitled, entitlements.Features[codersdk.FeatureSCIM].Entitlement) + assert.Equal(t, codersdk.EntitlementNotEntitled, + entitlements.Features[codersdk.FeatureHighAvailability].Entitlement) assert.False(t, entitlements.HasLicense) assert.False(t, entitlements.Experimental) })