Skip to content

feat(coderd): expire agents from server tailnet #9092

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ func New(options *Options) *API {
return (*api.TailnetCoordinator.Load()).ServeMultiAgent(uuid.New()), nil
},
wsconncache.New(api._dialWorkspaceAgentTailnet, 0),
api.TracerProvider,
)
if err != nil {
panic("failed to setup server tailnet: " + err.Error())
Expand Down
114 changes: 78 additions & 36 deletions coderd/tailnet.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ import (
"time"

"github.com/google/uuid"
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors"
"tailscale.com/derp"
"tailscale.com/tailcfg"

"cdr.dev/slog"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/coderd/wsconncache"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/site"
Expand All @@ -45,6 +47,7 @@ func NewServerTailnet(
derpMap *tailcfg.DERPMap,
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error),
cache *wsconncache.Cache,
traceProvider trace.TracerProvider,
) (*ServerTailnet, error) {
logger = logger.Named("servertailnet")
conn, err := tailnet.NewConn(&tailnet.Options{
Expand All @@ -58,15 +61,16 @@ func NewServerTailnet(

serverCtx, cancel := context.WithCancel(ctx)
tn := &ServerTailnet{
ctx: serverCtx,
cancel: cancel,
logger: logger,
conn: conn,
getMultiAgent: getMultiAgent,
cache: cache,
agentNodes: map[uuid.UUID]time.Time{},
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
transport: tailnetTransport.Clone(),
ctx: serverCtx,
cancel: cancel,
logger: logger,
tracer: traceProvider.Tracer(tracing.TracerName),
conn: conn,
getMultiAgent: getMultiAgent,
cache: cache,
agentConnectionTimes: map[uuid.UUID]time.Time{},
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
transport: tailnetTransport.Clone(),
}
tn.transport.DialContext = tn.dialContext
tn.transport.MaxIdleConnsPerHost = 10
Expand Down Expand Up @@ -139,25 +143,50 @@ func (s *ServerTailnet) expireOldAgents() {
case <-ticker.C:
}

s.nodesMu.Lock()
agentConn := s.getAgentConn()
for agentID, lastConnection := range s.agentNodes {
// If no one has connected since the cutoff and there are no active
// connections, remove the agent.
if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 {
_ = agentConn
// err := agentConn.UnsubscribeAgent(agentID)
// if err != nil {
// s.logger.Error(s.ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
// }
// delete(s.agentNodes, agentID)

// TODO(coadler): actually remove from the netmap, then reenable
// the above
s.doExpireOldAgents(cutoff)
}
}

func (s *ServerTailnet) doExpireOldAgents(cutoff time.Duration) {
// TODO: add some attrs to this.
ctx, span := s.tracer.Start(s.ctx, tracing.FuncName())
defer span.End()

start := time.Now()
deletedCount := 0

s.nodesMu.Lock()
s.logger.Debug(ctx, "pruning inactive agents", slog.F("agent_count", len(s.agentConnectionTimes)))
agentConn := s.getAgentConn()
for agentID, lastConnection := range s.agentConnectionTimes {
// If no one has connected since the cutoff and there are no active
// connections, remove the agent.
if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 {
deleted, err := s.conn.RemovePeer(tailnet.PeerSelector{
ID: tailnet.NodeID(agentID),
IP: netip.PrefixFrom(tailnet.IPFromUUID(agentID), 128),
})
if err != nil {
s.logger.Warn(ctx, "failed to remove peer from server tailnet", slog.Error(err))
continue
}
if !deleted {
s.logger.Warn(ctx, "peer didn't exist in tailnet", slog.Error(err))
}

deletedCount++
delete(s.agentConnectionTimes, agentID)
err = agentConn.UnsubscribeAgent(agentID)
if err != nil {
s.logger.Error(ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
}
}
s.nodesMu.Unlock()
}
s.nodesMu.Unlock()
s.logger.Debug(s.ctx, "successfully pruned inactive agents",
slog.F("deleted", deletedCount),
slog.F("took", time.Since(start)),
)
}

func (s *ServerTailnet) watchAgentUpdates() {
Expand Down Expand Up @@ -196,7 +225,7 @@ func (s *ServerTailnet) reinitCoordinator() {
s.agentConn.Store(&agentConn)

// Resubscribe to all of the agents we're tracking.
for agentID := range s.agentNodes {
for agentID := range s.agentConnectionTimes {
err := agentConn.SubscribeAgent(agentID)
if err != nil {
s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID))
Expand All @@ -212,14 +241,16 @@ type ServerTailnet struct {
cancel func()

logger slog.Logger
tracer trace.Tracer
conn *tailnet.Conn
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error)
agentConn atomic.Pointer[tailnet.MultiAgentConn]
cache *wsconncache.Cache
nodesMu sync.Mutex
// agentNodes is a map of agent tailnetNodes the server wants to keep a
// connection to. It contains the last time the agent was connected to.
agentNodes map[uuid.UUID]time.Time
// agentConnectionTimes is a map of agent tailnetNodes the server wants to
// keep a connection to. It contains the last time the agent was connected
// to.
agentConnectionTimes map[uuid.UUID]time.Time
// agentTockets holds a map of all open connections to an agent.
agentTickets map[uuid.UUID]map[uuid.UUID]struct{}

Expand Down Expand Up @@ -268,7 +299,7 @@ func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
s.nodesMu.Lock()
defer s.nodesMu.Unlock()

_, ok := s.agentNodes[agentID]
_, ok := s.agentConnectionTimes[agentID]
// If we don't have the node, subscribe.
if !ok {
s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID))
Expand All @@ -279,14 +310,27 @@ func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
s.agentTickets[agentID] = map[uuid.UUID]struct{}{}
}

s.agentNodes[agentID] = time.Now()
s.agentConnectionTimes[agentID] = time.Now()
return nil
}

func (s *ServerTailnet) acquireTicket(agentID uuid.UUID) (release func()) {
id := uuid.New()
s.nodesMu.Lock()
s.agentTickets[agentID][id] = struct{}{}
s.nodesMu.Unlock()

return func() {
s.nodesMu.Lock()
delete(s.agentTickets[agentID], id)
s.nodesMu.Unlock()
}
}

func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) {
var (
conn *codersdk.WorkspaceAgentConn
ret = func() {}
ret func()
)

if s.getAgentConn().AgentIsLegacy(agentID) {
Expand All @@ -299,12 +343,13 @@ func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*code
conn = cconn.WorkspaceAgentConn
ret = release
} else {
s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID))
err := s.ensureAgent(agentID)
if err != nil {
return nil, nil, xerrors.Errorf("ensure agent: %w", err)
}
ret = s.acquireTicket(agentID)

s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID))
conn = codersdk.NewWorkspaceAgentConn(s.conn, codersdk.WorkspaceAgentConnOptions{
AgentID: agentID,
CloseFunc: func() error { return codersdk.ErrSkipClose },
Expand All @@ -317,7 +362,6 @@ func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*code
reachable := conn.AwaitReachable(ctx)
if !reachable {
ret()
conn.Close()
return nil, nil, xerrors.New("agent is unreachable")
}

Expand All @@ -336,13 +380,11 @@ func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID,
nc, err := conn.DialContext(ctx, network, addr)
if err != nil {
release()
conn.Close()
return nil, xerrors.Errorf("dial context: %w", err)
}

return &netConnCloser{Conn: nc, close: func() {
release()
conn.Close()
}}, err
}

Expand Down
2 changes: 2 additions & 0 deletions coderd/tailnet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/trace"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
Expand Down Expand Up @@ -232,6 +233,7 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
manifest.DERPMap,
func(context.Context) (tailnet.MultiAgentConn, error) { return coord.ServeMultiAgent(uuid.New()), nil },
cache,
trace.NewNoopTracerProvider(),
)
require.NoError(t, err)

Expand Down
1 change: 1 addition & 0 deletions enterprise/wsproxy/wsproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ func New(ctx context.Context, opts *Options) (*Server, error) {
connInfo.DERPMap,
s.DialCoordinator,
wsconncache.New(s.DialWorkspaceAgent, 0),
s.TracerProvider,
)
if err != nil {
return nil, xerrors.Errorf("create server tailnet: %w", err)
Expand Down
85 changes: 79 additions & 6 deletions tailnet/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package tailnet

import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
Expand All @@ -19,7 +20,6 @@ import (
"golang.org/x/xerrors"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"tailscale.com/hostinfo"
"tailscale.com/ipn/ipnstate"
"tailscale.com/net/connstats"
"tailscale.com/net/dns"
Expand Down Expand Up @@ -67,6 +67,7 @@ func init() {
}

type Options struct {
ID uuid.UUID
Addresses []netip.Prefix
DERPMap *tailcfg.DERPMap
DERPHeader *http.Header
Expand All @@ -78,6 +79,18 @@ type Options struct {
ListenPort uint16
}

// NodeID creates a Tailscale NodeID from the last 8 bytes of a UUID. It ensures
// the returned NodeID is always positive.
func NodeID(uid uuid.UUID) tailcfg.NodeID {
id := int64(binary.BigEndian.Uint64(uid[8:]))

// ensure id is positive
y := id >> 63
id = (id ^ y) - y

return tailcfg.NodeID(id)
}

// NewConn constructs a new Wireguard server that will accept connections from the addresses provided.
func NewConn(options *Options) (conn *Conn, err error) {
if options == nil {
Expand Down Expand Up @@ -126,13 +139,23 @@ func NewConn(options *Options) (conn *Conn, err error) {
Caps: []filter.CapMatch{},
}},
}
nodeID, err := cryptorand.Int63()
if err != nil {
return nil, xerrors.Errorf("generate node id: %w", err)

var nodeID tailcfg.NodeID

// If we're provided with a UUID, use it to populate our node ID.
if options.ID != uuid.Nil {
nodeID = NodeID(options.ID)
} else {
uid, err := cryptorand.Int63()
if err != nil {
return nil, xerrors.Errorf("generate node id: %w", err)
}
nodeID = tailcfg.NodeID(uid)
}

// This is used by functions below to identify the node via key
netMap.SelfNode = &tailcfg.Node{
ID: tailcfg.NodeID(nodeID),
ID: nodeID,
Key: nodePublicKey,
Addresses: options.Addresses,
AllowedIPs: options.Addresses,
Expand Down Expand Up @@ -488,7 +511,7 @@ func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error {
AllowedIPs: node.AllowedIPs,
Endpoints: node.Endpoints,
DERP: fmt.Sprintf("%s:%d", tailcfg.DerpMagicIP, node.PreferredDERP),
Hostinfo: hostinfo.New().View(),
Hostinfo: (&tailcfg.Hostinfo{}).View(),
}
if c.blockEndpoints {
peerNode.Endpoints = nil
Expand All @@ -512,6 +535,56 @@ func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error {
return nil
}

// PeerSelector is used to select a peer from within a Tailnet.
type PeerSelector struct {
ID tailcfg.NodeID
IP netip.Prefix
}

func (c *Conn) RemovePeer(selector PeerSelector) (deleted bool, err error) {
c.mutex.Lock()
defer c.mutex.Unlock()

if c.isClosed() {
return false, xerrors.New("connection closed")
}

deleted = false
for _, peer := range c.peerMap {
if peer.ID == selector.ID {
delete(c.peerMap, peer.ID)
deleted = true
break
}

for _, peerIP := range peer.Addresses {
if peerIP.Bits() == selector.IP.Bits() && peerIP.Addr().Compare(selector.IP.Addr()) == 0 {
delete(c.peerMap, peer.ID)
deleted = true
break
}
}
}
if !deleted {
return false, nil
}

c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap))
for _, peer := range c.peerMap {
c.netMap.Peers = append(c.netMap.Peers, peer.Clone())
}

netMapCopy := *c.netMap
c.logger.Debug(context.Background(), "updating network map")
c.wireguardEngine.SetNetworkMap(&netMapCopy)
err = c.reconfig()
if err != nil {
return false, xerrors.Errorf("reconfig: %w", err)
}

return true, nil
}

func (c *Conn) reconfig() error {
cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("net.wgconfig")), netmap.AllowSingleHosts, "")
if err != nil {
Expand Down