diff --git a/doc/release/upcoming_changes/24770.new_feature.rst b/doc/release/upcoming_changes/24770.new_feature.rst new file mode 100644 index 000000000000..3e4779e607f1 --- /dev/null +++ b/doc/release/upcoming_changes/24770.new_feature.rst @@ -0,0 +1,5 @@ +``strict`` option for `testing.assert_equal` +-------------------------------------------- +The ``strict`` option is now available for `testing.assert_equal`. +Setting ``strict=True`` will disable the broadcasting behaviour for scalars +and ensure that input arrays have the same data type. \ No newline at end of file diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index 4b6be79b75ba..bba359643d13 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -209,7 +209,7 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:', return '\n'.join(msg) -def assert_equal(actual, desired, err_msg='', verbose=True): +def assert_equal(actual, desired, err_msg='', verbose=True, *, strict=False): """ Raises an AssertionError if two objects are not equal. @@ -217,10 +217,6 @@ def assert_equal(actual, desired, err_msg='', verbose=True): check that all elements of these objects are equal. An exception is raised at the first conflicting values. - When one of `actual` and `desired` is a scalar and the other is array_like, - the function checks that each element of the array_like object is equal to - the scalar. - This function handles NaN comparisons as if NaN was a "normal" number. That is, AssertionError is not raised if both objects have NaNs in the same positions. This is in contrast to the IEEE standard on NaNs, which says @@ -236,15 +232,34 @@ def assert_equal(actual, desired, err_msg='', verbose=True): The error message to be printed in case of failure. verbose : bool, optional If True, the conflicting values are appended to the error message. + strict : bool, optional + If True and either of the `actual` and `desired` arguments is an array, + raise an ``AssertionError`` when either the shape or the data type of + the arguments does not match. If neither argument is an array, this + parameter has no effect. + + .. versionadded:: 2.0.0 Raises ------ AssertionError If actual and desired are not equal. + See Also + -------- + assert_allclose + assert_array_almost_equal_nulp, + assert_array_max_ulp, + + Notes + ----- + By default, when one of `actual` and `desired` is a scalar and the other is + an array, the function checks that each element of the array is equal to + the scalar. This behaviour can be disabled by setting ``strict==True``. + Examples -------- - >>> np.testing.assert_equal([4,5], [4,6]) + >>> np.testing.assert_equal([4, 5], [4, 6]) Traceback (most recent call last): ... AssertionError: @@ -258,6 +273,40 @@ def assert_equal(actual, desired, err_msg='', verbose=True): >>> np.testing.assert_equal(np.array([1.0, 2.0, np.nan]), [1, 2, np.nan]) + As mentioned in the Notes section, `assert_equal` has special + handling for scalars when one of the arguments is an array. + Here, the test checks that each value in `x` is 3: + + >>> x = np.full((2, 5), fill_value=3) + >>> np.testing.assert_equal(x, 3) + + Use `strict` to raise an AssertionError when comparing a scalar with an + array of a different shape: + + >>> np.testing.assert_equal(x, 3, strict=True) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not equal + + (shapes (2, 5), () mismatch) + x: array([[3, 3, 3, 3, 3], + [3, 3, 3, 3, 3]]) + y: array(3) + + The `strict` parameter also ensures that the array data types match: + + >>> x = np.array([2, 2, 2]) + >>> y = np.array([2., 2., 2.], dtype=np.float32) + >>> np.testing.assert_equal(x, y, strict=True) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not equal + + (dtypes int64, float32 mismatch) + x: array([2, 2, 2]) + y: array([2., 2., 2.], dtype=float32) """ __tracebackhide__ = True # Hide traceback for py.test if isinstance(desired, dict): @@ -279,7 +328,8 @@ def assert_equal(actual, desired, err_msg='', verbose=True): from numpy.core import ndarray, isscalar, signbit from numpy import iscomplexobj, real, imag if isinstance(actual, ndarray) or isinstance(desired, ndarray): - return assert_array_equal(actual, desired, err_msg, verbose) + return assert_array_equal(actual, desired, err_msg, verbose, + strict=strict) msg = build_err_msg([actual, desired], err_msg, verbose=verbose) # Handle complex numbers: separate into real/imag to handle diff --git a/numpy/testing/_private/utils.pyi b/numpy/testing/_private/utils.pyi index 09fae5e2ff4d..648168fb5e00 100644 --- a/numpy/testing/_private/utils.pyi +++ b/numpy/testing/_private/utils.pyi @@ -167,6 +167,8 @@ def assert_equal( desired: object, err_msg: str = ..., verbose: bool = ..., + *, + strict: bool = ... ) -> None: ... def print_assert_equal( diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 9d84edd104e8..11d0b577cfd9 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -232,14 +232,14 @@ def test_array_vs_scalar_strict(self): b = 1. with pytest.raises(AssertionError): - assert_array_equal(a, b, strict=True) + self._assert_func(a, b, strict=True) def test_array_vs_array_strict(self): """Test comparing two arrays with strict option.""" a = np.array([1., 1., 1.]) b = np.array([1., 1., 1.]) - assert_array_equal(a, b, strict=True) + self._assert_func(a, b, strict=True) def test_array_vs_float_array_strict(self): """Test comparing two arrays with strict option.""" @@ -247,7 +247,7 @@ def test_array_vs_float_array_strict(self): b = np.array([1., 1., 1.]) with pytest.raises(AssertionError): - assert_array_equal(a, b, strict=True) + self._assert_func(a, b, strict=True) class TestBuildErrorMessage: