@@ -174,18 +174,25 @@ func (h *Handler) handleUpgrade(w http.ResponseWriter, r *http.Request) {
174
174
h .logger .Error (ctx , "failed to accept websocket" , slog .Error (err ))
175
175
return
176
176
}
177
- defer conn .Close (websocket .StatusInternalError , "internal error" )
178
177
179
- // BackedPipe handles sequence numbers internally
180
- // No need to expose them through the API
178
+ // Create a context that we can cancel to clean up the connection
179
+ connCtx , cancel := context .WithCancel (ctx )
180
+ defer cancel ()
181
+
182
+ // Ensure WebSocket is closed when this function returns
183
+ defer func () {
184
+ conn .Close (websocket .StatusNormalClosure , "connection closed" )
185
+ }()
181
186
182
187
// Create a WebSocket adapter
183
188
wsConn := & wsConn {
184
189
conn : conn ,
185
190
logger : h .logger ,
191
+ ctx : connCtx ,
192
+ cancel : cancel ,
186
193
}
187
194
188
- // Handle the reconnection
195
+ // Handle the reconnection - this establishes the connection
189
196
// BackedPipe only needs the reader sequence number for replay
190
197
err = h .manager .HandleConnection (streamID , wsConn , readSeqNum )
191
198
if err != nil {
@@ -194,19 +201,26 @@ func (h *Handler) handleUpgrade(w http.ResponseWriter, r *http.Request) {
194
201
return
195
202
}
196
203
197
- // Keep the connection open until it's closed
198
- <- ctx .Done ()
204
+ // Keep the connection open until the context is cancelled
205
+ // The wsConn will handle connection closure through its Read/Write methods
206
+ // When the connection is closed, the backing pipe will detect it and the context should be cancelled
207
+ <- connCtx .Done ()
208
+ h .logger .Debug (ctx , "websocket connection handler exiting" )
199
209
}
200
210
201
211
// wsConn adapts a WebSocket connection to io.ReadWriteCloser
202
212
type wsConn struct {
203
213
conn * websocket.Conn
204
214
logger slog.Logger
215
+ ctx context.Context
216
+ cancel context.CancelFunc
205
217
}
206
218
207
219
func (c * wsConn ) Read (p []byte ) (n int , err error ) {
208
- typ , data , err := c .conn .Read (context . Background () )
220
+ typ , data , err := c .conn .Read (c . ctx )
209
221
if err != nil {
222
+ // Cancel the context when read fails (connection closed)
223
+ c .cancel ()
210
224
return 0 , err
211
225
}
212
226
if typ != websocket .MessageBinary {
@@ -217,14 +231,17 @@ func (c *wsConn) Read(p []byte) (n int, err error) {
217
231
}
218
232
219
233
func (c * wsConn ) Write (p []byte ) (n int , err error ) {
220
- err = c .conn .Write (context . Background () , websocket .MessageBinary , p )
234
+ err = c .conn .Write (c . ctx , websocket .MessageBinary , p )
221
235
if err != nil {
236
+ // Cancel the context when write fails (connection closed)
237
+ c .cancel ()
222
238
return 0 , err
223
239
}
224
240
return len (p ), nil
225
241
}
226
242
227
243
func (c * wsConn ) Close () error {
244
+ c .cancel () // Cancel the context when explicitly closed
228
245
return c .conn .Close (websocket .StatusNormalClosure , "" )
229
246
}
230
247
0 commit comments