From 39708de939d09751c2203a306bbfd553d80c93e5 Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Sun, 16 May 2021 22:39:54 -0700 Subject: [PATCH 01/28] Support helpers for running on non-IPv4 IPv6-only hosts. --- Lib/test/support/socket_helper.py | 106 ++++++++++++++++++++++++++++-- 1 file changed, 100 insertions(+), 6 deletions(-) diff --git a/Lib/test/support/socket_helper.py b/Lib/test/support/socket_helper.py index e78712b74b1377..12f3784a3c861f 100644 --- a/Lib/test/support/socket_helper.py +++ b/Lib/test/support/socket_helper.py @@ -12,7 +12,7 @@ HOSTv6 = "::1" -def find_unused_port(family=socket.AF_INET, socktype=socket.SOCK_STREAM): +def find_unused_port(family=None, socktype=socket.SOCK_STREAM): """Returns an unused port that should be suitable for binding. This is achieved by creating a temporary socket with the same family and type as the 'sock' parameter (default is AF_INET, SOCK_STREAM), and binding it to @@ -20,6 +20,9 @@ def find_unused_port(family=socket.AF_INET, socktype=socket.SOCK_STREAM): eliciting an unused ephemeral port from the OS. The temporary socket is then closed and deleted, and the ephemeral port is returned. + When family is None it will use whichever of socket.AF_INET or + socket.AF_INET6 makes sense, finding a port available on both if possible. + Either this method or bind_port() should be used for any tests where a server socket needs to be bound to a particular port for the duration of the test. Which one to use depends on whether the calling code is creating @@ -66,11 +69,43 @@ def find_unused_port(family=socket.AF_INET, socktype=socket.SOCK_STREAM): other process when we close and delete our temporary socket but before our calling code has a chance to bind the returned port. We can deal with this issue if/when we come across it. + + TODO(gpshead): We should support a https://pypi.org/project/portpicker/ + portserver or equivalent process running on our buildbot hosts and use that + that portpicker library... """ - with socket.socket(family, socktype) as tempsock: - port = bind_port(tempsock) - del tempsock + if isinstance(family, int): + with socket.socket(family, socktype) as tempsock: + port = bind_port(tempsock) + del tempsock + else: + if family is not None: # Assume it's a sequence, it wasn't int|None. + families = family + else: + families = [] + if IPV4_ENABLED: + families.append(socket.AF_INET) + if IPV6_ENABLED: + families.append(socket.AF_INET6) + assert families, "At least one of IPv4 or IPv6 must be enabled." + port = 0 + errors = {} + for family in families: + try: + with socket.socket(family, socktype) as tempsock: + if not port: + port = bind_port(tempsock) + else: + sock.bind((host, 0)) + port = sock.getsockname()[1] + except OSError as err: + errors[family] = err + port = 0 + del tempsock + if not port: + raise support.TestFailed( + f"Could not bind to a port: {errors}") return port def bind_port(sock, host=HOST): @@ -78,7 +113,7 @@ def bind_port(sock, host=HOST): ephemeral ports in order to ensure we are using an unbound port. This is important as many tests may be running simultaneously, especially in a buildbot environment. This method raises an exception if the sock.family - is AF_INET and sock.type is SOCK_STREAM, *and* the socket has SO_REUSEADDR + is AF_INET* and sock.type is SOCK_STREAM, *and* the socket has SO_REUSEADDR or SO_REUSEPORT set on it. Tests should *never* set these socket options for TCP/IP sockets. The only case for setting these options is testing multicasting via multiple UDP sockets. @@ -88,7 +123,8 @@ def bind_port(sock, host=HOST): from bind()'ing to our host/port for the duration of the test. """ - if sock.family == socket.AF_INET and sock.type == socket.SOCK_STREAM: + if (sock.family in {socket.AF_INET, socket.AF_INET6} and + sock.type == socket.SOCK_STREAM): if hasattr(socket, 'SO_REUSEADDR'): if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 1: raise support.TestFailed("tests should never set the " @@ -112,6 +148,48 @@ def bind_port(sock, host=HOST): port = sock.getsockname()[1] return port + +def get_bound_ip_socket_and_port(*, hostname=HOST, socktype=socket.SOCK_STREAM): + """Get an IP socket bound to a port as a sock, port tuple. + + Creates a socket of socktype bound to hostname using whichever of IPv6 or + IPv4 is available. Context is a (socket, port) tuple. Exiting the context + closes the socket. + + Prefer the bind_ip_socket_and_port context manager when possible. + """ + if IPV6_ENABLED: + family = socket.AF_INET6 + sock = socket.socket(socket.AF_INET6, socktype) + elif IPV4_ENABLED: + family = socket.AF_INET + else: + raise support.TestFailed( + "At least one of IPv4 or IPv6 must be enabled.") + sock = socket.socket(family, socktype) + try: + port = bind_port(sock) + except support.TestFailed: + sock.close() + raise + return sock, port + + +@contextlib.contextmanager +def bind_ip_socket_and_port(*, hostname=HOST, socktype=socket.SOCK_STREAM): + """ + A context manager that creates a socket of socktype bound to hostname + using whichever of IPv6 or IPv4 is available. Context is a (socket, port) + tuple. Exiting the context closes the socket. + """ + sock, port = get_bound_ip_socket_and_port( + hostname=hostname, socktype=socktype) + try: + yield sock, port + finally: + sock.close() + + def bind_unix_socket(sock, addr): """Bind a unix socket, raising SkipTest if PermissionError is raised.""" assert sock.family == socket.AF_UNIX @@ -139,6 +217,22 @@ def _is_ipv6_enabled(): IPV6_ENABLED = _is_ipv6_enabled() +def _is_ipv4_enabled(): + """Check whether IPv4 is enabled on this host.""" + sock = None + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind((HOSTv4, 0)) + return True + except OSError: + return False + finally: + if sock: + sock.close() + +IPV4_ENABLED = _is_ipv4_enabled() + + _bind_nix_socket_error = None def skip_unless_bind_unix_socket(test): """Decorator for tests requiring a functional bind() for unix sockets.""" From aa9e7c5281b0941063a49595d08902404cce0df0 Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Sun, 16 May 2021 22:40:37 -0700 Subject: [PATCH 02/28] Fixes eight testsuites to work on IPv6-only hosts. --- Lib/test/_test_eintr.py | 44 ++++++++++++++-------------- Lib/test/test_docxmlrpc.py | 13 +++++++-- Lib/test/test_epoll.py | 8 ++++-- Lib/test/test_ftplib.py | 24 ++++++++++------ Lib/test/test_imaplib.py | 56 +++++++++++++++++++++--------------- Lib/test/test_largefile.py | 7 +++-- Lib/test/test_nntplib.py | 4 +-- Lib/test/test_robotparser.py | 10 ++++++- 8 files changed, 102 insertions(+), 64 deletions(-) diff --git a/Lib/test/_test_eintr.py b/Lib/test/_test_eintr.py index e43b59d064f55a..dd73a8f7c66bdb 100644 --- a/Lib/test/_test_eintr.py +++ b/Lib/test/_test_eintr.py @@ -285,28 +285,28 @@ def test_sendmsg(self): self._test_send(lambda sock, data: sock.sendmsg([data])) def test_accept(self): - sock = socket.create_server((socket_helper.HOST, 0)) - self.addCleanup(sock.close) - port = sock.getsockname()[1] - - code = '\n'.join(( - 'import socket, time', - '', - 'host = %r' % socket_helper.HOST, - 'port = %s' % port, - 'sleep_time = %r' % self.sleep_time, - '', - '# let parent block on accept()', - 'time.sleep(sleep_time)', - 'with socket.create_connection((host, port)):', - ' time.sleep(sleep_time)', - )) - - proc = self.subprocess(code) - with kill_on_error(proc): - client_sock, _ = sock.accept() - client_sock.close() - self.assertEqual(proc.wait(), 0) + with socket_helper.bind_ip_socket_and_port() as sock_port: + sock, port = sock_port + sock.listen() + + code = '\n'.join(( + 'import socket, time', + '', + 'host = %r' % socket_helper.HOST, + 'port = %s' % port, + 'sleep_time = %r' % self.sleep_time, + '', + '# let parent block on accept()', + 'time.sleep(sleep_time)', + 'with socket.create_connection((host, port)):', + ' time.sleep(sleep_time)', + )) + + proc = self.subprocess(code) + with kill_on_error(proc): + client_sock, _ = sock.accept() + client_sock.close() + self.assertEqual(proc.wait(), 0) # Issue #25122: There is a race condition in the FreeBSD kernel on # handling signals in the FIFO device. Skip the test until the bug is diff --git a/Lib/test/test_docxmlrpc.py b/Lib/test/test_docxmlrpc.py index 7d3e30cbee964a..f5ff914701b590 100644 --- a/Lib/test/test_docxmlrpc.py +++ b/Lib/test/test_docxmlrpc.py @@ -1,7 +1,9 @@ from xmlrpc.server import DocXMLRPCServer import http.client import re +import socket import sys +from test.support import socket_helper import threading import unittest @@ -20,7 +22,14 @@ def make_request_and_skip(self): def make_server(): - serv = DocXMLRPCServer(("localhost", 0), logRequests=False) + try: + serv = DocXMLRPCServer((socket_helper.HOST, 0), logRequests=False) + except OSError: + if not socket_helper.IPV6_ENABLED: + raise + class IPv6DocXMLRPCServer(DocXMLRPCServer): + address_family = socket.AF_INET6 + serv = IPv6DocXMLRPCServer((socket_helper.HOST, 0), logRequests=False) try: # Add some documentation @@ -74,7 +83,7 @@ def setUp(self): self.thread.start() PORT = self.serv.server_address[1] - self.client = http.client.HTTPConnection("localhost:%d" % PORT) + self.client = http.client.HTTPConnection(f"{socket_helper.HOST}:{PORT}") def tearDown(self): self.client.close() diff --git a/Lib/test/test_epoll.py b/Lib/test/test_epoll.py index b623852f9eb4ee..f035e6e754e727 100644 --- a/Lib/test/test_epoll.py +++ b/Lib/test/test_epoll.py @@ -25,6 +25,7 @@ import os import select import socket +from test.support import socket_helper import time import unittest @@ -41,7 +42,8 @@ class TestEPoll(unittest.TestCase): def setUp(self): - self.serverSocket = socket.create_server(('127.0.0.1', 0)) + self.serverSocket, _ = socket_helper.get_bound_ip_socket_and_port() + self.serverSocket.listen() self.connections = [self.serverSocket] def tearDown(self): @@ -49,10 +51,10 @@ def tearDown(self): skt.close() def _connected_pair(self): - client = socket.socket() + client = socket.socket(self.serverSocket.family) client.setblocking(False) try: - client.connect(('127.0.0.1', self.serverSocket.getsockname()[1])) + client.connect((socket_helper.HOST, self.serverSocket.getsockname()[1])) except OSError as e: self.assertEqual(e.args[0], errno.EINPROGRESS) else: diff --git a/Lib/test/test_ftplib.py b/Lib/test/test_ftplib.py index a48b429ca38027..53a8485be1256e 100644 --- a/Lib/test/test_ftplib.py +++ b/Lib/test/test_ftplib.py @@ -265,9 +265,15 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread): handler = DummyFTPHandler - def __init__(self, address, af=socket.AF_INET, encoding=DEFAULT_ENCODING): + def __init__(self, address, af=None, encoding=DEFAULT_ENCODING): threading.Thread.__init__(self) asyncore.dispatcher.__init__(self) + if af is None and address[0] == socket_helper.HOST: + if socket_helper.IPV4_ENABLED: + af = socket.AF_INET + else: + assert socket_helper.IPV6_ENABLED, 'no IPv4 or IPv6?' + af = socket.AF_INET6 self.daemon = True self.create_socket(af, socket.SOCK_STREAM) self.bind(address) @@ -699,19 +705,22 @@ def test_entry(line, type=None, perm=None, unique=None, name=None): for x in self.client.mlsd(): self.fail("unexpected data %s" % x) - def test_makeport(self): + @skipUnless(socket_helper.IPV4_ENABLED, "IPv4 required") + def test_makeport_ipv4(self): with self.client.makeport(): # IPv4 is in use, just make sure send_eprt has not been used self.assertEqual(self.server.handler_instance.last_received_cmd, - 'port') + 'port') - def test_makepasv(self): + @skipUnless(socket_helper.IPV4_ENABLED, "IPv4 required") + def test_makepasv_ipv4(self): host, port = self.client.makepasv() conn = socket.create_connection((host, port), timeout=TIMEOUT) conn.close() # IPv4 is in use, just make sure send_epsv has not been used self.assertEqual(self.server.handler_instance.last_received_cmd, 'pasv') + @skipUnless(socket_helper.IPV4_ENABLED, "IPv4 required") def test_makepasv_issue43285_security_disabled(self): """Test the opt-in to the old vulnerable behavior.""" self.client.trust_server_pasv_ipv4_address = True @@ -779,7 +788,7 @@ def is_client_connected(): def test_source_address(self): self.client.quit() - port = socket_helper.find_unused_port() + port = socket_helper.find_unused_port(family=None) try: self.client.connect(self.server.host, self.server.port, source_address=(HOST, port)) @@ -791,7 +800,7 @@ def test_source_address(self): raise def test_source_address_passive_connection(self): - port = socket_helper.find_unused_port() + port = socket_helper.find_unused_port(family=None) self.client.source_address = (HOST, port) try: with self.client.transfercmd('list') as sock: @@ -1033,9 +1042,8 @@ class TestTimeouts(TestCase): def setUp(self): self.evt = threading.Event() - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock, self.port = socket_helper.get_bound_ip_socket_and_port() self.sock.settimeout(20) - self.port = socket_helper.bind_port(self.sock) self.server_thread = threading.Thread(target=self.server) self.server_thread.daemon = True self.server_thread.start() diff --git a/Lib/test/test_imaplib.py b/Lib/test/test_imaplib.py index c2b935f58164e5..8fb3403f0b41d3 100644 --- a/Lib/test/test_imaplib.py +++ b/Lib/test/test_imaplib.py @@ -13,6 +13,7 @@ from test.support import (verbose, run_with_tz, run_with_locale, cpython_only) from test.support import hashlib_helper +from test.support import socket_helper from test.support import threading_helper from test.support import warnings_helper import unittest @@ -27,6 +28,15 @@ CAFILE = os.path.join(os.path.dirname(__file__) or os.curdir, "pycacert.pem") +if socket_helper.IPV4_ENABLED: + TCPServer = socketserver.TCPServer +elif socket_helper.IPV6_ENABLED: + class TCPServer(socketserver.TCPServer): + address_family = socket.AF_INET6 +else: + raise unittest.SkipTest('IPv4 or IPv6 required.') + + class TestImaplib(unittest.TestCase): def test_Internaldate2tuple(self): @@ -92,7 +102,7 @@ def test_imap4_host_default_value(self): if ssl: - class SecureTCPServer(socketserver.TCPServer): + class SecureTCPServer(TCPServer): def get_request(self): newsocket, fromaddr = self.socket.accept() @@ -238,7 +248,7 @@ def handle_error(self, request, client_address): self.thread.start() if connect: - self.client = self.imap_class(*self.server.server_address) + self.client = self.imap_class(*self.server.server_address[:2]) return self.client, self.server @@ -265,7 +275,7 @@ def handle(self): self.wfile.write(b'* OK') _, server = self._setup(EOFHandler, connect=False) self.assertRaises(imaplib.IMAP4.abort, self.imap_class, - *server.server_address) + *server.server_address[:2]) def test_line_termination(self): class BadNewlineHandler(SimpleIMAPHandler): @@ -274,7 +284,7 @@ def cmd_CAPABILITY(self, tag, args): self._send_tagged(tag, 'OK', 'CAPABILITY completed') _, server = self._setup(BadNewlineHandler, connect=False) self.assertRaises(imaplib.IMAP4.abort, self.imap_class, - *server.server_address) + *server.server_address[:2]) def test_enable_raises_error_if_not_AUTH(self): class EnableHandler(SimpleIMAPHandler): @@ -449,11 +459,11 @@ def handle(self): _, server = self._setup(TooLongHandler, connect=False) with self.assertRaisesRegex(imaplib.IMAP4.error, 'got more than 10 bytes'): - self.imap_class(*server.server_address) + self.imap_class(*server.server_address[:2]) def test_simple_with_statement(self): _, server = self._setup(SimpleIMAPHandler, connect=False) - with self.imap_class(*server.server_address): + with self.imap_class(*server.server_address[:2]): pass def test_imaplib_timeout_test(self): @@ -481,7 +491,7 @@ def handle(self): def test_with_statement(self): _, server = self._setup(SimpleIMAPHandler, connect=False) - with self.imap_class(*server.server_address) as imap: + with self.imap_class(*server.server_address[:2]) as imap: imap.login('user', 'pass') self.assertEqual(server.logged, 'user') self.assertIsNone(server.logged) @@ -489,7 +499,7 @@ def test_with_statement(self): def test_with_statement_logout(self): # It is legal to log out explicitly inside the with block _, server = self._setup(SimpleIMAPHandler, connect=False) - with self.imap_class(*server.server_address) as imap: + with self.imap_class(*server.server_address[:2]) as imap: imap.login('user', 'pass') self.assertEqual(server.logged, 'user') imap.logout() @@ -541,7 +551,7 @@ def test_unselect(self): class NewIMAPTests(NewIMAPTestsMixin, unittest.TestCase): imap_class = imaplib.IMAP4 - server_class = socketserver.TCPServer + server_class = TCPServer @unittest.skipUnless(ssl, "SSL not available") @@ -557,9 +567,9 @@ def test_ssl_raises(self): with self.assertRaisesRegex(ssl.CertificateError, "IP address mismatch, certificate is not valid for " - "'127.0.0.1'"): + f"'({socket_helper.HOSTv4}|{socket_helper.HOSTv6})'"): _, server = self._setup(SimpleIMAPHandler) - client = self.imap_class(*server.server_address, + client = self.imap_class(*server.server_address[:2], ssl_context=ssl_context) client.shutdown() @@ -582,7 +592,7 @@ def test_certfile_arg_warn(self): self.imap_class('localhost', 143, certfile=CERTFILE) class ThreadedNetworkedTests(unittest.TestCase): - server_class = socketserver.TCPServer + server_class = TCPServer imap_class = imaplib.IMAP4 def make_server(self, addr, hdlr): @@ -637,7 +647,7 @@ def reaped_server(self, hdlr): @contextmanager def reaped_pair(self, hdlr): with self.reaped_server(hdlr) as server: - client = self.imap_class(*server.server_address) + client = self.imap_class(*server.server_address[:2]) try: yield server, client finally: @@ -646,7 +656,7 @@ def reaped_pair(self, hdlr): @threading_helper.reap_threads def test_connect(self): with self.reaped_server(SimpleIMAPHandler) as server: - client = self.imap_class(*server.server_address) + client = self.imap_class(*server.server_address[:2]) client.shutdown() @threading_helper.reap_threads @@ -708,7 +718,7 @@ def handle(self): with self.reaped_server(EOFHandler) as server: self.assertRaises(imaplib.IMAP4.abort, - self.imap_class, *server.server_address) + self.imap_class, *server.server_address[:2]) @threading_helper.reap_threads def test_line_termination(self): @@ -721,7 +731,7 @@ def cmd_CAPABILITY(self, tag, args): with self.reaped_server(BadNewlineHandler) as server: self.assertRaises(imaplib.IMAP4.abort, - self.imap_class, *server.server_address) + self.imap_class, *server.server_address[:2]) class UTF8Server(SimpleIMAPHandler): capabilities = 'AUTH ENABLE UTF8=ACCEPT' @@ -906,19 +916,19 @@ def handle(self): with self.reaped_server(TooLongHandler) as server: self.assertRaises(imaplib.IMAP4.error, - self.imap_class, *server.server_address) + self.imap_class, *server.server_address[:2]) @threading_helper.reap_threads def test_simple_with_statement(self): # simplest call with self.reaped_server(SimpleIMAPHandler) as server: - with self.imap_class(*server.server_address): + with self.imap_class(*server.server_address[:2]): pass @threading_helper.reap_threads def test_with_statement(self): with self.reaped_server(SimpleIMAPHandler) as server: - with self.imap_class(*server.server_address) as imap: + with self.imap_class(*server.server_address[:2]) as imap: imap.login('user', 'pass') self.assertEqual(server.logged, 'user') self.assertIsNone(server.logged) @@ -927,7 +937,7 @@ def test_with_statement(self): def test_with_statement_logout(self): # what happens if already logout in the block? with self.reaped_server(SimpleIMAPHandler) as server: - with self.imap_class(*server.server_address) as imap: + with self.imap_class(*server.server_address[:2]) as imap: imap.login('user', 'pass') self.assertEqual(server.logged, 'user') imap.logout() @@ -941,7 +951,7 @@ def test_dump_ur(self): untagged_resp_dict = {'READ-WRITE': [b'']} with self.reaped_server(SimpleIMAPHandler) as server: - with self.imap_class(*server.server_address) as imap: + with self.imap_class(*server.server_address[:2]) as imap: with mock.patch.object(imap, '_mesg') as mock_mesg: imap._dump_ur(untagged_resp_dict) mock_mesg.assert_called_with( @@ -962,9 +972,9 @@ def test_ssl_verified(self): with self.assertRaisesRegex( ssl.CertificateError, "IP address mismatch, certificate is not valid for " - "'127.0.0.1'"): + f"'({socket_helper.HOSTv4}|{socket_helper.HOSTv6})'"): with self.reaped_server(SimpleIMAPHandler) as server: - client = self.imap_class(*server.server_address, + client = self.imap_class(*server.server_address[:2], ssl_context=ssl_context) client.shutdown() diff --git a/Lib/test/test_largefile.py b/Lib/test/test_largefile.py index 8f6bec16200534..ce2920dc6702f3 100644 --- a/Lib/test/test_largefile.py +++ b/Lib/test/test_largefile.py @@ -221,10 +221,11 @@ def run(sock): # bit more tolerance. @skip_no_disk_space(TESTFN, size * 2.5) def test_it(self): - port = socket_helper.find_unused_port() - with socket.create_server(("", port)) as sock: + with socket_helper.bind_ip_socket_and_port() as sock_port: + sock, port = sock_port + sock.listen() self.tcp_server(sock) - with socket.create_connection(("127.0.0.1", port)) as client: + with socket.create_connection((socket_helper.HOST, port)) as client: with open(TESTFN, 'rb') as f: client.sendfile(f) self.tearDown() diff --git a/Lib/test/test_nntplib.py b/Lib/test/test_nntplib.py index 4f0592188f8443..f565cc9ba27f2f 100644 --- a/Lib/test/test_nntplib.py +++ b/Lib/test/test_nntplib.py @@ -1585,8 +1585,8 @@ def nntp_class(*pos, **kw): class LocalServerTests(unittest.TestCase): def setUp(self): - sock = socket.socket() - port = socket_helper.bind_port(sock) + sock, port = socket_helper.get_bound_ip_socket_and_port() + self.addCleanup(sock.close) sock.listen() self.background = threading.Thread( target=self.run_server, args=(sock,)) diff --git a/Lib/test/test_robotparser.py b/Lib/test/test_robotparser.py index b0bed431d4b059..b877092ae15442 100644 --- a/Lib/test/test_robotparser.py +++ b/Lib/test/test_robotparser.py @@ -1,5 +1,6 @@ import io import os +import socket import threading import unittest import urllib.robotparser @@ -314,7 +315,14 @@ def setUp(self): # clear _opener global variable self.addCleanup(urllib.request.urlcleanup) - self.server = HTTPServer((socket_helper.HOST, 0), RobotHandler) + try: + self.server = HTTPServer((socket_helper.HOST, 0), RobotHandler) + except OSError: + if not socket_helper.IPV6_ENABLED: + raise + class IPv6HTTPServer(HTTPServer): + address_family = socket.AF_INET6 + self.server = IPv6HTTPServer((socket_helper.HOST, 0), RobotHandler) self.t = threading.Thread( name='HTTPServer serving', From ee00e8bd8d9dfcf548db79096800f3855b2c64e4 Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Tue, 18 May 2021 01:09:14 -0700 Subject: [PATCH 03/28] Make test_support pass on IPv6 only. --- Lib/test/test_support.py | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py index 55d78b733353d2..d4167fd31082fc 100644 --- a/Lib/test/test_support.py +++ b/Lib/test/test_support.py @@ -94,20 +94,36 @@ def test_forget(self): os_helper.unlink(mod_filename) os_helper.rmtree('__pycache__') - def test_HOST(self): - s = socket.create_server((socket_helper.HOST, 0)) - s.close() + def test_bind_ip_socket_and_port_HOST(self): + """This also tests get_bound_ip_socket_and_port() indirectly.""" + with socket_helper.bind_ip_socket_and_port( + hostname=socket_helper.HOST): + pass - def test_find_unused_port(self): + @unittest.skipUnless(socket_helper.IPV4_ENABLED, "IPv4 required") + def test_find_unused_port_ipv4(self): port = socket_helper.find_unused_port() s = socket.create_server((socket_helper.HOST, port)) s.close() - def test_bind_port(self): - s = socket.socket() - socket_helper.bind_port(s) - s.listen() - s.close() + @unittest.skipUnless(socket_helper.IPV6_ENABLED, "IPv6 required") + def test_find_unused_port_ipv6(self): + port = socket_helper.find_unused_port() + with socket.socket(socket.AF_INET6) as s: + s.bind((socket_helper.HOST, port)) + s.listen() + + @unittest.skipUnless(socket_helper.IPV4_ENABLED, "IPv4 required") + def test_bind_port_ipv4(self): + with socket.socket(socket.AF_INET) as s: + socket_helper.bind_port(s) + s.listen() + + @unittest.skipUnless(socket_helper.IPV6_ENABLED, "IPv6 required") + def test_bind_port_ipv6(self): + with socket.socket(socket.AF_INET6) as s: + socket_helper.bind_port(s) + s.listen() # Tests for temp_dir() From 29efd49babb33020293c060dd525adbb1bf00822 Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Tue, 18 May 2021 01:14:06 -0700 Subject: [PATCH 04/28] Make test_os work on IPv6-only. --- Lib/test/test_os.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py index 8b3d1feb78fe36..74e486efef8fa6 100644 --- a/Lib/test/test_os.py +++ b/Lib/test/test_os.py @@ -3234,7 +3234,11 @@ def handle_error(self): def __init__(self, address): threading.Thread.__init__(self) asyncore.dispatcher.__init__(self) - self.create_socket(socket.AF_INET, socket.SOCK_STREAM) + if socket_helper.IPV4_ENABLED: + family = socket.AF_INET + elif socket_helper.IPV6_ENABLED: + family = socket.AF_INET6 + self.create_socket(family, socket.SOCK_STREAM) self.bind(address) self.listen(5) self.host, self.port = self.socket.getsockname()[:2] @@ -3316,7 +3320,7 @@ def tearDownClass(cls): def setUp(self): self.server = SendfileTestServer((socket_helper.HOST, 0)) self.server.start() - self.client = socket.socket() + self.client = socket.socket(self.server.socket.family) self.client.connect((self.server.host, self.server.port)) self.client.settimeout(1) # synchronize by waiting for "220 ready" response From d51dc792e131a008f90837b3917609a82ea74acd Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Tue, 18 May 2021 01:21:10 -0700 Subject: [PATCH 05/28] fix test_telnetlib IPv6-only, introduce get_family helper. --- Lib/test/support/socket_helper.py | 20 ++++++++++++-------- Lib/test/test_telnetlib.py | 3 ++- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/Lib/test/support/socket_helper.py b/Lib/test/support/socket_helper.py index 12f3784a3c861f..891692b27aa4ed 100644 --- a/Lib/test/support/socket_helper.py +++ b/Lib/test/support/socket_helper.py @@ -149,6 +149,17 @@ def bind_port(sock, host=HOST): return port +def get_family(): + """Get a host appropriate socket AF_INET or AF_INET6 family.""" + if IPV6_ENABLED: + return socket.AF_INET6 + elif IPV4_ENABLED: + return socket.AF_INET + else: + raise support.TestFailed( + "At least one of IPv4 or IPv6 must be enabled.") + + def get_bound_ip_socket_and_port(*, hostname=HOST, socktype=socket.SOCK_STREAM): """Get an IP socket bound to a port as a sock, port tuple. @@ -158,14 +169,7 @@ def get_bound_ip_socket_and_port(*, hostname=HOST, socktype=socket.SOCK_STREAM): Prefer the bind_ip_socket_and_port context manager when possible. """ - if IPV6_ENABLED: - family = socket.AF_INET6 - sock = socket.socket(socket.AF_INET6, socktype) - elif IPV4_ENABLED: - family = socket.AF_INET - else: - raise support.TestFailed( - "At least one of IPv4 or IPv6 must be enabled.") + family = get_family() sock = socket.socket(family, socktype) try: port = bind_port(sock) diff --git a/Lib/test/test_telnetlib.py b/Lib/test/test_telnetlib.py index 41c4fcd4195e3a..7a7212be7c1252 100644 --- a/Lib/test/test_telnetlib.py +++ b/Lib/test/test_telnetlib.py @@ -25,7 +25,8 @@ class GeneralTests(unittest.TestCase): def setUp(self): self.evt = threading.Event() - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + family = socket_helper.get_family() + self.sock = socket.socket(family, socket.SOCK_STREAM) self.sock.settimeout(60) # Safety net. Look issue 11812 self.port = socket_helper.bind_port(self.sock) self.thread = threading.Thread(target=server, args=(self.evt,self.sock)) From 0c95594cd3104dc559ced4ea8954fec1e31e0221 Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Tue, 18 May 2021 01:28:56 -0700 Subject: [PATCH 06/28] Fix test_urllib2_localnet for IPv6-only. --- Lib/test/ssl_servers.py | 2 ++ Lib/test/test_urllib2_localnet.py | 8 +++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/Lib/test/ssl_servers.py b/Lib/test/ssl_servers.py index a4bd7455d47e76..d3b0f4315573dd 100644 --- a/Lib/test/ssl_servers.py +++ b/Lib/test/ssl_servers.py @@ -20,6 +20,8 @@ class HTTPSServer(_HTTPServer): + address_family = socket_helper.get_family() + def __init__(self, server_address, handler_class, context): _HTTPServer.__init__(self, server_address, handler_class) self.context = context diff --git a/Lib/test/test_urllib2_localnet.py b/Lib/test/test_urllib2_localnet.py index ebb43c30b4d505..bb4da9e813528e 100644 --- a/Lib/test/test_urllib2_localnet.py +++ b/Lib/test/test_urllib2_localnet.py @@ -9,6 +9,7 @@ import hashlib from test.support import hashlib_helper +from test.support import socket_helper from test.support import threading_helper from test.support import warnings_helper @@ -30,6 +31,7 @@ class LoopbackHttpServer(http.server.HTTPServer): """HTTP server w/ a few modifications that make it useful for loopback testing purposes. """ + address_family = socket_helper.get_family() def __init__(self, server_address, RequestHandlerClass): http.server.HTTPServer.__init__(self, @@ -60,7 +62,7 @@ def __init__(self, request_handler): self._stop_server = False self.ready = threading.Event() request_handler.protocol_version = "HTTP/1.0" - self.httpd = LoopbackHttpServer(("127.0.0.1", 0), + self.httpd = LoopbackHttpServer((socket_helper.HOST, 0), request_handler) self.port = self.httpd.server_port @@ -290,7 +292,7 @@ def http_server_with_basic_auth_handler(*args, **kwargs): return BasicAuthHandler(*args, **kwargs) self.server = LoopbackHttpServerThread(http_server_with_basic_auth_handler) self.addCleanup(self.stop_server) - self.server_url = 'http://127.0.0.1:%s' % self.server.port + self.server_url = f'http://{socket_helper.HOST}:{self.server.port}' self.server.start() self.server.ready.wait() @@ -346,7 +348,7 @@ def create_fake_proxy_handler(*args, **kwargs): self.addCleanup(self.stop_server) self.server.start() self.server.ready.wait() - proxy_url = "http://127.0.0.1:%d" % self.server.port + proxy_url = f"http://{socket_helper.HOST}:{self.server.port}" handler = urllib.request.ProxyHandler({"http" : proxy_url}) self.proxy_digest_handler = urllib.request.ProxyDigestAuthHandler() self.opener = urllib.request.build_opener( From 7938fdc119eea4359ed43d3740ee87dfb551b138 Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Tue, 18 May 2021 02:28:04 -0700 Subject: [PATCH 07/28] tcp_socket() to replace socket.socket() in tests. --- Lib/test/support/socket_helper.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/Lib/test/support/socket_helper.py b/Lib/test/support/socket_helper.py index 891692b27aa4ed..d46b30fb5473e7 100644 --- a/Lib/test/support/socket_helper.py +++ b/Lib/test/support/socket_helper.py @@ -160,6 +160,17 @@ def get_family(): "At least one of IPv4 or IPv6 must be enabled.") +def tcp_socket(): + """Get a new host appropriate IPv4 or IPv6 TCP STREAM socket.socket().""" + if IPV4_ENABLED: + return socket.socket(socket.AF_INET, socket.SOCK_STREAM) + elif IPV6_ENABLED: + return socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + else: + raise support.TestFailed( + "At least one of IPv4 or IPv6 must be enabled.") + + def get_bound_ip_socket_and_port(*, hostname=HOST, socktype=socket.SOCK_STREAM): """Get an IP socket bound to a port as a sock, port tuple. From 643675782a769841583f98d37aee901116b00659 Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Tue, 18 May 2021 02:28:36 -0700 Subject: [PATCH 08/28] make test_ssl pass on IPv6-only. --- Lib/test/test_ssl.py | 248 ++++++++++++++++++++++--------------------- 1 file changed, 127 insertions(+), 121 deletions(-) diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index 00d5eff81537d1..78a2bdfca8eb9d 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -366,7 +366,7 @@ def test_ssl_types(self): def test_private_init(self): with self.assertRaisesRegex(TypeError, "public constructor"): - with socket.socket() as s: + with socket_helper.tcp_socket() as s: ssl.SSLSocket(s) def test_str_for_enums(self): @@ -550,7 +550,7 @@ def test_openssl_version(self): def test_refcycle(self): # Issue #7943: an SSL object doesn't create reference cycles with # itself. - s = socket.socket(socket.AF_INET) + s = socket.socket(socket_helper.get_family()) ss = test_wrap_socket(s) wr = weakref.ref(ss) with warnings_helper.check_warnings(("", ResourceWarning)): @@ -560,7 +560,7 @@ def test_refcycle(self): def test_wrapped_unconnected(self): # Methods on an unconnected SSLSocket propagate the original # OSError raise by the underlying socket object. - s = socket.socket(socket.AF_INET) + s = socket.socket(socket_helper.get_family()) with test_wrap_socket(s) as ss: self.assertRaises(OSError, ss.recv, 1) self.assertRaises(OSError, ss.recv_into, bytearray(b'x')) @@ -579,14 +579,14 @@ def test_timeout(self): # Issue #8524: when creating an SSL socket, the timeout of the # original socket should be retained. for timeout in (None, 0.0, 5.0): - s = socket.socket(socket.AF_INET) + s = socket.socket(socket_helper.get_family()) s.settimeout(timeout) with test_wrap_socket(s) as ss: self.assertEqual(timeout, ss.gettimeout()) @ignore_deprecation def test_errors_sslwrap(self): - sock = socket.socket() + sock = socket_helper.tcp_socket() self.assertRaisesRegex(ValueError, "certfile must be specified", ssl.wrap_socket, sock, keyfile=CERTFILE) @@ -600,16 +600,16 @@ def test_errors_sslwrap(self): self.assertRaisesRegex(ValueError, "can't connect in server-side mode", s.connect, (HOST, 8080)) with self.assertRaises(OSError) as cm: - with socket.socket() as sock: + with socket_helper.tcp_socket() as sock: ssl.wrap_socket(sock, certfile=NONEXISTINGCERT) self.assertEqual(cm.exception.errno, errno.ENOENT) with self.assertRaises(OSError) as cm: - with socket.socket() as sock: + with socket_helper.tcp_socket() as sock: ssl.wrap_socket(sock, certfile=CERTFILE, keyfile=NONEXISTINGCERT) self.assertEqual(cm.exception.errno, errno.ENOENT) with self.assertRaises(OSError) as cm: - with socket.socket() as sock: + with socket_helper.tcp_socket() as sock: ssl.wrap_socket(sock, certfile=NONEXISTINGCERT, keyfile=NONEXISTINGCERT) self.assertEqual(cm.exception.errno, errno.ENOENT) @@ -618,7 +618,7 @@ def bad_cert_test(self, certfile): """Check that trying to use the given client certificate fails""" certfile = os.path.join(os.path.dirname(__file__) or os.curdir, certfile) - sock = socket.socket() + sock = socket_helper.tcp_socket() self.addCleanup(sock.close) with self.assertRaises(ssl.SSLError): test_wrap_socket(sock, @@ -838,34 +838,35 @@ def fail(cert, hostname): def test_server_side(self): # server_hostname doesn't work for server sockets ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - with socket.socket() as sock: + with socket_helper.tcp_socket() as sock: self.assertRaises(ValueError, ctx.wrap_socket, sock, True, server_hostname="some.hostname") def test_unknown_channel_binding(self): # should raise ValueError for unknown type - s = socket.create_server(('127.0.0.1', 0)) - c = socket.socket(socket.AF_INET) - c.connect(s.getsockname()) - with test_wrap_socket(c, do_handshake_on_connect=False) as ss: - with self.assertRaises(ValueError): - ss.get_channel_binding("unknown-type") - s.close() + with socket_helper.bind_ip_socket_and_port() as sock_port: + s = sock_port[0] + s.listen() + c = socket.socket(s.family) + c.connect(s.getsockname()) + with test_wrap_socket(c, do_handshake_on_connect=False) as ss: + with self.assertRaises(ValueError): + ss.get_channel_binding("unknown-type") @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES, "'tls-unique' channel binding not available") def test_tls_unique_channel_binding(self): # unconnected should return None for known type - s = socket.socket(socket.AF_INET) + s = socket.socket(socket_helper.get_family()) with test_wrap_socket(s) as ss: self.assertIsNone(ss.get_channel_binding("tls-unique")) # the same for server-side - s = socket.socket(socket.AF_INET) + s = socket.socket(socket_helper.get_family()) with test_wrap_socket(s, server_side=True, certfile=CERTFILE) as ss: self.assertIsNone(ss.get_channel_binding("tls-unique")) def test_dealloc_warn(self): - ss = test_wrap_socket(socket.socket(socket.AF_INET)) + ss = test_wrap_socket(socket.socket(socket_helper.get_family())) r = repr(ss) with self.assertWarns(ResourceWarning) as cm: ss = None @@ -981,7 +982,7 @@ def test_purpose_enum(self): '1.3.6.1.5.5.7.3.2') def test_unsupported_dtls(self): - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s = socket.socket(socket_helper.get_family(), socket.SOCK_DGRAM) self.addCleanup(s.close) with self.assertRaises(NotImplementedError) as cx: test_wrap_socket(s, cert_reqs=ssl.CERT_NONE) @@ -1057,10 +1058,10 @@ def local_february_name(): self.cert_time_fail(local_february_name() + " 9 00:00:00 2007 GMT") def test_connect_ex_error(self): - server = socket.socket(socket.AF_INET) + server = socket.socket(socket_helper.get_family()) self.addCleanup(server.close) port = socket_helper.bind_port(server) # Reserve port but don't listen - s = test_wrap_socket(socket.socket(socket.AF_INET), + s = test_wrap_socket(socket.socket(server.family), cert_reqs=ssl.CERT_REQUIRED) self.addCleanup(s.close) rc = s.connect_ex((HOST, port)) @@ -1077,7 +1078,7 @@ def test_read_write_zero(self): client_context, server_context, hostname = testing_context() server = ThreadedEchoServer(context=server_context) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) self.assertEqual(s.recv(0), b"") @@ -1752,7 +1753,7 @@ class MySSLObject(ssl.SSLObject): ctx.sslsocket_class = MySSLSocket ctx.sslobject_class = MySSLObject - with ctx.wrap_socket(socket.socket(), server_side=True) as sock: + with ctx.wrap_socket(socket_helper.tcp_socket(), server_side=True) as sock: self.assertIsInstance(sock, MySSLSocket) obj = ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO()) self.assertIsInstance(obj, MySSLObject) @@ -1804,8 +1805,10 @@ def test_subclass(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE - with socket.create_server(("127.0.0.1", 0)) as s: - c = socket.create_connection(s.getsockname()) + with socket_helper.bind_ip_socket_and_port() as sock_port: + s = sock_port[0] + s.listen() + c = socket.create_connection(s.getsockname()[:2]) c.setblocking(False) with ctx.wrap_socket(c, False, do_handshake_on_connect=False) as c: with self.assertRaises(ssl.SSLWantReadError) as cm: @@ -1947,19 +1950,20 @@ def setUp(self): self.server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) self.server_context.load_cert_chain(SIGNED_CERTFILE) server = ThreadedEchoServer(context=self.server_context) + self.family = server.sock.family self.server_addr = (HOST, server.port) server.__enter__() self.addCleanup(server.__exit__, None, None, None) def test_connect(self): - with test_wrap_socket(socket.socket(socket.AF_INET), + with test_wrap_socket(socket.socket(self.family), cert_reqs=ssl.CERT_NONE) as s: s.connect(self.server_addr) self.assertEqual({}, s.getpeercert()) self.assertFalse(s.server_side) # this should succeed because we specify the root cert - with test_wrap_socket(socket.socket(socket.AF_INET), + with test_wrap_socket(socket.socket(self.family), cert_reqs=ssl.CERT_REQUIRED, ca_certs=SIGNING_CA) as s: s.connect(self.server_addr) @@ -1970,7 +1974,7 @@ def test_connect_fail(self): # This should fail because we have no verification certs. Connection # failure crashes ThreadedEchoServer, so run this in an independent # test method. - s = test_wrap_socket(socket.socket(socket.AF_INET), + s = test_wrap_socket(socket.socket(self.family), cert_reqs=ssl.CERT_REQUIRED) self.addCleanup(s.close) self.assertRaisesRegex(ssl.SSLError, "certificate verify failed", @@ -1978,7 +1982,7 @@ def test_connect_fail(self): def test_connect_ex(self): # Issue #11326: check connect_ex() implementation - s = test_wrap_socket(socket.socket(socket.AF_INET), + s = test_wrap_socket(socket.socket(self.family), cert_reqs=ssl.CERT_REQUIRED, ca_certs=SIGNING_CA) self.addCleanup(s.close) @@ -1988,7 +1992,7 @@ def test_connect_ex(self): def test_non_blocking_connect_ex(self): # Issue #11326: non-blocking connect_ex() should allow handshake # to proceed after the socket gets ready. - s = test_wrap_socket(socket.socket(socket.AF_INET), + s = test_wrap_socket(socket.socket(self.family), cert_reqs=ssl.CERT_REQUIRED, ca_certs=SIGNING_CA, do_handshake_on_connect=False) @@ -2016,17 +2020,17 @@ def test_connect_with_context(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE - with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s: + with ctx.wrap_socket(socket.socket(self.family)) as s: s.connect(self.server_addr) self.assertEqual({}, s.getpeercert()) # Same with a server hostname - with ctx.wrap_socket(socket.socket(socket.AF_INET), + with ctx.wrap_socket(socket.socket(self.family), server_hostname="dummy") as s: s.connect(self.server_addr) ctx.verify_mode = ssl.CERT_REQUIRED # This should succeed because we specify the root cert ctx.load_verify_locations(SIGNING_CA) - with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s: + with ctx.wrap_socket(socket.socket(self.family)) as s: s.connect(self.server_addr) cert = s.getpeercert() self.assertTrue(cert) @@ -2037,7 +2041,7 @@ def test_connect_with_context_fail(self): # test method. ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) s = ctx.wrap_socket( - socket.socket(socket.AF_INET), + socket.socket(self.family), server_hostname=SIGNED_CERTFILE_HOSTNAME ) self.addCleanup(s.close) @@ -2052,7 +2056,7 @@ def test_connect_capath(self): # filename) for this test to be portable across OpenSSL releases. ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_verify_locations(capath=CAPATH) - with ctx.wrap_socket(socket.socket(socket.AF_INET), + with ctx.wrap_socket(socket.socket(self.family), server_hostname=SIGNED_CERTFILE_HOSTNAME) as s: s.connect(self.server_addr) cert = s.getpeercert() @@ -2061,7 +2065,7 @@ def test_connect_capath(self): # Same with a bytes `capath` argument ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_verify_locations(capath=BYTES_CAPATH) - with ctx.wrap_socket(socket.socket(socket.AF_INET), + with ctx.wrap_socket(socket.socket(self.family), server_hostname=SIGNED_CERTFILE_HOSTNAME) as s: s.connect(self.server_addr) cert = s.getpeercert() @@ -2073,7 +2077,7 @@ def test_connect_cadata(self): der = ssl.PEM_cert_to_DER_cert(pem) ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_verify_locations(cadata=pem) - with ctx.wrap_socket(socket.socket(socket.AF_INET), + with ctx.wrap_socket(socket.socket(self.family), server_hostname=SIGNED_CERTFILE_HOSTNAME) as s: s.connect(self.server_addr) cert = s.getpeercert() @@ -2082,7 +2086,7 @@ def test_connect_cadata(self): # same with DER ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_verify_locations(cadata=der) - with ctx.wrap_socket(socket.socket(socket.AF_INET), + with ctx.wrap_socket(socket.socket(self.family), server_hostname=SIGNED_CERTFILE_HOSTNAME) as s: s.connect(self.server_addr) cert = s.getpeercert() @@ -2093,7 +2097,7 @@ def test_makefile_close(self): # Issue #5238: creating a file-like object with makefile() shouldn't # delay closing the underlying "real socket" (here tested with its # file descriptor, hence skipping the test under Windows). - ss = test_wrap_socket(socket.socket(socket.AF_INET)) + ss = test_wrap_socket(socket.socket(self.family)) ss.connect(self.server_addr) fd = ss.fileno() f = ss.makefile() @@ -2108,7 +2112,7 @@ def test_makefile_close(self): self.assertEqual(e.exception.errno, errno.EBADF) def test_non_blocking_handshake(self): - s = socket.socket(socket.AF_INET) + s = socket.socket(self.family) s.connect(self.server_addr) s.setblocking(False) s = test_wrap_socket(s, @@ -2167,15 +2171,15 @@ def servername_cb(ssl_sock, server_name, initial_context): timeout=0.1) def test_ciphers(self): - with test_wrap_socket(socket.socket(socket.AF_INET), + with test_wrap_socket(socket.socket(self.family), cert_reqs=ssl.CERT_NONE, ciphers="ALL") as s: s.connect(self.server_addr) - with test_wrap_socket(socket.socket(socket.AF_INET), + with test_wrap_socket(socket.socket(self.family), cert_reqs=ssl.CERT_NONE, ciphers="DEFAULT") as s: s.connect(self.server_addr) # Error checking can happen at instantiation or when connecting with self.assertRaisesRegex(ssl.SSLError, "No cipher can be selected"): - with socket.socket(socket.AF_INET) as sock: + with socket.socket(self.family) as sock: s = test_wrap_socket(sock, cert_reqs=ssl.CERT_NONE, ciphers="^$:,;?*'dorothyx") s.connect(self.server_addr) @@ -2185,7 +2189,7 @@ def test_get_ca_certs_capath(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_verify_locations(capath=CAPATH) self.assertEqual(ctx.get_ca_certs(), []) - with ctx.wrap_socket(socket.socket(socket.AF_INET), + with ctx.wrap_socket(socket.socket(self.family), server_hostname='localhost') as s: s.connect(self.server_addr) cert = s.getpeercert() @@ -2198,7 +2202,7 @@ def test_context_setget(self): ctx1.load_verify_locations(capath=CAPATH) ctx2 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx2.load_verify_locations(capath=CAPATH) - s = socket.socket(socket.AF_INET) + s = socket.socket(self.family) with ctx1.wrap_socket(s, server_hostname='localhost') as ss: ss.connect(self.server_addr) self.assertIs(ss.context, ctx1) @@ -2245,7 +2249,7 @@ def ssl_io_loop(self, sock, incoming, outgoing, func, *args, **kwargs): return ret def test_bio_handshake(self): - sock = socket.socket(socket.AF_INET) + sock = socket.socket(self.family) self.addCleanup(sock.close) sock.connect(self.server_addr) incoming = ssl.MemoryBIO() @@ -2279,7 +2283,7 @@ def test_bio_handshake(self): self.assertRaises(ssl.SSLError, sslobj.write, b'foo') def test_bio_read_write_data(self): - sock = socket.socket(socket.AF_INET) + sock = socket.socket(self.family) self.addCleanup(sock.close) sock.connect(self.server_addr) incoming = ssl.MemoryBIO() @@ -2298,6 +2302,9 @@ def test_bio_read_write_data(self): class NetworkedTests(unittest.TestCase): + @unittest.skipUnless( + socket_helper.IPV4_ENABLED, + f"{REMOTE_HOST} was IPv4 only at the time of this writing.") def test_timeout_connect_ex(self): # Issue #12065: on a timeout, connect_ex() should return the original # errno (mimicking the behaviour of non-SSL sockets). @@ -2563,8 +2570,7 @@ def __init__(self, certificate=None, ssl_version=None, self.chatty = chatty self.connectionchatty = connectionchatty self.starttls_server = starttls_server - self.sock = socket.socket() - self.port = socket_helper.bind_port(self.sock) + self.sock, self.port = socket_helper.get_bound_ip_socket_and_port() self.flag = None self.active = False self.selected_alpn_protocols = [] @@ -2681,17 +2687,18 @@ def handle_error(self): def __init__(self, certfile): self.certfile = certfile - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.port = socket_helper.bind_port(sock, '') + sock, self.port = socket_helper.get_bound_ip_socket_and_port() asyncore.dispatcher.__init__(self, sock) self.listen(5) def handle_accepted(self, sock_obj, addr): if support.verbose: - sys.stdout.write(" server: new connection from %s:%s\n" %addr) + sys.stdout.write(" server: new connection from %s:%s\n" % addr[:2]) self.ConnectionHandler(sock_obj, self.certfile) def handle_error(self): + if support.verbose: + sys.stdout.write(" server: error:\n%s\n" % traceback.format_exc()) raise def __init__(self, certfile): @@ -2752,7 +2759,7 @@ def server_params_test(client_context, server_context, indata=b"FOO\n", chatty=chatty, connectionchatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=sni_name, session=session) as s: s.connect((HOST, server.port)) for arg in [indata, bytearray(indata), memoryview(indata)]: @@ -2914,7 +2921,7 @@ def test_getpeercert(self): client_context, server_context, hostname = testing_context() server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), do_handshake_on_connect=False, server_hostname=hostname) as s: s.connect((HOST, server.port)) @@ -2955,7 +2962,7 @@ def test_crl_check(self): # VERIFY_DEFAULT should pass server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) cert = s.getpeercert() @@ -2966,7 +2973,7 @@ def test_crl_check(self): server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: with self.assertRaisesRegex(ssl.SSLError, "certificate verify failed"): @@ -2977,7 +2984,7 @@ def test_crl_check(self): server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) cert = s.getpeercert() @@ -2992,7 +2999,7 @@ def test_check_hostname(self): # correct hostname should verify server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) cert = s.getpeercert() @@ -3001,7 +3008,7 @@ def test_check_hostname(self): # incorrect hostname should raise an exception server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname="invalid") as s: with self.assertRaisesRegex( ssl.CertificateError, @@ -3011,7 +3018,7 @@ def test_check_hostname(self): # missing server_hostname arg should cause an exception, too server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with socket.socket() as s: + with socket_helper.tcp_socket() as s: with self.assertRaisesRegex(ValueError, "check_hostname requires server_hostname"): client_context.wrap_socket(s) @@ -3027,7 +3034,7 @@ def test_hostname_checks_common_name(self): # default cert has a SAN server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) @@ -3035,7 +3042,7 @@ def test_hostname_checks_common_name(self): client_context.hostname_checks_common_name = False server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: with self.assertRaises(ssl.SSLCertVerificationError): s.connect((HOST, server.port)) @@ -3053,7 +3060,7 @@ def test_ecc_cert(self): # correct hostname should verify server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) cert = s.getpeercert() @@ -3079,7 +3086,7 @@ def test_dual_rsa_ecc(self): # correct hostname should verify server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) cert = s.getpeercert() @@ -3127,7 +3134,7 @@ def test_check_hostname_idn(self): for server_hostname, expected_hostname in idn_hostnames: server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with context.wrap_socket(socket.socket(), + with context.wrap_socket(socket_helper.tcp_socket(), server_hostname=server_hostname) as s: self.assertEqual(s.server_hostname, expected_hostname) s.connect((HOST, server.port)) @@ -3138,7 +3145,7 @@ def test_check_hostname_idn(self): # incorrect hostname should raise an exception server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with context.wrap_socket(socket.socket(), + with context.wrap_socket(socket_helper.tcp_socket(), server_hostname="python.example.org") as s: with self.assertRaises(ssl.CertificateError): s.connect((HOST, server.port)) @@ -3162,7 +3169,7 @@ def test_wrong_cert_tls12(self): ) with server, \ - client_context.wrap_socket(socket.socket(), + client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: try: # Expect either an SSL error about the server rejecting @@ -3193,7 +3200,7 @@ def test_wrong_cert_tls13(self): context=server_context, chatty=True, connectionchatty=True, ) with server, \ - client_context.wrap_socket(socket.socket(), + client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: # TLS 1.3 perform client cert exchange after handshake s.connect((HOST, server.port)) @@ -3226,7 +3233,7 @@ def test_rude_shutdown(self): listener_ready = threading.Event() listener_gone = threading.Event() - s = socket.socket() + s = socket_helper.tcp_socket() port = socket_helper.bind_port(s, HOST) # `listener` runs in a thread. It sits in an accept() until @@ -3243,7 +3250,7 @@ def listener(): def connector(): listener_ready.wait() - with socket.socket() as c: + with socket_helper.tcp_socket() as c: c.connect((HOST, port)) listener_gone.wait() try: @@ -3271,7 +3278,7 @@ def test_ssl_cert_verify_error(self): server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with context.wrap_socket(socket.socket(), + with context.wrap_socket(socket_helper.tcp_socket(), server_hostname=SIGNED_CERTFILE_HOSTNAME) as s: try: s.connect((HOST, server.port)) @@ -3422,7 +3429,7 @@ def test_starttls(self): connectionchatty=True) wrapped = False with server: - s = socket.socket() + s = socket_helper.tcp_socket() s.setblocking(True) s.connect((HOST, server.port)) if support.verbose: @@ -3503,8 +3510,8 @@ def test_asyncore_server(self): indata = b"FOO\n" server = AsyncoreEchoServer(CERTFILE) with server: - s = test_wrap_socket(socket.socket()) - s.connect(('127.0.0.1', server.port)) + s = test_wrap_socket(socket_helper.tcp_socket()) + s.connect((socket_helper.HOST, server.port)) if support.verbose: sys.stdout.write( " client: sending %r...\n" % indata) @@ -3536,7 +3543,7 @@ def test_recv_send(self): chatty=True, connectionchatty=False) with server: - s = test_wrap_socket(socket.socket(), + s = test_wrap_socket(socket_helper.tcp_socket(), server_side=False, certfile=CERTFILE, ca_certs=CERTFILE, @@ -3688,7 +3695,7 @@ def test_nonblocking_send(self): chatty=True, connectionchatty=False) with server: - s = test_wrap_socket(socket.socket(), + s = test_wrap_socket(socket_helper.tcp_socket(), server_side=False, certfile=CERTFILE, ca_certs=CERTFILE, @@ -3711,9 +3718,8 @@ def fill_buffer(): def test_handshake_timeout(self): # Issue #5103: SSL handshake must respect the socket timeout - server = socket.socket(socket.AF_INET) - host = "127.0.0.1" - port = socket_helper.bind_port(server) + host = socket_helper.HOST + server, port = socket_helper.get_bound_ip_socket_and_port(hostname=host) started = threading.Event() finish = False @@ -3736,7 +3742,7 @@ def serve(): try: try: - c = socket.socket(socket.AF_INET) + c = socket.socket(server.family) c.settimeout(0.2) c.connect((host, port)) # Will attempt handshake and time out @@ -3745,7 +3751,7 @@ def serve(): finally: c.close() try: - c = socket.socket(socket.AF_INET) + c = socket.socket(server.family) c = test_wrap_socket(c) c.settimeout(0.2) # Will attempt handshake and time out @@ -3762,9 +3768,9 @@ def test_server_accept(self): # Issue #16357: accept() on a SSLSocket created through # SSLContext.wrap_socket(). client_ctx, server_ctx, hostname = testing_context() - server = socket.socket(socket.AF_INET) - host = "127.0.0.1" - port = socket_helper.bind_port(server) + host = socket_helper.HOST + server, port = socket_helper.get_bound_ip_socket_and_port( + hostname=host) server = server_ctx.wrap_socket(server, server_side=True) self.assertTrue(server.server_side) @@ -3784,7 +3790,7 @@ def serve(): # Client wait until server setup and perform a connect. evt.wait() client = client_ctx.wrap_socket( - socket.socket(), server_hostname=hostname + socket_helper.tcp_socket(), server_hostname=hostname ) client.connect((hostname, port)) client.send(b'data') @@ -3801,7 +3807,7 @@ def serve(): def test_getpeercert_enotconn(self): context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) context.check_hostname = False - with context.wrap_socket(socket.socket()) as sock: + with context.wrap_socket(socket_helper.tcp_socket()) as sock: with self.assertRaises(OSError) as cm: sock.getpeercert() self.assertEqual(cm.exception.errno, errno.ENOTCONN) @@ -3809,7 +3815,7 @@ def test_getpeercert_enotconn(self): def test_do_handshake_enotconn(self): context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) context.check_hostname = False - with context.wrap_socket(socket.socket()) as sock: + with context.wrap_socket(socket_helper.tcp_socket()) as sock: with self.assertRaises(OSError) as cm: sock.do_handshake() self.assertEqual(cm.exception.errno, errno.ENOTCONN) @@ -3822,7 +3828,7 @@ def test_no_shared_ciphers(self): client_context.set_ciphers("AES128") server_context.set_ciphers("AES256") with ThreadedEchoServer(context=server_context) as server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: with self.assertRaises(OSError): s.connect((HOST, server.port)) @@ -3839,7 +3845,7 @@ def test_version_basic(self): with ThreadedEchoServer(CERTFILE, ssl_version=ssl.PROTOCOL_TLS_SERVER, chatty=False) as server: - with context.wrap_socket(socket.socket()) as s: + with context.wrap_socket(socket_helper.tcp_socket()) as s: self.assertIs(s.version(), None) self.assertIs(s._sslobj, None) s.connect((HOST, server.port)) @@ -3852,7 +3858,7 @@ def test_tls1_3(self): client_context, server_context, hostname = testing_context() client_context.minimum_version = ssl.TLSVersion.TLSv1_3 with ThreadedEchoServer(context=server_context) as server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) self.assertIn(s.cipher()[0], { @@ -3875,7 +3881,7 @@ def test_min_max_version_tlsv1_2(self): server_context.maximum_version = ssl.TLSVersion.TLSv1_2 with ThreadedEchoServer(context=server_context) as server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) self.assertEqual(s.version(), 'TLSv1.2') @@ -3892,7 +3898,7 @@ def test_min_max_version_tlsv1_1(self): seclevel_workaround(client_context, server_context) with ThreadedEchoServer(context=server_context) as server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) self.assertEqual(s.version(), 'TLSv1.1') @@ -3910,7 +3916,7 @@ def test_min_max_version_mismatch(self): seclevel_workaround(client_context, server_context) with ThreadedEchoServer(context=server_context) as server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: with self.assertRaises(ssl.SSLError) as e: s.connect((HOST, server.port)) @@ -3925,7 +3931,7 @@ def test_min_max_version_sslv3(self): seclevel_workaround(client_context, server_context) with ThreadedEchoServer(context=server_context) as server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) self.assertEqual(s.version(), 'SSLv3') @@ -3942,7 +3948,7 @@ def test_default_ecdh_curve(self): # our default cipher list should prefer ECDH-based ciphers # automatically. with ThreadedEchoServer(context=server_context) as server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) self.assertIn("ECDH", s.cipher()[0]) @@ -3962,7 +3968,7 @@ def test_tls_unique_channel_binding(self): with server: with client_context.wrap_socket( - socket.socket(), + socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) # get the data @@ -3986,7 +3992,7 @@ def test_tls_unique_channel_binding(self): # now, again with client_context.wrap_socket( - socket.socket(), + socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) new_cb_data = s.get_channel_binding("tls-unique") @@ -4255,7 +4261,7 @@ def test_read_write_after_close_raises_valuerror(self): server = ThreadedEchoServer(context=server_context, chatty=False) with server: - s = client_context.wrap_socket(socket.socket(), + s = client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) s.connect((HOST, server.port)) s.close() @@ -4271,7 +4277,7 @@ def test_sendfile(self): client_context, server_context, hostname = testing_context() server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) with open(os_helper.TESTFN, 'rb') as file: @@ -4345,7 +4351,7 @@ def test_session_handling(self): server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: # session is None before handshake self.assertEqual(s.session, None) @@ -4357,7 +4363,7 @@ def test_session_handling(self): s.session = object self.assertEqual(str(e.exception), 'Value is not a SSLSession.') - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) # cannot set session after handshake @@ -4366,7 +4372,7 @@ def test_session_handling(self): self.assertEqual(str(e.exception), 'Cannot set session after handshake.') - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: # can set session before handshake and before the # connection was established @@ -4376,7 +4382,7 @@ def test_session_handling(self): self.assertEqual(s.session, session) self.assertEqual(s.session_reused, True) - with client_context2.wrap_socket(socket.socket(), + with client_context2.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: # cannot re-use session with a different SSLContext with self.assertRaises(ValueError) as e: @@ -4421,7 +4427,7 @@ def test_pha_required(self): server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) s.write(b'HASCERT') @@ -4453,7 +4459,7 @@ def msg_cb(conn, direction, version, content_type, msg_type, data): server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) s.write(b'PHA') @@ -4484,7 +4490,7 @@ def test_pha_optional(self): server_context.verify_mode = ssl.CERT_OPTIONAL server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) s.write(b'HASCERT') @@ -4505,7 +4511,7 @@ def test_pha_optional_nocert(self): server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) s.write(b'HASCERT') @@ -4524,7 +4530,7 @@ def test_pha_no_pha_client(self): server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) with self.assertRaisesRegex(ssl.SSLError, 'not server'): @@ -4541,7 +4547,7 @@ def test_pha_no_pha_server(self): server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) s.write(b'HASCERT') @@ -4562,7 +4568,7 @@ def test_pha_not_tls13(self): server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) # PHA fails for TLS != 1.3 @@ -4588,7 +4594,7 @@ def test_bpo37428_pha_cert_none(self): server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) s.write(b'HASCERT') @@ -4607,7 +4613,7 @@ def test_internal_chain_client(self): server = ThreadedEchoServer(context=server_context, chatty=False) with server: with client_context.wrap_socket( - socket.socket(), + socket_helper.tcp_socket(), server_hostname=hostname ) as s: s.connect((HOST, server.port)) @@ -4646,7 +4652,7 @@ def test_internal_chain_server(self): server = ThreadedEchoServer(context=server_context, chatty=False) with server: with client_context.wrap_socket( - socket.socket(), + socket_helper.tcp_socket(), server_hostname=hostname ) as s: s.connect((HOST, server.port)) @@ -4701,7 +4707,7 @@ def test_keylog_filename(self): client_context.keylog_filename = os_helper.TESTFN server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) # header, 5 lines for TLS 1.3 @@ -4711,7 +4717,7 @@ def test_keylog_filename(self): server_context.keylog_filename = os_helper.TESTFN server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) self.assertGreaterEqual(self.keylog_lines(), 11) @@ -4720,7 +4726,7 @@ def test_keylog_filename(self): server_context.keylog_filename = os_helper.TESTFN server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) self.assertGreaterEqual(self.keylog_lines(), 21) @@ -4775,7 +4781,7 @@ def msg_cb(conn, direction, version, content_type, msg_type, data): server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) @@ -4805,10 +4811,10 @@ def sni_cb(sock, servername, ctx): server = ThreadedEchoServer(context=server_context, chatty=False) with server: - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) - with client_context.wrap_socket(socket.socket(), + with client_context.wrap_socket(socket_helper.tcp_socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) From d752f1ae6f5a0d0807d3f993288fd4042fdd6faf Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Tue, 18 May 2021 11:58:21 -0700 Subject: [PATCH 09/28] cleanup a bit. --- Lib/test/support/socket_helper.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/Lib/test/support/socket_helper.py b/Lib/test/support/socket_helper.py index d46b30fb5473e7..5a9d625b60818f 100644 --- a/Lib/test/support/socket_helper.py +++ b/Lib/test/support/socket_helper.py @@ -151,24 +151,16 @@ def bind_port(sock, host=HOST): def get_family(): """Get a host appropriate socket AF_INET or AF_INET6 family.""" + if IPV4_ENABLED: + return socket.AF_INET if IPV6_ENABLED: return socket.AF_INET6 - elif IPV4_ENABLED: - return socket.AF_INET - else: - raise support.TestFailed( - "At least one of IPv4 or IPv6 must be enabled.") + raise support.TestFailed("At least one of IPv4 or IPv6 must be enabled.") def tcp_socket(): """Get a new host appropriate IPv4 or IPv6 TCP STREAM socket.socket().""" - if IPV4_ENABLED: - return socket.socket(socket.AF_INET, socket.SOCK_STREAM) - elif IPV6_ENABLED: - return socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - else: - raise support.TestFailed( - "At least one of IPv4 or IPv6 must be enabled.") + return socket.socket(get_family(), socket.SOCK_STREAM) def get_bound_ip_socket_and_port(*, hostname=HOST, socktype=socket.SOCK_STREAM): From b7267560ad639f724777b1d7d1b5ba6984af43c8 Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Tue, 18 May 2021 11:58:34 -0700 Subject: [PATCH 10/28] Make test_httplib pass IPv6-only. --- Lib/test/test_httplib.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py index e9272569ecc531..a6921be808983d 100644 --- a/Lib/test/test_httplib.py +++ b/Lib/test/test_httplib.py @@ -1331,7 +1331,8 @@ def test_read1_bound_content_length(self): def test_response_fileno(self): # Make sure fd returned by fileno is valid. - serv = socket.create_server((HOST, 0)) + serv = socket_helper.get_bound_ip_socket_and_port()[0] + serv.listen() self.addCleanup(serv.close) result = None @@ -1350,7 +1351,7 @@ def run_server(): thread = threading.Thread(target=run_server) thread.start() self.addCleanup(thread.join, float(1)) - conn = client.HTTPConnection(*serv.getsockname()) + conn = client.HTTPConnection(*serv.getsockname()[:2]) conn.request("CONNECT", "dummy:1234") response = conn.getresponse() try: @@ -1673,8 +1674,7 @@ def test_client_constants(self): class SourceAddressTest(TestCase): def setUp(self): - self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.port = socket_helper.bind_port(self.serv) + self.serv, self.port = socket_helper.get_bound_ip_socket_and_port() self.source_port = socket_helper.find_unused_port() self.serv.listen() self.conn = None @@ -1706,8 +1706,8 @@ class TimeoutTest(TestCase): PORT = None def setUp(self): - self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - TimeoutTest.PORT = socket_helper.bind_port(self.serv) + self.serv = socket_helper.tcp_socket() + self.PORT = socket_helper.bind_port(self.serv) self.serv.listen() def tearDown(self): @@ -1722,7 +1722,7 @@ def testTimeoutAttribute(self): self.assertIsNone(socket.getdefaulttimeout()) socket.setdefaulttimeout(30) try: - httpConn = client.HTTPConnection(HOST, TimeoutTest.PORT) + httpConn = client.HTTPConnection(HOST, self.PORT) httpConn.connect() finally: socket.setdefaulttimeout(None) @@ -1733,7 +1733,7 @@ def testTimeoutAttribute(self): self.assertIsNone(socket.getdefaulttimeout()) socket.setdefaulttimeout(30) try: - httpConn = client.HTTPConnection(HOST, TimeoutTest.PORT, + httpConn = client.HTTPConnection(HOST, self.PORT, timeout=None) httpConn.connect() finally: @@ -1742,7 +1742,7 @@ def testTimeoutAttribute(self): httpConn.close() # a value - httpConn = client.HTTPConnection(HOST, TimeoutTest.PORT, timeout=30) + httpConn = client.HTTPConnection(HOST, self.PORT, timeout=30) httpConn.connect() self.assertEqual(httpConn.sock.gettimeout(), 30) httpConn.close() From 66054f4eaf1082d2d5a625364b609f19d815c335 Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Tue, 18 May 2021 12:03:43 -0700 Subject: [PATCH 11/28] Make test_wsgiref pass on IPv6-only. --- Lib/test/test_wsgiref.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_wsgiref.py b/Lib/test/test_wsgiref.py index 93ca6b99a92c9c..a92e2250b5a2c1 100644 --- a/Lib/test/test_wsgiref.py +++ b/Lib/test/test_wsgiref.py @@ -265,7 +265,12 @@ def app(environ, start_response): class WsgiHandler(NoLogRequestHandler, WSGIRequestHandler): pass - server = make_server(socket_helper.HOST, 0, app, handler_class=WsgiHandler) + class IPStackWSGIServer(WSGIServer): + address_family = socket_helper.get_family() + + server = make_server( + socket_helper.HOST, 0, app, + server_class=IPStackWSGIServer, handler_class=WsgiHandler) self.addCleanup(server.server_close) interrupted = threading.Event() @@ -278,7 +283,7 @@ def signal_handler(signum, frame): main_thread = threading.get_ident() def run_client(): - http = HTTPConnection(*server.server_address) + http = HTTPConnection(*server.server_address[:2]) http.request("GET", "/") with http.getresponse() as response: response.read(100) From f20c40a5d1ae7e428c322268999b389b2d7db4f6 Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Tue, 18 May 2021 12:09:14 -0700 Subject: [PATCH 12/28] fix test_smtplib for IPv6-only. --- Lib/test/test_smtplib.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/Lib/test/test_smtplib.py b/Lib/test/test_smtplib.py index f3d33ab0772dd3..d800b37c3ead5c 100644 --- a/Lib/test/test_smtplib.py +++ b/Lib/test/test_smtplib.py @@ -721,9 +721,8 @@ def setUp(self): sys.stdout = self.output self.evt = threading.Event() - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock, self.port = socket_helper.get_bound_ip_socket_and_port() self.sock.settimeout(15) - self.port = socket_helper.bind_port(self.sock) servargs = (self.evt, self.respdata, self.sock) self.thread = threading.Thread(target=server, args=servargs) self.thread.start() @@ -739,8 +738,8 @@ def tearDown(self): threading_helper.threading_cleanup(*self.thread_key) def testLineTooLong(self): - self.assertRaises(smtplib.SMTPResponseException, smtplib.SMTP, - HOST, self.port, 'localhost', 3) + with self.assertRaises(smtplib.SMTPResponseException): + smtplib.SMTP(HOST, self.port, 'localhost', 3) sim_users = {'Mr.A@somewhere.com':'John A', From 8859b2a8a93a9e4f91733d6552cc55fe8bdd1022 Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Tue, 18 May 2021 12:11:47 -0700 Subject: [PATCH 13/28] Fix test_poplib for IPv6-only. --- Lib/test/test_poplib.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/Lib/test/test_poplib.py b/Lib/test/test_poplib.py index c5ae9f77e4f006..b538c9ec6651d4 100644 --- a/Lib/test/test_poplib.py +++ b/Lib/test/test_poplib.py @@ -204,7 +204,7 @@ class DummyPOP3Server(asyncore.dispatcher, threading.Thread): handler = DummyPOP3Handler - def __init__(self, address, af=socket.AF_INET): + def __init__(self, address, af=socket_helper.get_family()): threading.Thread.__init__(self) asyncore.dispatcher.__init__(self) self.daemon = True @@ -481,9 +481,8 @@ class TestTimeouts(TestCase): def setUp(self): self.evt = threading.Event() - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock, self.port = socket_helper.get_bound_ip_socket_and_port() self.sock.settimeout(60) # Safety net. Look issue 11812 - self.port = socket_helper.bind_port(self.sock) self.thread = threading.Thread(target=self.server, args=(self.evt, self.sock)) self.thread.daemon = True self.thread.start() From d485dc379ba190518d5dfc4abd41835c1b731659 Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Tue, 18 May 2021 12:18:49 -0700 Subject: [PATCH 14/28] Make test_asyncore IPv6-only friendly. --- Lib/test/test_asyncore.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Lib/test/test_asyncore.py b/Lib/test/test_asyncore.py index 3bd904d1774bc3..47a8f8da057c6e 100644 --- a/Lib/test/test_asyncore.py +++ b/Lib/test/test_asyncore.py @@ -329,9 +329,8 @@ def tearDown(self): @threading_helper.reap_threads def test_send(self): evt = threading.Event() - sock = socket.socket() + sock, port = socket_helper.get_bound_ip_socket_and_port() sock.settimeout(3) - port = socket_helper.bind_port(sock) cap = BytesIO() args = (evt, cap, sock) @@ -344,7 +343,7 @@ def test_send(self): data = b"Suppose there isn't a 16-ton weight?" d = dispatcherwithsend_noread() - d.create_socket() + d.create_socket(family=sock.family) d.connect((socket_helper.HOST, port)) # give time for socket to connect @@ -793,6 +792,7 @@ def test_quick_connect(self): finally: threading_helper.join_thread(t) +@unittest.skipUnless(socket_helper.IPV4_ENABLED, 'IPv4 support required') class TestAPI_UseIPv4Sockets(BaseTestAPI): family = socket.AF_INET addr = (socket_helper.HOST, 0) From 77ed056ece2fe0c524e3709c5cb836106bc361eb Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Tue, 18 May 2021 12:20:15 -0700 Subject: [PATCH 15/28] make test_asynchat IPv6-only friendly. --- Lib/test/test_asynchat.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/Lib/test/test_asynchat.py b/Lib/test/test_asynchat.py index b32edddc7d5505..bafab802a5593e 100644 --- a/Lib/test/test_asynchat.py +++ b/Lib/test/test_asynchat.py @@ -26,8 +26,7 @@ class echo_server(threading.Thread): def __init__(self, event): threading.Thread.__init__(self) self.event = event - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.port = socket_helper.bind_port(self.sock) + self.sock, self.port = socket_helper.get_bound_ip_socket_and_port() # This will be set if the client wants us to wait before echoing # data back. self.start_resend_event = None @@ -69,7 +68,7 @@ class echo_client(asynchat.async_chat): def __init__(self, terminator, server_port): asynchat.async_chat.__init__(self) self.contents = [] - self.create_socket(socket.AF_INET, socket.SOCK_STREAM) + self.create_socket(socket_helper.get_family(), socket.SOCK_STREAM) self.connect((HOST, server_port)) self.set_terminator(terminator) self.buffer = b"" From 99e06657d03e85878aa5754b865e2bd41c76b60a Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Tue, 18 May 2021 16:12:48 -0700 Subject: [PATCH 16/28] Add udp_socket(), use SkipTest. --- Lib/test/support/socket_helper.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/Lib/test/support/socket_helper.py b/Lib/test/support/socket_helper.py index 5a9d625b60818f..2683628253b6d1 100644 --- a/Lib/test/support/socket_helper.py +++ b/Lib/test/support/socket_helper.py @@ -155,7 +155,7 @@ def get_family(): return socket.AF_INET if IPV6_ENABLED: return socket.AF_INET6 - raise support.TestFailed("At least one of IPv4 or IPv6 must be enabled.") + raise unittest.SkipTest('Neither IPv4 or IPv6 is enabled.') def tcp_socket(): @@ -163,6 +163,11 @@ def tcp_socket(): return socket.socket(get_family(), socket.SOCK_STREAM) +def udp_socket(proto=-1): + """Get a new host appropriate IPv4 or IPv6 UDP DGRAM socket.socket().""" + return socket.socket(get_family(), socket.SOCK_DGRAM, proto) + + def get_bound_ip_socket_and_port(*, hostname=HOST, socktype=socket.SOCK_STREAM): """Get an IP socket bound to a port as a sock, port tuple. @@ -170,7 +175,7 @@ def get_bound_ip_socket_and_port(*, hostname=HOST, socktype=socket.SOCK_STREAM): IPv4 is available. Context is a (socket, port) tuple. Exiting the context closes the socket. - Prefer the bind_ip_socket_and_port context manager when possible. + Prefer the bind_ip_socket_and_port context manager within a test method. """ family = get_family() sock = socket.socket(family, socktype) From 03aac813f196c90979bdbe99c03d458f2091b48b Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Tue, 18 May 2021 16:13:16 -0700 Subject: [PATCH 17/28] Allow test_socket to work on IPv6-only hosts. Tests that are agnostic about the IP protocol used will use whichever the socket_helper module deems appropriate for the host. --- Lib/test/test_socket.py | 210 +++++++++++++++++++++++----------------- 1 file changed, 121 insertions(+), 89 deletions(-) diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 828d1f3dcc6701..7d0af45eb3fdf7 100755 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -175,8 +175,8 @@ def socket_setdefaulttimeout(timeout): class SocketTCPTest(unittest.TestCase): def setUp(self): - self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.port = socket_helper.bind_port(self.serv) + self.serv, self.port = socket_helper.get_bound_ip_socket_and_port( + socktype=socket.SOCK_STREAM) self.serv.listen() def tearDown(self): @@ -186,8 +186,8 @@ def tearDown(self): class SocketUDPTest(unittest.TestCase): def setUp(self): - self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.port = socket_helper.bind_port(self.serv) + self.serv, self.port = socket_helper.get_bound_ip_socket_and_port( + socktype=socket.SOCK_DGRAM) def tearDown(self): self.serv.close() @@ -196,7 +196,7 @@ def tearDown(self): class SocketUDPLITETest(SocketUDPTest): def setUp(self): - self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDPLITE) + self.serv = socket_helper.udp_socket(socket.IPPROTO_UDPLITE) self.port = socket_helper.bind_port(self.serv) class ThreadSafeCleanupTestCase(unittest.TestCase): @@ -409,7 +409,7 @@ def __init__(self, methodName='runTest'): ThreadableTest.__init__(self) def clientSetUp(self): - self.cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.cli = socket.socket(socket_helper.get_family(), socket.SOCK_STREAM) def clientTearDown(self): self.cli.close() @@ -423,7 +423,7 @@ def __init__(self, methodName='runTest'): ThreadableTest.__init__(self) def clientSetUp(self): - self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.cli = socket.socket(socket_helper.get_family(), socket.SOCK_DGRAM) def clientTearDown(self): self.cli.close() @@ -720,23 +720,39 @@ def bindSock(self, sock): socket_helper.bind_port(sock, host=self.host) class TCPTestBase(InetTestBase): + """Base class for TCP tests.""" + + def newSocket(self): + return socket_helper.tcp_socket() + +@unittest.skipUnless(socket_helper.IPV4_ENABLED, 'Requires IPv4') +class TCP4TestBase(InetTestBase): """Base class for TCP-over-IPv4 tests.""" def newSocket(self): return socket.socket(socket.AF_INET, socket.SOCK_STREAM) class UDPTestBase(InetTestBase): + """Base class for UDP tests.""" + + def newSocket(self): + return socket_helper.udp_socket() + +@unittest.skipUnless(socket_helper.IPV4_ENABLED, 'Requires IPv4') +class UDP4TestBase(InetTestBase): """Base class for UDP-over-IPv4 tests.""" def newSocket(self): return socket.socket(socket.AF_INET, socket.SOCK_DGRAM) +@unittest.skipUnless(socket_helper.IPV4_ENABLED, 'Requires IPv4') class UDPLITETestBase(InetTestBase): """Base class for UDPLITE-over-IPv4 tests.""" def newSocket(self): return socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDPLITE) +@unittest.skipUnless(socket_helper.IPV4_ENABLED, 'Requires IPv4') class SCTPStreamBase(InetTestBase): """Base class for SCTP tests in one-to-one (SOCK_STREAM) mode.""" @@ -839,14 +855,15 @@ def test_SocketType_is_socketobject(self): s.close() def test_repr(self): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + family = socket_helper.get_family() + s = socket.socket(family, socket.SOCK_STREAM) with s: self.assertIn('fd=%i' % s.fileno(), repr(s)) - self.assertIn('family=%s' % socket.AF_INET, repr(s)) + self.assertIn('family=%s' % family, repr(s)) self.assertIn('type=%s' % socket.SOCK_STREAM, repr(s)) self.assertIn('proto=0', repr(s)) self.assertNotIn('raddr', repr(s)) - s.bind(('127.0.0.1', 0)) + s.bind((socket_helper.HOST, 0)) self.assertIn('laddr', repr(s)) self.assertIn(str(s.getsockname()), repr(s)) self.assertIn('[closed]', repr(s)) @@ -854,7 +871,8 @@ def test_repr(self): @unittest.skipUnless(_socket is not None, 'need _socket module') def test_csocket_repr(self): - s = _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM) + family = socket_helper.get_family() + s = _socket.socket(family, _socket.SOCK_STREAM) try: expected = ('' % (s.fileno(), s.family, s.type, s.proto)) @@ -866,7 +884,7 @@ def test_csocket_repr(self): self.assertEqual(repr(s), expected) def test_weakref(self): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + with socket_helper.tcp_socket() as s: p = proxy(s) self.assertEqual(p.fileno(), s.fileno()) s = None @@ -889,10 +907,10 @@ def testSocketError(self): def testSendtoErrors(self): # Testing that sendto doesn't mask failures. See #10169. - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s = socket_helper.udp_socket() self.addCleanup(s.close) s.bind(('', 0)) - sockname = s.getsockname() + sockname = s.getsockname()[:2] # 2 args with self.assertRaises(TypeError) as cm: s.sendto('\u2620', sockname) @@ -1009,6 +1027,7 @@ def testHostnameRes(self): if not fqhn in all_host_names: self.fail("Error testing host resolution mechanisms. (fqdn: %s, all: %s)" % (fqhn, repr(all_host_names))) + @unittest.skipUnless(socket_helper.IPV4_ENABLED, 'IPv4 required') def test_host_resolution(self): for addr in [socket_helper.HOSTv4, '10.0.0.1', '255.255.255.255']: self.assertEqual(socket.gethostbyname(addr), addr) @@ -1375,6 +1394,7 @@ def testStringToIPv6(self): # XXX The following don't test module-level functionality... + @unittest.skipUnless(socket_helper.IPV4_ENABLED, 'IPv4 required') def testSockName(self): # Testing getsockname() port = socket_helper.find_unused_port() @@ -1416,8 +1436,8 @@ def testSendAfterClose(self): self.assertRaises(OSError, sock.send, b"spam") def testCloseException(self): - sock = socket.socket() - sock.bind((socket._LOCALHOST, 0)) + sock = socket_helper.tcp_socket() + sock.bind((socket_helper.HOST, 0)) socket.socket(fileno=sock.fileno()).close() try: sock.close() @@ -1430,8 +1450,9 @@ def testCloseException(self): def testNewAttributes(self): # testing .family, .type and .protocol - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - self.assertEqual(sock.family, socket.AF_INET) + family = socket_helper.get_family() + with socket.socket(family, socket.SOCK_STREAM) as sock: + self.assertEqual(sock.family, family) if hasattr(socket, 'SOCK_CLOEXEC'): self.assertIn(sock.type, (socket.SOCK_STREAM | socket.SOCK_CLOEXEC, @@ -1441,13 +1462,15 @@ def testNewAttributes(self): self.assertEqual(sock.proto, 0) def test_getsockaddrarg(self): - sock = socket.socket() + sock = socket_helper.tcp_socket() self.addCleanup(sock.close) port = socket_helper.find_unused_port() big_port = port + 65536 neg_port = port - 65536 - self.assertRaises(OverflowError, sock.bind, (HOST, big_port)) - self.assertRaises(OverflowError, sock.bind, (HOST, neg_port)) + with self.assertRaises(OverflowError): + sock.bind((HOST, big_port)) + with self.assertRaises(OverflowError): + sock.bind((HOST, neg_port)) # Since find_unused_port() is inherently subject to race conditions, we # call it a couple times if necessary. for i in itertools.count(): @@ -1488,6 +1511,7 @@ def test_sio_loopback_fast_path(self): raise self.assertRaises(TypeError, s.ioctl, socket.SIO_LOOPBACK_FAST_PATH, None) + @unittest.skipUnless(socket_helper.IPV4_ENABLED, 'IPv4 required') def testGetaddrinfo(self): try: socket.getaddrinfo('localhost', 80) @@ -1625,14 +1649,14 @@ def test_sendall_interrupted_with_timeout(self): self.check_sendall_interrupted(True) def test_dealloc_warn(self): - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock = socket_helper.tcp_socket() r = repr(sock) with self.assertWarns(ResourceWarning) as cm: sock = None support.gc_collect() self.assertIn(r, str(cm.warning.args[0])) # An open socket file object gets dereferenced after the socket - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock = socket_helper.tcp_socket() f = sock.makefile('rb') r = repr(sock) sock = None @@ -1642,13 +1666,13 @@ def test_dealloc_warn(self): support.gc_collect() def test_name_closed_socketio(self): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + with socket_helper.tcp_socket() as sock: fp = sock.makefile("rb") fp.close() self.assertEqual(repr(fp), "<_io.BufferedReader name=-1>") def test_unusable_closed_socketio(self): - with socket.socket() as sock: + with socket_helper.tcp_socket() as sock: fp = sock.makefile("rb", buffering=0) self.assertTrue(fp.readable()) self.assertFalse(fp.writable()) @@ -1659,7 +1683,7 @@ def test_unusable_closed_socketio(self): self.assertRaises(ValueError, fp.seekable) def test_socket_close(self): - sock = socket.socket() + sock = socket_helper.tcp_socket() try: sock.bind((HOST, 0)) socket.close(sock.fileno()) @@ -1702,11 +1726,11 @@ def test_pickle(self): def test_listen_backlog(self): for backlog in 0, -1: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as srv: + with socket_helper.tcp_socket() as srv: srv.bind((HOST, 0)) srv.listen(backlog) - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as srv: + with socket_helper.tcp_socket() as srv: srv.bind((HOST, 0)) srv.listen() @@ -1714,7 +1738,7 @@ def test_listen_backlog(self): def test_listen_backlog_overflow(self): # Issue 15989 import _testcapi - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as srv: + with socket_helper.tcp_socket() as srv: srv.bind((HOST, 0)) self.assertRaises(OverflowError, srv.listen, _testcapi.INT_MAX + 1) @@ -1861,7 +1885,7 @@ def _test_socket_fileno(self, s, family, stype): self.assertEqual(s.type, stype) fd = s.fileno() - s2 = socket.socket(fileno=fd) + s2 = socket.socket(family, fileno=fd) self.addCleanup(s2.close) # detach old fd to avoid double close s.detach() @@ -1869,36 +1893,32 @@ def _test_socket_fileno(self, s, family, stype): self.assertEqual(s2.type, stype) self.assertEqual(s2.fileno(), fd) - def test_socket_fileno(self): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + def test_socket_fileno_tcp(self): + s = socket_helper.tcp_socket() self.addCleanup(s.close) s.bind((socket_helper.HOST, 0)) - self._test_socket_fileno(s, socket.AF_INET, socket.SOCK_STREAM) + self._test_socket_fileno(s, s.family, socket.SOCK_STREAM) - if hasattr(socket, "SOCK_DGRAM"): - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.addCleanup(s.close) - s.bind((socket_helper.HOST, 0)) - self._test_socket_fileno(s, socket.AF_INET, socket.SOCK_DGRAM) + @unittest.skipUnless(hasattr(socket, "SOCK_DGRAM"), "SOCK_DGRAM required") + def test_socket_fileno_udp(self): + s = socket_helper.udp_socket() + self.addCleanup(s.close) + s.bind((socket_helper.HOST, 0)) + self._test_socket_fileno(s, s.family, socket.SOCK_DGRAM) - if socket_helper.IPV6_ENABLED: - s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - self.addCleanup(s.close) - s.bind((socket_helper.HOSTv6, 0, 0, 0)) - self._test_socket_fileno(s, socket.AF_INET6, socket.SOCK_STREAM) - - if hasattr(socket, "AF_UNIX"): - tmpdir = tempfile.mkdtemp() - self.addCleanup(shutil.rmtree, tmpdir) - s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self.addCleanup(s.close) - try: - s.bind(os.path.join(tmpdir, 'socket')) - except PermissionError: - pass - else: - self._test_socket_fileno(s, socket.AF_UNIX, - socket.SOCK_STREAM) + @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "AF_UNIX required") + def test_socket_fileno_unix(self): + tmpdir = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, tmpdir) + s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.addCleanup(s.close) + try: + s.bind(os.path.join(tmpdir, 'socket')) + except PermissionError: + pass + else: + self._test_socket_fileno(s, socket.AF_UNIX, + socket.SOCK_STREAM) def test_socket_fileno_rejects_float(self): with self.assertRaises(TypeError): @@ -2514,7 +2534,7 @@ def _testSendAll(self): def testFromFd(self): # Testing fromfd() fd = self.cli_conn.fileno() - sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM) + sock = socket.fromfd(fd, socket_helper.get_family(), socket.SOCK_STREAM) self.addCleanup(sock.close) self.assertIsInstance(sock, socket.socket) msg = sock.recv(1024) @@ -2570,7 +2590,7 @@ def testDetach(self): self.cli_conn.close() # ...but we can create another socket using the (still open) # file descriptor - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=f) + sock = socket.socket(socket_helper.get_family(), socket.SOCK_STREAM, fileno=f) self.addCleanup(sock.close) msg = sock.recv(1024) self.assertEqual(msg, MSG) @@ -4380,13 +4400,13 @@ class SendrecvmsgSCTPStreamTestBase(SendrecvmsgSCTPFlagsBase, @requireAttrs(socket.socket, "sendmsg") @unittest.skipIf(AIX, "IPPROTO_SCTP: [Errno 62] Protocol not supported on AIX") -@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP") +@requireSocket(socket_helper.get_family(), "SOCK_STREAM", "IPPROTO_SCTP") class SendmsgSCTPStreamTest(SendmsgStreamTests, SendrecvmsgSCTPStreamTestBase): pass @requireAttrs(socket.socket, "recvmsg") @unittest.skipIf(AIX, "IPPROTO_SCTP: [Errno 62] Protocol not supported on AIX") -@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP") +@requireSocket(socket_helper.get_family(), "SOCK_STREAM", "IPPROTO_SCTP") class RecvmsgSCTPStreamTest(RecvmsgTests, RecvmsgGenericStreamTests, SendrecvmsgSCTPStreamTestBase): @@ -4400,7 +4420,7 @@ def testRecvmsgEOF(self): @requireAttrs(socket.socket, "recvmsg_into") @unittest.skipIf(AIX, "IPPROTO_SCTP: [Errno 62] Protocol not supported on AIX") -@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP") +@requireSocket(socket_helper.get_family(), "SOCK_STREAM", "IPPROTO_SCTP") class RecvmsgIntoSCTPStreamTest(RecvmsgIntoTests, RecvmsgGenericStreamTests, SendrecvmsgSCTPStreamTestBase): @@ -5142,7 +5162,7 @@ def mocked_socket_module(self): def test_connect(self): port = socket_helper.find_unused_port() - cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + cli = socket_helper.tcp_socket() self.addCleanup(cli.close) with self.assertRaises(OSError) as cm: cli.connect((HOST, port)) @@ -5210,7 +5230,7 @@ def _testFamily(self): self.cli = socket.create_connection((HOST, self.port), timeout=support.LOOPBACK_TIMEOUT) self.addCleanup(self.cli.close) - self.assertEqual(self.cli.family, 2) + self.assertEqual(self.cli.family, socket_helper.get_family()) testSourceAddress = _justAccept def _testSourceAddress(self): @@ -5724,7 +5744,7 @@ def testCreateConnectionBase(self): conn.sendall(data) def _testCreateConnectionBase(self): - address = self.serv.getsockname() + address = self.serv.getsockname()[:2] with socket.create_connection(address) as sock: self.assertFalse(sock._closed) sock.sendall(b'foo') @@ -5738,7 +5758,7 @@ def testCreateConnectionClose(self): conn.sendall(data) def _testCreateConnectionClose(self): - address = self.serv.getsockname() + address = self.serv.getsockname()[:2] with socket.create_connection(address) as sock: sock.close() self.assertTrue(sock._closed) @@ -6034,7 +6054,7 @@ def meth_from_sock(self, sock): # regular file def _testRegularFile(self): - address = self.serv.getsockname() + address = self.serv.getsockname()[:2] file = open(os_helper.TESTFN, 'rb') with socket.create_connection(address) as sock, file as file: meth = self.meth_from_sock(sock) @@ -6051,7 +6071,7 @@ def testRegularFile(self): # non regular file def _testNonRegularFile(self): - address = self.serv.getsockname() + address = self.serv.getsockname()[:2] file = io.BytesIO(self.FILEDATA) with socket.create_connection(address) as sock, file as file: sent = sock.sendfile(file) @@ -6069,7 +6089,7 @@ def testNonRegularFile(self): # empty file def _testEmptyFileSend(self): - address = self.serv.getsockname() + address = self.serv.getsockname()[:2] filename = os_helper.TESTFN + "2" with open(filename, 'wb'): self.addCleanup(os_helper.unlink, filename) @@ -6088,7 +6108,7 @@ def testEmptyFileSend(self): # offset def _testOffset(self): - address = self.serv.getsockname() + address = self.serv.getsockname()[:2] file = open(os_helper.TESTFN, 'rb') with socket.create_connection(address) as sock, file as file: meth = self.meth_from_sock(sock) @@ -6105,7 +6125,7 @@ def testOffset(self): # count def _testCount(self): - address = self.serv.getsockname() + address = self.serv.getsockname()[:2] file = open(os_helper.TESTFN, 'rb') sock = socket.create_connection(address, timeout=support.LOOPBACK_TIMEOUT) @@ -6126,7 +6146,7 @@ def testCount(self): # count small def _testCountSmall(self): - address = self.serv.getsockname() + address = self.serv.getsockname()[:2] file = open(os_helper.TESTFN, 'rb') sock = socket.create_connection(address, timeout=support.LOOPBACK_TIMEOUT) @@ -6147,7 +6167,7 @@ def testCountSmall(self): # count + offset def _testCountWithOffset(self): - address = self.serv.getsockname() + address = self.serv.getsockname()[:2] file = open(os_helper.TESTFN, 'rb') with socket.create_connection(address, timeout=2) as sock, file as file: count = 100007 @@ -6166,7 +6186,7 @@ def testCountWithOffset(self): # non blocking sockets are not supposed to work def _testNonBlocking(self): - address = self.serv.getsockname() + address = self.serv.getsockname()[:2] file = open(os_helper.TESTFN, 'rb') with socket.create_connection(address) as sock, file as file: sock.setblocking(False) @@ -6182,7 +6202,7 @@ def testNonBlocking(self): # timeout (non-triggered) def _testWithTimeout(self): - address = self.serv.getsockname() + address = self.serv.getsockname()[:2] file = open(os_helper.TESTFN, 'rb') sock = socket.create_connection(address, timeout=support.LOOPBACK_TIMEOUT) @@ -6200,7 +6220,7 @@ def testWithTimeout(self): # timeout (triggered) def _testWithTimeoutTriggeredSend(self): - address = self.serv.getsockname() + address = self.serv.getsockname()[:2] with open(os_helper.TESTFN, 'rb') as file: with socket.create_connection(address) as sock: sock.settimeout(0.01) @@ -6471,35 +6491,45 @@ def test_new_tcp_flags(self): class CreateServerTest(unittest.TestCase): - def test_address(self): + @unittest.skipUnless(socket_helper.IPV4_ENABLED, 'IPv4 required') + def test_address_ipv4(self): port = socket_helper.find_unused_port() with socket.create_server(("127.0.0.1", port)) as sock: self.assertEqual(sock.getsockname()[0], "127.0.0.1") self.assertEqual(sock.getsockname()[1], port) - if socket_helper.IPV6_ENABLED: - with socket.create_server(("::1", port), - family=socket.AF_INET6) as sock: - self.assertEqual(sock.getsockname()[0], "::1") - self.assertEqual(sock.getsockname()[1], port) - def test_family_and_type(self): + @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required') + def test_address_ipv6(self): + port = socket_helper.find_unused_port() + with socket.create_server(("::1", port), + family=socket.AF_INET6) as sock: + self.assertEqual(sock.getsockname()[0], "::1") + self.assertEqual(sock.getsockname()[1], port) + + @unittest.skipUnless(socket_helper.IPV4_ENABLED, 'IPv4 required') + def test_family_and_type_ipv4(self): with socket.create_server(("127.0.0.1", 0)) as sock: self.assertEqual(sock.family, socket.AF_INET) self.assertEqual(sock.type, socket.SOCK_STREAM) - if socket_helper.IPV6_ENABLED: - with socket.create_server(("::1", 0), family=socket.AF_INET6) as s: - self.assertEqual(s.family, socket.AF_INET6) - self.assertEqual(sock.type, socket.SOCK_STREAM) + + @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required') + def test_family_and_type_ipv6(self): + with socket.create_server(("::1", 0), family=socket.AF_INET6) as sock: + self.assertEqual(sock.family, socket.AF_INET6) + self.assertEqual(sock.type, socket.SOCK_STREAM) def test_reuse_port(self): + fam = socket_helper.get_family() if not hasattr(socket, "SO_REUSEPORT"): with self.assertRaises(ValueError): - socket.create_server(("localhost", 0), reuse_port=True) + socket.create_server( + ("localhost", 0), family=fam,reuse_port=True) else: - with socket.create_server(("localhost", 0)) as sock: + with socket.create_server(("localhost", 0), family=fam) as sock: opt = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) self.assertEqual(opt, 0) - with socket.create_server(("localhost", 0), reuse_port=True) as sock: + with socket.create_server( + ("localhost", 0), family=fam, reuse_port=True) as sock: opt = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) self.assertNotEqual(opt, 0) @@ -6554,6 +6584,7 @@ def echo_client(self, addr, family): sock.sendall(b'foo') self.assertEqual(sock.recv(1024), b'foo') + @unittest.skipUnless(socket_helper.IPV4_ENABLED, 'IPv4 required') def test_tcp4(self): port = socket_helper.find_unused_port() with socket.create_server(("", port)) as sock: @@ -6573,6 +6604,7 @@ def test_tcp6(self): @unittest.skipIf(not socket.has_dualstack_ipv6(), "dualstack_ipv6 not supported") @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test') + @unittest.skipUnless(socket_helper.IPV4_ENABLED, 'IPv4 required.') def test_dual_stack_client_v4(self): port = socket_helper.find_unused_port() with socket.create_server(("", port), family=socket.AF_INET6, From 56d179ebf154aaeb837d145c26377828505d0a73 Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Tue, 18 May 2021 16:27:58 -0700 Subject: [PATCH 18/28] cleanup test names to clarify IPv4 status. there's a _lot_ more cleanup that could be done in here. the capability skip decorators are applied in an inconsistent manner. some on the base class, others way down on the leaf Test classes. yuck. --- Lib/test/test_socket.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 7d0af45eb3fdf7..3d6b575581a248 100755 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -746,7 +746,7 @@ def newSocket(self): return socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @unittest.skipUnless(socket_helper.IPV4_ENABLED, 'Requires IPv4') -class UDPLITETestBase(InetTestBase): +class UDPLITE4TestBase(InetTestBase): """Base class for UDPLITE-over-IPv4 tests.""" def newSocket(self): @@ -4228,21 +4228,21 @@ def _testSecondCmsgTruncInData(self): # Derive concrete test classes for different socket types. -class SendrecvmsgUDPTestBase(SendrecvmsgDgramFlagsBase, +class SendrecvmsgUDP4TestBase(SendrecvmsgDgramFlagsBase, SendrecvmsgConnectionlessBase, - ThreadedSocketTestMixin, UDPTestBase): + ThreadedSocketTestMixin, UDP4TestBase): pass @requireAttrs(socket.socket, "sendmsg") -class SendmsgUDPTest(SendmsgConnectionlessTests, SendrecvmsgUDPTestBase): +class SendmsgUDP4Test(SendmsgConnectionlessTests, SendrecvmsgUDP4TestBase): pass @requireAttrs(socket.socket, "recvmsg") -class RecvmsgUDPTest(RecvmsgTests, SendrecvmsgUDPTestBase): +class RecvmsgUDP4Test(RecvmsgTests, SendrecvmsgUDP4TestBase): pass @requireAttrs(socket.socket, "recvmsg_into") -class RecvmsgIntoUDPTest(RecvmsgIntoTests, SendrecvmsgUDPTestBase): +class RecvmsgIntoUDP4Test(RecvmsgIntoTests, SendrecvmsgUDP4TestBase): pass @@ -4293,27 +4293,27 @@ class RecvmsgIntoRFC3542AncillaryUDP6Test(RecvmsgIntoMixin, @unittest.skipUnless(HAVE_SOCKET_UDPLITE, 'UDPLITE sockets required for this test.') -class SendrecvmsgUDPLITETestBase(SendrecvmsgDgramFlagsBase, +class SendrecvmsgUDPLITE4TestBase(SendrecvmsgDgramFlagsBase, SendrecvmsgConnectionlessBase, - ThreadedSocketTestMixin, UDPLITETestBase): + ThreadedSocketTestMixin, UDPLITE4TestBase): pass @unittest.skipUnless(HAVE_SOCKET_UDPLITE, 'UDPLITE sockets required for this test.') @requireAttrs(socket.socket, "sendmsg") -class SendmsgUDPLITETest(SendmsgConnectionlessTests, SendrecvmsgUDPLITETestBase): +class SendmsgUDPLITE4Test(SendmsgConnectionlessTests, SendrecvmsgUDPLITE4TestBase): pass @unittest.skipUnless(HAVE_SOCKET_UDPLITE, 'UDPLITE sockets required for this test.') @requireAttrs(socket.socket, "recvmsg") -class RecvmsgUDPLITETest(RecvmsgTests, SendrecvmsgUDPLITETestBase): +class RecvmsgUDPLITE4Test(RecvmsgTests, SendrecvmsgUDPLITE4TestBase): pass @unittest.skipUnless(HAVE_SOCKET_UDPLITE, 'UDPLITE sockets required for this test.') @requireAttrs(socket.socket, "recvmsg_into") -class RecvmsgIntoUDPLITETest(RecvmsgIntoTests, SendrecvmsgUDPLITETestBase): +class RecvmsgIntoUDPLITE4Test(RecvmsgIntoTests, SendrecvmsgUDPLITE4TestBase): pass @@ -6701,17 +6701,17 @@ def test_main(): tests.append(BasicBluetoothTest) tests.extend([ CmsgMacroTests, - SendmsgUDPTest, - RecvmsgUDPTest, - RecvmsgIntoUDPTest, + SendmsgUDP4Test, + RecvmsgUDP4Test, + RecvmsgIntoUDP4Test, SendmsgUDP6Test, RecvmsgUDP6Test, RecvmsgRFC3542AncillaryUDP6Test, RecvmsgIntoRFC3542AncillaryUDP6Test, RecvmsgIntoUDP6Test, - SendmsgUDPLITETest, - RecvmsgUDPLITETest, - RecvmsgIntoUDPLITETest, + SendmsgUDPLITE4Test, + RecvmsgUDPLITE4Test, + RecvmsgIntoUDPLITE4Test, SendmsgUDPLITE6Test, RecvmsgUDPLITE6Test, RecvmsgRFC3542AncillaryUDPLITE6Test, From 7d55a83b10fa6825df450fc4a2363a4b51228fd6 Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Tue, 18 May 2021 17:16:55 -0700 Subject: [PATCH 19/28] Add AF_INET6 support to multiprocessing IPC. Also makes the test_multiprocessing suites work on IPv6-only hosts. --- Lib/multiprocessing/connection.py | 14 +++++++++++-- Lib/test/_test_multiprocessing.py | 35 ++++++++++++++++++++----------- 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/Lib/multiprocessing/connection.py b/Lib/multiprocessing/connection.py index 510e4b5aba44a6..85f371bb47c024 100644 --- a/Lib/multiprocessing/connection.py +++ b/Lib/multiprocessing/connection.py @@ -51,6 +51,9 @@ default_family = 'AF_UNIX' families += ['AF_UNIX'] +if hasattr(socket, 'AF_INET6') and socket.has_ipv6: + families.append('AF_INET6') + if sys.platform == 'win32': default_family = 'AF_PIPE' families += ['AF_PIPE'] @@ -70,7 +73,7 @@ def arbitrary_address(family): ''' Return an arbitrary free address for the given family ''' - if family == 'AF_INET': + if family in {'AF_INET', 'AF_INET6'}: return ('localhost', 0) elif family == 'AF_UNIX': # Prefer abstract sockets if possible to avoid problems with the address @@ -101,9 +104,16 @@ def address_type(address): ''' Return the types of the address - This can be 'AF_INET', 'AF_UNIX', or 'AF_PIPE' + This can be 'AF_INET', 'AF_INET6', 'AF_UNIX', or 'AF_PIPE' ''' if type(address) == tuple: + if '.' in address[0]: + return 'AF_INET' + if ':' in address[0]: + return 'AF_INET6' + addr_info = socket.getaddrinfo(*address[:2]) + if addr_info: + return addr_info[0][0].name return 'AF_INET' elif type(address) is str and address.startswith('\\\\'): return 'AF_PIPE' diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py index ead92cfa2abfea..02b24f41a428c3 100644 --- a/Lib/test/_test_multiprocessing.py +++ b/Lib/test/_test_multiprocessing.py @@ -184,6 +184,14 @@ class BaseTestCase(object): ALLOWED_TYPES = ('processes', 'manager', 'threads') + def get_families(self): + fams = set(self.connection.families) + if not socket_helper.IPV6_ENABLED: + fams -= {'AF_INET6'} + if not socket_helper.IPV4_ENABLED: + fams -= {'AF_INET'} + return fams + def assertTimingAlmostEqual(self, a, b): if CHECK_TIMINGS: self.assertAlmostEqual(a, b, 1) @@ -3284,7 +3292,7 @@ class _TestListener(BaseTestCase): ALLOWED_TYPES = ('processes',) def test_multiple_bind(self): - for family in self.connection.families: + for family in self.get_families(): l = self.connection.Listener(family=family) self.addCleanup(l.close) self.assertRaises(OSError, self.connection.Listener, @@ -3324,7 +3332,7 @@ def _test(cls, address): conn.close() def test_listener_client(self): - for family in self.connection.families: + for family in self.get_families(): l = self.connection.Listener(family=family) p = self.Process(target=self._test, args=(l.address,)) p.daemon = True @@ -3351,7 +3359,7 @@ def test_issue14725(self): l.close() def test_issue16955(self): - for fam in self.connection.families: + for fam in self.get_families(): l = self.connection.Listener(family=fam) c = self.connection.Client(l.address) a = l.accept() @@ -3464,7 +3472,8 @@ def _listener(cls, conn, families): new_conn.close() l.close() - l = socket.create_server((socket_helper.HOST, 0)) + l = socket.create_server((socket_helper.HOST, 0), + family=socket_helper.get_family()) conn.send(l.getsockname()) new_conn, addr = l.accept() conn.send(new_conn) @@ -3481,7 +3490,7 @@ def _remote(cls, conn): client.close() address, msg = conn.recv() - client = socket.socket() + client = socket_helper.tcp_socket() client.connect(address) client.sendall(msg.upper()) client.close() @@ -3489,7 +3498,7 @@ def _remote(cls, conn): conn.close() def test_pickling(self): - families = self.connection.families + families = self.get_families() lconn, lconn0 = self.Pipe() lp = self.Process(target=self._listener, args=(lconn0, families)) @@ -4638,7 +4647,7 @@ def test_wait(self, slow=False): @classmethod def _child_test_wait_socket(cls, address, slow): - s = socket.socket() + s = socket_helper.tcp_socket() s.connect(address) for i in range(10): if slow: @@ -4648,7 +4657,8 @@ def _child_test_wait_socket(cls, address, slow): def test_wait_socket(self, slow=False): from multiprocessing.connection import wait - l = socket.create_server((socket_helper.HOST, 0)) + l = socket.create_server((socket_helper.HOST, 0), + family=socket_helper.get_family()) addr = l.getsockname() readers = [] procs = [] @@ -4836,7 +4846,8 @@ def test_timeout(self): try: socket.setdefaulttimeout(0.1) parent, child = multiprocessing.Pipe(duplex=True) - l = multiprocessing.connection.Listener(family='AF_INET') + l = multiprocessing.connection.Listener( + family=socket_helper.get_family().name) p = multiprocessing.Process(target=self._test_timeout, args=(child, l.address)) p.start() @@ -4910,11 +4921,11 @@ def get_high_socket_fd(self): # The child process will not have any socket handles, so # calling socket.fromfd() should produce WSAENOTSOCK even # if there is a handle of the same number. - return socket.socket().detach() + return socket_helper.tcp_socket().detach() else: # We want to produce a socket with an fd high enough that a # freshly created child process will not have any fds as high. - fd = socket.socket().detach() + fd = socket_helper.tcp_socket().detach() to_close = [] while fd < 50: to_close.append(fd) @@ -4925,7 +4936,7 @@ def get_high_socket_fd(self): def close(self, fd): if WIN32: - socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=fd).close() + socket.socket(socket_helper.get_family(), socket.SOCK_STREAM, fileno=fd).close() else: os.close(fd) From 1a721295cfe96c9ca37a3b6cd246226394506380 Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Tue, 18 May 2021 18:43:54 -0700 Subject: [PATCH 20/28] Fix test_httpservers for IPv6-only hosts. --- Lib/test/test_httpservers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_httpservers.py b/Lib/test/test_httpservers.py index cb0a3aa9e40451..eceb364a607d82 100644 --- a/Lib/test/test_httpservers.py +++ b/Lib/test/test_httpservers.py @@ -31,6 +31,7 @@ import unittest from test import support from test.support import os_helper +from test.support import socket_helper from test.support import threading_helper @@ -42,6 +43,9 @@ def log_message(self, *args): def read(self, n=None): return '' +class IPvWhateverHTTPServer(HTTPServer): + address_family = socket_helper.get_family() + class TestServerThread(threading.Thread): def __init__(self, test_object, request_handler): @@ -50,8 +54,8 @@ def __init__(self, test_object, request_handler): self.test_object = test_object def run(self): - self.server = HTTPServer(('localhost', 0), self.request_handler) - self.test_object.HOST, self.test_object.PORT = self.server.socket.getsockname() + self.server = IPvWhateverHTTPServer(('localhost', 0), self.request_handler) + self.test_object.HOST, self.test_object.PORT = self.server.socket.getsockname()[:2] self.test_object.server_started.set() self.test_object = None try: From bc6d51a25e57cfeed4846d0c25059aa054497760 Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Tue, 18 May 2021 18:55:50 -0700 Subject: [PATCH 21/28] If starting a logging config server on AF_INET fails, try AF_INET6. --- Lib/logging/config.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/Lib/logging/config.py b/Lib/logging/config.py index 3bc63b78621aba..752c8cbf025ca8 100644 --- a/Lib/logging/config.py +++ b/Lib/logging/config.py @@ -29,6 +29,7 @@ import logging import logging.handlers import re +import socket import struct import sys import threading @@ -885,7 +886,11 @@ class ConfigSocketReceiver(ThreadingTCPServer): def __init__(self, host='localhost', port=DEFAULT_LOGGING_CONFIG_PORT, handler=None, ready=None, verify=None): - ThreadingTCPServer.__init__(self, (host, port), handler) + try: + ThreadingTCPServer.__init__(self, (host, port), handler) + except OSError as err: + self.address_family = socket.AF_INET6 + ThreadingTCPServer.__init__(self, (host, port), handler) logging._acquireLock() self.abort = 0 logging._releaseLock() From 1317e7a571dd827010a8a990e11c0009af17f2df Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Tue, 18 May 2021 18:56:31 -0700 Subject: [PATCH 22/28] Fix test_logging for use on IPv6-only hosts. --- Lib/test/test_logging.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py index ee00a32026f65e..6c50763ba10ae3 100644 --- a/Lib/test/test_logging.py +++ b/Lib/test/test_logging.py @@ -823,6 +823,8 @@ class TestSMTPServer(smtpd.SMTPServer): :mod:`asyncore` module's global state. """ + address_family = socket_helper.get_family() + def __init__(self, addr, handler, poll_interval, sockmap): smtpd.SMTPServer.__init__(self, addr, None, map=sockmap, decode_data=True) @@ -937,6 +939,9 @@ class TestHTTPServer(ControlMixin, HTTPServer): :param poll_interval: The polling interval in seconds. :param log: Pass ``True`` to enable log messages. """ + + address_family = socket_helper.get_family() + def __init__(self, addr, handler, poll_interval=0.5, log=False, sslctx=None): class DelegatingHTTPRequestHandler(BaseHTTPRequestHandler): @@ -3231,9 +3236,9 @@ def setup_via_listener(self, text, verify=None): port = t.port t.ready.clear() try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock = socket_helper.tcp_socket() sock.settimeout(2.0) - sock.connect(('localhost', port)) + sock.connect((socket_helper.HOST, port)) slen = struct.pack('>L', len(text)) s = slen + text From 6d50fa509ef967295a242b92175e97b74cc9629b Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Tue, 18 May 2021 19:05:38 -0700 Subject: [PATCH 23/28] Make test_xmlrpc work on IPv6-only hosts rather than hang. --- Lib/test/test_xmlrpc.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/Lib/test/test_xmlrpc.py b/Lib/test/test_xmlrpc.py index a9f67466071bc6..34b8a154af3f9a 100644 --- a/Lib/test/test_xmlrpc.py +++ b/Lib/test/test_xmlrpc.py @@ -337,7 +337,10 @@ def run_server(): server.handle_request() # First request and attempt at second server.handle_request() # Retried second request - server = http.server.HTTPServer((socket_helper.HOST, 0), RequestHandler) + class IPvWhateverHTTPServer(http.server.HTTPServer): + address_family = socket_helper.get_family() + + server = IPvWhateverHTTPServer((socket_helper.HOST, 0), RequestHandler) self.addCleanup(server.server_close) thread = threading.Thread(target=run_server) thread.start() @@ -606,6 +609,9 @@ def getData(): return '42' class MyXMLRPCServer(xmlrpc.server.SimpleXMLRPCServer): + + address_family = socket_helper.get_family() + def get_request(self): # Ensure the socket is always non-blocking. On Linux, socket # attributes are not inherited like they are on *BSD and Windows. @@ -615,13 +621,13 @@ def get_request(self): if not requestHandler: requestHandler = xmlrpc.server.SimpleXMLRPCRequestHandler - serv = MyXMLRPCServer(("localhost", 0), requestHandler, + serv = MyXMLRPCServer((socket_helper.HOST, 0), requestHandler, encoding=encoding, logRequests=False, bind_and_activate=False) try: serv.server_bind() global ADDR, PORT, URL - ADDR, PORT = serv.socket.getsockname() + ADDR, PORT = serv.socket.getsockname()[:2] #connect to IP address directly. This avoids socket.create_connection() #trying to connect to "localhost" using all address families, which #causes slowdown e.g. on vista which supports AF_INET6. The server listens @@ -669,6 +675,9 @@ def my_function(): return True class MyXMLRPCServer(xmlrpc.server.MultiPathXMLRPCServer): + + address_family = socket_helper.get_family() + def get_request(self): # Ensure the socket is always non-blocking. On Linux, socket # attributes are not inherited like they are on *BSD and Windows. @@ -685,13 +694,13 @@ class BrokenDispatcher: def _marshaled_dispatch(self, data, dispatch_method=None, path=None): raise RuntimeError("broken dispatcher") - serv = MyXMLRPCServer(("localhost", 0), MyRequestHandler, + serv = MyXMLRPCServer((socket_helper.HOST, 0), MyRequestHandler, logRequests=False, bind_and_activate=False) serv.socket.settimeout(3) serv.server_bind() try: global ADDR, PORT, URL - ADDR, PORT = serv.socket.getsockname() + ADDR, PORT = serv.socket.getsockname()[:2] #connect to IP address directly. This avoids socket.create_connection() #trying to connect to "localhost" using all address families, which #causes slowdown e.g. on vista which supports AF_INET6. The server listens @@ -1498,7 +1507,11 @@ def test_cgihandler_has_use_builtin_types_flag(self): self.assertTrue(handler.use_builtin_types) def test_xmlrpcserver_has_use_builtin_types_flag(self): - server = xmlrpc.server.SimpleXMLRPCServer(("localhost", 0), + + class IPvWhateverSimpleXMLRPCServer(xmlrpc.server.SimpleXMLRPCServer): + address_family = socket_helper.get_family() + + server = IPvWhateverSimpleXMLRPCServer((socket_helper.HOST, 0), use_builtin_types=True) server.server_close() self.assertTrue(server.use_builtin_types) From be88847bf2af269c4460af5ce548bb23c6b36342 Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Tue, 18 May 2021 19:12:30 -0700 Subject: [PATCH 24/28] Prevent test_asyncio from hanging on an IPv6-only host. --- Lib/test/test_asyncio/test_streams.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 6eaa2899442184..a53dec889e8b15 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -815,8 +815,10 @@ def test_drain_raises(self): def server(): # Runs in a separate thread. - with socket.create_server(('localhost', 0)) as sock: - addr = sock.getsockname() + with socket_helper.bind_ip_socket_and_port() as sock_port: + sock = sock_port[0] + sock.listen() + addr = sock.getsockname()[:2] q.put(addr) clt, _ = sock.accept() clt.close() From 5d0f8d786bc92c3804022b522f131991f9208894 Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Tue, 18 May 2021 23:24:01 -0700 Subject: [PATCH 25/28] Fix test_asyncio.test_streams to work on IPv6-only. --- Lib/test/test_asyncio/test_streams.py | 30 +++++++++++++++------------ Lib/test/test_asyncio/utils.py | 8 ++++--- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index a53dec889e8b15..b8a49699b4b956 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -56,7 +56,7 @@ def _basetest_open_connection(self, open_connection_fut): def test_open_connection(self): with test_utils.run_test_server() as httpd: - conn_fut = asyncio.open_connection(*httpd.address) + conn_fut = asyncio.open_connection(*httpd.address[:2]) self._basetest_open_connection(conn_fut) @socket_helper.skip_unless_bind_unix_socket @@ -84,8 +84,8 @@ def _basetest_open_connection_no_loop_ssl(self, open_connection_fut): def test_open_connection_no_loop_ssl(self): with test_utils.run_test_server(use_ssl=True) as httpd: conn_fut = asyncio.open_connection( - *httpd.address, - ssl=test_utils.dummy_ssl_context()) + *httpd.address[:2], + ssl=test_utils.dummy_ssl_context()) self._basetest_open_connection_no_loop_ssl(conn_fut) @@ -115,7 +115,7 @@ def _basetest_open_connection_error(self, open_connection_fut): def test_open_connection_error(self): with test_utils.run_test_server() as httpd: - conn_fut = asyncio.open_connection(*httpd.address) + conn_fut = asyncio.open_connection(*httpd.address[:2]) self._basetest_open_connection_error(conn_fut) @socket_helper.skip_unless_bind_unix_socket @@ -582,19 +582,23 @@ async def handle_client(self, client_reader, client_writer): await client_writer.wait_closed() def start(self): - sock = socket.create_server(('127.0.0.1', 0)) + sock = socket.create_server( + (socket_helper.HOST, 0), + family=socket_helper.get_family()) self.server = self.loop.run_until_complete( asyncio.start_server(self.handle_client, sock=sock)) - return sock.getsockname() + return sock.getsockname()[:2] def handle_client_callback(self, client_reader, client_writer): self.loop.create_task(self.handle_client(client_reader, client_writer)) def start_callback(self): - sock = socket.create_server(('127.0.0.1', 0)) - addr = sock.getsockname() + sock = socket.create_server( + (socket_helper.HOST, 0), + family=socket_helper.get_family()) + addr = sock.getsockname()[:2] sock.close() self.server = self.loop.run_until_complete( asyncio.start_server(self.handle_client_callback, @@ -909,7 +913,7 @@ def test_LimitOverrunError_pickleable(self): def test_wait_closed_on_close(self): with test_utils.run_test_server() as httpd: rd, wr = self.loop.run_until_complete( - asyncio.open_connection(*httpd.address)) + asyncio.open_connection(*httpd.address[:2])) wr.write(b'GET / HTTP/1.0\r\n\r\n') f = rd.readline() @@ -926,7 +930,7 @@ def test_wait_closed_on_close(self): def test_wait_closed_on_close_with_unread_data(self): with test_utils.run_test_server() as httpd: rd, wr = self.loop.run_until_complete( - asyncio.open_connection(*httpd.address)) + asyncio.open_connection(*httpd.address[:2])) wr.write(b'GET / HTTP/1.0\r\n\r\n') f = rd.readline() @@ -937,7 +941,7 @@ def test_wait_closed_on_close_with_unread_data(self): def test_async_writer_api(self): async def inner(httpd): - rd, wr = await asyncio.open_connection(*httpd.address) + rd, wr = await asyncio.open_connection(*httpd.address[:2]) wr.write(b'GET / HTTP/1.0\r\n\r\n') data = await rd.readline() @@ -957,7 +961,7 @@ async def inner(httpd): def test_async_writer_api_exception_after_close(self): async def inner(httpd): - rd, wr = await asyncio.open_connection(*httpd.address) + rd, wr = await asyncio.open_connection(*httpd.address[:2]) wr.write(b'GET / HTTP/1.0\r\n\r\n') data = await rd.readline() @@ -984,7 +988,7 @@ def test_eof_feed_when_closing_writer(self): with test_utils.run_test_server() as httpd: rd, wr = self.loop.run_until_complete( - asyncio.open_connection(*httpd.address)) + asyncio.open_connection(*httpd.address[:2])) wr.close() f = wr.wait_closed() diff --git a/Lib/test/test_asyncio/utils.py b/Lib/test/test_asyncio/utils.py index 3765194cd0dd27..9b5e4ceb0e05a3 100644 --- a/Lib/test/test_asyncio/utils.py +++ b/Lib/test/test_asyncio/utils.py @@ -18,6 +18,7 @@ import weakref from unittest import mock +from test.support import socket_helper from http.server import HTTPServer from wsgiref.simple_server import WSGIRequestHandler, WSGIServer @@ -140,6 +141,7 @@ def log_message(self, format, *args): class SilentWSGIServer(WSGIServer): + address_family = socket_helper.get_family() request_timeout = support.LOOPBACK_TIMEOUT def get_request(self): @@ -215,7 +217,7 @@ class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer): def server_bind(self): socketserver.UnixStreamServer.server_bind(self) - self.server_name = '127.0.0.1' + self.server_name = socket_helper.HOST self.server_port = 80 @@ -236,7 +238,7 @@ def get_request(self): # as the second return value will be a path; # hence we return some fake data sufficient # to get the tests going - return request, ('127.0.0.1', '') + return request, (socket_helper.HOST, '') class SilentUnixWSGIServer(UnixWSGIServer): @@ -275,7 +277,7 @@ def run_test_unix_server(*, use_ssl=False): @contextlib.contextmanager -def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): +def run_test_server(*, host=socket_helper.HOST, port=0, use_ssl=False): yield from _run_test_server(address=(host, port), use_ssl=use_ssl, server_cls=SilentWSGIServer, server_ssl_cls=SSLWSGIServer) From 0b556daf63b0690866117cd914b502562d853c4a Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Wed, 19 May 2021 00:24:46 -0700 Subject: [PATCH 26/28] Fix socket_helper sock vs tempsock paste error.. --- Lib/test/support/socket_helper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Lib/test/support/socket_helper.py b/Lib/test/support/socket_helper.py index 2683628253b6d1..d673a6ddc6e591 100644 --- a/Lib/test/support/socket_helper.py +++ b/Lib/test/support/socket_helper.py @@ -97,8 +97,8 @@ def find_unused_port(family=None, socktype=socket.SOCK_STREAM): if not port: port = bind_port(tempsock) else: - sock.bind((host, 0)) - port = sock.getsockname()[1] + tempsock.bind((host, 0)) + port = tempsock.getsockname()[1] except OSError as err: errors[family] = err port = 0 From b7688e418723693a5dc35262f72ca66a18e718d9 Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Wed, 19 May 2021 09:10:46 -0700 Subject: [PATCH 27/28] socket_helper typo (yes, find_unused_port is too ugly...) --- Lib/test/support/socket_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/test/support/socket_helper.py b/Lib/test/support/socket_helper.py index d673a6ddc6e591..99e2d7e87573d3 100644 --- a/Lib/test/support/socket_helper.py +++ b/Lib/test/support/socket_helper.py @@ -97,7 +97,7 @@ def find_unused_port(family=None, socktype=socket.SOCK_STREAM): if not port: port = bind_port(tempsock) else: - tempsock.bind((host, 0)) + tempsock.bind((HOST, 0)) port = tempsock.getsockname()[1] except OSError as err: errors[family] = err From 2965d32982d8a86806b1738dcfe396d5e8b783e2 Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Thu, 20 May 2021 00:32:18 -0700 Subject: [PATCH 28/28] Undo find_unused_port complexity, use get_family() --- Lib/test/support/socket_helper.py | 65 ++++++++++--------------------- Lib/test/test_ftplib.py | 4 +- Lib/test/test_httplib.py | 2 +- Lib/test/test_logging.py | 2 +- Lib/test/test_smtplib.py | 2 +- Lib/test/test_socket.py | 22 +++++------ Lib/test/test_support.py | 16 ++++++-- 7 files changed, 48 insertions(+), 65 deletions(-) diff --git a/Lib/test/support/socket_helper.py b/Lib/test/support/socket_helper.py index 99e2d7e87573d3..c3e430f8ec1fed 100644 --- a/Lib/test/support/socket_helper.py +++ b/Lib/test/support/socket_helper.py @@ -20,20 +20,20 @@ def find_unused_port(family=None, socktype=socket.SOCK_STREAM): eliciting an unused ephemeral port from the OS. The temporary socket is then closed and deleted, and the ephemeral port is returned. - When family is None it will use whichever of socket.AF_INET or - socket.AF_INET6 makes sense, finding a port available on both if possible. + When family is None, we use to the result of get_family() instead. Either this method or bind_port() should be used for any tests where a server socket needs to be bound to a particular port for the duration of the test. Which one to use depends on whether the calling code is creating a python socket, or if an unused port needs to be provided in a constructor or passed to an external program (i.e. the -accept argument to openssl's - s_server mode). Always prefer bind_port() over find_unused_port() where - possible. Hard coded ports should *NEVER* be used. As soon as a server - socket is bound to a hard coded port, the ability to run multiple instances - of the test simultaneously on the same host is compromised, which makes the - test a ticking time bomb in a buildbot environment. On Unix buildbots, this - may simply manifest as a failed test, which can be recovered from without + s_server mode). Always prefer bind_port(), bind_ip_socket_and_port(), + and get_bound_ip_socket_and_port() over find_unused_port() where possible. + Hard coded ports should *NEVER* be used. As soon as a server socket is + bound to a hard coded port, the ability to run multiple instances of the + test simultaneously on the same host is compromised, which makes the test a + ticking time bomb in a buildbot environment. On Unix buildbots, this may + simply manifest as a failed test, which can be recovered from without intervention in most cases, but on Windows, the entire python process can completely and utterly wedge, requiring someone to log in to the buildbot and manually kill the affected process. @@ -71,43 +71,18 @@ def find_unused_port(family=None, socktype=socket.SOCK_STREAM): issue if/when we come across it. TODO(gpshead): We should support a https://pypi.org/project/portpicker/ - portserver or equivalent process running on our buildbot hosts and use that - that portpicker library... + portserver or equivalent running on our buildbot workers and use that + that for more reliability at avoiding conflicts between parallel tests. """ - if isinstance(family, int): - with socket.socket(family, socktype) as tempsock: - port = bind_port(tempsock) - del tempsock - else: - if family is not None: # Assume it's a sequence, it wasn't int|None. - families = family - else: - families = [] - if IPV4_ENABLED: - families.append(socket.AF_INET) - if IPV6_ENABLED: - families.append(socket.AF_INET6) - assert families, "At least one of IPv4 or IPv6 must be enabled." - port = 0 - errors = {} - for family in families: - try: - with socket.socket(family, socktype) as tempsock: - if not port: - port = bind_port(tempsock) - else: - tempsock.bind((HOST, 0)) - port = tempsock.getsockname()[1] - except OSError as err: - errors[family] = err - port = 0 - del tempsock - if not port: - raise support.TestFailed( - f"Could not bind to a port: {errors}") + if family is None: + family = get_family() + with socket.socket(family, socktype) as tempsock: + port = bind_port(tempsock) + del tempsock return port + def bind_port(sock, host=HOST): """Bind the socket to a free port and return the port number. Relies on ephemeral ports in order to ensure we are using an unbound port. This is @@ -189,10 +164,10 @@ def get_bound_ip_socket_and_port(*, hostname=HOST, socktype=socket.SOCK_STREAM): @contextlib.contextmanager def bind_ip_socket_and_port(*, hostname=HOST, socktype=socket.SOCK_STREAM): - """ - A context manager that creates a socket of socktype bound to hostname - using whichever of IPv6 or IPv4 is available. Context is a (socket, port) - tuple. Exiting the context closes the socket. + """A context manager that creates a socket of socktype. + + It uses whichever of IPv6 or IPv4 is available based on get_family(). + Context is a (socket, port) tuple. The socket is closed on context exit. """ sock, port = get_bound_ip_socket_and_port( hostname=hostname, socktype=socktype) diff --git a/Lib/test/test_ftplib.py b/Lib/test/test_ftplib.py index 53a8485be1256e..93e64b1ec27841 100644 --- a/Lib/test/test_ftplib.py +++ b/Lib/test/test_ftplib.py @@ -788,7 +788,7 @@ def is_client_connected(): def test_source_address(self): self.client.quit() - port = socket_helper.find_unused_port(family=None) + port = socket_helper.find_unused_port() try: self.client.connect(self.server.host, self.server.port, source_address=(HOST, port)) @@ -800,7 +800,7 @@ def test_source_address(self): raise def test_source_address_passive_connection(self): - port = socket_helper.find_unused_port(family=None) + port = socket_helper.find_unused_port() self.client.source_address = (HOST, port) try: with self.client.transfercmd('list') as sock: diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py index a6921be808983d..790d8feb842a87 100644 --- a/Lib/test/test_httplib.py +++ b/Lib/test/test_httplib.py @@ -1675,7 +1675,7 @@ def test_client_constants(self): class SourceAddressTest(TestCase): def setUp(self): self.serv, self.port = socket_helper.get_bound_ip_socket_and_port() - self.source_port = socket_helper.find_unused_port() + self.source_port = socket_helper.find_unused_port(family=self.serv.family) self.serv.listen() self.conn = None diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py index 6c50763ba10ae3..21323c722db925 100644 --- a/Lib/test/test_logging.py +++ b/Lib/test/test_logging.py @@ -939,7 +939,7 @@ class TestHTTPServer(ControlMixin, HTTPServer): :param poll_interval: The polling interval in seconds. :param log: Pass ``True`` to enable log messages. """ - + address_family = socket_helper.get_family() def __init__(self, addr, handler, poll_interval=0.5, diff --git a/Lib/test/test_smtplib.py b/Lib/test/test_smtplib.py index d800b37c3ead5c..37ace956c9a152 100644 --- a/Lib/test/test_smtplib.py +++ b/Lib/test/test_smtplib.py @@ -283,7 +283,7 @@ def testBasic(self): def testSourceAddress(self): # connect - src_port = socket_helper.find_unused_port() + src_port = socket_helper.find_unused_port(family=self.serv.socket.family) try: smtp = smtplib.SMTP(self.host, self.port, local_hostname='localhost', timeout=support.LOOPBACK_TIMEOUT, diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 3d6b575581a248..949105783ac5f8 100755 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -1397,7 +1397,7 @@ def testStringToIPv6(self): @unittest.skipUnless(socket_helper.IPV4_ENABLED, 'IPv4 required') def testSockName(self): # Testing getsockname() - port = socket_helper.find_unused_port() + port = socket_helper.find_unused_port(family=socket.AF_INET) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.addCleanup(sock.close) sock.bind(("0.0.0.0", port)) @@ -1464,7 +1464,7 @@ def testNewAttributes(self): def test_getsockaddrarg(self): sock = socket_helper.tcp_socket() self.addCleanup(sock.close) - port = socket_helper.find_unused_port() + port = socket_helper.find_unused_port(family=sock.family) big_port = port + 65536 neg_port = port - 65536 with self.assertRaises(OverflowError): @@ -1474,7 +1474,7 @@ def test_getsockaddrarg(self): # Since find_unused_port() is inherently subject to race conditions, we # call it a couple times if necessary. for i in itertools.count(): - port = socket_helper.find_unused_port() + port = socket_helper.find_unused_port(family=sock.family) try: sock.bind((HOST, port)) except OSError as e: @@ -6493,14 +6493,14 @@ class CreateServerTest(unittest.TestCase): @unittest.skipUnless(socket_helper.IPV4_ENABLED, 'IPv4 required') def test_address_ipv4(self): - port = socket_helper.find_unused_port() + port = socket_helper.find_unused_port(family=socket.AF_INET) with socket.create_server(("127.0.0.1", port)) as sock: self.assertEqual(sock.getsockname()[0], "127.0.0.1") self.assertEqual(sock.getsockname()[1], port) @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required') def test_address_ipv6(self): - port = socket_helper.find_unused_port() + port = socket_helper.find_unused_port(family=socket.AF_INET6) with socket.create_server(("::1", port), family=socket.AF_INET6) as sock: self.assertEqual(sock.getsockname()[0], "::1") @@ -6586,14 +6586,14 @@ def echo_client(self, addr, family): @unittest.skipUnless(socket_helper.IPV4_ENABLED, 'IPv4 required') def test_tcp4(self): - port = socket_helper.find_unused_port() + port = socket_helper.find_unused_port(family=socket.AF_INET) with socket.create_server(("", port)) as sock: self.echo_server(sock) self.echo_client(("127.0.0.1", port), socket.AF_INET) @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test') def test_tcp6(self): - port = socket_helper.find_unused_port() + port = socket_helper.find_unused_port(family=socket.AF_INET6) with socket.create_server(("", port), family=socket.AF_INET6) as sock: self.echo_server(sock) @@ -6606,9 +6606,9 @@ def test_tcp6(self): @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test') @unittest.skipUnless(socket_helper.IPV4_ENABLED, 'IPv4 required.') def test_dual_stack_client_v4(self): - port = socket_helper.find_unused_port() - with socket.create_server(("", port), family=socket.AF_INET6, + with socket.create_server(("", 0), family=socket.AF_INET6, dualstack_ipv6=True) as sock: + port = sock.getsockname()[1] self.echo_server(sock) self.echo_client(("127.0.0.1", port), socket.AF_INET) @@ -6616,9 +6616,9 @@ def test_dual_stack_client_v4(self): "dualstack_ipv6 not supported") @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test') def test_dual_stack_client_v6(self): - port = socket_helper.find_unused_port() - with socket.create_server(("", port), family=socket.AF_INET6, + with socket.create_server(("", 0), family=socket.AF_INET6, dualstack_ipv6=True) as sock: + port = sock.getsockname()[1] self.echo_server(sock) self.echo_client(("::1", port), socket.AF_INET6) diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py index d4167fd31082fc..d0231e783e2834 100644 --- a/Lib/test/test_support.py +++ b/Lib/test/test_support.py @@ -102,16 +102,24 @@ def test_bind_ip_socket_and_port_HOST(self): @unittest.skipUnless(socket_helper.IPV4_ENABLED, "IPv4 required") def test_find_unused_port_ipv4(self): - port = socket_helper.find_unused_port() + port = socket_helper.find_unused_port(family=socket.AF_INET) s = socket.create_server((socket_helper.HOST, port)) s.close() @unittest.skipUnless(socket_helper.IPV6_ENABLED, "IPv6 required") def test_find_unused_port_ipv6(self): + port = socket_helper.find_unused_port(family=socket.AF_INET6) + s = socket.create_server( + (socket_helper.HOST, port), + family=socket.AF_INET6) + s.close() + + def test_find_unused_port_noargs(self): port = socket_helper.find_unused_port() - with socket.socket(socket.AF_INET6) as s: - s.bind((socket_helper.HOST, port)) - s.listen() + s = socket.create_server( + (socket_helper.HOST, port), + family=socket_helper.get_family()) + s.close() @unittest.skipUnless(socket_helper.IPV4_ENABLED, "IPv4 required") def test_bind_port_ipv4(self):