diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go
index 83def00ba1c92..10605498df248 100644
--- a/enterprise/tailnet/coordinator.go
+++ b/enterprise/tailnet/coordinator.go
@@ -13,6 +13,7 @@ import (
"time"
"github.com/google/uuid"
+ lru "github.com/hashicorp/golang-lru/v2"
"golang.org/x/xerrors"
"cdr.dev/slog"
@@ -24,6 +25,12 @@ import (
// that uses PostgreSQL pubsub to exchange handshakes.
func NewCoordinator(logger slog.Logger, pubsub database.Pubsub) (agpl.Coordinator, error) {
ctx, cancelFunc := context.WithCancel(context.Background())
+
+ nameCache, err := lru.New[uuid.UUID, string](512)
+ if err != nil {
+ panic("make lru cache: " + err.Error())
+ }
+
coord := &haCoordinator{
id: uuid.New(),
log: logger,
@@ -31,8 +38,9 @@ func NewCoordinator(logger slog.Logger, pubsub database.Pubsub) (agpl.Coordinato
closeFunc: cancelFunc,
close: make(chan struct{}),
nodes: map[uuid.UUID]*agpl.Node{},
- agentSockets: map[uuid.UUID]net.Conn{},
- agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]net.Conn{},
+ agentSockets: map[uuid.UUID]*agpl.TrackedConn{},
+ agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]*agpl.TrackedConn{},
+ agentNameCache: nameCache,
}
if err := coord.runPubsub(ctx); err != nil {
@@ -53,10 +61,14 @@ type haCoordinator struct {
// nodes maps agent and connection IDs their respective node.
nodes map[uuid.UUID]*agpl.Node
// agentSockets maps agent IDs to their open websocket.
- agentSockets map[uuid.UUID]net.Conn
+ agentSockets map[uuid.UUID]*agpl.TrackedConn
// agentToConnectionSockets maps agent IDs to connection IDs of conns that
// are subscribed to updates for that agent.
- agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]net.Conn
+ agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]*agpl.TrackedConn
+
+ // agentNameCache holds a cache of agent names. If one of them disappears,
+ // it's helpful to have a name cached for debugging.
+ agentNameCache *lru.Cache[uuid.UUID, string]
}
// Node returns an in-memory node by ID.
@@ -94,12 +106,18 @@ func (c *haCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID
c.mutex.Lock()
connectionSockets, ok := c.agentToConnectionSockets[agent]
if !ok {
- connectionSockets = map[uuid.UUID]net.Conn{}
+ connectionSockets = map[uuid.UUID]*agpl.TrackedConn{}
c.agentToConnectionSockets[agent] = connectionSockets
}
- // Insert this connection into a map so the agent can publish node updates.
- connectionSockets[id] = conn
+ now := time.Now().Unix()
+ // Insert this connection into a map so the agent
+ // can publish node updates.
+ connectionSockets[id] = &agpl.TrackedConn{
+ Conn: conn,
+ Start: now,
+ LastWrite: now,
+ }
c.mutex.Unlock()
defer func() {
@@ -176,7 +194,9 @@ func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *js
// ServeAgent accepts a WebSocket connection to an agent that listens to
// incoming connections and publishes node updates.
-func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, _ string) error {
+func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) error {
+ c.agentNameCache.Add(id, name)
+
// Tell clients on other instances to send a callmemaybe to us.
err := c.publishAgentHello(id)
if err != nil {
@@ -196,21 +216,41 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, _ string) error
}
}
+ // This uniquely identifies a connection that belongs to this goroutine.
+ unique := uuid.New()
+ now := time.Now().Unix()
+ overwrites := int64(0)
+
// If an old agent socket is connected, we close it
// to avoid any leaks. This shouldn't ever occur because
// we expect one agent to be running.
c.mutex.Lock()
oldAgentSocket, ok := c.agentSockets[id]
if ok {
+ overwrites = oldAgentSocket.Overwrites + 1
_ = oldAgentSocket.Close()
}
- c.agentSockets[id] = conn
+ c.agentSockets[id] = &agpl.TrackedConn{
+ ID: unique,
+ Conn: conn,
+
+ Name: name,
+ Start: now,
+ LastWrite: now,
+ Overwrites: overwrites,
+ }
c.mutex.Unlock()
+
defer func() {
c.mutex.Lock()
defer c.mutex.Unlock()
- delete(c.agentSockets, id)
- delete(c.nodes, id)
+
+ // Only delete the connection if it's ours. It could have been
+ // overwritten.
+ if idConn, ok := c.agentSockets[id]; ok && idConn.ID == unique {
+ delete(c.agentSockets, id)
+ delete(c.nodes, id)
+ }
}()
decoder := json.NewDecoder(conn)
@@ -576,8 +616,14 @@ func (c *haCoordinator) formatAgentUpdate(id uuid.UUID, node *agpl.Node) ([]byte
return buf.Bytes(), nil
}
-func (*haCoordinator) ServeHTTPDebug(w http.ResponseWriter, _ *http.Request) {
+func (c *haCoordinator) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
- fmt.Fprintf(w, "
coordinator
")
- fmt.Fprintf(w, "ha debug coming soon
")
+
+ c.mutex.RLock()
+ defer c.mutex.RUnlock()
+
+ fmt.Fprintln(w, "high-availability wireguard coordinator debug
")
+ fmt.Fprintln(w, "warning: this only provides info from the node that served the request, if there are multiple replicas this data may be incomplete
")
+
+ agpl.CoordinatorHTTPDebug(c.agentSockets, c.agentToConnectionSockets, c.agentNameCache)(w, r)
}
diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go
index 216c04fe70606..f1e6d664b8a11 100644
--- a/tailnet/coordinator.go
+++ b/tailnet/coordinator.go
@@ -110,7 +110,7 @@ func ServeCoordinator(conn net.Conn, updateNodes func(node []*Node) error) (func
// coordinator is incompatible with multiple Coder replicas as all node data is
// in-memory.
func NewCoordinator() Coordinator {
- cache, err := lru.New[uuid.UUID, string](512)
+ nameCache, err := lru.New[uuid.UUID, string](512)
if err != nil {
panic("make lru cache: " + err.Error())
}
@@ -118,9 +118,9 @@ func NewCoordinator() Coordinator {
return &coordinator{
closed: false,
nodes: map[uuid.UUID]*Node{},
- agentSockets: map[uuid.UUID]*trackedConn{},
- agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]*trackedConn{},
- agentNameCache: cache,
+ agentSockets: map[uuid.UUID]*TrackedConn{},
+ agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]*TrackedConn{},
+ agentNameCache: nameCache,
}
}
@@ -138,31 +138,31 @@ type coordinator struct {
// nodes maps agent and connection IDs their respective node.
nodes map[uuid.UUID]*Node
// agentSockets maps agent IDs to their open websocket.
- agentSockets map[uuid.UUID]*trackedConn
+ agentSockets map[uuid.UUID]*TrackedConn
// agentToConnectionSockets maps agent IDs to connection IDs of conns that
// are subscribed to updates for that agent.
- agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]*trackedConn
+ agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]*TrackedConn
// agentNameCache holds a cache of agent names. If one of them disappears,
// it's helpful to have a name cached for debugging.
agentNameCache *lru.Cache[uuid.UUID, string]
}
-type trackedConn struct {
+type TrackedConn struct {
net.Conn
- // id is an ephemeral UUID used to uniquely identify the owner of the
+ // ID is an ephemeral UUID used to uniquely identify the owner of the
// connection.
- id uuid.UUID
+ ID uuid.UUID
- name string
- start int64
- lastWrite int64
- overwrites int64
+ Name string
+ Start int64
+ LastWrite int64
+ Overwrites int64
}
-func (t *trackedConn) Write(b []byte) (n int, err error) {
- atomic.StoreInt64(&t.lastWrite, time.Now().Unix())
+func (t *TrackedConn) Write(b []byte) (n int, err error) {
+ atomic.StoreInt64(&t.LastWrite, time.Now().Unix())
return t.Conn.Write(b)
}
@@ -212,17 +212,17 @@ func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID)
c.mutex.Lock()
connectionSockets, ok := c.agentToConnectionSockets[agent]
if !ok {
- connectionSockets = map[uuid.UUID]*trackedConn{}
+ connectionSockets = map[uuid.UUID]*TrackedConn{}
c.agentToConnectionSockets[agent] = connectionSockets
}
now := time.Now().Unix()
// Insert this connection into a map so the agent
// can publish node updates.
- connectionSockets[id] = &trackedConn{
+ connectionSockets[id] = &TrackedConn{
Conn: conn,
- start: now,
- lastWrite: now,
+ Start: now,
+ LastWrite: now,
}
c.mutex.Unlock()
defer func() {
@@ -337,17 +337,17 @@ func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) error
// dead.
oldAgentSocket, ok := c.agentSockets[id]
if ok {
- overwrites = oldAgentSocket.overwrites + 1
+ overwrites = oldAgentSocket.Overwrites + 1
_ = oldAgentSocket.Close()
}
- c.agentSockets[id] = &trackedConn{
- id: unique,
+ c.agentSockets[id] = &TrackedConn{
+ ID: unique,
Conn: conn,
- name: name,
- start: now,
- lastWrite: now,
- overwrites: overwrites,
+ Name: name,
+ Start: now,
+ LastWrite: now,
+ Overwrites: overwrites,
}
c.mutex.Unlock()
@@ -357,7 +357,7 @@ func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) error
// Only delete the connection if it's ours. It could have been
// overwritten.
- if idConn, ok := c.agentSockets[id]; ok && idConn.id == unique {
+ if idConn, ok := c.agentSockets[id]; ok && idConn.ID == unique {
delete(c.agentSockets, id)
delete(c.nodes, id)
}
@@ -450,123 +450,134 @@ func (c *coordinator) Close() error {
return nil
}
-func (c *coordinator) ServeHTTPDebug(w http.ResponseWriter, _ *http.Request) {
+func (c *coordinator) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
- now := time.Now()
c.mutex.RLock()
defer c.mutex.RUnlock()
fmt.Fprintln(w, "in-memory wireguard coordinator debug
")
- type idConn struct {
- id uuid.UUID
- conn *trackedConn
- }
-
- {
- fmt.Fprintf(w, "# agents: total %d
\n", len(c.agentSockets))
- fmt.Fprintln(w, "")
- agentSockets := make([]idConn, 0, len(c.agentSockets))
+ CoordinatorHTTPDebug(c.agentSockets, c.agentToConnectionSockets, c.agentNameCache)(w, r)
+}
- for id, conn := range c.agentSockets {
- agentSockets = append(agentSockets, idConn{id, conn})
+func CoordinatorHTTPDebug(
+ agentSocketsMap map[uuid.UUID]*TrackedConn,
+ agentToConnectionSocketsMap map[uuid.UUID]map[uuid.UUID]*TrackedConn,
+ agentNameCache *lru.Cache[uuid.UUID, string],
+) func(w http.ResponseWriter, _ *http.Request) {
+ return func(w http.ResponseWriter, _ *http.Request) {
+ now := time.Now()
+
+ type idConn struct {
+ id uuid.UUID
+ conn *TrackedConn
}
- slices.SortFunc(agentSockets, func(a, b idConn) bool {
- return a.conn.name < b.conn.name
- })
-
- for _, agent := range agentSockets {
- fmt.Fprintf(w, "- %s (
%s
): created %v ago, write %v ago, overwrites %d \n",
- agent.conn.name,
- agent.id.String(),
- now.Sub(time.Unix(agent.conn.start, 0)).Round(time.Second),
- now.Sub(time.Unix(agent.conn.lastWrite, 0)).Round(time.Second),
- agent.conn.overwrites,
- )
-
- if conns := c.agentToConnectionSockets[agent.id]; len(conns) > 0 {
- fmt.Fprintf(w, "connections: total %d
\n", len(conns))
-
- connSockets := make([]idConn, 0, len(conns))
- for id, conn := range conns {
- connSockets = append(connSockets, idConn{id, conn})
- }
- slices.SortFunc(connSockets, func(a, b idConn) bool {
- return a.id.String() < b.id.String()
- })
+ {
+ fmt.Fprintf(w, "# agents: total %d
\n", len(agentSocketsMap))
+ fmt.Fprintln(w, "")
+ agentSockets := make([]idConn, 0, len(agentSocketsMap))
- fmt.Fprintln(w, "")
- for _, connSocket := range connSockets {
- fmt.Fprintf(w, "- %s (
%s
): created %v ago, write %v ago \n",
- connSocket.conn.name,
- connSocket.id.String(),
- now.Sub(time.Unix(connSocket.conn.start, 0)).Round(time.Second),
- now.Sub(time.Unix(connSocket.conn.lastWrite, 0)).Round(time.Second),
- )
- }
- fmt.Fprintln(w, "
")
+ for id, conn := range agentSocketsMap {
+ agentSockets = append(agentSockets, idConn{id, conn})
}
- }
- fmt.Fprintln(w, "
")
- }
+ slices.SortFunc(agentSockets, func(a, b idConn) bool {
+ return a.conn.Name < b.conn.Name
+ })
+
+ for _, agent := range agentSockets {
+ fmt.Fprintf(w, "- %s (
%s
): created %v ago, write %v ago, overwrites %d \n",
+ agent.conn.Name,
+ agent.id.String(),
+ now.Sub(time.Unix(agent.conn.Start, 0)).Round(time.Second),
+ now.Sub(time.Unix(agent.conn.LastWrite, 0)).Round(time.Second),
+ agent.conn.Overwrites,
+ )
- {
- type agentConns struct {
- id uuid.UUID
- conns []idConn
+ if conns := agentToConnectionSocketsMap[agent.id]; len(conns) > 0 {
+ fmt.Fprintf(w, "connections: total %d
\n", len(conns))
+
+ connSockets := make([]idConn, 0, len(conns))
+ for id, conn := range conns {
+ connSockets = append(connSockets, idConn{id, conn})
+ }
+ slices.SortFunc(connSockets, func(a, b idConn) bool {
+ return a.id.String() < b.id.String()
+ })
+
+ fmt.Fprintln(w, "")
+ for _, connSocket := range connSockets {
+ fmt.Fprintf(w, "- %s (
%s
): created %v ago, write %v ago \n",
+ connSocket.conn.Name,
+ connSocket.id.String(),
+ now.Sub(time.Unix(connSocket.conn.Start, 0)).Round(time.Second),
+ now.Sub(time.Unix(connSocket.conn.LastWrite, 0)).Round(time.Second),
+ )
+ }
+ fmt.Fprintln(w, "
")
+ }
+ }
+
+ fmt.Fprintln(w, "
")
}
- missingAgents := []agentConns{}
- for agentID, conns := range c.agentToConnectionSockets {
- if len(conns) == 0 {
- continue
+ {
+ type agentConns struct {
+ id uuid.UUID
+ conns []idConn
}
- if _, ok := c.agentSockets[agentID]; !ok {
- connsSlice := make([]idConn, 0, len(conns))
- for id, conn := range conns {
- connsSlice = append(connsSlice, idConn{id, conn})
+ missingAgents := []agentConns{}
+ for agentID, conns := range agentToConnectionSocketsMap {
+ if len(conns) == 0 {
+ continue
}
- slices.SortFunc(connsSlice, func(a, b idConn) bool {
- return a.id.String() < b.id.String()
- })
-
- missingAgents = append(missingAgents, agentConns{agentID, connsSlice})
- }
- }
- slices.SortFunc(missingAgents, func(a, b agentConns) bool {
- return a.id.String() < b.id.String()
- })
- fmt.Fprintf(w, "# missing agents: total %d
\n", len(missingAgents))
- fmt.Fprintln(w, "")
+ if _, ok := agentSocketsMap[agentID]; !ok {
+ connsSlice := make([]idConn, 0, len(conns))
+ for id, conn := range conns {
+ connsSlice = append(connsSlice, idConn{id, conn})
+ }
+ slices.SortFunc(connsSlice, func(a, b idConn) bool {
+ return a.id.String() < b.id.String()
+ })
- for _, agentConns := range missingAgents {
- agentName, ok := c.agentNameCache.Get(agentConns.id)
- if !ok {
- agentName = "unknown"
+ missingAgents = append(missingAgents, agentConns{agentID, connsSlice})
+ }
}
+ slices.SortFunc(missingAgents, func(a, b agentConns) bool {
+ return a.id.String() < b.id.String()
+ })
- fmt.Fprintf(w, "- %s (
%s
): created ? ago, write ? ago, overwrites ? \n",
- agentName,
- agentConns.id.String(),
- )
-
- fmt.Fprintf(w, "connections: total %d
\n", len(agentConns.conns))
+ fmt.Fprintf(w, "# missing agents: total %d
\n", len(missingAgents))
fmt.Fprintln(w, "")
- for _, agentConn := range agentConns.conns {
- fmt.Fprintf(w, "- %s (
%s
): created %v ago, write %v ago \n",
- agentConn.conn.name,
- agentConn.id.String(),
- now.Sub(time.Unix(agentConn.conn.start, 0)).Round(time.Second),
- now.Sub(time.Unix(agentConn.conn.lastWrite, 0)).Round(time.Second),
+
+ for _, agentConns := range missingAgents {
+ agentName, ok := agentNameCache.Get(agentConns.id)
+ if !ok {
+ agentName = "unknown"
+ }
+
+ fmt.Fprintf(w, "- %s (
%s
): created ? ago, write ? ago, overwrites ? \n",
+ agentName,
+ agentConns.id.String(),
)
+
+ fmt.Fprintf(w, "connections: total %d
\n", len(agentConns.conns))
+ fmt.Fprintln(w, "")
+ for _, agentConn := range agentConns.conns {
+ fmt.Fprintf(w, "- %s (
%s
): created %v ago, write %v ago \n",
+ agentConn.conn.Name,
+ agentConn.id.String(),
+ now.Sub(time.Unix(agentConn.conn.Start, 0)).Round(time.Second),
+ now.Sub(time.Unix(agentConn.conn.LastWrite, 0)).Round(time.Second),
+ )
+ }
+ fmt.Fprintln(w, "
")
}
fmt.Fprintln(w, "
")
}
- fmt.Fprintln(w, "
")
}
}