Skip to content

Commit 8d1a031

Browse files
committed
Templatize class factories.
This makes mpl_toolkits axes classes picklable (see test_pickle) by generalizing the machinery of _picklable_subplot_class_constructor, which would otherwise have had to be reimplemented for each class factory.
1 parent 876415d commit 8d1a031

File tree

7 files changed

+72
-103
lines changed

7 files changed

+72
-103
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
The private ``matplotlib.axes._subplots._subplot_classes`` dict has been removed
2+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

lib/matplotlib/axes/_subplots.py

+3-63
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import functools
2-
31
from matplotlib import _api, docstring
42
import matplotlib.artist as martist
53
from matplotlib.axes._axes import Axes
@@ -37,15 +35,6 @@ def __init__(self, fig, *args, **kwargs):
3735
# This will also update the axes position.
3836
self.set_subplotspec(SubplotSpec._from_subplot_args(fig, args))
3937

40-
def __reduce__(self):
41-
# get the first axes class which does not inherit from a subplotbase
42-
axes_class = next(
43-
c for c in type(self).__mro__
44-
if issubclass(c, Axes) and not issubclass(c, SubplotBase))
45-
return (_picklable_subplot_class_constructor,
46-
(axes_class,),
47-
self.__getstate__())
48-
4938
@_api.deprecated(
5039
"3.4", alternative="get_subplotspec",
5140
addendum="(get_subplotspec returns a SubplotSpec instance.)")
@@ -170,59 +159,10 @@ def _make_twin_axes(self, *args, **kwargs):
170159
return twin
171160

172161

173-
# this here to support cartopy which was using a private part of the
174-
# API to register their Axes subclasses.
175-
176-
# In 3.1 this should be changed to a dict subclass that warns on use
177-
# In 3.3 to a dict subclass that raises a useful exception on use
178-
# In 3.4 should be removed
179-
180-
# The slow timeline is to give cartopy enough time to get several
181-
# release out before we break them.
182-
_subplot_classes = {}
183-
184-
185-
@functools.lru_cache(None)
186-
def subplot_class_factory(axes_class=None):
187-
"""
188-
Make a new class that inherits from `.SubplotBase` and the
189-
given axes_class (which is assumed to be a subclass of `.axes.Axes`).
190-
This is perhaps a little bit roundabout to make a new class on
191-
the fly like this, but it means that a new Subplot class does
192-
not have to be created for every type of Axes.
193-
"""
194-
if axes_class is None:
195-
_api.warn_deprecated(
196-
"3.3", message="Support for passing None to subplot_class_factory "
197-
"is deprecated since %(since)s; explicitly pass the default Axes "
198-
"class instead. This will become an error %(removal)s.")
199-
axes_class = Axes
200-
try:
201-
# Avoid creating two different instances of GeoAxesSubplot...
202-
# Only a temporary backcompat fix. This should be removed in
203-
# 3.4
204-
return next(cls for cls in SubplotBase.__subclasses__()
205-
if cls.__bases__ == (SubplotBase, axes_class))
206-
except StopIteration:
207-
return type("%sSubplot" % axes_class.__name__,
208-
(SubplotBase, axes_class),
209-
{'_axes_class': axes_class})
210-
211-
162+
subplot_class_factory = cbook._make_class_factory(
163+
SubplotBase, "{}Subplot", "_axes_class", default_axes_class=Axes)
212164
Subplot = subplot_class_factory(Axes) # Provided for backward compatibility.
213165

214-
215-
def _picklable_subplot_class_constructor(axes_class):
216-
"""
217-
Stub factory that returns an empty instance of the appropriate subplot
218-
class when called with an axes class. This is purely to allow pickling of
219-
Axes and Subplots.
220-
"""
221-
subplot_class = subplot_class_factory(axes_class)
222-
return subplot_class.__new__(subplot_class)
223-
224-
225166
docstring.interpd.update(Axes_kwdoc=martist.kwdoc(Axes))
226-
docstring.dedent_interpd(Axes.__init__)
227-
228167
docstring.interpd.update(Subplot_kwdoc=martist.kwdoc(Axes))
168+
docstring.dedent_interpd(Axes.__init__)

lib/matplotlib/cbook/__init__.py

+51
Original file line numberDiff line numberDiff line change
@@ -2298,3 +2298,54 @@ def _unikey_or_keysym_to_mplkey(unikey, keysym):
22982298
"next": "pagedown", # Used by tk.
22992299
}.get(key, key)
23002300
return key
2301+
2302+
2303+
@functools.lru_cache(None)
2304+
def _make_class_factory(
2305+
mixin_class, fmt, axes_attr=None, *, default_axes_class=None):
2306+
"""
2307+
Return a function that creates picklable classes inheriting from a mixin.
2308+
2309+
After ::
2310+
2311+
factory = _make_class_factory(FooMixin, fmt, axes_attr)
2312+
FooAxes = factory(Axes)
2313+
2314+
``Foo`` is a class that inherits from ``FooMixin`` and ``Axes`` and **is
2315+
picklable** (picklability is what differentiates this from a plain call to
2316+
`type`). Its ``__name__`` is set to ``fmt.format(Axes.__name__)`` and the
2317+
base class is stored in the ``axes_attr`` attribute, if not None.
2318+
"""
2319+
2320+
@functools.lru_cache(None)
2321+
def class_factory(axes_class):
2322+
# default_axes_class should go away once the deprecation elapses.
2323+
if axes_class is None and default_axes_class is not None:
2324+
warn_deprecated(
2325+
"3.3", message="Support for passing None to class factories "
2326+
"is deprecated since %(since)s and will be removed "
2327+
"%(removal)s; explicitly pass the default Axes class instead.")
2328+
return class_factory(default_axes_class)
2329+
d = {"__reduce__":
2330+
lambda self: (_picklable_class_constructor,
2331+
(mixin_class, fmt, axes_attr,
2332+
default_axes_class, axes_class,),
2333+
self.__getstate__())}
2334+
if axes_attr is not None:
2335+
d[axes_attr] = axes_class
2336+
cls = type(
2337+
fmt.format(axes_class.__name__), (mixin_class, axes_class), d)
2338+
# Better in first approximation than __module__ = "matplotlib.cbook"...
2339+
cls.__module__ = mixin_class.__module__
2340+
return cls
2341+
2342+
return class_factory
2343+
2344+
2345+
def _picklable_class_constructor(
2346+
base_cls, fmt, axes_attr, default_axes_class, axes_class):
2347+
"""Internal helper for _make_class_factory."""
2348+
cls = _make_class_factory(
2349+
base_cls, fmt, axes_attr,
2350+
default_axes_class=default_axes_class)(axes_class)
2351+
return cls.__new__(cls)

lib/matplotlib/tests/test_axes.py

-15
Original file line numberDiff line numberDiff line change
@@ -6325,21 +6325,6 @@ def test_spines_properbbox_after_zoom():
63256325
np.testing.assert_allclose(bb.get_points(), bb2.get_points(), rtol=1e-6)
63266326

63276327

6328-
def test_cartopy_backcompat():
6329-
6330-
class Dummy(matplotlib.axes.Axes):
6331-
...
6332-
6333-
class DummySubplot(matplotlib.axes.SubplotBase, Dummy):
6334-
_axes_class = Dummy
6335-
6336-
matplotlib.axes._subplots._subplot_classes[Dummy] = DummySubplot
6337-
6338-
FactoryDummySubplot = matplotlib.axes.subplot_class_factory(Dummy)
6339-
6340-
assert DummySubplot is FactoryDummySubplot
6341-
6342-
63436328
def test_gettightbbox_ignore_nan():
63446329
fig, ax = plt.subplots()
63456330
remove_ticks_and_titles(fig)

lib/matplotlib/tests/test_pickle.py

+6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import matplotlib.pyplot as plt
1111
import matplotlib.transforms as mtransforms
1212
import matplotlib.figure as mfigure
13+
from mpl_toolkits.axes_grid1 import parasite_axes
1314

1415

1516
def test_simple():
@@ -212,3 +213,8 @@ def test_unpickle_canvas():
212213
out.seek(0)
213214
fig2 = pickle.load(out)
214215
assert fig2.canvas is not None
216+
217+
218+
def test_mpl_toolkits():
219+
ax = parasite_axes.host_axes([0, 0, 1, 1])
220+
assert type(pickle.loads(pickle.dumps(ax))) == parasite_axes.HostAxes

lib/mpl_toolkits/axes_grid1/parasite_axes.py

+7-16
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,8 @@ def apply_aspect(self, position=None):
9494
# end of aux_transform support
9595

9696

97-
@functools.lru_cache(None)
98-
def parasite_axes_class_factory(axes_class):
99-
return type("%sParasite" % axes_class.__name__,
100-
(ParasiteAxesBase, axes_class), {})
101-
102-
97+
parasite_axes_class_factory = cbook._make_class_factory(
98+
ParasiteAxesBase, "{}Parasite")
10399
ParasiteAxes = parasite_axes_class_factory(Axes)
104100

105101

@@ -283,7 +279,7 @@ def _add_twin_axes(self, axes_class, **kwargs):
283279
*kwargs* are forwarded to the parasite axes constructor.
284280
"""
285281
if axes_class is None:
286-
axes_class = self._get_base_axes()
282+
axes_class = self._base_axes_class
287283
ax = parasite_axes_class_factory(axes_class)(self, **kwargs)
288284
self.parasites.append(ax)
289285
ax._remove_method = self._remove_any_twin
@@ -310,11 +306,10 @@ def get_tightbbox(self, renderer, call_axes_locator=True,
310306
return Bbox.union([b for b in bbs if b.width != 0 or b.height != 0])
311307

312308

313-
@functools.lru_cache(None)
314-
def host_axes_class_factory(axes_class):
315-
return type("%sHostAxes" % axes_class.__name__,
316-
(HostAxesBase, axes_class),
317-
{'_get_base_axes': lambda self: axes_class})
309+
host_axes_class_factory = cbook._make_class_factory(
310+
HostAxesBase, "{}HostAxes", "_base_axes_class")
311+
HostAxes = host_axes_class_factory(Axes)
312+
SubplotHost = subplot_class_factory(HostAxes)
318313

319314

320315
def host_subplot_class_factory(axes_class):
@@ -323,10 +318,6 @@ def host_subplot_class_factory(axes_class):
323318
return subplot_host_class
324319

325320

326-
HostAxes = host_axes_class_factory(Axes)
327-
SubplotHost = subplot_class_factory(HostAxes)
328-
329-
330321
def host_axes(*args, axes_class=Axes, figure=None, **kwargs):
331322
"""
332323
Create axes that can act as a hosts to parasitic axes.

lib/mpl_toolkits/axisartist/floating_axes.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
# TODO :
66
# see if tick_iterator method can be simplified by reusing the parent method.
77

8-
import functools
9-
108
import numpy as np
119

10+
from matplotlib import cbook
1211
import matplotlib.patches as mpatches
1312
from matplotlib.path import Path
1413
import matplotlib.axes as maxes
@@ -360,13 +359,8 @@ def adjust_axes_lim(self):
360359
self.set_ylim(ymin-dy, ymax+dy)
361360

362361

363-
@functools.lru_cache(None)
364-
def floatingaxes_class_factory(axes_class):
365-
return type("Floating %s" % axes_class.__name__,
366-
(FloatingAxesBase, axes_class),
367-
{'_axes_class_floating': axes_class})
368-
369-
362+
floatingaxes_class_factory = cbook._make_class_factory(
363+
FloatingAxesBase, "Floating {}", "_axes_class_floating")
370364
FloatingAxes = floatingaxes_class_factory(
371365
host_axes_class_factory(axislines.Axes))
372366
FloatingSubplot = maxes.subplot_class_factory(FloatingAxes)

0 commit comments

Comments
 (0)