@@ -10,7 +10,6 @@ import (
10
10
"net"
11
11
"net/http"
12
12
"sync"
13
- "time"
14
13
15
14
"github.com/google/uuid"
16
15
lru "github.com/hashicorp/golang-lru/v2"
@@ -79,44 +78,50 @@ func (c *haCoordinator) Node(id uuid.UUID) *agpl.Node {
79
78
return node
80
79
}
81
80
81
+ func (c * haCoordinator ) clientLogger (id , agent uuid.UUID ) slog.Logger {
82
+ return c .log .With (slog .F ("client_id" , id ), slog .F ("agent_id" , agent ))
83
+ }
84
+
85
+ func (c * haCoordinator ) agentLogger (agent uuid.UUID ) slog.Logger {
86
+ return c .log .With (slog .F ("agent_id" , agent ))
87
+ }
88
+
82
89
// ServeClient accepts a WebSocket connection that wants to connect to an agent
83
90
// with the specified ID.
84
91
func (c * haCoordinator ) ServeClient (conn net.Conn , id uuid.UUID , agent uuid.UUID ) error {
92
+ ctx , cancel := context .WithCancel (context .Background ())
93
+ defer cancel ()
94
+ logger := c .clientLogger (id , agent )
95
+
85
96
c .mutex .Lock ()
86
97
connectionSockets , ok := c .agentToConnectionSockets [agent ]
87
98
if ! ok {
88
99
connectionSockets = map [uuid.UUID ]* agpl.TrackedConn {}
89
100
c .agentToConnectionSockets [agent ] = connectionSockets
90
101
}
91
102
92
- now := time . Now (). Unix ( )
103
+ tc := agpl . NewTrackedConn ( ctx , cancel , conn , id , logger , 0 )
93
104
// Insert this connection into a map so the agent
94
105
// can publish node updates.
95
- connectionSockets [id ] = & agpl.TrackedConn {
96
- Conn : conn ,
97
- Start : now ,
98
- LastWrite : now ,
99
- }
106
+ connectionSockets [id ] = tc
100
107
101
108
// When a new connection is requested, we update it with the latest
102
109
// node of the agent. This allows the connection to establish.
103
110
node , ok := c .nodes [agent ]
104
- c .mutex .Unlock ()
105
111
if ok {
106
- data , err := json .Marshal ([]* agpl.Node {node })
107
- if err != nil {
108
- return xerrors .Errorf ("marshal node: %w" , err )
109
- }
110
- _ , err = conn .Write (data )
112
+ err := tc .Enqueue ([]* agpl.Node {node })
113
+ c .mutex .Unlock ()
111
114
if err != nil {
112
- return xerrors .Errorf ("write nodes : %w" , err )
115
+ return xerrors .Errorf ("enqueue node : %w" , err )
113
116
}
114
117
} else {
118
+ c .mutex .Unlock ()
115
119
err := c .publishClientHello (agent )
116
120
if err != nil {
117
121
return xerrors .Errorf ("publish client hello: %w" , err )
118
122
}
119
123
}
124
+ go tc .SendUpdates ()
120
125
121
126
defer func () {
122
127
c .mutex .Lock ()
@@ -161,8 +166,9 @@ func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *js
161
166
c .nodes [id ] = & node
162
167
// Write the new node from this client to the actively connected agent.
163
168
agentSocket , ok := c .agentSockets [agent ]
164
- c . mutex . Unlock ()
169
+
165
170
if ! ok {
171
+ c .mutex .Unlock ()
166
172
// If we don't own the agent locally, send it over pubsub to a node that
167
173
// owns the agent.
168
174
err := c .publishNodesToAgent (agent , []* agpl.Node {& node })
@@ -171,67 +177,50 @@ func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *js
171
177
}
172
178
return nil
173
179
}
174
-
175
- // Write the new node from this client to the actively
176
- // connected agent.
177
- data , err := json .Marshal ([]* agpl.Node {& node })
178
- if err != nil {
179
- return xerrors .Errorf ("marshal nodes: %w" , err )
180
- }
181
-
182
- _ , err = agentSocket .Write (data )
180
+ err = agentSocket .Enqueue ([]* agpl.Node {& node })
181
+ c .mutex .Unlock ()
183
182
if err != nil {
184
- if errors .Is (err , io .EOF ) || errors .Is (err , io .ErrClosedPipe ) {
185
- return nil
186
- }
187
- return xerrors .Errorf ("write json: %w" , err )
183
+ return xerrors .Errorf ("enqueu nodes: %w" , err )
188
184
}
189
-
190
185
return nil
191
186
}
192
187
193
188
// ServeAgent accepts a WebSocket connection to an agent that listens to
194
189
// incoming connections and publishes node updates.
195
190
func (c * haCoordinator ) ServeAgent (conn net.Conn , id uuid.UUID , name string ) error {
191
+ ctx , cancel := context .WithCancel (context .Background ())
192
+ defer cancel ()
193
+ logger := c .agentLogger (id )
196
194
c .agentNameCache .Add (id , name )
197
195
198
- // Publish all nodes on this instance that want to connect to this agent.
199
- nodes := c .nodesSubscribedToAgent (id )
200
- if len (nodes ) > 0 {
201
- data , err := json .Marshal (nodes )
202
- if err != nil {
203
- return xerrors .Errorf ("marshal json: %w" , err )
204
- }
205
- _ , err = conn .Write (data )
206
- if err != nil {
207
- return xerrors .Errorf ("write nodes: %w" , err )
208
- }
209
- }
210
-
211
- // This uniquely identifies a connection that belongs to this goroutine.
212
- unique := uuid .New ()
213
- now := time .Now ().Unix ()
214
- overwrites := int64 (0 )
215
-
216
- // If an old agent socket is connected, we close it
217
- // to avoid any leaks. This shouldn't ever occur because
218
- // we expect one agent to be running.
219
196
c .mutex .Lock ()
197
+ overwrites := int64 (0 )
198
+ // If an old agent socket is connected, we Close it to avoid any leaks. This
199
+ // shouldn't ever occur because we expect one agent to be running, but it's
200
+ // possible for a race condition to happen when an agent is disconnected and
201
+ // attempts to reconnect before the server realizes the old connection is
202
+ // dead.
220
203
oldAgentSocket , ok := c .agentSockets [id ]
221
204
if ok {
222
205
overwrites = oldAgentSocket .Overwrites + 1
223
206
_ = oldAgentSocket .Close ()
224
207
}
225
- c . agentSockets [ id ] = & agpl. TrackedConn {
226
- ID : unique ,
227
- Conn : conn ,
208
+ // This uniquely identifies a connection that belongs to this goroutine.
209
+ unique := uuid . New ()
210
+ tc := agpl . NewTrackedConn ( ctx , cancel , conn , unique , logger , overwrites )
228
211
229
- Name : name ,
230
- Start : now ,
231
- LastWrite : now ,
232
- Overwrites : overwrites ,
212
+ // Publish all nodes on this instance that want to connect to this agent.
213
+ nodes := c .nodesSubscribedToAgent (id )
214
+ if len (nodes ) > 0 {
215
+ err := tc .Enqueue (nodes )
216
+ if err != nil {
217
+ c .mutex .Unlock ()
218
+ return xerrors .Errorf ("enqueue nodes: %w" , err )
219
+ }
233
220
}
221
+ c .agentSockets [id ] = tc
234
222
c .mutex .Unlock ()
223
+ go tc .SendUpdates ()
235
224
236
225
// Tell clients on other instances to send a callmemaybe to us.
237
226
err := c .publishAgentHello (id )
@@ -269,8 +258,6 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err
269
258
}
270
259
271
260
func (c * haCoordinator ) nodesSubscribedToAgent (agentID uuid.UUID ) []* agpl.Node {
272
- c .mutex .Lock ()
273
- defer c .mutex .Unlock ()
274
261
sockets , ok := c .agentToConnectionSockets [agentID ]
275
262
if ! ok {
276
263
return nil
@@ -320,25 +307,11 @@ func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) (
320
307
return & node , nil
321
308
}
322
309
323
- data , err := json .Marshal ([]* agpl.Node {& node })
324
- if err != nil {
325
- c .mutex .Unlock ()
326
- return nil , xerrors .Errorf ("marshal nodes: %w" , err )
327
- }
328
-
329
310
// Publish the new node to every listening socket.
330
- var wg sync.WaitGroup
331
- wg .Add (len (connectionSockets ))
332
311
for _ , connectionSocket := range connectionSockets {
333
- connectionSocket := connectionSocket
334
- go func () {
335
- defer wg .Done ()
336
- _ = connectionSocket .SetWriteDeadline (time .Now ().Add (5 * time .Second ))
337
- _ , _ = connectionSocket .Write (data )
338
- }()
312
+ _ = connectionSocket .Enqueue ([]* agpl.Node {& node })
339
313
}
340
314
c .mutex .Unlock ()
341
- wg .Wait ()
342
315
return & node , nil
343
316
}
344
317
@@ -502,18 +475,19 @@ func (c *haCoordinator) handlePubsubMessage(ctx context.Context, message []byte)
502
475
503
476
c .mutex .Lock ()
504
477
agentSocket , ok := c .agentSockets [agentUUID ]
478
+ c .mutex .Unlock ()
505
479
if ! ok {
506
- c .mutex .Unlock ()
507
480
return
508
481
}
509
- c .mutex .Unlock ()
510
482
511
- // We get a single node over pubsub, so turn into an array.
512
- _ , err = agentSocket .Write (nodeJSON )
483
+ // Socket takes a slice of Nodes, so we need to parse the JSON here.
484
+ var nodes []* agpl.Node
485
+ err = json .Unmarshal (nodeJSON , & nodes )
486
+ if err != nil {
487
+ c .log .Error (ctx , "invalid nodes JSON" , slog .F ("id" , agentID ), slog .Error (err ), slog .F ("node" , string (nodeJSON )))
488
+ }
489
+ err = agentSocket .Enqueue (nodes )
513
490
if err != nil {
514
- if errors .Is (err , io .EOF ) || errors .Is (err , io .ErrClosedPipe ) {
515
- return
516
- }
517
491
c .log .Error (ctx , "send callmemaybe to agent" , slog .Error (err ))
518
492
return
519
493
}
@@ -536,7 +510,9 @@ func (c *haCoordinator) handlePubsubMessage(ctx context.Context, message []byte)
536
510
return
537
511
}
538
512
513
+ c .mutex .RLock ()
539
514
nodes := c .nodesSubscribedToAgent (agentUUID )
515
+ c .mutex .RUnlock ()
540
516
if len (nodes ) > 0 {
541
517
err := c .publishNodesToAgent (agentUUID , nodes )
542
518
if err != nil {
0 commit comments