14
14
import matplotlib .style
15
15
import matplotlib .units
16
16
import matplotlib .testing
17
- from matplotlib import cbook
18
- from matplotlib import ft2font
19
- from matplotlib import pyplot as plt
20
- from matplotlib import ticker
17
+ from matplotlib import cbook , ft2font , pyplot as plt , ticker , _pylab_helpers
21
18
from .compare import comparable_formats , compare_images , make_test_filename
22
19
from .exceptions import ImageComparisonFailure
23
20
@@ -129,6 +126,29 @@ def remove_ticks(ax):
129
126
remove_ticks (ax )
130
127
131
128
129
+ @contextlib .contextmanager
130
+ def _collect_new_figures ():
131
+ """
132
+ After::
133
+
134
+ with _collect_new_figures() as figs:
135
+ some_code()
136
+
137
+ the list *figs* contains the figures that have been created during the
138
+ execution of ``some_code``, sorted by figure number.
139
+ """
140
+ managers = _pylab_helpers .Gcf .figs
141
+ preexisting = [manager for manager in managers .values ()]
142
+ new_figs = []
143
+ try :
144
+ yield new_figs
145
+ finally :
146
+ new_managers = sorted ([manager for manager in managers .values ()
147
+ if manager not in preexisting ],
148
+ key = lambda manager : manager .num )
149
+ new_figs [:] = [manager .canvas .figure for manager in new_managers ]
150
+
151
+
132
152
def _raise_on_image_difference (expected , actual , tol ):
133
153
__tracebackhide__ = True
134
154
@@ -178,10 +198,8 @@ def copy_baseline(self, baseline, extension):
178
198
f"{ orig_expected_path } " ) from err
179
199
return expected_fname
180
200
181
- def compare (self , idx , baseline , extension , * , _lock = False ):
201
+ def compare (self , fig , baseline , extension , * , _lock = False ):
182
202
__tracebackhide__ = True
183
- fignum = plt .get_fignums ()[idx ]
184
- fig = plt .figure (fignum )
185
203
186
204
if self .remove_text :
187
205
remove_ticks_and_titles (fig )
@@ -196,7 +214,12 @@ def compare(self, idx, baseline, extension, *, _lock=False):
196
214
lock = (cbook ._lock_path (actual_path )
197
215
if _lock else contextlib .nullcontext ())
198
216
with lock :
199
- fig .savefig (actual_path , ** kwargs )
217
+ try :
218
+ fig .savefig (actual_path , ** kwargs )
219
+ finally :
220
+ # Matplotlib has an autouse fixture to close figures, but this
221
+ # makes things more convenient for third-party users.
222
+ plt .close (fig )
200
223
expected_path = self .copy_baseline (baseline , extension )
201
224
_raise_on_image_difference (expected_path , actual_path , self .tol )
202
225
@@ -235,7 +258,9 @@ def wrapper(*args, extension, request, **kwargs):
235
258
img = _ImageComparisonBase (func , tol = tol , remove_text = remove_text ,
236
259
savefig_kwargs = savefig_kwargs )
237
260
matplotlib .testing .set_font_settings_for_testing ()
238
- func (* args , ** kwargs )
261
+
262
+ with _collect_new_figures () as figs :
263
+ func (* args , ** kwargs )
239
264
240
265
# If the test is parametrized in any way other than applied via
241
266
# this decorator, then we need to use a lock to prevent two
@@ -252,11 +277,11 @@ def wrapper(*args, extension, request, **kwargs):
252
277
our_baseline_images = request .getfixturevalue (
253
278
'baseline_images' )
254
279
255
- assert len (plt . get_fignums () ) == len (our_baseline_images ), (
280
+ assert len (figs ) == len (our_baseline_images ), (
256
281
"Test generated {} images but there are {} baseline images"
257
- .format (len (plt . get_fignums () ), len (our_baseline_images )))
258
- for idx , baseline in enumerate ( our_baseline_images ):
259
- img .compare (idx , baseline , extension , _lock = needs_lock )
282
+ .format (len (figs ), len (our_baseline_images )))
283
+ for fig , baseline in zip ( figs , our_baseline_images ):
284
+ img .compare (fig , baseline , extension , _lock = needs_lock )
260
285
261
286
parameters = list (old_sig .parameters .values ())
262
287
if 'extension' not in old_sig .parameters :
@@ -427,11 +452,9 @@ def wrapper(*args, ext, request, **kwargs):
427
452
try :
428
453
fig_test = plt .figure ("test" )
429
454
fig_ref = plt .figure ("reference" )
430
- # Keep track of number of open figures, to make sure test
431
- # doesn't create any new ones
432
- n_figs = len (plt .get_fignums ())
433
- func (* args , fig_test = fig_test , fig_ref = fig_ref , ** kwargs )
434
- if len (plt .get_fignums ()) > n_figs :
455
+ with _collect_new_figures () as figs :
456
+ func (* args , fig_test = fig_test , fig_ref = fig_ref , ** kwargs )
457
+ if figs :
435
458
raise RuntimeError ('Number of open figures changed during '
436
459
'test. Make sure you are plotting to '
437
460
'fig_test or fig_ref, or if this is '
0 commit comments