Skip to content

Commit e965299

Browse files
committed
Add a simple mechanism to override DNS lookups in SimpleAsyncHTTPClient.
Intended for use in SSL unittests, where we will need to make requests to localhost using different domain names.
1 parent 88833c1 commit e965299

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

tornado/simple_httpclient.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ class SimpleAsyncHTTPClient(object):
4949

5050
def __new__(cls, io_loop=None, max_clients=10,
5151
max_simultaneous_connections=None,
52-
force_instance=False):
52+
force_instance=False,
53+
hostname_mapping=None):
5354
"""Creates a SimpleAsyncHTTPClient.
5455
5556
Only a single SimpleAsyncHTTPClient instance exists per IOLoop
@@ -61,6 +62,11 @@ def __new__(cls, io_loop=None, max_clients=10,
6162
only for compatibility with the curl-based AsyncHTTPClient. Note
6263
that these arguments are only used when the client is first created,
6364
and will be ignored when an existing client is reused.
65+
66+
hostname_mapping is a dictionary mapping hostnames to IP addresses.
67+
It can be used to make local DNS changes when modifying system-wide
68+
settings like /etc/hosts is not possible or desirable (e.g. in
69+
unittests).
6470
"""
6571
io_loop = io_loop or IOLoop.instance()
6672
if io_loop in cls._ASYNC_CLIENTS and not force_instance:
@@ -71,6 +77,7 @@ def __new__(cls, io_loop=None, max_clients=10,
7177
instance.max_clients = max_clients
7278
instance.queue = collections.deque()
7379
instance.active = {}
80+
instance.hostname_mapping = hostname_mapping
7481
if not force_instance:
7582
cls._ASYNC_CLIENTS[io_loop] = instance
7683
return instance
@@ -97,7 +104,7 @@ def _process_queue(self):
97104
request, callback = self.queue.popleft()
98105
key = object()
99106
self.active[key] = (request, callback)
100-
_HTTPConnection(self.io_loop, request,
107+
_HTTPConnection(self.io_loop, self, request,
101108
functools.partial(self._on_fetch_complete,
102109
key, callback))
103110

@@ -111,9 +118,10 @@ def _on_fetch_complete(self, key, callback, response):
111118
class _HTTPConnection(object):
112119
_SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE"])
113120

114-
def __init__(self, io_loop, request, callback):
121+
def __init__(self, io_loop, client, request, callback):
115122
self.start_time = time.time()
116123
self.io_loop = io_loop
124+
self.client = client
117125
self.request = request
118126
self.callback = callback
119127
self.code = None
@@ -130,6 +138,8 @@ def __init__(self, io_loop, request, callback):
130138
else:
131139
host = parsed.netloc
132140
port = 443 if parsed.scheme == "https" else 80
141+
if self.client.hostname_mapping is not None:
142+
host = self.client.hostname_mapping.get(host, host)
133143

134144
if parsed.scheme == "https":
135145
# TODO: cert verification, etc

0 commit comments

Comments
 (0)