Skip to content

BUG: assert_almost_equal fails on subclasses that cannot handle bool #8452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 10, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions numpy/testing/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,24 @@ def test_subclass(self):
a = np.array([[1., 2.], [3., 4.]])
b = np.ma.masked_array([[1., 2.], [0., 4.]],
[[False, False], [True, False]])
assert_array_almost_equal(a, b)
assert_array_almost_equal(b, a)
assert_array_almost_equal(b, b)
self._assert_func(a, b)
self._assert_func(b, a)
self._assert_func(b, b)

def test_subclass_that_cannot_be_bool(self):
# While we cannot guarantee testing functions will always work for
# subclasses, the tests should ideally rely only on subclasses having
# comparison operators, not on them being able to store booleans
# (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
class MyArray(np.ndarray):
def __lt__(self, other):
return super(MyArray, self).__lt__(other).view(np.ndarray)

def all(self, *args, **kwargs):
raise NotImplementedError

a = np.array([1., 2.]).view(MyArray)
self._assert_func(a, a)


class TestAlmostEqual(_GenericTest, unittest.TestCase):
Expand Down Expand Up @@ -387,6 +402,21 @@ def test_error_message(self):
# remove anything that's not the array string
self.assertEqual(str(e).split('%)\n ')[1], b)

def test_subclass_that_cannot_be_bool(self):
# While we cannot guarantee testing functions will always work for
# subclasses, the tests should ideally rely only on subclasses having
# comparison operators, not on them being able to store booleans
# (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
class MyArray(np.ndarray):
def __lt__(self, other):
return super(MyArray, self).__lt__(other).view(np.ndarray)

def all(self, *args, **kwargs):
raise NotImplementedError

a = np.array([1., 2.]).view(MyArray)
self._assert_func(a, a)


class TestApproxEqual(unittest.TestCase):

Expand Down
21 changes: 7 additions & 14 deletions numpy/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,8 +669,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
header='', precision=6, equal_nan=True,
equal_inf=True):
__tracebackhide__ = True # Hide traceback for py.test
from numpy.core import array, isnan, isinf, any, all, inf, zeros_like
from numpy.core.numerictypes import bool_
from numpy.core import array, isnan, isinf, any, inf
x = array(x, copy=False, subok=True)
y = array(y, copy=False, subok=True)

Expand Down Expand Up @@ -726,14 +725,13 @@ def chk_same_position(x_id, y_id, hasval='nan'):
raise AssertionError(msg)

if isnumber(x) and isnumber(y):
x_id, y_id = zeros_like(x, dtype=bool_), zeros_like(y, dtype=bool_)
if equal_nan:
x_isnan, y_isnan = isnan(x), isnan(y)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A more trivial change would replace just this line with

x_id = zeros(x.shape, dtype=bool_)
y_id = zeros(y.shape, dtype=bool_)

but it seems unnecessary to create these arrays in the first place -- and indeed the code becomes quite a bit shorter and, to me at least, clearer by just removing elements from the arrays.

# Validate that NaNs are in the same place
if any(x_isnan) or any(y_isnan):
chk_same_position(x_isnan, y_isnan, hasval='nan')
x_id |= x_isnan
y_id |= y_isnan
x = x[~x_isnan]
y = y[~y_isnan]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I am lazy now, but do these assert functions check the shape exactly? Because otherwise, I think there may be a problem with broadcasting if x and y are only broadcastable to one another.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the chk_same_position ensures that, when broadcast, x_isnan and y_isnan take out the same elements. As a result, x and y will still broadcast against each other.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, shape is checked exactly in any case, so no worries.


if equal_inf:
x_isinf, y_isinf = isinf(x), isinf(y)
Expand All @@ -742,19 +740,14 @@ def chk_same_position(x_id, y_id, hasval='nan'):
# Check +inf and -inf separately, since they are different
chk_same_position(x == +inf, y == +inf, hasval='+inf')
chk_same_position(x == -inf, y == -inf, hasval='-inf')
x_id |= x_isinf
y_id |= y_isinf
x = x[~x_isinf]
y = y[~y_isinf]

# Only do the comparison if actual values are left
if all(x_id):
if x.size == 0:
return

if any(x_id):
val = safe_comparison(x[~x_id], y[~y_id])
else:
val = safe_comparison(x, y)
else:
val = safe_comparison(x, y)
val = safe_comparison(x, y)

if isinstance(val, bool):
cond = val
Expand Down