Skip to content

Commit cc5e65e

Browse files
committed
ENH/TST: add deadband to image comparison
This allows us to ignore small images in colors using a more precise threshold than tol provides.
1 parent 54d718e commit cc5e65e

File tree

5 files changed

+47
-17
lines changed

5 files changed

+47
-17
lines changed

lib/matplotlib/testing/compare.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -366,16 +366,20 @@ def crop_to_same(actual_path, actual_image, expected_path, expected_image):
366366
return actual_image, expected_image
367367

368368

369-
def calculate_rms(expected_image, actual_image):
369+
def calculate_rms(expected_image, actual_image, *, deadband=0):
370370
"""
371371
Calculate the per-pixel errors, then compute the root mean square error.
372372
"""
373373
if expected_image.shape != actual_image.shape:
374374
raise ImageComparisonFailure(
375375
f"Image sizes do not match expected size: {expected_image.shape} "
376376
f"actual size {actual_image.shape}")
377+
diff = expected_image - actual_image
378+
if deadband > 0:
379+
# ignore small color differences
380+
diff[np.abs(diff) <= deadband] = 0
377381
# Convert to float to avoid overflowing finite integer types.
378-
return np.sqrt(((expected_image - actual_image).astype(float) ** 2).mean())
382+
return np.sqrt(((diff).astype(float) ** 2).mean())
379383

380384

381385
# NOTE: compare_image and save_diff_image assume that the image does not have
@@ -392,7 +396,7 @@ def _load_image(path):
392396
return np.asarray(img)
393397

394398

395-
def compare_images(expected, actual, tol, in_decorator=False):
399+
def compare_images(expected, actual, tol, in_decorator=False, *, deadband=0):
396400
"""
397401
Compare two "image" files checking differences within a tolerance.
398402

lib/matplotlib/testing/compare.pyi

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,21 @@ def convert(filename: str, cache: bool) -> str: ...
1616
def crop_to_same(
1717
actual_path: str, actual_image: NDArray, expected_path: str, expected_image: NDArray
1818
) -> tuple[NDArray, NDArray]: ...
19-
def calculate_rms(expected_image: NDArray, actual_image: NDArray) -> float: ...
19+
def calculate_rms(expected_image: NDArray, actual_image: NDArray,
20+
*, deadband: int | None = ...) -> float: ...
2021
@overload
2122
def compare_images(
22-
expected: str, actual: str, tol: float, in_decorator: Literal[True]
23+
expected: str, actual: str, tol: float, in_decorator: Literal[True],
24+
*, deadband: int | None = ...
2325
) -> None | dict[str, float | str]: ...
2426
@overload
2527
def compare_images(
26-
expected: str, actual: str, tol: float, in_decorator: Literal[False]
28+
expected: str, actual: str, tol: float, in_decorator: Literal[False],
29+
*, deadband: int | None = ...
2730
) -> None | str: ...
2831
@overload
2932
def compare_images(
30-
expected: str, actual: str, tol: float, in_decorator: bool = ...
33+
expected: str, actual: str, tol: float, in_decorator: bool = ...,
34+
*, deadband: int | None = ...
3135
) -> None | str | dict[str, float | str]: ...
3236
def save_diff_image(expected: str, actual: str, output: str) -> None: ...

lib/matplotlib/testing/decorators.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,10 @@ def _collect_new_figures():
9797
new_figs[:] = [manager.canvas.figure for manager in new_managers]
9898

9999

100-
def _raise_on_image_difference(expected, actual, tol):
100+
def _raise_on_image_difference(expected, actual, tol, *, deadband):
101101
__tracebackhide__ = True
102102

103-
err = compare_images(expected, actual, tol, in_decorator=True)
103+
err = compare_images(expected, actual, tol, in_decorator=True, deadband=deadband)
104104
if err:
105105
for key in ["actual", "expected", "diff"]:
106106
err[key] = os.path.relpath(err[key])
@@ -117,12 +117,13 @@ class _ImageComparisonBase:
117117
any code that would be specific to any testing framework.
118118
"""
119119

120-
def __init__(self, func, tol, remove_text, savefig_kwargs):
120+
def __init__(self, func, tol, remove_text, savefig_kwargs, *, deadband=0):
121121
self.func = func
122122
self.baseline_dir, self.result_dir = _image_directories(func)
123123
self.tol = tol
124124
self.remove_text = remove_text
125125
self.savefig_kwargs = savefig_kwargs
126+
self.deadband = deadband
126127

127128
def copy_baseline(self, baseline, extension):
128129
baseline_path = self.baseline_dir / baseline
@@ -171,12 +172,14 @@ def compare(self, fig, baseline, extension, *, _lock=False):
171172
# makes things more convenient for third-party users.
172173
plt.close(fig)
173174
expected_path = self.copy_baseline(baseline, extension)
174-
_raise_on_image_difference(expected_path, actual_path, self.tol)
175+
_raise_on_image_difference(
176+
expected_path, actual_path, self.tol, deadband=self.deadband
177+
)
175178

176179

177180
def _pytest_image_comparison(baseline_images, extensions, tol,
178181
freetype_version, remove_text, savefig_kwargs,
179-
style):
182+
style, *, deadband=0):
180183
"""
181184
Decorate function with image comparison for pytest.
182185
@@ -260,7 +263,9 @@ def image_comparison(baseline_images, extensions=None, tol=0,
260263
freetype_version=None, remove_text=False,
261264
savefig_kwarg=None,
262265
# Default of mpl_test_settings fixture and cleanup too.
263-
style=("classic", "_classic_test_patch")):
266+
style=("classic", "_classic_test_patch"),
267+
*,
268+
deadband=0):
264269
"""
265270
Compare images generated by the test with those specified in
266271
*baseline_images*, which must correspond, else an `.ImageComparisonFailure`
@@ -315,6 +320,19 @@ def image_comparison(baseline_images, extensions=None, tol=0,
315320
The optional style(s) to apply to the image test. The test itself
316321
can also apply additional styles if desired. Defaults to ``["classic",
317322
"_classic_test_patch"]``.
323+
324+
deadband : int, default 0
325+
326+
Like *tol* this provides a way to allow slight changes in the images to
327+
pass.
328+
329+
The most common change between architectures is that float math or
330+
float-to-int may have slight differences in rounding that results in the
331+
value in an 8bit color channel to change by +/- 1.
332+
333+
The per-channel differences must be greater than deadband to contribute
334+
to the computed RMS.
335+
318336
"""
319337

320338
if baseline_images is not None:
@@ -346,7 +364,7 @@ def image_comparison(baseline_images, extensions=None, tol=0,
346364
savefig_kwargs=savefig_kwarg, style=style)
347365

348366

349-
def check_figures_equal(*, extensions=("png", "pdf", "svg"), tol=0):
367+
def check_figures_equal(*, extensions=("png", "pdf", "svg"), tol=0, deadband=0):
350368
"""
351369
Decorator for test cases that generate and compare two figures.
352370
@@ -420,7 +438,7 @@ def wrapper(*args, ext, request, **kwargs):
420438
fig_test.savefig(test_image_path)
421439
fig_ref.savefig(ref_image_path)
422440
_raise_on_image_difference(
423-
ref_image_path, test_image_path, tol=tol
441+
ref_image_path, test_image_path, tol=tol, deadband=deadband
424442
)
425443
finally:
426444
plt.close(fig_test)

lib/matplotlib/testing/decorators.pyi

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@ def image_comparison(
1818
remove_text: bool = ...,
1919
savefig_kwarg: dict[str, Any] | None = ...,
2020
style: RcStyleType = ...,
21+
*,
22+
deadband: int | None = ...
2123
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: ...
2224
def check_figures_equal(
23-
*, extensions: Sequence[str] = ..., tol: float = ...
25+
*, extensions: Sequence[str] = ..., tol: float = ...,
26+
deadband: int | None = ...
2427
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: ...
2528
def _image_directories(func: Callable) -> tuple[Path, Path]: ...

lib/mpl_toolkits/mplot3d/tests/test_axes3d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def test_axes3d_repr():
115115

116116

117117
@mpl3d_image_comparison(['axes3d_primary_views.png'], style='mpl20',
118-
tol=0.05 if platform.machine() == "arm64" else 0)
118+
tol=0.05 if platform.machine() == "arm64" else 0,
119+
deadband=1)
119120
def test_axes3d_primary_views():
120121
# (elev, azim, roll)
121122
views = [(90, -90, 0), # XY

0 commit comments

Comments
 (0)