Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix(tailnet): Improve start and close to detect connection races
  • Loading branch information
mafredri committed Feb 21, 2023
commit 03b09175eb90a810f3c3bfe9316c77fed5514c6b
62 changes: 46 additions & 16 deletions tailnet/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ type Options struct {
}

// NewConn constructs a new Wireguard server that will accept connections from the addresses provided.
func NewConn(options *Options) (*Conn, error) {
func NewConn(options *Options) (conn *Conn, err error) {
if options == nil {
options = &Options{}
}
Expand Down Expand Up @@ -123,6 +123,11 @@ func NewConn(options *Options) (*Conn, error) {
if err != nil {
return nil, xerrors.Errorf("create wireguard link monitor: %w", err)
}
defer func() {
if err != nil {
wireguardMonitor.Close()
}
}()

dialer := &tsdial.Dialer{
Logf: Logger(options.Logger),
Expand All @@ -134,6 +139,11 @@ func NewConn(options *Options) (*Conn, error) {
if err != nil {
return nil, xerrors.Errorf("create wgengine: %w", err)
}
defer func() {
if err != nil {
wireguardEngine.Close()
}
}()
dialer.UseNetstackForIP = func(ip netip.Addr) bool {
_, ok := wireguardEngine.PeerForIP(ip)
return ok
Expand Down Expand Up @@ -166,10 +176,6 @@ func NewConn(options *Options) (*Conn, error) {
return netStack.DialContextTCP(ctx, dst)
}
netStack.ProcessLocalIPs = true
err = netStack.Start(nil)
if err != nil {
return nil, xerrors.Errorf("start netstack: %w", err)
}
wireguardEngine = wgengine.NewWatchdog(wireguardEngine)
wireguardEngine.SetDERPMap(options.DERPMap)
netMapCopy := *netMap
Expand Down Expand Up @@ -203,6 +209,11 @@ func NewConn(options *Options) (*Conn, error) {
},
wireguardEngine: wireguardEngine,
}
defer func() {
if err != nil {
_ = server.Close()
}
}()
wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) {
server.logger.Debug(context.Background(), "wireguard status", slog.F("status", s), slog.F("err", err))
if err != nil {
Expand Down Expand Up @@ -236,6 +247,12 @@ func NewConn(options *Options) (*Conn, error) {
server.sendNode()
})
netStack.ForwardTCPIn = server.forwardTCP

err = netStack.Start(nil)
if err != nil {
return nil, xerrors.Errorf("start netstack: %w", err)
}

return server, nil
}

Expand Down Expand Up @@ -519,22 +536,35 @@ func (c *Conn) Close() error {
default:
}
close(c.closed)
for _, l := range c.listeners {
_ = l.closeNoLock()
}
c.mutex.Unlock()
c.dialCancel()
_ = c.dialer.Close()
_ = c.magicConn.Close()

var wg sync.WaitGroup
defer wg.Wait()

if c.trafficStats != nil {
wg.Add(1)
go func() {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = c.trafficStats.Shutdown(ctx)
}()
}

_ = c.netStack.Close()
c.dialCancel()
_ = c.wireguardMonitor.Close()
_ = c.tunDevice.Close()
_ = c.dialer.Close()
// Stops internals, e.g. tunDevice, magicConn and dnsManager.
c.wireguardEngine.Close()
if c.trafficStats != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = c.trafficStats.Shutdown(ctx)

c.mutex.Lock()
for _, l := range c.listeners {
_ = l.closeNoLock()
}
c.listeners = nil
c.mutex.Unlock()
Copy link
Member Author

@mafredri mafredri Feb 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt these changes are necessary, I simply tried to mimic the order a bit closer to what's done in: https://github.com/tailscale/tailscale/blob/cd18bb68a49608b86f60e22cb081f0156d3d11b5/tsnet/tsnet.go#L152-L192


return nil
}

Expand Down