Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions tailnet/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,10 @@ func NewConn(options *Options) (*Conn, error) {
logIPSet := netipx.IPSetBuilder{}
logIPs, _ := logIPSet.IPSet()
wireguardEngine.SetFilter(filter.New(netMap.PacketFilter, localIPs, logIPs, nil, Logger(options.Logger.Named("packet-filter"))))
dialContext, dialCancel := context.WithCancel(context.Background())
server := &Conn{
dialContext: dialContext,
dialCancel: dialCancel,
closed: make(chan struct{}),
logger: options.Logger,
magicConn: magicConn,
Expand Down Expand Up @@ -229,9 +232,11 @@ func IP() netip.Addr {

// Conn is an actively listening Wireguard connection.
type Conn struct {
mutex sync.Mutex
closed chan struct{}
logger slog.Logger
dialContext context.Context
dialCancel context.CancelFunc
mutex sync.Mutex
closed chan struct{}
logger slog.Logger

dialer *tsdial.Dialer
tunDevice *tstun.Wrapper
Expand Down Expand Up @@ -378,6 +383,7 @@ func (c *Conn) Close() error {
_ = l.closeNoLock()
}
c.mutex.Unlock()
c.dialCancel()
_ = c.dialer.Close()
_ = c.magicConn.Close()
_ = c.netStack.Close()
Expand Down Expand Up @@ -500,15 +506,12 @@ func (c *Conn) forwardTCP(conn net.Conn, port uint16) {
}

func (c *Conn) forwardTCPToLocal(conn net.Conn, port uint16) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
defer conn.Close()

dialAddrStr := net.JoinHostPort("127.0.0.1", strconv.Itoa(int(port)))
var stdDialer net.Dialer
server, err := stdDialer.DialContext(ctx, "tcp", dialAddrStr)
server, err := stdDialer.DialContext(c.dialContext, "tcp", dialAddrStr)
if err != nil {
c.logger.Debug(ctx, "dial local port", slog.F("port", port), slog.Error(err))
c.logger.Debug(c.dialContext, "dial local port", slog.F("port", port), slog.Error(err))
return
}
defer server.Close()
Expand All @@ -528,9 +531,9 @@ func (c *Conn) forwardTCPToLocal(conn net.Conn, port uint16) {
return
}
if err != nil {
c.logger.Debug(ctx, "proxy connection closed with error", slog.Error(err))
c.logger.Debug(c.dialContext, "proxy connection closed with error", slog.Error(err))
}
c.logger.Debug(ctx, "forwarded connection closed", slog.F("local_addr", dialAddrStr))
c.logger.Debug(c.dialContext, "forwarded connection closed", slog.F("local_addr", dialAddrStr))
}

type listenKey struct {
Expand Down