Skip to content

Commit b3e576a

Browse files
committed
Merge pull request #5809 from anntzer/cleanup-generative-tests
Support generative tests in @cleanup.
2 parents 3675841 + b441be2 commit b3e576a

File tree

2 files changed

+26
-12
lines changed

2 files changed

+26
-12
lines changed

lib/matplotlib/testing/decorators.py

+25-11
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import functools
77
import gc
8+
import inspect
89
import os
910
import sys
1011
import shutil
@@ -129,18 +130,31 @@ def cleanup(style=None):
129130
# writing a decorator with optional arguments.
130131

131132
def make_cleanup(func):
132-
@functools.wraps(func)
133-
def wrapped_function(*args, **kwargs):
134-
original_units_registry = matplotlib.units.registry.copy()
135-
original_settings = mpl.rcParams.copy()
136-
matplotlib.style.use(style)
137-
try:
138-
func(*args, **kwargs)
139-
finally:
140-
_do_cleanup(original_units_registry,
141-
original_settings)
133+
if inspect.isgenerator(func):
134+
@functools.wraps(func)
135+
def wrapped_callable(*args, **kwargs):
136+
original_units_registry = matplotlib.units.registry.copy()
137+
original_settings = mpl.rcParams.copy()
138+
matplotlib.style.use(style)
139+
try:
140+
for yielded in func(*args, **kwargs):
141+
yield yielded
142+
finally:
143+
_do_cleanup(original_units_registry,
144+
original_settings)
145+
else:
146+
@functools.wraps(func)
147+
def wrapped_callable(*args, **kwargs):
148+
original_units_registry = matplotlib.units.registry.copy()
149+
original_settings = mpl.rcParams.copy()
150+
matplotlib.style.use(style)
151+
try:
152+
func(*args, **kwargs)
153+
finally:
154+
_do_cleanup(original_units_registry,
155+
original_settings)
142156

143-
return wrapped_function
157+
return wrapped_callable
144158

145159
if isinstance(style, six.string_types):
146160
return make_cleanup

lib/matplotlib/tests/test_axes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4278,7 +4278,7 @@ def _helper_y(ax):
42784278
orig_xlim = ax_lst[0][1].get_xlim()
42794279
ax.remove()
42804280
ax.set_xlim(0, 5)
4281-
assert assert_array_equal(ax_lst[0][1].get_xlim(), orig_xlim)
4281+
assert_array_equal(ax_lst[0][1].get_xlim(), orig_xlim)
42824282

42834283

42844284
@cleanup

0 commit comments

Comments
 (0)