@@ -104,7 +104,7 @@ func NewCoordinator() Coordinator {
104
104
return & coordinator {
105
105
closed : false ,
106
106
nodes : map [uuid.UUID ]* Node {},
107
- agentSockets : map [uuid.UUID ]net. Conn {},
107
+ agentSockets : map [uuid.UUID ]idConn {},
108
108
agentToConnectionSockets : map [uuid.UUID ]map [uuid.UUID ]net.Conn {},
109
109
}
110
110
}
@@ -123,12 +123,19 @@ type coordinator struct {
123
123
// nodes maps agent and connection IDs their respective node.
124
124
nodes map [uuid.UUID ]* Node
125
125
// agentSockets maps agent IDs to their open websocket.
126
- agentSockets map [uuid.UUID ]net. Conn
126
+ agentSockets map [uuid.UUID ]idConn
127
127
// agentToConnectionSockets maps agent IDs to connection IDs of conns that
128
128
// are subscribed to updates for that agent.
129
129
agentToConnectionSockets map [uuid.UUID ]map [uuid.UUID ]net.Conn
130
130
}
131
131
132
+ type idConn struct {
133
+ // id is an ephemeral UUID used to uniquely identify the owner of the
134
+ // connection.
135
+ id uuid.UUID
136
+ conn net.Conn
137
+ }
138
+
132
139
// Node returns an in-memory node by ID.
133
140
// If the node does not exist, nil is returned.
134
141
func (c * coordinator ) Node (id uuid.UUID ) * Node {
@@ -137,6 +144,18 @@ func (c *coordinator) Node(id uuid.UUID) *Node {
137
144
return c .nodes [id ]
138
145
}
139
146
147
+ func (c * coordinator ) NodeCount () int {
148
+ c .mutex .Lock ()
149
+ defer c .mutex .Unlock ()
150
+ return len (c .nodes )
151
+ }
152
+
153
+ func (c * coordinator ) AgentCount () int {
154
+ c .mutex .Lock ()
155
+ defer c .mutex .Unlock ()
156
+ return len (c .agentSockets )
157
+ }
158
+
140
159
// ServeClient accepts a WebSocket connection that wants to connect to an agent
141
160
// with the specified ID.
142
161
func (c * coordinator ) ServeClient (conn net.Conn , id uuid.UUID , agent uuid.UUID ) error {
@@ -224,9 +243,9 @@ func (c *coordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json
224
243
return xerrors .Errorf ("marshal nodes: %w" , err )
225
244
}
226
245
227
- _ , err = agentSocket .Write (data )
246
+ _ , err = agentSocket .conn . Write (data )
228
247
if err != nil {
229
- if errors .Is (err , io .EOF ) {
248
+ if errors .Is (err , io .EOF ) || errors . Is ( err , io . ErrClosedPipe ) || errors . Is ( err , context . Canceled ) {
230
249
return nil
231
250
}
232
251
return xerrors .Errorf ("write json: %w" , err )
@@ -268,27 +287,41 @@ func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error {
268
287
c .mutex .Lock ()
269
288
}
270
289
271
- // If an old agent socket is connected, we close it
272
- // to avoid any leaks. This shouldn't ever occur because
273
- // we expect one agent to be running.
290
+ // This uniquely identifies a connection that belongs to this goroutine.
291
+ unique := uuid .New ()
292
+
293
+ // If an old agent socket is connected, we close it to avoid any leaks. This
294
+ // shouldn't ever occur because we expect one agent to be running, but it's
295
+ // possible for a race condition to happen when an agent is disconnected and
296
+ // attempts to reconnect before the server realizes the old connection is
297
+ // dead.
274
298
oldAgentSocket , ok := c .agentSockets [id ]
275
299
if ok {
276
- _ = oldAgentSocket .Close ()
300
+ _ = oldAgentSocket .conn .Close ()
301
+ }
302
+ c .agentSockets [id ] = idConn {
303
+ id : unique ,
304
+ conn : conn ,
277
305
}
278
- c . agentSockets [ id ] = conn
306
+
279
307
c .mutex .Unlock ()
280
308
defer func () {
281
309
c .mutex .Lock ()
282
310
defer c .mutex .Unlock ()
283
- delete (c .agentSockets , id )
284
- delete (c .nodes , id )
311
+
312
+ // Only delete the connection if it's ours. It could have been
313
+ // overwritten.
314
+ if idConn := c .agentSockets [id ]; idConn .id == unique {
315
+ delete (c .agentSockets , id )
316
+ delete (c .nodes , id )
317
+ }
285
318
}()
286
319
287
320
decoder := json .NewDecoder (conn )
288
321
for {
289
322
err := c .handleNextAgentMessage (id , decoder )
290
323
if err != nil {
291
- if errors .Is (err , io .EOF ) || errors .Is (err , context .Canceled ) {
324
+ if errors .Is (err , io .EOF ) || errors .Is (err , io . ErrClosedPipe ) || errors . Is ( err , context .Canceled ) {
292
325
return nil
293
326
}
294
327
return xerrors .Errorf ("handle next agent message: %w" , err )
@@ -349,7 +382,7 @@ func (c *coordinator) Close() error {
349
382
for _ , socket := range c .agentSockets {
350
383
socket := socket
351
384
go func () {
352
- _ = socket .Close ()
385
+ _ = socket .conn . Close ()
353
386
wg .Done ()
354
387
}()
355
388
}
0 commit comments