Skip to content

Commit 3f11652

Browse files
committed
gh-135056: Add a --cors CLI argument to http.server
Add a --cors command line argument to the stdlib http.server module, which will add an `Access-Control-Allow-Origin: *` header to all responses. As part of this implementation, add a `response_headers` argument to SimpleHTTPRequestHandler and HttpServer, which allows callers to add addition headers to the response.
1 parent b525e31 commit 3f11652

File tree

3 files changed

+82
-12
lines changed

3 files changed

+82
-12
lines changed

Doc/library/http.server.rst

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ instantiation, of which this module provides three different variants:
362362
delays, it now always returns the IP address.
363363

364364

365-
.. class:: SimpleHTTPRequestHandler(request, client_address, server, directory=None)
365+
.. class:: SimpleHTTPRequestHandler(request, client_address, server, directory=None, response_headers=None)
366366

367367
This class serves files from the directory *directory* and below,
368368
or the current directory if *directory* is not provided, directly
@@ -374,6 +374,10 @@ instantiation, of which this module provides three different variants:
374374
.. versionchanged:: 3.9
375375
The *directory* parameter accepts a :term:`path-like object`.
376376

377+
.. versionchanged:: 3.15
378+
The *response_headers* parameter accepts an optional dictionary of
379+
additional HTTP headers to add to each response.
380+
377381
A lot of the work, such as parsing the request, is done by the base class
378382
:class:`BaseHTTPRequestHandler`. This class implements the :func:`do_GET`
379383
and :func:`do_HEAD` functions.
@@ -428,6 +432,9 @@ instantiation, of which this module provides three different variants:
428432
followed by a ``'Content-Length:'`` header with the file's size and a
429433
``'Last-Modified:'`` header with the file's modification time.
430434

435+
The headers specified in the dictionary instance argument
436+
``response_headers`` are each individually sent in the response.
437+
431438
Then follows a blank line signifying the end of the headers, and then the
432439
contents of the file are output.
433440

@@ -437,6 +444,9 @@ instantiation, of which this module provides three different variants:
437444
.. versionchanged:: 3.7
438445
Support of the ``'If-Modified-Since'`` header.
439446

447+
.. versionchanged:: 3.15
448+
Support ``response_headers`` as an instance argument.
449+
440450
The :class:`SimpleHTTPRequestHandler` class can be used in the following
441451
manner in order to create a very basic webserver serving files relative to
442452
the current directory::
@@ -543,6 +553,14 @@ The following options are accepted:
543553

544554
.. versionadded:: 3.14
545555

556+
.. option:: --cors
557+
558+
Adds an additional CORS (Cross-Origin Resource sharing) header to each response::
559+
560+
Access-Control-Allow-Origin: *
561+
562+
.. versionadded:: 3.15
563+
546564

547565
.. _http.server-security:
548566

Lib/http/server.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,22 @@ class HTTPServer(socketserver.TCPServer):
117117
allow_reuse_address = True # Seems to make sense in testing environment
118118
allow_reuse_port = True
119119

120+
def __init__(self, *args, response_headers=None, **kwargs):
121+
self.response_headers = response_headers if response_headers is not None else {}
122+
super().__init__(*args, **kwargs)
123+
120124
def server_bind(self):
121125
"""Override server_bind to store the server name."""
122126
socketserver.TCPServer.server_bind(self)
123127
host, port = self.server_address[:2]
124128
self.server_name = socket.getfqdn(host)
125129
self.server_port = port
126130

131+
def finish_request(self, request, client_address):
132+
"""Finish one request by instantiating RequestHandlerClass."""
133+
args = (request, client_address, self)
134+
kwargs = dict(response_headers=self.response_headers) if self.response_headers else dict()
135+
self.RequestHandlerClass(*args, **kwargs)
127136

128137
class ThreadingHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
129138
daemon_threads = True
@@ -132,7 +141,7 @@ class ThreadingHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
132141
class HTTPSServer(HTTPServer):
133142
def __init__(self, server_address, RequestHandlerClass,
134143
bind_and_activate=True, *, certfile, keyfile=None,
135-
password=None, alpn_protocols=None):
144+
password=None, alpn_protocols=None, response_headers=None):
136145
try:
137146
import ssl
138147
except ImportError:
@@ -150,7 +159,8 @@ def __init__(self, server_address, RequestHandlerClass,
150159

151160
super().__init__(server_address,
152161
RequestHandlerClass,
153-
bind_and_activate)
162+
bind_and_activate,
163+
response_headers=response_headers)
154164

155165
def server_activate(self):
156166
"""Wrap the socket in SSLSocket."""
@@ -692,10 +702,11 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
692702
'.xz': 'application/x-xz',
693703
}
694704

695-
def __init__(self, *args, directory=None, **kwargs):
705+
def __init__(self, *args, directory=None, response_headers=None, **kwargs):
696706
if directory is None:
697707
directory = os.getcwd()
698708
self.directory = os.fspath(directory)
709+
self.response_headers = response_headers or {}
699710
super().__init__(*args, **kwargs)
700711

701712
def do_GET(self):
@@ -736,6 +747,10 @@ def send_head(self):
736747
new_url = urllib.parse.urlunsplit(new_parts)
737748
self.send_header("Location", new_url)
738749
self.send_header("Content-Length", "0")
750+
# User specified response_headers
751+
if self.response_headers is not None:
752+
for header, value in self.response_headers.items():
753+
self.send_header(header, value)
739754
self.end_headers()
740755
return None
741756
for index in self.index_pages:
@@ -795,6 +810,9 @@ def send_head(self):
795810
self.send_header("Content-Length", str(fs[6]))
796811
self.send_header("Last-Modified",
797812
self.date_time_string(fs.st_mtime))
813+
if self.response_headers is not None:
814+
for header, value in self.response_headers.items():
815+
self.send_header(header, value)
798816
self.end_headers()
799817
return f
800818
except:
@@ -970,7 +988,7 @@ def _get_best_family(*address):
970988
def test(HandlerClass=BaseHTTPRequestHandler,
971989
ServerClass=ThreadingHTTPServer,
972990
protocol="HTTP/1.0", port=8000, bind=None,
973-
tls_cert=None, tls_key=None, tls_password=None):
991+
tls_cert=None, tls_key=None, tls_password=None, response_headers=None):
974992
"""Test the HTTP request handler class.
975993
976994
This runs an HTTP server on port 8000 (or the port argument).
@@ -981,9 +999,10 @@ def test(HandlerClass=BaseHTTPRequestHandler,
981999

9821000
if tls_cert:
9831001
server = ServerClass(addr, HandlerClass, certfile=tls_cert,
984-
keyfile=tls_key, password=tls_password)
1002+
keyfile=tls_key, password=tls_password,
1003+
response_headers=response_headers)
9851004
else:
986-
server = ServerClass(addr, HandlerClass)
1005+
server = ServerClass(addr, HandlerClass, response_headers=response_headers)
9871006

9881007
with server as httpd:
9891008
host, port = httpd.socket.getsockname()[:2]
@@ -1024,6 +1043,8 @@ def _main(args=None):
10241043
parser.add_argument('port', default=8000, type=int, nargs='?',
10251044
help='bind to this port '
10261045
'(default: %(default)s)')
1046+
parser.add_argument('--cors', action='store_true',
1047+
help='Enable Access-Control-Allow-Origin: * header')
10271048
args = parser.parse_args(args)
10281049

10291050
if not args.tls_cert and args.tls_key:
@@ -1051,15 +1072,19 @@ def server_bind(self):
10511072
return super().server_bind()
10521073

10531074
def finish_request(self, request, client_address):
1054-
self.RequestHandlerClass(request, client_address, self,
1055-
directory=args.directory)
1075+
handler_args = (request, client_address, self)
1076+
handler_kwargs = dict(directory=args.directory)
1077+
if self.response_headers:
1078+
handler_kwargs['response_headers'] = self.response_headers
1079+
self.RequestHandlerClass(*handler_args, **handler_kwargs)
10561080

10571081
class HTTPDualStackServer(DualStackServerMixin, ThreadingHTTPServer):
10581082
pass
10591083
class HTTPSDualStackServer(DualStackServerMixin, ThreadingHTTPSServer):
10601084
pass
10611085

10621086
ServerClass = HTTPSDualStackServer if args.tls_cert else HTTPDualStackServer
1087+
response_headers = {'Access-Control-Allow-Origin': '*'} if args.cors else None
10631088

10641089
test(
10651090
HandlerClass=SimpleHTTPRequestHandler,
@@ -1070,6 +1095,7 @@ class HTTPSDualStackServer(DualStackServerMixin, ThreadingHTTPSServer):
10701095
tls_cert=args.tls_cert,
10711096
tls_key=args.tls_key,
10721097
tls_password=tls_key_password,
1098+
response_headers=response_headers
10731099
)
10741100

10751101

Lib/test/test_httpservers.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,12 @@ def test_https_server_raises_runtime_error(self):
8181

8282

8383
class TestServerThread(threading.Thread):
84-
def __init__(self, test_object, request_handler, tls=None):
84+
def __init__(self, test_object, request_handler, tls=None, server_kwargs=None):
8585
threading.Thread.__init__(self)
8686
self.request_handler = request_handler
8787
self.test_object = test_object
8888
self.tls = tls
89+
self.server_kwargs = server_kwargs or {}
8990

9091
def run(self):
9192
if self.tls:
@@ -95,7 +96,8 @@ def run(self):
9596
request_handler=self.request_handler,
9697
)
9798
else:
98-
self.server = HTTPServer(('localhost', 0), self.request_handler)
99+
self.server = HTTPServer(('localhost', 0), self.request_handler,
100+
**self.server_kwargs)
99101
self.test_object.HOST, self.test_object.PORT = self.server.socket.getsockname()
100102
self.test_object.server_started.set()
101103
self.test_object = None
@@ -113,12 +115,14 @@ class BaseTestCase(unittest.TestCase):
113115

114116
# Optional tuple (certfile, keyfile, password) to use for HTTPS servers.
115117
tls = None
118+
server_kwargs = None
116119

117120
def setUp(self):
118121
self._threads = threading_helper.threading_setup()
119122
os.environ = os_helper.EnvironmentVarGuard()
120123
self.server_started = threading.Event()
121-
self.thread = TestServerThread(self, self.request_handler, self.tls)
124+
self.thread = TestServerThread(self, self.request_handler, self.tls,
125+
self.server_kwargs)
122126
self.thread.start()
123127
self.server_started.wait()
124128

@@ -824,6 +828,16 @@ def test_path_without_leading_slash(self):
824828
self.tempdir_name + "/?hi=1")
825829

826830

831+
class CorsHTTPServerTestCase(SimpleHTTPServerTestCase):
832+
server_kwargs = dict(
833+
response_headers = {'Access-Control-Allow-Origin': '*'}
834+
)
835+
def test_cors(self):
836+
response = self.request(self.base_url + '/test')
837+
self.check_status_and_reason(response, HTTPStatus.OK)
838+
self.assertEqual(response.getheader('Access-Control-Allow-Origin'), '*')
839+
840+
827841
class SocketlessRequestHandler(SimpleHTTPRequestHandler):
828842
def __init__(self, directory=None):
829843
request = mock.Mock()
@@ -1306,6 +1320,7 @@ class CommandLineTestCase(unittest.TestCase):
13061320
'tls_cert': None,
13071321
'tls_key': None,
13081322
'tls_password': None,
1323+
'response_headers': None,
13091324
}
13101325

13111326
def setUp(self):
@@ -1371,6 +1386,17 @@ def test_protocol_flag(self, mock_func):
13711386
mock_func.assert_called_once_with(**call_args)
13721387
mock_func.reset_mock()
13731388

1389+
@mock.patch('http.server.test')
1390+
def test_cors_flag(self, mock_func):
1391+
self.invoke_httpd('--cors')
1392+
call_args = self.args | dict(
1393+
response_headers={
1394+
'Access-Control-Allow-Origin': '*'
1395+
}
1396+
)
1397+
mock_func.assert_called_once_with(**call_args)
1398+
mock_func.reset_mock()
1399+
13741400
@unittest.skipIf(ssl is None, "requires ssl")
13751401
@mock.patch('http.server.test')
13761402
def test_tls_cert_and_key_flags(self, mock_func):

0 commit comments

Comments
 (0)