Skip to content

Adds row and column axes sharing to simple subplot mosaics #26411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 145 additions & 50 deletions lib/matplotlib/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,37 +30,41 @@
:ref:`figure_explanation`.
"""

from contextlib import ExitStack
import inspect
import itertools
import logging
from numbers import Integral
import threading
from contextlib import ExitStack
from numbers import Integral

import numpy as np

import matplotlib as mpl
from matplotlib import _blocking_input, backend_bases, _docstring, projections
from matplotlib.artist import (
Artist, allow_rasterization, _finalize_rasterization)
from matplotlib.backend_bases import (
DrawEvent, FigureCanvasBase, NonGuiException, MouseButton, _get_renderer)
import matplotlib._api as _api
import matplotlib.cbook as cbook
import matplotlib.colorbar as cbar
import matplotlib.image as mimage

import matplotlib.legend as mlegend
from matplotlib import _blocking_input, _docstring, backend_bases, projections
from matplotlib.artist import Artist, _finalize_rasterization, allow_rasterization
from matplotlib.axes import Axes
from matplotlib.backend_bases import (
DrawEvent,
FigureCanvasBase,
MouseButton,
NonGuiException,
_get_renderer,
)
from matplotlib.gridspec import GridSpec
from matplotlib.layout_engine import (
ConstrainedLayoutEngine, TightLayoutEngine, LayoutEngine,
PlaceHolderLayoutEngine
ConstrainedLayoutEngine,
LayoutEngine,
PlaceHolderLayoutEngine,
TightLayoutEngine,
)
import matplotlib.legend as mlegend
from matplotlib.patches import Rectangle
from matplotlib.text import Text
from matplotlib.transforms import (Affine2D, Bbox, BboxTransformTo,
TransformedBbox)
from matplotlib.transforms import Affine2D, Bbox, BboxTransformTo, TransformedBbox

_log = logging.getLogger(__name__)

Expand Down Expand Up @@ -1871,11 +1875,18 @@ def subplot_mosaic(self, mosaic, *, sharex=False, sharey=False,
The Axes identifiers may be `str` or a non-iterable hashable
object (e.g. `tuple` s may not be used).

sharex, sharey : bool, default: False
If True, the x-axis (*sharex*) or y-axis (*sharey*) will be shared
among all subplots. In that case, tick label visibility and axis
units behave as for `subplots`. If False, each subplot's x- or
y-axis will be independent.
sharex, sharey : bool or {'none', 'all', 'row', 'col'}, default: False
Controls sharing of x-axis (*sharex*) or y-axis (*sharey*):

- True or 'all': x- or y-axis will be shared among all subplots.
- False or 'none': each subplot x- or y-axis will be independent.
- 'row': each subplot with the same rows span will share an x- or
y-axis.
- 'col': each subplot with the same column span will share an x- or
y-axis.

Tick label visibility and axis units behave as for `subplots`.


width_ratios : array-like of length *ncols*, optional
Defines the relative widths of the columns. Each column gets a
Expand Down Expand Up @@ -1934,6 +1945,13 @@ def subplot_mosaic(self, mosaic, *, sharex=False, sharey=False,
gridspec_kw = dict(gridspec_kw or {})
per_subplot_kw = per_subplot_kw or {}

# Only accept strict bool and str to allow a possible future API expansion.
_api.check_isinstance((bool, str), sharex=sharex, sharey=sharey)
if not isinstance(sharex, str):
sharex = "all" if sharex else "none"
if not isinstance(sharey, str):
sharey = "all" if sharey else "none"

if height_ratios is not None:
if 'height_ratios' in gridspec_kw:
raise ValueError("'height_ratios' must not be defined both as "
Expand All @@ -1954,9 +1972,6 @@ def subplot_mosaic(self, mosaic, *, sharex=False, sharey=False,

per_subplot_kw = self._norm_per_subplot_kw(per_subplot_kw)

# Only accept strict bools to allow a possible future API expansion.
_api.check_isinstance(bool, sharex=sharex, sharey=sharey)

def _make_array(inp):
"""
Convert input into 2D array
Expand Down Expand Up @@ -2015,6 +2030,33 @@ def _identify_keys_and_nested(mosaic):

return tuple(unique_ids), nested

def _parse_mosaic_to_span(mosaic, unique_ids):
"""
Maps the mosaic label/ids to the row and column span.

Returns
-------
dict[str, (row_slice, col_slice)]
"""
ids_to_span = {}
for id_ in unique_ids:
# sort out where each axes starts/ends
indx = np.argwhere(mosaic == id_)
start_row, start_col = np.min(indx, axis=0)
end_row, end_col = np.max(indx, axis=0) + 1
# and construct the slice object
slc = (slice(start_row, end_row), slice(start_col, end_col))

if (mosaic[slc] != id_).any():
raise ValueError(
f"While trying to layout\n{mosaic!r}\n"
f"we found that the label {id_!r} specifies a "
"non-rectangular or non-contiguous area.")

ids_to_span[id_] = slc

return ids_to_span

def _do_layout(gs, mosaic, unique_ids, nested):
"""
Recursively do the mosaic.
Expand Down Expand Up @@ -2045,22 +2087,14 @@ def _do_layout(gs, mosaic, unique_ids, nested):
# nested mosaic) at this level
this_level = dict()

label_to_span = _parse_mosaic_to_span(mosaic, unique_ids)

# go through the unique keys,
for name in unique_ids:
# sort out where each axes starts/ends
indx = np.argwhere(mosaic == name)
start_row, start_col = np.min(indx, axis=0)
end_row, end_col = np.max(indx, axis=0) + 1
# and construct the slice object
slc = (slice(start_row, end_row), slice(start_col, end_col))
# some light error checking
if (mosaic[slc] != name).any():
raise ValueError(
f"While trying to layout\n{mosaic!r}\n"
f"we found that the label {name!r} specifies a "
"non-rectangular or non-contiguous area.")
for label, slc in label_to_span.items():
# and stash this slice for later
this_level[(start_row, start_col)] = (name, slc, 'axes')
start_row = slc[0].start
start_col = slc[1].start
this_level[(start_row, start_col)] = (label, slc, 'axes')

# do the same thing for the nested mosaics (simpler because these
# cannot be spans yet!)
Expand All @@ -2070,24 +2104,25 @@ def _do_layout(gs, mosaic, unique_ids, nested):
# now go through the things in this level and add them
# in order left-to-right top-to-bottom
for key in sorted(this_level):
name, arg, method = this_level[key]
label, arg, method = this_level[key]
# we are doing some hokey function dispatch here based
# on the 'method' string stashed above to sort out if this
# element is an Axes or a nested mosaic.
if method == 'axes':
slc = arg
# add a single axes
if name in output:
raise ValueError(f"There are duplicate keys {name} "
if label in output:
raise ValueError(f"There are duplicate keys {label} "
f"in the layout\n{mosaic!r}")

ax = self.add_subplot(
gs[slc], **{
'label': str(name),
'label': str(label),
**subplot_kw,
**per_subplot_kw.get(name, {})
**per_subplot_kw.get(label, {})
}
)
output[name] = ax
output[label] = ax
elif method == 'nested':
nested_mosaic = arg
j, k = key
Expand All @@ -2113,15 +2148,75 @@ def _do_layout(gs, mosaic, unique_ids, nested):
mosaic = _make_array(mosaic)
rows, cols = mosaic.shape
gs = self.add_gridspec(rows, cols, **gridspec_kw)
ret = _do_layout(gs, mosaic, *_identify_keys_and_nested(mosaic))
ax0 = next(iter(ret.values()))
for ax in ret.values():
if sharex:
unique_labels, nested_coord_to_labels = _identify_keys_and_nested(mosaic)
ret = _do_layout(gs, mosaic, unique_labels, nested_coord_to_labels)

# Handle axes sharing

def _find_row_col_groups(mosaic, unique_labels):
label_to_span = _parse_mosaic_to_span(mosaic, unique_labels)

row_group = {}
col_group = {}
for label, (row_slice, col_slice) in label_to_span.items():
start_row, end_row = row_slice.start, row_slice.stop
start_col, end_col = col_slice.start, col_slice.stop

if (start_col, end_col) not in col_group:
col_group[(start_col, end_col)] = [label]
else:
col_group[(start_col, end_col)].append(label)

if (start_row, end_row) not in row_group:
row_group[(start_row, end_row)] = [label]
else:
row_group[(start_row, end_row)].append(label)

return row_group, col_group

# Pairs of axes where the first axes is meant to call sharex/sharey on
# the second axes. The second axes depends on if the sharex/sharey is
# set to "row", "col", or "all".
share_axes_pairs = {
"all": tuple(
(ax, next(iter(ret.values()))) for ax in ret.values()
)
}
if sharex in ("row", "col") or sharey in ("row", "col"):
if nested_coord_to_labels:
raise ValueError(
"Cannot share axes by row or column when using nested mosaic"
)
else:
row_groups, col_groups = _find_row_col_groups(mosaic, unique_labels)

share_axes_pairs.update({
"row": tuple(
(ret[label], ret[row_group[0]])
for row_group in row_groups.values()
for label in row_group
),
"col": tuple(
(ret[label], ret[col_group[0]])
for col_group in col_groups.values()
for label in col_group
),
})

if sharex in share_axes_pairs:
for ax, ax0 in share_axes_pairs[sharex]:
ax.sharex(ax0)
ax._label_outer_xaxis(skip_non_rectangular_axes=True)
if sharey:
if sharex in ["col", "all"] and ax0 is not ax:
ax0._label_outer_xaxis(skip_non_rectangular_axes=True)
ax._label_outer_xaxis(skip_non_rectangular_axes=True)

if sharey in share_axes_pairs:
for ax, ax0 in share_axes_pairs[sharey]:
ax.sharey(ax0)
ax._label_outer_yaxis(skip_non_rectangular_axes=True)
if sharey in ["row", "all"] and ax0 is not ax:
ax0._label_outer_yaxis(skip_non_rectangular_axes=True)
ax._label_outer_yaxis(skip_non_rectangular_axes=True)

if extra := set(per_subplot_kw) - set(ret):
raise ValueError(
f"The keys {extra} are in *per_subplot_kw* "
Expand Down Expand Up @@ -2155,7 +2250,7 @@ class SubFigure(FigureBase):
See :doc:`/gallery/subplots_axes_and_figures/subfigures`

.. note::
The *subfigure* concept is new in v3.4, and the API is still provisional.
The *subfigure* concept is new in v3.4, and API is still provisional.
"""

def __init__(self, parent, subplotspec, *,
Expand Down Expand Up @@ -3213,8 +3308,8 @@ def __setstate__(self, state):

if restore_to_pylab:
# lazy import to avoid circularity
import matplotlib.pyplot as plt
import matplotlib._pylab_helpers as pylab_helpers
import matplotlib.pyplot as plt
allnums = plt.get_fignums()
num = max(allnums) + 1 if allnums else 1
backend = plt._get_backend_mod()
Expand Down
4 changes: 2 additions & 2 deletions lib/matplotlib/figure.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ class FigureBase(Artist):
self,
mosaic: str | HashableList,
*,
sharex: bool = ...,
sharey: bool = ...,
sharex: bool | Literal["none", "col", "row", "all"] = ...,
sharey: bool | Literal["none", "col", "row", "all"] = ...,
width_ratios: ArrayLike | None = ...,
height_ratios: ArrayLike | None = ...,
empty_sentinel: Any = ...,
Expand Down
Loading