Skip to content

Add GridSpec.subplots() #14421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions doc/users/next_whats_new/2019-06-01-AL.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
GridSpec.subplots()
```````````````````

The `.GridSpec` class gained a `~.GridSpecBase.subplots` method, so that one
can write ::

fig.add_gridspec(2, 2, height_ratios=[3, 1]).subplots()

as an alternative to ::

fig.subplots(2, 2, gridspec_kw={"height_ratios": [3, 1]})
5 changes: 3 additions & 2 deletions examples/lines_bars_and_markers/linestyles.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ def plot_linestyles(ax, linestyles, title):
color="blue", fontsize=8, ha="right", family="monospace")


fig, (ax0, ax1) = plt.subplots(2, 1, gridspec_kw={'height_ratios': [1, 3]},
figsize=(10, 8))
ax0, ax1 = (plt.figure(figsize=(10, 8))
.add_gridspec(2, 1, height_ratios=[1, 3])
.subplots())

plot_linestyles(ax0, linestyle_str[::-1], title='Named linestyles')
plot_linestyles(ax1, linestyle_tuple[::-1], title='Parametrized linestyles')
Expand Down
17 changes: 10 additions & 7 deletions examples/subplots_axes_and_figures/subplots_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,17 @@
# labels of inner Axes are automatically removed by *sharex* and *sharey*.
# Still there remains an unused empty space between the subplots.
#
# The parameter *gridspec_kw* of `.pyplot.subplots` controls the grid
# properties (see also `.GridSpec`). For example, we can reduce the height
# between vertical subplots using ``gridspec_kw={'hspace': 0}``.
# To precisely control the positioning of the subplots, one can explicitly
# create a `.GridSpec` with `.add_gridspec`, and then call its
# `~.GridSpecBase.subplots` method. For example, we can reduce the height
# between vertical subplots using ``add_gridspec(hspace=0)``.
#
# `.label_outer` is a handy method to remove labels and ticks from subplots
# that are not at the edge of the grid.

fig, axs = plt.subplots(3, sharex=True, sharey=True, gridspec_kw={'hspace': 0})
fig = plt.figure()
gs = fig.add_gridspec(3, hspace=0)
axs = gs.subplots(sharex=True, sharey=True)
fig.suptitle('Sharing both axes')
axs[0].plot(x, y ** 2)
axs[1].plot(x, 0.3 * y, 'o')
Expand All @@ -164,9 +167,9 @@
# Apart from ``True`` and ``False``, both *sharex* and *sharey* accept the
# values 'row' and 'col' to share the values only per row or column.

fig, axs = plt.subplots(2, 2, sharex='col', sharey='row',
gridspec_kw={'hspace': 0, 'wspace': 0})
(ax1, ax2), (ax3, ax4) = axs
fig = plt.figure()
gs = fig.add_gridspec(2, 2, hspace=0, wspace=0)
(ax1, ax2), (ax3, ax4) = gs.subplots(sharex='col', sharey='row')
fig.suptitle('Sharing x per column, y per row')
ax1.plot(x, y)
ax2.plot(x, y**2, 'tab:orange')
Expand Down
9 changes: 4 additions & 5 deletions examples/userdemo/demo_gridspec06.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@ def squiggle_xy(a, b, c, d):
for b in range(4):
# gridspec inside gridspec
inner_grid = outer_grid[a, b].subgridspec(3, 3, wspace=0, hspace=0)
for c in range(3):
for d in range(3):
ax = fig.add_subplot(inner_grid[c, d])
ax.plot(*squiggle_xy(a + 1, b + 1, c + 1, d + 1))
ax.set(xticks=[], yticks=[])
axs = inner_grid.subplots() # Create all subplots for the inner grid.
for (c, d), ax in np.ndenumerate(axs):
ax.plot(*squiggle_xy(a + 1, b + 1, c + 1, d + 1))
ax.set(xticks=[], yticks=[])

# show only the outside spines
for ax in fig.get_axes():
Expand Down
69 changes: 6 additions & 63 deletions lib/matplotlib/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,68 +1528,11 @@ def subplots(self, nrows=1, ncols=1, sharex=False, sharey=False,
# Note that this is the same as
fig.subplots(2, 2, sharex=True, sharey=True)
"""

if isinstance(sharex, bool):
sharex = "all" if sharex else "none"
if isinstance(sharey, bool):
sharey = "all" if sharey else "none"
# This check was added because it is very easy to type
# `subplots(1, 2, 1)` when `subplot(1, 2, 1)` was intended.
# In most cases, no error will ever occur, but mysterious behavior
# will result because what was intended to be the subplot index is
# instead treated as a bool for sharex.
if isinstance(sharex, Integral):
cbook._warn_external(
"sharex argument to subplots() was an integer. Did you "
"intend to use subplot() (without 's')?")
cbook._check_in_list(["all", "row", "col", "none"],
sharex=sharex, sharey=sharey)
if subplot_kw is None:
subplot_kw = {}
if gridspec_kw is None:
gridspec_kw = {}
# don't mutate kwargs passed by user...
subplot_kw = subplot_kw.copy()
gridspec_kw = gridspec_kw.copy()

if self.get_constrained_layout():
gs = GridSpec(nrows, ncols, figure=self, **gridspec_kw)
else:
# this should turn constrained_layout off if we don't want it
gs = GridSpec(nrows, ncols, figure=None, **gridspec_kw)
self._gridspecs.append(gs)

# Create array to hold all axes.
axarr = np.empty((nrows, ncols), dtype=object)
for row in range(nrows):
for col in range(ncols):
shared_with = {"none": None, "all": axarr[0, 0],
"row": axarr[row, 0], "col": axarr[0, col]}
subplot_kw["sharex"] = shared_with[sharex]
subplot_kw["sharey"] = shared_with[sharey]
axarr[row, col] = self.add_subplot(gs[row, col], **subplot_kw)

# turn off redundant tick labeling
if sharex in ["col", "all"]:
# turn off all but the bottom row
for ax in axarr[:-1, :].flat:
ax.xaxis.set_tick_params(which='both',
labelbottom=False, labeltop=False)
ax.xaxis.offsetText.set_visible(False)
if sharey in ["row", "all"]:
# turn off all but the first column
for ax in axarr[:, 1:].flat:
ax.yaxis.set_tick_params(which='both',
labelleft=False, labelright=False)
ax.yaxis.offsetText.set_visible(False)

if squeeze:
# Discarding unneeded dimensions that equal 1. If we only have one
# subplot, just return it instead of a 1-element array.
return axarr.item() if axarr.size == 1 else axarr.squeeze()
else:
# Returned axis array will be always 2-d, even if nrows=ncols=1.
return axarr
return (self.add_gridspec(nrows, ncols, figure=self, **gridspec_kw)
.subplots(sharex=sharex, sharey=sharey, squeeze=squeeze,
subplot_kw=subplot_kw))

def delaxes(self, ax):
"""
Expand Down Expand Up @@ -2599,17 +2542,17 @@ def align_labels(self, axs=None):
self.align_xlabels(axs=axs)
self.align_ylabels(axs=axs)

def add_gridspec(self, nrows, ncols, **kwargs):
def add_gridspec(self, nrows=1, ncols=1, **kwargs):
"""
Return a `.GridSpec` that has this figure as a parent. This allows
complex layout of axes in the figure.

Parameters
----------
nrows : int
nrows : int, default: 1
Number of rows in grid.

ncols : int
ncols : int, default: 1
Number or columns in grid.

Returns
Expand Down
122 changes: 122 additions & 0 deletions lib/matplotlib/gridspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import copy
import logging
from numbers import Integral

import numpy as np

Expand Down Expand Up @@ -234,6 +235,127 @@ def _normalize(key, size, axis): # Includes last index.

return SubplotSpec(self, num1, num2)

def subplots(self, sharex=False, sharey=False, squeeze=True,
subplot_kw=None):
"""
Add all subplots specified by this `GridSpec` to its parent figure.

This utility wrapper makes it convenient to create common layouts of
subplots in a single call.

Parameters
----------
sharex, sharey : bool or {'none', 'all', 'row', 'col'}, default: False
Controls sharing of properties among x (*sharex*) or y (*sharey*)
axes:

- True or 'all': x- or y-axis will be shared among all
subplots.
- False or 'none': each subplot x- or y-axis will be
independent.
- 'row': each subplot row will share an x- or y-axis.
- 'col': each subplot column will share an x- or y-axis.

When subplots have a shared x-axis along a column, only the x tick
labels of the bottom subplot are created. Similarly, when subplots
have a shared y-axis along a row, only the y tick labels of the
first column subplot are created. To later turn other subplots'
ticklabels on, use `~matplotlib.axes.Axes.tick_params`.

squeeze : bool, optional, default: True
- If True, extra dimensions are squeezed out from the returned
array of Axes:

- if only one subplot is constructed (nrows=ncols=1), the
resulting single Axes object is returned as a scalar.
- for Nx1 or 1xM subplots, the returned object is a 1D numpy
object array of Axes objects.
- for NxM, subplots with N>1 and M>1 are returned
as a 2D array.

- If False, no squeezing at all is done: the returned Axes object
is always a 2D array containing Axes instances, even if it ends
up being 1x1.

subplot_kw : dict, optional
Dict with keywords passed to the
:meth:`~matplotlib.figure.Figure.add_subplot` call used to create
each subplot.

Returns
-------
ax : `~.axes.Axes` object or array of Axes objects.
*ax* can be either a single `~matplotlib.axes.Axes` object or
an array of Axes objects if more than one subplot was created. The
dimensions of the resulting array can be controlled with the
squeeze keyword, see above.

See Also
--------
.pyplot.subplots
.Figure.add_subplot
.pyplot.subplot
"""

figure = self[0, 0].get_topmost_subplotspec().get_gridspec().figure

if figure is None:
raise ValueError("GridSpec.subplots() only works for GridSpecs "
"created with a parent figure")

if isinstance(sharex, bool):
sharex = "all" if sharex else "none"
if isinstance(sharey, bool):
sharey = "all" if sharey else "none"
# This check was added because it is very easy to type
# `subplots(1, 2, 1)` when `subplot(1, 2, 1)` was intended.
# In most cases, no error will ever occur, but mysterious behavior
# will result because what was intended to be the subplot index is
# instead treated as a bool for sharex.
if isinstance(sharex, Integral):
cbook._warn_external(
"sharex argument to subplots() was an integer. Did you "
"intend to use subplot() (without 's')?")
cbook._check_in_list(["all", "row", "col", "none"],
sharex=sharex, sharey=sharey)
if subplot_kw is None:
subplot_kw = {}
# don't mutate kwargs passed by user...
subplot_kw = subplot_kw.copy()

# Create array to hold all axes.
axarr = np.empty((self._nrows, self._ncols), dtype=object)
for row in range(self._nrows):
for col in range(self._ncols):
shared_with = {"none": None, "all": axarr[0, 0],
"row": axarr[row, 0], "col": axarr[0, col]}
subplot_kw["sharex"] = shared_with[sharex]
subplot_kw["sharey"] = shared_with[sharey]
axarr[row, col] = figure.add_subplot(
self[row, col], **subplot_kw)

# turn off redundant tick labeling
if sharex in ["col", "all"]:
# turn off all but the bottom row
for ax in axarr[:-1, :].flat:
ax.xaxis.set_tick_params(which='both',
labelbottom=False, labeltop=False)
ax.xaxis.offsetText.set_visible(False)
if sharey in ["row", "all"]:
# turn off all but the first column
for ax in axarr[:, 1:].flat:
ax.yaxis.set_tick_params(which='both',
labelleft=False, labelright=False)
ax.yaxis.offsetText.set_visible(False)

if squeeze:
# Discarding unneeded dimensions that equal 1. If we only have one
# subplot, just return it instead of a 1-element array.
return axarr.item() if axarr.size == 1 else axarr.squeeze()
else:
# Returned axis array will be always 2-d, even if nrows=ncols=1.
return axarr


class GridSpec(GridSpecBase):
"""
Expand Down
9 changes: 4 additions & 5 deletions tutorials/intermediate/gridspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,10 @@ def squiggle_xy(a, b, c, d, i=np.arange(0.0, 2*np.pi, 0.05)):
for b in range(4):
# gridspec inside gridspec
inner_grid = outer_grid[a, b].subgridspec(3, 3, wspace=0, hspace=0)
for c in range(3):
for d in range(3):
ax = fig11.add_subplot(inner_grid[c, d])
ax.plot(*squiggle_xy(a + 1, b + 1, c + 1, d + 1))
ax.set(xticks=[], yticks=[])
axs = inner_grid.subplots() # Create all subplots for the inner grid.
for (c, d), ax in np.ndenumerate(axs):
ax.plot(*squiggle_xy(a + 1, b + 1, c + 1, d + 1))
ax.set(xticks=[], yticks=[])

# show only the outside spines
for ax in fig11.get_axes():
Expand Down