Skip to content

Commit 3ee24e5

Browse files
authored
Merge pull request #16797 from QuLogic/backport-checkfigeq
Backport #15589 and #16693, fixes for check_figures_equal
2 parents e935d85 + 5d69b1e commit 3ee24e5

File tree

1 file changed

+40
-23
lines changed

1 file changed

+40
-23
lines changed

lib/matplotlib/testing/decorators.py

+40-23
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
from pathlib import Path
77
import shutil
8+
import string
89
import sys
910
import unittest
1011
import warnings
@@ -17,7 +18,7 @@
1718
from matplotlib import ft2font
1819
from matplotlib import pyplot as plt
1920
from matplotlib import ticker
20-
from . import is_called_from_pytest
21+
2122
from .compare import comparable_formats, compare_images, make_test_filename
2223
from .exceptions import ImageComparisonFailure
2324

@@ -381,34 +382,50 @@ def test_plot(fig_test, fig_ref):
381382
fig_test.subplots().plot([1, 3, 5])
382383
fig_ref.subplots().plot([0, 1, 2], [1, 3, 5])
383384
"""
384-
POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD
385+
ALLOWED_CHARS = set(string.digits + string.ascii_letters + '_-[]()')
386+
KEYWORD_ONLY = inspect.Parameter.KEYWORD_ONLY
385387
def decorator(func):
386388
import pytest
387389

388390
_, result_dir = _image_directories(func)
391+
old_sig = inspect.signature(func)
389392

390393
@pytest.mark.parametrize("ext", extensions)
391-
def wrapper(*args, ext, **kwargs):
392-
fig_test = plt.figure("test")
393-
fig_ref = plt.figure("reference")
394-
func(*args, fig_test=fig_test, fig_ref=fig_ref, **kwargs)
395-
test_image_path = result_dir / (func.__name__ + "." + ext)
396-
ref_image_path = result_dir / (
397-
func.__name__ + "-expected." + ext
398-
)
399-
fig_test.savefig(test_image_path)
400-
fig_ref.savefig(ref_image_path)
401-
_raise_on_image_difference(
402-
ref_image_path, test_image_path, tol=tol
403-
)
404-
405-
sig = inspect.signature(func)
406-
new_sig = sig.replace(
407-
parameters=([param
408-
for param in sig.parameters.values()
409-
if param.name not in {"fig_test", "fig_ref"}]
410-
+ [inspect.Parameter("ext", POSITIONAL_OR_KEYWORD)])
411-
)
394+
def wrapper(*args, **kwargs):
395+
ext = kwargs['ext']
396+
if 'ext' not in old_sig.parameters:
397+
kwargs.pop('ext')
398+
request = kwargs['request']
399+
if 'request' not in old_sig.parameters:
400+
kwargs.pop('request')
401+
402+
file_name = "".join(c for c in request.node.name
403+
if c in ALLOWED_CHARS)
404+
try:
405+
fig_test = plt.figure("test")
406+
fig_ref = plt.figure("reference")
407+
func(*args, fig_test=fig_test, fig_ref=fig_ref, **kwargs)
408+
test_image_path = result_dir / (file_name + "." + ext)
409+
ref_image_path = result_dir / (file_name + "-expected." + ext)
410+
fig_test.savefig(test_image_path)
411+
fig_ref.savefig(ref_image_path)
412+
_raise_on_image_difference(
413+
ref_image_path, test_image_path, tol=tol
414+
)
415+
finally:
416+
plt.close(fig_test)
417+
plt.close(fig_ref)
418+
419+
parameters = [
420+
param
421+
for param in old_sig.parameters.values()
422+
if param.name not in {"fig_test", "fig_ref"}
423+
]
424+
if 'ext' not in old_sig.parameters:
425+
parameters += [inspect.Parameter("ext", KEYWORD_ONLY)]
426+
if 'request' not in old_sig.parameters:
427+
parameters += [inspect.Parameter("request", KEYWORD_ONLY)]
428+
new_sig = old_sig.replace(parameters=parameters)
412429
wrapper.__signature__ = new_sig
413430

414431
# reach a bit into pytest internals to hoist the marks from

0 commit comments

Comments
 (0)