Skip to content

Commit 5ca26b1

Browse files
authored
Merge pull request #24770 from mdhaber/gh24680b
ENH: add parameter `strict` to `assert_equal`
2 parents b310253 + cef6970 commit 5ca26b1

File tree

4 files changed

+67
-10
lines changed

4 files changed

+67
-10
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
``strict`` option for `testing.assert_equal`
2+
--------------------------------------------
3+
The ``strict`` option is now available for `testing.assert_equal`.
4+
Setting ``strict=True`` will disable the broadcasting behaviour for scalars
5+
and ensure that input arrays have the same data type.

numpy/testing/_private/utils.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -209,18 +209,14 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:',
209209
return '\n'.join(msg)
210210

211211

212-
def assert_equal(actual, desired, err_msg='', verbose=True):
212+
def assert_equal(actual, desired, err_msg='', verbose=True, *, strict=False):
213213
"""
214214
Raises an AssertionError if two objects are not equal.
215215
216216
Given two objects (scalars, lists, tuples, dictionaries or numpy arrays),
217217
check that all elements of these objects are equal. An exception is raised
218218
at the first conflicting values.
219219
220-
When one of `actual` and `desired` is a scalar and the other is array_like,
221-
the function checks that each element of the array_like object is equal to
222-
the scalar.
223-
224220
This function handles NaN comparisons as if NaN was a "normal" number.
225221
That is, AssertionError is not raised if both objects have NaNs in the same
226222
positions. This is in contrast to the IEEE standard on NaNs, which says
@@ -236,15 +232,34 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
236232
The error message to be printed in case of failure.
237233
verbose : bool, optional
238234
If True, the conflicting values are appended to the error message.
235+
strict : bool, optional
236+
If True and either of the `actual` and `desired` arguments is an array,
237+
raise an ``AssertionError`` when either the shape or the data type of
238+
the arguments does not match. If neither argument is an array, this
239+
parameter has no effect.
240+
241+
.. versionadded:: 2.0.0
239242
240243
Raises
241244
------
242245
AssertionError
243246
If actual and desired are not equal.
244247
248+
See Also
249+
--------
250+
assert_allclose
251+
assert_array_almost_equal_nulp,
252+
assert_array_max_ulp,
253+
254+
Notes
255+
-----
256+
By default, when one of `actual` and `desired` is a scalar and the other is
257+
an array, the function checks that each element of the array is equal to
258+
the scalar. This behaviour can be disabled by setting ``strict==True``.
259+
245260
Examples
246261
--------
247-
>>> np.testing.assert_equal([4,5], [4,6])
262+
>>> np.testing.assert_equal([4, 5], [4, 6])
248263
Traceback (most recent call last):
249264
...
250265
AssertionError:
@@ -258,6 +273,40 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
258273
259274
>>> np.testing.assert_equal(np.array([1.0, 2.0, np.nan]), [1, 2, np.nan])
260275
276+
As mentioned in the Notes section, `assert_equal` has special
277+
handling for scalars when one of the arguments is an array.
278+
Here, the test checks that each value in `x` is 3:
279+
280+
>>> x = np.full((2, 5), fill_value=3)
281+
>>> np.testing.assert_equal(x, 3)
282+
283+
Use `strict` to raise an AssertionError when comparing a scalar with an
284+
array of a different shape:
285+
286+
>>> np.testing.assert_equal(x, 3, strict=True)
287+
Traceback (most recent call last):
288+
...
289+
AssertionError:
290+
Arrays are not equal
291+
<BLANKLINE>
292+
(shapes (2, 5), () mismatch)
293+
x: array([[3, 3, 3, 3, 3],
294+
[3, 3, 3, 3, 3]])
295+
y: array(3)
296+
297+
The `strict` parameter also ensures that the array data types match:
298+
299+
>>> x = np.array([2, 2, 2])
300+
>>> y = np.array([2., 2., 2.], dtype=np.float32)
301+
>>> np.testing.assert_equal(x, y, strict=True)
302+
Traceback (most recent call last):
303+
...
304+
AssertionError:
305+
Arrays are not equal
306+
<BLANKLINE>
307+
(dtypes int64, float32 mismatch)
308+
x: array([2, 2, 2])
309+
y: array([2., 2., 2.], dtype=float32)
261310
"""
262311
__tracebackhide__ = True # Hide traceback for py.test
263312
if isinstance(desired, dict):
@@ -279,7 +328,8 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
279328
from numpy.core import ndarray, isscalar, signbit
280329
from numpy import iscomplexobj, real, imag
281330
if isinstance(actual, ndarray) or isinstance(desired, ndarray):
282-
return assert_array_equal(actual, desired, err_msg, verbose)
331+
return assert_array_equal(actual, desired, err_msg, verbose,
332+
strict=strict)
283333
msg = build_err_msg([actual, desired], err_msg, verbose=verbose)
284334

285335
# Handle complex numbers: separate into real/imag to handle

numpy/testing/_private/utils.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ def assert_equal(
167167
desired: object,
168168
err_msg: str = ...,
169169
verbose: bool = ...,
170+
*,
171+
strict: bool = ...
170172
) -> None: ...
171173

172174
def print_assert_equal(

numpy/testing/tests/test_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,22 +232,22 @@ def test_array_vs_scalar_strict(self):
232232
b = 1.
233233

234234
with pytest.raises(AssertionError):
235-
assert_array_equal(a, b, strict=True)
235+
self._assert_func(a, b, strict=True)
236236

237237
def test_array_vs_array_strict(self):
238238
"""Test comparing two arrays with strict option."""
239239
a = np.array([1., 1., 1.])
240240
b = np.array([1., 1., 1.])
241241

242-
assert_array_equal(a, b, strict=True)
242+
self._assert_func(a, b, strict=True)
243243

244244
def test_array_vs_float_array_strict(self):
245245
"""Test comparing two arrays with strict option."""
246246
a = np.array([1, 1, 1])
247247
b = np.array([1., 1., 1.])
248248

249249
with pytest.raises(AssertionError):
250-
assert_array_equal(a, b, strict=True)
250+
self._assert_func(a, b, strict=True)
251251

252252

253253
class TestBuildErrorMessage:

0 commit comments

Comments
 (0)