Skip to content

fix(tailnet): Improve tailnet setup and agentconn stability #6292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions coderd/coderd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/http"
"net/netip"
"strconv"
"sync"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -78,8 +79,14 @@ func TestDERP(t *testing.T) {
DERPMap: derpMap,
})
require.NoError(t, err)

w2Ready := make(chan struct{}, 1)
w2ReadyOnce := sync.Once{}
w1.SetNodeCallback(func(node *tailnet.Node) {
w2.UpdateNodes([]*tailnet.Node{node})
w2ReadyOnce.Do(func() {
close(w2Ready)
})
})
w2.SetNodeCallback(func(node *tailnet.Node) {
w1.UpdateNodes([]*tailnet.Node{node})
Expand All @@ -98,6 +105,7 @@ func TestDERP(t *testing.T) {
}()

<-conn
<-w2Ready
nc, err := w2.DialContextTCP(context.Background(), netip.AddrPortFrom(w1IP, 35565))
require.NoError(t, err)
_ = nc.Close()
Expand Down
2 changes: 2 additions & 0 deletions coderd/workspaceagents_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,8 @@ func TestWorkspaceAgentListeningPorts(t *testing.T) {
t.Parallel()

setup := func(t *testing.T, apps []*proto.App) (*codersdk.Client, uint16, uuid.UUID) {
t.Helper()

client := coderdtest.New(t, &coderdtest.Options{
IncludeProvisionerDaemon: true,
})
Expand Down
11 changes: 11 additions & 0 deletions coderd/wsconncache/wsconncache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/coder/coder/codersdk/agentsdk"
"github.com/coder/coder/tailnet"
"github.com/coder/coder/tailnet/tailnettest"
"github.com/coder/coder/testutil"
)

func TestMain(m *testing.M) {
Expand Down Expand Up @@ -131,6 +132,14 @@ func TestCache(t *testing.T) {
return
}
defer release()

ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
if !conn.AwaitReachable(ctx) {
t.Error("agent not reachable")
return
}

transport := conn.HTTPTransport()
defer transport.CloseIdleConnections()
proxy.Transport = transport
Expand All @@ -146,6 +155,8 @@ func TestCache(t *testing.T) {
}

func setupAgent(t *testing.T, metadata agentsdk.Metadata, ptyTimeout time.Duration) *codersdk.WorkspaceAgentConn {
t.Helper()

metadata.DERPMap = tailnettest.RunDERPAndSTUN(t)

coordinator := tailnet.NewCoordinator()
Expand Down
21 changes: 18 additions & 3 deletions codersdk/workspaceagentconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ type ReconnectingPTYRequest struct {
func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string) (net.Conn, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()

if !c.AwaitReachable(ctx) {
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
}
conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentReconnectingPTYPort))
if err != nil {
return nil, err
Expand Down Expand Up @@ -207,6 +209,9 @@ func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID,
func (c *WorkspaceAgentConn) SSH(ctx context.Context) (net.Conn, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
if !c.AwaitReachable(ctx) {
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
}
return c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentSSHPort))
}

Expand Down Expand Up @@ -235,6 +240,9 @@ func (c *WorkspaceAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error)
func (c *WorkspaceAgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
if !c.AwaitReachable(ctx) {
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
}
speedConn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentSpeedtestPort))
if err != nil {
return nil, xerrors.Errorf("dial speedtest: %w", err)
Expand All @@ -257,6 +265,9 @@ func (c *WorkspaceAgentConn) DialContext(ctx context.Context, network string, ad
_, rawPort, _ := net.SplitHostPort(addr)
port, _ := strconv.ParseUint(rawPort, 10, 16)
ipp := netip.AddrPortFrom(WorkspaceAgentIP, uint16(port))
if !c.AwaitReachable(ctx) {
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
}
if network == "udp" {
return c.Conn.DialContextUDP(ctx, ipp)
}
Expand Down Expand Up @@ -317,7 +328,7 @@ func (c *WorkspaceAgentConn) apiClient() *http.Client {
// Disable keep alives as we're usually only making a single
// request, and this triggers goleak in tests
DisableKeepAlives: true,
DialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
if network != "tcp" {
return nil, xerrors.Errorf("network must be tcp")
}
Expand All @@ -331,7 +342,11 @@ func (c *WorkspaceAgentConn) apiClient() *http.Client {
return nil, xerrors.Errorf("request %q does not appear to be for http api", addr)
}

conn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentHTTPAPIServerPort))
if !c.AwaitReachable(ctx) {
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
}

conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentHTTPAPIServerPort))
if err != nil {
return nil, xerrors.Errorf("dial http api: %w", err)
}
Expand Down
10 changes: 8 additions & 2 deletions codersdk/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,19 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
return nil, err
}

return &WorkspaceAgentConn{
agentConn := &WorkspaceAgentConn{
Conn: conn,
CloseFunc: func() {
cancelFunc()
<-closed
},
}, nil
}
if !agentConn.AwaitReachable(ctx) {
_ = agentConn.Close()
return nil, xerrors.Errorf("timed out waiting for agent to become reachable: %w", ctx.Err())
}

return agentConn, nil
}

// WorkspaceAgent returns an agent by ID.
Expand Down
78 changes: 59 additions & 19 deletions tailnet/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ type Options struct {
}

// NewConn constructs a new Wireguard server that will accept connections from the addresses provided.
func NewConn(options *Options) (*Conn, error) {
func NewConn(options *Options) (conn *Conn, err error) {
if options == nil {
options = &Options{}
}
Expand Down Expand Up @@ -123,6 +123,11 @@ func NewConn(options *Options) (*Conn, error) {
if err != nil {
return nil, xerrors.Errorf("create wireguard link monitor: %w", err)
}
defer func() {
if err != nil {
wireguardMonitor.Close()
}
}()

dialer := &tsdial.Dialer{
Logf: Logger(options.Logger),
Expand All @@ -134,6 +139,11 @@ func NewConn(options *Options) (*Conn, error) {
if err != nil {
return nil, xerrors.Errorf("create wgengine: %w", err)
}
defer func() {
if err != nil {
wireguardEngine.Close()
}
}()
dialer.UseNetstackForIP = func(ip netip.Addr) bool {
_, ok := wireguardEngine.PeerForIP(ip)
return ok
Expand Down Expand Up @@ -166,10 +176,6 @@ func NewConn(options *Options) (*Conn, error) {
return netStack.DialContextTCP(ctx, dst)
}
netStack.ProcessLocalIPs = true
err = netStack.Start(nil)
if err != nil {
return nil, xerrors.Errorf("start netstack: %w", err)
}
wireguardEngine = wgengine.NewWatchdog(wireguardEngine)
wireguardEngine.SetDERPMap(options.DERPMap)
netMapCopy := *netMap
Expand Down Expand Up @@ -203,6 +209,11 @@ func NewConn(options *Options) (*Conn, error) {
},
wireguardEngine: wireguardEngine,
}
defer func() {
if err != nil {
_ = server.Close()
}
}()
wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) {
server.logger.Debug(context.Background(), "wireguard status", slog.F("status", s), slog.F("err", err))
if err != nil {
Expand Down Expand Up @@ -236,6 +247,12 @@ func NewConn(options *Options) (*Conn, error) {
server.sendNode()
})
netStack.ForwardTCPIn = server.forwardTCP

err = netStack.Start(nil)
if err != nil {
return nil, xerrors.Errorf("start netstack: %w", err)
}

return server, nil
}

Expand Down Expand Up @@ -519,22 +536,35 @@ func (c *Conn) Close() error {
default:
}
close(c.closed)
for _, l := range c.listeners {
_ = l.closeNoLock()
}
c.mutex.Unlock()
c.dialCancel()
_ = c.dialer.Close()
_ = c.magicConn.Close()

var wg sync.WaitGroup
defer wg.Wait()

if c.trafficStats != nil {
wg.Add(1)
go func() {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = c.trafficStats.Shutdown(ctx)
}()
}

_ = c.netStack.Close()
c.dialCancel()
_ = c.wireguardMonitor.Close()
_ = c.tunDevice.Close()
_ = c.dialer.Close()
// Stops internals, e.g. tunDevice, magicConn and dnsManager.
c.wireguardEngine.Close()
if c.trafficStats != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = c.trafficStats.Shutdown(ctx)

c.mutex.Lock()
for _, l := range c.listeners {
_ = l.closeNoLock()
}
c.listeners = nil
c.mutex.Unlock()
Copy link
Member Author

@mafredri mafredri Feb 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt these changes are necessary, I simply tried to mimic the order a bit closer to what's done in: https://github.com/tailscale/tailscale/blob/cd18bb68a49608b86f60e22cb081f0156d3d11b5/tsnet/tsnet.go#L152-L192


return nil
}

Expand Down Expand Up @@ -714,16 +744,25 @@ func (c *Conn) forwardTCPToLocal(conn net.Conn, port uint16) {
func (c *Conn) SetConnStatsCallback(maxPeriod time.Duration, maxConns int, dump func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts)) {
connStats := connstats.NewStatistics(maxPeriod, maxConns, dump)

shutdown := func(s *connstats.Statistics) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.Shutdown(ctx)
}

c.mutex.Lock()
if c.isClosed() {
c.mutex.Unlock()
shutdown(connStats)
return
}
old := c.trafficStats
c.trafficStats = connStats
c.mutex.Unlock()

// Make sure to shutdown the old callback.
if old != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = old.Shutdown(ctx)
shutdown(old)
}

c.tunDevice.SetStatistics(connStats)
Expand Down Expand Up @@ -776,6 +815,7 @@ func (a addr) String() string { return a.ln.addr }
// Logger converts the Tailscale logging function to use slog.
func Logger(logger slog.Logger) tslogger.Logf {
return tslogger.Logf(func(format string, args ...any) {
slog.Helper()
logger.Debug(context.Background(), fmt.Sprintf(format, args...))
})
}
Expand Down