@@ -23,24 +23,26 @@ type Conn struct {
23
23
24
24
msgReadLimit int64
25
25
26
- readClosed int64
27
- closeOnce sync.Once
28
- closed chan struct {}
29
- closeErr error
26
+ readClosed int64
27
+ closeOnce sync.Once
28
+ closed chan struct {}
29
+ closeErrOnce sync.Once
30
+ closeErr error
30
31
31
32
releaseOnClose func ()
32
33
releaseOnMessage func ()
33
34
34
35
readSignal chan struct {}
35
36
readBufMu sync.Mutex
36
- readBuf []wsjs.MessageEvent
37
+ // Max size of readBuf is 32.
38
+ readBuf []wsjs.MessageEvent
37
39
}
38
40
39
41
func (c * Conn ) close (err error ) {
40
42
c .closeOnce .Do (func () {
41
43
runtime .SetFinalizer (c , nil )
42
44
43
- c .closeErr = fmt . Errorf ( "websocket closed: %w" , err )
45
+ c .setCloseErr ( err )
44
46
close (c .closed )
45
47
})
46
48
}
@@ -49,6 +51,8 @@ func (c *Conn) init() {
49
51
c .closed = make (chan struct {})
50
52
c .readSignal = make (chan struct {}, 1 )
51
53
c .msgReadLimit = 32768
54
+ // Capacity limited to 32 messages.
55
+ c .readBuf = make ([]wsjs.MessageEvent , 0 , 32 )
52
56
53
57
c .releaseOnClose = c .ws .OnClose (func (e wsjs.CloseEvent ) {
54
58
cerr := CloseError {
@@ -66,6 +70,12 @@ func (c *Conn) init() {
66
70
c .readBufMu .Lock ()
67
71
defer c .readBufMu .Unlock ()
68
72
73
+ if len (c .readBuf ) == cap (c .readBuf ) {
74
+ c .setCloseErr (fmt .Errorf ("too many messages in buffer, cannot keep up: %v" , len (c .readBuf )))
75
+ c .Close (StatusPolicyViolation , "unable to read fast enough" )
76
+ return
77
+ }
78
+
69
79
c .readBuf = append (c .readBuf , e )
70
80
71
81
// Lets the read goroutine know there is definitely something in readBuf.
@@ -76,11 +86,15 @@ func (c *Conn) init() {
76
86
})
77
87
78
88
runtime .SetFinalizer (c , func (c * Conn ) {
79
- c .ws . Close ( int ( StatusInternalError ), "" )
80
- c .close ( errors . New ( "connection garbage collected" ) )
89
+ c .setCloseErr ( errors . New ( "connection garbage collected" ) )
90
+ c .closeWithInternal ( )
81
91
})
82
92
}
83
93
94
+ func (c * Conn ) closeWithInternal () {
95
+ c .Close (StatusInternalError , "something went wrong" )
96
+ }
97
+
84
98
// Read attempts to read a message from the connection.
85
99
// The maximum time spent waiting is bounded by the context.
86
100
func (c * Conn ) Read (ctx context.Context ) (MessageType , []byte , error ) {
@@ -113,11 +127,8 @@ func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) {
113
127
defer c .readBufMu .Unlock ()
114
128
115
129
me := c .readBuf [0 ]
116
- // Ensures GC can collect the message event.
117
- c .readBuf [0 ] = wsjs.MessageEvent {}
118
- // We do not shrink the array since it will be resized
119
- // as appropriate by append in the OnMessage callback.
120
- c .readBuf = c .readBuf [1 :]
130
+ copy (c .readBuf , c .readBuf [1 :])
131
+ c .readBuf = c .readBuf [:len (c .readBuf )- 1 ]
121
132
122
133
if len (c .readBuf ) > 0 {
123
134
// Next time we read, we'll grab the message.
@@ -146,8 +157,10 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
146
157
// to match the Go API. It can only error if the message type
147
158
// is unexpected or the passed bytes contain invalid UTF-8 for
148
159
// MessageText.
149
- c .Close (StatusInternalError , "something went wrong" )
150
- return fmt .Errorf ("failed to write: %w" , err )
160
+ err := fmt .Errorf ("failed to write: %w" , err )
161
+ c .setCloseErr (err )
162
+ c .closeWithInternal ()
163
+ return err
151
164
}
152
165
return nil
153
166
}
0 commit comments