Skip to content

Commit 442b49f

Browse files
committed
Refactor websocket close logic; remove dependency on singleton IOLoop.
1 parent 5a18d50 commit 442b49f

File tree

1 file changed

+35
-22
lines changed

1 file changed

+35
-22
lines changed

tornado/websocket.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -186,15 +186,10 @@ def _not_supported(self, *args, **kwargs):
186186

187187
def on_connection_close(self):
188188
if self.ws_connection:
189-
self.ws_connection.client_terminated = True
189+
self.ws_connection.on_connection_close()
190+
self.ws_connection = None
190191
self.on_close()
191192

192-
def _set_client_terminated(self, value):
193-
self.ws_connection.client_terminated = value
194-
195-
client_terminated = property(lambda self: self.ws_connection.client_terminated,
196-
_set_client_terminated)
197-
198193

199194
for method in ["write", "redirect", "set_header", "send_error", "set_cookie",
200195
"set_status", "flush", "finish"]:
@@ -209,6 +204,7 @@ def __init__(self, handler):
209204
self.request = handler.request
210205
self.stream = handler.stream
211206
self.client_terminated = False
207+
self.server_terminated = False
212208

213209
def async_callback(self, callback, *args, **kwargs):
214210
"""Wrap callbacks with this if they are used on asynchronous requests.
@@ -227,10 +223,15 @@ def wrapper(*args, **kwargs):
227223
self._abort()
228224
return wrapper
229225

226+
def on_connection_close(self):
227+
self._abort()
228+
230229
def _abort(self):
231230
"""Instantly aborts the WebSocket connection by closing the socket"""
232231
self.client_terminated = True
233-
self.stream.close()
232+
self.server_terminated = True
233+
self.stream.close() # forcibly tear down the connection
234+
self.close() # let the subclass cleanup
234235

235236

236237
class WebSocketProtocol76(WebSocketProtocol):
@@ -384,14 +385,18 @@ def write_message(self, message, binary=False):
384385

385386
def close(self):
386387
"""Closes the WebSocket connection."""
387-
if self.client_terminated and self._waiting:
388-
tornado.ioloop.IOLoop.instance().remove_timeout(self._waiting)
388+
if not self.server_terminated:
389+
if not self.stream.closed():
390+
self.stream.write("\xff\x00")
391+
self.server_terminated = True
392+
if self.client_terminated:
393+
if self._waiting is not None:
394+
self.stream.io_loop.remove_timeout(self._waiting)
389395
self._waiting = None
390396
self.stream.close()
391-
elif not self.stream.closed():
392-
self.stream.write("\xff\x00")
393-
self._waiting = tornado.ioloop.IOLoop.instance().add_timeout(
394-
time.time() + 5, self._abort)
397+
elif self._waiting is None:
398+
self._waiting = self.stream.io_loop.add_timeout(
399+
time.time() + 5, self._abort)
395400

396401

397402
class WebSocketProtocol13(WebSocketProtocol):
@@ -408,7 +413,7 @@ def __init__(self, handler):
408413
self._frame_length = None
409414
self._fragmented_message_buffer = None
410415
self._fragmented_message_opcode = None
411-
self._started_closing_handshake = False
416+
self._waiting = None
412417

413418
def accept_connection(self):
414419
try:
@@ -589,9 +594,7 @@ def _handle_message(self, opcode, data):
589594
elif opcode == 0x8:
590595
# Close
591596
self.client_terminated = True
592-
if not self._started_closing_handshake:
593-
self._write_frame(True, 0x8, b(""))
594-
self.stream.close()
597+
self.close()
595598
elif opcode == 0x9:
596599
# Ping
597600
self._write_frame(True, 0xA, data)
@@ -603,7 +606,17 @@ def _handle_message(self, opcode, data):
603606

604607
def close(self):
605608
"""Closes the WebSocket connection."""
606-
if self.stream.closed(): return
607-
self._write_frame(True, 0x8, b(""))
608-
self._started_closing_handshake = True
609-
self._waiting = tornado.ioloop.IOLoop.instance().add_timeout(time.time() + 5, self._abort)
609+
if not self.server_terminated:
610+
if not self.stream.closed():
611+
self._write_frame(True, 0x8, b(""))
612+
self.server_terminated = True
613+
if self.client_terminated:
614+
if self._waiting is not None:
615+
self.stream.io_loop.remove_timeout(self._waiting)
616+
self._waiting = None
617+
self.stream.close()
618+
elif self._waiting is None:
619+
# Give the client a few seconds to complete a clean shutdown,
620+
# otherwise just close the connection.
621+
self._waiting = self.stream.io_loop.add_timeout(
622+
time.time() + 5, self._abort)

0 commit comments

Comments
 (0)