|
5 | 5 | import os
|
6 | 6 | from pathlib import Path
|
7 | 7 | import shutil
|
| 8 | +import string |
8 | 9 | import sys
|
9 | 10 | import unittest
|
10 | 11 | import warnings
|
|
17 | 18 | from matplotlib import ft2font
|
18 | 19 | from matplotlib import pyplot as plt
|
19 | 20 | from matplotlib import ticker
|
20 |
| -from . import is_called_from_pytest |
| 21 | + |
21 | 22 | from .compare import comparable_formats, compare_images, make_test_filename
|
22 | 23 | from .exceptions import ImageComparisonFailure
|
23 | 24 |
|
@@ -381,34 +382,50 @@ def test_plot(fig_test, fig_ref):
|
381 | 382 | fig_test.subplots().plot([1, 3, 5])
|
382 | 383 | fig_ref.subplots().plot([0, 1, 2], [1, 3, 5])
|
383 | 384 | """
|
384 |
| - POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD |
| 385 | + ALLOWED_CHARS = set(string.digits + string.ascii_letters + '_-[]()') |
| 386 | + KEYWORD_ONLY = inspect.Parameter.KEYWORD_ONLY |
385 | 387 | def decorator(func):
|
386 | 388 | import pytest
|
387 | 389 |
|
388 | 390 | _, result_dir = _image_directories(func)
|
| 391 | + old_sig = inspect.signature(func) |
389 | 392 |
|
390 | 393 | @pytest.mark.parametrize("ext", extensions)
|
391 |
| - def wrapper(*args, ext, **kwargs): |
392 |
| - fig_test = plt.figure("test") |
393 |
| - fig_ref = plt.figure("reference") |
394 |
| - func(*args, fig_test=fig_test, fig_ref=fig_ref, **kwargs) |
395 |
| - test_image_path = result_dir / (func.__name__ + "." + ext) |
396 |
| - ref_image_path = result_dir / ( |
397 |
| - func.__name__ + "-expected." + ext |
398 |
| - ) |
399 |
| - fig_test.savefig(test_image_path) |
400 |
| - fig_ref.savefig(ref_image_path) |
401 |
| - _raise_on_image_difference( |
402 |
| - ref_image_path, test_image_path, tol=tol |
403 |
| - ) |
404 |
| - |
405 |
| - sig = inspect.signature(func) |
406 |
| - new_sig = sig.replace( |
407 |
| - parameters=([param |
408 |
| - for param in sig.parameters.values() |
409 |
| - if param.name not in {"fig_test", "fig_ref"}] |
410 |
| - + [inspect.Parameter("ext", POSITIONAL_OR_KEYWORD)]) |
411 |
| - ) |
| 394 | + def wrapper(*args, **kwargs): |
| 395 | + ext = kwargs['ext'] |
| 396 | + if 'ext' not in old_sig.parameters: |
| 397 | + kwargs.pop('ext') |
| 398 | + request = kwargs['request'] |
| 399 | + if 'request' not in old_sig.parameters: |
| 400 | + kwargs.pop('request') |
| 401 | + |
| 402 | + file_name = "".join(c for c in request.node.name |
| 403 | + if c in ALLOWED_CHARS) |
| 404 | + try: |
| 405 | + fig_test = plt.figure("test") |
| 406 | + fig_ref = plt.figure("reference") |
| 407 | + func(*args, fig_test=fig_test, fig_ref=fig_ref, **kwargs) |
| 408 | + test_image_path = result_dir / (file_name + "." + ext) |
| 409 | + ref_image_path = result_dir / (file_name + "-expected." + ext) |
| 410 | + fig_test.savefig(test_image_path) |
| 411 | + fig_ref.savefig(ref_image_path) |
| 412 | + _raise_on_image_difference( |
| 413 | + ref_image_path, test_image_path, tol=tol |
| 414 | + ) |
| 415 | + finally: |
| 416 | + plt.close(fig_test) |
| 417 | + plt.close(fig_ref) |
| 418 | + |
| 419 | + parameters = [ |
| 420 | + param |
| 421 | + for param in old_sig.parameters.values() |
| 422 | + if param.name not in {"fig_test", "fig_ref"} |
| 423 | + ] |
| 424 | + if 'ext' not in old_sig.parameters: |
| 425 | + parameters += [inspect.Parameter("ext", KEYWORD_ONLY)] |
| 426 | + if 'request' not in old_sig.parameters: |
| 427 | + parameters += [inspect.Parameter("request", KEYWORD_ONLY)] |
| 428 | + new_sig = old_sig.replace(parameters=parameters) |
412 | 429 | wrapper.__signature__ = new_sig
|
413 | 430 |
|
414 | 431 | # reach a bit into pytest internals to hoist the marks from
|
|
0 commit comments