Skip to content

MNT Use a pytest-like context manager in estimator_checks.py #18418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Sep 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,9 @@ def test_strict_mode_check_estimator():

# MyNMF will fail check_fit_non_negative() in strict mode because it yields
# a bad error message
with pytest.raises(AssertionError, match='does not match'):
with pytest.raises(
AssertionError, match="The error message should contain"
):
check_estimator(MyNMFWithBadErrorMessage(), strict_mode=True)
# However, it should pass the test suite in non-strict mode because when
# strict mode is off, check_fit_non_negative() will not check the exact
Expand Down
93 changes: 93 additions & 0 deletions sklearn/utils/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import tempfile
from subprocess import check_output, STDOUT, CalledProcessError
from subprocess import TimeoutExpired
import re
import contextlib
from collections.abc import Iterable

import scipy as sp
from functools import wraps
Expand Down Expand Up @@ -769,3 +772,93 @@ def _convert_container(container, constructor_name, columns_name=None):
return pd.Index(container)
elif constructor_name == 'slice':
return slice(container[0], container[1])


def raises(expected_exc_type, match=None, may_pass=False, err_msg=None):
"""Context manager to ensure exceptions are raised within a code block.

This is similar to and inspired from pytest.raises, but supports a few
other cases.

This is only intended to be used in estimator_checks.py where we don't
want to use pytest. In the rest of the code base, just use pytest.raises
instead.

Parameters
----------
excepted_exc_type : Exception or list of Exception
The exception that should be raised by the block. If a list, the block
should raise one of the exceptions.
match : str or list of str, default=None
A regex that the exception message should match. If a list, one of
the entries must match. If None, match isn't enforced.
may_pass : bool, default=False
If True, the block is allowed to not raise an exception. Useful in
cases where some estimators may support a feature but others must
fail with an appropriate error message. By default, the context
manager will raise an exception if the block does not raise an
exception.
err_msg : str, default=None
If the context manager fails (e.g. the block fails to raise the
proper exception, or fails to match), then an AssertionError is
raised with this message. By default, an AssertionError is raised
with a default error message (depends on the kind of failure). Use
this to indicate how users should fix their estimators to pass the
checks.

Attributes
----------
raised_and_matched : bool
True if an exception was raised and a match was found, False otherwise.
"""
return _Raises(expected_exc_type, match, may_pass, err_msg)


class _Raises(contextlib.AbstractContextManager):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you explore the contextlib.contextmanager implementation and found this to be cleaner?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe @contextlib.contextmanager allows to define the __exit__ function, which is the core of the CM for catching exceptions

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For ref this is the pytest implem: https://github.com/pytest-dev/pytest/blob/master/src/_pytest/python_api.py#L702 (I kept it quite similar)

# see raises() for parameters
def __init__(self, expected_exc_type, match, may_pass, err_msg):
self.expected_exc_types = (
expected_exc_type
if isinstance(expected_exc_type, Iterable)
else [expected_exc_type]
)
self.matches = [match] if isinstance(match, str) else match
self.may_pass = may_pass
self.err_msg = err_msg
self.raised_and_matched = False

def __exit__(self, exc_type, exc_value, _):
# see
# https://docs.python.org/2.5/whatsnew/pep-343.html#SECTION000910000000000000000

if exc_type is None: # No exception was raised in the block
if self.may_pass:
return True # CM is happy
else:
err_msg = (
self.err_msg or f"Did not raise: {self.expected_exc_types}"
)
raise AssertionError(err_msg)

if not any(
issubclass(exc_type, expected_type)
for expected_type in self.expected_exc_types
):
if self.err_msg is not None:
raise AssertionError(self.err_msg) from exc_value
else:
return False # will re-raise the original exception

if self.matches is not None:
err_msg = self.err_msg or (
"The error message should contain one of the following "
"patterns:\n{}\nGot {}".format(
"\n".join(self.matches), str(exc_value)
)
)
if not any(re.search(match, str(exc_value))
for match in self.matches):
raise AssertionError(err_msg) from exc_value
self.raised_and_matched = True

return True
Loading