Skip to content

Commit ce1351d

Browse files
committed
Adds row and column axes sharing to simple subplot mosaics
1 parent 5d5700a commit ce1351d

File tree

3 files changed

+244
-61
lines changed

3 files changed

+244
-61
lines changed

lib/matplotlib/figure.py

+145-50
Original file line numberDiff line numberDiff line change
@@ -30,37 +30,41 @@
3030
:ref:`figure_explanation`.
3131
"""
3232

33-
from contextlib import ExitStack
3433
import inspect
3534
import itertools
3635
import logging
37-
from numbers import Integral
3836
import threading
37+
from contextlib import ExitStack
38+
from numbers import Integral
3939

4040
import numpy as np
4141

4242
import matplotlib as mpl
43-
from matplotlib import _blocking_input, backend_bases, _docstring, projections
44-
from matplotlib.artist import (
45-
Artist, allow_rasterization, _finalize_rasterization)
46-
from matplotlib.backend_bases import (
47-
DrawEvent, FigureCanvasBase, NonGuiException, MouseButton, _get_renderer)
4843
import matplotlib._api as _api
4944
import matplotlib.cbook as cbook
5045
import matplotlib.colorbar as cbar
5146
import matplotlib.image as mimage
52-
47+
import matplotlib.legend as mlegend
48+
from matplotlib import _blocking_input, _docstring, backend_bases, projections
49+
from matplotlib.artist import Artist, _finalize_rasterization, allow_rasterization
5350
from matplotlib.axes import Axes
51+
from matplotlib.backend_bases import (
52+
DrawEvent,
53+
FigureCanvasBase,
54+
MouseButton,
55+
NonGuiException,
56+
_get_renderer,
57+
)
5458
from matplotlib.gridspec import GridSpec
5559
from matplotlib.layout_engine import (
56-
ConstrainedLayoutEngine, TightLayoutEngine, LayoutEngine,
57-
PlaceHolderLayoutEngine
60+
ConstrainedLayoutEngine,
61+
LayoutEngine,
62+
PlaceHolderLayoutEngine,
63+
TightLayoutEngine,
5864
)
59-
import matplotlib.legend as mlegend
6065
from matplotlib.patches import Rectangle
6166
from matplotlib.text import Text
62-
from matplotlib.transforms import (Affine2D, Bbox, BboxTransformTo,
63-
TransformedBbox)
67+
from matplotlib.transforms import Affine2D, Bbox, BboxTransformTo, TransformedBbox
6468

6569
_log = logging.getLogger(__name__)
6670

@@ -1871,11 +1875,18 @@ def subplot_mosaic(self, mosaic, *, sharex=False, sharey=False,
18711875
The Axes identifiers may be `str` or a non-iterable hashable
18721876
object (e.g. `tuple` s may not be used).
18731877
1874-
sharex, sharey : bool, default: False
1875-
If True, the x-axis (*sharex*) or y-axis (*sharey*) will be shared
1876-
among all subplots. In that case, tick label visibility and axis
1877-
units behave as for `subplots`. If False, each subplot's x- or
1878-
y-axis will be independent.
1878+
sharex, sharey : bool or {'none', 'all', 'row', 'col'}, default: False
1879+
Controls sharing of x-axis (*sharex*) or y-axis (*sharey*):
1880+
1881+
- True or 'all': x- or y-axis will be shared among all subplots.
1882+
- False or 'none': each subplot x- or y-axis will be independent.
1883+
- 'row': each subplot with the same rows span will share an x- or
1884+
y-axis.
1885+
- 'col': each subplot with the same column span will share an x- or
1886+
y-axis.
1887+
1888+
Tick label visibility and axis units behave as for `subplots`.
1889+
18791890
18801891
width_ratios : array-like of length *ncols*, optional
18811892
Defines the relative widths of the columns. Each column gets a
@@ -1934,6 +1945,13 @@ def subplot_mosaic(self, mosaic, *, sharex=False, sharey=False,
19341945
gridspec_kw = dict(gridspec_kw or {})
19351946
per_subplot_kw = per_subplot_kw or {}
19361947

1948+
# Only accept strict bool and str to allow a possible future API expansion.
1949+
_api.check_isinstance((bool, str), sharex=sharex, sharey=sharey)
1950+
if not isinstance(sharex, str):
1951+
sharex = "all" if sharex else "none"
1952+
if not isinstance(sharey, str):
1953+
sharey = "all" if sharey else "none"
1954+
19371955
if height_ratios is not None:
19381956
if 'height_ratios' in gridspec_kw:
19391957
raise ValueError("'height_ratios' must not be defined both as "
@@ -1954,9 +1972,6 @@ def subplot_mosaic(self, mosaic, *, sharex=False, sharey=False,
19541972

19551973
per_subplot_kw = self._norm_per_subplot_kw(per_subplot_kw)
19561974

1957-
# Only accept strict bools to allow a possible future API expansion.
1958-
_api.check_isinstance(bool, sharex=sharex, sharey=sharey)
1959-
19601975
def _make_array(inp):
19611976
"""
19621977
Convert input into 2D array
@@ -2015,6 +2030,33 @@ def _identify_keys_and_nested(mosaic):
20152030

20162031
return tuple(unique_ids), nested
20172032

2033+
def _parse_mosaic_to_span(mosaic, unique_ids):
2034+
"""
2035+
Maps the mosaic label/ids to the row and column span.
2036+
2037+
Returns
2038+
-------
2039+
dict[str, (row_slice, col_slice)]
2040+
"""
2041+
ids_to_span = {}
2042+
for id_ in unique_ids:
2043+
# sort out where each axes starts/ends
2044+
indx = np.argwhere(mosaic == id_)
2045+
start_row, start_col = np.min(indx, axis=0)
2046+
end_row, end_col = np.max(indx, axis=0) + 1
2047+
# and construct the slice object
2048+
slc = (slice(start_row, end_row), slice(start_col, end_col))
2049+
2050+
if (mosaic[slc] != id_).any():
2051+
raise ValueError(
2052+
f"While trying to layout\n{mosaic!r}\n"
2053+
f"we found that the label {id_!r} specifies a "
2054+
"non-rectangular or non-contiguous area.")
2055+
2056+
ids_to_span[id_] = slc
2057+
2058+
return ids_to_span
2059+
20182060
def _do_layout(gs, mosaic, unique_ids, nested):
20192061
"""
20202062
Recursively do the mosaic.
@@ -2045,22 +2087,14 @@ def _do_layout(gs, mosaic, unique_ids, nested):
20452087
# nested mosaic) at this level
20462088
this_level = dict()
20472089

2090+
label_to_span = _parse_mosaic_to_span(mosaic, unique_ids)
2091+
20482092
# go through the unique keys,
2049-
for name in unique_ids:
2050-
# sort out where each axes starts/ends
2051-
indx = np.argwhere(mosaic == name)
2052-
start_row, start_col = np.min(indx, axis=0)
2053-
end_row, end_col = np.max(indx, axis=0) + 1
2054-
# and construct the slice object
2055-
slc = (slice(start_row, end_row), slice(start_col, end_col))
2056-
# some light error checking
2057-
if (mosaic[slc] != name).any():
2058-
raise ValueError(
2059-
f"While trying to layout\n{mosaic!r}\n"
2060-
f"we found that the label {name!r} specifies a "
2061-
"non-rectangular or non-contiguous area.")
2093+
for label, slc in label_to_span.items():
20622094
# and stash this slice for later
2063-
this_level[(start_row, start_col)] = (name, slc, 'axes')
2095+
start_row = slc[0].start
2096+
start_col = slc[1].start
2097+
this_level[(start_row, start_col)] = (label, slc, 'axes')
20642098

20652099
# do the same thing for the nested mosaics (simpler because these
20662100
# cannot be spans yet!)
@@ -2070,24 +2104,25 @@ def _do_layout(gs, mosaic, unique_ids, nested):
20702104
# now go through the things in this level and add them
20712105
# in order left-to-right top-to-bottom
20722106
for key in sorted(this_level):
2073-
name, arg, method = this_level[key]
2107+
label, arg, method = this_level[key]
20742108
# we are doing some hokey function dispatch here based
20752109
# on the 'method' string stashed above to sort out if this
20762110
# element is an Axes or a nested mosaic.
20772111
if method == 'axes':
20782112
slc = arg
20792113
# add a single axes
2080-
if name in output:
2081-
raise ValueError(f"There are duplicate keys {name} "
2114+
if label in output:
2115+
raise ValueError(f"There are duplicate keys {label} "
20822116
f"in the layout\n{mosaic!r}")
2117+
20832118
ax = self.add_subplot(
20842119
gs[slc], **{
2085-
'label': str(name),
2120+
'label': str(label),
20862121
**subplot_kw,
2087-
**per_subplot_kw.get(name, {})
2122+
**per_subplot_kw.get(label, {})
20882123
}
20892124
)
2090-
output[name] = ax
2125+
output[label] = ax
20912126
elif method == 'nested':
20922127
nested_mosaic = arg
20932128
j, k = key
@@ -2113,15 +2148,75 @@ def _do_layout(gs, mosaic, unique_ids, nested):
21132148
mosaic = _make_array(mosaic)
21142149
rows, cols = mosaic.shape
21152150
gs = self.add_gridspec(rows, cols, **gridspec_kw)
2116-
ret = _do_layout(gs, mosaic, *_identify_keys_and_nested(mosaic))
2117-
ax0 = next(iter(ret.values()))
2118-
for ax in ret.values():
2119-
if sharex:
2151+
unique_labels, nested_coord_to_labels = _identify_keys_and_nested(mosaic)
2152+
ret = _do_layout(gs, mosaic, unique_labels, nested_coord_to_labels)
2153+
2154+
# Handle axes sharing
2155+
2156+
def _find_row_col_groups(mosaic, unique_labels):
2157+
label_to_span = _parse_mosaic_to_span(mosaic, unique_labels)
2158+
2159+
row_group = {}
2160+
col_group = {}
2161+
for label, (row_slice, col_slice) in label_to_span.items():
2162+
start_row, end_row = row_slice.start, row_slice.stop
2163+
start_col, end_col = col_slice.start, col_slice.stop
2164+
2165+
if (start_col, end_col) not in col_group:
2166+
col_group[(start_col, end_col)] = [label]
2167+
else:
2168+
col_group[(start_col, end_col)].append(label)
2169+
2170+
if (start_row, end_row) not in row_group:
2171+
row_group[(start_row, end_row)] = [label]
2172+
else:
2173+
row_group[(start_row, end_row)].append(label)
2174+
2175+
return row_group, col_group
2176+
2177+
# Pairs of axes where the first axes is meant to call sharex/sharey on
2178+
# the second axes. The second axes depends on if the sharex/sharey is
2179+
# set to "row", "col", or "all".
2180+
share_axes_pairs = {
2181+
"all": tuple(
2182+
(ax, next(iter(ret.values()))) for ax in ret.values()
2183+
)
2184+
}
2185+
if sharex in ("row", "col") or sharey in ("row", "col"):
2186+
if nested_coord_to_labels:
2187+
raise ValueError(
2188+
"Cannot share axes by row or column when using nested mosaic"
2189+
)
2190+
else:
2191+
row_groups, col_groups = _find_row_col_groups(mosaic, unique_labels)
2192+
2193+
share_axes_pairs.update({
2194+
"row": tuple(
2195+
(ret[label], ret[row_group[0]])
2196+
for row_group in row_groups.values()
2197+
for label in row_group
2198+
),
2199+
"col": tuple(
2200+
(ret[label], ret[col_group[0]])
2201+
for col_group in col_groups.values()
2202+
for label in col_group
2203+
),
2204+
})
2205+
2206+
if sharex in share_axes_pairs:
2207+
for ax, ax0 in share_axes_pairs[sharex]:
21202208
ax.sharex(ax0)
2121-
ax._label_outer_xaxis(skip_non_rectangular_axes=True)
2122-
if sharey:
2209+
if sharex in ["col", "all"] and ax0 is not ax:
2210+
ax0._label_outer_xaxis(skip_non_rectangular_axes=True)
2211+
ax._label_outer_xaxis(skip_non_rectangular_axes=True)
2212+
2213+
if sharey in share_axes_pairs:
2214+
for ax, ax0 in share_axes_pairs[sharey]:
21232215
ax.sharey(ax0)
2124-
ax._label_outer_yaxis(skip_non_rectangular_axes=True)
2216+
if sharey in ["row", "all"] and ax0 is not ax:
2217+
ax0._label_outer_yaxis(skip_non_rectangular_axes=True)
2218+
ax._label_outer_yaxis(skip_non_rectangular_axes=True)
2219+
21252220
if extra := set(per_subplot_kw) - set(ret):
21262221
raise ValueError(
21272222
f"The keys {extra} are in *per_subplot_kw* "
@@ -2155,7 +2250,7 @@ class SubFigure(FigureBase):
21552250
See :doc:`/gallery/subplots_axes_and_figures/subfigures`
21562251
21572252
.. note::
2158-
The *subfigure* concept is new in v3.4, and the API is still provisional.
2253+
The *subfigure* concept is new in v3.4, and API is still provisional.
21592254
"""
21602255

21612256
def __init__(self, parent, subplotspec, *,
@@ -3213,8 +3308,8 @@ def __setstate__(self, state):
32133308

32143309
if restore_to_pylab:
32153310
# lazy import to avoid circularity
3216-
import matplotlib.pyplot as plt
32173311
import matplotlib._pylab_helpers as pylab_helpers
3312+
import matplotlib.pyplot as plt
32183313
allnums = plt.get_fignums()
32193314
num = max(allnums) + 1 if allnums else 1
32203315
backend = plt._get_backend_mod()

lib/matplotlib/figure.pyi

+2-2
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,8 @@ class FigureBase(Artist):
232232
self,
233233
mosaic: str | HashableList,
234234
*,
235-
sharex: bool = ...,
236-
sharey: bool = ...,
235+
sharex: bool | Literal["none", "col", "row", "all"] = ...,
236+
sharey: bool | Literal["none", "col", "row", "all"] = ...,
237237
width_ratios: ArrayLike | None = ...,
238238
height_ratios: ArrayLike | None = ...,
239239
empty_sentinel: Any = ...,

0 commit comments

Comments
 (0)