@@ -209,18 +209,14 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:',
209
209
return '\n ' .join (msg )
210
210
211
211
212
- def assert_equal (actual , desired , err_msg = '' , verbose = True ):
212
+ def assert_equal (actual , desired , err_msg = '' , verbose = True , * , strict = False ):
213
213
"""
214
214
Raises an AssertionError if two objects are not equal.
215
215
216
216
Given two objects (scalars, lists, tuples, dictionaries or numpy arrays),
217
217
check that all elements of these objects are equal. An exception is raised
218
218
at the first conflicting values.
219
219
220
- When one of `actual` and `desired` is a scalar and the other is array_like,
221
- the function checks that each element of the array_like object is equal to
222
- the scalar.
223
-
224
220
This function handles NaN comparisons as if NaN was a "normal" number.
225
221
That is, AssertionError is not raised if both objects have NaNs in the same
226
222
positions. This is in contrast to the IEEE standard on NaNs, which says
@@ -236,15 +232,34 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
236
232
The error message to be printed in case of failure.
237
233
verbose : bool, optional
238
234
If True, the conflicting values are appended to the error message.
235
+ strict : bool, optional
236
+ If True and either of the `actual` and `desired` arguments is an array,
237
+ raise an ``AssertionError`` when either the shape or the data type of
238
+ the arguments does not match. If neither argument is an array, this
239
+ parameter has no effect.
240
+
241
+ .. versionadded:: 2.0.0
239
242
240
243
Raises
241
244
------
242
245
AssertionError
243
246
If actual and desired are not equal.
244
247
248
+ See Also
249
+ --------
250
+ assert_allclose
251
+ assert_array_almost_equal_nulp,
252
+ assert_array_max_ulp,
253
+
254
+ Notes
255
+ -----
256
+ By default, when one of `actual` and `desired` is a scalar and the other is
257
+ an array, the function checks that each element of the array is equal to
258
+ the scalar. This behaviour can be disabled by setting ``strict==True``.
259
+
245
260
Examples
246
261
--------
247
- >>> np.testing.assert_equal([4,5], [4,6])
262
+ >>> np.testing.assert_equal([4, 5], [4, 6])
248
263
Traceback (most recent call last):
249
264
...
250
265
AssertionError:
@@ -258,6 +273,40 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
258
273
259
274
>>> np.testing.assert_equal(np.array([1.0, 2.0, np.nan]), [1, 2, np.nan])
260
275
276
+ As mentioned in the Notes section, `assert_equal` has special
277
+ handling for scalars when one of the arguments is an array.
278
+ Here, the test checks that each value in `x` is 3:
279
+
280
+ >>> x = np.full((2, 5), fill_value=3)
281
+ >>> np.testing.assert_equal(x, 3)
282
+
283
+ Use `strict` to raise an AssertionError when comparing a scalar with an
284
+ array of a different shape:
285
+
286
+ >>> np.testing.assert_equal(x, 3, strict=True)
287
+ Traceback (most recent call last):
288
+ ...
289
+ AssertionError:
290
+ Arrays are not equal
291
+ <BLANKLINE>
292
+ (shapes (2, 5), () mismatch)
293
+ x: array([[3, 3, 3, 3, 3],
294
+ [3, 3, 3, 3, 3]])
295
+ y: array(3)
296
+
297
+ The `strict` parameter also ensures that the array data types match:
298
+
299
+ >>> x = np.array([2, 2, 2])
300
+ >>> y = np.array([2., 2., 2.], dtype=np.float32)
301
+ >>> np.testing.assert_equal(x, y, strict=True)
302
+ Traceback (most recent call last):
303
+ ...
304
+ AssertionError:
305
+ Arrays are not equal
306
+ <BLANKLINE>
307
+ (dtypes int64, float32 mismatch)
308
+ x: array([2, 2, 2])
309
+ y: array([2., 2., 2.], dtype=float32)
261
310
"""
262
311
__tracebackhide__ = True # Hide traceback for py.test
263
312
if isinstance (desired , dict ):
@@ -279,7 +328,8 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
279
328
from numpy .core import ndarray , isscalar , signbit
280
329
from numpy import iscomplexobj , real , imag
281
330
if isinstance (actual , ndarray ) or isinstance (desired , ndarray ):
282
- return assert_array_equal (actual , desired , err_msg , verbose )
331
+ return assert_array_equal (actual , desired , err_msg , verbose ,
332
+ strict = strict )
283
333
msg = build_err_msg ([actual , desired ], err_msg , verbose = verbose )
284
334
285
335
# Handle complex numbers: separate into real/imag to handle
0 commit comments