From fe46cd691cde2a707ea727d65414cf7285969bbe Mon Sep 17 00:00:00 2001 From: Marten van Kerkwijk Date: Fri, 6 Jan 2017 17:12:26 -0500 Subject: [PATCH] BUG assert_almost_equal fails on subclasses that cannot handle bool gh-8410 breaks a large number of astropy tests, because it sets up a boolean array for values that should actually be compared (i.e., are not `nan` or `inf`) using `zeros_like`. The latter means that for subclasses, the boolean test array is not a plain `ndarray` but the subclass. But for astropy's `Quantity`, the `all` method is undefined. This commit ensures the test arrays from `isinf` and `isnan` are used directly. --- numpy/testing/tests/test_utils.py | 36 ++++++++++++++++++++++++++++--- numpy/testing/utils.py | 21 ++++++------------ 2 files changed, 40 insertions(+), 17 deletions(-) diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 4ca6c63543e1..a05fc3bdb30f 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -299,9 +299,24 @@ def test_subclass(self): a = np.array([[1., 2.], [3., 4.]]) b = np.ma.masked_array([[1., 2.], [0., 4.]], [[False, False], [True, False]]) - assert_array_almost_equal(a, b) - assert_array_almost_equal(b, a) - assert_array_almost_equal(b, b) + self._assert_func(a, b) + self._assert_func(b, a) + self._assert_func(b, b) + + def test_subclass_that_cannot_be_bool(self): + # While we cannot guarantee testing functions will always work for + # subclasses, the tests should ideally rely only on subclasses having + # comparison operators, not on them being able to store booleans + # (which, e.g., astropy Quantity cannot usefully do). See gh-8452. + class MyArray(np.ndarray): + def __lt__(self, other): + return super(MyArray, self).__lt__(other).view(np.ndarray) + + def all(self, *args, **kwargs): + raise NotImplementedError + + a = np.array([1., 2.]).view(MyArray) + self._assert_func(a, a) class TestAlmostEqual(_GenericTest, unittest.TestCase): @@ -387,6 +402,21 @@ def test_error_message(self): # remove anything that's not the array string self.assertEqual(str(e).split('%)\n ')[1], b) + def test_subclass_that_cannot_be_bool(self): + # While we cannot guarantee testing functions will always work for + # subclasses, the tests should ideally rely only on subclasses having + # comparison operators, not on them being able to store booleans + # (which, e.g., astropy Quantity cannot usefully do). See gh-8452. + class MyArray(np.ndarray): + def __lt__(self, other): + return super(MyArray, self).__lt__(other).view(np.ndarray) + + def all(self, *args, **kwargs): + raise NotImplementedError + + a = np.array([1., 2.]).view(MyArray) + self._assert_func(a, a) + class TestApproxEqual(unittest.TestCase): diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index 7858eefac625..a44a51c812b6 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -669,8 +669,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', precision=6, equal_nan=True, equal_inf=True): __tracebackhide__ = True # Hide traceback for py.test - from numpy.core import array, isnan, isinf, any, all, inf, zeros_like - from numpy.core.numerictypes import bool_ + from numpy.core import array, isnan, isinf, any, inf x = array(x, copy=False, subok=True) y = array(y, copy=False, subok=True) @@ -726,14 +725,13 @@ def chk_same_position(x_id, y_id, hasval='nan'): raise AssertionError(msg) if isnumber(x) and isnumber(y): - x_id, y_id = zeros_like(x, dtype=bool_), zeros_like(y, dtype=bool_) if equal_nan: x_isnan, y_isnan = isnan(x), isnan(y) # Validate that NaNs are in the same place if any(x_isnan) or any(y_isnan): chk_same_position(x_isnan, y_isnan, hasval='nan') - x_id |= x_isnan - y_id |= y_isnan + x = x[~x_isnan] + y = y[~y_isnan] if equal_inf: x_isinf, y_isinf = isinf(x), isinf(y) @@ -742,19 +740,14 @@ def chk_same_position(x_id, y_id, hasval='nan'): # Check +inf and -inf separately, since they are different chk_same_position(x == +inf, y == +inf, hasval='+inf') chk_same_position(x == -inf, y == -inf, hasval='-inf') - x_id |= x_isinf - y_id |= y_isinf + x = x[~x_isinf] + y = y[~y_isinf] # Only do the comparison if actual values are left - if all(x_id): + if x.size == 0: return - if any(x_id): - val = safe_comparison(x[~x_id], y[~y_id]) - else: - val = safe_comparison(x, y) - else: - val = safe_comparison(x, y) + val = safe_comparison(x, y) if isinstance(val, bool): cond = val