Skip to content

fix: coordinator node update race #7345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Lint fixes, make core private
Signed-off-by: Spike Curtis <spike@coder.com>
  • Loading branch information
spikecurtis committed May 2, 2023
commit cb695701819cd3ded36d85d2a84ca28547dbb382
90 changes: 43 additions & 47 deletions tailnet/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ const LoggerName = "coord"
// in-memory.
func NewCoordinator(logger slog.Logger) Coordinator {
return &coordinator{
core: NewCore(logger),
core: newCore(logger),
}
}

Expand All @@ -132,15 +132,12 @@ func NewCoordinator(logger slog.Logger) Coordinator {
// This coordinator is incompatible with multiple Coder
// replicas as all node data is in-memory.
type coordinator struct {
core *Core
core *core
}

// since there is only one coordinator in this implementation, use the zero-value UUID
var coordinatorID uuid.UUID

// Core is an in-memory structure of Node and TrackedConn mappings. Its methods may be called from multiple goroutines;
// core is an in-memory structure of Node and TrackedConn mappings. Its methods may be called from multiple goroutines;
// it is protected by a mutex to ensure data stay consistent.
type Core struct {
type core struct {
logger slog.Logger
mutex sync.RWMutex
closed bool
Expand All @@ -158,13 +155,13 @@ type Core struct {
agentNameCache *lru.Cache[uuid.UUID, string]
}

func NewCore(logger slog.Logger) *Core {
func newCore(logger slog.Logger) *core {
nameCache, err := lru.New[uuid.UUID, string](512)
if err != nil {
panic("make lru cache: " + err.Error())
}

return &Core{
return &core{
logger: logger,
closed: false,
nodes: make(map[uuid.UUID]*Node),
Expand Down Expand Up @@ -233,13 +230,12 @@ func (t *TrackedConn) SendUpdates() {
return
}
_, err = t.conn.Write(data)
if err == nil {
t.logger.Debug(t.ctx, "wrote nodes", slog.F("nodes", nodes))
} else {
if err != nil {
t.logger.Info(t.ctx, "could not write nodes to connection", slog.Error(err), slog.F("nodes", nodes))
_ = t.Close()
return
}
t.logger.Debug(t.ctx, "wrote nodes", slog.F("nodes", nodes))
}
}
}
Expand Down Expand Up @@ -267,30 +263,30 @@ func NewTrackedConn(ctx context.Context, cancel func(), conn net.Conn, id uuid.U
// Node returns an in-memory node by ID.
// If the node does not exist, nil is returned.
func (c *coordinator) Node(id uuid.UUID) *Node {
return c.core.Node(id)
return c.core.node(id)
}

func (c *Core) Node(id uuid.UUID) *Node {
func (c *core) node(id uuid.UUID) *Node {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.nodes[id]
}

func (c *coordinator) NodeCount() int {
return c.core.NodeCount()
return c.core.nodeCount()
}

func (c *Core) NodeCount() int {
func (c *core) nodeCount() int {
c.mutex.Lock()
defer c.mutex.Unlock()
return len(c.nodes)
}

func (c *coordinator) AgentCount() int {
return c.core.AgentCount()
return c.core.agentCount()
}

func (c *Core) AgentCount() int {
func (c *core) agentCount() int {
c.mutex.Lock()
defer c.mutex.Unlock()
return len(c.agentSockets)
Expand All @@ -301,13 +297,13 @@ func (c *Core) AgentCount() int {
func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logger := c.core.ClientLogger(id, agent)
logger := c.core.clientLogger(id, agent)
logger.Debug(ctx, "coordinating client")
tc, err := c.core.InitAndTrackClient(ctx, cancel, conn, id, agent)
tc, err := c.core.initAndTrackClient(ctx, cancel, conn, id, agent)
if err != nil {
return err
}
defer c.core.ClientDisconnected(id, agent)
defer c.core.clientDisconnected(id, agent)

// On this goroutine, we read updates from the client and publish them. We start a second goroutine
// to write updates back to the client.
Expand All @@ -326,19 +322,19 @@ func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID)
}
}

func (c *Core) ClientLogger(id, agent uuid.UUID) slog.Logger {
func (c *core) clientLogger(id, agent uuid.UUID) slog.Logger {
return c.logger.With(slog.F("client_id", id), slog.F("agent_id", agent))
}

// InitAndTrackClient creates a TrackedConn for the client, and sends any initial node updates if we have any. It is
// initAndTrackClient creates a TrackedConn for the client, and sends any initial Node updates if we have any. It is
// one function that does two things because it is critical that we hold the mutex for both things, lest we miss some
// updates.
func (c *Core) InitAndTrackClient(
func (c *core) initAndTrackClient(
ctx context.Context, cancel func(), conn net.Conn, id, agent uuid.UUID,
) (
*TrackedConn, error,
) {
logger := c.ClientLogger(id, agent)
logger := c.clientLogger(id, agent)
c.mutex.Lock()
defer c.mutex.Unlock()
if c.closed {
Expand Down Expand Up @@ -372,8 +368,8 @@ func (c *Core) InitAndTrackClient(
return tc, nil
}

func (c *Core) ClientDisconnected(id, agent uuid.UUID) {
logger := c.ClientLogger(id, agent)
func (c *core) clientDisconnected(id, agent uuid.UUID) {
logger := c.clientLogger(id, agent)
c.mutex.Lock()
defer c.mutex.Unlock()
// Clean all traces of this connection from the map.
Expand All @@ -393,18 +389,18 @@ func (c *Core) ClientDisconnected(id, agent uuid.UUID) {
}

func (c *coordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error {
logger := c.core.ClientLogger(id, agent)
logger := c.core.clientLogger(id, agent)
var node Node
err := decoder.Decode(&node)
if err != nil {
return xerrors.Errorf("read json: %w", err)
}
logger.Debug(context.Background(), "got client node update", slog.F("node", node))
return c.core.ClientNodeUpdate(id, agent, &node)
return c.core.clientNodeUpdate(id, agent, &node)
}

func (c *Core) ClientNodeUpdate(id, agent uuid.UUID, node *Node) error {
logger := c.ClientLogger(id, agent)
func (c *core) clientNodeUpdate(id, agent uuid.UUID, node *Node) error {
logger := c.clientLogger(id, agent)
c.mutex.Lock()
defer c.mutex.Unlock()
// Update the node of this client in our in-memory map. If an agent entirely
Expand All @@ -426,7 +422,7 @@ func (c *Core) ClientNodeUpdate(id, agent uuid.UUID, node *Node) error {
return nil
}

func (c *Core) AgentLogger(id uuid.UUID) slog.Logger {
func (c *core) agentLogger(id uuid.UUID) slog.Logger {
return c.logger.With(slog.F("agent_id", id))
}

Expand All @@ -435,11 +431,11 @@ func (c *Core) AgentLogger(id uuid.UUID) slog.Logger {
func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logger := c.core.AgentLogger(id)
logger := c.core.agentLogger(id)
logger.Debug(context.Background(), "coordinating agent")
// This uniquely identifies a connection that belongs to this goroutine.
unique := uuid.New()
tc, err := c.core.InitAndTrackAgent(ctx, cancel, conn, id, unique, name)
tc, err := c.core.initAndTrackAgent(ctx, cancel, conn, id, unique, name)
if err != nil {
return err
}
Expand All @@ -448,7 +444,7 @@ func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) error
// to write updates back to the agent.
go tc.SendUpdates()

defer c.core.AgentDisconnected(id, unique)
defer c.core.agentDisconnected(id, unique)

decoder := json.NewDecoder(conn)
for {
Expand All @@ -463,8 +459,8 @@ func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) error
}
}

func (c *Core) AgentDisconnected(id, unique uuid.UUID) {
logger := c.AgentLogger(id)
func (c *core) agentDisconnected(id, unique uuid.UUID) {
logger := c.agentLogger(id)
c.mutex.Lock()
defer c.mutex.Unlock()

Expand All @@ -477,10 +473,10 @@ func (c *Core) AgentDisconnected(id, unique uuid.UUID) {
}
}

// InitAndTrackAgent creates a TrackedConn for the agent, and sends any initial nodes updates if we have any. It is
// initAndTrackAgent creates a TrackedConn for the agent, and sends any initial nodes updates if we have any. It is
// one function that does two things because it is critical that we hold the mutex for both things, lest we miss some
// updates.
func (c *Core) InitAndTrackAgent(ctx context.Context, cancel func(), conn net.Conn, id, unique uuid.UUID, name string) (*TrackedConn, error) {
func (c *core) initAndTrackAgent(ctx context.Context, cancel func(), conn net.Conn, id, unique uuid.UUID, name string) (*TrackedConn, error) {
logger := c.logger.With(slog.F("agent_id", id))
c.mutex.Lock()
defer c.mutex.Unlock()
Expand Down Expand Up @@ -531,18 +527,18 @@ func (c *Core) InitAndTrackAgent(ctx context.Context, cancel func(), conn net.Co
}

func (c *coordinator) handleNextAgentMessage(id uuid.UUID, decoder *json.Decoder) error {
logger := c.core.AgentLogger(id)
logger := c.core.agentLogger(id)
var node Node
err := decoder.Decode(&node)
if err != nil {
return xerrors.Errorf("read json: %w", err)
}
logger.Debug(context.Background(), "decoded agent node", slog.F("node", node))
return c.core.AgentNodeUpdate(id, &node)
return c.core.agentNodeUpdate(id, &node)
}

func (c *Core) AgentNodeUpdate(id uuid.UUID, node *Node) error {
logger := c.AgentLogger(id)
func (c *core) agentNodeUpdate(id uuid.UUID, node *Node) error {
logger := c.agentLogger(id)
c.mutex.Lock()
defer c.mutex.Unlock()
c.nodes[id] = node
Expand Down Expand Up @@ -571,10 +567,10 @@ func (c *Core) AgentNodeUpdate(id uuid.UUID, node *Node) error {
// Close closes all of the open connections in the coordinator and stops the
// coordinator from accepting new connections.
func (c *coordinator) Close() error {
return c.core.Close()
return c.core.close()
}

func (c *Core) Close() error {
func (c *core) close() error {
c.mutex.Lock()
if c.closed {
c.mutex.Unlock()
Expand Down Expand Up @@ -611,10 +607,10 @@ func (c *Core) Close() error {
}

func (c *coordinator) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) {
c.core.ServeHTTPDebug(w, r)
c.core.serveHTTPDebug(w, r)
}

func (c *Core) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) {
func (c *core) serveHTTPDebug(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")

c.mutex.RLock()
Expand Down
1 change: 1 addition & 0 deletions tailnet/coordinator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ func TestCoordinator_AgentUpdateWhileClientConnects(t *testing.T) {
aNode.PreferredDERP = 1
require.NoError(t, err)
aData, err = json.Marshal(&aNode)
require.NoError(t, err)
err = agentWS.SetWriteDeadline(time.Now().Add(testutil.WaitShort))
require.NoError(t, err)
_, err = agentWS.Write(aData)
Expand Down