@@ -13,6 +13,7 @@ import (
13
13
"time"
14
14
15
15
"github.com/google/uuid"
16
+ lru "github.com/hashicorp/golang-lru/v2"
16
17
"golang.org/x/xerrors"
17
18
18
19
"cdr.dev/slog"
@@ -24,15 +25,22 @@ import (
24
25
// that uses PostgreSQL pubsub to exchange handshakes.
25
26
func NewCoordinator (logger slog.Logger , pubsub database.Pubsub ) (agpl.Coordinator , error ) {
26
27
ctx , cancelFunc := context .WithCancel (context .Background ())
28
+
29
+ nameCache , err := lru.New [uuid.UUID , string ](512 )
30
+ if err != nil {
31
+ panic ("make lru cache: " + err .Error ())
32
+ }
33
+
27
34
coord := & haCoordinator {
28
35
id : uuid .New (),
29
36
log : logger ,
30
37
pubsub : pubsub ,
31
38
closeFunc : cancelFunc ,
32
39
close : make (chan struct {}),
33
40
nodes : map [uuid.UUID ]* agpl.Node {},
34
- agentSockets : map [uuid.UUID ]net.Conn {},
35
- agentToConnectionSockets : map [uuid.UUID ]map [uuid.UUID ]net.Conn {},
41
+ agentSockets : map [uuid.UUID ]* agpl.TrackedConn {},
42
+ agentToConnectionSockets : map [uuid.UUID ]map [uuid.UUID ]* agpl.TrackedConn {},
43
+ agentNameCache : nameCache ,
36
44
}
37
45
38
46
if err := coord .runPubsub (ctx ); err != nil {
@@ -53,10 +61,14 @@ type haCoordinator struct {
53
61
// nodes maps agent and connection IDs their respective node.
54
62
nodes map [uuid.UUID ]* agpl.Node
55
63
// agentSockets maps agent IDs to their open websocket.
56
- agentSockets map [uuid.UUID ]net. Conn
64
+ agentSockets map [uuid.UUID ]* agpl. TrackedConn
57
65
// agentToConnectionSockets maps agent IDs to connection IDs of conns that
58
66
// are subscribed to updates for that agent.
59
- agentToConnectionSockets map [uuid.UUID ]map [uuid.UUID ]net.Conn
67
+ agentToConnectionSockets map [uuid.UUID ]map [uuid.UUID ]* agpl.TrackedConn
68
+
69
+ // agentNameCache holds a cache of agent names. If one of them disappears,
70
+ // it's helpful to have a name cached for debugging.
71
+ agentNameCache * lru.Cache [uuid.UUID , string ]
60
72
}
61
73
62
74
// Node returns an in-memory node by ID.
@@ -94,12 +106,18 @@ func (c *haCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID
94
106
c .mutex .Lock ()
95
107
connectionSockets , ok := c .agentToConnectionSockets [agent ]
96
108
if ! ok {
97
- connectionSockets = map [uuid.UUID ]net. Conn {}
109
+ connectionSockets = map [uuid.UUID ]* agpl. TrackedConn {}
98
110
c .agentToConnectionSockets [agent ] = connectionSockets
99
111
}
100
112
101
- // Insert this connection into a map so the agent can publish node updates.
102
- connectionSockets [id ] = conn
113
+ now := time .Now ().Unix ()
114
+ // Insert this connection into a map so the agent
115
+ // can publish node updates.
116
+ connectionSockets [id ] = & agpl.TrackedConn {
117
+ Conn : conn ,
118
+ Start : now ,
119
+ LastWrite : now ,
120
+ }
103
121
c .mutex .Unlock ()
104
122
105
123
defer func () {
@@ -176,7 +194,9 @@ func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *js
176
194
177
195
// ServeAgent accepts a WebSocket connection to an agent that listens to
178
196
// incoming connections and publishes node updates.
179
- func (c * haCoordinator ) ServeAgent (conn net.Conn , id uuid.UUID , _ string ) error {
197
+ func (c * haCoordinator ) ServeAgent (conn net.Conn , id uuid.UUID , name string ) error {
198
+ c .agentNameCache .Add (id , name )
199
+
180
200
// Tell clients on other instances to send a callmemaybe to us.
181
201
err := c .publishAgentHello (id )
182
202
if err != nil {
@@ -196,21 +216,41 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, _ string) error
196
216
}
197
217
}
198
218
219
+ // This uniquely identifies a connection that belongs to this goroutine.
220
+ unique := uuid .New ()
221
+ now := time .Now ().Unix ()
222
+ overwrites := int64 (0 )
223
+
199
224
// If an old agent socket is connected, we close it
200
225
// to avoid any leaks. This shouldn't ever occur because
201
226
// we expect one agent to be running.
202
227
c .mutex .Lock ()
203
228
oldAgentSocket , ok := c .agentSockets [id ]
204
229
if ok {
230
+ overwrites = oldAgentSocket .Overwrites + 1
205
231
_ = oldAgentSocket .Close ()
206
232
}
207
- c .agentSockets [id ] = conn
233
+ c .agentSockets [id ] = & agpl.TrackedConn {
234
+ ID : unique ,
235
+ Conn : conn ,
236
+
237
+ Name : name ,
238
+ Start : now ,
239
+ LastWrite : now ,
240
+ Overwrites : overwrites ,
241
+ }
208
242
c .mutex .Unlock ()
243
+
209
244
defer func () {
210
245
c .mutex .Lock ()
211
246
defer c .mutex .Unlock ()
212
- delete (c .agentSockets , id )
213
- delete (c .nodes , id )
247
+
248
+ // Only delete the connection if it's ours. It could have been
249
+ // overwritten.
250
+ if idConn , ok := c .agentSockets [id ]; ok && idConn .ID == unique {
251
+ delete (c .agentSockets , id )
252
+ delete (c .nodes , id )
253
+ }
214
254
}()
215
255
216
256
decoder := json .NewDecoder (conn )
@@ -576,8 +616,14 @@ func (c *haCoordinator) formatAgentUpdate(id uuid.UUID, node *agpl.Node) ([]byte
576
616
return buf .Bytes (), nil
577
617
}
578
618
579
- func (* haCoordinator ) ServeHTTPDebug (w http.ResponseWriter , _ * http.Request ) {
619
+ func (c * haCoordinator ) ServeHTTPDebug (w http.ResponseWriter , r * http.Request ) {
580
620
w .Header ().Set ("Content-Type" , "text/html; charset=utf-8" )
581
- fmt .Fprintf (w , "<h1>coordinator</h1>" )
582
- fmt .Fprintf (w , "<h2>ha debug coming soon</h2>" )
621
+
622
+ c .mutex .RLock ()
623
+ defer c .mutex .RUnlock ()
624
+
625
+ fmt .Fprintln (w , "<h1>high-availability wireguard coordinator debug</h1>" )
626
+ fmt .Fprintln (w , "<h4 style=\" margin-top:-25px\" >warning: this only provides info from the node that served the request, if there are multiple replicas this data may be incomplete</h4>" )
627
+
628
+ agpl .CoordinatorHTTPDebug (c .agentSockets , c .agentToConnectionSockets , c .agentNameCache )(w , r )
583
629
}
0 commit comments