Skip to content

subplots() returns AxesArray #25937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions lib/matplotlib/gridspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,10 @@ def subplots(self, *, sharex=False, sharey=False, squeeze=True,
if squeeze:
# Discarding unneeded dimensions that equal 1. If we only have one
# subplot, just return it instead of a 1-element array.
return axarr.item() if axarr.size == 1 else axarr.squeeze()
return axarr.item() if axarr.size == 1 else AxesArray(axarr.squeeze())
else:
# Returned axis array will be always 2-d, even if nrows=ncols=1.
return axarr
return AxesArray(axarr)


class GridSpec(GridSpecBase):
Expand Down Expand Up @@ -736,3 +736,70 @@ def subgridspec(self, nrows, ncols, **kwargs):
fig.add_subplot(gssub[0, i])
"""
return GridSpecFromSubplotSpec(nrows, ncols, self, **kwargs)


class AxesArray:
"""
A container for a 1D or 2D grid arrangement of Axes.

This is used as the return type of ``subplots()``.

Formerly, ``subplots()`` returned a numpy array of Axes. For a transition period,
AxesArray will act like a numpy array, but all functions and properties that
are not listed explicitly below are deprecated.
"""
def __init__(self, array):
self._array = array

@staticmethod
def _ensure_wrapped(ax_or_axs):
if isinstance(ax_or_axs, np.ndarray):
return AxesArray(ax_or_axs)
else:
return ax_or_axs

def __getitem__(self, index):
return self._ensure_wrapped(self._array[index])

@property
def __array_struct__(self):
return self._array.__array_struct__

@property
def ndim(self):
return self._array.ndim

@property
def shape(self):
return self._array.shape

@property
def size(self):
return self._array.size

@property
def flat(self):
return self._array.flat
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could add ravel() here too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess there is no good reason to use flatten() over flat, but inexperienced me did so in lots of places. If we don’t want to support flatten(), it might be good to have a more detailed deprecation message saying what to use instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with adding flatten() here too. It really doesn't cost anything to support it too and should save quite a bit of churn in old codebases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback. Will add flatten() and ravel() as "discouraged" but keep for maximum compatibility. I also will switch to pending deprecation. Cutting the behavior is both difficult (it's hard to overview the whole array API) and not urgent. So, let's take it really slow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect there's also users of expressions such as axarr.T or axarr.ravel("F") (or similar...) that want to iterate in column-major order.


@property
def flatten(self):
"""[Disouraged] Use ``axs.flat`` instead."""
return self._array.flatten

@property
def ravel(self):
"""[Disouraged] Use ``axs.flat`` instead."""
return self._array.ravel

@property
def __iter__(self):
return iter([self._ensure_wrapped(row) for row in self._array])

def __getattr__(self, item):
# forward all other attributes to the underlying array
# (this is a temporary measure to allow a smooth transition)
attr = getattr(self._array, item)
_api.warn_deprecated("3.9",
message=f"Using {item!r} on AxesArray is deprecated.",
pending=True)
return attr
78 changes: 78 additions & 0 deletions lib/matplotlib/tests/test_subplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pytest

import matplotlib as mpl
from matplotlib.axes import Axes, SubplotBase
import matplotlib.pyplot as plt
from matplotlib.testing.decorators import check_figures_equal, image_comparison
Expand Down Expand Up @@ -283,3 +284,80 @@ def test_old_subplot_compat():
assert not isinstance(fig.add_axes(rect=[0, 0, 1, 1]), SubplotBase)
with pytest.raises(TypeError):
Axes(fig, [0, 0, 1, 1], rect=[0, 0, 1, 1])


class TestAxesArray:
@staticmethod
def contain_same_axes(axs1, axs2):
return all(ax1 is ax2 for ax1, ax2 in zip(axs1.flat, axs2.flat))

def test_1d(self):
axs = plt.figure().subplots(1, 3)
# shape and size
assert axs.shape == (3,)
assert axs.size == 3
assert axs.ndim == 1
# flat
assert all(isinstance(ax, Axes) for ax in axs.flat)
assert len(set(id(ax) for ax in axs.flat)) == 3
# flatten
assert all(isinstance(ax, Axes) for ax in axs.flatten())
assert len(set(id(ax) for ax in axs.flatten())) == 3
# ravel
assert all(isinstance(ax, Axes) for ax in axs.ravel())
assert len(set(id(ax) for ax in axs.ravel())) == 3
# single index
assert all(isinstance(axs[i], Axes) for i in range(axs.size))
assert len(set(axs[i] for i in range(axs.size))) == 3
# iteration
assert all(ax1 is ax2 for ax1, ax2 in zip(axs, axs.flat))

def test_1d_no_squeeze(self):
axs = plt.figure().subplots(1, 3, squeeze=False)
# shape and size
assert axs.shape == (1, 3)
assert axs.size == 3
assert axs.ndim == 2
# flat
assert all(isinstance(ax, Axes) for ax in axs.flat)
assert len(set(id(ax) for ax in axs.flat)) == 3
# 2d indexing
assert axs[0, 0] is axs.flat[0]
assert axs[0, 2] is axs.flat[-1]
# single index
axs_type = type(axs)
assert type(axs[0]) is axs_type
assert axs[0].shape == (3,)
# iteration
assert all(self.contain_same_axes(axi, axs[i]) for i, axi in enumerate(axs))

def test_2d(self):
axs = plt.figure().subplots(2, 3)
# shape and size
assert axs.shape == (2, 3)
assert axs.size == 6
assert axs.ndim == 2
# flat
assert all(isinstance(ax, Axes) for ax in axs.flat)
assert len(set(id(ax) for ax in axs.flat)) == 6
# flatten
assert all(isinstance(ax, Axes) for ax in axs.flatten())
assert len(set(id(ax) for ax in axs.flatten())) == 6
# ravel
assert all(isinstance(ax, Axes) for ax in axs.ravel())
assert len(set(id(ax) for ax in axs.ravel())) == 6
# 2d indexing
assert axs[0, 0] is axs.flat[0]
assert axs[1, 2] is axs.flat[-1]
# single index
axs_type = type(axs)
assert type(axs[0]) is axs_type
assert axs[0].shape == (3,)
# iteration
assert all(self.contain_same_axes(axi, axs[i]) for i, axi in enumerate(axs))

def test_deprecated(self):
axs = plt.figure().subplots(2, 2)
with pytest.warns(PendingDeprecationWarning,
match="Using 'diagonal' on AxesArray"):
axs.diagonal()