Skip to content

Commit a5fb373

Browse files
committed
MNT: refactor compare / generate into stand alone methods
Also eliminate an enum only used in one place
1 parent 6df6b17 commit a5fb373

File tree

5 files changed

+98
-69
lines changed

5 files changed

+98
-69
lines changed

lib/matplotlib/testing/decorators.py

Lines changed: 94 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import contextlib
2-
import enum
32
import functools
43
import inspect
54
import json
@@ -123,15 +122,6 @@ class _ImageComparisonBase:
123122
This class provides *just* the comparison-related functionality and avoids
124123
any code that would be specific to any testing framework.
125124
"""
126-
class _ImageCheckMode(enum.Enum):
127-
TEST = enum.auto()
128-
GENERATE = enum.auto()
129-
130-
mode = (_ImageCheckMode.GENERATE
131-
if os.environ.get("MPLGENERATEBASELINE")
132-
else _ImageCheckMode.TEST
133-
)
134-
135125
def __init__(self, func, tol, remove_text, savefig_kwargs):
136126
self.func = func
137127
self.result_dir = _results_directory(func)
@@ -144,17 +134,7 @@ def __init__(self, func, tol, remove_text, savefig_kwargs):
144134
self.remove_text = remove_text
145135
self.savefig_kwargs = savefig_kwargs
146136

147-
def copy_baseline(self, baseline, extension):
148-
baseline_path = self.baseline_dir / baseline
149-
orig_expected_path = baseline_path.with_suffix(f'.{extension}')
150-
if extension == 'eps' and not orig_expected_path.exists():
151-
orig_expected_path = orig_expected_path.with_suffix('.pdf')
152-
153-
rel_path = orig_expected_path.relative_to(self.root_dir)
154-
if rel_path not in self.image_revs:
155-
raise ValueError(f'{rel_path!r} is not known.')
156-
if self.mode != self._ImageCheckMode.TEST:
157-
return orig_expected_path
137+
def copy_baseline(self, orig_expected_path):
158138
expected_fname = Path(make_test_filename(
159139
self.result_dir / orig_expected_path.name, 'expected'))
160140
try:
@@ -174,8 +154,26 @@ def copy_baseline(self, baseline, extension):
174154
f"{orig_expected_path}") from err
175155
return expected_fname
176156

177-
def compare(self, fig, baseline, extension, *, _lock=False):
178-
__tracebackhide__ = True
157+
# TODO add caching?
158+
def _get_md(self):
159+
if self.md_path.exists():
160+
with open(self.md_path) as fin:
161+
md = {Path(k): v for k, v in json.load(fin).items()}
162+
else:
163+
md = {}
164+
self.md_path.parent.mkdir(parents=True, exist_ok=True)
165+
return md
166+
167+
def _write_md(self, md):
168+
with open(self.md_path, 'w') as fout:
169+
json.dump(
170+
{str(PurePosixPath(*k.parts)): v for k, v in md.items()},
171+
fout,
172+
sort_keys=True,
173+
indent=' '
174+
)
175+
176+
def _prep_figure(self, fig, baseline, extension):
179177

180178
if self.remove_text:
181179
remove_ticks_and_titles(fig)
@@ -186,50 +184,75 @@ def compare(self, fig, baseline, extension, *, _lock=False):
186184
kwargs.setdefault('metadata',
187185
{'Creator': None, 'Producer': None,
188186
'CreationDate': None})
187+
orig_expected_path = self._compute_baseline_filename(baseline, extension)
189188

190-
lock = (cbook._lock_path(actual_path)
191-
if _lock else contextlib.nullcontext())
189+
return actual_path, kwargs, orig_expected_path
190+
191+
def _compute_baseline_filename(self, baseline, extension):
192+
baseline_path = self.baseline_dir / baseline
193+
orig_expected_path = baseline_path.with_suffix(f'.{extension}')
194+
rel_path = orig_expected_path.relative_to(self.root_dir)
195+
196+
if extension == 'eps' and rel_path not in self.image_revs:
197+
orig_expected_path = orig_expected_path.with_suffix('.pdf')
198+
rel_path = orig_expected_path.relative_to(self.root_dir)
199+
200+
if rel_path not in self.image_revs:
201+
raise ValueError(f'{rel_path!r} is not known.')
202+
return orig_expected_path
203+
204+
def _save_and_close(self, fig, actual_path, kwargs):
205+
try:
206+
fig.savefig(actual_path, **kwargs)
207+
finally:
208+
# Matplotlib has an autouse fixture to close figures, but this
209+
# makes things more convenient for third-party users.
210+
plt.close(fig)
211+
212+
def generate(self, fig, baseline, extension, *, _lock=False):
213+
__tracebackhide__ = True
214+
md = self._get_md()
215+
216+
actual_path, kwargs, orig_expected_path = self._prep_figure(
217+
fig, baseline, extension
218+
)
219+
220+
lock = (cbook._lock_path(actual_path) if _lock else contextlib.nullcontext())
192221
with lock:
193-
try:
194-
fig.savefig(actual_path, **kwargs)
195-
finally:
196-
# Matplotlib has an autouse fixture to close figures, but this
197-
# makes things more convenient for third-party users.
198-
plt.close(fig)
199-
expected_path = self.copy_baseline(baseline, extension)
200-
# TODO make sure the file exists (and cache?)
201-
if self.md_path.exists():
202-
with open(self.md_path) as fin:
203-
md = {Path(k): v for k, v in json.load(fin).items()}
222+
self._save_and_close(fig, actual_path, kwargs)
204223

205-
else:
206-
md = {}
207-
self.md_path.parent.mkdir(parents=True, exist_ok=True)
208-
if self.mode == self._ImageCheckMode.GENERATE:
209-
rel_path = expected_path.relative_to(self.root_dir)
210-
if rel_path not in md and rel_path.suffix == '.eps':
211-
rel_path = rel_path.with_suffix('.pdf')
212-
expected_path.parent.mkdir(parents=True, exist_ok=True)
213-
shutil.copyfile(actual_path, expected_path)
214-
215-
md[rel_path] = {
216-
'mpl_version': matplotlib.__version__,
217-
**{k: self.image_revs[rel_path][k]for k in ('sha', 'rev')}
218-
}
219-
with open(self.md_path, 'w') as fout:
220-
json.dump(
221-
{str(PurePosixPath(*k.parts)): v for k, v in md.items()},
222-
fout,
223-
sort_keys=True,
224-
indent=' '
225-
)
226-
else:
227-
rel_path = actual_path.relative_to(self.result_dir.parent)
228-
if rel_path not in md and rel_path.suffix == '.eps':
229-
rel_path = rel_path.with_suffix('.pdf')
230-
if md[rel_path]['sha'] != self.image_revs[rel_path]['sha']:
231-
raise RuntimeError("Baseline images do not match checkout.")
232-
_raise_on_image_difference(expected_path, actual_path, self.tol)
224+
rel_path = orig_expected_path.relative_to(self.root_dir)
225+
if rel_path not in self.image_revs and rel_path.suffix == '.eps':
226+
rel_path = rel_path.with_suffix('.pdf')
227+
orig_expected_path.parent.mkdir(parents=True, exist_ok=True)
228+
shutil.copyfile(actual_path, orig_expected_path)
229+
230+
md[rel_path] = {
231+
'mpl_version': matplotlib.__version__,
232+
**{k: self.image_revs[rel_path][k]for k in ('sha', 'rev')}
233+
}
234+
self._write_md(md)
235+
236+
def compare(self, fig, baseline, extension, *, _lock=False):
237+
__tracebackhide__ = True
238+
md = self._get_md()
239+
actual_path, kwargs, orig_expected_path = self._prep_figure(
240+
fig, baseline, extension
241+
)
242+
243+
lock = (cbook._lock_path(actual_path) if _lock else contextlib.nullcontext())
244+
with lock:
245+
self._save_and_close(fig, actual_path, kwargs)
246+
247+
expected_path = self.copy_baseline(orig_expected_path)
248+
249+
rel_path = actual_path.relative_to(self.result_dir.parent)
250+
if rel_path not in md and rel_path.suffix == '.eps':
251+
rel_path = rel_path.with_suffix('.pdf')
252+
if md[rel_path]['sha'] != self.image_revs[rel_path]['sha']:
253+
raise RuntimeError("Baseline images do not match checkout.")
254+
255+
_raise_on_image_difference(expected_path, actual_path, self.tol)
233256

234257

235258
def _pytest_image_comparison(baseline_images, extensions, tol,
@@ -293,8 +316,14 @@ def wrapper(*args, extension, request, **kwargs):
293316
assert len(figs) == len(our_baseline_images), (
294317
f"Test generated {len(figs)} images but there are "
295318
f"{len(our_baseline_images)} baseline images")
319+
320+
generating = bool(os.environ.get("MPLGENERATEBASELINE"))
321+
296322
for fig, baseline in zip(figs, our_baseline_images):
297-
img.compare(fig, baseline, extension, _lock=needs_lock)
323+
if generating:
324+
img.generate(fig, baseline, extension, _lock=needs_lock)
325+
else:
326+
img.compare(fig, baseline, extension, _lock=needs_lock)
298327

299328
parameters = list(old_sig.parameters.values())
300329
if 'extension' not in old_sig.parameters:

lib/matplotlib/tests/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
# Check that the test directories exist.
12-
if not (base_path).exists():
12+
if not (base_path).exists() and not os.environ.get("MPLGENERATEBASELINE"):
1313
raise OSError(
1414
f'The baseline image directory ({base_path!r}) does not exist. '
1515
'This is most likely because the test data is not installed. '

lib/mpl_toolkits/axes_grid1/tests/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
# Check that the test directories exist.
14-
if not (base_path).exists():
14+
if not (base_path).exists() and not os.environ.get("MPLGENERATEBASELINE"):
1515
raise OSError(
1616
f'The baseline image directory ({base_path!r}) does not exist. '
1717
'This is most likely because the test data is not installed. '

lib/mpl_toolkits/axisartist/tests/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
base_path = Path(__file__).parent / "baseline_images"
1111

1212
# Check that the test directories exist.
13-
if not (base_path).exists():
13+
if not (base_path).exists() and not os.environ.get("MPLGENERATEBASELINE"):
1414
raise OSError(
1515
f"The baseline image directory ({base_path!r}) does not exist. "
1616
"This is most likely because the test data is not installed. "

lib/mpl_toolkits/mplot3d/tests/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
# Check that the test directories exist.
12-
if not (base_path).exists():
12+
if not (base_path).exists() and not os.environ.get("MPLGENERATEBASELINE"):
1313
raise OSError(
1414
f"The baseline image directory ({base_path!r}) does not exist. "
1515
"This is most likely because the test data is not installed. "

0 commit comments

Comments
 (0)