diff --git a/session.go b/session.go index 046a3d3..ae81cff 100644 --- a/session.go +++ b/session.go @@ -2,6 +2,7 @@ package yamux import ( "bufio" + "bytes" "fmt" "io" "io/ioutil" @@ -63,7 +64,7 @@ type Session struct { // sendCh is used to mark a stream as ready to send, // or to send a header out directly. - sendCh chan sendReady + sendCh chan *sendReady // recvDoneCh is closed when recv() exits to avoid a race // between stream registration and stream shutdown @@ -80,6 +81,7 @@ type Session struct { // or to directly send a header type sendReady struct { Hdr []byte + mu sync.Mutex // Protects Body from unsafe reads. Body []byte Err chan error } @@ -101,7 +103,7 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { inflight: make(map[uint32]struct{}), synCh: make(chan struct{}, config.AcceptBacklog), acceptCh: make(chan *Stream, config.AcceptBacklog), - sendCh: make(chan sendReady, 64), + sendCh: make(chan *sendReady, 64), recvDoneCh: make(chan struct{}), shutdownCh: make(chan struct{}), } @@ -373,7 +375,7 @@ func (s *Session) waitForSendErr(hdr header, body []byte, errCh chan error) erro timerPool.Put(t) }() - ready := sendReady{Hdr: hdr, Body: body, Err: errCh} + ready := &sendReady{Hdr: hdr, Body: body, Err: errCh} select { case s.sendCh <- ready: case <-s.shutdownCh: @@ -382,12 +384,34 @@ func (s *Session) waitForSendErr(hdr header, body []byte, errCh chan error) erro return ErrConnectionWriteTimeout } + bodyCopy := func() { + if body == nil { + return // A nil body is ignored. + } + + // In the event of session shutdown or connection write timeout, + // we need to prevent `send` from reading the body buffer after + // returning from this function since the caller may re-use the + // underlying array. + ready.mu.Lock() + defer ready.mu.Unlock() + + if ready.Body == nil { + return // Body was already copied in `send`. + } + newBody := make([]byte, len(body)) + copy(newBody, body) + ready.Body = newBody + } + select { case err := <-errCh: return err case <-s.shutdownCh: + bodyCopy() return ErrSessionShutdown case <-timer.C: + bodyCopy() return ErrConnectionWriteTimeout } } @@ -409,7 +433,7 @@ func (s *Session) sendNoWait(hdr header) error { }() select { - case s.sendCh <- sendReady{Hdr: hdr}: + case s.sendCh <- &sendReady{Hdr: hdr}: return nil case <-s.shutdownCh: return ErrSessionShutdown @@ -420,7 +444,10 @@ func (s *Session) sendNoWait(hdr header) error { // send is a long running goroutine that sends data func (s *Session) send() { + var bodyBuf bytes.Buffer for { + bodyBuf.Reset() + select { case ready := <-s.sendCh: // Send a header if ready @@ -438,9 +465,26 @@ func (s *Session) send() { } } - // Send data from a body if given + ready.mu.Lock() if ready.Body != nil { - _, err := s.conn.Write(ready.Body) + // Copy the body into the buffer to avoid + // holding a mutex lock during the write. + _, err := bodyBuf.Write(ready.Body) + if err != nil { + ready.Body = nil + ready.mu.Unlock() + s.logger.Printf("[ERR] yamux: Failed to copy body into buffer: %v", err) + asyncSendErr(ready.Err, err) + s.exitErr(err) + return + } + ready.Body = nil + } + ready.mu.Unlock() + + if bodyBuf.Len() > 0 { + // Send data from a body if given + _, err := s.conn.Write(bodyBuf.Bytes()) if err != nil { s.logger.Printf("[ERR] yamux: Failed to write body: %v", err) asyncSendErr(ready.Err, err) @@ -639,8 +683,9 @@ func (s *Session) incomingStream(id uint32) error { // Backlog exceeded! RST the stream s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset") delete(s.streams, id) - stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0) - return s.sendNoWait(stream.sendHdr) + hdr := header(make([]byte, headerSize)) + hdr.encode(typeWindowUpdate, flagRST, id, 0) + return s.sendNoWait(hdr) } } diff --git a/stream.go b/stream.go index f444bdc..d197d28 100644 --- a/stream.go +++ b/stream.go @@ -2,6 +2,7 @@ package yamux import ( "bytes" + "errors" "io" "sync" "sync/atomic" @@ -200,6 +201,10 @@ START: // Send the header s.sendHdr.encode(typeData, flags, s.id, max) if err = s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil { + if errors.Is(err, ErrSessionShutdown) || errors.Is(err, ErrConnectionWriteTimeout) { + // Message left in ready queue, header re-use is unsafe. + s.sendHdr = header(make([]byte, headerSize)) + } return 0, err } @@ -273,6 +278,10 @@ func (s *Stream) sendWindowUpdate() error { // Send the header s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta) if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil { + if errors.Is(err, ErrSessionShutdown) || errors.Is(err, ErrConnectionWriteTimeout) { + // Message left in ready queue, header re-use is unsafe. + s.controlHdr = header(make([]byte, headerSize)) + } return err } return nil @@ -287,6 +296,10 @@ func (s *Stream) sendClose() error { flags |= flagFIN s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0) if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil { + if errors.Is(err, ErrSessionShutdown) || errors.Is(err, ErrConnectionWriteTimeout) { + // Message left in ready queue, header re-use is unsafe. + s.controlHdr = header(make([]byte, headerSize)) + } return err } return nil @@ -362,8 +375,9 @@ func (s *Stream) closeTimeout() { // Send a RST so the remote side closes too. s.sendLock.Lock() defer s.sendLock.Unlock() - s.sendHdr.encode(typeWindowUpdate, flagRST, s.id, 0) - s.session.sendNoWait(s.sendHdr) + hdr := header(make([]byte, headerSize)) + hdr.encode(typeWindowUpdate, flagRST, s.id, 0) + s.session.sendNoWait(hdr) } // forceClose is used for when the session is exiting @@ -465,6 +479,7 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { if length > s.recvWindow { s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length) + s.recvLock.Unlock() return ErrRecvWindowExceeded }