diff --git a/netconn.go b/netconn.go index 64aadf0b..0f9e7102 100644 --- a/netconn.go +++ b/netconn.go @@ -6,7 +6,9 @@ import ( "io" "math" "net" + "os" "sync" + "sync/atomic" "time" ) @@ -43,15 +45,26 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { msgType: msgType, } - var cancel context.CancelFunc - nc.writeContext, cancel = context.WithCancel(ctx) - nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel) + var writeCancel context.CancelFunc + nc.writeContext, writeCancel = context.WithCancel(ctx) + nc.writeTimer = time.AfterFunc(math.MaxInt64, func() { + nc.afterWriteDeadline.Store(true) + if nc.writing.Load() { + writeCancel() + } + }) if !nc.writeTimer.Stop() { <-nc.writeTimer.C } - nc.readContext, cancel = context.WithCancel(ctx) - nc.readTimer = time.AfterFunc(math.MaxInt64, cancel) + var readCancel context.CancelFunc + nc.readContext, readCancel = context.WithCancel(ctx) + nc.readTimer = time.AfterFunc(math.MaxInt64, func() { + nc.afterReadDeadline.Store(true) + if nc.reading.Load() { + readCancel() + } + }) if !nc.readTimer.Stop() { <-nc.readTimer.C } @@ -63,11 +76,15 @@ type netConn struct { c *Conn msgType MessageType - writeTimer *time.Timer - writeContext context.Context + writeTimer *time.Timer + writeContext context.Context + writing atomic.Bool + afterWriteDeadline atomic.Bool - readTimer *time.Timer - readContext context.Context + readTimer *time.Timer + readContext context.Context + reading atomic.Bool + afterReadDeadline atomic.Bool readMu sync.Mutex eofed bool @@ -81,16 +98,34 @@ func (c *netConn) Close() error { } func (c *netConn) Write(p []byte) (int, error) { + if c.afterWriteDeadline.Load() { + return 0, os.ErrDeadlineExceeded + } + + if swapped := c.writing.CompareAndSwap(false, true); !swapped { + panic("Concurrent writes not allowed") + } + defer c.writing.Store(false) + err := c.c.Write(c.writeContext, c.msgType, p) if err != nil { return 0, err } + return len(p), nil } func (c *netConn) Read(p []byte) (int, error) { + if c.afterReadDeadline.Load() { + return 0, os.ErrDeadlineExceeded + } + c.readMu.Lock() defer c.readMu.Unlock() + if swapped := c.reading.CompareAndSwap(false, true); !swapped { + panic("Concurrent reads not allowed") + } + defer c.reading.Store(false) if c.eofed { return 0, io.EOF @@ -151,8 +186,9 @@ func (c *netConn) SetWriteDeadline(t time.Time) error { if t.IsZero() { c.writeTimer.Stop() } else { - c.writeTimer.Reset(t.Sub(time.Now())) + c.writeTimer.Reset(time.Until(t)) } + c.afterWriteDeadline.Store(false) return nil } @@ -160,7 +196,8 @@ func (c *netConn) SetReadDeadline(t time.Time) error { if t.IsZero() { c.readTimer.Stop() } else { - c.readTimer.Reset(t.Sub(time.Now())) + c.readTimer.Reset(time.Until(t)) } + c.afterReadDeadline.Store(false) return nil }