@@ -159,23 +159,70 @@ func TestWebsocketCloseMsg(t *testing.T) {
159
159
})
160
160
}
161
161
162
- type mockHijacker struct {
163
- http.ResponseWriter
164
- serverConn net.Conn
165
- clientConn net.Conn
166
- rw * bufio.ReadWriter
162
+ // Our WebSocket library accepts any arbitrary ResponseWriter at the type level,
163
+ // but it must also implement http.Hijack
164
+ type mockWsResponseWriter struct {
165
+ recorder http.ResponseWriter
166
+ serverConn net.Conn
167
+ clientConn net.Conn
168
+ serverReadWriter * bufio.ReadWriter
167
169
}
168
170
169
- func (m mockHijacker ) Hijack () (net.Conn , * bufio.ReadWriter , error ) {
170
- return m .serverConn , m .rw , nil
171
+ func (m mockWsResponseWriter ) Hijack () (net.Conn , * bufio.ReadWriter , error ) {
172
+ return m .serverConn , m .serverReadWriter , nil
171
173
}
172
174
173
- func (m mockHijacker ) Flush () {
174
- if f , ok := m .ResponseWriter .(http.Flusher ); ok {
175
+ func (m mockWsResponseWriter ) Flush () {
176
+ if f , ok := m .recorder .(http.Flusher ); ok {
175
177
f .Flush ()
176
178
}
177
179
}
178
180
181
+ func (m mockWsResponseWriter ) Header () http.Header {
182
+ return m .recorder .Header ()
183
+ }
184
+
185
+ func (m mockWsResponseWriter ) Write (b []byte ) (int , error ) {
186
+ return m .serverReadWriter .Write (b )
187
+ }
188
+
189
+ func (m mockWsResponseWriter ) WriteHeader (code int ) {
190
+ m .recorder .WriteHeader (code )
191
+ }
192
+
193
+ type mockWsResponseWrite func (b []byte ) (int , error )
194
+
195
+ func (w mockWsResponseWrite ) Write (b []byte ) (int , error ) {
196
+ return w (b )
197
+ }
198
+
199
+ func newMockWebsocketWriter () mockWsResponseWriter {
200
+ server , client := net .Pipe ()
201
+ recorder := httptest .NewRecorder ()
202
+
203
+ var write mockWsResponseWrite = func (b []byte ) (int , error ) {
204
+ serverCount , err := server .Write (b )
205
+ if err != nil {
206
+ return serverCount , err
207
+ }
208
+ recorderCount , err := recorder .Write (b )
209
+ if serverCount < recorderCount {
210
+ return serverCount , err
211
+ }
212
+ return recorderCount , err
213
+ }
214
+
215
+ return mockWsResponseWriter {
216
+ serverConn : server ,
217
+ clientConn : client ,
218
+ recorder : recorder ,
219
+ serverReadWriter : bufio .NewReadWriter (
220
+ bufio .NewReader (server ),
221
+ bufio .NewWriter (write ),
222
+ ),
223
+ }
224
+ }
225
+
179
226
func TestOneWayWebSocket (t * testing.T ) {
180
227
t .Parallel ()
181
228
@@ -194,21 +241,6 @@ func TestOneWayWebSocket(t *testing.T) {
194
241
return req
195
242
}
196
243
197
- newMockHijacker := func () mockHijacker {
198
- server , client := net .Pipe ()
199
- reader := bufio .NewReader (strings .NewReader ("" ))
200
- recorder := httptest .NewRecorder ()
201
- writer := bufio .NewWriter (recorder )
202
- readWriter := bufio .NewReadWriter (reader , writer )
203
-
204
- return mockHijacker {
205
- serverConn : server ,
206
- clientConn : client ,
207
- ResponseWriter : recorder ,
208
- rw : readWriter ,
209
- }
210
- }
211
-
212
244
t .Run ("Produces an error if the socket connection could not be established" , func (t * testing.T ) {
213
245
t .Parallel ()
214
246
@@ -226,19 +258,18 @@ func TestOneWayWebSocket(t *testing.T) {
226
258
req .ProtoMinor = p .minor
227
259
req .Proto = p .proto
228
260
229
- _ , _ , err := httpapi .OneWayWebSocket [any ](httptest .NewRecorder (), req )
261
+ writer := newMockWebsocketWriter ()
262
+ _ , _ , err := httpapi .OneWayWebSocket [any ](writer , req )
230
263
require .ErrorContains (t , err , p .proto )
231
264
}
232
265
})
233
266
234
267
t .Run ("Returned callback can publish a new event to the WebSocket connection" , func (t * testing.T ) {
235
268
t .Parallel ()
236
269
237
- mock := newMockHijacker ()
238
- send , _ , err := httpapi .OneWayWebSocket [codersdk.ServerSentEvent ](
239
- mock ,
240
- createBaseRequest (t ),
241
- )
270
+ req := createBaseRequest (t )
271
+ writer := newMockWebsocketWriter ()
272
+ send , _ , err := httpapi .OneWayWebSocket [codersdk.ServerSentEvent ](writer , req )
242
273
require .NoError (t , err )
243
274
244
275
err = send (codersdk.ServerSentEvent {
0 commit comments