Skip to content

Commit 5a77e3d

Browse files
committed
fix #488 : added rtol and atol arguments to LArray.equals in order to test equality between two arrays within a relative or absolute tolerance
1 parent 34355aa commit 5a77e3d

File tree

2 files changed

+61
-5
lines changed

2 files changed

+61
-5
lines changed

doc/source/changes/version_0_28.rst.inc

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,29 @@ Miscellaneous improvements
231231

232232
>>> arr = read_csv('arr.csv', nb_axes=3)
233233

234-
Closes :issue:`548`:
234+
Closes :issue:`548`.
235+
236+
* added the relative tolerance `rtol` and the absolute tolerance `atol` arguments to the `LArray.equals` method.
237+
These two arguments can be used to test the equality between two arrays within a given relative or
238+
absolute tolerance:
239+
240+
>>> arr1 = LArray([6., 8.], "a=a0,a1")
241+
>>> arr1
242+
a a0 a1
243+
6.0 8.0
244+
>>> arr2 = LArray([5.999, 8.001], "a=a0,a1")
245+
>>> arr2
246+
a a0 a1
247+
5.999 8.001
248+
>>> arr1.equals(arr2)
249+
False
250+
>>> # equals returns True if abs(array1 - array2) <= (atol + rtol * abs(array2))
251+
>>> arr1.equals(arr2, atol=0.01)
252+
True
253+
>>> arr1.equals(arr2, rtol=0.01)
254+
True
255+
256+
Closes :issue:`488`.
235257

236258
* renamed argument `transpose` by `wide` in `to_csv` method.
237259

larray/core/array.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5225,14 +5225,18 @@ def __int__(self):
52255225
def __float__(self):
52265226
return self.data.__float__()
52275227

5228-
def equals(self, other, nan_equals=False):
5228+
def equals(self, other, rtol=0, atol=0, nan_equals=False):
52295229
"""
52305230
Compares self with another array and returns True if they have the same axes and elements, False otherwise.
52315231
52325232
Parameters
52335233
----------
52345234
other: LArray-like
52355235
Input array. aslarray() is used on a non-LArray input.
5236+
rtol : float or int, optional
5237+
The relative tolerance parameter (see Notes). Defaults to 0.
5238+
atol : float or int, optional
5239+
The absolute tolerance parameter (see Notes). Defaults to 0.
52365240
nan_equals: boolean, optional
52375241
Whether or not to consider nan values at the same positions in the two arrays as equal.
52385242
By default, an array containing nan values is never equal to another array, even if that other array
@@ -5244,6 +5248,15 @@ def equals(self, other, nan_equals=False):
52445248
bool
52455249
Returns True if self is equal to other.
52465250
5251+
Notes
5252+
-----
5253+
For finite values, equals uses the following equation to test whether two values are equal:
5254+
5255+
absolute(array1 - array2) <= (atol + rtol * absolute(array2))
5256+
5257+
The above equation is not symmetric in array1 and array2, so that equals(array1, array2) might be different
5258+
from equals(array2, array1) in some rare cases.
5259+
52475260
Examples
52485261
--------
52495262
>>> arr1 = ndtest((2, 3))
@@ -5261,6 +5274,24 @@ def equals(self, other, nan_equals=False):
52615274
>>> arr1.equals(arr3)
52625275
False
52635276
5277+
Test equality between two arrays within a given tolerance range.
5278+
Return True if absolute(array1 - array2) <= (atol + rtol * absolute(array2)).
5279+
5280+
>>> arr1 = LArray([6., 8.], "a=a0,a1")
5281+
>>> arr1
5282+
a a0 a1
5283+
6.0 8.0
5284+
>>> arr2 = LArray([5.999, 8.001], "a=a0,a1")
5285+
>>> arr2
5286+
a a0 a1
5287+
5.999 8.001
5288+
>>> arr1.equals(arr2)
5289+
False
5290+
>>> arr1.equals(arr2, atol=0.01)
5291+
True
5292+
>>> arr1.equals(arr2, rtol=0.01)
5293+
True
5294+
52645295
Arrays with nan values
52655296
52665297
>>> arr1 = ndtest((2, 3), dtype=float)
@@ -5283,10 +5314,13 @@ def equals(self, other, nan_equals=False):
52835314
other = aslarray(other)
52845315
except Exception:
52855316
return False
5286-
if nan_equals:
5287-
return self.axes == other.axes and all(nan_equal(self, other))
5317+
if rtol == 0 and atol == 0:
5318+
if nan_equals:
5319+
return self.axes == other.axes and all(nan_equal(self, other))
5320+
else:
5321+
return self.axes == other.axes and np.array_equal(np.asarray(self), np.asarray(other))
52885322
else:
5289-
return self.axes == other.axes and np.array_equal(np.asarray(self), np.asarray(other))
5323+
return self.axes == other.axes and np.allclose(np.asarray(self), np.asarray(other), rtol, atol, nan_equals)
52905324

52915325
def divnot0(self, other):
52925326
"""Divides array by other, but returns 0.0 where other is 0.

0 commit comments

Comments
 (0)