Skip to content

Commit ec23976

Browse files
committed
feat(coderd): expire agents from server tailnet
We don't want the list of peers in the server tailnet to grow boundlessly, so we expire agents that haven't been connected to in some time.
1 parent 320de18 commit ec23976

File tree

5 files changed

+161
-42
lines changed

5 files changed

+161
-42
lines changed

coderd/coderd.go

+1
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ func New(options *Options) *API {
407407
return (*api.TailnetCoordinator.Load()).ServeMultiAgent(uuid.New()), nil
408408
},
409409
wsconncache.New(api._dialWorkspaceAgentTailnet, 0),
410+
api.TracerProvider,
410411
)
411412
if err != nil {
412413
panic("failed to setup server tailnet: " + err.Error())

coderd/tailnet.go

+78-36
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@ import (
1414
"time"
1515

1616
"github.com/google/uuid"
17+
"go.opentelemetry.io/otel/trace"
1718
"golang.org/x/xerrors"
1819
"tailscale.com/derp"
1920
"tailscale.com/tailcfg"
2021

2122
"cdr.dev/slog"
23+
"github.com/coder/coder/coderd/tracing"
2224
"github.com/coder/coder/coderd/wsconncache"
2325
"github.com/coder/coder/codersdk"
2426
"github.com/coder/coder/site"
@@ -45,6 +47,7 @@ func NewServerTailnet(
4547
derpMap *tailcfg.DERPMap,
4648
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error),
4749
cache *wsconncache.Cache,
50+
traceProvider trace.TracerProvider,
4851
) (*ServerTailnet, error) {
4952
logger = logger.Named("servertailnet")
5053
conn, err := tailnet.NewConn(&tailnet.Options{
@@ -58,15 +61,16 @@ func NewServerTailnet(
5861

5962
serverCtx, cancel := context.WithCancel(ctx)
6063
tn := &ServerTailnet{
61-
ctx: serverCtx,
62-
cancel: cancel,
63-
logger: logger,
64-
conn: conn,
65-
getMultiAgent: getMultiAgent,
66-
cache: cache,
67-
agentNodes: map[uuid.UUID]time.Time{},
68-
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
69-
transport: tailnetTransport.Clone(),
64+
ctx: serverCtx,
65+
cancel: cancel,
66+
logger: logger,
67+
tracer: traceProvider.Tracer(tracing.TracerName),
68+
conn: conn,
69+
getMultiAgent: getMultiAgent,
70+
cache: cache,
71+
agentConnectionTimes: map[uuid.UUID]time.Time{},
72+
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
73+
transport: tailnetTransport.Clone(),
7074
}
7175
tn.transport.DialContext = tn.dialContext
7276
tn.transport.MaxIdleConnsPerHost = 10
@@ -139,25 +143,50 @@ func (s *ServerTailnet) expireOldAgents() {
139143
case <-ticker.C:
140144
}
141145

142-
s.nodesMu.Lock()
143-
agentConn := s.getAgentConn()
144-
for agentID, lastConnection := range s.agentNodes {
145-
// If no one has connected since the cutoff and there are no active
146-
// connections, remove the agent.
147-
if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 {
148-
_ = agentConn
149-
// err := agentConn.UnsubscribeAgent(agentID)
150-
// if err != nil {
151-
// s.logger.Error(s.ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
152-
// }
153-
// delete(s.agentNodes, agentID)
154-
155-
// TODO(coadler): actually remove from the netmap, then reenable
156-
// the above
146+
s.doExpireOldAgents(cutoff)
147+
}
148+
}
149+
150+
func (s *ServerTailnet) doExpireOldAgents(cutoff time.Duration) {
151+
// TODO: add some attrs to this.
152+
ctx, span := s.tracer.Start(s.ctx, tracing.FuncName())
153+
defer span.End()
154+
155+
start := time.Now()
156+
deletedCount := 0
157+
158+
s.nodesMu.Lock()
159+
s.logger.Debug(ctx, "pruning inactive agents", slog.F("agent_count", len(s.agentConnectionTimes)))
160+
agentConn := s.getAgentConn()
161+
for agentID, lastConnection := range s.agentConnectionTimes {
162+
// If no one has connected since the cutoff and there are no active
163+
// connections, remove the agent.
164+
if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 {
165+
deleted, err := s.conn.RemovePeer(tailnet.PeerSelector{
166+
ID: tailnet.NodeID(agentID),
167+
IP: netip.PrefixFrom(tailnet.IPFromUUID(agentID), 128),
168+
})
169+
if err != nil {
170+
s.logger.Warn(ctx, "failed to remove peer from server tailnet", slog.Error(err))
171+
continue
172+
}
173+
if !deleted {
174+
s.logger.Warn(ctx, "peer didn't exist in tailnet", slog.Error(err))
175+
}
176+
177+
deletedCount++
178+
delete(s.agentConnectionTimes, agentID)
179+
err = agentConn.UnsubscribeAgent(agentID)
180+
if err != nil {
181+
s.logger.Error(ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
157182
}
158183
}
159-
s.nodesMu.Unlock()
160184
}
185+
s.nodesMu.Unlock()
186+
s.logger.Debug(s.ctx, "successfully pruned inactive agents",
187+
slog.F("deleted", deletedCount),
188+
slog.F("took", time.Since(start)),
189+
)
161190
}
162191

163192
func (s *ServerTailnet) watchAgentUpdates() {
@@ -196,7 +225,7 @@ func (s *ServerTailnet) reinitCoordinator() {
196225
s.agentConn.Store(&agentConn)
197226

198227
// Resubscribe to all of the agents we're tracking.
199-
for agentID := range s.agentNodes {
228+
for agentID := range s.agentConnectionTimes {
200229
err := agentConn.SubscribeAgent(agentID)
201230
if err != nil {
202231
s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID))
@@ -212,14 +241,16 @@ type ServerTailnet struct {
212241
cancel func()
213242

214243
logger slog.Logger
244+
tracer trace.Tracer
215245
conn *tailnet.Conn
216246
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error)
217247
agentConn atomic.Pointer[tailnet.MultiAgentConn]
218248
cache *wsconncache.Cache
219249
nodesMu sync.Mutex
220-
// agentNodes is a map of agent tailnetNodes the server wants to keep a
221-
// connection to. It contains the last time the agent was connected to.
222-
agentNodes map[uuid.UUID]time.Time
250+
// agentConnectionTimes is a map of agent tailnetNodes the server wants to
251+
// keep a connection to. It contains the last time the agent was connected
252+
// to.
253+
agentConnectionTimes map[uuid.UUID]time.Time
223254
// agentTockets holds a map of all open connections to an agent.
224255
agentTickets map[uuid.UUID]map[uuid.UUID]struct{}
225256

@@ -268,7 +299,7 @@ func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
268299
s.nodesMu.Lock()
269300
defer s.nodesMu.Unlock()
270301

271-
_, ok := s.agentNodes[agentID]
302+
_, ok := s.agentConnectionTimes[agentID]
272303
// If we don't have the node, subscribe.
273304
if !ok {
274305
s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID))
@@ -279,14 +310,27 @@ func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
279310
s.agentTickets[agentID] = map[uuid.UUID]struct{}{}
280311
}
281312

282-
s.agentNodes[agentID] = time.Now()
313+
s.agentConnectionTimes[agentID] = time.Now()
283314
return nil
284315
}
285316

317+
func (s *ServerTailnet) acquireTicket(agentID uuid.UUID) (release func()) {
318+
id := uuid.New()
319+
s.nodesMu.Lock()
320+
s.agentTickets[agentID][id] = struct{}{}
321+
s.nodesMu.Unlock()
322+
323+
return func() {
324+
s.nodesMu.Lock()
325+
delete(s.agentTickets[agentID], id)
326+
s.nodesMu.Unlock()
327+
}
328+
}
329+
286330
func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) {
287331
var (
288332
conn *codersdk.WorkspaceAgentConn
289-
ret = func() {}
333+
ret func()
290334
)
291335

292336
if s.getAgentConn().AgentIsLegacy(agentID) {
@@ -299,12 +343,13 @@ func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*code
299343
conn = cconn.WorkspaceAgentConn
300344
ret = release
301345
} else {
346+
s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID))
302347
err := s.ensureAgent(agentID)
303348
if err != nil {
304349
return nil, nil, xerrors.Errorf("ensure agent: %w", err)
305350
}
351+
ret = s.acquireTicket(agentID)
306352

307-
s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID))
308353
conn = codersdk.NewWorkspaceAgentConn(s.conn, codersdk.WorkspaceAgentConnOptions{
309354
AgentID: agentID,
310355
CloseFunc: func() error { return codersdk.ErrSkipClose },
@@ -317,7 +362,6 @@ func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*code
317362
reachable := conn.AwaitReachable(ctx)
318363
if !reachable {
319364
ret()
320-
conn.Close()
321365
return nil, nil, xerrors.New("agent is unreachable")
322366
}
323367

@@ -336,13 +380,11 @@ func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID,
336380
nc, err := conn.DialContext(ctx, network, addr)
337381
if err != nil {
338382
release()
339-
conn.Close()
340383
return nil, xerrors.Errorf("dial context: %w", err)
341384
}
342385

343386
return &netConnCloser{Conn: nc, close: func() {
344387
release()
345-
conn.Close()
346388
}}, err
347389
}
348390

coderd/tailnet_test.go

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/spf13/afero"
1515
"github.com/stretchr/testify/assert"
1616
"github.com/stretchr/testify/require"
17+
"go.opentelemetry.io/otel/trace"
1718

1819
"cdr.dev/slog"
1920
"cdr.dev/slog/sloggers/slogtest"
@@ -232,6 +233,7 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
232233
manifest.DERPMap,
233234
func(context.Context) (tailnet.MultiAgentConn, error) { return coord.ServeMultiAgent(uuid.New()), nil },
234235
cache,
236+
trace.NewNoopTracerProvider(),
235237
)
236238
require.NoError(t, err)
237239

enterprise/wsproxy/wsproxy.go

+1
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ func New(ctx context.Context, opts *Options) (*Server, error) {
250250
connInfo.DERPMap,
251251
s.DialCoordinator,
252252
wsconncache.New(s.DialWorkspaceAgent, 0),
253+
s.TracerProvider,
253254
)
254255
if err != nil {
255256
return nil, xerrors.Errorf("create server tailnet: %w", err)

tailnet/conn.go

+79-6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package tailnet
22

33
import (
44
"context"
5+
"encoding/binary"
56
"errors"
67
"fmt"
78
"net"
@@ -19,7 +20,6 @@ import (
1920
"golang.org/x/xerrors"
2021
"gvisor.dev/gvisor/pkg/tcpip"
2122
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
22-
"tailscale.com/hostinfo"
2323
"tailscale.com/ipn/ipnstate"
2424
"tailscale.com/net/connstats"
2525
"tailscale.com/net/dns"
@@ -67,6 +67,7 @@ func init() {
6767
}
6868

6969
type Options struct {
70+
ID uuid.UUID
7071
Addresses []netip.Prefix
7172
DERPMap *tailcfg.DERPMap
7273
DERPHeader *http.Header
@@ -78,6 +79,18 @@ type Options struct {
7879
ListenPort uint16
7980
}
8081

82+
// NodeID creates a Tailscale NodeID from the last 8 bytes of a UUID. It ensures
83+
// the returned NodeID is always positive.
84+
func NodeID(uid uuid.UUID) tailcfg.NodeID {
85+
id := int64(binary.BigEndian.Uint64(uid[8:]))
86+
87+
// ensure id is positive
88+
y := id >> 63
89+
id = (id ^ y) - y
90+
91+
return tailcfg.NodeID(id)
92+
}
93+
8194
// NewConn constructs a new Wireguard server that will accept connections from the addresses provided.
8295
func NewConn(options *Options) (conn *Conn, err error) {
8396
if options == nil {
@@ -126,13 +139,23 @@ func NewConn(options *Options) (conn *Conn, err error) {
126139
Caps: []filter.CapMatch{},
127140
}},
128141
}
129-
nodeID, err := cryptorand.Int63()
130-
if err != nil {
131-
return nil, xerrors.Errorf("generate node id: %w", err)
142+
143+
var nodeID tailcfg.NodeID
144+
145+
// If we're provided with a UUID, use it to populate our node ID.
146+
if options.ID != uuid.Nil {
147+
nodeID = NodeID(options.ID)
148+
} else {
149+
uid, err := cryptorand.Int63()
150+
if err != nil {
151+
return nil, xerrors.Errorf("generate node id: %w", err)
152+
}
153+
nodeID = tailcfg.NodeID(uid)
132154
}
155+
133156
// This is used by functions below to identify the node via key
134157
netMap.SelfNode = &tailcfg.Node{
135-
ID: tailcfg.NodeID(nodeID),
158+
ID: nodeID,
136159
Key: nodePublicKey,
137160
Addresses: options.Addresses,
138161
AllowedIPs: options.Addresses,
@@ -488,7 +511,7 @@ func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error {
488511
AllowedIPs: node.AllowedIPs,
489512
Endpoints: node.Endpoints,
490513
DERP: fmt.Sprintf("%s:%d", tailcfg.DerpMagicIP, node.PreferredDERP),
491-
Hostinfo: hostinfo.New().View(),
514+
Hostinfo: (&tailcfg.Hostinfo{}).View(),
492515
}
493516
if c.blockEndpoints {
494517
peerNode.Endpoints = nil
@@ -512,6 +535,56 @@ func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error {
512535
return nil
513536
}
514537

538+
// PeerSelector is used to select a peer from within a Tailnet.
539+
type PeerSelector struct {
540+
ID tailcfg.NodeID
541+
IP netip.Prefix
542+
}
543+
544+
func (c *Conn) RemovePeer(selector PeerSelector) (deleted bool, err error) {
545+
c.mutex.Lock()
546+
defer c.mutex.Unlock()
547+
548+
if c.isClosed() {
549+
return false, xerrors.New("connection closed")
550+
}
551+
552+
deleted = false
553+
for _, peer := range c.peerMap {
554+
if peer.ID == selector.ID {
555+
delete(c.peerMap, peer.ID)
556+
deleted = true
557+
break
558+
}
559+
560+
for _, peerIP := range peer.Addresses {
561+
if peerIP.Bits() == selector.IP.Bits() && peerIP.Addr().Compare(selector.IP.Addr()) == 0 {
562+
delete(c.peerMap, peer.ID)
563+
deleted = true
564+
break
565+
}
566+
}
567+
}
568+
if !deleted {
569+
return false, nil
570+
}
571+
572+
c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap))
573+
for _, peer := range c.peerMap {
574+
c.netMap.Peers = append(c.netMap.Peers, peer.Clone())
575+
}
576+
577+
netMapCopy := *c.netMap
578+
c.logger.Debug(context.Background(), "updating network map")
579+
c.wireguardEngine.SetNetworkMap(&netMapCopy)
580+
err = c.reconfig()
581+
if err != nil {
582+
return false, xerrors.Errorf("reconfig: %w", err)
583+
}
584+
585+
return true, nil
586+
}
587+
515588
func (c *Conn) reconfig() error {
516589
cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("net.wgconfig")), netmap.AllowSingleHosts, "")
517590
if err != nil {

0 commit comments

Comments
 (0)