Skip to content

Commit fcf3f72

Browse files
committed
Make image_comparison work even without the autoclose fixture.
... by capturing only the images created during the execution of the test, and closing them after the comparison.
1 parent 66687d0 commit fcf3f72

File tree

1 file changed

+41
-18
lines changed

1 file changed

+41
-18
lines changed

lib/matplotlib/testing/decorators.py

+41-18
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@
1414
import matplotlib.style
1515
import matplotlib.units
1616
import matplotlib.testing
17-
from matplotlib import cbook
18-
from matplotlib import ft2font
19-
from matplotlib import pyplot as plt
20-
from matplotlib import ticker
17+
from matplotlib import cbook, ft2font, pyplot as plt, ticker, _pylab_helpers
2118
from .compare import comparable_formats, compare_images, make_test_filename
2219
from .exceptions import ImageComparisonFailure
2320

@@ -129,6 +126,29 @@ def remove_ticks(ax):
129126
remove_ticks(ax)
130127

131128

129+
@contextlib.contextmanager
130+
def _collect_new_figures():
131+
"""
132+
After::
133+
134+
with _collect_new_figures() as figs:
135+
some_code()
136+
137+
the list *figs* contains the figures that have been created during the
138+
execution of ``some_code``, sorted by figure number.
139+
"""
140+
managers = _pylab_helpers.Gcf.figs
141+
preexisting = [manager for manager in managers.values()]
142+
new_figs = []
143+
try:
144+
yield new_figs
145+
finally:
146+
new_managers = sorted([manager for manager in managers.values()
147+
if manager not in preexisting],
148+
key=lambda manager: manager.num)
149+
new_figs[:] = [manager.canvas.figure for manager in new_managers]
150+
151+
132152
def _raise_on_image_difference(expected, actual, tol):
133153
__tracebackhide__ = True
134154

@@ -178,10 +198,8 @@ def copy_baseline(self, baseline, extension):
178198
f"{orig_expected_path}") from err
179199
return expected_fname
180200

181-
def compare(self, idx, baseline, extension, *, _lock=False):
201+
def compare(self, fig, baseline, extension, *, _lock=False):
182202
__tracebackhide__ = True
183-
fignum = plt.get_fignums()[idx]
184-
fig = plt.figure(fignum)
185203

186204
if self.remove_text:
187205
remove_ticks_and_titles(fig)
@@ -196,7 +214,12 @@ def compare(self, idx, baseline, extension, *, _lock=False):
196214
lock = (cbook._lock_path(actual_path)
197215
if _lock else contextlib.nullcontext())
198216
with lock:
199-
fig.savefig(actual_path, **kwargs)
217+
try:
218+
fig.savefig(actual_path, **kwargs)
219+
finally:
220+
# Matplotlib has an autouse fixture to close figures, but this
221+
# makes things more convenient for third-party users.
222+
plt.close(fig)
200223
expected_path = self.copy_baseline(baseline, extension)
201224
_raise_on_image_difference(expected_path, actual_path, self.tol)
202225

@@ -235,7 +258,9 @@ def wrapper(*args, extension, request, **kwargs):
235258
img = _ImageComparisonBase(func, tol=tol, remove_text=remove_text,
236259
savefig_kwargs=savefig_kwargs)
237260
matplotlib.testing.set_font_settings_for_testing()
238-
func(*args, **kwargs)
261+
262+
with _collect_new_figures() as figs:
263+
func(*args, **kwargs)
239264

240265
# If the test is parametrized in any way other than applied via
241266
# this decorator, then we need to use a lock to prevent two
@@ -252,11 +277,11 @@ def wrapper(*args, extension, request, **kwargs):
252277
our_baseline_images = request.getfixturevalue(
253278
'baseline_images')
254279

255-
assert len(plt.get_fignums()) == len(our_baseline_images), (
280+
assert len(figs) == len(our_baseline_images), (
256281
"Test generated {} images but there are {} baseline images"
257-
.format(len(plt.get_fignums()), len(our_baseline_images)))
258-
for idx, baseline in enumerate(our_baseline_images):
259-
img.compare(idx, baseline, extension, _lock=needs_lock)
282+
.format(len(figs), len(our_baseline_images)))
283+
for fig, baseline in zip(figs, our_baseline_images):
284+
img.compare(fig, baseline, extension, _lock=needs_lock)
260285

261286
parameters = list(old_sig.parameters.values())
262287
if 'extension' not in old_sig.parameters:
@@ -427,11 +452,9 @@ def wrapper(*args, ext, request, **kwargs):
427452
try:
428453
fig_test = plt.figure("test")
429454
fig_ref = plt.figure("reference")
430-
# Keep track of number of open figures, to make sure test
431-
# doesn't create any new ones
432-
n_figs = len(plt.get_fignums())
433-
func(*args, fig_test=fig_test, fig_ref=fig_ref, **kwargs)
434-
if len(plt.get_fignums()) > n_figs:
455+
with _collect_new_figures() as figs:
456+
func(*args, fig_test=fig_test, fig_ref=fig_ref, **kwargs)
457+
if figs:
435458
raise RuntimeError('Number of open figures changed during '
436459
'test. Make sure you are plotting to '
437460
'fig_test or fig_ref, or if this is '

0 commit comments

Comments
 (0)