9
9
import sys
10
10
import unittest
11
11
import warnings
12
+ try :
13
+ from contextlib import nullcontext
14
+ except ImportError :
15
+ from contextlib import ExitStack as nullcontext # Py3.6.
12
16
13
17
import matplotlib as mpl
14
18
import matplotlib .style
@@ -200,7 +204,7 @@ def copy_baseline(self, baseline, extension):
200
204
f"{ orig_expected_path } " ) from err
201
205
return expected_fname
202
206
203
- def compare (self , idx , baseline , extension ):
207
+ def compare (self , idx , baseline , extension , * , _lock = False ):
204
208
__tracebackhide__ = True
205
209
fignum = plt .get_fignums ()[idx ]
206
210
fig = plt .figure (fignum )
@@ -214,10 +218,12 @@ def compare(self, idx, baseline, extension):
214
218
kwargs .setdefault ('metadata' ,
215
219
{'Creator' : None , 'Producer' : None ,
216
220
'CreationDate' : None })
217
- fig .savefig (actual_path , ** kwargs )
218
221
219
- expected_path = self .copy_baseline (baseline , extension )
220
- _raise_on_image_difference (expected_path , actual_path , self .tol )
222
+ lock = cbook ._lock_path (actual_path ) if _lock else nullcontext ()
223
+ with lock :
224
+ fig .savefig (actual_path , ** kwargs )
225
+ expected_path = self .copy_baseline (baseline , extension )
226
+ _raise_on_image_difference (expected_path , actual_path , self .tol )
221
227
222
228
223
229
def _pytest_image_comparison (baseline_images , extensions , tol ,
@@ -227,43 +233,66 @@ def _pytest_image_comparison(baseline_images, extensions, tol,
227
233
Decorate function with image comparison for pytest.
228
234
229
235
This function creates a decorator that wraps a figure-generating function
230
- with image comparison code. Pytest can become confused if we change the
231
- signature of the function, so we indirectly pass anything we need via the
232
- `mpl_image_comparison_parameters` fixture and extra markers.
236
+ with image comparison code.
233
237
"""
234
238
import pytest
235
239
236
240
extensions = map (_mark_skip_if_format_is_uncomparable , extensions )
241
+ KEYWORD_ONLY = inspect .Parameter .KEYWORD_ONLY
237
242
238
243
def decorator (func ):
244
+ old_sig = inspect .signature (func )
245
+
239
246
@functools .wraps (func )
240
- # Parameter indirection; see docstring above and comment below.
241
- @pytest .mark .usefixtures ('mpl_image_comparison_parameters' )
242
247
@pytest .mark .parametrize ('extension' , extensions )
243
- @pytest .mark .baseline_images (baseline_images )
244
- # END Parameter indirection.
245
248
@pytest .mark .style (style )
246
249
@_checked_on_freetype_version (freetype_version )
247
250
@functools .wraps (func )
248
- def wrapper (* args , ** kwargs ):
251
+ def wrapper (* args , extension , request , ** kwargs ):
249
252
__tracebackhide__ = True
253
+ if 'extension' in old_sig .parameters :
254
+ kwargs ['extension' ] = extension
255
+ if 'request' in old_sig .parameters :
256
+ kwargs ['request' ] = request
257
+
250
258
img = _ImageComparisonBase (func , tol = tol , remove_text = remove_text ,
251
259
savefig_kwargs = savefig_kwargs )
252
260
matplotlib .testing .set_font_settings_for_testing ()
253
261
func (* args , ** kwargs )
254
262
255
- # Parameter indirection:
256
- # This is hacked on via the mpl_image_comparison_parameters fixture
257
- # so that we don't need to modify the function's real signature for
258
- # any parametrization. Modifying the signature is very very tricky
259
- # and likely to confuse pytest.
260
- baseline_images , extension = func .parameters
261
-
262
- assert len (plt .get_fignums ()) == len (baseline_images ), (
263
+ # If the test is parametrized in any way other than applied via
264
+ # this decorator, then we need to use a lock to prevent two
265
+ # processes from touching the same output file.
266
+ needs_lock = any (
267
+ marker .args [0 ] != 'extension'
268
+ for marker in request .node .iter_markers ('parametrize' ))
269
+
270
+ if baseline_images is not None :
271
+ our_baseline_images = baseline_images
272
+ else :
273
+ # Allow baseline image list to be produced on the fly based on
274
+ # current parametrization.
275
+ our_baseline_images = request .getfixturevalue (
276
+ 'baseline_images' )
277
+
278
+ assert len (plt .get_fignums ()) == len (our_baseline_images ), (
263
279
"Test generated {} images but there are {} baseline images"
264
- .format (len (plt .get_fignums ()), len (baseline_images )))
265
- for idx , baseline in enumerate (baseline_images ):
266
- img .compare (idx , baseline , extension )
280
+ .format (len (plt .get_fignums ()), len (our_baseline_images )))
281
+ for idx , baseline in enumerate (our_baseline_images ):
282
+ img .compare (idx , baseline , extension , _lock = needs_lock )
283
+
284
+ parameters = list (old_sig .parameters .values ())
285
+ if 'extension' not in old_sig .parameters :
286
+ parameters += [inspect .Parameter ('extension' , KEYWORD_ONLY )]
287
+ if 'request' not in old_sig .parameters :
288
+ parameters += [inspect .Parameter ("request" , KEYWORD_ONLY )]
289
+ new_sig = old_sig .replace (parameters = parameters )
290
+ wrapper .__signature__ = new_sig
291
+
292
+ # Reach a bit into pytest internals to hoist the marks from our wrapped
293
+ # function.
294
+ new_marks = getattr (func , 'pytestmark' , []) + wrapper .pytestmark
295
+ wrapper .pytestmark = new_marks
267
296
268
297
return wrapper
269
298
@@ -398,13 +427,11 @@ def decorator(func):
398
427
f"function has the signature { old_sig } " )
399
428
400
429
@pytest .mark .parametrize ("ext" , extensions )
401
- def wrapper (* args , ** kwargs ):
402
- ext = kwargs ['ext' ]
403
- if 'ext' not in old_sig .parameters :
404
- kwargs .pop ('ext' )
405
- request = kwargs ['request' ]
406
- if 'request' not in old_sig .parameters :
407
- kwargs .pop ('request' )
430
+ def wrapper (* args , ext , request , ** kwargs ):
431
+ if 'ext' in old_sig .parameters :
432
+ kwargs ['ext' ] = ext
433
+ if 'request' in old_sig .parameters :
434
+ kwargs ['request' ] = request
408
435
409
436
file_name = "" .join (c for c in request .node .name
410
437
if c in ALLOWED_CHARS )
0 commit comments