@@ -14,11 +14,13 @@ import (
14
14
"time"
15
15
16
16
"github.com/google/uuid"
17
+ "go.opentelemetry.io/otel/trace"
17
18
"golang.org/x/xerrors"
18
19
"tailscale.com/derp"
19
20
"tailscale.com/tailcfg"
20
21
21
22
"cdr.dev/slog"
23
+ "github.com/coder/coder/coderd/tracing"
22
24
"github.com/coder/coder/coderd/wsconncache"
23
25
"github.com/coder/coder/codersdk"
24
26
"github.com/coder/coder/site"
@@ -45,6 +47,7 @@ func NewServerTailnet(
45
47
derpMap * tailcfg.DERPMap ,
46
48
getMultiAgent func (context.Context ) (tailnet.MultiAgentConn , error ),
47
49
cache * wsconncache.Cache ,
50
+ traceProvider trace.TracerProvider ,
48
51
) (* ServerTailnet , error ) {
49
52
logger = logger .Named ("servertailnet" )
50
53
conn , err := tailnet .NewConn (& tailnet.Options {
@@ -58,15 +61,16 @@ func NewServerTailnet(
58
61
59
62
serverCtx , cancel := context .WithCancel (ctx )
60
63
tn := & ServerTailnet {
61
- ctx : serverCtx ,
62
- cancel : cancel ,
63
- logger : logger ,
64
- conn : conn ,
65
- getMultiAgent : getMultiAgent ,
66
- cache : cache ,
67
- agentNodes : map [uuid.UUID ]time.Time {},
68
- agentTickets : map [uuid.UUID ]map [uuid.UUID ]struct {}{},
69
- transport : tailnetTransport .Clone (),
64
+ ctx : serverCtx ,
65
+ cancel : cancel ,
66
+ logger : logger ,
67
+ tracer : traceProvider .Tracer (tracing .TracerName ),
68
+ conn : conn ,
69
+ getMultiAgent : getMultiAgent ,
70
+ cache : cache ,
71
+ agentConnectionTimes : map [uuid.UUID ]time.Time {},
72
+ agentTickets : map [uuid.UUID ]map [uuid.UUID ]struct {}{},
73
+ transport : tailnetTransport .Clone (),
70
74
}
71
75
tn .transport .DialContext = tn .dialContext
72
76
tn .transport .MaxIdleConnsPerHost = 10
@@ -139,25 +143,50 @@ func (s *ServerTailnet) expireOldAgents() {
139
143
case <- ticker .C :
140
144
}
141
145
142
- s .nodesMu .Lock ()
143
- agentConn := s .getAgentConn ()
144
- for agentID , lastConnection := range s .agentNodes {
145
- // If no one has connected since the cutoff and there are no active
146
- // connections, remove the agent.
147
- if time .Since (lastConnection ) > cutoff && len (s .agentTickets [agentID ]) == 0 {
148
- _ = agentConn
149
- // err := agentConn.UnsubscribeAgent(agentID)
150
- // if err != nil {
151
- // s.logger.Error(s.ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
152
- // }
153
- // delete(s.agentNodes, agentID)
154
-
155
- // TODO(coadler): actually remove from the netmap, then reenable
156
- // the above
146
+ s .doExpireOldAgents (cutoff )
147
+ }
148
+ }
149
+
150
+ func (s * ServerTailnet ) doExpireOldAgents (cutoff time.Duration ) {
151
+ // TODO: add some attrs to this.
152
+ ctx , span := s .tracer .Start (s .ctx , tracing .FuncName ())
153
+ defer span .End ()
154
+
155
+ start := time .Now ()
156
+ deletedCount := 0
157
+
158
+ s .nodesMu .Lock ()
159
+ s .logger .Debug (ctx , "pruning inactive agents" , slog .F ("agent_count" , len (s .agentConnectionTimes )))
160
+ agentConn := s .getAgentConn ()
161
+ for agentID , lastConnection := range s .agentConnectionTimes {
162
+ // If no one has connected since the cutoff and there are no active
163
+ // connections, remove the agent.
164
+ if time .Since (lastConnection ) > cutoff && len (s .agentTickets [agentID ]) == 0 {
165
+ deleted , err := s .conn .RemovePeer (tailnet.PeerSelector {
166
+ ID : tailnet .NodeID (agentID ),
167
+ IP : netip .PrefixFrom (tailnet .IPFromUUID (agentID ), 128 ),
168
+ })
169
+ if err != nil {
170
+ s .logger .Warn (ctx , "failed to remove peer from server tailnet" , slog .Error (err ))
171
+ continue
172
+ }
173
+ if ! deleted {
174
+ s .logger .Warn (ctx , "peer didn't exist in tailnet" , slog .Error (err ))
175
+ }
176
+
177
+ deletedCount ++
178
+ delete (s .agentConnectionTimes , agentID )
179
+ err = agentConn .UnsubscribeAgent (agentID )
180
+ if err != nil {
181
+ s .logger .Error (ctx , "unsubscribe expired agent" , slog .Error (err ), slog .F ("agent_id" , agentID ))
157
182
}
158
183
}
159
- s .nodesMu .Unlock ()
160
184
}
185
+ s .nodesMu .Unlock ()
186
+ s .logger .Debug (s .ctx , "successfully pruned inactive agents" ,
187
+ slog .F ("deleted" , deletedCount ),
188
+ slog .F ("took" , time .Since (start )),
189
+ )
161
190
}
162
191
163
192
func (s * ServerTailnet ) watchAgentUpdates () {
@@ -196,7 +225,7 @@ func (s *ServerTailnet) reinitCoordinator() {
196
225
s .agentConn .Store (& agentConn )
197
226
198
227
// Resubscribe to all of the agents we're tracking.
199
- for agentID := range s .agentNodes {
228
+ for agentID := range s .agentConnectionTimes {
200
229
err := agentConn .SubscribeAgent (agentID )
201
230
if err != nil {
202
231
s .logger .Warn (s .ctx , "resubscribe to agent" , slog .Error (err ), slog .F ("agent_id" , agentID ))
@@ -212,14 +241,16 @@ type ServerTailnet struct {
212
241
cancel func ()
213
242
214
243
logger slog.Logger
244
+ tracer trace.Tracer
215
245
conn * tailnet.Conn
216
246
getMultiAgent func (context.Context ) (tailnet.MultiAgentConn , error )
217
247
agentConn atomic.Pointer [tailnet.MultiAgentConn ]
218
248
cache * wsconncache.Cache
219
249
nodesMu sync.Mutex
220
- // agentNodes is a map of agent tailnetNodes the server wants to keep a
221
- // connection to. It contains the last time the agent was connected to.
222
- agentNodes map [uuid.UUID ]time.Time
250
+ // agentConnectionTimes is a map of agent tailnetNodes the server wants to
251
+ // keep a connection to. It contains the last time the agent was connected
252
+ // to.
253
+ agentConnectionTimes map [uuid.UUID ]time.Time
223
254
// agentTockets holds a map of all open connections to an agent.
224
255
agentTickets map [uuid.UUID ]map [uuid.UUID ]struct {}
225
256
@@ -268,7 +299,7 @@ func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
268
299
s .nodesMu .Lock ()
269
300
defer s .nodesMu .Unlock ()
270
301
271
- _ , ok := s .agentNodes [agentID ]
302
+ _ , ok := s .agentConnectionTimes [agentID ]
272
303
// If we don't have the node, subscribe.
273
304
if ! ok {
274
305
s .logger .Debug (s .ctx , "subscribing to agent" , slog .F ("agent_id" , agentID ))
@@ -279,14 +310,27 @@ func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
279
310
s .agentTickets [agentID ] = map [uuid.UUID ]struct {}{}
280
311
}
281
312
282
- s .agentNodes [agentID ] = time .Now ()
313
+ s .agentConnectionTimes [agentID ] = time .Now ()
283
314
return nil
284
315
}
285
316
317
+ func (s * ServerTailnet ) acquireTicket (agentID uuid.UUID ) (release func ()) {
318
+ id := uuid .New ()
319
+ s .nodesMu .Lock ()
320
+ s.agentTickets [agentID ][id ] = struct {}{}
321
+ s .nodesMu .Unlock ()
322
+
323
+ return func () {
324
+ s .nodesMu .Lock ()
325
+ delete (s .agentTickets [agentID ], id )
326
+ s .nodesMu .Unlock ()
327
+ }
328
+ }
329
+
286
330
func (s * ServerTailnet ) AgentConn (ctx context.Context , agentID uuid.UUID ) (* codersdk.WorkspaceAgentConn , func (), error ) {
287
331
var (
288
332
conn * codersdk.WorkspaceAgentConn
289
- ret = func () {}
333
+ ret func ()
290
334
)
291
335
292
336
if s .getAgentConn ().AgentIsLegacy (agentID ) {
@@ -299,12 +343,13 @@ func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*code
299
343
conn = cconn .WorkspaceAgentConn
300
344
ret = release
301
345
} else {
346
+ s .logger .Debug (s .ctx , "acquiring agent" , slog .F ("agent_id" , agentID ))
302
347
err := s .ensureAgent (agentID )
303
348
if err != nil {
304
349
return nil , nil , xerrors .Errorf ("ensure agent: %w" , err )
305
350
}
351
+ ret = s .acquireTicket (agentID )
306
352
307
- s .logger .Debug (s .ctx , "acquiring agent" , slog .F ("agent_id" , agentID ))
308
353
conn = codersdk .NewWorkspaceAgentConn (s .conn , codersdk.WorkspaceAgentConnOptions {
309
354
AgentID : agentID ,
310
355
CloseFunc : func () error { return codersdk .ErrSkipClose },
@@ -317,7 +362,6 @@ func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*code
317
362
reachable := conn .AwaitReachable (ctx )
318
363
if ! reachable {
319
364
ret ()
320
- conn .Close ()
321
365
return nil , nil , xerrors .New ("agent is unreachable" )
322
366
}
323
367
@@ -336,13 +380,11 @@ func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID,
336
380
nc , err := conn .DialContext (ctx , network , addr )
337
381
if err != nil {
338
382
release ()
339
- conn .Close ()
340
383
return nil , xerrors .Errorf ("dial context: %w" , err )
341
384
}
342
385
343
386
return & netConnCloser {Conn : nc , close : func () {
344
387
release ()
345
- conn .Close ()
346
388
}}, err
347
389
}
348
390
0 commit comments