Skip to content

Commit 4da1223

Browse files
authored
fix: pass OnSubscribe to HA MultiAgent (#9947)
Fixes #9929
1 parent 61154a6 commit 4da1223

File tree

1 file changed

+28
-13
lines changed

1 file changed

+28
-13
lines changed

enterprise/tailnet/coordinator.go

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@ func (c *haCoordinator) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn {
5757
ID: id,
5858
AgentIsLegacyFunc: c.agentIsLegacy,
5959
OnSubscribe: c.clientSubscribeToAgent,
60+
OnUnsubscribe: c.clientUnsubscribeFromAgent,
6061
OnNodeUpdate: c.clientNodeUpdate,
61-
OnRemove: func(enq agpl.Queue) { c.clientDisconnected(enq.UniqueID()) },
62+
OnRemove: c.clientDisconnected,
6263
}).Init()
6364
c.addClient(id, m)
6465
return m
@@ -101,6 +102,22 @@ func (c *haCoordinator) clientSubscribeToAgent(enq agpl.Queue, agentID uuid.UUID
101102
return nil, nil
102103
}
103104

105+
func (c *haCoordinator) clientUnsubscribeFromAgent(enq agpl.Queue, agentID uuid.UUID) error {
106+
c.mutex.Lock()
107+
defer c.mutex.Unlock()
108+
109+
connectionSockets, ok := c.agentToConnectionSockets[agentID]
110+
if !ok {
111+
return nil
112+
}
113+
delete(connectionSockets, enq.UniqueID())
114+
if len(connectionSockets) == 0 {
115+
delete(c.agentToConnectionSockets, agentID)
116+
}
117+
118+
return nil
119+
}
120+
104121
type haCoordinator struct {
105122
id uuid.UUID
106123
log slog.Logger
@@ -161,7 +178,7 @@ func (c *haCoordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error
161178
defer tc.Close()
162179

163180
c.addClient(id, tc)
164-
defer c.clientDisconnected(id)
181+
defer c.clientDisconnected(tc)
165182

166183
agentNode, err := c.clientSubscribeToAgent(tc, agentID)
167184
if err != nil {
@@ -200,26 +217,24 @@ func (c *haCoordinator) initOrSetAgentConnectionSocketLocked(agentID uuid.UUID,
200217
c.clientsToAgents[enq.UniqueID()][agentID] = c.agentSockets[agentID]
201218
}
202219

203-
func (c *haCoordinator) clientDisconnected(id uuid.UUID) {
220+
func (c *haCoordinator) clientDisconnected(enq agpl.Queue) {
204221
c.mutex.Lock()
205222
defer c.mutex.Unlock()
206223

207-
for agentID := range c.clientsToAgents[id] {
208-
// Clean all traces of this connection from the map.
209-
delete(c.nodes, id)
224+
for agentID := range c.clientsToAgents[enq.UniqueID()] {
210225
connectionSockets, ok := c.agentToConnectionSockets[agentID]
211226
if !ok {
212-
return
227+
continue
213228
}
214-
delete(connectionSockets, id)
215-
if len(connectionSockets) != 0 {
216-
return
229+
delete(connectionSockets, enq.UniqueID())
230+
if len(connectionSockets) == 0 {
231+
delete(c.agentToConnectionSockets, agentID)
217232
}
218-
delete(c.agentToConnectionSockets, agentID)
219233
}
220234

221-
delete(c.clients, id)
222-
delete(c.clientsToAgents, id)
235+
delete(c.nodes, enq.UniqueID())
236+
delete(c.clients, enq.UniqueID())
237+
delete(c.clientsToAgents, enq.UniqueID())
223238
}
224239

225240
func (c *haCoordinator) handleNextClientMessage(id uuid.UUID, decoder *json.Decoder) error {

0 commit comments

Comments
 (0)