From bda160cacf2242555d6fd1f87309e13f6db3b9b9 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Mon, 23 Jan 2023 15:17:11 -0600 Subject: [PATCH 1/4] fix: ensure agent websocket only removes its own conn from the map --- tailnet/coordinator.go | 42 +++++++++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index dbd70ead1a778..800e0603acf3a 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -104,7 +104,7 @@ func NewCoordinator() Coordinator { return &coordinator{ closed: false, nodes: map[uuid.UUID]*Node{}, - agentSockets: map[uuid.UUID]net.Conn{}, + agentSockets: map[uuid.UUID]idConn{}, agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]net.Conn{}, } } @@ -123,12 +123,19 @@ type coordinator struct { // nodes maps agent and connection IDs their respective node. nodes map[uuid.UUID]*Node // agentSockets maps agent IDs to their open websocket. - agentSockets map[uuid.UUID]net.Conn + agentSockets map[uuid.UUID]idConn // agentToConnectionSockets maps agent IDs to connection IDs of conns that // are subscribed to updates for that agent. agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]net.Conn } +type idConn struct { + // id is an ephemeral UUID used to uniquely identify the owner of the + // connection. + id uuid.UUID + conn net.Conn +} + // Node returns an in-memory node by ID. // If the node does not exist, nil is returned. func (c *coordinator) Node(id uuid.UUID) *Node { @@ -224,7 +231,7 @@ func (c *coordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json return xerrors.Errorf("marshal nodes: %w", err) } - _, err = agentSocket.Write(data) + _, err = agentSocket.conn.Write(data) if err != nil { if errors.Is(err, io.EOF) { return nil @@ -268,20 +275,33 @@ func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error { c.mutex.Lock() } - // If an old agent socket is connected, we close it - // to avoid any leaks. This shouldn't ever occur because - // we expect one agent to be running. + // This uniquely identifies a connection that belongs to this goroutine. + unique := uuid.New() + + // If an old agent socket is connected, we close it to avoid any leaks. This + // shouldn't ever occur because we expect one agent to be running, but it's + // possible for a race condition to happen when an agent is disconnected and + // attempts to reconnect before the server realizes the old connection is + // dead. oldAgentSocket, ok := c.agentSockets[id] if ok { - _ = oldAgentSocket.Close() + _ = oldAgentSocket.conn.Close() + } + c.agentSockets[id] = idConn{ + id: unique, + conn: conn, } - c.agentSockets[id] = conn c.mutex.Unlock() defer func() { c.mutex.Lock() defer c.mutex.Unlock() - delete(c.agentSockets, id) - delete(c.nodes, id) + + // Only delete the connection if it's ours. It could have been + // overwritten. + if idConn := c.agentSockets[id]; idConn.id == unique { + delete(c.agentSockets, id) + delete(c.nodes, id) + } }() decoder := json.NewDecoder(conn) @@ -349,7 +369,7 @@ func (c *coordinator) Close() error { for _, socket := range c.agentSockets { socket := socket go func() { - _ = socket.Close() + _ = socket.conn.Close() wg.Done() }() } From 4b67adf64300e337b16fff931305dd3631cc912b Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Mon, 23 Jan 2023 16:00:32 -0600 Subject: [PATCH 2/4] add test --- tailnet/coordinator.go | 4 +- tailnet/coordinator_test.go | 83 +++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 2 deletions(-) diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 800e0603acf3a..fd1a31a29708f 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -233,7 +233,7 @@ func (c *coordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json _, err = agentSocket.conn.Write(data) if err != nil { - if errors.Is(err, io.EOF) { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) { return nil } return xerrors.Errorf("write json: %w", err) @@ -308,7 +308,7 @@ func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error { for { err := c.handleNextAgentMessage(id, decoder) if err != nil { - if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) { return nil } return xerrors.Errorf("handle next agent message: %w", err) diff --git a/tailnet/coordinator_test.go b/tailnet/coordinator_test.go index a4a020deadf93..2028675c17077 100644 --- a/tailnet/coordinator_test.go +++ b/tailnet/coordinator_test.go @@ -145,4 +145,87 @@ func TestCoordinator(t *testing.T) { <-clientErrChan <-closeClientChan }) + + t.Run("AgentDoubleConnect", func(t *testing.T) { + t.Parallel() + coordinator := tailnet.NewCoordinator() + + agentWS1, agentServerWS1 := net.Pipe() + defer agentWS1.Close() + agentNodeChan1 := make(chan []*tailnet.Node) + sendAgentNode1, agentErrChan1 := tailnet.ServeCoordinator(agentWS1, func(nodes []*tailnet.Node) error { + agentNodeChan1 <- nodes + return nil + }) + agentID := uuid.New() + closeAgentChan1 := make(chan struct{}) + go func() { + err := coordinator.ServeAgent(agentServerWS1, agentID) + assert.NoError(t, err) + close(closeAgentChan1) + }() + sendAgentNode1(&tailnet.Node{}) + require.Eventually(t, func() bool { + return coordinator.Node(agentID) != nil + }, testutil.WaitShort, testutil.IntervalFast) + + clientWS, clientServerWS := net.Pipe() + defer clientWS.Close() + defer clientServerWS.Close() + clientNodeChan := make(chan []*tailnet.Node) + sendClientNode, clientErrChan := tailnet.ServeCoordinator(clientWS, func(nodes []*tailnet.Node) error { + clientNodeChan <- nodes + return nil + }) + clientID := uuid.New() + closeClientChan := make(chan struct{}) + go func() { + err := coordinator.ServeClient(clientServerWS, clientID, agentID) + assert.NoError(t, err) + close(closeClientChan) + }() + agentNodes := <-clientNodeChan + require.Len(t, agentNodes, 1) + sendClientNode(&tailnet.Node{}) + clientNodes := <-agentNodeChan1 + require.Len(t, clientNodes, 1) + + // Ensure an update to the agent node reaches the client! + sendAgentNode1(&tailnet.Node{}) + agentNodes = <-clientNodeChan + require.Len(t, agentNodes, 1) + + // Create a new agent connection without disconnecting the old one. + agentWS2, agentServerWS2 := net.Pipe() + defer agentWS2.Close() + agentNodeChan2 := make(chan []*tailnet.Node) + _, agentErrChan2 := tailnet.ServeCoordinator(agentWS2, func(nodes []*tailnet.Node) error { + agentNodeChan2 <- nodes + return nil + }) + closeAgentChan2 := make(chan struct{}) + go func() { + err := coordinator.ServeAgent(agentServerWS2, agentID) + assert.NoError(t, err) + close(closeAgentChan2) + }() + + // Ensure the existing listening client sends it's node immediately! + clientNodes = <-agentNodeChan2 + require.Len(t, clientNodes, 1) + + err := agentWS2.Close() + require.NoError(t, err) + <-agentErrChan2 + <-closeAgentChan2 + + err = clientWS.Close() + require.NoError(t, err) + <-clientErrChan + <-closeClientChan + + // This original agent websocket should've been closed forcefully. + <-agentErrChan1 + <-closeAgentChan1 + }) } From c080b7c8678bdd5b8e03943be0ce63428b1896c8 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Mon, 23 Jan 2023 16:07:14 -0600 Subject: [PATCH 3/4] fixup! add test --- tailnet/coordinator.go | 8 ++++++++ tailnet/coordinator_test.go | 11 +++++++++++ 2 files changed, 19 insertions(+) diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index fd1a31a29708f..ffa25fa327392 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -144,6 +144,14 @@ func (c *coordinator) Node(id uuid.UUID) *Node { return c.nodes[id] } +func (c *coordinator) NodeCount() int { + return len(c.nodes) +} + +func (c *coordinator) AgentCount() int { + return len(c.agentSockets) +} + // ServeClient accepts a WebSocket connection that wants to connect to an agent // with the specified ID. func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { diff --git a/tailnet/coordinator_test.go b/tailnet/coordinator_test.go index 2028675c17077..60d909f7150b3 100644 --- a/tailnet/coordinator_test.go +++ b/tailnet/coordinator_test.go @@ -214,6 +214,17 @@ func TestCoordinator(t *testing.T) { clientNodes = <-agentNodeChan2 require.Len(t, clientNodes, 1) + counts, ok := coordinator.(interface { + NodeCount() int + AgentCount() int + }) + if !ok { + t.Fatal("coordinator should have NodeCount() and AgentCount()") + } + + assert.Equal(t, 2, counts.NodeCount()) + assert.Equal(t, 1, counts.AgentCount()) + err := agentWS2.Close() require.NoError(t, err) <-agentErrChan2 From fd0509d47f6a39b26e131139c2a006514fe53b44 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Mon, 23 Jan 2023 16:23:33 -0600 Subject: [PATCH 4/4] fixup! add test --- tailnet/coordinator.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index ffa25fa327392..7c3f48c9ea060 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -145,10 +145,14 @@ func (c *coordinator) Node(id uuid.UUID) *Node { } func (c *coordinator) NodeCount() int { + c.mutex.Lock() + defer c.mutex.Unlock() return len(c.nodes) } func (c *coordinator) AgentCount() int { + c.mutex.Lock() + defer c.mutex.Unlock() return len(c.agentSockets) } @@ -299,6 +303,7 @@ func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error { id: unique, conn: conn, } + c.mutex.Unlock() defer func() { c.mutex.Lock()