Skip to content

ENH: Compressed layout for fixed-aspect axes #17246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
93 changes: 93 additions & 0 deletions examples/subplots_axes_and_figures/compress_axes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
===========================
Compress axes layout option
===========================

If a grid of subplot axes have fixed aspect ratios, the axes will usually
be too far apart in one of the dimensions. For simple layouts
the ``compress_layout=True`` option to `.Figure` or `.Figure.subplots`
can try to compress that dimension so the axes are a similar distance apart in
both dimensions.
"""

import matplotlib.pyplot as plt
import numpy as np

#############################################################################
#
# The default behavior with constrained_layout. Note how there is a large
# horizontal space between the axes.

fig, axs = plt.subplots(2, 2, figsize=(5, 3), facecolor='0.75',
sharex=True, sharey=True, constrained_layout=True)
for ax in axs.flat:
ax.set_aspect(1.0)
plt.show()

#############################################################################
#
# Adding ``compress_layout=True`` attempts to collapse this space:

fig, axs = plt.subplots(2, 2, figsize=(5, 3), facecolor='0.75',
sharex=True, sharey=True, constrained_layout=True,
compress_layout=True)
for ax in axs.flat:
ax.set_aspect(1.0)
plt.show()

#############################################################################
#
# Compatibility
# -------------
#
# Currently this works with ``constrained_layout=True`` for simple layouts
# that do not have nested gridspec layouts
# (:doc:`/gallery/subplots_axes_and_figures/gridspec_nested`). This
# includes simple colorbar layouts with ``constrained_layout``:

fig, axs = plt.subplots(2, 2, figsize=(5, 3), facecolor='0.75',
sharex=True, sharey=True, constrained_layout=True,
compress_layout=True)
for ax in axs.flat:
ax.set_aspect(1.0)
pc = ax.pcolormesh(np.random.randn(20, 20))
fig.colorbar(pc, ax=ax)
plt.show()

fig, axs = plt.subplots(2, 2, figsize=(5, 3), facecolor='0.75',
sharex=True, sharey=True, constrained_layout=True,
compress_layout=True)
for ax in axs.flat:
ax.set_aspect(1.0)
pc = ax.pcolormesh(np.random.randn(20, 20))
fig.colorbar(pc, ax=axs)
plt.show()

#############################################################################
# Compatibility is currently not as good with ``tight_layout`` or no layout
# manager, primarily because colorbars are implimented with nested gridspecs.

for tl in [True, False]:
fig, axs = plt.subplots(2, 2, figsize=(5, 3), facecolor='0.75',
sharex=True, sharey=True, tight_layout=tl,
compress_layout=True)
for ax in axs.flat:
ax.set_aspect(1.0)
pc = ax.pcolormesh(np.random.randn(20, 20))
fig.colorbar(pc, ax=ax)
fig.suptitle(f'Tight Layout: {tl}')
plt.show()

#############################################################################
# However, both work with simple layouts that do not have colorbars.

for tl in [True, False]:
fig, axs = plt.subplots(2, 2, figsize=(5, 3), facecolor='0.75',
sharex=True, sharey=True, tight_layout=tl,
compress_layout=True)
for ax in axs.flat:
ax.set_aspect(1.0)
pc = ax.pcolormesh(np.random.randn(20, 20))
# fig.colorbar(pc, ax=ax)
fig.suptitle(f'Tight Layout: {tl}')
plt.show()
223 changes: 223 additions & 0 deletions lib/matplotlib/_compress_layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import numpy as np

"""
This code attemprs to compress axes if they have excessive space between
axes, usually because the axes have fixed aspect ratios.
"""


def compress_layout(fig, *, bboxes=None, w_pad=0.05, h_pad=0.05,
wspace=0, hspace=0):
"""
Utility that will attempt to compress axes on a figure together.

w_pad, h_pad are inches and are the half distance to the next
axes in width and height respectively.
"""

w, h = fig.get_size_inches()
w_pad = w_pad / w * 2
h_pad = h_pad / h * 2

renderer = fig.canvas.get_renderer()
gss = set()
invTransFig = fig.transFigure.inverted().transform_bbox

if bboxes is None:
bboxes = dict()
for ax in fig.axes:
bboxes[ax] = invTransFig(ax.get_tightbbox(renderer))

colorbars = []
for ax in fig.axes:
if hasattr(ax, '_colorbar_info'):
colorbars += [ax]
elif hasattr(ax, 'get_subplotspec'):
gs = ax.get_subplotspec().get_gridspec()
gss.add(gs)
for cba in ax._colorbars:
# get the bbox including the colorbar for this axis
if cba._colorbar_info['location'] == 'right':
bboxes[ax].x1 = bboxes[cba].x1
if cba._colorbar_info['location'] == 'left':
bboxes[ax].x0 = bboxes[cba].x0
if cba._colorbar_info['location'] == 'top':
bboxes[ax].y1 = bboxes[cba].y1
if cba._colorbar_info['location'] == 'bottom':
bboxes[ax].y0 = bboxes[cba].y0

# we placed everything, but what if there are huge gaps...
for gs in gss:
axs = [ax for ax in fig.axes
if (hasattr(ax, 'get_subplotspec')
and ax.get_subplotspec().get_gridspec() == gs)]
nrows, ncols = gs.get_geometry()
# get widths:
dxs = np.zeros((nrows, ncols))
dys = np.zeros((nrows, ncols))
margxs = np.zeros((nrows, ncols))
margys = np.zeros((nrows, ncols))

for i in range(nrows):
for j in range(ncols):
for ax in axs:
ss = ax.get_subplotspec()
if (i in ss.rowspan) and (j in ss.colspan):
orpos = ax.get_position(original=True)
di = ss.colspan[-1] - ss.colspan[0] + 1
dj = ss.rowspan[-1] - ss.rowspan[0] + 1
dxs[i, j] = bboxes[ax].bounds[2] / di
pad = np.max([w_pad, wspace * bboxes[ax].bounds[2]])
if ss.colspan[-1] < ncols - 1:
dxs[i, j] = dxs[i, j] + pad / di
dys[i, j] = bboxes[ax].bounds[3] / dj
pad = np.max([h_pad, hspace * bboxes[ax].bounds[3]])
if ss.rowspan[0] > 0:
dys[i, j] = dys[i, j] + pad / dj
margxs[i, j] = orpos.x0 - bboxes[ax].x0
margys[i, j] = orpos.y0 - bboxes[ax].y0

margxs = np.flipud(margxs)
margys = np.flipud(margys)
dys = np.flipud(dys)
dxs = np.flipud(dxs)

ddxs = np.max(dxs, axis=0)
ddys = np.max(dys, axis=1)
dx = np.sum(ddxs)
dy = np.sum(ddys)
x1 = y1 = -np.Inf
x0 = y0 = np.Inf

if (dx < dy) and (dx < 0.9):
margx = np.max(margxs, axis=0)
# Squish x!
extra = (1 - dx) / 2
for ax in axs:
ss = ax.get_subplotspec()
orpos = ax.get_position(original=True)
x = extra
for j in range(0, ss.colspan[0]):
x += ddxs[j]
deltax = -orpos.x0 + x + margx[ss.colspan[0]]
orpos.x1 = orpos.x1 + deltax
orpos.x0 = orpos.x0 + deltax
# keep track of new bbox edges for placing colorbars
if bboxes[ax].x1 + deltax > x1:
x1 = bboxes[ax].x1 + deltax
if bboxes[ax].x0 + deltax < x0:
x0 = bboxes[ax].x0 + deltax
bboxes[ax].x0 = bboxes[ax].x0 + deltax
bboxes[ax].x1 = bboxes[ax].x1 + deltax
# Now set the new position.
ax._set_position(orpos, which='original')
# shift any colorbars belongig to this axis
for cba in ax._colorbars:
pos = cba.get_position(original=True)
if cba._colorbar_info['location'] in ['bottom', 'top']:
# shrink to make same size as active...
posac = ax.get_position(original=False)
dx = ((1 - cba._colorbar_info['shrink']) *
(posac.x1 - posac.x0) / 2)
pos.x0 = posac.x0 + dx
pos.x1 = posac.x1 - dx
else:
pos.x0 = pos.x0 + deltax
pos.x1 = pos.x1 + deltax
cba._set_position(pos, which='original')
colorbars.remove(cba)
for cb in colorbars:
# shift any colorbars belonging to the gridspec.
pos = cb.get_position(original=True)
bbox = bboxes[cb]
if cb._colorbar_info['location'] == 'right':
marg = bbox.x0 - pos.x0
x = x1 + marg + w_pad
pos.x1 = pos.x1 - pos.x0 + x
pos.x0 = x
elif cb._colorbar_info['location'] == 'left':
marg = bbox.x1 - pos.x1
# left edge:
x = x0 - marg - w_pad
_dx = pos.x1 - pos.x0
pos.x1 = x - marg
pos.x0 = x - marg - _dx
else:
ddx = (x1 - x0) * (1 - cb._colorbar_info['shrink']) / 2
marg = bbox.x0 - pos.x0
pos.x0 = x0 - marg + ddx
marg = bbox.x1 - pos.x1
pos.x1 = x1 - marg - ddx
cb._set_position(pos, which='original')

if (dx > dy) and (dy < 0.9):
margy = np.min(margys, axis=1)
# Squish y!
extra = (1 - dy) / 2
for nn, ax in enumerate(axs):
ss = ax.get_subplotspec()
orpos = ax.get_position(original=True)
y = extra
for j in range(0, nrows - ss.rowspan[-1] - 1):
y += ddys[j]
deltay = -orpos.y0 + y + margy[nrows - ss.rowspan[-1] - 1]
orpos.y1 = orpos.y1 + deltay
orpos.y0 = orpos.y0 + deltay
ax._set_position(orpos, which='original')
# keep track of new bbox edges for placing colorbars
if bboxes[ax].y1 + deltay > y1:
y1 = bboxes[ax].y1 + deltay
if bboxes[ax].y0 + deltay < y0:
y0 = bboxes[ax].y0 + deltay
bboxes[ax].y0 = bboxes[ax].y0 + deltay
bboxes[ax].y1 = bboxes[ax].y1 + deltay
# shift any colorbars belongig to this axis
for cba in ax._colorbars:
pos = cba.get_position(original=True)
if cba._colorbar_info['location'] in ['right', 'left']:
# shrink to make same size as active...
posac = ax.get_position(original=False)
ddy = ((1 - cba._colorbar_info['shrink']) *
(posac.y1 - posac.y0) / 2)
pos.y0 = posac.y0 + ddy
pos.y1 = posac.y1 - ddy
else:
pos.y0 = pos.y0 + deltay
pos.y1 = pos.y1 + deltay
cba._set_position(pos, which='original')
colorbars.remove(cba)
# print(colorbars)
for cb in colorbars:
# shift any colorbars belonging to the gridspec.
pos = cb.get_position(original=True)
bbox = bboxes[cb]
if cb._colorbar_info['location'] == 'top':
marg = bbox.y0 - pos.y0
y = y1 + marg + h_pad
pos.y1 = pos.y1 - pos.y0 + y
pos.y0 = y
elif cb._colorbar_info['location'] == 'bottom':
marg = bbox.y1 - pos.y1
# left edge:
y = y0 - marg - h_pad
_dy = pos.y1 - pos.y0
pos.y1 = y - marg
pos.y0 = y - marg - _dy
else:
ddy = (y1 - y0) * (1 - cb._colorbar_info['shrink']) / 2
marg = bbox.y0 - pos.y0
pos.y0 = y0 - marg + ddy
marg = bbox.y1 - pos.y1
pos.y1 = y1 - marg - ddy

cb._set_position(pos, which='original')
# need to do suptitles:
suptitle = fig._suptitle
do_suptitle = (suptitle is not None and
suptitle._layoutbox is not None and
suptitle.get_in_layout())
if do_suptitle:
x, y = suptitle.get_position()
bbox = invTransFig(suptitle.get_window_extent(renderer))
marg = y - bbox.y0
suptitle.set_y(y1 + marg + h_pad)
16 changes: 14 additions & 2 deletions lib/matplotlib/_constrained_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import matplotlib.cbook as cbook
import matplotlib._layoutbox as layoutbox


_log = logging.getLogger(__name__)


Expand All @@ -69,7 +70,7 @@ def _axes_all_finite_sized(fig):

######################################################
def do_constrained_layout(fig, renderer, h_pad, w_pad,
hspace=None, wspace=None):
hspace=None, wspace=None, squish=False):
"""
Do the constrained_layout. Called at draw time in
``figure.constrained_layout()``
Expand All @@ -89,6 +90,8 @@ def do_constrained_layout(fig, renderer, h_pad, w_pad,
hspace, wspace : float
are in fractions of the subplot sizes.

squish : bool, default False
try to compress fixed-aspect axes.
"""

# Steps:
Expand Down Expand Up @@ -152,12 +155,16 @@ def do_constrained_layout(fig, renderer, h_pad, w_pad,
# change size after the first re-position (i.e. x/yticklabels get
# larger/smaller). This second reposition tends to be much milder,
# so doing twice makes things work OK.
bboxes = {} # need these for packing the layout later...
for ax in fig.axes:
_log.debug(ax._layoutbox)
if ax._layoutbox is not None:
# make margins for each layout box based on the size of
# the decorators.
_make_layout_margins(ax, renderer, h_pad, w_pad)
bbox = _make_layout_margins(ax, renderer, h_pad, w_pad)
else:
bbox = None
bboxes[ax] = bbox

# do layout for suptitle.
suptitle = fig._suptitle
Expand Down Expand Up @@ -198,6 +205,7 @@ def do_constrained_layout(fig, renderer, h_pad, w_pad,
_align_spines(fig, gs)

fig._layoutbox.constrained_layout_called += 1
# call the kiwi solver:
fig._layoutbox.update_variables()

# check if any axes collapsed to zero. If not, don't change positions:
Expand All @@ -220,6 +228,7 @@ def do_constrained_layout(fig, renderer, h_pad, w_pad,
else:
cbook._warn_external('constrained_layout not applied. At least '
'one axes collapsed to zero width or height.')
return bboxes


def _make_ghost_gridspec_slots(fig, gs):
Expand Down Expand Up @@ -254,6 +263,8 @@ def _make_layout_margins(ax, renderer, h_pad, w_pad):
For each axes, make a margin between the *pos* layoutbox and the
*axes* layoutbox be a minimum size that can accommodate the
decorations on the axis.

Returns the bbox for some width/height calcs outside this loop.
"""
fig = ax.figure
invTransFig = fig.transFigure.inverted().transform_bbox
Expand Down Expand Up @@ -302,6 +313,7 @@ def _make_layout_margins(ax, renderer, h_pad, w_pad):
ax._poslayoutbox.constrain_bottom_margin(0, strength='weak')
ax._poslayoutbox.constrain_right_margin(0, strength='weak')
ax._poslayoutbox.constrain_left_margin(0, strength='weak')
return bbox


def _align_spines(fig, gs):
Expand Down
Loading