Skip to content

Commit 9fc9f7a

Browse files
committed
Ensure message order with a buffer
1 parent 480d0eb commit 9fc9f7a

File tree

3 files changed

+54
-20
lines changed

3 files changed

+54
-20
lines changed

conn.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,6 @@ func (c *Conn) Subprotocol() string {
120120
return c.subprotocol
121121
}
122122

123-
func (c *Conn) setCloseErr(err error) {
124-
c.closeErrOnce.Do(func() {
125-
c.closeErr = fmt.Errorf("websocket closed: %w", err)
126-
})
127-
}
128-
129123
func (c *Conn) close(err error) {
130124
c.closeOnce.Do(func() {
131125
runtime.SetFinalizer(c, nil)

conn_common.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,9 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context {
202202
func (c *Conn) SetReadLimit(n int64) {
203203
c.msgReadLimit = n
204204
}
205+
206+
func (c *Conn) setCloseErr(err error) {
207+
c.closeErrOnce.Do(func() {
208+
c.closeErr = fmt.Errorf("websocket closed: %w", err)
209+
})
210+
}

websocket_js.go

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,29 +23,32 @@ type Conn struct {
2323

2424
msgReadLimit int64
2525

26-
readClosed int64
27-
closeOnce sync.Once
28-
closed chan struct{}
29-
closeErr error
26+
readClosed int64
27+
closeOnce sync.Once
28+
closed chan struct{}
29+
closeErrOnce sync.Once
30+
closeErr error
3031

3132
releaseOnClose func()
3233
releaseOnMessage func()
3334

34-
readch chan wsjs.MessageEvent
35+
readSignal chan struct{}
36+
readBufMu sync.Mutex
37+
readBuf []wsjs.MessageEvent
3538
}
3639

3740
func (c *Conn) close(err error) {
3841
c.closeOnce.Do(func() {
3942
runtime.SetFinalizer(c, nil)
4043

41-
c.closeErr = fmt.Errorf("websocket closed: %w", err)
44+
c.setCloseErr(err)
4245
close(c.closed)
4346
})
4447
}
4548

4649
func (c *Conn) init() {
4750
c.closed = make(chan struct{})
48-
c.readch = make(chan wsjs.MessageEvent, 1)
51+
c.readSignal = make(chan struct{}, 1)
4952
c.msgReadLimit = 32768
5053

5154
c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) {
@@ -61,15 +64,28 @@ func (c *Conn) init() {
6164
})
6265

6366
c.releaseOnMessage = c.ws.OnMessage(func(e wsjs.MessageEvent) {
64-
c.readch <- e
67+
c.readBufMu.Lock()
68+
defer c.readBufMu.Unlock()
69+
70+
c.readBuf = append(c.readBuf, e)
71+
72+
// Lets the read goroutine know there is definitely something in readBuf.
73+
select {
74+
case c.readSignal <- struct{}{}:
75+
default:
76+
}
6577
})
6678

6779
runtime.SetFinalizer(c, func(c *Conn) {
68-
c.ws.Close(int(StatusInternalError), "")
69-
c.close(errors.New("connection garbage collected"))
80+
c.setCloseErr(errors.New("connection garbage collected"))
81+
c.closeWithInternal()
7082
})
7183
}
7284

85+
func (c *Conn) closeWithInternal() {
86+
c.Close(StatusInternalError, "something went wrong")
87+
}
88+
7389
// Read attempts to read a message from the connection.
7490
// The maximum time spent waiting is bounded by the context.
7591
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
@@ -89,16 +105,32 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
89105
}
90106

91107
func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) {
92-
var me wsjs.MessageEvent
93108
select {
94109
case <-ctx.Done():
95110
c.Close(StatusPolicyViolation, "read timed out")
96111
return 0, nil, ctx.Err()
97-
case me = <-c.readch:
112+
case <-c.readSignal:
98113
case <-c.closed:
99114
return 0, nil, c.closeErr
100115
}
101116

117+
c.readBufMu.Lock()
118+
defer c.readBufMu.Unlock()
119+
120+
me := c.readBuf[0]
121+
// We copy the messages forward and decrease the size
122+
// of the slice to avoid reallocating.
123+
copy(c.readBuf, c.readBuf[1:])
124+
c.readBuf = c.readBuf[:len(c.readBuf)-1]
125+
126+
if len(c.readBuf) > 0 {
127+
// Next time we read, we'll grab the message.
128+
select {
129+
case c.readSignal <- struct{}{}:
130+
default:
131+
}
132+
}
133+
102134
switch p := me.Data.(type) {
103135
case string:
104136
return MessageText, []byte(p), nil
@@ -118,8 +150,10 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
118150
// to match the Go API. It can only error if the message type
119151
// is unexpected or the passed bytes contain invalid UTF-8 for
120152
// MessageText.
121-
c.Close(StatusInternalError, "something went wrong")
122-
return fmt.Errorf("failed to write: %w", err)
153+
err := fmt.Errorf("failed to write: %w", err)
154+
c.setCloseErr(err)
155+
c.closeWithInternal()
156+
return err
123157
}
124158
return nil
125159
}

0 commit comments

Comments
 (0)