Skip to content

Commit 1c2984a

Browse files
committed
Use context.AfterFunc() instead of a goroutine to track timeouts
1 parent 91013c1 commit 1c2984a

File tree

4 files changed

+38
-40
lines changed

4 files changed

+38
-40
lines changed

close.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -231,12 +231,6 @@ func (c *Conn) waitGoroutines() error {
231231
t := time.NewTimer(time.Second * 15)
232232
defer t.Stop()
233233

234-
select {
235-
case <-c.timeoutLoopDone:
236-
case <-t.C:
237-
return errors.New("failed to wait for timeoutLoop goroutine to exit")
238-
}
239-
240234
c.closeReadMu.Lock()
241235
closeRead := c.closeReadCtx != nil
242236
c.closeReadMu.Unlock()

conn.go

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,8 @@ type Conn struct {
5151
br *bufio.Reader
5252
bw *bufio.Writer
5353

54-
readTimeout chan context.Context
55-
writeTimeout chan context.Context
56-
timeoutLoopDone chan struct{}
54+
readTimeoutStop atomic.Pointer[func() bool]
55+
writeTimeoutStop atomic.Pointer[func() bool]
5756

5857
// Read state.
5958
readMu *mu
@@ -113,10 +112,6 @@ func newConn(cfg connConfig) *Conn {
113112
br: cfg.br,
114113
bw: cfg.bw,
115114

116-
readTimeout: make(chan context.Context),
117-
writeTimeout: make(chan context.Context),
118-
timeoutLoopDone: make(chan struct{}),
119-
120115
closed: make(chan struct{}),
121116
activePings: make(map[string]chan<- struct{}),
122117
onPingReceived: cfg.onPingReceived,
@@ -144,8 +139,6 @@ func newConn(cfg connConfig) *Conn {
144139
c.close()
145140
})
146141

147-
go c.timeoutLoop()
148-
149142
return c
150143
}
151144

@@ -175,27 +168,34 @@ func (c *Conn) close() error {
175168
return err
176169
}
177170

178-
func (c *Conn) timeoutLoop() {
179-
defer close(c.timeoutLoopDone)
171+
func (c *Conn) setupWriteTimeout(ctx context.Context) {
172+
stop := context.AfterFunc(ctx, func() {
173+
c.clearWriteTimeout()
174+
c.close()
175+
})
176+
swapTimeoutStop(&c.writeTimeoutStop, &stop)
177+
}
180178

181-
readCtx := context.Background()
182-
writeCtx := context.Background()
179+
func (c *Conn) clearWriteTimeout() {
180+
swapTimeoutStop(&c.writeTimeoutStop, nil)
181+
}
183182

184-
for {
185-
select {
186-
case <-c.closed:
187-
return
188-
189-
case writeCtx = <-c.writeTimeout:
190-
case readCtx = <-c.readTimeout:
191-
192-
case <-readCtx.Done():
193-
c.close()
194-
return
195-
case <-writeCtx.Done():
196-
c.close()
197-
return
198-
}
183+
func (c *Conn) setupReadTimeout(ctx context.Context) {
184+
stop := context.AfterFunc(ctx, func() {
185+
c.clearReadTimeout()
186+
c.close()
187+
})
188+
swapTimeoutStop(&c.readTimeoutStop, &stop)
189+
}
190+
191+
func (c *Conn) clearReadTimeout() {
192+
swapTimeoutStop(&c.readTimeoutStop, nil)
193+
}
194+
195+
func swapTimeoutStop(p *atomic.Pointer[func() bool], newStop *func() bool) {
196+
oldStop := p.Swap(newStop)
197+
if oldStop != nil {
198+
(*oldStop)()
199199
}
200200
}
201201

read.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,22 +220,24 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) {
220220
// to be called after the read is done. It also returns an error if the
221221
// connection is closed. The reference to the error is used to assign
222222
// an error depending on if the connection closed or the context timed
223-
// out during use. Typically the referenced error is a named return
223+
// out during use. Typically, the referenced error is a named return
224224
// variable of the function calling this method.
225225
func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) {
226226
select {
227227
case <-c.closed:
228228
return nil, net.ErrClosed
229-
case c.readTimeout <- ctx:
229+
default:
230230
}
231+
c.setupReadTimeout(ctx)
231232

232233
done := func() {
233234
select {
234235
case <-c.closed:
235236
if *err != nil {
236237
*err = net.ErrClosed
237238
}
238-
case c.readTimeout <- context.Background():
239+
default:
240+
c.clearReadTimeout()
239241
}
240242
if *err != nil && ctx.Err() != nil {
241243
*err = ctx.Err()
@@ -280,7 +282,7 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error
280282
return n, fmt.Errorf("failed to read frame payload: %w", err)
281283
}
282284

283-
return n, err
285+
return n, nil
284286
}
285287

286288
func (c *Conn) handleControl(ctx context.Context, h header) (err error) {

write.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,14 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
271271
select {
272272
case <-c.closed:
273273
return 0, net.ErrClosed
274-
case c.writeTimeout <- ctx:
274+
default:
275275
}
276+
c.setupWriteTimeout(ctx)
276277
defer func() {
277278
select {
278279
case <-c.closed:
279-
case c.writeTimeout <- context.Background():
280+
default:
281+
c.clearWriteTimeout()
280282
}
281283
}()
282284

0 commit comments

Comments
 (0)