Skip to content

fix(tailnet): Skip nodes without DERP, avoid use of RemoveAllPeers #6320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,9 @@ func (a *agent) runCoordinator(ctx context.Context, network *tailnet.Conn) error
}
defer coordinator.Close()
a.logger.Info(ctx, "connected to coordination server")
sendNodes, errChan := tailnet.ServeCoordinator(coordinator, network.UpdateNodes)
sendNodes, errChan := tailnet.ServeCoordinator(coordinator, func(nodes []*tailnet.Node) error {
return network.UpdateNodes(nodes, false)
})
network.SetNodeCallback(sendNodes)
select {
case <-ctx.Done():
Expand Down
15 changes: 12 additions & 3 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1179,12 +1179,21 @@ func setupAgent(t *testing.T, metadata agentsdk.Metadata, ptyTimeout time.Durati
coordinator.ServeClient(serverConn, uuid.New(), agentID)
}()
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
return conn.UpdateNodes(node)
return conn.UpdateNodes(node, false)
})
conn.SetNodeCallback(sendNode)
return &codersdk.WorkspaceAgentConn{
agentConn := &codersdk.WorkspaceAgentConn{
Conn: conn,
}, c, statsCh, fs
}
t.Cleanup(func() {
_ = agentConn.Close()
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
if !agentConn.AwaitReachable(ctx) {
t.Fatal("agent not reachable")
}
return agentConn, c, statsCh, fs
}

var dialTestPayload = []byte("dean-was-here123")
Expand Down
3 changes: 2 additions & 1 deletion cli/speedtest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent"
"github.com/coder/coder/cli/clitest"
Expand All @@ -28,7 +29,7 @@ func TestSpeedtest(t *testing.T) {
agentClient.SetSessionToken(agentToken)
agentCloser := agent.New(agent.Options{
Client: agentClient,
Logger: slogtest.Make(t, nil).Named("agent"),
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
})
defer agentCloser.Close()
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
Expand Down
2 changes: 2 additions & 0 deletions cli/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"golang.org/x/crypto/ssh"
gosshagent "golang.org/x/crypto/ssh/agent"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"

"github.com/coder/coder/agent"
Expand All @@ -47,6 +48,7 @@ func setupWorkspaceForAgent(t *testing.T, mutate func([]*proto.Agent) []*proto.A
}
}
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
client.Logger = slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug)
user := coderdtest.CreateFirstUser(t, client)
agentToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Expand Down
6 changes: 3 additions & 3 deletions coderd/coderd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,16 @@ func TestDERP(t *testing.T) {
})
require.NoError(t, err)

w2Ready := make(chan struct{}, 1)
w2Ready := make(chan struct{})
w2ReadyOnce := sync.Once{}
w1.SetNodeCallback(func(node *tailnet.Node) {
w2.UpdateNodes([]*tailnet.Node{node})
w2.UpdateNodes([]*tailnet.Node{node}, false)
w2ReadyOnce.Do(func() {
close(w2Ready)
})
})
w2.SetNodeCallback(func(node *tailnet.Node) {
w1.UpdateNodes([]*tailnet.Node{node})
w1.UpdateNodes([]*tailnet.Node{node}, false)
})

conn := make(chan struct{})
Expand Down
29 changes: 15 additions & 14 deletions coderd/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
}

func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
ctx := r.Context()
clientConn, serverConn := net.Pipe()

derpMap := api.DERPMap.Clone()
Expand Down Expand Up @@ -453,32 +454,32 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*
}

sendNodes, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
err := conn.RemoveAllPeers()
if err != nil {
return xerrors.Errorf("remove all peers: %w", err)
}

err = conn.UpdateNodes(node)
err = conn.UpdateNodes(node, true)
if err != nil {
return xerrors.Errorf("update nodes: %w", err)
}
return nil
})
conn.SetNodeCallback(sendNodes)
agentConn := &codersdk.WorkspaceAgentConn{
Conn: conn,
CloseFunc: func() {
_ = clientConn.Close()
_ = serverConn.Close()
},
}
go func() {
err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID)
if err != nil {
api.Logger.Warn(r.Context(), "tailnet coordinator client error", slog.Error(err))
_ = conn.Close()
_ = agentConn.Close()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change OK? And should this always happen if ServeClient exits?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking more about this, it seems good!

}
}()
return &codersdk.WorkspaceAgentConn{
Conn: conn,
CloseFunc: func() {
_ = clientConn.Close()
_ = serverConn.Close()
},
}, nil
if !agentConn.AwaitReachable(ctx) {
_ = agentConn.Close()
return nil, xerrors.Errorf("agent not reachable")
}
return agentConn, nil
}

// @Summary Get connection info for workspace agent
Expand Down
13 changes: 11 additions & 2 deletions coderd/wsconncache/wsconncache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,21 @@ func setupAgent(t *testing.T, metadata agentsdk.Metadata, ptyTimeout time.Durati
})
go coordinator.ServeClient(serverConn, uuid.New(), agentID)
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
return conn.UpdateNodes(node)
return conn.UpdateNodes(node, false)
})
conn.SetNodeCallback(sendNode)
return &codersdk.WorkspaceAgentConn{
agentConn := &codersdk.WorkspaceAgentConn{
Conn: conn,
}
t.Cleanup(func() {
_ = agentConn.Close()
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
if !agentConn.AwaitReachable(ctx) {
t.Fatal("agent not reachable")
}
return agentConn
}

type client struct {
Expand Down
22 changes: 15 additions & 7 deletions codersdk/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ type DialWorkspaceAgentOptions struct {
BlockEndpoints bool
}

func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *DialWorkspaceAgentOptions) (*WorkspaceAgentConn, error) {
func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *DialWorkspaceAgentOptions) (agentConn *WorkspaceAgentConn, err error) {
if options == nil {
options = &DialWorkspaceAgentOptions{}
}
Expand Down Expand Up @@ -128,6 +128,11 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
if err != nil {
return nil, xerrors.Errorf("create tailnet: %w", err)
}
defer func() {
if err != nil {
_ = conn.Close()
}
}()

coordinateURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/coordinate", agentID))
if err != nil {
Expand All @@ -145,7 +150,12 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
Jar: jar,
Transport: c.HTTPClient.Transport,
}
ctx, cancelFunc := context.WithCancel(ctx)
ctx, cancel := context.WithCancel(ctx)
defer func() {
if err != nil {
cancel()
}
}()
closed := make(chan struct{})
first := make(chan error)
go func() {
Expand Down Expand Up @@ -175,7 +185,7 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
continue
}
sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(node []*tailnet.Node) error {
return conn.UpdateNodes(node)
return conn.UpdateNodes(node, false)
})
conn.SetNodeCallback(sendNode)
options.Logger.Debug(ctx, "serving coordinator")
Expand All @@ -194,15 +204,13 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
}()
err = <-first
if err != nil {
cancelFunc()
_ = conn.Close()
return nil, err
}

agentConn := &WorkspaceAgentConn{
agentConn = &WorkspaceAgentConn{
Conn: conn,
CloseFunc: func() {
cancelFunc()
cancel()
<-closed
},
}
Expand Down
31 changes: 22 additions & 9 deletions tailnet/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func NewConn(options *Options) (conn *Conn, err error) {
}()

dialer := &tsdial.Dialer{
Logf: Logger(options.Logger),
Logf: Logger(options.Logger.Named("tsdial")),
}
wireguardEngine, err := wgengine.NewUserspaceEngine(Logger(options.Logger.Named("wgengine")), wgengine.Config{
LinkMonitor: wireguardMonitor,
Expand Down Expand Up @@ -179,6 +179,7 @@ func NewConn(options *Options) (conn *Conn, err error) {
wireguardEngine = wgengine.NewWatchdog(wireguardEngine)
wireguardEngine.SetDERPMap(options.DERPMap)
netMapCopy := *netMap
options.Logger.Debug(context.Background(), "updating network map", slog.F("net_map", netMapCopy))
wireguardEngine.SetNetworkMap(&netMapCopy)

localIPSet := netipx.IPSetBuilder{}
Expand Down Expand Up @@ -329,9 +330,11 @@ func (c *Conn) SetDERPMap(derpMap *tailcfg.DERPMap) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.logger.Debug(context.Background(), "updating derp map", slog.F("derp_map", derpMap))
c.netMap.DERPMap = derpMap
c.wireguardEngine.SetNetworkMap(c.netMap)
c.wireguardEngine.SetDERPMap(derpMap)
c.netMap.DERPMap = derpMap
netMapCopy := *c.netMap
c.logger.Debug(context.Background(), "updating network map", slog.F("net_map", netMapCopy))
c.wireguardEngine.SetNetworkMap(&netMapCopy)
}

func (c *Conn) RemoveAllPeers() error {
Expand All @@ -341,6 +344,7 @@ func (c *Conn) RemoveAllPeers() error {
c.netMap.Peers = []*tailcfg.Node{}
c.peerMap = map[tailcfg.NodeID]*tailcfg.Node{}
netMapCopy := *c.netMap
c.logger.Debug(context.Background(), "updating network map", slog.F("net_map", netMapCopy))
c.wireguardEngine.SetNetworkMap(&netMapCopy)
cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("wgconfig")), netmap.AllowSingleHosts, "")
if err != nil {
Expand All @@ -360,11 +364,18 @@ func (c *Conn) RemoveAllPeers() error {
}

// UpdateNodes connects with a set of peers. This can be constantly updated,
// and peers will continually be reconnected as necessary.
func (c *Conn) UpdateNodes(nodes []*Node) error {
// and peers will continually be reconnected as necessary. If replacePeers is
// true, all peers will be removed before adding the new ones.
//
//nolint:revive // Complains about replacePeers.
func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error {
c.mutex.Lock()
defer c.mutex.Unlock()
status := c.Status()
if replacePeers {
c.netMap.Peers = []*tailcfg.Node{}
c.peerMap = map[tailcfg.NodeID]*tailcfg.Node{}
}
for _, peer := range c.netMap.Peers {
peerStatus, ok := status.Peer[peer.Key]
if !ok {
Expand All @@ -384,6 +395,11 @@ func (c *Conn) UpdateNodes(nodes []*Node) error {
delete(c.peerMap, peer.ID)
}
for _, node := range nodes {
// If no preferred DERP is provided, we can't reach the node.
if node.PreferredDERP == 0 {
c.logger.Debug(context.Background(), "no preferred DERP, skipping node", slog.F("node", node))
continue
}
c.logger.Debug(context.Background(), "adding node", slog.F("node", node))

peerStatus, ok := status.Peer[node.Key]
Expand All @@ -402,10 +418,6 @@ func (c *Conn) UpdateNodes(nodes []*Node) error {
// reason. TODO: @kylecarbs debug this!
KeepAlive: ok && peerStatus.Active,
}
// If no preferred DERP is provided, don't set an IP!
if node.PreferredDERP == 0 {
peerNode.DERP = ""
}
if c.blockEndpoints {
peerNode.Endpoints = nil
}
Expand All @@ -416,6 +428,7 @@ func (c *Conn) UpdateNodes(nodes []*Node) error {
c.netMap.Peers = append(c.netMap.Peers, peer.Clone())
}
netMapCopy := *c.netMap
c.logger.Debug(context.Background(), "updating network map", slog.F("net_map", netMapCopy))
c.wireguardEngine.SetNetworkMap(&netMapCopy)
cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("wgconfig")), netmap.AllowSingleHosts, "")
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions tailnet/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ func TestTailnet(t *testing.T) {
_ = w2.Close()
})
w1.SetNodeCallback(func(node *tailnet.Node) {
err := w2.UpdateNodes([]*tailnet.Node{node})
require.NoError(t, err)
err := w2.UpdateNodes([]*tailnet.Node{node}, false)
assert.NoError(t, err)
})
w2.SetNodeCallback(func(node *tailnet.Node) {
err := w1.UpdateNodes([]*tailnet.Node{node})
require.NoError(t, err)
err := w1.UpdateNodes([]*tailnet.Node{node}, false)
assert.NoError(t, err)
})
require.True(t, w2.AwaitReachable(context.Background(), w1IP))
conn := make(chan struct{})
Expand Down