Skip to content

Commit 344d32b

Browse files
authored
feat(coderd): expire agents from server tailnet (#9092)
1 parent a08f7b8 commit 344d32b

File tree

6 files changed

+162
-42
lines changed

6 files changed

+162
-42
lines changed

agent/agent.go

+1
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,7 @@ func (a *agent) trackConnGoroutine(fn func()) error {
758758

759759
func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *tailcfg.DERPMap, disableDirectConnections bool) (_ *tailnet.Conn, err error) {
760760
network, err := tailnet.NewConn(&tailnet.Options{
761+
ID: agentID,
761762
Addresses: a.wireguardAddresses(agentID),
762763
DERPMap: derpMap,
763764
Logger: a.logger.Named("net.tailnet"),

coderd/coderd.go

+1
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ func New(options *Options) *API {
407407
return (*api.TailnetCoordinator.Load()).ServeMultiAgent(uuid.New()), nil
408408
},
409409
wsconncache.New(api._dialWorkspaceAgentTailnet, 0),
410+
api.TracerProvider,
410411
)
411412
if err != nil {
412413
panic("failed to setup server tailnet: " + err.Error())

coderd/tailnet.go

+78-36
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@ import (
1414
"time"
1515

1616
"github.com/google/uuid"
17+
"go.opentelemetry.io/otel/trace"
1718
"golang.org/x/xerrors"
1819
"tailscale.com/derp"
1920
"tailscale.com/tailcfg"
2021

2122
"cdr.dev/slog"
23+
"github.com/coder/coder/coderd/tracing"
2224
"github.com/coder/coder/coderd/wsconncache"
2325
"github.com/coder/coder/codersdk"
2426
"github.com/coder/coder/site"
@@ -45,6 +47,7 @@ func NewServerTailnet(
4547
derpMap *tailcfg.DERPMap,
4648
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error),
4749
cache *wsconncache.Cache,
50+
traceProvider trace.TracerProvider,
4851
) (*ServerTailnet, error) {
4952
logger = logger.Named("servertailnet")
5053
conn, err := tailnet.NewConn(&tailnet.Options{
@@ -58,15 +61,16 @@ func NewServerTailnet(
5861

5962
serverCtx, cancel := context.WithCancel(ctx)
6063
tn := &ServerTailnet{
61-
ctx: serverCtx,
62-
cancel: cancel,
63-
logger: logger,
64-
conn: conn,
65-
getMultiAgent: getMultiAgent,
66-
cache: cache,
67-
agentNodes: map[uuid.UUID]time.Time{},
68-
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
69-
transport: tailnetTransport.Clone(),
64+
ctx: serverCtx,
65+
cancel: cancel,
66+
logger: logger,
67+
tracer: traceProvider.Tracer(tracing.TracerName),
68+
conn: conn,
69+
getMultiAgent: getMultiAgent,
70+
cache: cache,
71+
agentConnectionTimes: map[uuid.UUID]time.Time{},
72+
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
73+
transport: tailnetTransport.Clone(),
7074
}
7175
tn.transport.DialContext = tn.dialContext
7276
tn.transport.MaxIdleConnsPerHost = 10
@@ -139,25 +143,50 @@ func (s *ServerTailnet) expireOldAgents() {
139143
case <-ticker.C:
140144
}
141145

142-
s.nodesMu.Lock()
143-
agentConn := s.getAgentConn()
144-
for agentID, lastConnection := range s.agentNodes {
145-
// If no one has connected since the cutoff and there are no active
146-
// connections, remove the agent.
147-
if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 {
148-
_ = agentConn
149-
// err := agentConn.UnsubscribeAgent(agentID)
150-
// if err != nil {
151-
// s.logger.Error(s.ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
152-
// }
153-
// delete(s.agentNodes, agentID)
154-
155-
// TODO(coadler): actually remove from the netmap, then reenable
156-
// the above
146+
s.doExpireOldAgents(cutoff)
147+
}
148+
}
149+
150+
func (s *ServerTailnet) doExpireOldAgents(cutoff time.Duration) {
151+
// TODO: add some attrs to this.
152+
ctx, span := s.tracer.Start(s.ctx, tracing.FuncName())
153+
defer span.End()
154+
155+
start := time.Now()
156+
deletedCount := 0
157+
158+
s.nodesMu.Lock()
159+
s.logger.Debug(ctx, "pruning inactive agents", slog.F("agent_count", len(s.agentConnectionTimes)))
160+
agentConn := s.getAgentConn()
161+
for agentID, lastConnection := range s.agentConnectionTimes {
162+
// If no one has connected since the cutoff and there are no active
163+
// connections, remove the agent.
164+
if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 {
165+
deleted, err := s.conn.RemovePeer(tailnet.PeerSelector{
166+
ID: tailnet.NodeID(agentID),
167+
IP: netip.PrefixFrom(tailnet.IPFromUUID(agentID), 128),
168+
})
169+
if err != nil {
170+
s.logger.Warn(ctx, "failed to remove peer from server tailnet", slog.Error(err))
171+
continue
172+
}
173+
if !deleted {
174+
s.logger.Warn(ctx, "peer didn't exist in tailnet", slog.Error(err))
175+
}
176+
177+
deletedCount++
178+
delete(s.agentConnectionTimes, agentID)
179+
err = agentConn.UnsubscribeAgent(agentID)
180+
if err != nil {
181+
s.logger.Error(ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
157182
}
158183
}
159-
s.nodesMu.Unlock()
160184
}
185+
s.nodesMu.Unlock()
186+
s.logger.Debug(s.ctx, "successfully pruned inactive agents",
187+
slog.F("deleted", deletedCount),
188+
slog.F("took", time.Since(start)),
189+
)
161190
}
162191

163192
func (s *ServerTailnet) watchAgentUpdates() {
@@ -196,7 +225,7 @@ func (s *ServerTailnet) reinitCoordinator() {
196225
s.agentConn.Store(&agentConn)
197226

198227
// Resubscribe to all of the agents we're tracking.
199-
for agentID := range s.agentNodes {
228+
for agentID := range s.agentConnectionTimes {
200229
err := agentConn.SubscribeAgent(agentID)
201230
if err != nil {
202231
s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID))
@@ -212,14 +241,16 @@ type ServerTailnet struct {
212241
cancel func()
213242

214243
logger slog.Logger
244+
tracer trace.Tracer
215245
conn *tailnet.Conn
216246
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error)
217247
agentConn atomic.Pointer[tailnet.MultiAgentConn]
218248
cache *wsconncache.Cache
219249
nodesMu sync.Mutex
220-
// agentNodes is a map of agent tailnetNodes the server wants to keep a
221-
// connection to. It contains the last time the agent was connected to.
222-
agentNodes map[uuid.UUID]time.Time
250+
// agentConnectionTimes is a map of agent tailnetNodes the server wants to
251+
// keep a connection to. It contains the last time the agent was connected
252+
// to.
253+
agentConnectionTimes map[uuid.UUID]time.Time
223254
// agentTockets holds a map of all open connections to an agent.
224255
agentTickets map[uuid.UUID]map[uuid.UUID]struct{}
225256

@@ -268,7 +299,7 @@ func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
268299
s.nodesMu.Lock()
269300
defer s.nodesMu.Unlock()
270301

271-
_, ok := s.agentNodes[agentID]
302+
_, ok := s.agentConnectionTimes[agentID]
272303
// If we don't have the node, subscribe.
273304
if !ok {
274305
s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID))
@@ -279,14 +310,27 @@ func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
279310
s.agentTickets[agentID] = map[uuid.UUID]struct{}{}
280311
}
281312

282-
s.agentNodes[agentID] = time.Now()
313+
s.agentConnectionTimes[agentID] = time.Now()
283314
return nil
284315
}
285316

317+
func (s *ServerTailnet) acquireTicket(agentID uuid.UUID) (release func()) {
318+
id := uuid.New()
319+
s.nodesMu.Lock()
320+
s.agentTickets[agentID][id] = struct{}{}
321+
s.nodesMu.Unlock()
322+
323+
return func() {
324+
s.nodesMu.Lock()
325+
delete(s.agentTickets[agentID], id)
326+
s.nodesMu.Unlock()
327+
}
328+
}
329+
286330
func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) {
287331
var (
288332
conn *codersdk.WorkspaceAgentConn
289-
ret = func() {}
333+
ret func()
290334
)
291335

292336
if s.getAgentConn().AgentIsLegacy(agentID) {
@@ -299,12 +343,13 @@ func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*code
299343
conn = cconn.WorkspaceAgentConn
300344
ret = release
301345
} else {
346+
s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID))
302347
err := s.ensureAgent(agentID)
303348
if err != nil {
304349
return nil, nil, xerrors.Errorf("ensure agent: %w", err)
305350
}
351+
ret = s.acquireTicket(agentID)
306352

307-
s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID))
308353
conn = codersdk.NewWorkspaceAgentConn(s.conn, codersdk.WorkspaceAgentConnOptions{
309354
AgentID: agentID,
310355
CloseFunc: func() error { return codersdk.ErrSkipClose },
@@ -317,7 +362,6 @@ func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*code
317362
reachable := conn.AwaitReachable(ctx)
318363
if !reachable {
319364
ret()
320-
conn.Close()
321365
return nil, nil, xerrors.New("agent is unreachable")
322366
}
323367

@@ -336,13 +380,11 @@ func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID,
336380
nc, err := conn.DialContext(ctx, network, addr)
337381
if err != nil {
338382
release()
339-
conn.Close()
340383
return nil, xerrors.Errorf("dial context: %w", err)
341384
}
342385

343386
return &netConnCloser{Conn: nc, close: func() {
344387
release()
345-
conn.Close()
346388
}}, err
347389
}
348390

coderd/tailnet_test.go

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/spf13/afero"
1515
"github.com/stretchr/testify/assert"
1616
"github.com/stretchr/testify/require"
17+
"go.opentelemetry.io/otel/trace"
1718

1819
"cdr.dev/slog"
1920
"cdr.dev/slog/sloggers/slogtest"
@@ -232,6 +233,7 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
232233
manifest.DERPMap,
233234
func(context.Context) (tailnet.MultiAgentConn, error) { return coord.ServeMultiAgent(uuid.New()), nil },
234235
cache,
236+
trace.NewNoopTracerProvider(),
235237
)
236238
require.NoError(t, err)
237239

enterprise/wsproxy/wsproxy.go

+1
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ func New(ctx context.Context, opts *Options) (*Server, error) {
250250
connInfo.DERPMap,
251251
s.DialCoordinator,
252252
wsconncache.New(s.DialWorkspaceAgent, 0),
253+
s.TracerProvider,
253254
)
254255
if err != nil {
255256
return nil, xerrors.Errorf("create server tailnet: %w", err)

tailnet/conn.go

+79-6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package tailnet
22

33
import (
44
"context"
5+
"encoding/binary"
56
"errors"
67
"fmt"
78
"net"
@@ -19,7 +20,6 @@ import (
1920
"golang.org/x/xerrors"
2021
"gvisor.dev/gvisor/pkg/tcpip"
2122
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
22-
"tailscale.com/hostinfo"
2323
"tailscale.com/ipn/ipnstate"
2424
"tailscale.com/net/connstats"
2525
"tailscale.com/net/dns"
@@ -67,6 +67,7 @@ func init() {
6767
}
6868

6969
type Options struct {
70+
ID uuid.UUID
7071
Addresses []netip.Prefix
7172
DERPMap *tailcfg.DERPMap
7273
DERPHeader *http.Header
@@ -78,6 +79,18 @@ type Options struct {
7879
ListenPort uint16
7980
}
8081

82+
// NodeID creates a Tailscale NodeID from the last 8 bytes of a UUID. It ensures
83+
// the returned NodeID is always positive.
84+
func NodeID(uid uuid.UUID) tailcfg.NodeID {
85+
id := int64(binary.BigEndian.Uint64(uid[8:]))
86+
87+
// ensure id is positive
88+
y := id >> 63
89+
id = (id ^ y) - y
90+
91+
return tailcfg.NodeID(id)
92+
}
93+
8194
// NewConn constructs a new Wireguard server that will accept connections from the addresses provided.
8295
func NewConn(options *Options) (conn *Conn, err error) {
8396
if options == nil {
@@ -126,13 +139,23 @@ func NewConn(options *Options) (conn *Conn, err error) {
126139
Caps: []filter.CapMatch{},
127140
}},
128141
}
129-
nodeID, err := cryptorand.Int63()
130-
if err != nil {
131-
return nil, xerrors.Errorf("generate node id: %w", err)
142+
143+
var nodeID tailcfg.NodeID
144+
145+
// If we're provided with a UUID, use it to populate our node ID.
146+
if options.ID != uuid.Nil {
147+
nodeID = NodeID(options.ID)
148+
} else {
149+
uid, err := cryptorand.Int63()
150+
if err != nil {
151+
return nil, xerrors.Errorf("generate node id: %w", err)
152+
}
153+
nodeID = tailcfg.NodeID(uid)
132154
}
155+
133156
// This is used by functions below to identify the node via key
134157
netMap.SelfNode = &tailcfg.Node{
135-
ID: tailcfg.NodeID(nodeID),
158+
ID: nodeID,
136159
Key: nodePublicKey,
137160
Addresses: options.Addresses,
138161
AllowedIPs: options.Addresses,
@@ -488,7 +511,7 @@ func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error {
488511
AllowedIPs: node.AllowedIPs,
489512
Endpoints: node.Endpoints,
490513
DERP: fmt.Sprintf("%s:%d", tailcfg.DerpMagicIP, node.PreferredDERP),
491-
Hostinfo: hostinfo.New().View(),
514+
Hostinfo: (&tailcfg.Hostinfo{}).View(),
492515
}
493516
if c.blockEndpoints {
494517
peerNode.Endpoints = nil
@@ -512,6 +535,56 @@ func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error {
512535
return nil
513536
}
514537

538+
// PeerSelector is used to select a peer from within a Tailnet.
539+
type PeerSelector struct {
540+
ID tailcfg.NodeID
541+
IP netip.Prefix
542+
}
543+
544+
func (c *Conn) RemovePeer(selector PeerSelector) (deleted bool, err error) {
545+
c.mutex.Lock()
546+
defer c.mutex.Unlock()
547+
548+
if c.isClosed() {
549+
return false, xerrors.New("connection closed")
550+
}
551+
552+
deleted = false
553+
for _, peer := range c.peerMap {
554+
if peer.ID == selector.ID {
555+
delete(c.peerMap, peer.ID)
556+
deleted = true
557+
break
558+
}
559+
560+
for _, peerIP := range peer.Addresses {
561+
if peerIP.Bits() == selector.IP.Bits() && peerIP.Addr().Compare(selector.IP.Addr()) == 0 {
562+
delete(c.peerMap, peer.ID)
563+
deleted = true
564+
break
565+
}
566+
}
567+
}
568+
if !deleted {
569+
return false, nil
570+
}
571+
572+
c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap))
573+
for _, peer := range c.peerMap {
574+
c.netMap.Peers = append(c.netMap.Peers, peer.Clone())
575+
}
576+
577+
netMapCopy := *c.netMap
578+
c.logger.Debug(context.Background(), "updating network map")
579+
c.wireguardEngine.SetNetworkMap(&netMapCopy)
580+
err = c.reconfig()
581+
if err != nil {
582+
return false, xerrors.Errorf("reconfig: %w", err)
583+
}
584+
585+
return true, nil
586+
}
587+
515588
func (c *Conn) reconfig() error {
516589
cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("net.wgconfig")), netmap.AllowSingleHosts, "")
517590
if err != nil {

0 commit comments

Comments
 (0)