Skip to content

Commit bd63011

Browse files
authored
fix: coordinator node update race (#7345)
* fix: coordinator node update race Signed-off-by: Spike Curtis <spike@coder.com> * Lint fixes, make core private Signed-off-by: Spike Curtis <spike@coder.com> * Don't log broken connections as errors Signed-off-by: Spike Curtis <spike@coder.com> --------- Signed-off-by: Spike Curtis <spike@coder.com>
1 parent 0e78d0a commit bd63011

File tree

3 files changed

+437
-245
lines changed

3 files changed

+437
-245
lines changed

enterprise/tailnet/coordinator.go

+57-81
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"net"
1111
"net/http"
1212
"sync"
13-
"time"
1413

1514
"github.com/google/uuid"
1615
lru "github.com/hashicorp/golang-lru/v2"
@@ -79,44 +78,50 @@ func (c *haCoordinator) Node(id uuid.UUID) *agpl.Node {
7978
return node
8079
}
8180

81+
func (c *haCoordinator) clientLogger(id, agent uuid.UUID) slog.Logger {
82+
return c.log.With(slog.F("client_id", id), slog.F("agent_id", agent))
83+
}
84+
85+
func (c *haCoordinator) agentLogger(agent uuid.UUID) slog.Logger {
86+
return c.log.With(slog.F("agent_id", agent))
87+
}
88+
8289
// ServeClient accepts a WebSocket connection that wants to connect to an agent
8390
// with the specified ID.
8491
func (c *haCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
92+
ctx, cancel := context.WithCancel(context.Background())
93+
defer cancel()
94+
logger := c.clientLogger(id, agent)
95+
8596
c.mutex.Lock()
8697
connectionSockets, ok := c.agentToConnectionSockets[agent]
8798
if !ok {
8899
connectionSockets = map[uuid.UUID]*agpl.TrackedConn{}
89100
c.agentToConnectionSockets[agent] = connectionSockets
90101
}
91102

92-
now := time.Now().Unix()
103+
tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, 0)
93104
// Insert this connection into a map so the agent
94105
// can publish node updates.
95-
connectionSockets[id] = &agpl.TrackedConn{
96-
Conn: conn,
97-
Start: now,
98-
LastWrite: now,
99-
}
106+
connectionSockets[id] = tc
100107

101108
// When a new connection is requested, we update it with the latest
102109
// node of the agent. This allows the connection to establish.
103110
node, ok := c.nodes[agent]
104-
c.mutex.Unlock()
105111
if ok {
106-
data, err := json.Marshal([]*agpl.Node{node})
107-
if err != nil {
108-
return xerrors.Errorf("marshal node: %w", err)
109-
}
110-
_, err = conn.Write(data)
112+
err := tc.Enqueue([]*agpl.Node{node})
113+
c.mutex.Unlock()
111114
if err != nil {
112-
return xerrors.Errorf("write nodes: %w", err)
115+
return xerrors.Errorf("enqueue node: %w", err)
113116
}
114117
} else {
118+
c.mutex.Unlock()
115119
err := c.publishClientHello(agent)
116120
if err != nil {
117121
return xerrors.Errorf("publish client hello: %w", err)
118122
}
119123
}
124+
go tc.SendUpdates()
120125

121126
defer func() {
122127
c.mutex.Lock()
@@ -161,8 +166,9 @@ func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *js
161166
c.nodes[id] = &node
162167
// Write the new node from this client to the actively connected agent.
163168
agentSocket, ok := c.agentSockets[agent]
164-
c.mutex.Unlock()
169+
165170
if !ok {
171+
c.mutex.Unlock()
166172
// If we don't own the agent locally, send it over pubsub to a node that
167173
// owns the agent.
168174
err := c.publishNodesToAgent(agent, []*agpl.Node{&node})
@@ -171,67 +177,50 @@ func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *js
171177
}
172178
return nil
173179
}
174-
175-
// Write the new node from this client to the actively
176-
// connected agent.
177-
data, err := json.Marshal([]*agpl.Node{&node})
178-
if err != nil {
179-
return xerrors.Errorf("marshal nodes: %w", err)
180-
}
181-
182-
_, err = agentSocket.Write(data)
180+
err = agentSocket.Enqueue([]*agpl.Node{&node})
181+
c.mutex.Unlock()
183182
if err != nil {
184-
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) {
185-
return nil
186-
}
187-
return xerrors.Errorf("write json: %w", err)
183+
return xerrors.Errorf("enqueu nodes: %w", err)
188184
}
189-
190185
return nil
191186
}
192187

193188
// ServeAgent accepts a WebSocket connection to an agent that listens to
194189
// incoming connections and publishes node updates.
195190
func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) error {
191+
ctx, cancel := context.WithCancel(context.Background())
192+
defer cancel()
193+
logger := c.agentLogger(id)
196194
c.agentNameCache.Add(id, name)
197195

198-
// Publish all nodes on this instance that want to connect to this agent.
199-
nodes := c.nodesSubscribedToAgent(id)
200-
if len(nodes) > 0 {
201-
data, err := json.Marshal(nodes)
202-
if err != nil {
203-
return xerrors.Errorf("marshal json: %w", err)
204-
}
205-
_, err = conn.Write(data)
206-
if err != nil {
207-
return xerrors.Errorf("write nodes: %w", err)
208-
}
209-
}
210-
211-
// This uniquely identifies a connection that belongs to this goroutine.
212-
unique := uuid.New()
213-
now := time.Now().Unix()
214-
overwrites := int64(0)
215-
216-
// If an old agent socket is connected, we close it
217-
// to avoid any leaks. This shouldn't ever occur because
218-
// we expect one agent to be running.
219196
c.mutex.Lock()
197+
overwrites := int64(0)
198+
// If an old agent socket is connected, we Close it to avoid any leaks. This
199+
// shouldn't ever occur because we expect one agent to be running, but it's
200+
// possible for a race condition to happen when an agent is disconnected and
201+
// attempts to reconnect before the server realizes the old connection is
202+
// dead.
220203
oldAgentSocket, ok := c.agentSockets[id]
221204
if ok {
222205
overwrites = oldAgentSocket.Overwrites + 1
223206
_ = oldAgentSocket.Close()
224207
}
225-
c.agentSockets[id] = &agpl.TrackedConn{
226-
ID: unique,
227-
Conn: conn,
208+
// This uniquely identifies a connection that belongs to this goroutine.
209+
unique := uuid.New()
210+
tc := agpl.NewTrackedConn(ctx, cancel, conn, unique, logger, overwrites)
228211

229-
Name: name,
230-
Start: now,
231-
LastWrite: now,
232-
Overwrites: overwrites,
212+
// Publish all nodes on this instance that want to connect to this agent.
213+
nodes := c.nodesSubscribedToAgent(id)
214+
if len(nodes) > 0 {
215+
err := tc.Enqueue(nodes)
216+
if err != nil {
217+
c.mutex.Unlock()
218+
return xerrors.Errorf("enqueue nodes: %w", err)
219+
}
233220
}
221+
c.agentSockets[id] = tc
234222
c.mutex.Unlock()
223+
go tc.SendUpdates()
235224

236225
// Tell clients on other instances to send a callmemaybe to us.
237226
err := c.publishAgentHello(id)
@@ -269,8 +258,6 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err
269258
}
270259

271260
func (c *haCoordinator) nodesSubscribedToAgent(agentID uuid.UUID) []*agpl.Node {
272-
c.mutex.Lock()
273-
defer c.mutex.Unlock()
274261
sockets, ok := c.agentToConnectionSockets[agentID]
275262
if !ok {
276263
return nil
@@ -320,25 +307,11 @@ func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) (
320307
return &node, nil
321308
}
322309

323-
data, err := json.Marshal([]*agpl.Node{&node})
324-
if err != nil {
325-
c.mutex.Unlock()
326-
return nil, xerrors.Errorf("marshal nodes: %w", err)
327-
}
328-
329310
// Publish the new node to every listening socket.
330-
var wg sync.WaitGroup
331-
wg.Add(len(connectionSockets))
332311
for _, connectionSocket := range connectionSockets {
333-
connectionSocket := connectionSocket
334-
go func() {
335-
defer wg.Done()
336-
_ = connectionSocket.SetWriteDeadline(time.Now().Add(5 * time.Second))
337-
_, _ = connectionSocket.Write(data)
338-
}()
312+
_ = connectionSocket.Enqueue([]*agpl.Node{&node})
339313
}
340314
c.mutex.Unlock()
341-
wg.Wait()
342315
return &node, nil
343316
}
344317

@@ -502,18 +475,19 @@ func (c *haCoordinator) handlePubsubMessage(ctx context.Context, message []byte)
502475

503476
c.mutex.Lock()
504477
agentSocket, ok := c.agentSockets[agentUUID]
478+
c.mutex.Unlock()
505479
if !ok {
506-
c.mutex.Unlock()
507480
return
508481
}
509-
c.mutex.Unlock()
510482

511-
// We get a single node over pubsub, so turn into an array.
512-
_, err = agentSocket.Write(nodeJSON)
483+
// Socket takes a slice of Nodes, so we need to parse the JSON here.
484+
var nodes []*agpl.Node
485+
err = json.Unmarshal(nodeJSON, &nodes)
486+
if err != nil {
487+
c.log.Error(ctx, "invalid nodes JSON", slog.F("id", agentID), slog.Error(err), slog.F("node", string(nodeJSON)))
488+
}
489+
err = agentSocket.Enqueue(nodes)
513490
if err != nil {
514-
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) {
515-
return
516-
}
517491
c.log.Error(ctx, "send callmemaybe to agent", slog.Error(err))
518492
return
519493
}
@@ -536,7 +510,9 @@ func (c *haCoordinator) handlePubsubMessage(ctx context.Context, message []byte)
536510
return
537511
}
538512

513+
c.mutex.RLock()
539514
nodes := c.nodesSubscribedToAgent(agentUUID)
515+
c.mutex.RUnlock()
540516
if len(nodes) > 0 {
541517
err := c.publishNodesToAgent(agentUUID, nodes)
542518
if err != nil {

0 commit comments

Comments
 (0)