Skip to content

Commit 9ca4a93

Browse files
committed
fix #131: added absolute and relative tolerance when comparing float arrays
1 parent 332e7e6 commit 9ca4a93

File tree

2 files changed

+60
-11
lines changed

2 files changed

+60
-11
lines changed

larray_editor/comparator.py

+55-9
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import ast
12
import numpy as np
23
from qtpy.QtCore import Qt
4+
from qtpy.QtGui import QDoubleValidator
35
from qtpy.QtWidgets import (QWidget, QVBoxLayout, QListWidget, QSplitter, QDialogButtonBox, QHBoxLayout,
4-
QDialog, QLabel, QCheckBox)
6+
QDialog, QLabel, QCheckBox, QLineEdit, QComboBox)
57

6-
from larray import (LArray, Session, Axis, X, stack, full, full_like, zeros_like,
7-
nan, isnan, larray_nan_equal, nan_equal)
8+
from larray import LArray, Session, Axis, stack, full_like, nan, larray_nan_equal, element_equal
89
from larray_editor.utils import ima, replace_inf, _
910
from larray_editor.arraywidget import ArrayEditorWidget
1011

@@ -16,6 +17,7 @@ def __init__(self, parent=None):
1617

1718
layout = QVBoxLayout()
1819
self.setLayout(layout)
20+
1921
# max diff label
2022
maxdiff_layout = QHBoxLayout()
2123
maxdiff_layout.addWidget(QLabel('maximum absolute relative difference:'))
@@ -26,36 +28,80 @@ def __init__(self, parent=None):
2628

2729
self.arraywidget = ArrayEditorWidget(self, data=None, readonly=True, bg_gradient='red-white-blue')
2830

31+
# show difference only
2932
diff_checkbox = QCheckBox(_('Differences Only'))
3033
diff_checkbox.stateChanged.connect(self.display)
3134
self.diff_checkbox = diff_checkbox
3235
self.arraywidget.btn_layout.addWidget(diff_checkbox)
3336

37+
# absolute/relative tolerance
38+
tolerance_layout = QHBoxLayout()
39+
tooltip = """Element i of two arrays are considered as equal if they satisfy the following equation:
40+
abs(array1[i] - array2[i]) <= (absolute_tol + relative_tol * abs(array2[i]))"""
41+
42+
tolerance_label = QLabel("tolerance:")
43+
tolerance_label.setToolTip(tooltip)
44+
self.arraywidget.btn_layout.addWidget(tolerance_label)
45+
46+
tolerance_combobox = QComboBox()
47+
tolerance_combobox.addItems(["absolute", "relative"])
48+
tolerance_combobox.setToolTip(tooltip)
49+
tolerance_combobox.currentTextChanged.connect(self.update_isequal)
50+
tolerance_layout.addWidget(tolerance_combobox)
51+
self.tolerance_combobox = tolerance_combobox
52+
53+
tolerance_line_edit = QLineEdit()
54+
tolerance_line_edit.setValidator(QDoubleValidator())
55+
tolerance_line_edit.setPlaceholderText("1e-8")
56+
tolerance_line_edit.setMaximumWidth(80)
57+
tolerance_line_edit.setToolTip("Press Enter to activate the new tolerance value")
58+
tolerance_line_edit.editingFinished.connect(self.update_isequal)
59+
tolerance_layout.addWidget(tolerance_line_edit)
60+
self.tolerance_line_edit = tolerance_line_edit
61+
62+
self.arraywidget.btn_layout.addLayout(tolerance_layout)
63+
64+
# add local arraywidget to layout
65+
self.arraywidget.btn_layout.addStretch()
3466
layout.addWidget(self.arraywidget)
3567

3668
self.array = None
69+
self.array0 = None
3770
self.isequal = None
3871
self.bg_value = None
3972
self.stack_axis = None
4073

74+
# override keyPressEvent to prevent pressing Enter after changing the tolerance value
75+
# in associated QLineEdit to close the parent dialog box
76+
def keyPressEvent(self, event):
77+
if event.key() == Qt.Key_Enter or event.key() == Qt.Key_Return:
78+
return
79+
QWidget.keyPressEvent(self, event)
80+
4181
def set_data(self, arrays, stack_axis):
4282
assert all(np.isscalar(a) or isinstance(a, LArray) for a in arrays)
4383
self.stack_axis = stack_axis
4484
try:
4585
self.array = stack(arrays, stack_axis)
46-
array0 = self.array[stack_axis.i[0]]
86+
self.array0 = self.array[stack_axis.i[0]]
4787
except Exception as e:
4888
self.array = LArray(str(e))
49-
array0 = self.array
89+
self.array0 = self.array
90+
self.update_isequal()
91+
92+
def update_isequal(self):
5093
try:
51-
self.isequal = nan_equal(self.array, array0)
94+
tol_str = self.tolerance_line_edit.text()
95+
tol = ast.literal_eval(tol_str) if tol_str else 0
96+
atol, rtol = (tol, 0) if self.tolerance_combobox.currentText() == "absolute" else (0, tol)
97+
self.isequal = element_equal(self.array, self.array0, rtol=rtol, atol=atol, nan_equals=True)
5298
except TypeError:
53-
self.isequal = self.array == array0
99+
self.isequal = self.array == self.array0
54100

55101
try:
56102
with np.errstate(divide='ignore', invalid='ignore'):
57-
diff = self.array - array0
58-
reldiff = diff / array0
103+
diff = self.array - self.array0
104+
reldiff = diff / self.array0
59105
# this is necessary for nan, inf and -inf (because inf - inf = nan, not 0)
60106
# this is more precise than divnot0, it only ignore 0 / 0, not x / 0
61107
reldiff[self.isequal] = 0

larray_editor/tests/test_api_larray.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import numpy as np
1010
from larray import (Session, Axis, LArray, ndtest, ndrange, zeros, from_lists, union,
11-
sin, cos, radians, maximum, sqrt)
11+
sin, cos, radians, maximum, sqrt, where)
1212

1313

1414
logger.setLevel(logging.DEBUG)
@@ -154,4 +154,7 @@ def make_demo(width=20, ball_radius=5, path_radius=5, steps=30):
154154
# print('\n open CSV')
155155
# edit('x_csv')
156156

157-
# compare(arr3, arr4, arr5, arr6)
157+
arr1 = ndtest((3, 3))
158+
arr2 = 2 * arr1
159+
arr3 = where(arr1 % 2 == 0, arr1, -arr1)
160+
compare(arr1, arr2, arr3)

0 commit comments

Comments
 (0)