Skip to content

Commit 677721e

Browse files
authored
fix(tailnet): Skip nodes without DERP, avoid use of RemoveAllPeers (#6320)
* fix(tailnet): Skip nodes without DERP, avoid use of RemoveAllPeers
1 parent a414de9 commit 677721e

File tree

10 files changed

+89
-44
lines changed

10 files changed

+89
-44
lines changed

agent/agent.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,9 @@ func (a *agent) runCoordinator(ctx context.Context, network *tailnet.Conn) error
601601
}
602602
defer coordinator.Close()
603603
a.logger.Info(ctx, "connected to coordination server")
604-
sendNodes, errChan := tailnet.ServeCoordinator(coordinator, network.UpdateNodes)
604+
sendNodes, errChan := tailnet.ServeCoordinator(coordinator, func(nodes []*tailnet.Node) error {
605+
return network.UpdateNodes(nodes, false)
606+
})
605607
network.SetNodeCallback(sendNodes)
606608
select {
607609
case <-ctx.Done():

agent/agent_test.go

+12-3
Original file line numberDiff line numberDiff line change
@@ -1179,12 +1179,21 @@ func setupAgent(t *testing.T, metadata agentsdk.Metadata, ptyTimeout time.Durati
11791179
coordinator.ServeClient(serverConn, uuid.New(), agentID)
11801180
}()
11811181
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
1182-
return conn.UpdateNodes(node)
1182+
return conn.UpdateNodes(node, false)
11831183
})
11841184
conn.SetNodeCallback(sendNode)
1185-
return &codersdk.WorkspaceAgentConn{
1185+
agentConn := &codersdk.WorkspaceAgentConn{
11861186
Conn: conn,
1187-
}, c, statsCh, fs
1187+
}
1188+
t.Cleanup(func() {
1189+
_ = agentConn.Close()
1190+
})
1191+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
1192+
defer cancel()
1193+
if !agentConn.AwaitReachable(ctx) {
1194+
t.Fatal("agent not reachable")
1195+
}
1196+
return agentConn, c, statsCh, fs
11881197
}
11891198

11901199
var dialTestPayload = []byte("dean-was-here123")

cli/speedtest_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"github.com/stretchr/testify/assert"
88
"github.com/stretchr/testify/require"
99

10+
"cdr.dev/slog"
1011
"cdr.dev/slog/sloggers/slogtest"
1112
"github.com/coder/coder/agent"
1213
"github.com/coder/coder/cli/clitest"
@@ -28,7 +29,7 @@ func TestSpeedtest(t *testing.T) {
2829
agentClient.SetSessionToken(agentToken)
2930
agentCloser := agent.New(agent.Options{
3031
Client: agentClient,
31-
Logger: slogtest.Make(t, nil).Named("agent"),
32+
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
3233
})
3334
defer agentCloser.Close()
3435
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)

cli/ssh_test.go

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"golang.org/x/crypto/ssh"
2525
gosshagent "golang.org/x/crypto/ssh/agent"
2626

27+
"cdr.dev/slog"
2728
"cdr.dev/slog/sloggers/slogtest"
2829

2930
"github.com/coder/coder/agent"
@@ -47,6 +48,7 @@ func setupWorkspaceForAgent(t *testing.T, mutate func([]*proto.Agent) []*proto.A
4748
}
4849
}
4950
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
51+
client.Logger = slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug)
5052
user := coderdtest.CreateFirstUser(t, client)
5153
agentToken := uuid.NewString()
5254
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{

coderd/coderd_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,16 @@ func TestDERP(t *testing.T) {
8080
})
8181
require.NoError(t, err)
8282

83-
w2Ready := make(chan struct{}, 1)
83+
w2Ready := make(chan struct{})
8484
w2ReadyOnce := sync.Once{}
8585
w1.SetNodeCallback(func(node *tailnet.Node) {
86-
w2.UpdateNodes([]*tailnet.Node{node})
86+
w2.UpdateNodes([]*tailnet.Node{node}, false)
8787
w2ReadyOnce.Do(func() {
8888
close(w2Ready)
8989
})
9090
})
9191
w2.SetNodeCallback(func(node *tailnet.Node) {
92-
w1.UpdateNodes([]*tailnet.Node{node})
92+
w1.UpdateNodes([]*tailnet.Node{node}, false)
9393
})
9494

9595
conn := make(chan struct{})

coderd/workspaceagents.go

+15-14
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
404404
}
405405

406406
func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
407+
ctx := r.Context()
407408
clientConn, serverConn := net.Pipe()
408409

409410
derpMap := api.DERPMap.Clone()
@@ -453,32 +454,32 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*
453454
}
454455

455456
sendNodes, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
456-
err := conn.RemoveAllPeers()
457-
if err != nil {
458-
return xerrors.Errorf("remove all peers: %w", err)
459-
}
460-
461-
err = conn.UpdateNodes(node)
457+
err = conn.UpdateNodes(node, true)
462458
if err != nil {
463459
return xerrors.Errorf("update nodes: %w", err)
464460
}
465461
return nil
466462
})
467463
conn.SetNodeCallback(sendNodes)
464+
agentConn := &codersdk.WorkspaceAgentConn{
465+
Conn: conn,
466+
CloseFunc: func() {
467+
_ = clientConn.Close()
468+
_ = serverConn.Close()
469+
},
470+
}
468471
go func() {
469472
err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID)
470473
if err != nil {
471474
api.Logger.Warn(r.Context(), "tailnet coordinator client error", slog.Error(err))
472-
_ = conn.Close()
475+
_ = agentConn.Close()
473476
}
474477
}()
475-
return &codersdk.WorkspaceAgentConn{
476-
Conn: conn,
477-
CloseFunc: func() {
478-
_ = clientConn.Close()
479-
_ = serverConn.Close()
480-
},
481-
}, nil
478+
if !agentConn.AwaitReachable(ctx) {
479+
_ = agentConn.Close()
480+
return nil, xerrors.Errorf("agent not reachable")
481+
}
482+
return agentConn, nil
482483
}
483484

484485
// @Summary Get connection info for workspace agent

coderd/wsconncache/wsconncache_test.go

+11-2
Original file line numberDiff line numberDiff line change
@@ -191,12 +191,21 @@ func setupAgent(t *testing.T, metadata agentsdk.Metadata, ptyTimeout time.Durati
191191
})
192192
go coordinator.ServeClient(serverConn, uuid.New(), agentID)
193193
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
194-
return conn.UpdateNodes(node)
194+
return conn.UpdateNodes(node, false)
195195
})
196196
conn.SetNodeCallback(sendNode)
197-
return &codersdk.WorkspaceAgentConn{
197+
agentConn := &codersdk.WorkspaceAgentConn{
198198
Conn: conn,
199199
}
200+
t.Cleanup(func() {
201+
_ = agentConn.Close()
202+
})
203+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
204+
defer cancel()
205+
if !agentConn.AwaitReachable(ctx) {
206+
t.Fatal("agent not reachable")
207+
}
208+
return agentConn
200209
}
201210

202211
type client struct {

codersdk/workspaceagents.go

+15-7
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ type DialWorkspaceAgentOptions struct {
100100
BlockEndpoints bool
101101
}
102102

103-
func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *DialWorkspaceAgentOptions) (*WorkspaceAgentConn, error) {
103+
func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *DialWorkspaceAgentOptions) (agentConn *WorkspaceAgentConn, err error) {
104104
if options == nil {
105105
options = &DialWorkspaceAgentOptions{}
106106
}
@@ -128,6 +128,11 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
128128
if err != nil {
129129
return nil, xerrors.Errorf("create tailnet: %w", err)
130130
}
131+
defer func() {
132+
if err != nil {
133+
_ = conn.Close()
134+
}
135+
}()
131136

132137
coordinateURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/coordinate", agentID))
133138
if err != nil {
@@ -145,7 +150,12 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
145150
Jar: jar,
146151
Transport: c.HTTPClient.Transport,
147152
}
148-
ctx, cancelFunc := context.WithCancel(ctx)
153+
ctx, cancel := context.WithCancel(ctx)
154+
defer func() {
155+
if err != nil {
156+
cancel()
157+
}
158+
}()
149159
closed := make(chan struct{})
150160
first := make(chan error)
151161
go func() {
@@ -175,7 +185,7 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
175185
continue
176186
}
177187
sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(node []*tailnet.Node) error {
178-
return conn.UpdateNodes(node)
188+
return conn.UpdateNodes(node, false)
179189
})
180190
conn.SetNodeCallback(sendNode)
181191
options.Logger.Debug(ctx, "serving coordinator")
@@ -194,15 +204,13 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
194204
}()
195205
err = <-first
196206
if err != nil {
197-
cancelFunc()
198-
_ = conn.Close()
199207
return nil, err
200208
}
201209

202-
agentConn := &WorkspaceAgentConn{
210+
agentConn = &WorkspaceAgentConn{
203211
Conn: conn,
204212
CloseFunc: func() {
205-
cancelFunc()
213+
cancel()
206214
<-closed
207215
},
208216
}

tailnet/conn.go

+22-9
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ func NewConn(options *Options) (conn *Conn, err error) {
130130
}()
131131

132132
dialer := &tsdial.Dialer{
133-
Logf: Logger(options.Logger),
133+
Logf: Logger(options.Logger.Named("tsdial")),
134134
}
135135
wireguardEngine, err := wgengine.NewUserspaceEngine(Logger(options.Logger.Named("wgengine")), wgengine.Config{
136136
LinkMonitor: wireguardMonitor,
@@ -179,6 +179,7 @@ func NewConn(options *Options) (conn *Conn, err error) {
179179
wireguardEngine = wgengine.NewWatchdog(wireguardEngine)
180180
wireguardEngine.SetDERPMap(options.DERPMap)
181181
netMapCopy := *netMap
182+
options.Logger.Debug(context.Background(), "updating network map", slog.F("net_map", netMapCopy))
182183
wireguardEngine.SetNetworkMap(&netMapCopy)
183184

184185
localIPSet := netipx.IPSetBuilder{}
@@ -329,9 +330,11 @@ func (c *Conn) SetDERPMap(derpMap *tailcfg.DERPMap) {
329330
c.mutex.Lock()
330331
defer c.mutex.Unlock()
331332
c.logger.Debug(context.Background(), "updating derp map", slog.F("derp_map", derpMap))
332-
c.netMap.DERPMap = derpMap
333-
c.wireguardEngine.SetNetworkMap(c.netMap)
334333
c.wireguardEngine.SetDERPMap(derpMap)
334+
c.netMap.DERPMap = derpMap
335+
netMapCopy := *c.netMap
336+
c.logger.Debug(context.Background(), "updating network map", slog.F("net_map", netMapCopy))
337+
c.wireguardEngine.SetNetworkMap(&netMapCopy)
335338
}
336339

337340
func (c *Conn) RemoveAllPeers() error {
@@ -341,6 +344,7 @@ func (c *Conn) RemoveAllPeers() error {
341344
c.netMap.Peers = []*tailcfg.Node{}
342345
c.peerMap = map[tailcfg.NodeID]*tailcfg.Node{}
343346
netMapCopy := *c.netMap
347+
c.logger.Debug(context.Background(), "updating network map", slog.F("net_map", netMapCopy))
344348
c.wireguardEngine.SetNetworkMap(&netMapCopy)
345349
cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("wgconfig")), netmap.AllowSingleHosts, "")
346350
if err != nil {
@@ -360,11 +364,18 @@ func (c *Conn) RemoveAllPeers() error {
360364
}
361365

362366
// UpdateNodes connects with a set of peers. This can be constantly updated,
363-
// and peers will continually be reconnected as necessary.
364-
func (c *Conn) UpdateNodes(nodes []*Node) error {
367+
// and peers will continually be reconnected as necessary. If replacePeers is
368+
// true, all peers will be removed before adding the new ones.
369+
//
370+
//nolint:revive // Complains about replacePeers.
371+
func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error {
365372
c.mutex.Lock()
366373
defer c.mutex.Unlock()
367374
status := c.Status()
375+
if replacePeers {
376+
c.netMap.Peers = []*tailcfg.Node{}
377+
c.peerMap = map[tailcfg.NodeID]*tailcfg.Node{}
378+
}
368379
for _, peer := range c.netMap.Peers {
369380
peerStatus, ok := status.Peer[peer.Key]
370381
if !ok {
@@ -384,6 +395,11 @@ func (c *Conn) UpdateNodes(nodes []*Node) error {
384395
delete(c.peerMap, peer.ID)
385396
}
386397
for _, node := range nodes {
398+
// If no preferred DERP is provided, we can't reach the node.
399+
if node.PreferredDERP == 0 {
400+
c.logger.Debug(context.Background(), "no preferred DERP, skipping node", slog.F("node", node))
401+
continue
402+
}
387403
c.logger.Debug(context.Background(), "adding node", slog.F("node", node))
388404

389405
peerStatus, ok := status.Peer[node.Key]
@@ -402,10 +418,6 @@ func (c *Conn) UpdateNodes(nodes []*Node) error {
402418
// reason. TODO: @kylecarbs debug this!
403419
KeepAlive: ok && peerStatus.Active,
404420
}
405-
// If no preferred DERP is provided, don't set an IP!
406-
if node.PreferredDERP == 0 {
407-
peerNode.DERP = ""
408-
}
409421
if c.blockEndpoints {
410422
peerNode.Endpoints = nil
411423
}
@@ -416,6 +428,7 @@ func (c *Conn) UpdateNodes(nodes []*Node) error {
416428
c.netMap.Peers = append(c.netMap.Peers, peer.Clone())
417429
}
418430
netMapCopy := *c.netMap
431+
c.logger.Debug(context.Background(), "updating network map", slog.F("net_map", netMapCopy))
419432
c.wireguardEngine.SetNetworkMap(&netMapCopy)
420433
cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("wgconfig")), netmap.AllowSingleHosts, "")
421434
if err != nil {

tailnet/conn_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@ func TestTailnet(t *testing.T) {
5555
_ = w2.Close()
5656
})
5757
w1.SetNodeCallback(func(node *tailnet.Node) {
58-
err := w2.UpdateNodes([]*tailnet.Node{node})
59-
require.NoError(t, err)
58+
err := w2.UpdateNodes([]*tailnet.Node{node}, false)
59+
assert.NoError(t, err)
6060
})
6161
w2.SetNodeCallback(func(node *tailnet.Node) {
62-
err := w1.UpdateNodes([]*tailnet.Node{node})
63-
require.NoError(t, err)
62+
err := w1.UpdateNodes([]*tailnet.Node{node}, false)
63+
assert.NoError(t, err)
6464
})
6565
require.True(t, w2.AwaitReachable(context.Background(), w1IP))
6666
conn := make(chan struct{})

0 commit comments

Comments
 (0)