@@ -117,13 +117,22 @@ class HTTPServer(socketserver.TCPServer):
117
117
allow_reuse_address = True # Seems to make sense in testing environment
118
118
allow_reuse_port = True
119
119
120
+ def __init__ (self , * args , response_headers = None , ** kwargs ):
121
+ self .response_headers = response_headers if response_headers is not None else {}
122
+ super ().__init__ (* args , ** kwargs )
123
+
120
124
def server_bind (self ):
121
125
"""Override server_bind to store the server name."""
122
126
socketserver .TCPServer .server_bind (self )
123
127
host , port = self .server_address [:2 ]
124
128
self .server_name = socket .getfqdn (host )
125
129
self .server_port = port
126
130
131
+ def finish_request (self , request , client_address ):
132
+ """Finish one request by instantiating RequestHandlerClass."""
133
+ args = (request , client_address , self )
134
+ kwargs = dict (response_headers = self .response_headers ) if self .response_headers else dict ()
135
+ self .RequestHandlerClass (* args , ** kwargs )
127
136
128
137
class ThreadingHTTPServer (socketserver .ThreadingMixIn , HTTPServer ):
129
138
daemon_threads = True
@@ -132,7 +141,7 @@ class ThreadingHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
132
141
class HTTPSServer (HTTPServer ):
133
142
def __init__ (self , server_address , RequestHandlerClass ,
134
143
bind_and_activate = True , * , certfile , keyfile = None ,
135
- password = None , alpn_protocols = None ):
144
+ password = None , alpn_protocols = None , response_headers = None ):
136
145
try :
137
146
import ssl
138
147
except ImportError :
@@ -150,7 +159,8 @@ def __init__(self, server_address, RequestHandlerClass,
150
159
151
160
super ().__init__ (server_address ,
152
161
RequestHandlerClass ,
153
- bind_and_activate )
162
+ bind_and_activate ,
163
+ response_headers = response_headers )
154
164
155
165
def server_activate (self ):
156
166
"""Wrap the socket in SSLSocket."""
@@ -692,10 +702,11 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
692
702
'.xz' : 'application/x-xz' ,
693
703
}
694
704
695
- def __init__ (self , * args , directory = None , ** kwargs ):
705
+ def __init__ (self , * args , directory = None , response_headers = None , ** kwargs ):
696
706
if directory is None :
697
707
directory = os .getcwd ()
698
708
self .directory = os .fspath (directory )
709
+ self .response_headers = response_headers or {}
699
710
super ().__init__ (* args , ** kwargs )
700
711
701
712
def do_GET (self ):
@@ -736,6 +747,10 @@ def send_head(self):
736
747
new_url = urllib .parse .urlunsplit (new_parts )
737
748
self .send_header ("Location" , new_url )
738
749
self .send_header ("Content-Length" , "0" )
750
+ # User specified response_headers
751
+ if self .response_headers is not None :
752
+ for header , value in self .response_headers .items ():
753
+ self .send_header (header , value )
739
754
self .end_headers ()
740
755
return None
741
756
for index in self .index_pages :
@@ -795,6 +810,9 @@ def send_head(self):
795
810
self .send_header ("Content-Length" , str (fs [6 ]))
796
811
self .send_header ("Last-Modified" ,
797
812
self .date_time_string (fs .st_mtime ))
813
+ if self .response_headers is not None :
814
+ for header , value in self .response_headers .items ():
815
+ self .send_header (header , value )
798
816
self .end_headers ()
799
817
return f
800
818
except :
@@ -970,7 +988,7 @@ def _get_best_family(*address):
970
988
def test (HandlerClass = BaseHTTPRequestHandler ,
971
989
ServerClass = ThreadingHTTPServer ,
972
990
protocol = "HTTP/1.0" , port = 8000 , bind = None ,
973
- tls_cert = None , tls_key = None , tls_password = None ):
991
+ tls_cert = None , tls_key = None , tls_password = None , response_headers = None ):
974
992
"""Test the HTTP request handler class.
975
993
976
994
This runs an HTTP server on port 8000 (or the port argument).
@@ -981,9 +999,10 @@ def test(HandlerClass=BaseHTTPRequestHandler,
981
999
982
1000
if tls_cert :
983
1001
server = ServerClass (addr , HandlerClass , certfile = tls_cert ,
984
- keyfile = tls_key , password = tls_password )
1002
+ keyfile = tls_key , password = tls_password ,
1003
+ response_headers = response_headers )
985
1004
else :
986
- server = ServerClass (addr , HandlerClass )
1005
+ server = ServerClass (addr , HandlerClass , response_headers = response_headers )
987
1006
988
1007
with server as httpd :
989
1008
host , port = httpd .socket .getsockname ()[:2 ]
@@ -1024,6 +1043,8 @@ def _main(args=None):
1024
1043
parser .add_argument ('port' , default = 8000 , type = int , nargs = '?' ,
1025
1044
help = 'bind to this port '
1026
1045
'(default: %(default)s)' )
1046
+ parser .add_argument ('--cors' , action = 'store_true' ,
1047
+ help = 'Enable Access-Control-Allow-Origin: * header' )
1027
1048
args = parser .parse_args (args )
1028
1049
1029
1050
if not args .tls_cert and args .tls_key :
@@ -1051,15 +1072,19 @@ def server_bind(self):
1051
1072
return super ().server_bind ()
1052
1073
1053
1074
def finish_request (self , request , client_address ):
1054
- self .RequestHandlerClass (request , client_address , self ,
1055
- directory = args .directory )
1075
+ handler_args = (request , client_address , self )
1076
+ handler_kwargs = dict (directory = args .directory )
1077
+ if self .response_headers :
1078
+ handler_kwargs ['response_headers' ] = self .response_headers
1079
+ self .RequestHandlerClass (* handler_args , ** handler_kwargs )
1056
1080
1057
1081
class HTTPDualStackServer (DualStackServerMixin , ThreadingHTTPServer ):
1058
1082
pass
1059
1083
class HTTPSDualStackServer (DualStackServerMixin , ThreadingHTTPSServer ):
1060
1084
pass
1061
1085
1062
1086
ServerClass = HTTPSDualStackServer if args .tls_cert else HTTPDualStackServer
1087
+ response_headers = {'Access-Control-Allow-Origin' : '*' } if args .cors else None
1063
1088
1064
1089
test (
1065
1090
HandlerClass = SimpleHTTPRequestHandler ,
@@ -1070,6 +1095,7 @@ class HTTPSDualStackServer(DualStackServerMixin, ThreadingHTTPSServer):
1070
1095
tls_cert = args .tls_cert ,
1071
1096
tls_key = args .tls_key ,
1072
1097
tls_password = tls_key_password ,
1098
+ response_headers = response_headers
1073
1099
)
1074
1100
1075
1101
0 commit comments