Skip to content

Commit 6ac2c58

Browse files
committed
Add follow_redirects support to SimpleAsyncHTTPClient.
1 parent ab217b6 commit 6ac2c58

File tree

2 files changed

+54
-10
lines changed

2 files changed

+54
-10
lines changed

tornado/simple_httpclient.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import collections
1212
import contextlib
13+
import copy
1314
import errno
1415
import functools
1516
import logging
@@ -279,8 +280,23 @@ def _on_body(self, data):
279280
buffer = StringIO()
280281
else:
281282
buffer = StringIO(data) # TODO: don't require one big string?
282-
response = HTTPResponse(self.request, self.code, headers=self.headers,
283-
buffer=buffer)
283+
original_request = getattr(self.request, "original_request",
284+
self.request)
285+
if (self.request.follow_redirects and
286+
self.request.max_redirects > 0 and
287+
self.code in (301, 302)):
288+
new_request = copy.copy(self.request)
289+
new_request.url = urlparse.urljoin(self.request.url,
290+
self.headers["Location"])
291+
new_request.max_redirects -= 1
292+
new_request.original_request = original_request
293+
self.client.fetch(new_request, self.callback)
294+
self.callback = None
295+
return
296+
response = HTTPResponse(original_request,
297+
self.code, headers=self.headers,
298+
buffer=buffer,
299+
effective_url=self.request.url)
284300
self.callback(response)
285301
self.callback = None
286302

tornado/test/simple_httpclient_test.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from tornado.ioloop import IOLoop
1212
from tornado.simple_httpclient import SimpleAsyncHTTPClient
1313
from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, get_unused_port
14-
from tornado.web import Application, RequestHandler, asynchronous
14+
from tornado.web import Application, RequestHandler, asynchronous, url
1515

1616
class HelloWorldHandler(RequestHandler):
1717
def get(self):
@@ -50,18 +50,27 @@ def get(self):
5050
self.queue.append(self.finish)
5151
self.wake_callback()
5252

53+
class CountdownHandler(RequestHandler):
54+
def get(self, count):
55+
count = int(count)
56+
if count > 0:
57+
self.redirect(self.reverse_url("countdown", count - 1))
58+
else:
59+
self.write("Zero")
60+
5361
class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase):
5462
def get_app(self):
5563
# callable objects to finish pending /trigger requests
5664
self.triggers = collections.deque()
5765
return Application([
58-
("/hello", HelloWorldHandler),
59-
("/post", PostHandler),
60-
("/chunk", ChunkHandler),
61-
("/auth", AuthHandler),
62-
("/hang", HangHandler),
63-
("/trigger", TriggerHandler, dict(queue=self.triggers,
64-
wake_callback=self.stop)),
66+
url("/hello", HelloWorldHandler),
67+
url("/post", PostHandler),
68+
url("/chunk", ChunkHandler),
69+
url("/auth", AuthHandler),
70+
url("/hang", HangHandler),
71+
url("/trigger", TriggerHandler, dict(queue=self.triggers,
72+
wake_callback=self.stop)),
73+
url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
6574
], gzip=True)
6675

6776
def setUp(self):
@@ -176,3 +185,22 @@ def test_connection_limit(self):
176185
self.assertEqual(seen, [0, 1])
177186
self.assertEqual(len(client.queue), 0)
178187

188+
def test_follow_redirect(self):
189+
response = self.fetch("/countdown/2", follow_redirects=False)
190+
self.assertEqual(302, response.code)
191+
self.assertTrue(response.headers["Location"].endswith("/countdown/1"))
192+
193+
response = self.fetch("/countdown/2")
194+
self.assertEqual(200, response.code)
195+
self.assertTrue(response.effective_url.endswith("/countdown/0"))
196+
self.assertEqual("Zero", response.body)
197+
198+
def test_max_redirects(self):
199+
response = self.fetch("/countdown/5", max_redirects=3)
200+
self.assertEqual(302, response.code)
201+
# We requested 5, followed three redirects for 4, 3, 2, then the last
202+
# unfollowed redirect is to 1.
203+
self.assertTrue(response.request.url.endswith("/countdown/5"))
204+
self.assertTrue(response.effective_url.endswith("/countdown/2"))
205+
self.assertTrue(response.headers["Location"].endswith("/countdown/1"))
206+

0 commit comments

Comments
 (0)