From a37705c436f670ddfc62b94c99e5d252f7f96899 Mon Sep 17 00:00:00 2001 From: Mihai Parparita Date: Fri, 14 Oct 2022 12:28:39 -0700 Subject: [PATCH] Fix net.Conn deadlines tearing down the connection when no reads or writes are active We were unconditionally canceling the read/write contexts when the deadline timer fired, even if there were no active reads or writes. We instead only cancel the context if there is an active operation, otherwise we set a flag so that future calls (without the deadline being reset) will fail. Updates tailscale/tailscale#5921 --- netconn.go | 59 ++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 48 insertions(+), 11 deletions(-) diff --git a/netconn.go b/netconn.go index 64aadf0b..0f9e7102 100644 --- a/netconn.go +++ b/netconn.go @@ -6,7 +6,9 @@ import ( "io" "math" "net" + "os" "sync" + "sync/atomic" "time" ) @@ -43,15 +45,26 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { msgType: msgType, } - var cancel context.CancelFunc - nc.writeContext, cancel = context.WithCancel(ctx) - nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel) + var writeCancel context.CancelFunc + nc.writeContext, writeCancel = context.WithCancel(ctx) + nc.writeTimer = time.AfterFunc(math.MaxInt64, func() { + nc.afterWriteDeadline.Store(true) + if nc.writing.Load() { + writeCancel() + } + }) if !nc.writeTimer.Stop() { <-nc.writeTimer.C } - nc.readContext, cancel = context.WithCancel(ctx) - nc.readTimer = time.AfterFunc(math.MaxInt64, cancel) + var readCancel context.CancelFunc + nc.readContext, readCancel = context.WithCancel(ctx) + nc.readTimer = time.AfterFunc(math.MaxInt64, func() { + nc.afterReadDeadline.Store(true) + if nc.reading.Load() { + readCancel() + } + }) if !nc.readTimer.Stop() { <-nc.readTimer.C } @@ -63,11 +76,15 @@ type netConn struct { c *Conn msgType MessageType - writeTimer *time.Timer - writeContext context.Context + writeTimer *time.Timer + writeContext context.Context + writing atomic.Bool + afterWriteDeadline atomic.Bool - readTimer *time.Timer - readContext context.Context + readTimer *time.Timer + readContext context.Context + reading atomic.Bool + afterReadDeadline atomic.Bool readMu sync.Mutex eofed bool @@ -81,16 +98,34 @@ func (c *netConn) Close() error { } func (c *netConn) Write(p []byte) (int, error) { + if c.afterWriteDeadline.Load() { + return 0, os.ErrDeadlineExceeded + } + + if swapped := c.writing.CompareAndSwap(false, true); !swapped { + panic("Concurrent writes not allowed") + } + defer c.writing.Store(false) + err := c.c.Write(c.writeContext, c.msgType, p) if err != nil { return 0, err } + return len(p), nil } func (c *netConn) Read(p []byte) (int, error) { + if c.afterReadDeadline.Load() { + return 0, os.ErrDeadlineExceeded + } + c.readMu.Lock() defer c.readMu.Unlock() + if swapped := c.reading.CompareAndSwap(false, true); !swapped { + panic("Concurrent reads not allowed") + } + defer c.reading.Store(false) if c.eofed { return 0, io.EOF @@ -151,8 +186,9 @@ func (c *netConn) SetWriteDeadline(t time.Time) error { if t.IsZero() { c.writeTimer.Stop() } else { - c.writeTimer.Reset(t.Sub(time.Now())) + c.writeTimer.Reset(time.Until(t)) } + c.afterWriteDeadline.Store(false) return nil } @@ -160,7 +196,8 @@ func (c *netConn) SetReadDeadline(t time.Time) error { if t.IsZero() { c.readTimer.Stop() } else { - c.readTimer.Reset(t.Sub(time.Now())) + c.readTimer.Reset(time.Until(t)) } + c.afterReadDeadline.Store(false) return nil }