Skip to content

Commit 302c503

Browse files
committed
Allow calls to SSLIOStream.write while the connection is in progress.
Skip fast-path writes while connecting, and rework the interaction between base class and subclass to avoid the possibility of doubly-wrapped sockets. Closes tornadoweb#587.
1 parent 0147ac2 commit 302c503

File tree

2 files changed

+55
-9
lines changed

2 files changed

+55
-9
lines changed

tornado/iostream.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,11 @@ def write(self, data, callback=None):
207207
else:
208208
self._write_buffer.append(data)
209209
self._write_callback = stack_context.wrap(callback)
210-
self._handle_write()
211-
if self._write_buffer:
212-
self._add_io_state(self.io_loop.WRITE)
213-
self._maybe_add_error_listener()
210+
if not self._connecting:
211+
self._handle_write()
212+
if self._write_buffer:
213+
self._add_io_state(self.io_loop.WRITE)
214+
self._maybe_add_error_listener()
214215

215216
def set_close_callback(self, callback):
216217
"""Call the given callback when the stream is closed."""
@@ -626,6 +627,7 @@ def __init__(self, *args, **kwargs):
626627
self._ssl_accepting = True
627628
self._handshake_reading = False
628629
self._handshake_writing = False
630+
self._ssl_connect_callback = None
629631

630632
def reading(self):
631633
return self._handshake_reading or super(SSLIOStream, self).reading()
@@ -663,7 +665,11 @@ def _do_ssl_handshake(self):
663665
return self.close()
664666
else:
665667
self._ssl_accepting = False
666-
super(SSLIOStream, self)._handle_connect()
668+
if self._ssl_connect_callback is not None:
669+
callback = self._ssl_connect_callback
670+
self._ssl_connect_callback = None
671+
self._run_callback(callback)
672+
667673

668674
def _handle_read(self):
669675
if self._ssl_accepting:
@@ -677,14 +683,23 @@ def _handle_write(self):
677683
return
678684
super(SSLIOStream, self)._handle_write()
679685

686+
def connect(self, address, callback=None):
687+
# Save the user's callback and run it after the ssl handshake
688+
# has completed.
689+
self._ssl_connect_callback = callback
690+
super(SSLIOStream, self).connect(address, callback=None)
691+
680692
def _handle_connect(self):
693+
# When the connection is complete, wrap the socket for SSL
694+
# traffic. Note that we do this by overriding _handle_connect
695+
# instead of by passing a callback to super().connect because
696+
# user callbacks are enqueued asynchronously on the IOLoop,
697+
# but since _handle_events calls _handle_connect immediately
698+
# followed by _handle_write we need this to be synchronous.
681699
self.socket = ssl.wrap_socket(self.socket,
682700
do_handshake_on_connect=False,
683701
**self._ssl_options)
684-
# Don't call the superclass's _handle_connect (which is responsible
685-
# for telling the application that the connection is complete)
686-
# until we've completed the SSL handshake (so certificates are
687-
# available, etc).
702+
super(SSLIOStream, self)._handle_connect()
688703

689704
def _read_from_socket(self):
690705
if self._ssl_accepting:

tornado/test/iostream_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from tornado.util import b
77
from tornado.web import RequestHandler, Application
88
import errno
9+
import logging
910
import os
1011
import platform
1112
import socket
@@ -75,6 +76,36 @@ def test_read_zero_bytes(self):
7576

7677
self.stream.close()
7778

79+
def test_write_while_connecting(self):
80+
stream = self._make_client_iostream()
81+
connected = [False]
82+
def connected_callback():
83+
connected[0] = True
84+
self.stop()
85+
stream.connect(("localhost", self.get_http_port()),
86+
callback=connected_callback)
87+
# unlike the previous tests, try to write before the connection
88+
# is complete.
89+
written = [False]
90+
def write_callback():
91+
written[0] = True
92+
self.stop()
93+
stream.write(b("GET / HTTP/1.0\r\nConnection: close\r\n\r\n"),
94+
callback=write_callback)
95+
self.assertTrue(not connected[0])
96+
# by the time the write has flushed, the connection callback has
97+
# also run
98+
try:
99+
self.wait(lambda: connected[0] and written[0])
100+
finally:
101+
logging.info((connected, written))
102+
103+
stream.read_until_close(self.stop)
104+
data = self.wait()
105+
self.assertTrue(data.endswith(b("Hello")))
106+
107+
stream.close()
108+
78109

79110
class TestIOStreamMixin(object):
80111
def _make_server_iostream(self, connection, **kwargs):

0 commit comments

Comments
 (0)