Skip to content

Commit 48c0c3a

Browse files
committed
fix #172 : added arguments rtol, atol and nans_equal to the compare() function
1 parent 6d3fdd7 commit 48c0c3a

File tree

4 files changed

+53
-35
lines changed

4 files changed

+53
-35
lines changed

doc/source/changes/version_0_32.rst.inc

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,12 @@
11
.. py:currentmodule:: larray_editor
22

3-
Syntax changes
4-
^^^^^^^^^^^^^^
5-
6-
* renamed ``MappingEditor.old_method_name()`` to :py:obj:`MappingEditor.new_method_name()` (closes :editor_issue:`1`).
7-
8-
* renamed ``old_argument_name`` argument of :py:obj:`MappingEditor.method_name()` to ``new_argument_name``.
9-
10-
11-
Backward incompatible changes
12-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
13-
14-
* other backward incompatible changes
15-
16-
17-
New features
18-
^^^^^^^^^^^^
19-
20-
* added a feature (see the :ref:`miscellaneous section <misc_editor>` for details).
21-
22-
* added another feature in the editor (closes :editor_issue:`1`).
23-
24-
.. note::
25-
26-
- It works for foo bar !
27-
- It does not work for foo baz !
28-
29-
303
.. _misc_editor:
314

325
Miscellaneous improvements
336
^^^^^^^^^^^^^^^^^^^^^^^^^^
347

35-
* improved something.
8+
* added keyword arguments ``rtol``, ``atol`` and ``nans_equal`` to the :py:obj:`compare()` function
9+
(closes :editor_issue:`172`).
3610

3711

3812
Fixes

larray_editor/api.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,22 @@ def compare(*args, **kwargs):
200200
display_caller_info: bool, optional
201201
Whether or not to display the filename and line number where the Editor has been called.
202202
Defaults to True.
203+
rtol : float or int, optional
204+
The relative tolerance parameter (see Notes). Defaults to 0.
205+
atol : float or int, optional
206+
The absolute tolerance parameter (see Notes). Defaults to 0.
207+
nans_equal : boolean, optional
208+
Whether or not to consider NaN values at the same positions in the two arrays as equal.
209+
By default, an array containing NaN values is never equal to another array, even if that other array
210+
also contains NaN values at the same positions. The reason is that a NaN value is different from
211+
*anything*, including itself. Defaults to True.
212+
213+
Notes
214+
-----
215+
For finite values, the following equation is used to test whether two values are equal:
216+
217+
absolute(array1 - array2) <= (atol + rtol * absolute(array2))
218+
203219
204220
Examples
205221
--------
@@ -214,6 +230,7 @@ def compare(*args, **kwargs):
214230
names = kwargs.pop('names', None)
215231
depth = kwargs.pop('depth', 0)
216232
display_caller_info = kwargs.pop('display_caller_info', True)
233+
217234
_app = QApplication.instance()
218235
if _app is None:
219236
_app = qapplication()

larray_editor/comparator.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
class ComparatorWidget(QWidget):
1515
"""Comparator Widget"""
16-
def __init__(self, parent=None, bg_gradient='red-white-blue'):
16+
def __init__(self, parent=None, bg_gradient='red-white-blue', rtol=0, atol=0, nans_equal=True, **kwargs):
1717
QWidget.__init__(self, parent)
1818

1919
layout = QVBoxLayout()
@@ -62,6 +62,17 @@ def __init__(self, parent=None, bg_gradient='red-white-blue'):
6262

6363
self.arraywidget.btn_layout.addLayout(tolerance_layout)
6464

65+
if rtol > 0 and atol > 0:
66+
raise ValueError("Arguments 'rtol' and 'atol' cannot be used together.")
67+
if rtol > 0:
68+
self.tolerance_combobox.setCurrentText("relative")
69+
self.tolerance_line_edit.setText(str(rtol))
70+
if atol > 0:
71+
self.tolerance_combobox.setCurrentText("absolute")
72+
self.tolerance_line_edit.setText(str(atol))
73+
74+
self.nans_equal = nans_equal
75+
6576
# add local arraywidget to layout
6677
self.arraywidget.btn_layout.addStretch()
6778
layout.addWidget(self.arraywidget)
@@ -104,7 +115,7 @@ def update_isequal(self):
104115
tol_str = self.tolerance_line_edit.text()
105116
tol = ast.literal_eval(tol_str) if tol_str else 0
106117
atol, rtol = (tol, 0) if self.tolerance_combobox.currentText() == "absolute" else (0, tol)
107-
self.isequal = self.array.eq(self.array0, rtol=rtol, atol=atol, nans_equal=True)
118+
self.isequal = self.array.eq(self.array0, rtol=rtol, atol=atol, nans_equal=self.nans_equal)
108119
except TypeError:
109120
self.isequal = self.array == self.array0
110121

@@ -171,16 +182,19 @@ def _setup_and_check(self, widget, data, title, readonly, **kwargs):
171182
readonly: bool
172183
kwargs:
173184
185+
* rtol: int or float
186+
* atol: int or float
187+
* nans_equal: bool
188+
* bg_gradient: str
174189
* names: list of str
175190
"""
176191
arrays = [aslarray(array) for array in data if isinstance(array, DISPLAY_IN_GRID)]
177192
names = kwargs.get('names', ["Array{}".format(i) for i in range(len(arrays))])
178-
bg_gradient = kwargs.get('bg_gradient', 'red-white-blue')
179193

180194
layout = QVBoxLayout()
181195
widget.setLayout(layout)
182196

183-
comparator_widget = ComparatorWidget(self, bg_gradient=bg_gradient)
197+
comparator_widget = ComparatorWidget(self, **kwargs)
184198
comparator_widget.set_data(arrays, Axis(names, 'array'))
185199
layout.addWidget(comparator_widget)
186200

@@ -213,12 +227,15 @@ def _setup_and_check(self, widget, data, title, readonly, **kwargs):
213227
readonly: bool
214228
kwargs:
215229
230+
* rtol: int or float
231+
* atol: int or float
232+
* nans_equal: bool
233+
* bg_gradient: str
216234
* names: list of str
217235
* colors: str
218236
"""
219237
sessions = data
220238
names = kwargs.get('names', ["Session{}".format(i) for i in range(len(sessions))])
221-
bg_gradient = kwargs.get('bg_gradient', 'red-white-blue')
222239

223240
assert all(isinstance(s, Session) for s in sessions)
224241
self.sessions = sessions
@@ -238,7 +255,7 @@ def _setup_and_check(self, widget, data, title, readonly, **kwargs):
238255
listwidget.item(i).setForeground(Qt.red)
239256
self.listwidget = listwidget
240257

241-
comparatorwidget = ComparatorWidget(self, bg_gradient=bg_gradient)
258+
comparatorwidget = ComparatorWidget(self, **kwargs)
242259
self.arraywidget = comparatorwidget
243260

244261
main_splitter = QSplitter(Qt.Horizontal)

larray_editor/tests/test_api_larray.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66

77
import numpy as np
8-
from larray import (Session, Axis, LArray, ndtest, zeros, from_lists, union,
8+
from larray import (Session, Axis, LArray, ndtest, zeros, ones, from_lists, union,
99
sin, cos, radians, maximum, sqrt, where, random, inf, nan)
1010

1111
from larray_editor.api import *
@@ -157,6 +157,16 @@ def make_demo(width=20, ball_radius=5, path_radius=5, steps=30):
157157
# print('\n open CSV')
158158
# edit('x_csv')
159159

160+
arr1 = ones((geo, sex))
161+
arr2 = random.normal(axes=(geo, sex))
162+
compare(arr1, arr2, atol=0.5)
163+
compare(arr1, arr2, rtol=0.3)
164+
165+
arr2 = where(arr2 > 1, arr1, -arr1)
166+
arr1['M'] = nan
167+
arr2['M'] = nan
168+
compare(arr1, arr2, nans_equal=False)
169+
160170
arr1 = ndtest((3, 3))
161171
arr2 = 2 * arr1
162172
arr3 = where(arr1 % 2 == 0, arr1, -arr1)

0 commit comments

Comments
 (0)