Skip to content

gh-110012: Fix RuntimeWarning in test_socket #110013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions Lib/test/test_socket.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import unittest
from test import support
from test.support import (
is_apple, os_helper, refleak_helper, socket_helper, threading_helper
is_apple, os_helper, refleak_helper, socket_helper, threading_helper,
warnings_helper,
)
import _thread as thread
import array
Expand Down Expand Up @@ -198,6 +199,16 @@ def socket_setdefaulttimeout(timeout):
socket.setdefaulttimeout(old_timeout)


@contextlib.contextmanager
def catch_malformed_data_warning(quiet=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it catch them or ignore them? What is the behavior of quiet=False?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

def _filterwarnings(filters, quiet=False):
"""Catch the warnings, then check if all the expected
warnings have been raised and re-raise unexpected warnings.
If 'quiet' is True, only re-raise the unexpected warnings.
"""
# Clear the warning registry of the calling module
# in order to re-raise the warnings.
frame = sys._getframe(2)
registry = frame.f_globals.get('__warningregistry__')
if registry:
registry.clear()
with warnings.catch_warnings(record=True) as w:
# Set filter "always" to record all warnings. Because
# test_warnings swap the module, we need to look up in
# the sys.modules dictionary.
sys.modules['warnings'].simplefilter("always")
yield WarningsRecorder(w)
# Filter the recorded warnings
reraise = list(w)
missing = []
for msg, cat in filters:
seen = False
for w in reraise[:]:
warning = w.message
# Filter out the matching messages
if (re.match(msg, str(warning), re.I) and
issubclass(warning.__class__, cat)):
seen = True
reraise.remove(w)
if not seen and not quiet:
# This filter caught nothing
missing.append((msg, cat.__name__))
if reraise:
raise AssertionError("unhandled warning %s" % reraise[0])
if missing:
raise AssertionError("filter (%r, %s) did not catch any warning" %
missing[0])

  1. It catches warnings with warnings.catch_warnings(record=True)
  2. Then it filters out ones that matched ("received malformed or improperly-truncated ancillary data", RuntimeWarning)
  3. If there are any left, the test fails
  4. quiet=True handles the cases where no warning was raised - just by ignoring the last check, if quiet=False we are required to have at least one matched warning

# This warning happens on macos and win, but does not always happen on linux.
with warnings_helper.check_warnings(
("received malformed or improperly-truncated ancillary data", RuntimeWarning),
quiet=quiet,
):
yield


HAVE_SOCKET_CAN = _have_socket_can()

HAVE_SOCKET_CAN_ISOTP = _have_socket_can_isotp()
Expand Down Expand Up @@ -3946,8 +3957,9 @@ def checkTruncatedArray(self, ancbuf, maxdata, mindata=0):
# mindata and maxdata bytes when received with buffer size
# ancbuf, and that any complete file descriptor numbers are
# valid.
msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
len(MSG), ancbuf)
with catch_malformed_data_warning():
msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
len(MSG), ancbuf)
self.assertEqual(msg, MSG)
self.checkRecvmsgAddress(addr, self.cli_addr)
self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC)
Expand Down Expand Up @@ -4298,8 +4310,9 @@ def testSingleCmsgTruncInData(self):
self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
socket.IPV6_RECVHOPLIMIT, 1)
self.misc_event.set()
msg, ancdata, flags, addr = self.doRecvmsg(
self.serv_sock, len(MSG), socket.CMSG_LEN(SIZEOF_INT) - 1)
with catch_malformed_data_warning():
msg, ancdata, flags, addr = self.doRecvmsg(
self.serv_sock, len(MSG), socket.CMSG_LEN(SIZEOF_INT) - 1)

self.assertEqual(msg, MSG)
self.checkRecvmsgAddress(addr, self.cli_addr)
Expand Down Expand Up @@ -4402,9 +4415,10 @@ def testSecondCmsgTruncInData(self):
self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
socket.IPV6_RECVTCLASS, 1)
self.misc_event.set()
msg, ancdata, flags, addr = self.doRecvmsg(
self.serv_sock, len(MSG),
socket.CMSG_SPACE(SIZEOF_INT) + socket.CMSG_LEN(SIZEOF_INT) - 1)
with catch_malformed_data_warning():
msg, ancdata, flags, addr = self.doRecvmsg(
self.serv_sock, len(MSG),
socket.CMSG_SPACE(SIZEOF_INT) + socket.CMSG_LEN(SIZEOF_INT) - 1)

self.assertEqual(msg, MSG)
self.checkRecvmsgAddress(addr, self.cli_addr)
Expand Down
Loading