@@ -23,29 +23,32 @@ type Conn struct {
23
23
24
24
msgReadLimit int64
25
25
26
- readClosed int64
27
- closeOnce sync.Once
28
- closed chan struct {}
29
- closeErr error
26
+ readClosed int64
27
+ closeOnce sync.Once
28
+ closed chan struct {}
29
+ closeErrOnce sync.Once
30
+ closeErr error
30
31
31
32
releaseOnClose func ()
32
33
releaseOnMessage func ()
33
34
34
- readch chan wsjs.MessageEvent
35
+ readSignal chan struct {}
36
+ readBufMu sync.Mutex
37
+ readBuf []wsjs.MessageEvent
35
38
}
36
39
37
40
func (c * Conn ) close (err error ) {
38
41
c .closeOnce .Do (func () {
39
42
runtime .SetFinalizer (c , nil )
40
43
41
- c .closeErr = fmt . Errorf ( "websocket closed: %w" , err )
44
+ c .setCloseErr ( err )
42
45
close (c .closed )
43
46
})
44
47
}
45
48
46
49
func (c * Conn ) init () {
47
50
c .closed = make (chan struct {})
48
- c .readch = make (chan wsjs. MessageEvent , 1 )
51
+ c .readSignal = make (chan struct {} , 1 )
49
52
c .msgReadLimit = 32768
50
53
51
54
c .releaseOnClose = c .ws .OnClose (func (e wsjs.CloseEvent ) {
@@ -61,15 +64,28 @@ func (c *Conn) init() {
61
64
})
62
65
63
66
c .releaseOnMessage = c .ws .OnMessage (func (e wsjs.MessageEvent ) {
64
- c .readch <- e
67
+ c .readBufMu .Lock ()
68
+ defer c .readBufMu .Unlock ()
69
+
70
+ c .readBuf = append (c .readBuf , e )
71
+
72
+ // Lets the read goroutine know there is definitely something in readBuf.
73
+ select {
74
+ case c .readSignal <- struct {}{}:
75
+ default :
76
+ }
65
77
})
66
78
67
79
runtime .SetFinalizer (c , func (c * Conn ) {
68
- c .ws . Close ( int ( StatusInternalError ), "" )
69
- c .close ( errors . New ( "connection garbage collected" ) )
80
+ c .setCloseErr ( errors . New ( "connection garbage collected" ) )
81
+ c .closeWithInternal ( )
70
82
})
71
83
}
72
84
85
+ func (c * Conn ) closeWithInternal () {
86
+ c .Close (StatusInternalError , "something went wrong" )
87
+ }
88
+
73
89
// Read attempts to read a message from the connection.
74
90
// The maximum time spent waiting is bounded by the context.
75
91
func (c * Conn ) Read (ctx context.Context ) (MessageType , []byte , error ) {
@@ -89,16 +105,32 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
89
105
}
90
106
91
107
func (c * Conn ) read (ctx context.Context ) (MessageType , []byte , error ) {
92
- var me wsjs.MessageEvent
93
108
select {
94
109
case <- ctx .Done ():
95
110
c .Close (StatusPolicyViolation , "read timed out" )
96
111
return 0 , nil , ctx .Err ()
97
- case me = <- c .readch :
112
+ case <- c .readSignal :
98
113
case <- c .closed :
99
114
return 0 , nil , c .closeErr
100
115
}
101
116
117
+ c .readBufMu .Lock ()
118
+ defer c .readBufMu .Unlock ()
119
+
120
+ me := c .readBuf [0 ]
121
+ // We copy the messages forward and decrease the size
122
+ // of the slice to avoid reallocating.
123
+ copy (c .readBuf , c .readBuf [1 :])
124
+ c .readBuf = c .readBuf [:len (c .readBuf )- 1 ]
125
+
126
+ if len (c .readBuf ) > 0 {
127
+ // Next time we read, we'll grab the message.
128
+ select {
129
+ case c .readSignal <- struct {}{}:
130
+ default :
131
+ }
132
+ }
133
+
102
134
switch p := me .Data .(type ) {
103
135
case string :
104
136
return MessageText , []byte (p ), nil
@@ -118,8 +150,10 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
118
150
// to match the Go API. It can only error if the message type
119
151
// is unexpected or the passed bytes contain invalid UTF-8 for
120
152
// MessageText.
121
- c .Close (StatusInternalError , "something went wrong" )
122
- return fmt .Errorf ("failed to write: %w" , err )
153
+ err := fmt .Errorf ("failed to write: %w" , err )
154
+ c .setCloseErr (err )
155
+ c .closeWithInternal ()
156
+ return err
123
157
}
124
158
return nil
125
159
}
0 commit comments