Skip to content

Commit cdecca3

Browse files
committed
Templatize class factories.
This makes mpl_toolkits axes classes picklable (see test_pickle) by generalizing the machinery of _picklable_subplot_class_constructor, which would otherwise have had to be reimplemented for each class factory.
1 parent ca6e9dc commit cdecca3

File tree

7 files changed

+76
-101
lines changed

7 files changed

+76
-101
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
The private ``matplotlib.axes._subplots._subplot_classes`` dict has been removed
2+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
3+
4+
Support for passing ``None`` to ``subplot_class_factory`` is deprecated
5+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
6+
Explicitly pass in the base `~matplotlib.axes.Axes` class instead.

lib/matplotlib/axes/_subplots.py

+3-61
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import functools
2-
3-
from matplotlib import _api
1+
from matplotlib import _api, cbook
42
from matplotlib.axes._axes import Axes
53
from matplotlib.gridspec import GridSpec, SubplotSpec
64

@@ -36,15 +34,6 @@ def __init__(self, fig, *args, **kwargs):
3634
# This will also update the axes position.
3735
self.set_subplotspec(SubplotSpec._from_subplot_args(fig, args))
3836

39-
def __reduce__(self):
40-
# get the first axes class which does not inherit from a subplotbase
41-
axes_class = next(
42-
c for c in type(self).__mro__
43-
if issubclass(c, Axes) and not issubclass(c, SubplotBase))
44-
return (_picklable_subplot_class_constructor,
45-
(axes_class,),
46-
self.__getstate__())
47-
4837
@_api.deprecated(
4938
"3.4", alternative="get_subplotspec",
5039
addendum="(get_subplotspec returns a SubplotSpec instance.)")
@@ -169,53 +158,6 @@ def _make_twin_axes(self, *args, **kwargs):
169158
return twin
170159

171160

172-
# this here to support cartopy which was using a private part of the
173-
# API to register their Axes subclasses.
174-
175-
# In 3.1 this should be changed to a dict subclass that warns on use
176-
# In 3.3 to a dict subclass that raises a useful exception on use
177-
# In 3.4 should be removed
178-
179-
# The slow timeline is to give cartopy enough time to get several
180-
# release out before we break them.
181-
_subplot_classes = {}
182-
183-
184-
@functools.lru_cache(None)
185-
def subplot_class_factory(axes_class=None):
186-
"""
187-
Make a new class that inherits from `.SubplotBase` and the
188-
given axes_class (which is assumed to be a subclass of `.axes.Axes`).
189-
This is perhaps a little bit roundabout to make a new class on
190-
the fly like this, but it means that a new Subplot class does
191-
not have to be created for every type of Axes.
192-
"""
193-
if axes_class is None:
194-
_api.warn_deprecated(
195-
"3.3", message="Support for passing None to subplot_class_factory "
196-
"is deprecated since %(since)s; explicitly pass the default Axes "
197-
"class instead. This will become an error %(removal)s.")
198-
axes_class = Axes
199-
try:
200-
# Avoid creating two different instances of GeoAxesSubplot...
201-
# Only a temporary backcompat fix. This should be removed in
202-
# 3.4
203-
return next(cls for cls in SubplotBase.__subclasses__()
204-
if cls.__bases__ == (SubplotBase, axes_class))
205-
except StopIteration:
206-
return type("%sSubplot" % axes_class.__name__,
207-
(SubplotBase, axes_class),
208-
{'_axes_class': axes_class})
209-
210-
161+
subplot_class_factory = cbook._make_class_factory(
162+
SubplotBase, "{}Subplot", "_axes_class")
211163
Subplot = subplot_class_factory(Axes) # Provided for backward compatibility.
212-
213-
214-
def _picklable_subplot_class_constructor(axes_class):
215-
"""
216-
Stub factory that returns an empty instance of the appropriate subplot
217-
class when called with an axes class. This is purely to allow pickling of
218-
Axes and Subplots.
219-
"""
220-
subplot_class = subplot_class_factory(axes_class)
221-
return subplot_class.__new__(subplot_class)

lib/matplotlib/cbook/__init__.py

+50
Original file line numberDiff line numberDiff line change
@@ -2188,3 +2188,53 @@ def _unikey_or_keysym_to_mplkey(unikey, keysym):
21882188
"next": "pagedown", # Used by tk.
21892189
}.get(key, key)
21902190
return key
2191+
2192+
2193+
@functools.lru_cache(None)
2194+
def _make_class_factory(mixin_class, fmt, attr_name=None):
2195+
"""
2196+
Return a function that creates picklable classes inheriting from a mixin.
2197+
2198+
After ::
2199+
2200+
factory = _make_class_factory(FooMixin, fmt, attr_name)
2201+
FooAxes = factory(Axes)
2202+
2203+
``Foo`` is a class that inherits from ``FooMixin`` and ``Axes`` and **is
2204+
picklable** (picklability is what differentiates this from a plain call to
2205+
`type`). Its ``__name__`` is set to ``fmt.format(Axes.__name__)`` and the
2206+
base class is stored in the ``attr_name`` attribute, if not None.
2207+
2208+
Moreover, the return value of ``factory`` is memoized: calls with the same
2209+
``Axes`` class always return the same subclass.
2210+
"""
2211+
2212+
@functools.lru_cache(None)
2213+
def class_factory(axes_class):
2214+
# The parameter is named "axes_class" for backcompat but is really just
2215+
# a base class; no axes semantics are used.
2216+
base_class = axes_class
2217+
2218+
class subcls(mixin_class, base_class):
2219+
# Better approximation than __module__ = "matplotlib.cbook".
2220+
__module__ = mixin_class.__module__
2221+
2222+
def __reduce__(self):
2223+
return (_picklable_class_constructor,
2224+
(mixin_class, fmt, attr_name, base_class),
2225+
self.__getstate__())
2226+
2227+
subcls.__name__ = subcls.__qualname__ = fmt.format(base_class.__name__)
2228+
if attr_name is not None:
2229+
setattr(subcls, attr_name, base_class)
2230+
return subcls
2231+
2232+
class_factory.__module__ = mixin_class.__module__
2233+
return class_factory
2234+
2235+
2236+
def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class):
2237+
"""Internal helper for _make_class_factory."""
2238+
factory = _make_class_factory(mixin_class, fmt, attr_name)
2239+
cls = factory(base_class)
2240+
return cls.__new__(cls)

lib/matplotlib/tests/test_axes.py

-15
Original file line numberDiff line numberDiff line change
@@ -6340,21 +6340,6 @@ def test_spines_properbbox_after_zoom():
63406340
np.testing.assert_allclose(bb.get_points(), bb2.get_points(), rtol=1e-6)
63416341

63426342

6343-
def test_cartopy_backcompat():
6344-
6345-
class Dummy(matplotlib.axes.Axes):
6346-
...
6347-
6348-
class DummySubplot(matplotlib.axes.SubplotBase, Dummy):
6349-
_axes_class = Dummy
6350-
6351-
matplotlib.axes._subplots._subplot_classes[Dummy] = DummySubplot
6352-
6353-
FactoryDummySubplot = matplotlib.axes.subplot_class_factory(Dummy)
6354-
6355-
assert DummySubplot is FactoryDummySubplot
6356-
6357-
63586343
def test_gettightbbox_ignore_nan():
63596344
fig, ax = plt.subplots()
63606345
remove_ticks_and_titles(fig)

lib/matplotlib/tests/test_pickle.py

+6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import matplotlib.pyplot as plt
1111
import matplotlib.transforms as mtransforms
1212
import matplotlib.figure as mfigure
13+
from mpl_toolkits.axes_grid1 import parasite_axes
1314

1415

1516
def test_simple():
@@ -212,3 +213,8 @@ def test_unpickle_canvas():
212213
out.seek(0)
213214
fig2 = pickle.load(out)
214215
assert fig2.canvas is not None
216+
217+
218+
def test_mpl_toolkits():
219+
ax = parasite_axes.host_axes([0, 0, 1, 1])
220+
assert type(pickle.loads(pickle.dumps(ax))) == parasite_axes.HostAxes

lib/mpl_toolkits/axes_grid1/parasite_axes.py

+8-17
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import functools
22

3-
from matplotlib import _api
3+
from matplotlib import _api, cbook
44
import matplotlib.artist as martist
55
import matplotlib.image as mimage
66
import matplotlib.transforms as mtransforms
@@ -95,12 +95,8 @@ def apply_aspect(self, position=None):
9595
# end of aux_transform support
9696

9797

98-
@functools.lru_cache(None)
99-
def parasite_axes_class_factory(axes_class):
100-
return type("%sParasite" % axes_class.__name__,
101-
(ParasiteAxesBase, axes_class), {})
102-
103-
98+
parasite_axes_class_factory = cbook._make_class_factory(
99+
ParasiteAxesBase, "{}Parasite")
104100
ParasiteAxes = parasite_axes_class_factory(Axes)
105101

106102

@@ -277,7 +273,7 @@ def _add_twin_axes(self, axes_class, **kwargs):
277273
*kwargs* are forwarded to the parasite axes constructor.
278274
"""
279275
if axes_class is None:
280-
axes_class = self._get_base_axes()
276+
axes_class = self._base_axes_class
281277
ax = parasite_axes_class_factory(axes_class)(self, **kwargs)
282278
self.parasites.append(ax)
283279
ax._remove_method = self._remove_any_twin
@@ -304,11 +300,10 @@ def get_tightbbox(self, renderer, call_axes_locator=True,
304300
return Bbox.union([b for b in bbs if b.width != 0 or b.height != 0])
305301

306302

307-
@functools.lru_cache(None)
308-
def host_axes_class_factory(axes_class):
309-
return type("%sHostAxes" % axes_class.__name__,
310-
(HostAxesBase, axes_class),
311-
{'_get_base_axes': lambda self: axes_class})
303+
host_axes_class_factory = cbook._make_class_factory(
304+
HostAxesBase, "{}HostAxes", "_base_axes_class")
305+
HostAxes = host_axes_class_factory(Axes)
306+
SubplotHost = subplot_class_factory(HostAxes)
312307

313308

314309
def host_subplot_class_factory(axes_class):
@@ -317,10 +312,6 @@ def host_subplot_class_factory(axes_class):
317312
return subplot_host_class
318313

319314

320-
HostAxes = host_axes_class_factory(Axes)
321-
SubplotHost = subplot_class_factory(HostAxes)
322-
323-
324315
def host_axes(*args, axes_class=Axes, figure=None, **kwargs):
325316
"""
326317
Create axes that can act as a hosts to parasitic axes.

lib/mpl_toolkits/axisartist/floating_axes.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
# TODO :
66
# see if tick_iterator method can be simplified by reusing the parent method.
77

8-
import functools
9-
108
import numpy as np
119

10+
from matplotlib import cbook
1211
import matplotlib.patches as mpatches
1312
from matplotlib.path import Path
1413
import matplotlib.axes as maxes
@@ -351,12 +350,8 @@ def adjust_axes_lim(self):
351350
self.set_ylim(ymin-dy, ymax+dy)
352351

353352

354-
@functools.lru_cache(None)
355-
def floatingaxes_class_factory(axes_class):
356-
return type("Floating %s" % axes_class.__name__,
357-
(FloatingAxesBase, axes_class), {})
358-
359-
353+
floatingaxes_class_factory = cbook._make_class_factory(
354+
FloatingAxesBase, "Floating {}")
360355
FloatingAxes = floatingaxes_class_factory(
361356
host_axes_class_factory(axislines.Axes))
362357
FloatingSubplot = maxes.subplot_class_factory(FloatingAxes)

0 commit comments

Comments
 (0)