Skip to content

Commit 72a1ae5

Browse files
committed
chore: simplify mock interface definition
1 parent 27818d2 commit 72a1ae5

File tree

1 file changed

+61
-30
lines changed

1 file changed

+61
-30
lines changed

coderd/httpapi/httpapi_test.go

Lines changed: 61 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -159,23 +159,70 @@ func TestWebsocketCloseMsg(t *testing.T) {
159159
})
160160
}
161161

162-
type mockHijacker struct {
163-
http.ResponseWriter
164-
serverConn net.Conn
165-
clientConn net.Conn
166-
rw *bufio.ReadWriter
162+
// Our WebSocket library accepts any arbitrary ResponseWriter at the type level,
163+
// but it must also implement http.Hijack
164+
type mockWsResponseWriter struct {
165+
recorder http.ResponseWriter
166+
serverConn net.Conn
167+
clientConn net.Conn
168+
serverReadWriter *bufio.ReadWriter
167169
}
168170

169-
func (m mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
170-
return m.serverConn, m.rw, nil
171+
func (m mockWsResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
172+
return m.serverConn, m.serverReadWriter, nil
171173
}
172174

173-
func (m mockHijacker) Flush() {
174-
if f, ok := m.ResponseWriter.(http.Flusher); ok {
175+
func (m mockWsResponseWriter) Flush() {
176+
if f, ok := m.recorder.(http.Flusher); ok {
175177
f.Flush()
176178
}
177179
}
178180

181+
func (m mockWsResponseWriter) Header() http.Header {
182+
return m.recorder.Header()
183+
}
184+
185+
func (m mockWsResponseWriter) Write(b []byte) (int, error) {
186+
return m.serverReadWriter.Write(b)
187+
}
188+
189+
func (m mockWsResponseWriter) WriteHeader(code int) {
190+
m.recorder.WriteHeader(code)
191+
}
192+
193+
type mockWsResponseWrite func(b []byte) (int, error)
194+
195+
func (w mockWsResponseWrite) Write(b []byte) (int, error) {
196+
return w(b)
197+
}
198+
199+
func newMockWebsocketWriter() mockWsResponseWriter {
200+
server, client := net.Pipe()
201+
recorder := httptest.NewRecorder()
202+
203+
var write mockWsResponseWrite = func(b []byte) (int, error) {
204+
serverCount, err := server.Write(b)
205+
if err != nil {
206+
return serverCount, err
207+
}
208+
recorderCount, err := recorder.Write(b)
209+
if serverCount < recorderCount {
210+
return serverCount, err
211+
}
212+
return recorderCount, err
213+
}
214+
215+
return mockWsResponseWriter{
216+
serverConn: server,
217+
clientConn: client,
218+
recorder: recorder,
219+
serverReadWriter: bufio.NewReadWriter(
220+
bufio.NewReader(server),
221+
bufio.NewWriter(write),
222+
),
223+
}
224+
}
225+
179226
func TestOneWayWebSocket(t *testing.T) {
180227
t.Parallel()
181228

@@ -194,21 +241,6 @@ func TestOneWayWebSocket(t *testing.T) {
194241
return req
195242
}
196243

197-
newMockHijacker := func() mockHijacker {
198-
server, client := net.Pipe()
199-
reader := bufio.NewReader(strings.NewReader(""))
200-
recorder := httptest.NewRecorder()
201-
writer := bufio.NewWriter(recorder)
202-
readWriter := bufio.NewReadWriter(reader, writer)
203-
204-
return mockHijacker{
205-
serverConn: server,
206-
clientConn: client,
207-
ResponseWriter: recorder,
208-
rw: readWriter,
209-
}
210-
}
211-
212244
t.Run("Produces an error if the socket connection could not be established", func(t *testing.T) {
213245
t.Parallel()
214246

@@ -226,19 +258,18 @@ func TestOneWayWebSocket(t *testing.T) {
226258
req.ProtoMinor = p.minor
227259
req.Proto = p.proto
228260

229-
_, _, err := httpapi.OneWayWebSocket[any](httptest.NewRecorder(), req)
261+
writer := newMockWebsocketWriter()
262+
_, _, err := httpapi.OneWayWebSocket[any](writer, req)
230263
require.ErrorContains(t, err, p.proto)
231264
}
232265
})
233266

234267
t.Run("Returned callback can publish a new event to the WebSocket connection", func(t *testing.T) {
235268
t.Parallel()
236269

237-
mock := newMockHijacker()
238-
send, _, err := httpapi.OneWayWebSocket[codersdk.ServerSentEvent](
239-
mock,
240-
createBaseRequest(t),
241-
)
270+
req := createBaseRequest(t)
271+
writer := newMockWebsocketWriter()
272+
send, _, err := httpapi.OneWayWebSocket[codersdk.ServerSentEvent](writer, req)
242273
require.NoError(t, err)
243274

244275
err = send(codersdk.ServerSentEvent{

0 commit comments

Comments
 (0)