Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
use multiagents for all clients
  • Loading branch information
coadler committed Jun 29, 2023
commit 457470da10d883c7e10ce5ff7937dd1e8e09a450
54 changes: 42 additions & 12 deletions coderd/tailnet.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ func init() {
}
}

// TODO(coadler): ServerTailnet does not currently remove stale peers.

// NewServerTailnet creates a new tailnet intended for use by coderd. It
// automatically falls back to wsconncache if a legacy agent is encountered.
func NewServerTailnet(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although maybe a bit weird, I'd argue we should put this in workspaceagents.go since wsconncache will be going away.

Expand Down Expand Up @@ -102,14 +100,49 @@ func NewServerTailnet(
})

go tn.watchAgentUpdates()
go tn.expireOldAgents()
return tn, nil
}

func (s *ServerTailnet) expireOldAgents() {
const (
tick = 5 * time.Minute
cutoff = 30 * time.Minute
)

ticker := time.NewTicker(tick)
defer ticker.Stop()

for {
select {
case <-s.ctx.Done():
return
case <-ticker.C:
}

s.nodesMu.Lock()
agentConn := s.getAgentConn()
for agentID, node := range s.agentNodes {
if time.Since(node.lastConnection) > cutoff {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This measures the last time we started a connection, not the last time the connection was used. If we proxy a long-lived connection like ReconnectingPTY or a devURL websocket, it could easily be in use for greater than 30 minutes.

We might need some ref-counting to keep track of the connections to each agent, so that we expire them when they are no longer used.

err := agentConn.UnsubscribeAgent(agentID)
if err != nil {
s.logger.Error(s.ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
}
delete(s.agentNodes, agentID)

// TODO(coadler): actually remove from the netmap
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do this before merge

}
}
s.nodesMu.Unlock()
}
}

func (s *ServerTailnet) watchAgentUpdates() {
for {
nodes := s.getAgentConn().NextUpdate(s.ctx)
if nodes == nil {
if s.getAgentConn().IsClosed() && s.ctx.Err() == nil {
conn := s.getAgentConn()
nodes, ok := conn.NextUpdate(s.ctx)
if !ok {
if conn.IsClosed() && s.ctx.Err() == nil {
s.reinitCoordinator()
continue
}
Expand All @@ -129,24 +162,22 @@ func (s *ServerTailnet) getAgentConn() tailnet.MultiAgentConn {
}

func (s *ServerTailnet) reinitCoordinator() {
s.nodesMu.Lock()
agentConn := (*s.coord.Load()).ServeMultiAgent(uuid.New())
s.agentConn.Store(&agentConn)

s.nodesMu.Lock()
// Resubscribe to all of the agents we're tracking.
for agentID, agentNode := range s.agentNodes {
closer, err := agentConn.SubscribeAgent(agentID)
for agentID := range s.agentNodes {
err := agentConn.SubscribeAgent(agentID)
if err != nil {
s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID))
}
agentNode.close = closer
}
s.nodesMu.Unlock()
}

type tailnetNode struct {
lastConnection time.Time
close func()
}

type ServerTailnet struct {
Expand Down Expand Up @@ -210,13 +241,12 @@ func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
// If we don't have the node, subscribe.
if !ok {
s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID))
closer, err := s.getAgentConn().SubscribeAgent(agentID)
err := s.getAgentConn().SubscribeAgent(agentID)
if err != nil {
return xerrors.Errorf("subscribe agent: %w", err)
}
tnode = &tailnetNode{
lastConnection: time.Now(),
close: closer,
}
s.agentNodes[agentID] = tnode
} else {
Expand Down
144 changes: 80 additions & 64 deletions enterprise/tailnet/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"sync"

"github.com/google/uuid"
"github.com/hashicorp/go-multierror"
lru "github.com/hashicorp/golang-lru/v2"
"golang.org/x/xerrors"

Expand Down Expand Up @@ -42,6 +41,8 @@ func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, err
agentSockets: map[uuid.UUID]agpl.Enqueueable{},
agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]agpl.Enqueueable{},
agentNameCache: nameCache,
clients: map[uuid.UUID]agpl.Enqueueable{},
clientsToAgents: map[uuid.UUID]map[uuid.UUID]struct{}{},
legacyAgents: map[uuid.UUID]struct{}{},
}

Expand All @@ -57,14 +58,22 @@ func (c *haCoordinator) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn {
ID: id,
Logger: c.log,
AgentIsLegacyFunc: c.agentIsLegacy,
OnSubscribe: c.multiAgentSubscribe,
OnNodeUpdate: c.multiAgentUpdate,
OnSubscribe: c.clientSubscribeToAgent,
OnNodeUpdate: c.clientNodeUpdate,
OnRemove: c.clientDisconnected,
}).Init()
c.mutex.Lock()
c.clients[id] = m
c.clientsToAgents[id] = map[uuid.UUID]struct{}{}
c.mutex.Unlock()
return m
}

func (c *haCoordinator) multiAgentSubscribe(enq agpl.Enqueueable, agentID uuid.UUID) (func(), error) {
func (c *haCoordinator) clientSubscribeToAgent(enq agpl.Enqueueable, agentID uuid.UUID) error {
c.mutex.Lock()
defer c.mutex.Unlock()

c.initOrSetAgentConnectionSocketLocked(agentID, enq)

node := c.nodes[enq.UniqueID()]

Expand All @@ -73,44 +82,43 @@ func (c *haCoordinator) multiAgentSubscribe(enq agpl.Enqueueable, agentID uuid.U
if ok {
err := enq.Enqueue([]*agpl.Node{agentNode})
if err != nil {
return nil, xerrors.Errorf("enqueue agent on subscribe: %w", err)
return xerrors.Errorf("enqueue agent on subscribe: %w", err)
}
} else {
// If we don't have the node locally, notify other coordinators.
c.mutex.Unlock()
err := c.publishClientHello(agentID)
if err != nil {
return nil, xerrors.Errorf("publish client hello: %w", err)
return xerrors.Errorf("publish client hello: %w", err)
}
}

if node != nil {
err := c.handleClientUpdate(enq.UniqueID(), agentID, node)
err := c.sendNodeToAgentLocked(agentID, node)
if err != nil {
return nil, xerrors.Errorf("handle client update: %w", err)
return xerrors.Errorf("handle client update: %w", err)
}
}

return c.cleanupClientConn(enq.UniqueID(), agentID), nil
}

func (c *haCoordinator) multiAgentUpdate(id uuid.UUID, agents []uuid.UUID, node *agpl.Node) error {
var errs *multierror.Error
// This isn't the most efficient, but this coordinator is being deprecated
// soon anyways.
for _, agent := range agents {
err := c.handleClientUpdate(id, agent, node)
if err != nil {
errs = multierror.Append(errs, err)
}
}
if errs != nil {
return errs
}

return nil
}

// func (c *haCoordinator) multiAgentUpdate(id uuid.UUID, agents []uuid.UUID, node *agpl.Node) error {
// var errs *multierror.Error
// // This isn't the most efficient, but this coordinator is being deprecated
// // soon anyways.
// for _, agent := range agents {
// err := c.handleClientUpdate(id, agent, node)
// if err != nil {
// errs = multierror.Append(errs, err)
// }
// }
// if errs != nil {
// return errs
// }

// return nil
// }

type haCoordinator struct {
id uuid.UUID
log slog.Logger
Expand All @@ -127,6 +135,9 @@ type haCoordinator struct {
// are subscribed to updates for that agent.
agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]agpl.Enqueueable

clients map[uuid.UUID]agpl.Enqueueable
clientsToAgents map[uuid.UUID]map[uuid.UUID]struct{}

// agentNameCache holds a cache of agent names. If one of them disappears,
// it's helpful to have a name cached for debugging.
agentNameCache *lru.Cache[uuid.UUID, string]
Expand All @@ -152,40 +163,25 @@ func (c *haCoordinator) agentLogger(agent uuid.UUID) slog.Logger {

// ServeClient accepts a WebSocket connection that wants to connect to an agent
// with the specified ID.
func (c *haCoordinator) ServeClient(conn net.Conn, id, agent uuid.UUID) error {
func (c *haCoordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logger := c.clientLogger(id, agent)
logger := c.clientLogger(id, agentID)

c.mutex.Lock()

tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, 0)
c.initOrSetAgentConnectionSocketLocked(agent, tc)
ma := c.ServeMultiAgent(id)
defer ma.Close()

// When a new connection is requested, we update it with the latest
// node of the agent. This allows the connection to establish.
node, ok := c.nodes[agent]
if ok {
err := tc.Enqueue([]*agpl.Node{node})
c.mutex.Unlock()
if err != nil {
return xerrors.Errorf("enqueue node: %w", err)
}
} else {
c.mutex.Unlock()
err := c.publishClientHello(agent)
if err != nil {
return xerrors.Errorf("publish client hello: %w", err)
}
err := ma.SubscribeAgent(agentID)
if err != nil {
return xerrors.Errorf("subscribe agent: %w", err)
}
go tc.SendUpdates()

defer c.cleanupClientConn(id, agent)
go agpl.SendUpdatesToConn(ctx, logger, ma, conn)

decoder := json.NewDecoder(conn)
// Indefinitely handle messages from the client websocket.
for {
err := c.handleNextClientMessage(id, agent, decoder)
err := c.handleNextClientMessage(id, decoder)
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) {
return nil
Expand All @@ -202,12 +198,14 @@ func (c *haCoordinator) initOrSetAgentConnectionSocketLocked(agentID uuid.UUID,
c.agentToConnectionSockets[agentID] = connectionSockets
}
connectionSockets[enq.UniqueID()] = enq
c.clientsToAgents[enq.UniqueID()][agentID] = struct{}{}
}

func (c *haCoordinator) cleanupClientConn(id, agentID uuid.UUID) func() {
return func() {
c.mutex.Lock()
defer c.mutex.Unlock()
func (c *haCoordinator) clientDisconnected(id uuid.UUID) {
c.mutex.Lock()
defer c.mutex.Unlock()

for agentID := range c.clientsToAgents[id] {
// Clean all traces of this connection from the map.
delete(c.nodes, id)
connectionSockets, ok := c.agentToConnectionSockets[agentID]
Expand All @@ -220,39 +218,52 @@ func (c *haCoordinator) cleanupClientConn(id, agentID uuid.UUID) func() {
}
delete(c.agentToConnectionSockets, agentID)
}

delete(c.clients, id)
delete(c.clientsToAgents, id)
}

func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error {
func (c *haCoordinator) handleNextClientMessage(id uuid.UUID, decoder *json.Decoder) error {
var node agpl.Node
err := decoder.Decode(&node)
if err != nil {
return xerrors.Errorf("read json: %w", err)
}

return c.handleClientUpdate(id, agent, &node)
return c.clientNodeUpdate(id, &node)
}

func (c *haCoordinator) handleClientUpdate(id, agent uuid.UUID, node *agpl.Node) error {
func (c *haCoordinator) clientNodeUpdate(id uuid.UUID, node *agpl.Node) error {
c.mutex.Lock()
defer c.mutex.Unlock()
// Update the node of this client in our in-memory map. If an agent entirely
// shuts down and reconnects, it needs to be aware of all clients attempting
// to establish connections.
c.nodes[id] = node

// Write the new node from this client to the actively connected agent.
agentSocket, ok := c.agentSockets[agent]
for agentID := range c.clientsToAgents[id] {
// Write the new node from this client to the actively connected agent.
err := c.sendNodeToAgentLocked(agentID, node)
if err != nil {
c.log.Error(context.Background(), "send node to agent", slog.Error(err), slog.F("agent_id", agentID))
}
}

return nil
}

func (c *haCoordinator) sendNodeToAgentLocked(agentID uuid.UUID, node *agpl.Node) error {
agentSocket, ok := c.agentSockets[agentID]
if !ok {
c.mutex.Unlock()
// If we don't own the agent locally, send it over pubsub to a node that
// owns the agent.
err := c.publishNodesToAgent(agent, []*agpl.Node{node})
err := c.publishNodesToAgent(agentID, []*agpl.Node{node})
if err != nil {
return xerrors.Errorf("publish node to agent")
}
return nil
}
err := agentSocket.Enqueue([]*agpl.Node{node})
c.mutex.Unlock()
if err != nil {
return xerrors.Errorf("enqueue node: %w", err)
}
Expand Down Expand Up @@ -422,7 +433,7 @@ func (c *haCoordinator) Close() error {
for _, socket := range c.agentSockets {
socket := socket
go func() {
_ = socket.Close()
_ = socket.CoordinatorClose()
wg.Done()
}()
}
Expand All @@ -432,12 +443,17 @@ func (c *haCoordinator) Close() error {
for _, socket := range connMap {
socket := socket
go func() {
_ = socket.Close()
_ = socket.CoordinatorClose()
wg.Done()
}()
}
}

// Ensure clients that have no subscriptions are properly closed.
for _, client := range c.clients {
_ = client.CoordinatorClose()
}

wg.Wait()
return nil
}
Expand Down
Loading