Skip to content

ENH/TST: add deadband to image comparison #28923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions lib/matplotlib/testing/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,16 +366,20 @@
return actual_image, expected_image


def calculate_rms(expected_image, actual_image):
def calculate_rms(expected_image, actual_image, *, deadband=0):
"""
Calculate the per-pixel errors, then compute the root mean square error.
"""
if expected_image.shape != actual_image.shape:
raise ImageComparisonFailure(
f"Image sizes do not match expected size: {expected_image.shape} "
f"actual size {actual_image.shape}")
diff = expected_image - actual_image
if deadband > 0:
# ignore small color differences
diff[np.abs(diff) <= deadband] = 0

Check warning on line 380 in lib/matplotlib/testing/compare.py

View check run for this annotation

Codecov / codecov/patch

lib/matplotlib/testing/compare.py#L380

Added line #L380 was not covered by tests
# Convert to float to avoid overflowing finite integer types.
return np.sqrt(((expected_image - actual_image).astype(float) ** 2).mean())
return np.sqrt(((diff).astype(float) ** 2).mean())


# NOTE: compare_image and save_diff_image assume that the image does not have
Expand All @@ -392,7 +396,7 @@
return np.asarray(img)


def compare_images(expected, actual, tol, in_decorator=False):
def compare_images(expected, actual, tol, in_decorator=False, *, deadband=0):
"""
Compare two "image" files checking differences within a tolerance.

Expand Down
12 changes: 8 additions & 4 deletions lib/matplotlib/testing/compare.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,21 @@ def convert(filename: str, cache: bool) -> str: ...
def crop_to_same(
actual_path: str, actual_image: NDArray, expected_path: str, expected_image: NDArray
) -> tuple[NDArray, NDArray]: ...
def calculate_rms(expected_image: NDArray, actual_image: NDArray) -> float: ...
def calculate_rms(expected_image: NDArray, actual_image: NDArray,
*, deadband: int | None = ...) -> float: ...
@overload
def compare_images(
expected: str, actual: str, tol: float, in_decorator: Literal[True]
expected: str, actual: str, tol: float, in_decorator: Literal[True],
*, deadband: int | None = ...
) -> None | dict[str, float | str]: ...
@overload
def compare_images(
expected: str, actual: str, tol: float, in_decorator: Literal[False]
expected: str, actual: str, tol: float, in_decorator: Literal[False],
*, deadband: int | None = ...
) -> None | str: ...
@overload
def compare_images(
expected: str, actual: str, tol: float, in_decorator: bool = ...
expected: str, actual: str, tol: float, in_decorator: bool = ...,
*, deadband: int | None = ...
) -> None | str | dict[str, float | str]: ...
def save_diff_image(expected: str, actual: str, output: str) -> None: ...
34 changes: 26 additions & 8 deletions lib/matplotlib/testing/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ def _collect_new_figures():
new_figs[:] = [manager.canvas.figure for manager in new_managers]


def _raise_on_image_difference(expected, actual, tol):
def _raise_on_image_difference(expected, actual, tol, *, deadband):
__tracebackhide__ = True

err = compare_images(expected, actual, tol, in_decorator=True)
err = compare_images(expected, actual, tol, in_decorator=True, deadband=deadband)
if err:
for key in ["actual", "expected", "diff"]:
err[key] = os.path.relpath(err[key])
Expand All @@ -117,12 +117,13 @@ class _ImageComparisonBase:
any code that would be specific to any testing framework.
"""

def __init__(self, func, tol, remove_text, savefig_kwargs):
def __init__(self, func, tol, remove_text, savefig_kwargs, *, deadband=0):
self.func = func
self.baseline_dir, self.result_dir = _image_directories(func)
self.tol = tol
self.remove_text = remove_text
self.savefig_kwargs = savefig_kwargs
self.deadband = deadband

def copy_baseline(self, baseline, extension):
baseline_path = self.baseline_dir / baseline
Expand Down Expand Up @@ -171,12 +172,14 @@ def compare(self, fig, baseline, extension, *, _lock=False):
# makes things more convenient for third-party users.
plt.close(fig)
expected_path = self.copy_baseline(baseline, extension)
_raise_on_image_difference(expected_path, actual_path, self.tol)
_raise_on_image_difference(
expected_path, actual_path, self.tol, deadband=self.deadband
)


def _pytest_image_comparison(baseline_images, extensions, tol,
freetype_version, remove_text, savefig_kwargs,
style):
style, *, deadband=0):
"""
Decorate function with image comparison for pytest.

Expand Down Expand Up @@ -260,7 +263,9 @@ def image_comparison(baseline_images, extensions=None, tol=0,
freetype_version=None, remove_text=False,
savefig_kwarg=None,
# Default of mpl_test_settings fixture and cleanup too.
style=("classic", "_classic_test_patch")):
style=("classic", "_classic_test_patch"),
*,
deadband=0):
"""
Compare images generated by the test with those specified in
*baseline_images*, which must correspond, else an `.ImageComparisonFailure`
Expand Down Expand Up @@ -315,6 +320,19 @@ def image_comparison(baseline_images, extensions=None, tol=0,
The optional style(s) to apply to the image test. The test itself
can also apply additional styles if desired. Defaults to ``["classic",
"_classic_test_patch"]``.

deadband : int, default 0

Like *tol* this provides a way to allow slight changes in the images to
pass.

The most common change between architectures is that float math or
float-to-int may have slight differences in rounding that results in the
value in an 8bit color channel to change by +/- 1.

The per-channel differences must be greater than deadband to contribute
to the computed RMS.

"""

if baseline_images is not None:
Expand Down Expand Up @@ -346,7 +364,7 @@ def image_comparison(baseline_images, extensions=None, tol=0,
savefig_kwargs=savefig_kwarg, style=style)


def check_figures_equal(*, extensions=("png", "pdf", "svg"), tol=0):
def check_figures_equal(*, extensions=("png", "pdf", "svg"), tol=0, deadband=0):
"""
Decorator for test cases that generate and compare two figures.

Expand Down Expand Up @@ -420,7 +438,7 @@ def wrapper(*args, ext, request, **kwargs):
fig_test.savefig(test_image_path)
fig_ref.savefig(ref_image_path)
_raise_on_image_difference(
ref_image_path, test_image_path, tol=tol
ref_image_path, test_image_path, tol=tol, deadband=deadband
)
finally:
plt.close(fig_test)
Expand Down
5 changes: 4 additions & 1 deletion lib/matplotlib/testing/decorators.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ def image_comparison(
remove_text: bool = ...,
savefig_kwarg: dict[str, Any] | None = ...,
style: RcStyleType = ...,
*,
deadband: int | None = ...
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: ...
def check_figures_equal(
*, extensions: Sequence[str] = ..., tol: float = ...
*, extensions: Sequence[str] = ..., tol: float = ...,
deadband: int | None = ...
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: ...
def _image_directories(func: Callable) -> tuple[Path, Path]: ...
3 changes: 2 additions & 1 deletion lib/mpl_toolkits/mplot3d/tests/test_axes3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def test_axes3d_repr():


@mpl3d_image_comparison(['axes3d_primary_views.png'], style='mpl20',
tol=0.05 if platform.machine() == "arm64" else 0)
tol=0.05 if platform.machine() == "arm64" else 0,
deadband=1)
def test_axes3d_primary_views():
# (elev, azim, roll)
views = [(90, -90, 0), # XY
Expand Down
Loading