14
14
import matplotlib .style
15
15
import matplotlib .units
16
16
import matplotlib .testing
17
- from matplotlib import cbook
18
- from matplotlib import ft2font
19
- from matplotlib import pyplot as plt
20
- from matplotlib import ticker
17
+ from matplotlib import cbook , ft2font , pyplot as plt , ticker , _pylab_helpers
21
18
from .compare import comparable_formats , compare_images , make_test_filename
22
19
from .exceptions import ImageComparisonFailure
23
20
@@ -129,6 +126,28 @@ def remove_ticks(ax):
129
126
remove_ticks (ax )
130
127
131
128
129
+ @contextlib .contextmanager
130
+ def _collect_new_figures ():
131
+ """
132
+ After::
133
+
134
+ with _collect_new_figures() as figs: some_code()
135
+
136
+ *figs* contains the figures that have been created during the execution of
137
+ ``some_code``, sorted by figure number.
138
+ """
139
+ managers = _pylab_helpers .Gcf .figs
140
+ preexisting = [manager for manager in managers .values ()]
141
+ new_figs = []
142
+ try :
143
+ yield new_figs
144
+ finally :
145
+ new_managers = sorted ([manager for manager in managers .values ()
146
+ if manager not in preexisting ],
147
+ key = lambda manager : manager .num )
148
+ new_figs [:] = [manager .canvas .figure for manager in new_managers ]
149
+
150
+
132
151
def _raise_on_image_difference (expected , actual , tol ):
133
152
__tracebackhide__ = True
134
153
@@ -178,10 +197,8 @@ def copy_baseline(self, baseline, extension):
178
197
f"{ orig_expected_path } " ) from err
179
198
return expected_fname
180
199
181
- def compare (self , idx , baseline , extension , * , _lock = False ):
200
+ def compare (self , fig , baseline , extension , * , _lock = False ):
182
201
__tracebackhide__ = True
183
- fignum = plt .get_fignums ()[idx ]
184
- fig = plt .figure (fignum )
185
202
186
203
if self .remove_text :
187
204
remove_ticks_and_titles (fig )
@@ -196,7 +213,12 @@ def compare(self, idx, baseline, extension, *, _lock=False):
196
213
lock = (cbook ._lock_path (actual_path )
197
214
if _lock else contextlib .nullcontext ())
198
215
with lock :
199
- fig .savefig (actual_path , ** kwargs )
216
+ try :
217
+ fig .savefig (actual_path , ** kwargs )
218
+ finally :
219
+ # Matplotlib has an autouse fixture to close figures, but this
220
+ # makes things more convenient for third-party users.
221
+ plt .close (fig )
200
222
expected_path = self .copy_baseline (baseline , extension )
201
223
_raise_on_image_difference (expected_path , actual_path , self .tol )
202
224
@@ -235,7 +257,9 @@ def wrapper(*args, extension, request, **kwargs):
235
257
img = _ImageComparisonBase (func , tol = tol , remove_text = remove_text ,
236
258
savefig_kwargs = savefig_kwargs )
237
259
matplotlib .testing .set_font_settings_for_testing ()
238
- func (* args , ** kwargs )
260
+
261
+ with _collect_new_figures () as figs :
262
+ func (* args , ** kwargs )
239
263
240
264
# If the test is parametrized in any way other than applied via
241
265
# this decorator, then we need to use a lock to prevent two
@@ -252,11 +276,11 @@ def wrapper(*args, extension, request, **kwargs):
252
276
our_baseline_images = request .getfixturevalue (
253
277
'baseline_images' )
254
278
255
- assert len (plt . get_fignums () ) == len (our_baseline_images ), (
279
+ assert len (figs ) == len (our_baseline_images ), (
256
280
"Test generated {} images but there are {} baseline images"
257
- .format (len (plt . get_fignums () ), len (our_baseline_images )))
258
- for idx , baseline in enumerate ( our_baseline_images ):
259
- img .compare (idx , baseline , extension , _lock = needs_lock )
281
+ .format (len (figs ), len (our_baseline_images )))
282
+ for fig , baseline in zip ( figs , our_baseline_images ):
283
+ img .compare (fig , baseline , extension , _lock = needs_lock )
260
284
261
285
parameters = list (old_sig .parameters .values ())
262
286
if 'extension' not in old_sig .parameters :
@@ -427,11 +451,9 @@ def wrapper(*args, ext, request, **kwargs):
427
451
try :
428
452
fig_test = plt .figure ("test" )
429
453
fig_ref = plt .figure ("reference" )
430
- # Keep track of number of open figures, to make sure test
431
- # doesn't create any new ones
432
- n_figs = len (plt .get_fignums ())
433
- func (* args , fig_test = fig_test , fig_ref = fig_ref , ** kwargs )
434
- if len (plt .get_fignums ()) > n_figs :
454
+ with _collect_new_figures () as figs :
455
+ func (* args , fig_test = fig_test , fig_ref = fig_ref , ** kwargs )
456
+ if figs :
435
457
raise RuntimeError ('Number of open figures changed during '
436
458
'test. Make sure you are plotting to '
437
459
'fig_test or fig_ref, or if this is '
0 commit comments