Skip to content

ENH: make subplot reuse gridspecs if they fit #11441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 71 additions & 25 deletions lib/matplotlib/axes/_subplots.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import numpy as np
import warnings

from matplotlib import docstring
Expand Down Expand Up @@ -31,39 +32,39 @@ def __init__(self, fig, *args, **kwargs):
"""

self.figure = fig

if len(args) == 1:
if isinstance(args[0], SubplotSpec):
self._subplotspec = args[0]
else:
if len(args) == 1 and isinstance(args[0], SubplotSpec):
self._subplotspec = args[0]
else:
# we need to make the subplotspec either from a new gridspec or
# an existing one:
if len(args) == 1:
# 223-style argument...
try:
s = str(int(args[0]))
rows, cols, num = map(int, s)
except ValueError:
raise ValueError('Single argument to subplot must be '
'a 3-digit integer')
self._subplotspec = GridSpec(rows, cols,
figure=self.figure)[num - 1]
# num - 1 for converting from MATLAB to python indexing
elif len(args) == 3:
rows, cols, num = args
rows = int(rows)
cols = int(cols)
if isinstance(num, tuple) and len(num) == 2:
num = [int(n) for n in num]
self._subplotspec = GridSpec(
rows, cols,
figure=self.figure)[(num[0] - 1):num[1]]
num = [int(num), int(num)]
elif len(args) == 3:
rows, cols, num = args
rows = int(rows)
cols = int(cols)
if isinstance(num, tuple) and len(num) == 2:
num = [int(n) for n in num]
else:
if num < 1 or num > rows*cols:
raise ValueError(
("num must be 1 <= num <= {maxn}, not {num}"
).format(maxn=rows*cols, num=num))
num = [int(num), int(num)]
else:
if num < 1 or num > rows*cols:
raise ValueError(
("num must be 1 <= num <= {maxn}, not {num}"
).format(maxn=rows*cols, num=num))
self._subplotspec = GridSpec(
rows, cols, figure=self.figure)[int(num) - 1]
raise ValueError('Illegal argument(s) to subplot: %s' %
(args,))
gs, num = self._make_subplotspec(rows, cols, num,
figure=self.figure)
self._subplotspec = gs[(num[0] - 1):num[1]]
# num - 1 for converting from MATLAB to python indexing
else:
raise ValueError('Illegal argument(s) to subplot: %s' % (args,))

self.update_params()

Expand All @@ -87,6 +88,51 @@ def __init__(self, fig, *args, **kwargs):
name=self._layoutbox.name+'.pos',
pos=True, subplot=True, artist=self)

def _make_subplotspec(self, rows, cols, num, figure=None):
"""
Return the subplotspec for this subplot, but reuse an old
GridSpec if it exists and if the new gridspec "fits".
"""
axs = figure.get_axes()
for ax in axs:
if hasattr(ax, 'get_subplotspec'):
gs = ax.get_subplotspec().get_gridspec()
if hasattr(gs, 'get_topmost_subplotspec'):
# This is needed for colorbar gridspec layouts.
# This is probably OK becase this whole logic tree
# is for when the user is doing simple things with the
# add_subplot command. Complicated stuff, the proper
# gridspec is passed in...
gs = gs.get_topmost_subplotspec().get_gridspec()

(nrow, ncol) = gs.get_geometry()
if (not (nrow % rows) and not (ncol % cols)):
# this gridspec "fits"...
# now we have to see if we need to modify num...
rowfac = int(nrow / rows)
colfac = int(ncol / cols)
if (not isinstance(num, tuple) and
not isinstance(num, list)):
num = [num, num]
# converting between num and rows/cols is a PITA:
newnum = num
row = int(np.floor((num[0]-1) / cols))
col = (num[0]-1) - row * cols
row *= rowfac
col *= colfac
newnum[0] = row * ncol + col + 1
row = int(np.floor((num[1]-1) / cols))
col = (num[1]-1) - row * cols
row *= rowfac
col *= colfac
row = row + (rowfac - 1)
col = col + (colfac - 1)
newnum[1] = row * ncol + col + 1
return gs, newnum
# no axes fit with the new subplot specification so make a
# new one...
return GridSpec(rows, cols, figure=figure), num

def __reduce__(self):
# get the first axes class which does not inherit from a subplotbase
axes_class = next(
Expand Down
2 changes: 2 additions & 0 deletions lib/matplotlib/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,6 +1197,8 @@ def add_subplot(self, *args, **kwargs):
fig.add_subplot(111, projection='polar')

# add Subplot instance sub
gs = gridspec.GridSpec(2, 3)
sub = gs[1, 1]
fig.add_subplot(sub)

See Also
Expand Down
2 changes: 1 addition & 1 deletion lib/matplotlib/gridspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def __init__(self, gridspec, num1, num2=None):
self._gridspec = gridspec
self.num1 = num1
self.num2 = num2
if gridspec._layoutbox is not None:
if hasattr(gridspec, '_layoutbox') and gridspec._layoutbox is not None:
glb = gridspec._layoutbox
# So note that here we don't assign any layout yet,
# just make the layoutbox that will conatin all items
Expand Down
7 changes: 5 additions & 2 deletions lib/matplotlib/tests/test_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,17 @@ def test_gca():
assert fig.gca() is ax3

# the final request for a polar axes will end up creating one
# with a spec of 111.
# with a spec of 121. The 2 stays in there, because we reuse the
# grid spec of the 12x calls...
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
# Changing the projection will throw a warning
assert fig.gca(polar=True) is not ax3
assert len(w) == 1
assert fig.gca(polar=True) is not ax1
assert fig.gca(polar=True) is not ax2
assert fig.gca().get_geometry() == (1, 1, 1)
assert fig.gca(polar=True) is not ax3
# assert fig.gca().get_geometry() == (1, 1, 1)

fig.sca(ax1)
assert fig.gca(projection='rectilinear') is ax1
Expand Down