Skip to content

fix: coordinator node update race #7345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 57 additions & 81 deletions enterprise/tailnet/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"net"
"net/http"
"sync"
"time"

"github.com/google/uuid"
lru "github.com/hashicorp/golang-lru/v2"
Expand Down Expand Up @@ -79,44 +78,50 @@ func (c *haCoordinator) Node(id uuid.UUID) *agpl.Node {
return node
}

func (c *haCoordinator) clientLogger(id, agent uuid.UUID) slog.Logger {
return c.log.With(slog.F("client_id", id), slog.F("agent_id", agent))
}

func (c *haCoordinator) agentLogger(agent uuid.UUID) slog.Logger {
return c.log.With(slog.F("agent_id", agent))
}

// ServeClient accepts a WebSocket connection that wants to connect to an agent
// with the specified ID.
func (c *haCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logger := c.clientLogger(id, agent)

c.mutex.Lock()
connectionSockets, ok := c.agentToConnectionSockets[agent]
if !ok {
connectionSockets = map[uuid.UUID]*agpl.TrackedConn{}
c.agentToConnectionSockets[agent] = connectionSockets
}

now := time.Now().Unix()
tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, 0)
// Insert this connection into a map so the agent
// can publish node updates.
connectionSockets[id] = &agpl.TrackedConn{
Conn: conn,
Start: now,
LastWrite: now,
}
connectionSockets[id] = tc

// When a new connection is requested, we update it with the latest
// node of the agent. This allows the connection to establish.
node, ok := c.nodes[agent]
c.mutex.Unlock()
if ok {
data, err := json.Marshal([]*agpl.Node{node})
if err != nil {
return xerrors.Errorf("marshal node: %w", err)
}
_, err = conn.Write(data)
err := tc.Enqueue([]*agpl.Node{node})
c.mutex.Unlock()
if err != nil {
return xerrors.Errorf("write nodes: %w", err)
return xerrors.Errorf("enqueue node: %w", err)
}
} else {
c.mutex.Unlock()
err := c.publishClientHello(agent)
if err != nil {
return xerrors.Errorf("publish client hello: %w", err)
}
}
go tc.SendUpdates()

defer func() {
c.mutex.Lock()
Expand Down Expand Up @@ -161,8 +166,9 @@ func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *js
c.nodes[id] = &node
// Write the new node from this client to the actively connected agent.
agentSocket, ok := c.agentSockets[agent]
c.mutex.Unlock()

if !ok {
c.mutex.Unlock()
// If we don't own the agent locally, send it over pubsub to a node that
// owns the agent.
err := c.publishNodesToAgent(agent, []*agpl.Node{&node})
Expand All @@ -171,67 +177,50 @@ func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *js
}
return nil
}

// Write the new node from this client to the actively
// connected agent.
data, err := json.Marshal([]*agpl.Node{&node})
if err != nil {
return xerrors.Errorf("marshal nodes: %w", err)
}

_, err = agentSocket.Write(data)
err = agentSocket.Enqueue([]*agpl.Node{&node})
c.mutex.Unlock()
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) {
return nil
}
return xerrors.Errorf("write json: %w", err)
return xerrors.Errorf("enqueu nodes: %w", err)
}

return nil
}

// ServeAgent accepts a WebSocket connection to an agent that listens to
// incoming connections and publishes node updates.
func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logger := c.agentLogger(id)
c.agentNameCache.Add(id, name)

// Publish all nodes on this instance that want to connect to this agent.
nodes := c.nodesSubscribedToAgent(id)
if len(nodes) > 0 {
data, err := json.Marshal(nodes)
if err != nil {
return xerrors.Errorf("marshal json: %w", err)
}
_, err = conn.Write(data)
if err != nil {
return xerrors.Errorf("write nodes: %w", err)
}
}

// This uniquely identifies a connection that belongs to this goroutine.
unique := uuid.New()
now := time.Now().Unix()
overwrites := int64(0)

// If an old agent socket is connected, we close it
// to avoid any leaks. This shouldn't ever occur because
// we expect one agent to be running.
c.mutex.Lock()
overwrites := int64(0)
// If an old agent socket is connected, we Close it to avoid any leaks. This
// shouldn't ever occur because we expect one agent to be running, but it's
// possible for a race condition to happen when an agent is disconnected and
// attempts to reconnect before the server realizes the old connection is
// dead.
oldAgentSocket, ok := c.agentSockets[id]
if ok {
overwrites = oldAgentSocket.Overwrites + 1
_ = oldAgentSocket.Close()
}
c.agentSockets[id] = &agpl.TrackedConn{
ID: unique,
Conn: conn,
// This uniquely identifies a connection that belongs to this goroutine.
unique := uuid.New()
tc := agpl.NewTrackedConn(ctx, cancel, conn, unique, logger, overwrites)

Name: name,
Start: now,
LastWrite: now,
Overwrites: overwrites,
// Publish all nodes on this instance that want to connect to this agent.
nodes := c.nodesSubscribedToAgent(id)
if len(nodes) > 0 {
err := tc.Enqueue(nodes)
if err != nil {
c.mutex.Unlock()
return xerrors.Errorf("enqueue nodes: %w", err)
}
}
c.agentSockets[id] = tc
c.mutex.Unlock()
go tc.SendUpdates()

// Tell clients on other instances to send a callmemaybe to us.
err := c.publishAgentHello(id)
Expand Down Expand Up @@ -269,8 +258,6 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err
}

func (c *haCoordinator) nodesSubscribedToAgent(agentID uuid.UUID) []*agpl.Node {
c.mutex.Lock()
defer c.mutex.Unlock()
sockets, ok := c.agentToConnectionSockets[agentID]
if !ok {
return nil
Expand Down Expand Up @@ -320,25 +307,11 @@ func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) (
return &node, nil
}

data, err := json.Marshal([]*agpl.Node{&node})
if err != nil {
c.mutex.Unlock()
return nil, xerrors.Errorf("marshal nodes: %w", err)
}

// Publish the new node to every listening socket.
var wg sync.WaitGroup
wg.Add(len(connectionSockets))
for _, connectionSocket := range connectionSockets {
connectionSocket := connectionSocket
go func() {
defer wg.Done()
_ = connectionSocket.SetWriteDeadline(time.Now().Add(5 * time.Second))
_, _ = connectionSocket.Write(data)
}()
_ = connectionSocket.Enqueue([]*agpl.Node{&node})
}
c.mutex.Unlock()
wg.Wait()
return &node, nil
}

Expand Down Expand Up @@ -502,18 +475,19 @@ func (c *haCoordinator) handlePubsubMessage(ctx context.Context, message []byte)

c.mutex.Lock()
agentSocket, ok := c.agentSockets[agentUUID]
c.mutex.Unlock()
if !ok {
c.mutex.Unlock()
return
}
c.mutex.Unlock()

// We get a single node over pubsub, so turn into an array.
_, err = agentSocket.Write(nodeJSON)
// Socket takes a slice of Nodes, so we need to parse the JSON here.
var nodes []*agpl.Node
err = json.Unmarshal(nodeJSON, &nodes)
if err != nil {
c.log.Error(ctx, "invalid nodes JSON", slog.F("id", agentID), slog.Error(err), slog.F("node", string(nodeJSON)))
}
err = agentSocket.Enqueue(nodes)
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) {
return
}
c.log.Error(ctx, "send callmemaybe to agent", slog.Error(err))
return
}
Expand All @@ -536,7 +510,9 @@ func (c *haCoordinator) handlePubsubMessage(ctx context.Context, message []byte)
return
}

c.mutex.RLock()
nodes := c.nodesSubscribedToAgent(agentUUID)
c.mutex.RUnlock()
if len(nodes) > 0 {
err := c.publishNodesToAgent(agentUUID, nodes)
if err != nil {
Expand Down
Loading