Skip to content

Commit a975e48

Browse files
authored
Merge pull request #17267 from QuLogic/testing-lock
Improve image comparison decorator
2 parents 92d4b16 + e00b630 commit a975e48

File tree

3 files changed

+64
-48
lines changed

3 files changed

+64
-48
lines changed

lib/matplotlib/testing/decorators.py

Lines changed: 57 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
import sys
1010
import unittest
1111
import warnings
12+
try:
13+
from contextlib import nullcontext
14+
except ImportError:
15+
from contextlib import ExitStack as nullcontext # Py3.6.
1216

1317
import matplotlib as mpl
1418
import matplotlib.style
@@ -200,7 +204,7 @@ def copy_baseline(self, baseline, extension):
200204
f"{orig_expected_path}") from err
201205
return expected_fname
202206

203-
def compare(self, idx, baseline, extension):
207+
def compare(self, idx, baseline, extension, *, _lock=False):
204208
__tracebackhide__ = True
205209
fignum = plt.get_fignums()[idx]
206210
fig = plt.figure(fignum)
@@ -214,10 +218,12 @@ def compare(self, idx, baseline, extension):
214218
kwargs.setdefault('metadata',
215219
{'Creator': None, 'Producer': None,
216220
'CreationDate': None})
217-
fig.savefig(actual_path, **kwargs)
218221

219-
expected_path = self.copy_baseline(baseline, extension)
220-
_raise_on_image_difference(expected_path, actual_path, self.tol)
222+
lock = cbook._lock_path(actual_path) if _lock else nullcontext()
223+
with lock:
224+
fig.savefig(actual_path, **kwargs)
225+
expected_path = self.copy_baseline(baseline, extension)
226+
_raise_on_image_difference(expected_path, actual_path, self.tol)
221227

222228

223229
def _pytest_image_comparison(baseline_images, extensions, tol,
@@ -227,43 +233,66 @@ def _pytest_image_comparison(baseline_images, extensions, tol,
227233
Decorate function with image comparison for pytest.
228234
229235
This function creates a decorator that wraps a figure-generating function
230-
with image comparison code. Pytest can become confused if we change the
231-
signature of the function, so we indirectly pass anything we need via the
232-
`mpl_image_comparison_parameters` fixture and extra markers.
236+
with image comparison code.
233237
"""
234238
import pytest
235239

236240
extensions = map(_mark_skip_if_format_is_uncomparable, extensions)
241+
KEYWORD_ONLY = inspect.Parameter.KEYWORD_ONLY
237242

238243
def decorator(func):
244+
old_sig = inspect.signature(func)
245+
239246
@functools.wraps(func)
240-
# Parameter indirection; see docstring above and comment below.
241-
@pytest.mark.usefixtures('mpl_image_comparison_parameters')
242247
@pytest.mark.parametrize('extension', extensions)
243-
@pytest.mark.baseline_images(baseline_images)
244-
# END Parameter indirection.
245248
@pytest.mark.style(style)
246249
@_checked_on_freetype_version(freetype_version)
247250
@functools.wraps(func)
248-
def wrapper(*args, **kwargs):
251+
def wrapper(*args, extension, request, **kwargs):
249252
__tracebackhide__ = True
253+
if 'extension' in old_sig.parameters:
254+
kwargs['extension'] = extension
255+
if 'request' in old_sig.parameters:
256+
kwargs['request'] = request
257+
250258
img = _ImageComparisonBase(func, tol=tol, remove_text=remove_text,
251259
savefig_kwargs=savefig_kwargs)
252260
matplotlib.testing.set_font_settings_for_testing()
253261
func(*args, **kwargs)
254262

255-
# Parameter indirection:
256-
# This is hacked on via the mpl_image_comparison_parameters fixture
257-
# so that we don't need to modify the function's real signature for
258-
# any parametrization. Modifying the signature is very very tricky
259-
# and likely to confuse pytest.
260-
baseline_images, extension = func.parameters
261-
262-
assert len(plt.get_fignums()) == len(baseline_images), (
263+
# If the test is parametrized in any way other than applied via
264+
# this decorator, then we need to use a lock to prevent two
265+
# processes from touching the same output file.
266+
needs_lock = any(
267+
marker.args[0] != 'extension'
268+
for marker in request.node.iter_markers('parametrize'))
269+
270+
if baseline_images is not None:
271+
our_baseline_images = baseline_images
272+
else:
273+
# Allow baseline image list to be produced on the fly based on
274+
# current parametrization.
275+
our_baseline_images = request.getfixturevalue(
276+
'baseline_images')
277+
278+
assert len(plt.get_fignums()) == len(our_baseline_images), (
263279
"Test generated {} images but there are {} baseline images"
264-
.format(len(plt.get_fignums()), len(baseline_images)))
265-
for idx, baseline in enumerate(baseline_images):
266-
img.compare(idx, baseline, extension)
280+
.format(len(plt.get_fignums()), len(our_baseline_images)))
281+
for idx, baseline in enumerate(our_baseline_images):
282+
img.compare(idx, baseline, extension, _lock=needs_lock)
283+
284+
parameters = list(old_sig.parameters.values())
285+
if 'extension' not in old_sig.parameters:
286+
parameters += [inspect.Parameter('extension', KEYWORD_ONLY)]
287+
if 'request' not in old_sig.parameters:
288+
parameters += [inspect.Parameter("request", KEYWORD_ONLY)]
289+
new_sig = old_sig.replace(parameters=parameters)
290+
wrapper.__signature__ = new_sig
291+
292+
# Reach a bit into pytest internals to hoist the marks from our wrapped
293+
# function.
294+
new_marks = getattr(func, 'pytestmark', []) + wrapper.pytestmark
295+
wrapper.pytestmark = new_marks
267296

268297
return wrapper
269298

@@ -398,13 +427,11 @@ def decorator(func):
398427
f"function has the signature {old_sig}")
399428

400429
@pytest.mark.parametrize("ext", extensions)
401-
def wrapper(*args, **kwargs):
402-
ext = kwargs['ext']
403-
if 'ext' not in old_sig.parameters:
404-
kwargs.pop('ext')
405-
request = kwargs['request']
406-
if 'request' not in old_sig.parameters:
407-
kwargs.pop('request')
430+
def wrapper(*args, ext, request, **kwargs):
431+
if 'ext' in old_sig.parameters:
432+
kwargs['ext'] = ext
433+
if 'request' in old_sig.parameters:
434+
kwargs['request'] = request
408435

409436
file_name = "".join(c for c in request.node.name
410437
if c in ALLOWED_CHARS)

lib/matplotlib/tests/test_axes.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3150,15 +3150,10 @@ def test_hist_stacked_weighted():
31503150
ax.hist((d1, d2), weights=(w1, w2), histtype="stepfilled", stacked=True)
31513151

31523152

3153-
@image_comparison(['stem.png', 'stem.png'], style='mpl20', remove_text=True)
3154-
def test_stem():
3155-
# Note, we don't use @pytest.mark.parametrize, because in parallel this
3156-
# might cause one process result to overwrite another's.
3157-
for use_line_collection in [True, False]:
3158-
_test_stem(use_line_collection)
3159-
3160-
3161-
def _test_stem(use_line_collection):
3153+
@pytest.mark.parametrize("use_line_collection", [True, False],
3154+
ids=['w/ line collection', 'w/o line collection'])
3155+
@image_comparison(['stem.png'], style='mpl20', remove_text=True)
3156+
def test_stem(use_line_collection):
31623157
x = np.linspace(0.1, 2 * np.pi, 100)
31633158
args = (x, np.cos(x))
31643159
# Label is a single space to force a legend to be drawn, but to avoid any

lib/mpl_toolkits/tests/test_axes_grid.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,10 @@
1212
# The original version of this test relied on mpl_toolkits's slightly different
1313
# colorbar implementation; moving to matplotlib's own colorbar implementation
1414
# caused the small image comparison error.
15-
@image_comparison(['imagegrid_cbar_mode.png', 'imagegrid_cbar_mode.png'],
15+
@pytest.mark.parametrize("legacy_colorbar", [False, True])
16+
@image_comparison(['imagegrid_cbar_mode.png'],
1617
remove_text=True, style='mpl20', tol=0.3)
17-
def test_imagegrid_cbar_mode_edge():
18-
# Note, we don't use @pytest.mark.parametrize, because in parallel this
19-
# might cause one process result to overwrite another's.
20-
for legacy_colorbar in [False, True]:
21-
_test_imagegrid_cbar_mode_edge(legacy_colorbar)
22-
23-
24-
def _test_imagegrid_cbar_mode_edge(legacy_colorbar):
18+
def test_imagegrid_cbar_mode_edge(legacy_colorbar):
2519
mpl.rcParams["mpl_toolkits.legacy_colorbar"] = legacy_colorbar
2620

2721
X, Y = np.meshgrid(np.linspace(0, 6, 30), np.linspace(0, 6, 30))

0 commit comments

Comments
 (0)