@@ -13,13 +13,16 @@ import (
13
13
"sync/atomic"
14
14
"syscall/js"
15
15
16
+ "nhooyr.io/websocket/internal/bpool"
16
17
"nhooyr.io/websocket/internal/wsjs"
17
18
)
18
19
19
20
// Conn provides a wrapper around the browser WebSocket API.
20
21
type Conn struct {
21
22
ws wsjs.WebSocket
22
23
24
+ msgReadLimit int64
25
+
23
26
readClosed int64
24
27
closeOnce sync.Once
25
28
closed chan struct {}
@@ -43,6 +46,7 @@ func (c *Conn) close(err error) {
43
46
func (c * Conn ) init () {
44
47
c .closed = make (chan struct {})
45
48
c .readch = make (chan wsjs.MessageEvent , 1 )
49
+ c .msgReadLimit = 32768
46
50
47
51
c .releaseOnClose = c .ws .OnClose (func (e wsjs.CloseEvent ) {
48
52
cerr := CloseError {
@@ -77,6 +81,10 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
77
81
if err != nil {
78
82
return 0 , nil , fmt .Errorf ("failed to read: %w" , err )
79
83
}
84
+ if len (p ) > c .msgReadLimit {
85
+ r .c .Close (StatusMessageTooBig , fmt .Sprintf ("read limited at %v bytes" , r .c .msgReadLimit ))
86
+ return 0 , nil , c .closeErr
87
+ }
80
88
return typ , p , nil
81
89
}
82
90
@@ -106,6 +114,11 @@ func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) {
106
114
func (c * Conn ) Write (ctx context.Context , typ MessageType , p []byte ) error {
107
115
err := c .write (ctx , typ , p )
108
116
if err != nil {
117
+ // Have to ensure the WebSocket is closed after a write error
118
+ // to match the Go API. It can only error if the message type
119
+ // is unexpected or the passed bytes contain invalid UTF-8 for
120
+ // MessageText.
121
+ c .Close (StatusInternalError , "something went wrong" )
109
122
return fmt .Errorf ("failed to write: %w" , err )
110
123
}
111
124
return nil
@@ -216,8 +229,10 @@ func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Resp
216
229
return c , & http.Response {}, nil
217
230
}
218
231
219
- func (c * netConn ) netConnReader (ctx context.Context ) (MessageType , io.Reader , error ) {
220
- typ , p , err := c .c .Read (ctx )
232
+ // Reader attempts to read a message from the connection.
233
+ // The maximum time spent waiting is bounded by the context.
234
+ func (c * Conn ) Reader (ctx context.Context ) (MessageType , io.Reader , error ) {
235
+ typ , p , err := c .Read (ctx )
221
236
if err != nil {
222
237
return 0 , nil , err
223
238
}
@@ -228,3 +243,60 @@ func (c *netConn) netConnReader(ctx context.Context) (MessageType, io.Reader, er
228
243
func (c * Conn ) reader (ctx context.Context ) {
229
244
c .read (ctx )
230
245
}
246
+
247
+ // Writer returns a writer to write a WebSocket data message to the connection.
248
+ // It buffers the entire message in memory and then sends it when the writer
249
+ // is closed.
250
+ func (c * Conn ) Writer (ctx context.Context , typ MessageType ) (io.WriteCloser , error ) {
251
+ return writer {
252
+ c : c ,
253
+ ctx : ctx ,
254
+ typ : typ ,
255
+ b : bpool .Get (),
256
+ }, nil
257
+ }
258
+
259
+ type writer struct {
260
+ closed bool
261
+
262
+ c * Conn
263
+ ctx context.Context
264
+ typ MessageType
265
+
266
+ b * bytes.Buffer
267
+ }
268
+
269
+ func (w writer ) Write (p []byte ) (int , error ) {
270
+ if w .closed {
271
+ return 0 , errors .New ("cannot write to closed writer" )
272
+ }
273
+ n , err := w .b .Write (p )
274
+ if err != nil {
275
+ return n , fmt .Errorf ("failed to write message: %w" , err )
276
+ }
277
+ return n , nil
278
+ }
279
+
280
+ func (w writer ) Close () error {
281
+ if w .closed {
282
+ return errors .New ("cannot close closed writer" )
283
+ }
284
+ w .closed = true
285
+ defer bpool .Put (w .b )
286
+
287
+ err := w .c .Write (w .ctx , w .typ , w .b .Bytes ())
288
+ if err != nil {
289
+ return fmt .Errorf ("failed to close writer: %w" , err )
290
+ }
291
+ return nil
292
+ }
293
+
294
+ // SetReadLimit sets the max number of bytes to read for a single message.
295
+ // It applies to the Reader and Read methods.
296
+ //
297
+ // By default, the connection has a message read limit of 32768 bytes.
298
+ //
299
+ // When the limit is hit, the connection will be closed with StatusMessageTooBig.
300
+ func (c * Conn ) SetReadLimit (n int64 ) {
301
+ c .readLimit = n
302
+ }
0 commit comments