Skip to content

Commit b35f0bc

Browse files
committed
Make image_comparison work even without the autoclose fixture.
... by capturing only the images created during the execution of the test, and closing them after the comparison.
1 parent 66687d0 commit b35f0bc

File tree

1 file changed

+40
-18
lines changed

1 file changed

+40
-18
lines changed

lib/matplotlib/testing/decorators.py

+40-18
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@
1414
import matplotlib.style
1515
import matplotlib.units
1616
import matplotlib.testing
17-
from matplotlib import cbook
18-
from matplotlib import ft2font
19-
from matplotlib import pyplot as plt
20-
from matplotlib import ticker
17+
from matplotlib import cbook, ft2font, pyplot as plt, ticker, _pylab_helpers
2118
from .compare import comparable_formats, compare_images, make_test_filename
2219
from .exceptions import ImageComparisonFailure
2320

@@ -129,6 +126,28 @@ def remove_ticks(ax):
129126
remove_ticks(ax)
130127

131128

129+
@contextlib.contextmanager
130+
def _collect_new_figures():
131+
"""
132+
After::
133+
134+
with _collect_new_figures() as figs: some_code()
135+
136+
*figs* contains the figures that have been created during the execution of
137+
``some_code``, sorted by figure number.
138+
"""
139+
managers = _pylab_helpers.Gcf.figs
140+
preexisting = [manager for manager in managers.values()]
141+
new_figs = []
142+
try:
143+
yield new_figs
144+
finally:
145+
new_managers = sorted([manager for manager in managers.values()
146+
if manager not in preexisting],
147+
key=lambda manager: manager.num)
148+
new_figs[:] = [manager.canvas.figure for manager in new_managers]
149+
150+
132151
def _raise_on_image_difference(expected, actual, tol):
133152
__tracebackhide__ = True
134153

@@ -178,10 +197,8 @@ def copy_baseline(self, baseline, extension):
178197
f"{orig_expected_path}") from err
179198
return expected_fname
180199

181-
def compare(self, idx, baseline, extension, *, _lock=False):
200+
def compare(self, fig, baseline, extension, *, _lock=False):
182201
__tracebackhide__ = True
183-
fignum = plt.get_fignums()[idx]
184-
fig = plt.figure(fignum)
185202

186203
if self.remove_text:
187204
remove_ticks_and_titles(fig)
@@ -196,7 +213,12 @@ def compare(self, idx, baseline, extension, *, _lock=False):
196213
lock = (cbook._lock_path(actual_path)
197214
if _lock else contextlib.nullcontext())
198215
with lock:
199-
fig.savefig(actual_path, **kwargs)
216+
try:
217+
fig.savefig(actual_path, **kwargs)
218+
finally:
219+
# Matplotlib has an autouse fixture to close figures, but this
220+
# makes things more convenient for third-party users.
221+
plt.close(fig)
200222
expected_path = self.copy_baseline(baseline, extension)
201223
_raise_on_image_difference(expected_path, actual_path, self.tol)
202224

@@ -235,7 +257,9 @@ def wrapper(*args, extension, request, **kwargs):
235257
img = _ImageComparisonBase(func, tol=tol, remove_text=remove_text,
236258
savefig_kwargs=savefig_kwargs)
237259
matplotlib.testing.set_font_settings_for_testing()
238-
func(*args, **kwargs)
260+
261+
with _collect_new_figures() as figs:
262+
func(*args, **kwargs)
239263

240264
# If the test is parametrized in any way other than applied via
241265
# this decorator, then we need to use a lock to prevent two
@@ -252,11 +276,11 @@ def wrapper(*args, extension, request, **kwargs):
252276
our_baseline_images = request.getfixturevalue(
253277
'baseline_images')
254278

255-
assert len(plt.get_fignums()) == len(our_baseline_images), (
279+
assert len(figs) == len(our_baseline_images), (
256280
"Test generated {} images but there are {} baseline images"
257-
.format(len(plt.get_fignums()), len(our_baseline_images)))
258-
for idx, baseline in enumerate(our_baseline_images):
259-
img.compare(idx, baseline, extension, _lock=needs_lock)
281+
.format(len(figs), len(our_baseline_images)))
282+
for fig, baseline in zip(figs, our_baseline_images):
283+
img.compare(fig, baseline, extension, _lock=needs_lock)
260284

261285
parameters = list(old_sig.parameters.values())
262286
if 'extension' not in old_sig.parameters:
@@ -427,11 +451,9 @@ def wrapper(*args, ext, request, **kwargs):
427451
try:
428452
fig_test = plt.figure("test")
429453
fig_ref = plt.figure("reference")
430-
# Keep track of number of open figures, to make sure test
431-
# doesn't create any new ones
432-
n_figs = len(plt.get_fignums())
433-
func(*args, fig_test=fig_test, fig_ref=fig_ref, **kwargs)
434-
if len(plt.get_fignums()) > n_figs:
454+
with _collect_new_figures() as figs:
455+
func(*args, fig_test=fig_test, fig_ref=fig_ref, **kwargs)
456+
if figs:
435457
raise RuntimeError('Number of open figures changed during '
436458
'test. Make sure you are plotting to '
437459
'fig_test or fig_ref, or if this is '

0 commit comments

Comments
 (0)