@@ -12,7 +12,6 @@ import (
12
12
"sync"
13
13
14
14
"github.com/google/uuid"
15
- "github.com/hashicorp/go-multierror"
16
15
lru "github.com/hashicorp/golang-lru/v2"
17
16
"golang.org/x/xerrors"
18
17
@@ -42,6 +41,8 @@ func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, err
42
41
agentSockets : map [uuid.UUID ]agpl.Enqueueable {},
43
42
agentToConnectionSockets : map [uuid.UUID ]map [uuid.UUID ]agpl.Enqueueable {},
44
43
agentNameCache : nameCache ,
44
+ clients : map [uuid.UUID ]agpl.Enqueueable {},
45
+ clientsToAgents : map [uuid.UUID ]map [uuid.UUID ]struct {}{},
45
46
legacyAgents : map [uuid.UUID ]struct {}{},
46
47
}
47
48
@@ -57,14 +58,22 @@ func (c *haCoordinator) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn {
57
58
ID : id ,
58
59
Logger : c .log ,
59
60
AgentIsLegacyFunc : c .agentIsLegacy ,
60
- OnSubscribe : c .multiAgentSubscribe ,
61
- OnNodeUpdate : c .multiAgentUpdate ,
61
+ OnSubscribe : c .clientSubscribeToAgent ,
62
+ OnNodeUpdate : c .clientNodeUpdate ,
63
+ OnRemove : c .clientDisconnected ,
62
64
}).Init ()
65
+ c .mutex .Lock ()
66
+ c .clients [id ] = m
67
+ c .clientsToAgents [id ] = map [uuid.UUID ]struct {}{}
68
+ c .mutex .Unlock ()
63
69
return m
64
70
}
65
71
66
- func (c * haCoordinator ) multiAgentSubscribe (enq agpl.Enqueueable , agentID uuid.UUID ) ( func (), error ) {
72
+ func (c * haCoordinator ) clientSubscribeToAgent (enq agpl.Enqueueable , agentID uuid.UUID ) error {
67
73
c .mutex .Lock ()
74
+ defer c .mutex .Unlock ()
75
+
76
+ c .initOrSetAgentConnectionSocketLocked (agentID , enq )
68
77
69
78
node := c .nodes [enq .UniqueID ()]
70
79
@@ -73,44 +82,43 @@ func (c *haCoordinator) multiAgentSubscribe(enq agpl.Enqueueable, agentID uuid.U
73
82
if ok {
74
83
err := enq .Enqueue ([]* agpl.Node {agentNode })
75
84
if err != nil {
76
- return nil , xerrors .Errorf ("enqueue agent on subscribe: %w" , err )
85
+ return xerrors .Errorf ("enqueue agent on subscribe: %w" , err )
77
86
}
78
87
} else {
79
88
// If we don't have the node locally, notify other coordinators.
80
- c .mutex .Unlock ()
81
89
err := c .publishClientHello (agentID )
82
90
if err != nil {
83
- return nil , xerrors .Errorf ("publish client hello: %w" , err )
91
+ return xerrors .Errorf ("publish client hello: %w" , err )
84
92
}
85
93
}
86
94
87
95
if node != nil {
88
- err := c .handleClientUpdate ( enq . UniqueID (), agentID , node )
96
+ err := c .sendNodeToAgentLocked ( agentID , node )
89
97
if err != nil {
90
- return nil , xerrors .Errorf ("handle client update: %w" , err )
98
+ return xerrors .Errorf ("handle client update: %w" , err )
91
99
}
92
100
}
93
101
94
- return c .cleanupClientConn (enq .UniqueID (), agentID ), nil
95
- }
96
-
97
- func (c * haCoordinator ) multiAgentUpdate (id uuid.UUID , agents []uuid.UUID , node * agpl.Node ) error {
98
- var errs * multierror.Error
99
- // This isn't the most efficient, but this coordinator is being deprecated
100
- // soon anyways.
101
- for _ , agent := range agents {
102
- err := c .handleClientUpdate (id , agent , node )
103
- if err != nil {
104
- errs = multierror .Append (errs , err )
105
- }
106
- }
107
- if errs != nil {
108
- return errs
109
- }
110
-
111
102
return nil
112
103
}
113
104
105
+ // func (c *haCoordinator) multiAgentUpdate(id uuid.UUID, agents []uuid.UUID, node *agpl.Node) error {
106
+ // var errs *multierror.Error
107
+ // // This isn't the most efficient, but this coordinator is being deprecated
108
+ // // soon anyways.
109
+ // for _, agent := range agents {
110
+ // err := c.handleClientUpdate(id, agent, node)
111
+ // if err != nil {
112
+ // errs = multierror.Append(errs, err)
113
+ // }
114
+ // }
115
+ // if errs != nil {
116
+ // return errs
117
+ // }
118
+
119
+ // return nil
120
+ // }
121
+
114
122
type haCoordinator struct {
115
123
id uuid.UUID
116
124
log slog.Logger
@@ -127,6 +135,9 @@ type haCoordinator struct {
127
135
// are subscribed to updates for that agent.
128
136
agentToConnectionSockets map [uuid.UUID ]map [uuid.UUID ]agpl.Enqueueable
129
137
138
+ clients map [uuid.UUID ]agpl.Enqueueable
139
+ clientsToAgents map [uuid.UUID ]map [uuid.UUID ]struct {}
140
+
130
141
// agentNameCache holds a cache of agent names. If one of them disappears,
131
142
// it's helpful to have a name cached for debugging.
132
143
agentNameCache * lru.Cache [uuid.UUID , string ]
@@ -152,40 +163,25 @@ func (c *haCoordinator) agentLogger(agent uuid.UUID) slog.Logger {
152
163
153
164
// ServeClient accepts a WebSocket connection that wants to connect to an agent
154
165
// with the specified ID.
155
- func (c * haCoordinator ) ServeClient (conn net.Conn , id , agent uuid.UUID ) error {
166
+ func (c * haCoordinator ) ServeClient (conn net.Conn , id , agentID uuid.UUID ) error {
156
167
ctx , cancel := context .WithCancel (context .Background ())
157
168
defer cancel ()
158
- logger := c .clientLogger (id , agent )
169
+ logger := c .clientLogger (id , agentID )
159
170
160
- c .mutex .Lock ()
161
-
162
- tc := agpl .NewTrackedConn (ctx , cancel , conn , id , logger , 0 )
163
- c .initOrSetAgentConnectionSocketLocked (agent , tc )
171
+ ma := c .ServeMultiAgent (id )
172
+ defer ma .Close ()
164
173
165
- // When a new connection is requested, we update it with the latest
166
- // node of the agent. This allows the connection to establish.
167
- node , ok := c .nodes [agent ]
168
- if ok {
169
- err := tc .Enqueue ([]* agpl.Node {node })
170
- c .mutex .Unlock ()
171
- if err != nil {
172
- return xerrors .Errorf ("enqueue node: %w" , err )
173
- }
174
- } else {
175
- c .mutex .Unlock ()
176
- err := c .publishClientHello (agent )
177
- if err != nil {
178
- return xerrors .Errorf ("publish client hello: %w" , err )
179
- }
174
+ err := ma .SubscribeAgent (agentID )
175
+ if err != nil {
176
+ return xerrors .Errorf ("subscribe agent: %w" , err )
180
177
}
181
- go tc .SendUpdates ()
182
178
183
- defer c . cleanupClientConn ( id , agent )
179
+ go agpl . SendUpdatesToConn ( ctx , logger , ma , conn )
184
180
185
181
decoder := json .NewDecoder (conn )
186
182
// Indefinitely handle messages from the client websocket.
187
183
for {
188
- err := c .handleNextClientMessage (id , agent , decoder )
184
+ err := c .handleNextClientMessage (id , decoder )
189
185
if err != nil {
190
186
if errors .Is (err , io .EOF ) || errors .Is (err , io .ErrClosedPipe ) {
191
187
return nil
@@ -202,12 +198,14 @@ func (c *haCoordinator) initOrSetAgentConnectionSocketLocked(agentID uuid.UUID,
202
198
c .agentToConnectionSockets [agentID ] = connectionSockets
203
199
}
204
200
connectionSockets [enq .UniqueID ()] = enq
201
+ c .clientsToAgents [enq .UniqueID ()][agentID ] = struct {}{}
205
202
}
206
203
207
- func (c * haCoordinator ) cleanupClientConn (id , agentID uuid.UUID ) func () {
208
- return func () {
209
- c .mutex .Lock ()
210
- defer c .mutex .Unlock ()
204
+ func (c * haCoordinator ) clientDisconnected (id uuid.UUID ) {
205
+ c .mutex .Lock ()
206
+ defer c .mutex .Unlock ()
207
+
208
+ for agentID := range c .clientsToAgents [id ] {
211
209
// Clean all traces of this connection from the map.
212
210
delete (c .nodes , id )
213
211
connectionSockets , ok := c .agentToConnectionSockets [agentID ]
@@ -220,39 +218,52 @@ func (c *haCoordinator) cleanupClientConn(id, agentID uuid.UUID) func() {
220
218
}
221
219
delete (c .agentToConnectionSockets , agentID )
222
220
}
221
+
222
+ delete (c .clients , id )
223
+ delete (c .clientsToAgents , id )
223
224
}
224
225
225
- func (c * haCoordinator ) handleNextClientMessage (id , agent uuid.UUID , decoder * json.Decoder ) error {
226
+ func (c * haCoordinator ) handleNextClientMessage (id uuid.UUID , decoder * json.Decoder ) error {
226
227
var node agpl.Node
227
228
err := decoder .Decode (& node )
228
229
if err != nil {
229
230
return xerrors .Errorf ("read json: %w" , err )
230
231
}
231
232
232
- return c .handleClientUpdate (id , agent , & node )
233
+ return c .clientNodeUpdate (id , & node )
233
234
}
234
235
235
- func (c * haCoordinator ) handleClientUpdate (id , agent uuid.UUID , node * agpl.Node ) error {
236
+ func (c * haCoordinator ) clientNodeUpdate (id uuid.UUID , node * agpl.Node ) error {
236
237
c .mutex .Lock ()
238
+ defer c .mutex .Unlock ()
237
239
// Update the node of this client in our in-memory map. If an agent entirely
238
240
// shuts down and reconnects, it needs to be aware of all clients attempting
239
241
// to establish connections.
240
242
c .nodes [id ] = node
241
243
242
- // Write the new node from this client to the actively connected agent.
243
- agentSocket , ok := c .agentSockets [agent ]
244
+ for agentID := range c .clientsToAgents [id ] {
245
+ // Write the new node from this client to the actively connected agent.
246
+ err := c .sendNodeToAgentLocked (agentID , node )
247
+ if err != nil {
248
+ c .log .Error (context .Background (), "send node to agent" , slog .Error (err ), slog .F ("agent_id" , agentID ))
249
+ }
250
+ }
251
+
252
+ return nil
253
+ }
254
+
255
+ func (c * haCoordinator ) sendNodeToAgentLocked (agentID uuid.UUID , node * agpl.Node ) error {
256
+ agentSocket , ok := c .agentSockets [agentID ]
244
257
if ! ok {
245
- c .mutex .Unlock ()
246
258
// If we don't own the agent locally, send it over pubsub to a node that
247
259
// owns the agent.
248
- err := c .publishNodesToAgent (agent , []* agpl.Node {node })
260
+ err := c .publishNodesToAgent (agentID , []* agpl.Node {node })
249
261
if err != nil {
250
262
return xerrors .Errorf ("publish node to agent" )
251
263
}
252
264
return nil
253
265
}
254
266
err := agentSocket .Enqueue ([]* agpl.Node {node })
255
- c .mutex .Unlock ()
256
267
if err != nil {
257
268
return xerrors .Errorf ("enqueue node: %w" , err )
258
269
}
@@ -422,7 +433,7 @@ func (c *haCoordinator) Close() error {
422
433
for _ , socket := range c .agentSockets {
423
434
socket := socket
424
435
go func () {
425
- _ = socket .Close ()
436
+ _ = socket .CoordinatorClose ()
426
437
wg .Done ()
427
438
}()
428
439
}
@@ -432,12 +443,17 @@ func (c *haCoordinator) Close() error {
432
443
for _ , socket := range connMap {
433
444
socket := socket
434
445
go func () {
435
- _ = socket .Close ()
446
+ _ = socket .CoordinatorClose ()
436
447
wg .Done ()
437
448
}()
438
449
}
439
450
}
440
451
452
+ // Ensure clients that have no subscriptions are properly closed.
453
+ for _ , client := range c .clients {
454
+ _ = client .CoordinatorClose ()
455
+ }
456
+
441
457
wg .Wait ()
442
458
return nil
443
459
}
0 commit comments