Skip to content

Commit 99013b3

Browse files
authored
chore: Close dials in tailnet conn on close (#4174)
Fixes a race seen in: https://github.com/coder/coder/actions/runs/3114263658/jobs/5049905647
1 parent 8cd5aea commit 99013b3

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

tailnet/conn.go

+13-10
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,10 @@ func NewConn(options *Options) (*Conn, error) {
173173
logIPSet := netipx.IPSetBuilder{}
174174
logIPs, _ := logIPSet.IPSet()
175175
wireguardEngine.SetFilter(filter.New(netMap.PacketFilter, localIPs, logIPs, nil, Logger(options.Logger.Named("packet-filter"))))
176+
dialContext, dialCancel := context.WithCancel(context.Background())
176177
server := &Conn{
178+
dialContext: dialContext,
179+
dialCancel: dialCancel,
177180
closed: make(chan struct{}),
178181
logger: options.Logger,
179182
magicConn: magicConn,
@@ -229,9 +232,11 @@ func IP() netip.Addr {
229232

230233
// Conn is an actively listening Wireguard connection.
231234
type Conn struct {
232-
mutex sync.Mutex
233-
closed chan struct{}
234-
logger slog.Logger
235+
dialContext context.Context
236+
dialCancel context.CancelFunc
237+
mutex sync.Mutex
238+
closed chan struct{}
239+
logger slog.Logger
235240

236241
dialer *tsdial.Dialer
237242
tunDevice *tstun.Wrapper
@@ -378,6 +383,7 @@ func (c *Conn) Close() error {
378383
_ = l.closeNoLock()
379384
}
380385
c.mutex.Unlock()
386+
c.dialCancel()
381387
_ = c.dialer.Close()
382388
_ = c.magicConn.Close()
383389
_ = c.netStack.Close()
@@ -500,15 +506,12 @@ func (c *Conn) forwardTCP(conn net.Conn, port uint16) {
500506
}
501507

502508
func (c *Conn) forwardTCPToLocal(conn net.Conn, port uint16) {
503-
ctx, cancel := context.WithCancel(context.Background())
504-
defer cancel()
505509
defer conn.Close()
506-
507510
dialAddrStr := net.JoinHostPort("127.0.0.1", strconv.Itoa(int(port)))
508511
var stdDialer net.Dialer
509-
server, err := stdDialer.DialContext(ctx, "tcp", dialAddrStr)
512+
server, err := stdDialer.DialContext(c.dialContext, "tcp", dialAddrStr)
510513
if err != nil {
511-
c.logger.Debug(ctx, "dial local port", slog.F("port", port), slog.Error(err))
514+
c.logger.Debug(c.dialContext, "dial local port", slog.F("port", port), slog.Error(err))
512515
return
513516
}
514517
defer server.Close()
@@ -528,9 +531,9 @@ func (c *Conn) forwardTCPToLocal(conn net.Conn, port uint16) {
528531
return
529532
}
530533
if err != nil {
531-
c.logger.Debug(ctx, "proxy connection closed with error", slog.Error(err))
534+
c.logger.Debug(c.dialContext, "proxy connection closed with error", slog.Error(err))
532535
}
533-
c.logger.Debug(ctx, "forwarded connection closed", slog.F("local_addr", dialAddrStr))
536+
c.logger.Debug(c.dialContext, "forwarded connection closed", slog.F("local_addr", dialAddrStr))
534537
}
535538

536539
type listenKey struct {

0 commit comments

Comments
 (0)