Skip to content

Commit a988fef

Browse files
committed
address review comments
1 parent f9040fc commit a988fef

File tree

19 files changed

+441
-178
lines changed

19 files changed

+441
-178
lines changed

agent/agent.go

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ func (a *agent) run(ctx context.Context) error {
605605
network := a.network
606606
a.closeMutex.Unlock()
607607
if network == nil {
608-
network, err = a.createTailnet(ctx, manifest.DERPMap, manifest.DisableDirectConnections)
608+
network, err = a.createTailnet(ctx, manifest.AgentID, manifest.DERPMap, manifest.DisableDirectConnections)
609609
if err != nil {
610610
return xerrors.Errorf("create tailnet: %w", err)
611611
}
@@ -623,6 +623,11 @@ func (a *agent) run(ctx context.Context) error {
623623

624624
a.startReportingConnectionStats(ctx)
625625
} else {
626+
// Update the wireguard IPs if the agent ID changed.
627+
err := network.SetAddresses(a.wireguardAddresses(manifest.AgentID))
628+
if err != nil {
629+
a.logger.Error(ctx, "update tailnet addresses", slog.Error(err))
630+
}
626631
// Update the DERP map and allow/disallow direct connections.
627632
network.SetDERPMap(manifest.DERPMap)
628633
network.SetBlockEndpoints(manifest.DisableDirectConnections)
@@ -636,6 +641,20 @@ func (a *agent) run(ctx context.Context) error {
636641
return nil
637642
}
638643

644+
func (a *agent) wireguardAddresses(agentID uuid.UUID) []netip.Prefix {
645+
if len(a.addresses) == 0 {
646+
return []netip.Prefix{
647+
// This is the IP that should be used primarily.
648+
netip.PrefixFrom(tailnet.IPFromUUID(agentID), 128),
649+
// We also listen on the legacy codersdk.WorkspaceAgentIP. This
650+
// allows for a transition away from wsconncache.
651+
netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128),
652+
}
653+
}
654+
655+
return a.addresses
656+
}
657+
639658
func (a *agent) trackConnGoroutine(fn func()) error {
640659
a.closeMutex.Lock()
641660
defer a.closeMutex.Unlock()
@@ -650,9 +669,9 @@ func (a *agent) trackConnGoroutine(fn func()) error {
650669
return nil
651670
}
652671

653-
func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap, disableDirectConnections bool) (_ *tailnet.Conn, err error) {
672+
func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *tailcfg.DERPMap, disableDirectConnections bool) (_ *tailnet.Conn, err error) {
654673
network, err := tailnet.NewConn(&tailnet.Options{
655-
Addresses: a.addresses,
674+
Addresses: a.wireguardAddresses(agentID),
656675
DERPMap: derpMap,
657676
Logger: a.logger.Named("tailnet"),
658677
ListenPort: a.tailnetListenPort,

coderd/apidoc/docs.go

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/apidoc/swagger.json

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/tailnet.go

Lines changed: 67 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func init() {
3535
}
3636
}
3737

38-
// TODO: ServerTailnet does not currently remove stale peers.
38+
// TODO(coadler): ServerTailnet does not currently remove stale peers.
3939

4040
// NewServerTailnet creates a new tailnet intended for use by coderd. It
4141
// automatically falls back to wsconncache if a legacy agent is encountered.
@@ -56,10 +56,17 @@ func NewServerTailnet(
5656
return nil, xerrors.Errorf("create tailnet conn: %w", err)
5757
}
5858

59+
id := uuid.New()
60+
ma := (*coord.Load()).ServeMultiAgent(id)
61+
62+
serverCtx, cancel := context.WithCancel(ctx)
5963
tn := &ServerTailnet{
64+
ctx: serverCtx,
65+
cancel: cancel,
6066
logger: logger,
6167
conn: conn,
6268
coordinator: coord,
69+
agentConn: ma,
6370
cache: cache,
6471
agentNodes: map[uuid.UUID]*tailnetNode{},
6572
transport: tailnetTransport.Clone(),
@@ -69,16 +76,9 @@ func NewServerTailnet(
6976
tn.transport.MaxIdleConns = 0
7077

7178
conn.SetNodeCallback(func(node *tailnet.Node) {
72-
tn.nodesMu.Lock()
73-
ids := make([]uuid.UUID, 0, len(tn.agentNodes))
74-
for id := range tn.agentNodes {
75-
ids = append(ids, id)
76-
}
77-
tn.nodesMu.Unlock()
78-
79-
err := (*tn.coordinator.Load()).BroadcastToAgents(ids, node)
79+
err := tn.agentConn.UpdateSelf(node)
8080
if err != nil {
81-
tn.logger.Error(context.Background(), "broadcast server node to agents", slog.Error(err))
81+
tn.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err))
8282
}
8383
})
8484

@@ -99,19 +99,52 @@ func NewServerTailnet(
9999
return left
100100
})
101101

102+
go tn.watchAgentUpdates()
102103
return tn, nil
103104
}
104105

106+
func (s *ServerTailnet) watchAgentUpdates() {
107+
for {
108+
nodes := s.agentConn.NextUpdate(s.ctx)
109+
if nodes == nil {
110+
return
111+
}
112+
113+
toUpdate := make([]*tailnet.Node, 0)
114+
115+
s.nodesMu.Lock()
116+
for _, node := range nodes {
117+
cached, ok := s.agentNodes[node.AgentID]
118+
if ok {
119+
cached.node = node.Node
120+
toUpdate = append(toUpdate, node.Node)
121+
}
122+
}
123+
s.nodesMu.Unlock()
124+
125+
if len(toUpdate) > 0 {
126+
err := s.conn.UpdateNodes(toUpdate, false)
127+
if err != nil {
128+
s.logger.Error(context.Background(), "update node in server tailnet", slog.Error(err))
129+
return
130+
}
131+
}
132+
}
133+
}
134+
105135
type tailnetNode struct {
106136
node *tailnet.Node
107137
lastConnection time.Time
108-
stop func()
109138
}
110139

111140
type ServerTailnet struct {
141+
ctx context.Context
142+
cancel func()
143+
112144
logger slog.Logger
113145
conn *tailnet.Conn
114146
coordinator *atomic.Pointer[tailnet.Coordinator]
147+
agentConn tailnet.MultiAgentConn
115148
cache *wsconncache.Cache
116149
nodesMu sync.Mutex
117150
// agentNodes is a map of agent tailnetNodes the server wants to keep a
@@ -121,23 +154,6 @@ type ServerTailnet struct {
121154
transport *http.Transport
122155
}
123156

124-
func (s *ServerTailnet) updateNode(id uuid.UUID, node *tailnet.Node) {
125-
s.nodesMu.Lock()
126-
cached, ok := s.agentNodes[id]
127-
if ok {
128-
cached.node = node
129-
}
130-
s.nodesMu.Unlock()
131-
132-
if ok {
133-
err := s.conn.UpdateNodes([]*tailnet.Node{node}, false)
134-
if err != nil {
135-
s.logger.Error(context.Background(), "update node in server tailnet", slog.Error(err))
136-
return
137-
}
138-
}
139-
}
140-
141157
func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) (_ *httputil.ReverseProxy, release func(), _ error) {
142158
proxy := httputil.NewSingleHostReverseProxy(targetURL)
143159
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
@@ -188,18 +204,16 @@ func (s *ServerTailnet) getNode(agentID uuid.UUID) (*tailnet.Node, error) {
188204
s.nodesMu.Unlock()
189205
return nil, xerrors.Errorf("node %q not found", agentID.String())
190206
}
191-
stop := coord.SubscribeAgent(agentID, s.updateNode)
207+
208+
err := s.agentConn.SubscribeAgent(agentID, s.conn.Node())
209+
if err != nil {
210+
return nil, xerrors.Errorf("subscribe agent: %w", err)
211+
}
192212
tnode = &tailnetNode{
193213
node: node,
194214
lastConnection: time.Now(),
195-
stop: stop,
196215
}
197216
s.agentNodes[agentID] = tnode
198-
199-
err := coord.BroadcastToAgents([]uuid.UUID{agentID}, s.conn.Node())
200-
if err != nil {
201-
s.logger.Debug(context.Background(), "broadcast server node to agents", slog.Error(err))
202-
}
203217
}
204218
s.nodesMu.Unlock()
205219

@@ -257,7 +271,7 @@ func (*ServerTailnet) nodeIsLegacy(node *tailnet.Node) bool {
257271
return node.Addresses[0].Addr() == codersdk.WorkspaceAgentIP
258272
}
259273

260-
func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (_ *codersdk.WorkspaceAgentConn, release func(), _ error) {
274+
func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) {
261275
node, err := s.awaitNodeExists(ctx, agentID, 5*time.Second)
262276
if err != nil {
263277
return nil, nil, xerrors.Errorf("get agent node: %w", err)
@@ -284,8 +298,13 @@ func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (_ *co
284298
})
285299
}
286300

301+
// Since we now have an open conn, be careful to close it if we error
302+
// without returning it to the user.
303+
287304
reachable := conn.AwaitReachable(ctx)
288305
if !reachable {
306+
ret()
307+
conn.Close()
289308
return nil, nil, xerrors.New("agent is unreachable")
290309
}
291310

@@ -298,8 +317,13 @@ func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID,
298317
return nil, xerrors.Errorf("acquire agent conn: %w", err)
299318
}
300319

320+
// Since we now have an open conn, be careful to close it if we error
321+
// without returning it to the user.
322+
301323
node, err := s.getNode(agentID)
302324
if err != nil {
325+
release()
326+
conn.Close()
303327
return nil, xerrors.New("get agent node")
304328
}
305329

@@ -308,11 +332,14 @@ func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID,
308332
ipp := netip.AddrPortFrom(node.Addresses[0].Addr(), uint16(port))
309333

310334
var nc net.Conn
311-
if network == "tcp" {
335+
switch network {
336+
case "tcp":
312337
nc, err = conn.DialContextTCP(ctx, ipp)
313-
} else if network == "udp" {
338+
case "udp":
314339
nc, err = conn.DialContextUDP(ctx, ipp)
315-
} else {
340+
default:
341+
release()
342+
conn.Close()
316343
return nil, xerrors.Errorf("unknown network %q", network)
317344
}
318345

@@ -333,6 +360,7 @@ func (c *netConnCloser) Close() error {
333360
}
334361

335362
func (s *ServerTailnet) Close() error {
363+
s.cancel()
336364
_ = s.cache.Close()
337365
_ = s.conn.Close()
338366
s.transport.CloseIdleConnections()

coderd/tailnet_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
183183
GetNode: func(agentID uuid.UUID) (*tailnet.Node, error) {
184184
node := coordinator.Node(agentID)
185185
if node == nil {
186-
return nil, xerrors.Errorf("node not found %q", err)
186+
return nil, xerrors.Errorf("node not found %q", agentID)
187187
}
188188
return node, nil
189189
},

coderd/workspaceagents.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ func (api *API) workspaceAgentManifest(rw http.ResponseWriter, r *http.Request)
161161
}
162162

163163
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.Manifest{
164+
AgentID: apiAgent.ID,
164165
Apps: convertApps(dbApps),
165166
DERPMap: api.DERPMap,
166167
GitAuthConfigs: len(api.GitAuthConfigs),

coderd/wsconncache/wsconncache.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// Package wsconncache caches workspace agent connections by UUID.
2-
// Deprecated
2+
// Deprecated: Use ServerTailnet instead.
33
package wsconncache
44

55
import (

coderd/wsconncache/wsconncache_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati
200200
GetNode: func(agentID uuid.UUID) (*tailnet.Node, error) {
201201
node := coordinator.Node(agentID)
202202
if node == nil {
203-
return nil, xerrors.Errorf("node not found %q", err)
203+
return nil, xerrors.Errorf("node not found %q", agentID)
204204
}
205205
return node, nil
206206
},

codersdk/agentsdk/agentsdk.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ func (c *Client) PostMetadata(ctx context.Context, key string, req PostMetadataR
8484
}
8585

8686
type Manifest struct {
87+
AgentID uuid.UUID `json:"agent_id"`
8788
// GitAuthConfigs stores the number of Git configurations
8889
// the Coder deployment has. If this number is >0, we
8990
// set up special configuration in the workspace.

codersdk/workspaceagentconn.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,12 @@ func (c *WorkspaceAgentConn) AwaitReachable(ctx context.Context) bool {
167167
defer span.End()
168168

169169
var (
170-
addr netip.Addr
171-
err error
170+
addr netip.Addr
171+
err error
172+
ticker = time.NewTicker(10 * time.Millisecond)
172173
)
174+
defer ticker.Stop()
175+
173176
for {
174177
addr, err = c.getAgentAddress()
175178
if err == nil {
@@ -179,7 +182,7 @@ func (c *WorkspaceAgentConn) AwaitReachable(ctx context.Context) bool {
179182
select {
180183
case <-ctx.Done():
181184
return false
182-
case <-time.After(10 * time.Millisecond):
185+
case <-ticker.C:
183186
continue
184187
}
185188
}

0 commit comments

Comments
 (0)