Skip to content

Commit 457470d

Browse files
committed
use multiagents for all clients
1 parent e55e146 commit 457470d

File tree

5 files changed

+457
-411
lines changed

5 files changed

+457
-411
lines changed

coderd/tailnet.go

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@ func init() {
3434
}
3535
}
3636

37-
// TODO(coadler): ServerTailnet does not currently remove stale peers.
38-
3937
// NewServerTailnet creates a new tailnet intended for use by coderd. It
4038
// automatically falls back to wsconncache if a legacy agent is encountered.
4139
func NewServerTailnet(
@@ -102,14 +100,49 @@ func NewServerTailnet(
102100
})
103101

104102
go tn.watchAgentUpdates()
103+
go tn.expireOldAgents()
105104
return tn, nil
106105
}
107106

107+
func (s *ServerTailnet) expireOldAgents() {
108+
const (
109+
tick = 5 * time.Minute
110+
cutoff = 30 * time.Minute
111+
)
112+
113+
ticker := time.NewTicker(tick)
114+
defer ticker.Stop()
115+
116+
for {
117+
select {
118+
case <-s.ctx.Done():
119+
return
120+
case <-ticker.C:
121+
}
122+
123+
s.nodesMu.Lock()
124+
agentConn := s.getAgentConn()
125+
for agentID, node := range s.agentNodes {
126+
if time.Since(node.lastConnection) > cutoff {
127+
err := agentConn.UnsubscribeAgent(agentID)
128+
if err != nil {
129+
s.logger.Error(s.ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
130+
}
131+
delete(s.agentNodes, agentID)
132+
133+
// TODO(coadler): actually remove from the netmap
134+
}
135+
}
136+
s.nodesMu.Unlock()
137+
}
138+
}
139+
108140
func (s *ServerTailnet) watchAgentUpdates() {
109141
for {
110-
nodes := s.getAgentConn().NextUpdate(s.ctx)
111-
if nodes == nil {
112-
if s.getAgentConn().IsClosed() && s.ctx.Err() == nil {
142+
conn := s.getAgentConn()
143+
nodes, ok := conn.NextUpdate(s.ctx)
144+
if !ok {
145+
if conn.IsClosed() && s.ctx.Err() == nil {
113146
s.reinitCoordinator()
114147
continue
115148
}
@@ -129,24 +162,22 @@ func (s *ServerTailnet) getAgentConn() tailnet.MultiAgentConn {
129162
}
130163

131164
func (s *ServerTailnet) reinitCoordinator() {
165+
s.nodesMu.Lock()
132166
agentConn := (*s.coord.Load()).ServeMultiAgent(uuid.New())
133167
s.agentConn.Store(&agentConn)
134168

135-
s.nodesMu.Lock()
136169
// Resubscribe to all of the agents we're tracking.
137-
for agentID, agentNode := range s.agentNodes {
138-
closer, err := agentConn.SubscribeAgent(agentID)
170+
for agentID := range s.agentNodes {
171+
err := agentConn.SubscribeAgent(agentID)
139172
if err != nil {
140173
s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID))
141174
}
142-
agentNode.close = closer
143175
}
144176
s.nodesMu.Unlock()
145177
}
146178

147179
type tailnetNode struct {
148180
lastConnection time.Time
149-
close func()
150181
}
151182

152183
type ServerTailnet struct {
@@ -210,13 +241,12 @@ func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
210241
// If we don't have the node, subscribe.
211242
if !ok {
212243
s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID))
213-
closer, err := s.getAgentConn().SubscribeAgent(agentID)
244+
err := s.getAgentConn().SubscribeAgent(agentID)
214245
if err != nil {
215246
return xerrors.Errorf("subscribe agent: %w", err)
216247
}
217248
tnode = &tailnetNode{
218249
lastConnection: time.Now(),
219-
close: closer,
220250
}
221251
s.agentNodes[agentID] = tnode
222252
} else {

enterprise/tailnet/coordinator.go

Lines changed: 80 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212
"sync"
1313

1414
"github.com/google/uuid"
15-
"github.com/hashicorp/go-multierror"
1615
lru "github.com/hashicorp/golang-lru/v2"
1716
"golang.org/x/xerrors"
1817

@@ -42,6 +41,8 @@ func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, err
4241
agentSockets: map[uuid.UUID]agpl.Enqueueable{},
4342
agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]agpl.Enqueueable{},
4443
agentNameCache: nameCache,
44+
clients: map[uuid.UUID]agpl.Enqueueable{},
45+
clientsToAgents: map[uuid.UUID]map[uuid.UUID]struct{}{},
4546
legacyAgents: map[uuid.UUID]struct{}{},
4647
}
4748

@@ -57,14 +58,22 @@ func (c *haCoordinator) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn {
5758
ID: id,
5859
Logger: c.log,
5960
AgentIsLegacyFunc: c.agentIsLegacy,
60-
OnSubscribe: c.multiAgentSubscribe,
61-
OnNodeUpdate: c.multiAgentUpdate,
61+
OnSubscribe: c.clientSubscribeToAgent,
62+
OnNodeUpdate: c.clientNodeUpdate,
63+
OnRemove: c.clientDisconnected,
6264
}).Init()
65+
c.mutex.Lock()
66+
c.clients[id] = m
67+
c.clientsToAgents[id] = map[uuid.UUID]struct{}{}
68+
c.mutex.Unlock()
6369
return m
6470
}
6571

66-
func (c *haCoordinator) multiAgentSubscribe(enq agpl.Enqueueable, agentID uuid.UUID) (func(), error) {
72+
func (c *haCoordinator) clientSubscribeToAgent(enq agpl.Enqueueable, agentID uuid.UUID) error {
6773
c.mutex.Lock()
74+
defer c.mutex.Unlock()
75+
76+
c.initOrSetAgentConnectionSocketLocked(agentID, enq)
6877

6978
node := c.nodes[enq.UniqueID()]
7079

@@ -73,44 +82,43 @@ func (c *haCoordinator) multiAgentSubscribe(enq agpl.Enqueueable, agentID uuid.U
7382
if ok {
7483
err := enq.Enqueue([]*agpl.Node{agentNode})
7584
if err != nil {
76-
return nil, xerrors.Errorf("enqueue agent on subscribe: %w", err)
85+
return xerrors.Errorf("enqueue agent on subscribe: %w", err)
7786
}
7887
} else {
7988
// If we don't have the node locally, notify other coordinators.
80-
c.mutex.Unlock()
8189
err := c.publishClientHello(agentID)
8290
if err != nil {
83-
return nil, xerrors.Errorf("publish client hello: %w", err)
91+
return xerrors.Errorf("publish client hello: %w", err)
8492
}
8593
}
8694

8795
if node != nil {
88-
err := c.handleClientUpdate(enq.UniqueID(), agentID, node)
96+
err := c.sendNodeToAgentLocked(agentID, node)
8997
if err != nil {
90-
return nil, xerrors.Errorf("handle client update: %w", err)
98+
return xerrors.Errorf("handle client update: %w", err)
9199
}
92100
}
93101

94-
return c.cleanupClientConn(enq.UniqueID(), agentID), nil
95-
}
96-
97-
func (c *haCoordinator) multiAgentUpdate(id uuid.UUID, agents []uuid.UUID, node *agpl.Node) error {
98-
var errs *multierror.Error
99-
// This isn't the most efficient, but this coordinator is being deprecated
100-
// soon anyways.
101-
for _, agent := range agents {
102-
err := c.handleClientUpdate(id, agent, node)
103-
if err != nil {
104-
errs = multierror.Append(errs, err)
105-
}
106-
}
107-
if errs != nil {
108-
return errs
109-
}
110-
111102
return nil
112103
}
113104

105+
// func (c *haCoordinator) multiAgentUpdate(id uuid.UUID, agents []uuid.UUID, node *agpl.Node) error {
106+
// var errs *multierror.Error
107+
// // This isn't the most efficient, but this coordinator is being deprecated
108+
// // soon anyways.
109+
// for _, agent := range agents {
110+
// err := c.handleClientUpdate(id, agent, node)
111+
// if err != nil {
112+
// errs = multierror.Append(errs, err)
113+
// }
114+
// }
115+
// if errs != nil {
116+
// return errs
117+
// }
118+
119+
// return nil
120+
// }
121+
114122
type haCoordinator struct {
115123
id uuid.UUID
116124
log slog.Logger
@@ -127,6 +135,9 @@ type haCoordinator struct {
127135
// are subscribed to updates for that agent.
128136
agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]agpl.Enqueueable
129137

138+
clients map[uuid.UUID]agpl.Enqueueable
139+
clientsToAgents map[uuid.UUID]map[uuid.UUID]struct{}
140+
130141
// agentNameCache holds a cache of agent names. If one of them disappears,
131142
// it's helpful to have a name cached for debugging.
132143
agentNameCache *lru.Cache[uuid.UUID, string]
@@ -152,40 +163,25 @@ func (c *haCoordinator) agentLogger(agent uuid.UUID) slog.Logger {
152163

153164
// ServeClient accepts a WebSocket connection that wants to connect to an agent
154165
// with the specified ID.
155-
func (c *haCoordinator) ServeClient(conn net.Conn, id, agent uuid.UUID) error {
166+
func (c *haCoordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error {
156167
ctx, cancel := context.WithCancel(context.Background())
157168
defer cancel()
158-
logger := c.clientLogger(id, agent)
169+
logger := c.clientLogger(id, agentID)
159170

160-
c.mutex.Lock()
161-
162-
tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, 0)
163-
c.initOrSetAgentConnectionSocketLocked(agent, tc)
171+
ma := c.ServeMultiAgent(id)
172+
defer ma.Close()
164173

165-
// When a new connection is requested, we update it with the latest
166-
// node of the agent. This allows the connection to establish.
167-
node, ok := c.nodes[agent]
168-
if ok {
169-
err := tc.Enqueue([]*agpl.Node{node})
170-
c.mutex.Unlock()
171-
if err != nil {
172-
return xerrors.Errorf("enqueue node: %w", err)
173-
}
174-
} else {
175-
c.mutex.Unlock()
176-
err := c.publishClientHello(agent)
177-
if err != nil {
178-
return xerrors.Errorf("publish client hello: %w", err)
179-
}
174+
err := ma.SubscribeAgent(agentID)
175+
if err != nil {
176+
return xerrors.Errorf("subscribe agent: %w", err)
180177
}
181-
go tc.SendUpdates()
182178

183-
defer c.cleanupClientConn(id, agent)
179+
go agpl.SendUpdatesToConn(ctx, logger, ma, conn)
184180

185181
decoder := json.NewDecoder(conn)
186182
// Indefinitely handle messages from the client websocket.
187183
for {
188-
err := c.handleNextClientMessage(id, agent, decoder)
184+
err := c.handleNextClientMessage(id, decoder)
189185
if err != nil {
190186
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) {
191187
return nil
@@ -202,12 +198,14 @@ func (c *haCoordinator) initOrSetAgentConnectionSocketLocked(agentID uuid.UUID,
202198
c.agentToConnectionSockets[agentID] = connectionSockets
203199
}
204200
connectionSockets[enq.UniqueID()] = enq
201+
c.clientsToAgents[enq.UniqueID()][agentID] = struct{}{}
205202
}
206203

207-
func (c *haCoordinator) cleanupClientConn(id, agentID uuid.UUID) func() {
208-
return func() {
209-
c.mutex.Lock()
210-
defer c.mutex.Unlock()
204+
func (c *haCoordinator) clientDisconnected(id uuid.UUID) {
205+
c.mutex.Lock()
206+
defer c.mutex.Unlock()
207+
208+
for agentID := range c.clientsToAgents[id] {
211209
// Clean all traces of this connection from the map.
212210
delete(c.nodes, id)
213211
connectionSockets, ok := c.agentToConnectionSockets[agentID]
@@ -220,39 +218,52 @@ func (c *haCoordinator) cleanupClientConn(id, agentID uuid.UUID) func() {
220218
}
221219
delete(c.agentToConnectionSockets, agentID)
222220
}
221+
222+
delete(c.clients, id)
223+
delete(c.clientsToAgents, id)
223224
}
224225

225-
func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error {
226+
func (c *haCoordinator) handleNextClientMessage(id uuid.UUID, decoder *json.Decoder) error {
226227
var node agpl.Node
227228
err := decoder.Decode(&node)
228229
if err != nil {
229230
return xerrors.Errorf("read json: %w", err)
230231
}
231232

232-
return c.handleClientUpdate(id, agent, &node)
233+
return c.clientNodeUpdate(id, &node)
233234
}
234235

235-
func (c *haCoordinator) handleClientUpdate(id, agent uuid.UUID, node *agpl.Node) error {
236+
func (c *haCoordinator) clientNodeUpdate(id uuid.UUID, node *agpl.Node) error {
236237
c.mutex.Lock()
238+
defer c.mutex.Unlock()
237239
// Update the node of this client in our in-memory map. If an agent entirely
238240
// shuts down and reconnects, it needs to be aware of all clients attempting
239241
// to establish connections.
240242
c.nodes[id] = node
241243

242-
// Write the new node from this client to the actively connected agent.
243-
agentSocket, ok := c.agentSockets[agent]
244+
for agentID := range c.clientsToAgents[id] {
245+
// Write the new node from this client to the actively connected agent.
246+
err := c.sendNodeToAgentLocked(agentID, node)
247+
if err != nil {
248+
c.log.Error(context.Background(), "send node to agent", slog.Error(err), slog.F("agent_id", agentID))
249+
}
250+
}
251+
252+
return nil
253+
}
254+
255+
func (c *haCoordinator) sendNodeToAgentLocked(agentID uuid.UUID, node *agpl.Node) error {
256+
agentSocket, ok := c.agentSockets[agentID]
244257
if !ok {
245-
c.mutex.Unlock()
246258
// If we don't own the agent locally, send it over pubsub to a node that
247259
// owns the agent.
248-
err := c.publishNodesToAgent(agent, []*agpl.Node{node})
260+
err := c.publishNodesToAgent(agentID, []*agpl.Node{node})
249261
if err != nil {
250262
return xerrors.Errorf("publish node to agent")
251263
}
252264
return nil
253265
}
254266
err := agentSocket.Enqueue([]*agpl.Node{node})
255-
c.mutex.Unlock()
256267
if err != nil {
257268
return xerrors.Errorf("enqueue node: %w", err)
258269
}
@@ -422,7 +433,7 @@ func (c *haCoordinator) Close() error {
422433
for _, socket := range c.agentSockets {
423434
socket := socket
424435
go func() {
425-
_ = socket.Close()
436+
_ = socket.CoordinatorClose()
426437
wg.Done()
427438
}()
428439
}
@@ -432,12 +443,17 @@ func (c *haCoordinator) Close() error {
432443
for _, socket := range connMap {
433444
socket := socket
434445
go func() {
435-
_ = socket.Close()
446+
_ = socket.CoordinatorClose()
436447
wg.Done()
437448
}()
438449
}
439450
}
440451

452+
// Ensure clients that have no subscriptions are properly closed.
453+
for _, client := range c.clients {
454+
_ = client.CoordinatorClose()
455+
}
456+
441457
wg.Wait()
442458
return nil
443459
}

0 commit comments

Comments
 (0)