Skip to content

Commit 44a27d2

Browse files
committed
subplots() returns AxesArray
subplots() used to return a numpy array of Axes, which has some drawbacks. The numpy array is mainly used as a 2D container structure that allows 2D indexing. Apart from that, it's not particularly well suited: - Many of the numpy functions do not work on Axes. - Some functions work, but have awkward semantics; e.g. len() gives the number of rows. - We can't add our own functionality. AxesArray introduces a facade to the underlying array to allow us to customize the API. For the beginning, the API is 100% compatible with the previous numpy array behavior, but we deprecate everything except for a few reasonable methods.
1 parent a7b47f4 commit 44a27d2

File tree

2 files changed

+147
-2
lines changed

2 files changed

+147
-2
lines changed

lib/matplotlib/gridspec.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,10 @@ def subplots(self, *, sharex=False, sharey=False, squeeze=True,
309309
if squeeze:
310310
# Discarding unneeded dimensions that equal 1. If we only have one
311311
# subplot, just return it instead of a 1-element array.
312-
return axarr.item() if axarr.size == 1 else axarr.squeeze()
312+
return axarr.item() if axarr.size == 1 else AxesArray(axarr.squeeze())
313313
else:
314314
# Returned axis array will be always 2-d, even if nrows=ncols=1.
315-
return axarr
315+
return AxesArray(axarr)
316316

317317

318318
class GridSpec(GridSpecBase):
@@ -736,3 +736,70 @@ def subgridspec(self, nrows, ncols, **kwargs):
736736
fig.add_subplot(gssub[0, i])
737737
"""
738738
return GridSpecFromSubplotSpec(nrows, ncols, self, **kwargs)
739+
740+
741+
class AxesArray:
742+
"""
743+
A container for a 1D or 2D grid arrangement of Axes.
744+
745+
This is used as the return type of ``subplots()``.
746+
747+
Formerly, ``subplots()`` returned a numpy array of Axes. For a transition period,
748+
AxesArray will act like a numpy array, but all functions and properties that
749+
are not listed explicitly below are deprecated.
750+
"""
751+
def __init__(self, array):
752+
self._array = array
753+
754+
@staticmethod
755+
def _ensure_wrapped(ax_or_axs):
756+
if isinstance(ax_or_axs, np.ndarray):
757+
return AxesArray(ax_or_axs)
758+
else:
759+
return ax_or_axs
760+
761+
def __getitem__(self, index):
762+
return self._ensure_wrapped(self._array[index])
763+
764+
@property
765+
def __array_struct__(self):
766+
return self._array.__array_struct__
767+
768+
@property
769+
def ndim(self):
770+
return self._array.ndim
771+
772+
@property
773+
def shape(self):
774+
return self._array.shape
775+
776+
@property
777+
def size(self):
778+
return self._array.size
779+
780+
@property
781+
def flat(self):
782+
return self._array.flat
783+
784+
@property
785+
def flatten(self):
786+
"""[Disouraged] Use ``axs.flat`` instead."""
787+
return self._array.flatten
788+
789+
@property
790+
def ravel(self):
791+
"""[Disouraged] Use ``axs.flat`` instead."""
792+
return self._array.ravel
793+
794+
@property
795+
def __iter__(self):
796+
return iter([self._ensure_wrapped(row) for row in self._array])
797+
798+
def __getattr__(self, item):
799+
# forward all other attributes to the underlying array
800+
# (this is a temporary measure to allow a smooth transition)
801+
attr = getattr(self._array, item)
802+
_api.warn_deprecated("3.9",
803+
message=f"Using {item!r} on AxesArray is deprecated.",
804+
pending=True)
805+
return attr

lib/matplotlib/tests/test_subplots.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import pytest
55

6+
import matplotlib as mpl
67
from matplotlib.axes import Axes, SubplotBase
78
import matplotlib.pyplot as plt
89
from matplotlib.testing.decorators import check_figures_equal, image_comparison
@@ -283,3 +284,80 @@ def test_old_subplot_compat():
283284
assert not isinstance(fig.add_axes(rect=[0, 0, 1, 1]), SubplotBase)
284285
with pytest.raises(TypeError):
285286
Axes(fig, [0, 0, 1, 1], rect=[0, 0, 1, 1])
287+
288+
289+
class TestAxesArray:
290+
@staticmethod
291+
def contain_same_axes(axs1, axs2):
292+
return all(ax1 is ax2 for ax1, ax2 in zip(axs1.flat, axs2.flat))
293+
294+
def test_1d(self):
295+
axs = plt.figure().subplots(1, 3)
296+
# shape and size
297+
assert axs.shape == (3,)
298+
assert axs.size == 3
299+
assert axs.ndim == 1
300+
# flat
301+
assert all(isinstance(ax, Axes) for ax in axs.flat)
302+
assert len(set(id(ax) for ax in axs.flat)) == 3
303+
# flatten
304+
assert all(isinstance(ax, Axes) for ax in axs.flatten())
305+
assert len(set(id(ax) for ax in axs.flatten())) == 3
306+
# ravel
307+
assert all(isinstance(ax, Axes) for ax in axs.ravel())
308+
assert len(set(id(ax) for ax in axs.ravel())) == 3
309+
# single index
310+
assert all(isinstance(axs[i], Axes) for i in range(axs.size))
311+
assert len(set(axs[i] for i in range(axs.size))) == 3
312+
# iteration
313+
assert all(ax1 is ax2 for ax1, ax2 in zip(axs, axs.flat))
314+
315+
def test_1d_no_squeeze(self):
316+
axs = plt.figure().subplots(1, 3, squeeze=False)
317+
# shape and size
318+
assert axs.shape == (1, 3)
319+
assert axs.size == 3
320+
assert axs.ndim == 2
321+
# flat
322+
assert all(isinstance(ax, Axes) for ax in axs.flat)
323+
assert len(set(id(ax) for ax in axs.flat)) == 3
324+
# 2d indexing
325+
assert axs[0, 0] is axs.flat[0]
326+
assert axs[0, 2] is axs.flat[-1]
327+
# single index
328+
axs_type = type(axs)
329+
assert type(axs[0]) is axs_type
330+
assert axs[0].shape == (3,)
331+
# iteration
332+
assert all(self.contain_same_axes(axi, axs[i]) for i, axi in enumerate(axs))
333+
334+
def test_2d(self):
335+
axs = plt.figure().subplots(2, 3)
336+
# shape and size
337+
assert axs.shape == (2, 3)
338+
assert axs.size == 6
339+
assert axs.ndim == 2
340+
# flat
341+
assert all(isinstance(ax, Axes) for ax in axs.flat)
342+
assert len(set(id(ax) for ax in axs.flat)) == 6
343+
# flatten
344+
assert all(isinstance(ax, Axes) for ax in axs.flatten())
345+
assert len(set(id(ax) for ax in axs.flatten())) == 6
346+
# ravel
347+
assert all(isinstance(ax, Axes) for ax in axs.ravel())
348+
assert len(set(id(ax) for ax in axs.ravel())) == 6
349+
# 2d indexing
350+
assert axs[0, 0] is axs.flat[0]
351+
assert axs[1, 2] is axs.flat[-1]
352+
# single index
353+
axs_type = type(axs)
354+
assert type(axs[0]) is axs_type
355+
assert axs[0].shape == (3,)
356+
# iteration
357+
assert all(self.contain_same_axes(axi, axs[i]) for i, axi in enumerate(axs))
358+
359+
def test_deprecated(self):
360+
axs = plt.figure().subplots(2, 2)
361+
with pytest.warns(PendingDeprecationWarning,
362+
match="Using 'diagonal' on AxesArray"):
363+
axs.diagonal()

0 commit comments

Comments
 (0)