diff --git a/tailnet/conn.go b/tailnet/conn.go index 4a10d0df8dba0..61eb6db17ba87 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -173,7 +173,10 @@ func NewConn(options *Options) (*Conn, error) { logIPSet := netipx.IPSetBuilder{} logIPs, _ := logIPSet.IPSet() wireguardEngine.SetFilter(filter.New(netMap.PacketFilter, localIPs, logIPs, nil, Logger(options.Logger.Named("packet-filter")))) + dialContext, dialCancel := context.WithCancel(context.Background()) server := &Conn{ + dialContext: dialContext, + dialCancel: dialCancel, closed: make(chan struct{}), logger: options.Logger, magicConn: magicConn, @@ -229,9 +232,11 @@ func IP() netip.Addr { // Conn is an actively listening Wireguard connection. type Conn struct { - mutex sync.Mutex - closed chan struct{} - logger slog.Logger + dialContext context.Context + dialCancel context.CancelFunc + mutex sync.Mutex + closed chan struct{} + logger slog.Logger dialer *tsdial.Dialer tunDevice *tstun.Wrapper @@ -378,6 +383,7 @@ func (c *Conn) Close() error { _ = l.closeNoLock() } c.mutex.Unlock() + c.dialCancel() _ = c.dialer.Close() _ = c.magicConn.Close() _ = c.netStack.Close() @@ -500,15 +506,12 @@ func (c *Conn) forwardTCP(conn net.Conn, port uint16) { } func (c *Conn) forwardTCPToLocal(conn net.Conn, port uint16) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() defer conn.Close() - dialAddrStr := net.JoinHostPort("127.0.0.1", strconv.Itoa(int(port))) var stdDialer net.Dialer - server, err := stdDialer.DialContext(ctx, "tcp", dialAddrStr) + server, err := stdDialer.DialContext(c.dialContext, "tcp", dialAddrStr) if err != nil { - c.logger.Debug(ctx, "dial local port", slog.F("port", port), slog.Error(err)) + c.logger.Debug(c.dialContext, "dial local port", slog.F("port", port), slog.Error(err)) return } defer server.Close() @@ -528,9 +531,9 @@ func (c *Conn) forwardTCPToLocal(conn net.Conn, port uint16) { return } if err != nil { - c.logger.Debug(ctx, "proxy connection closed with error", slog.Error(err)) + c.logger.Debug(c.dialContext, "proxy connection closed with error", slog.Error(err)) } - c.logger.Debug(ctx, "forwarded connection closed", slog.F("local_addr", dialAddrStr)) + c.logger.Debug(c.dialContext, "forwarded connection closed", slog.F("local_addr", dialAddrStr)) } type listenKey struct {