diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go index 03450f6057d04..c25a9c2f773f3 100644 --- a/enterprise/tailnet/coordinator.go +++ b/enterprise/tailnet/coordinator.go @@ -10,7 +10,6 @@ import ( "net" "net/http" "sync" - "time" "github.com/google/uuid" lru "github.com/hashicorp/golang-lru/v2" @@ -79,9 +78,21 @@ func (c *haCoordinator) Node(id uuid.UUID) *agpl.Node { return node } +func (c *haCoordinator) clientLogger(id, agent uuid.UUID) slog.Logger { + return c.log.With(slog.F("client_id", id), slog.F("agent_id", agent)) +} + +func (c *haCoordinator) agentLogger(agent uuid.UUID) slog.Logger { + return c.log.With(slog.F("agent_id", agent)) +} + // ServeClient accepts a WebSocket connection that wants to connect to an agent // with the specified ID. func (c *haCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + logger := c.clientLogger(id, agent) + c.mutex.Lock() connectionSockets, ok := c.agentToConnectionSockets[agent] if !ok { @@ -89,34 +100,28 @@ func (c *haCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID c.agentToConnectionSockets[agent] = connectionSockets } - now := time.Now().Unix() + tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, 0) // Insert this connection into a map so the agent // can publish node updates. - connectionSockets[id] = &agpl.TrackedConn{ - Conn: conn, - Start: now, - LastWrite: now, - } + connectionSockets[id] = tc // When a new connection is requested, we update it with the latest // node of the agent. This allows the connection to establish. node, ok := c.nodes[agent] - c.mutex.Unlock() if ok { - data, err := json.Marshal([]*agpl.Node{node}) - if err != nil { - return xerrors.Errorf("marshal node: %w", err) - } - _, err = conn.Write(data) + err := tc.Enqueue([]*agpl.Node{node}) + c.mutex.Unlock() if err != nil { - return xerrors.Errorf("write nodes: %w", err) + return xerrors.Errorf("enqueue node: %w", err) } } else { + c.mutex.Unlock() err := c.publishClientHello(agent) if err != nil { return xerrors.Errorf("publish client hello: %w", err) } } + go tc.SendUpdates() defer func() { c.mutex.Lock() @@ -161,8 +166,9 @@ func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *js c.nodes[id] = &node // Write the new node from this client to the actively connected agent. agentSocket, ok := c.agentSockets[agent] - c.mutex.Unlock() + if !ok { + c.mutex.Unlock() // If we don't own the agent locally, send it over pubsub to a node that // owns the agent. err := c.publishNodesToAgent(agent, []*agpl.Node{&node}) @@ -171,67 +177,50 @@ func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *js } return nil } - - // Write the new node from this client to the actively - // connected agent. - data, err := json.Marshal([]*agpl.Node{&node}) - if err != nil { - return xerrors.Errorf("marshal nodes: %w", err) - } - - _, err = agentSocket.Write(data) + err = agentSocket.Enqueue([]*agpl.Node{&node}) + c.mutex.Unlock() if err != nil { - if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) { - return nil - } - return xerrors.Errorf("write json: %w", err) + return xerrors.Errorf("enqueu nodes: %w", err) } - return nil } // ServeAgent accepts a WebSocket connection to an agent that listens to // incoming connections and publishes node updates. func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + logger := c.agentLogger(id) c.agentNameCache.Add(id, name) - // Publish all nodes on this instance that want to connect to this agent. - nodes := c.nodesSubscribedToAgent(id) - if len(nodes) > 0 { - data, err := json.Marshal(nodes) - if err != nil { - return xerrors.Errorf("marshal json: %w", err) - } - _, err = conn.Write(data) - if err != nil { - return xerrors.Errorf("write nodes: %w", err) - } - } - - // This uniquely identifies a connection that belongs to this goroutine. - unique := uuid.New() - now := time.Now().Unix() - overwrites := int64(0) - - // If an old agent socket is connected, we close it - // to avoid any leaks. This shouldn't ever occur because - // we expect one agent to be running. c.mutex.Lock() + overwrites := int64(0) + // If an old agent socket is connected, we Close it to avoid any leaks. This + // shouldn't ever occur because we expect one agent to be running, but it's + // possible for a race condition to happen when an agent is disconnected and + // attempts to reconnect before the server realizes the old connection is + // dead. oldAgentSocket, ok := c.agentSockets[id] if ok { overwrites = oldAgentSocket.Overwrites + 1 _ = oldAgentSocket.Close() } - c.agentSockets[id] = &agpl.TrackedConn{ - ID: unique, - Conn: conn, + // This uniquely identifies a connection that belongs to this goroutine. + unique := uuid.New() + tc := agpl.NewTrackedConn(ctx, cancel, conn, unique, logger, overwrites) - Name: name, - Start: now, - LastWrite: now, - Overwrites: overwrites, + // Publish all nodes on this instance that want to connect to this agent. + nodes := c.nodesSubscribedToAgent(id) + if len(nodes) > 0 { + err := tc.Enqueue(nodes) + if err != nil { + c.mutex.Unlock() + return xerrors.Errorf("enqueue nodes: %w", err) + } } + c.agentSockets[id] = tc c.mutex.Unlock() + go tc.SendUpdates() // Tell clients on other instances to send a callmemaybe to us. err := c.publishAgentHello(id) @@ -269,8 +258,6 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err } func (c *haCoordinator) nodesSubscribedToAgent(agentID uuid.UUID) []*agpl.Node { - c.mutex.Lock() - defer c.mutex.Unlock() sockets, ok := c.agentToConnectionSockets[agentID] if !ok { return nil @@ -320,25 +307,11 @@ func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) ( return &node, nil } - data, err := json.Marshal([]*agpl.Node{&node}) - if err != nil { - c.mutex.Unlock() - return nil, xerrors.Errorf("marshal nodes: %w", err) - } - // Publish the new node to every listening socket. - var wg sync.WaitGroup - wg.Add(len(connectionSockets)) for _, connectionSocket := range connectionSockets { - connectionSocket := connectionSocket - go func() { - defer wg.Done() - _ = connectionSocket.SetWriteDeadline(time.Now().Add(5 * time.Second)) - _, _ = connectionSocket.Write(data) - }() + _ = connectionSocket.Enqueue([]*agpl.Node{&node}) } c.mutex.Unlock() - wg.Wait() return &node, nil } @@ -502,18 +475,19 @@ func (c *haCoordinator) handlePubsubMessage(ctx context.Context, message []byte) c.mutex.Lock() agentSocket, ok := c.agentSockets[agentUUID] + c.mutex.Unlock() if !ok { - c.mutex.Unlock() return } - c.mutex.Unlock() - // We get a single node over pubsub, so turn into an array. - _, err = agentSocket.Write(nodeJSON) + // Socket takes a slice of Nodes, so we need to parse the JSON here. + var nodes []*agpl.Node + err = json.Unmarshal(nodeJSON, &nodes) + if err != nil { + c.log.Error(ctx, "invalid nodes JSON", slog.F("id", agentID), slog.Error(err), slog.F("node", string(nodeJSON))) + } + err = agentSocket.Enqueue(nodes) if err != nil { - if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) { - return - } c.log.Error(ctx, "send callmemaybe to agent", slog.Error(err)) return } @@ -536,7 +510,9 @@ func (c *haCoordinator) handlePubsubMessage(ctx context.Context, message []byte) return } + c.mutex.RLock() nodes := c.nodesSubscribedToAgent(agentUUID) + c.mutex.RUnlock() if len(nodes) > 0 { err := c.publishNodesToAgent(agentUUID, nodes) if err != nil { diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 0fc790053a822..2f11566ded9a1 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -113,24 +113,14 @@ func ServeCoordinator(conn net.Conn, updateNodes func(node []*Node) error) (func }, errChan } -const loggerName = "coord" +const LoggerName = "coord" // NewCoordinator constructs a new in-memory connection coordinator. This // coordinator is incompatible with multiple Coder replicas as all node data is // in-memory. func NewCoordinator(logger slog.Logger) Coordinator { - nameCache, err := lru.New[uuid.UUID, string](512) - if err != nil { - panic("make lru cache: " + err.Error()) - } - return &coordinator{ - logger: logger.Named(loggerName), - closed: false, - nodes: map[uuid.UUID]*Node{}, - agentSockets: map[uuid.UUID]*TrackedConn{}, - agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]*TrackedConn{}, - agentNameCache: nameCache, + core: newCore(logger), } } @@ -142,6 +132,12 @@ func NewCoordinator(logger slog.Logger) Coordinator { // This coordinator is incompatible with multiple Coder // replicas as all node data is in-memory. type coordinator struct { + core *core +} + +// core is an in-memory structure of Node and TrackedConn mappings. Its methods may be called from multiple goroutines; +// it is protected by a mutex to ensure data stay consistent. +type core struct { logger slog.Logger mutex sync.RWMutex closed bool @@ -159,8 +155,30 @@ type coordinator struct { agentNameCache *lru.Cache[uuid.UUID, string] } +func newCore(logger slog.Logger) *core { + nameCache, err := lru.New[uuid.UUID, string](512) + if err != nil { + panic("make lru cache: " + err.Error()) + } + + return &core{ + logger: logger, + closed: false, + nodes: make(map[uuid.UUID]*Node), + agentSockets: map[uuid.UUID]*TrackedConn{}, + agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]*TrackedConn{}, + agentNameCache: nameCache, + } +} + +var ErrWouldBlock = xerrors.New("would block") + type TrackedConn struct { - net.Conn + ctx context.Context + cancel func() + conn net.Conn + updates chan []*Node + logger slog.Logger // ID is an ephemeral UUID used to uniquely identify the owner of the // connection. @@ -172,26 +190,105 @@ type TrackedConn struct { Overwrites int64 } -func (t *TrackedConn) Write(b []byte) (n int, err error) { +func (t *TrackedConn) Enqueue(n []*Node) (err error) { atomic.StoreInt64(&t.LastWrite, time.Now().Unix()) - return t.Conn.Write(b) + select { + case t.updates <- n: + return nil + default: + return ErrWouldBlock + } +} + +// Close the connection and cancel the context for reading node updates from the queue +func (t *TrackedConn) Close() error { + t.cancel() + return t.conn.Close() +} + +// SendUpdates reads node updates and writes them to the connection. Ends when writes hit an error or context is +// canceled. +func (t *TrackedConn) SendUpdates() { + for { + select { + case <-t.ctx.Done(): + t.logger.Debug(t.ctx, "done sending updates") + return + case nodes := <-t.updates: + data, err := json.Marshal(nodes) + if err != nil { + t.logger.Error(t.ctx, "unable to marshal nodes update", slog.Error(err), slog.F("nodes", nodes)) + return + } + + // Set a deadline so that hung connections don't put back pressure on the system. + // Node updates are tiny, so even the dinkiest connection can handle them if it's not hung. + err = t.conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + // often, this is just because the connection is closed/broken, so only log at debug. + t.logger.Debug(t.ctx, "unable to set write deadline", slog.Error(err)) + _ = t.Close() + return + } + _, err = t.conn.Write(data) + if err != nil { + // often, this is just because the connection is closed/broken, so only log at debug. + t.logger.Debug(t.ctx, "could not write nodes to connection", slog.Error(err), slog.F("nodes", nodes)) + _ = t.Close() + return + } + t.logger.Debug(t.ctx, "wrote nodes", slog.F("nodes", nodes)) + } + } +} + +func NewTrackedConn(ctx context.Context, cancel func(), conn net.Conn, id uuid.UUID, logger slog.Logger, overwrites int64) *TrackedConn { + // buffer updates so they don't block, since we hold the + // coordinator mutex while queuing. Node updates don't + // come quickly, so 512 should be plenty for all but + // the most pathological cases. + updates := make(chan []*Node, 512) + now := time.Now().Unix() + return &TrackedConn{ + ctx: ctx, + conn: conn, + cancel: cancel, + updates: updates, + logger: logger, + ID: id, + Start: now, + LastWrite: now, + Overwrites: overwrites, + } } // Node returns an in-memory node by ID. // If the node does not exist, nil is returned. func (c *coordinator) Node(id uuid.UUID) *Node { + return c.core.node(id) +} + +func (c *core) node(id uuid.UUID) *Node { c.mutex.Lock() defer c.mutex.Unlock() return c.nodes[id] } func (c *coordinator) NodeCount() int { + return c.core.nodeCount() +} + +func (c *core) nodeCount() int { c.mutex.Lock() defer c.mutex.Unlock() return len(c.nodes) } func (c *coordinator) AgentCount() int { + return c.core.agentCount() +} + +func (c *core) agentCount() int { c.mutex.Lock() defer c.mutex.Unlock() return len(c.agentSockets) @@ -200,129 +297,207 @@ func (c *coordinator) AgentCount() int { // ServeClient accepts a WebSocket connection that wants to connect to an agent // with the specified ID. func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { - logger := c.logger.With(slog.F("client_id", id), slog.F("agent_id", agent)) - logger.Debug(context.Background(), "coordinating client") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + logger := c.core.clientLogger(id, agent) + logger.Debug(ctx, "coordinating client") + tc, err := c.core.initAndTrackClient(ctx, cancel, conn, id, agent) + if err != nil { + return err + } + defer c.core.clientDisconnected(id, agent) + + // On this goroutine, we read updates from the client and publish them. We start a second goroutine + // to write updates back to the client. + go tc.SendUpdates() + + decoder := json.NewDecoder(conn) + for { + err := c.handleNextClientMessage(id, agent, decoder) + if err != nil { + logger.Debug(ctx, "unable to read client update; closed conn?", slog.Error(err)) + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) { + return nil + } + return xerrors.Errorf("handle next client message: %w", err) + } + } +} + +func (c *core) clientLogger(id, agent uuid.UUID) slog.Logger { + return c.logger.With(slog.F("client_id", id), slog.F("agent_id", agent)) +} + +// initAndTrackClient creates a TrackedConn for the client, and sends any initial Node updates if we have any. It is +// one function that does two things because it is critical that we hold the mutex for both things, lest we miss some +// updates. +func (c *core) initAndTrackClient( + ctx context.Context, cancel func(), conn net.Conn, id, agent uuid.UUID, +) ( + *TrackedConn, error, +) { + logger := c.clientLogger(id, agent) c.mutex.Lock() + defer c.mutex.Unlock() if c.closed { - c.mutex.Unlock() - return xerrors.New("coordinator is closed") + return nil, xerrors.New("coordinator is closed") } + tc := NewTrackedConn(ctx, cancel, conn, id, logger, 0) // When a new connection is requested, we update it with the latest // node of the agent. This allows the connection to establish. node, ok := c.nodes[agent] - c.mutex.Unlock() if ok { - data, err := json.Marshal([]*Node{node}) - if err != nil { - return xerrors.Errorf("marshal node: %w", err) - } - _, err = conn.Write(data) - logger.Debug(context.Background(), "wrote initial node") + err := tc.Enqueue([]*Node{node}) + // this should never error since we're still the only goroutine that + // knows about the TrackedConn. If we hit an error something really + // wrong is happening if err != nil { - return xerrors.Errorf("write nodes: %w", err) + logger.Critical(ctx, "unable to queue initial node", slog.Error(err)) + return nil, err } } - c.mutex.Lock() + + // Insert this connection into a map so the agent + // can publish node updates. connectionSockets, ok := c.agentToConnectionSockets[agent] if !ok { connectionSockets = map[uuid.UUID]*TrackedConn{} c.agentToConnectionSockets[agent] = connectionSockets } + connectionSockets[id] = tc + logger.Debug(ctx, "added tracked connection") + return tc, nil +} - now := time.Now().Unix() - // Insert this connection into a map so the agent - // can publish node updates. - connectionSockets[id] = &TrackedConn{ - Conn: conn, - Start: now, - LastWrite: now, +func (c *core) clientDisconnected(id, agent uuid.UUID) { + logger := c.clientLogger(id, agent) + c.mutex.Lock() + defer c.mutex.Unlock() + // Clean all traces of this connection from the map. + delete(c.nodes, id) + logger.Debug(context.Background(), "deleted client node") + connectionSockets, ok := c.agentToConnectionSockets[agent] + if !ok { + return } - c.mutex.Unlock() - logger.Debug(context.Background(), "added tracked connection") - defer func() { - c.mutex.Lock() - defer c.mutex.Unlock() - // Clean all traces of this connection from the map. - delete(c.nodes, id) - logger.Debug(context.Background(), "deleted client node") - connectionSockets, ok := c.agentToConnectionSockets[agent] - if !ok { - return - } - delete(connectionSockets, id) - logger.Debug(context.Background(), "deleted client connectionSocket from map") - if len(connectionSockets) != 0 { - return - } - delete(c.agentToConnectionSockets, agent) - logger.Debug(context.Background(), "deleted last client connectionSocket from map") - }() - - decoder := json.NewDecoder(conn) - for { - err := c.handleNextClientMessage(id, agent, decoder) - if err != nil { - if errors.Is(err, io.EOF) { - return nil - } - return xerrors.Errorf("handle next client message: %w", err) - } + delete(connectionSockets, id) + logger.Debug(context.Background(), "deleted client connectionSocket from map") + if len(connectionSockets) != 0 { + return } + delete(c.agentToConnectionSockets, agent) + logger.Debug(context.Background(), "deleted last client connectionSocket from map") } func (c *coordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error { - logger := c.logger.With(slog.F("client_id", id), slog.F("agent_id", agent)) + logger := c.core.clientLogger(id, agent) var node Node err := decoder.Decode(&node) if err != nil { return xerrors.Errorf("read json: %w", err) } logger.Debug(context.Background(), "got client node update", slog.F("node", node)) + return c.core.clientNodeUpdate(id, agent, &node) +} +func (c *core) clientNodeUpdate(id, agent uuid.UUID, node *Node) error { + logger := c.clientLogger(id, agent) c.mutex.Lock() + defer c.mutex.Unlock() // Update the node of this client in our in-memory map. If an agent entirely // shuts down and reconnects, it needs to be aware of all clients attempting // to establish connections. - c.nodes[id] = &node + c.nodes[id] = node agentSocket, ok := c.agentSockets[agent] if !ok { - c.mutex.Unlock() logger.Debug(context.Background(), "no agent socket, unable to send node") return nil } - c.mutex.Unlock() - // Write the new node from this client to the actively connected agent. - data, err := json.Marshal([]*Node{&node}) + err := agentSocket.Enqueue([]*Node{node}) if err != nil { - return xerrors.Errorf("marshal nodes: %w", err) + return xerrors.Errorf("Enqueue node: %w", err) } + logger.Debug(context.Background(), "enqueued node to agent") + return nil +} + +func (c *core) agentLogger(id uuid.UUID) slog.Logger { + return c.logger.With(slog.F("agent_id", id)) +} - _, err = agentSocket.Write(data) +// ServeAgent accepts a WebSocket connection to an agent that +// listens to incoming connections and publishes node updates. +func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + logger := c.core.agentLogger(id) + logger.Debug(context.Background(), "coordinating agent") + // This uniquely identifies a connection that belongs to this goroutine. + unique := uuid.New() + tc, err := c.core.initAndTrackAgent(ctx, cancel, conn, id, unique, name) if err != nil { - if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) { - return nil + return err + } + + // On this goroutine, we read updates from the agent and publish them. We start a second goroutine + // to write updates back to the agent. + go tc.SendUpdates() + + defer c.core.agentDisconnected(id, unique) + + decoder := json.NewDecoder(conn) + for { + err := c.handleNextAgentMessage(id, decoder) + if err != nil { + logger.Debug(ctx, "unable to read agent update; closed conn?", slog.Error(err)) + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) { + return nil + } + return xerrors.Errorf("handle next agent message: %w", err) } - return xerrors.Errorf("write json: %w", err) } - logger.Debug(context.Background(), "sent client node to agent") +} - return nil +func (c *core) agentDisconnected(id, unique uuid.UUID) { + logger := c.agentLogger(id) + c.mutex.Lock() + defer c.mutex.Unlock() + + // Only delete the connection if it's ours. It could have been + // overwritten. + if idConn, ok := c.agentSockets[id]; ok && idConn.ID == unique { + delete(c.agentSockets, id) + delete(c.nodes, id) + logger.Debug(context.Background(), "deleted agent socket and node") + } } -// ServeAgent accepts a WebSocket connection to an agent that -// listens to incoming connections and publishes node updates. -func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) error { +// initAndTrackAgent creates a TrackedConn for the agent, and sends any initial nodes updates if we have any. It is +// one function that does two things because it is critical that we hold the mutex for both things, lest we miss some +// updates. +func (c *core) initAndTrackAgent(ctx context.Context, cancel func(), conn net.Conn, id, unique uuid.UUID, name string) (*TrackedConn, error) { logger := c.logger.With(slog.F("agent_id", id)) - logger.Debug(context.Background(), "coordinating agent") c.mutex.Lock() + defer c.mutex.Unlock() if c.closed { - c.mutex.Unlock() - return xerrors.New("coordinator is closed") + return nil, xerrors.New("coordinator is closed") } + overwrites := int64(0) + // If an old agent socket is connected, we Close it to avoid any leaks. This + // shouldn't ever occur because we expect one agent to be running, but it's + // possible for a race condition to happen when an agent is disconnected and + // attempts to reconnect before the server realizes the old connection is + // dead. + oldAgentSocket, ok := c.agentSockets[id] + if ok { + overwrites = oldAgentSocket.Overwrites + 1 + _ = oldAgentSocket.Close() + } + tc := NewTrackedConn(ctx, cancel, conn, unique, logger, overwrites) c.agentNameCache.Add(id, name) sockets, ok := c.agentToConnectionSockets[id] @@ -337,117 +512,67 @@ func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) error } nodes = append(nodes, node) } - c.mutex.Unlock() - data, err := json.Marshal(nodes) + err := tc.Enqueue(nodes) + // this should never error since we're still the only goroutine that + // knows about the TrackedConn. If we hit an error something really + // wrong is happening if err != nil { - return xerrors.Errorf("marshal json: %w", err) + logger.Critical(ctx, "unable to queue initial nodes", slog.Error(err)) + return nil, err } - _, err = conn.Write(data) - logger.Debug(context.Background(), "wrote initial client(s) to agent", slog.F("nodes", nodes)) - if err != nil { - return xerrors.Errorf("write nodes: %w", err) - } - c.mutex.Lock() + logger.Debug(ctx, "wrote initial client(s) to agent", slog.F("nodes", nodes)) } - // This uniquely identifies a connection that belongs to this goroutine. - unique := uuid.New() - now := time.Now().Unix() - overwrites := int64(0) - - // If an old agent socket is connected, we close it to avoid any leaks. This - // shouldn't ever occur because we expect one agent to be running, but it's - // possible for a race condition to happen when an agent is disconnected and - // attempts to reconnect before the server realizes the old connection is - // dead. - oldAgentSocket, ok := c.agentSockets[id] - if ok { - overwrites = oldAgentSocket.Overwrites + 1 - _ = oldAgentSocket.Close() - } - c.agentSockets[id] = &TrackedConn{ - ID: unique, - Conn: conn, - - Name: name, - Start: now, - LastWrite: now, - Overwrites: overwrites, - } - - c.mutex.Unlock() - logger.Debug(context.Background(), "added agent socket") - defer func() { - c.mutex.Lock() - defer c.mutex.Unlock() - - // Only delete the connection if it's ours. It could have been - // overwritten. - if idConn, ok := c.agentSockets[id]; ok && idConn.ID == unique { - delete(c.agentSockets, id) - delete(c.nodes, id) - logger.Debug(context.Background(), "deleted agent socket") - } - }() - - decoder := json.NewDecoder(conn) - for { - err := c.handleNextAgentMessage(id, decoder) - if err != nil { - if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) { - return nil - } - return xerrors.Errorf("handle next agent message: %w", err) - } - } + c.agentSockets[id] = tc + logger.Debug(ctx, "added agent socket") + return tc, nil } func (c *coordinator) handleNextAgentMessage(id uuid.UUID, decoder *json.Decoder) error { - logger := c.logger.With(slog.F("agent_id", id)) + logger := c.core.agentLogger(id) var node Node err := decoder.Decode(&node) if err != nil { return xerrors.Errorf("read json: %w", err) } logger.Debug(context.Background(), "decoded agent node", slog.F("node", node)) + return c.core.agentNodeUpdate(id, &node) +} +func (c *core) agentNodeUpdate(id uuid.UUID, node *Node) error { + logger := c.agentLogger(id) c.mutex.Lock() - c.nodes[id] = &node + defer c.mutex.Unlock() + c.nodes[id] = node connectionSockets, ok := c.agentToConnectionSockets[id] if !ok { - c.mutex.Unlock() logger.Debug(context.Background(), "no client sockets; unable to send node") return nil } - data, err := json.Marshal([]*Node{&node}) - if err != nil { - c.mutex.Unlock() - return xerrors.Errorf("marshal nodes: %w", err) - } // Publish the new node to every listening socket. - var wg sync.WaitGroup - wg.Add(len(connectionSockets)) for clientID, connectionSocket := range connectionSockets { - clientID := clientID - connectionSocket := connectionSocket - go func() { - _ = connectionSocket.SetWriteDeadline(time.Now().Add(5 * time.Second)) - _, err := connectionSocket.Write(data) - logger.Debug(context.Background(), "sent agent node to client", + err := connectionSocket.Enqueue([]*Node{node}) + if err == nil { + logger.Debug(context.Background(), "enqueued agent node to client", + slog.F("client_id", clientID)) + } else { + // queue is backed up for some reason. This is bad, but we don't want to drop + // updates to other clients over it. Log and move on. + logger.Error(context.Background(), "failed to Enqueue", slog.F("client_id", clientID), slog.Error(err)) - wg.Done() - }() + } } - - c.mutex.Unlock() - wg.Wait() return nil } // Close closes all of the open connections in the coordinator and stops the // coordinator from accepting new connections. func (c *coordinator) Close() error { + return c.core.close() +} + +func (c *core) close() error { c.mutex.Lock() if c.closed { c.mutex.Unlock() @@ -484,6 +609,10 @@ func (c *coordinator) Close() error { } func (c *coordinator) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) { + c.core.serveHTTPDebug(w, r) +} + +func (c *core) serveHTTPDebug(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html; charset=utf-8") c.mutex.RLock() diff --git a/tailnet/coordinator_test.go b/tailnet/coordinator_test.go index 61117751cfc96..407f5bb2cf767 100644 --- a/tailnet/coordinator_test.go +++ b/tailnet/coordinator_test.go @@ -1,8 +1,10 @@ package tailnet_test import ( + "encoding/json" "net" "testing" + "time" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" @@ -247,3 +249,88 @@ func TestCoordinator(t *testing.T) { <-closeAgentChan1 }) } + +// TestCoordinator_AgentUpdateWhileClientConnects tests for regression on +// https://github.com/coder/coder/issues/7295 +func TestCoordinator_AgentUpdateWhileClientConnects(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + coordinator := tailnet.NewCoordinator(logger) + agentWS, agentServerWS := net.Pipe() + defer agentWS.Close() + + agentID := uuid.New() + go func() { + err := coordinator.ServeAgent(agentServerWS, agentID, "") + assert.NoError(t, err) + }() + + // send an agent update before the client connects so that there is + // node data available to send right away. + aNode := tailnet.Node{PreferredDERP: 0} + aData, err := json.Marshal(&aNode) + require.NoError(t, err) + err = agentWS.SetWriteDeadline(time.Now().Add(testutil.WaitShort)) + require.NoError(t, err) + _, err = agentWS.Write(aData) + require.NoError(t, err) + + require.Eventually(t, func() bool { + return coordinator.Node(agentID) != nil + }, testutil.WaitShort, testutil.IntervalFast) + + // Connect from the client + clientWS, clientServerWS := net.Pipe() + defer clientWS.Close() + clientID := uuid.New() + go func() { + err := coordinator.ServeClient(clientServerWS, clientID, agentID) + assert.NoError(t, err) + }() + + // peek one byte from the node update, so we know the coordinator is + // trying to write to the client. + // buffer needs to be 2 characters longer because return value is a list + // so, it needs [ and ] + buf := make([]byte, len(aData)+2) + err = clientWS.SetReadDeadline(time.Now().Add(testutil.WaitShort)) + require.NoError(t, err) + n, err := clientWS.Read(buf[:1]) + require.NoError(t, err) + require.Equal(t, 1, n) + + // send a second update + aNode.PreferredDERP = 1 + require.NoError(t, err) + aData, err = json.Marshal(&aNode) + require.NoError(t, err) + err = agentWS.SetWriteDeadline(time.Now().Add(testutil.WaitShort)) + require.NoError(t, err) + _, err = agentWS.Write(aData) + require.NoError(t, err) + + // read the rest of the update from the client, should be initial node. + err = clientWS.SetReadDeadline(time.Now().Add(testutil.WaitShort)) + require.NoError(t, err) + n, err = clientWS.Read(buf[1:]) + require.NoError(t, err) + require.Equal(t, len(buf)-1, n) + var cNodes []*tailnet.Node + err = json.Unmarshal(buf, &cNodes) + require.NoError(t, err) + require.Len(t, cNodes, 1) + require.Equal(t, 0, cNodes[0].PreferredDERP) + + // read second update + // without a fix for https://github.com/coder/coder/issues/7295 our + // read would time out here. + err = clientWS.SetReadDeadline(time.Now().Add(testutil.WaitShort)) + require.NoError(t, err) + n, err = clientWS.Read(buf) + require.NoError(t, err) + require.Equal(t, len(buf), n) + err = json.Unmarshal(buf, &cNodes) + require.NoError(t, err) + require.Len(t, cNodes, 1) + require.Equal(t, 1, cNodes[0].PreferredDERP) +}