Skip to content

Commit c3731a1

Browse files
authored
fix: ensure agent websocket only removes its own conn (#5828)
1 parent 443e218 commit c3731a1

File tree

2 files changed

+140
-13
lines changed

2 files changed

+140
-13
lines changed

tailnet/coordinator.go

+46-13
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ func NewCoordinator() Coordinator {
104104
return &coordinator{
105105
closed: false,
106106
nodes: map[uuid.UUID]*Node{},
107-
agentSockets: map[uuid.UUID]net.Conn{},
107+
agentSockets: map[uuid.UUID]idConn{},
108108
agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]net.Conn{},
109109
}
110110
}
@@ -123,12 +123,19 @@ type coordinator struct {
123123
// nodes maps agent and connection IDs their respective node.
124124
nodes map[uuid.UUID]*Node
125125
// agentSockets maps agent IDs to their open websocket.
126-
agentSockets map[uuid.UUID]net.Conn
126+
agentSockets map[uuid.UUID]idConn
127127
// agentToConnectionSockets maps agent IDs to connection IDs of conns that
128128
// are subscribed to updates for that agent.
129129
agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]net.Conn
130130
}
131131

132+
type idConn struct {
133+
// id is an ephemeral UUID used to uniquely identify the owner of the
134+
// connection.
135+
id uuid.UUID
136+
conn net.Conn
137+
}
138+
132139
// Node returns an in-memory node by ID.
133140
// If the node does not exist, nil is returned.
134141
func (c *coordinator) Node(id uuid.UUID) *Node {
@@ -137,6 +144,18 @@ func (c *coordinator) Node(id uuid.UUID) *Node {
137144
return c.nodes[id]
138145
}
139146

147+
func (c *coordinator) NodeCount() int {
148+
c.mutex.Lock()
149+
defer c.mutex.Unlock()
150+
return len(c.nodes)
151+
}
152+
153+
func (c *coordinator) AgentCount() int {
154+
c.mutex.Lock()
155+
defer c.mutex.Unlock()
156+
return len(c.agentSockets)
157+
}
158+
140159
// ServeClient accepts a WebSocket connection that wants to connect to an agent
141160
// with the specified ID.
142161
func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
@@ -224,9 +243,9 @@ func (c *coordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json
224243
return xerrors.Errorf("marshal nodes: %w", err)
225244
}
226245

227-
_, err = agentSocket.Write(data)
246+
_, err = agentSocket.conn.Write(data)
228247
if err != nil {
229-
if errors.Is(err, io.EOF) {
248+
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) {
230249
return nil
231250
}
232251
return xerrors.Errorf("write json: %w", err)
@@ -268,27 +287,41 @@ func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error {
268287
c.mutex.Lock()
269288
}
270289

271-
// If an old agent socket is connected, we close it
272-
// to avoid any leaks. This shouldn't ever occur because
273-
// we expect one agent to be running.
290+
// This uniquely identifies a connection that belongs to this goroutine.
291+
unique := uuid.New()
292+
293+
// If an old agent socket is connected, we close it to avoid any leaks. This
294+
// shouldn't ever occur because we expect one agent to be running, but it's
295+
// possible for a race condition to happen when an agent is disconnected and
296+
// attempts to reconnect before the server realizes the old connection is
297+
// dead.
274298
oldAgentSocket, ok := c.agentSockets[id]
275299
if ok {
276-
_ = oldAgentSocket.Close()
300+
_ = oldAgentSocket.conn.Close()
301+
}
302+
c.agentSockets[id] = idConn{
303+
id: unique,
304+
conn: conn,
277305
}
278-
c.agentSockets[id] = conn
306+
279307
c.mutex.Unlock()
280308
defer func() {
281309
c.mutex.Lock()
282310
defer c.mutex.Unlock()
283-
delete(c.agentSockets, id)
284-
delete(c.nodes, id)
311+
312+
// Only delete the connection if it's ours. It could have been
313+
// overwritten.
314+
if idConn := c.agentSockets[id]; idConn.id == unique {
315+
delete(c.agentSockets, id)
316+
delete(c.nodes, id)
317+
}
285318
}()
286319

287320
decoder := json.NewDecoder(conn)
288321
for {
289322
err := c.handleNextAgentMessage(id, decoder)
290323
if err != nil {
291-
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
324+
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) {
292325
return nil
293326
}
294327
return xerrors.Errorf("handle next agent message: %w", err)
@@ -349,7 +382,7 @@ func (c *coordinator) Close() error {
349382
for _, socket := range c.agentSockets {
350383
socket := socket
351384
go func() {
352-
_ = socket.Close()
385+
_ = socket.conn.Close()
353386
wg.Done()
354387
}()
355388
}

tailnet/coordinator_test.go

+94
Original file line numberDiff line numberDiff line change
@@ -145,4 +145,98 @@ func TestCoordinator(t *testing.T) {
145145
<-clientErrChan
146146
<-closeClientChan
147147
})
148+
149+
t.Run("AgentDoubleConnect", func(t *testing.T) {
150+
t.Parallel()
151+
coordinator := tailnet.NewCoordinator()
152+
153+
agentWS1, agentServerWS1 := net.Pipe()
154+
defer agentWS1.Close()
155+
agentNodeChan1 := make(chan []*tailnet.Node)
156+
sendAgentNode1, agentErrChan1 := tailnet.ServeCoordinator(agentWS1, func(nodes []*tailnet.Node) error {
157+
agentNodeChan1 <- nodes
158+
return nil
159+
})
160+
agentID := uuid.New()
161+
closeAgentChan1 := make(chan struct{})
162+
go func() {
163+
err := coordinator.ServeAgent(agentServerWS1, agentID)
164+
assert.NoError(t, err)
165+
close(closeAgentChan1)
166+
}()
167+
sendAgentNode1(&tailnet.Node{})
168+
require.Eventually(t, func() bool {
169+
return coordinator.Node(agentID) != nil
170+
}, testutil.WaitShort, testutil.IntervalFast)
171+
172+
clientWS, clientServerWS := net.Pipe()
173+
defer clientWS.Close()
174+
defer clientServerWS.Close()
175+
clientNodeChan := make(chan []*tailnet.Node)
176+
sendClientNode, clientErrChan := tailnet.ServeCoordinator(clientWS, func(nodes []*tailnet.Node) error {
177+
clientNodeChan <- nodes
178+
return nil
179+
})
180+
clientID := uuid.New()
181+
closeClientChan := make(chan struct{})
182+
go func() {
183+
err := coordinator.ServeClient(clientServerWS, clientID, agentID)
184+
assert.NoError(t, err)
185+
close(closeClientChan)
186+
}()
187+
agentNodes := <-clientNodeChan
188+
require.Len(t, agentNodes, 1)
189+
sendClientNode(&tailnet.Node{})
190+
clientNodes := <-agentNodeChan1
191+
require.Len(t, clientNodes, 1)
192+
193+
// Ensure an update to the agent node reaches the client!
194+
sendAgentNode1(&tailnet.Node{})
195+
agentNodes = <-clientNodeChan
196+
require.Len(t, agentNodes, 1)
197+
198+
// Create a new agent connection without disconnecting the old one.
199+
agentWS2, agentServerWS2 := net.Pipe()
200+
defer agentWS2.Close()
201+
agentNodeChan2 := make(chan []*tailnet.Node)
202+
_, agentErrChan2 := tailnet.ServeCoordinator(agentWS2, func(nodes []*tailnet.Node) error {
203+
agentNodeChan2 <- nodes
204+
return nil
205+
})
206+
closeAgentChan2 := make(chan struct{})
207+
go func() {
208+
err := coordinator.ServeAgent(agentServerWS2, agentID)
209+
assert.NoError(t, err)
210+
close(closeAgentChan2)
211+
}()
212+
213+
// Ensure the existing listening client sends it's node immediately!
214+
clientNodes = <-agentNodeChan2
215+
require.Len(t, clientNodes, 1)
216+
217+
counts, ok := coordinator.(interface {
218+
NodeCount() int
219+
AgentCount() int
220+
})
221+
if !ok {
222+
t.Fatal("coordinator should have NodeCount() and AgentCount()")
223+
}
224+
225+
assert.Equal(t, 2, counts.NodeCount())
226+
assert.Equal(t, 1, counts.AgentCount())
227+
228+
err := agentWS2.Close()
229+
require.NoError(t, err)
230+
<-agentErrChan2
231+
<-closeAgentChan2
232+
233+
err = clientWS.Close()
234+
require.NoError(t, err)
235+
<-clientErrChan
236+
<-closeClientChan
237+
238+
// This original agent websocket should've been closed forcefully.
239+
<-agentErrChan1
240+
<-closeAgentChan1
241+
})
148242
}

0 commit comments

Comments
 (0)