1
1
import contextlib
2
- import enum
3
2
import functools
4
3
import inspect
5
4
import json
@@ -123,15 +122,6 @@ class _ImageComparisonBase:
123
122
This class provides *just* the comparison-related functionality and avoids
124
123
any code that would be specific to any testing framework.
125
124
"""
126
- class _ImageCheckMode (enum .Enum ):
127
- TEST = enum .auto ()
128
- GENERATE = enum .auto ()
129
-
130
- mode = (_ImageCheckMode .GENERATE
131
- if os .environ .get ("MPLGENERATEBASELINE" )
132
- else _ImageCheckMode .TEST
133
- )
134
-
135
125
def __init__ (self , func , tol , remove_text , savefig_kwargs ):
136
126
self .func = func
137
127
self .result_dir = _results_directory (func )
@@ -144,17 +134,7 @@ def __init__(self, func, tol, remove_text, savefig_kwargs):
144
134
self .remove_text = remove_text
145
135
self .savefig_kwargs = savefig_kwargs
146
136
147
- def copy_baseline (self , baseline , extension ):
148
- baseline_path = self .baseline_dir / baseline
149
- orig_expected_path = baseline_path .with_suffix (f'.{ extension } ' )
150
- if extension == 'eps' and not orig_expected_path .exists ():
151
- orig_expected_path = orig_expected_path .with_suffix ('.pdf' )
152
-
153
- rel_path = orig_expected_path .relative_to (self .root_dir )
154
- if rel_path not in self .image_revs :
155
- raise ValueError (f'{ rel_path !r} is not known.' )
156
- if self .mode != self ._ImageCheckMode .TEST :
157
- return orig_expected_path
137
+ def copy_baseline (self , orig_expected_path ):
158
138
expected_fname = Path (make_test_filename (
159
139
self .result_dir / orig_expected_path .name , 'expected' ))
160
140
try :
@@ -174,8 +154,26 @@ def copy_baseline(self, baseline, extension):
174
154
f"{ orig_expected_path } " ) from err
175
155
return expected_fname
176
156
177
- def compare (self , fig , baseline , extension , * , _lock = False ):
178
- __tracebackhide__ = True
157
+ # TODO add caching?
158
+ def _get_md (self ):
159
+ if self .md_path .exists ():
160
+ with open (self .md_path ) as fin :
161
+ md = {Path (k ): v for k , v in json .load (fin ).items ()}
162
+ else :
163
+ md = {}
164
+ self .md_path .parent .mkdir (parents = True , exist_ok = True )
165
+ return md
166
+
167
+ def _write_md (self , md ):
168
+ with open (self .md_path , 'w' ) as fout :
169
+ json .dump (
170
+ {str (PurePosixPath (* k .parts )): v for k , v in md .items ()},
171
+ fout ,
172
+ sort_keys = True ,
173
+ indent = ' '
174
+ )
175
+
176
+ def _prep_figure (self , fig , baseline , extension ):
179
177
180
178
if self .remove_text :
181
179
remove_ticks_and_titles (fig )
@@ -186,50 +184,75 @@ def compare(self, fig, baseline, extension, *, _lock=False):
186
184
kwargs .setdefault ('metadata' ,
187
185
{'Creator' : None , 'Producer' : None ,
188
186
'CreationDate' : None })
187
+ orig_expected_path = self ._compute_baseline_filename (baseline , extension )
189
188
190
- lock = (cbook ._lock_path (actual_path )
191
- if _lock else contextlib .nullcontext ())
189
+ return actual_path , kwargs , orig_expected_path
190
+
191
+ def _compute_baseline_filename (self , baseline , extension ):
192
+ baseline_path = self .baseline_dir / baseline
193
+ orig_expected_path = baseline_path .with_suffix (f'.{ extension } ' )
194
+ rel_path = orig_expected_path .relative_to (self .root_dir )
195
+
196
+ if extension == 'eps' and rel_path not in self .image_revs :
197
+ orig_expected_path = orig_expected_path .with_suffix ('.pdf' )
198
+ rel_path = orig_expected_path .relative_to (self .root_dir )
199
+
200
+ if rel_path not in self .image_revs :
201
+ raise ValueError (f'{ rel_path !r} is not known.' )
202
+ return orig_expected_path
203
+
204
+ def _save_and_close (self , fig , actual_path , kwargs ):
205
+ try :
206
+ fig .savefig (actual_path , ** kwargs )
207
+ finally :
208
+ # Matplotlib has an autouse fixture to close figures, but this
209
+ # makes things more convenient for third-party users.
210
+ plt .close (fig )
211
+
212
+ def generate (self , fig , baseline , extension , * , _lock = False ):
213
+ __tracebackhide__ = True
214
+ md = self ._get_md ()
215
+
216
+ actual_path , kwargs , orig_expected_path = self ._prep_figure (
217
+ fig , baseline , extension
218
+ )
219
+
220
+ lock = (cbook ._lock_path (actual_path ) if _lock else contextlib .nullcontext ())
192
221
with lock :
193
- try :
194
- fig .savefig (actual_path , ** kwargs )
195
- finally :
196
- # Matplotlib has an autouse fixture to close figures, but this
197
- # makes things more convenient for third-party users.
198
- plt .close (fig )
199
- expected_path = self .copy_baseline (baseline , extension )
200
- # TODO make sure the file exists (and cache?)
201
- if self .md_path .exists ():
202
- with open (self .md_path ) as fin :
203
- md = {Path (k ): v for k , v in json .load (fin ).items ()}
222
+ self ._save_and_close (fig , actual_path , kwargs )
204
223
205
- else :
206
- md = {}
207
- self .md_path .parent .mkdir (parents = True , exist_ok = True )
208
- if self .mode == self ._ImageCheckMode .GENERATE :
209
- rel_path = expected_path .relative_to (self .root_dir )
210
- if rel_path not in md and rel_path .suffix == '.eps' :
211
- rel_path = rel_path .with_suffix ('.pdf' )
212
- expected_path .parent .mkdir (parents = True , exist_ok = True )
213
- shutil .copyfile (actual_path , expected_path )
214
-
215
- md [rel_path ] = {
216
- 'mpl_version' : matplotlib .__version__ ,
217
- ** {k : self .image_revs [rel_path ][k ]for k in ('sha' , 'rev' )}
218
- }
219
- with open (self .md_path , 'w' ) as fout :
220
- json .dump (
221
- {str (PurePosixPath (* k .parts )): v for k , v in md .items ()},
222
- fout ,
223
- sort_keys = True ,
224
- indent = ' '
225
- )
226
- else :
227
- rel_path = actual_path .relative_to (self .result_dir .parent )
228
- if rel_path not in md and rel_path .suffix == '.eps' :
229
- rel_path = rel_path .with_suffix ('.pdf' )
230
- if md [rel_path ]['sha' ] != self .image_revs [rel_path ]['sha' ]:
231
- raise RuntimeError ("Baseline images do not match checkout." )
232
- _raise_on_image_difference (expected_path , actual_path , self .tol )
224
+ rel_path = orig_expected_path .relative_to (self .root_dir )
225
+ if rel_path not in self .image_revs and rel_path .suffix == '.eps' :
226
+ rel_path = rel_path .with_suffix ('.pdf' )
227
+ orig_expected_path .parent .mkdir (parents = True , exist_ok = True )
228
+ shutil .copyfile (actual_path , orig_expected_path )
229
+
230
+ md [rel_path ] = {
231
+ 'mpl_version' : matplotlib .__version__ ,
232
+ ** {k : self .image_revs [rel_path ][k ]for k in ('sha' , 'rev' )}
233
+ }
234
+ self ._write_md (md )
235
+
236
+ def compare (self , fig , baseline , extension , * , _lock = False ):
237
+ __tracebackhide__ = True
238
+ md = self ._get_md ()
239
+ actual_path , kwargs , orig_expected_path = self ._prep_figure (
240
+ fig , baseline , extension
241
+ )
242
+
243
+ lock = (cbook ._lock_path (actual_path ) if _lock else contextlib .nullcontext ())
244
+ with lock :
245
+ self ._save_and_close (fig , actual_path , kwargs )
246
+
247
+ expected_path = self .copy_baseline (orig_expected_path )
248
+
249
+ rel_path = actual_path .relative_to (self .result_dir .parent )
250
+ if rel_path not in md and rel_path .suffix == '.eps' :
251
+ rel_path = rel_path .with_suffix ('.pdf' )
252
+ if md [rel_path ]['sha' ] != self .image_revs [rel_path ]['sha' ]:
253
+ raise RuntimeError ("Baseline images do not match checkout." )
254
+
255
+ _raise_on_image_difference (expected_path , actual_path , self .tol )
233
256
234
257
235
258
def _pytest_image_comparison (baseline_images , extensions , tol ,
@@ -293,8 +316,14 @@ def wrapper(*args, extension, request, **kwargs):
293
316
assert len (figs ) == len (our_baseline_images ), (
294
317
f"Test generated { len (figs )} images but there are "
295
318
f"{ len (our_baseline_images )} baseline images" )
319
+
320
+ generating = bool (os .environ .get ("MPLGENERATEBASELINE" ))
321
+
296
322
for fig , baseline in zip (figs , our_baseline_images ):
297
- img .compare (fig , baseline , extension , _lock = needs_lock )
323
+ if generating :
324
+ img .generate (fig , baseline , extension , _lock = needs_lock )
325
+ else :
326
+ img .compare (fig , baseline , extension , _lock = needs_lock )
298
327
299
328
parameters = list (old_sig .parameters .values ())
300
329
if 'extension' not in old_sig .parameters :
0 commit comments