Skip to content

Adds row and column axes sharing to simple subplot mosaics #26327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
166 changes: 133 additions & 33 deletions lib/matplotlib/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1870,11 +1870,18 @@ def subplot_mosaic(self, mosaic, *, sharex=False, sharey=False,
The Axes identifiers may be `str` or a non-iterable hashable
object (e.g. `tuple` s may not be used).

sharex, sharey : bool, default: False
If True, the x-axis (*sharex*) or y-axis (*sharey*) will be shared
among all subplots. In that case, tick label visibility and axis
units behave as for `subplots`. If False, each subplot's x- or
y-axis will be independent.
sharex, sharey : bool or {'none', 'all', 'row', 'col'}, default: False
Controls sharing of x-axis (*sharex*) or y-axis (*sharey*):

- True or 'all': x- or y-axis will be shared among all subplots.
- False or 'none': each subplot x- or y-axis will be independent.
- 'row': each subplot with the same rows span will share an x- or
y-axis.
- 'col': each subplot with the same column span will share an x- or
y-axis.

Tick label visibility and axis units behave as for `subplots`.


width_ratios : array-like of length *ncols*, optional
Defines the relative widths of the columns. Each column gets a
Expand Down Expand Up @@ -1933,6 +1940,11 @@ def subplot_mosaic(self, mosaic, *, sharex=False, sharey=False,
gridspec_kw = dict(gridspec_kw or {})
per_subplot_kw = per_subplot_kw or {}

if not isinstance(sharex, str):
sharex = "all" if sharex else "none"
if not isinstance(sharey, str):
sharey = "all" if sharey else "none"

if height_ratios is not None:
if 'height_ratios' in gridspec_kw:
raise ValueError("'height_ratios' must not be defined both as "
Expand All @@ -1954,7 +1966,7 @@ def subplot_mosaic(self, mosaic, *, sharex=False, sharey=False,
per_subplot_kw = self._norm_per_subplot_kw(per_subplot_kw)

# Only accept strict bools to allow a possible future API expansion.
_api.check_isinstance(bool, sharex=sharex, sharey=sharey)
_api.check_isinstance((bool, str), sharex=sharex, sharey=sharey)
Comment on lines 1968 to +1969
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This handling was intended to to disallow "truthy" (but not True, the bool) and "falsy" (but not False, the bool) things from giving the True/False case for sharex/y

Since it is happening after normalization to string (on lines 1943-7) this is not actually doing anything here.

I think that normalization should be moved after this check.


def _make_array(inp):
"""
Expand Down Expand Up @@ -2014,6 +2026,33 @@ def _identify_keys_and_nested(mosaic):

return tuple(unique_ids), nested

def _parse_mosaic_to_span(mosaic, unique_ids):
"""
Maps the mosaic label/ids to the row and column span.

Returns
-------
dict[str, (row_slice, col_slice)]
"""
ids_to_span = {}
for id_ in unique_ids:
# sort out where each axes starts/ends
indx = np.argwhere(mosaic == id_)
start_row, start_col = np.min(indx, axis=0)
end_row, end_col = np.max(indx, axis=0) + 1
# and construct the slice object
slc = (slice(start_row, end_row), slice(start_col, end_col))

if (mosaic[slc] != id_).any():
raise ValueError(
f"While trying to layout\n{mosaic!r}\n"
f"we found that the label {id_!r} specifies a "
"non-rectangular or non-contiguous area.")

ids_to_span[id_] = slc

return ids_to_span

def _do_layout(gs, mosaic, unique_ids, nested):
"""
Recursively do the mosaic.
Expand Down Expand Up @@ -2044,22 +2083,14 @@ def _do_layout(gs, mosaic, unique_ids, nested):
# nested mosaic) at this level
this_level = dict()

label_to_span = _parse_mosaic_to_span(mosaic, unique_ids)

# go through the unique keys,
for name in unique_ids:
# sort out where each axes starts/ends
indx = np.argwhere(mosaic == name)
start_row, start_col = np.min(indx, axis=0)
end_row, end_col = np.max(indx, axis=0) + 1
# and construct the slice object
slc = (slice(start_row, end_row), slice(start_col, end_col))
# some light error checking
if (mosaic[slc] != name).any():
raise ValueError(
f"While trying to layout\n{mosaic!r}\n"
f"we found that the label {name!r} specifies a "
"non-rectangular or non-contiguous area.")
for label, slc in label_to_span.items():
# and stash this slice for later
this_level[(start_row, start_col)] = (name, slc, 'axes')
start_row = slc[0].start
start_col = slc[1].start
this_level[(start_row, start_col)] = (label, slc, 'axes')

# do the same thing for the nested mosaics (simpler because these
# cannot be spans yet!)
Expand All @@ -2069,24 +2100,25 @@ def _do_layout(gs, mosaic, unique_ids, nested):
# now go through the things in this level and add them
# in order left-to-right top-to-bottom
for key in sorted(this_level):
name, arg, method = this_level[key]
label, arg, method = this_level[key]
# we are doing some hokey function dispatch here based
# on the 'method' string stashed above to sort out if this
# element is an Axes or a nested mosaic.
if method == 'axes':
slc = arg
# add a single axes
if name in output:
raise ValueError(f"There are duplicate keys {name} "
if label in output:
raise ValueError(f"There are duplicate keys {label} "
f"in the layout\n{mosaic!r}")

ax = self.add_subplot(
gs[slc], **{
'label': str(name),
'label': str(label),
**subplot_kw,
**per_subplot_kw.get(name, {})
**per_subplot_kw.get(label, {})
}
)
output[name] = ax
output[label] = ax
elif method == 'nested':
nested_mosaic = arg
j, k = key
Expand All @@ -2113,14 +2145,81 @@ def _do_layout(gs, mosaic, unique_ids, nested):
rows, cols = mosaic.shape
gs = self.add_gridspec(rows, cols, **gridspec_kw)
ret = _do_layout(gs, mosaic, *_identify_keys_and_nested(mosaic))
ax0 = next(iter(ret.values()))
for ax in ret.values():
if sharex:

# Handle axes sharing
def _find_row_col_groups(mosaic, unique_labels):
label_to_span = _parse_mosaic_to_span(mosaic, unique_labels)

row_group = {}
col_group = {}
for label, (row_slice, col_slice) in label_to_span.items():
start_row, end_row = row_slice.start, row_slice.stop
start_col, end_col = col_slice.start, col_slice.stop

if (start_col, end_col) not in col_group:
col_group[(start_col, end_col)] = [label]
else:
col_group[(start_col, end_col)].append(label)

if (start_row, end_row) not in row_group:
row_group[(start_row, end_row)] = [label]
else:
row_group[(start_row, end_row)].append(label)

return row_group, col_group

# Pairs of axes where the first axes is meant to call sharex/sharey on
# the second axes. The second axes depends on if the sharex/sharey is
# set to "row", "col", or "all".

def check_is_nested():
gs0 = self.axes[0].get_subplotspec().get_gridspec()
for ax in self.axes[1:]:
if ax.get_subplotspec().get_gridspec() != gs0:
return True
return False

share_axes_pairs = {
"all": tuple(
(ax, next(iter(ret.values())))
for ax in ret.values()
)
}
if sharex in ("row", "col") or sharey in ("row", "col"):
if check_is_nested():
raise ValueError(
"Cannot share axes by row or column when using nested mosaic"
)
else:
row_groups, col_groups = _find_row_col_groups(mosaic, ret.keys())

share_axes_pairs.update({
"row": tuple(
(ret[label], ret[row_group[0]])
for row_group in row_groups.values()
for label in row_group
),
"col": tuple(
(ret[label], ret[col_group[0]])
for col_group in col_groups.values()
for label in col_group
),
})

if sharex in share_axes_pairs:
for ax, ax0 in share_axes_pairs[sharex]:
ax.sharex(ax0)
ax._label_outer_xaxis(check_patch=True)
if sharey:
if sharex in ["col", "all"] and ax0 is not ax:
ax0._label_outer_xaxis(check_patch=True)
ax._label_outer_xaxis(check_patch=True)

if sharey in share_axes_pairs:
for ax, ax0 in share_axes_pairs[sharey]:
ax.sharey(ax0)
ax._label_outer_yaxis(check_patch=True)
if sharey in ["row", "all"] and ax0 is not ax:
ax0._label_outer_yaxis(check_patch=True)
ax._label_outer_yaxis(check_patch=True)

if extra := set(per_subplot_kw) - set(ret):
raise ValueError(
f"The keys {extra} are in *per_subplot_kw* "
Expand Down Expand Up @@ -2154,7 +2253,7 @@ class SubFigure(FigureBase):
See :doc:`/gallery/subplots_axes_and_figures/subfigures`

.. note::
The *subfigure* concept is new in v3.4, and the API is still provisional.
The *subfigure* concept is new in v3.4, and API is still provisional.
"""

def __init__(self, parent, subplotspec, *,
Expand Down Expand Up @@ -3204,6 +3303,7 @@ def __setstate__(self, state):
if restore_to_pylab:
# lazy import to avoid circularity
import matplotlib.pyplot as plt

import matplotlib._pylab_helpers as pylab_helpers
allnums = plt.get_fignums()
num = max(allnums) + 1 if allnums else 1
Expand Down
4 changes: 2 additions & 2 deletions lib/matplotlib/figure.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ class FigureBase(Artist):
self,
mosaic: str | HashableList,
*,
sharex: bool = ...,
sharey: bool = ...,
sharex: bool | Literal["none", "col", "row", "all"] = ...,
sharey: bool | Literal["none", "col", "row", "all"] = ...,
width_ratios: ArrayLike | None = ...,
height_ratios: ArrayLike | None = ...,
empty_sentinel: Any = ...,
Expand Down
94 changes: 92 additions & 2 deletions lib/matplotlib/tests/test_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,7 +1234,8 @@ def test_nested_user_order(self):
assert list(ax_dict) == list("ABCDEFGHI")
assert list(fig.axes) == list(ax_dict.values())

def test_share_all(self):
@pytest.mark.parametrize("sharex,sharey", [(True, True), ("all", "all")])
def test_share_all(self, sharex, sharey):
layout = [
["A", [["B", "C"],
["D", "E"]]],
Expand All @@ -1243,11 +1244,100 @@ def test_share_all(self):
["."]]]]]
]
fig = plt.figure()
ax_dict = fig.subplot_mosaic(layout, sharex=True, sharey=True)
ax_dict = fig.subplot_mosaic(layout, sharex=sharex, sharey=sharey)
ax_dict["A"].set(xscale="log", yscale="logit")
assert all(ax.get_xscale() == "log" and ax.get_yscale() == "logit"
for ax in ax_dict.values())

@check_figures_equal(extensions=["png"])
def test_sharex_row(self, fig_test, fig_ref):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider doing the test similar to the test_share_all function above, which sets the scale of shared axes and asserts that all expected sharing is done.

Specifically, without anything actually plotted, there is no way for check_figures_equal (which does an image comparison) to know whether two things are shared or not, as each will default to a range of 0-1 anyway...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I will add that. Because sharing axes affect tick labels, I'd like to keep the image comparison.

fig_test.subplot_mosaic([["A", "B"], ["C", "D"]],
sharex="row", sharey=False)

axes_ref = fig_ref.subplot_mosaic(
[
["A", "B"],
["C", "D"]
],
sharex=False,
sharey=False
)
axes_ref["A"].sharex(axes_ref["B"])
axes_ref["C"].sharex(axes_ref["D"])

@check_figures_equal(extensions=["png"])
def test_sharey_row(self, fig_test, fig_ref):
fig_test.subplot_mosaic([["A", "B"], ["C", "D"]],
sharex=False, sharey="row")

axes_ref = fig_ref.subplot_mosaic(
[
["A", "B"],
["C", "D"]
],
sharex=False,
sharey=False
)
axes_ref["A"].sharey(axes_ref["B"])
axes_ref["C"].sharey(axes_ref["D"])
axes_ref["B"].yaxis.set_tick_params(which="both", labelleft=False)
axes_ref["D"].yaxis.set_tick_params(which="both", labelleft=False)

@check_figures_equal(extensions=["png"])
def test_sharex_col(self, fig_test, fig_ref):
fig_test.subplot_mosaic([["A", "B"], ["C", "D"]],
sharex="col", sharey=False)
axes_ref = fig_ref.subplot_mosaic(
[
["A", "B"],
["C", "D"]
],
sharex=False,
sharey=False
)
axes_ref["A"].sharex(axes_ref["B"])
axes_ref["B"].sharex(axes_ref["D"])
axes_ref["A"].xaxis.set_tick_params(which="both", labelbottom=False)
axes_ref["B"].xaxis.set_tick_params(which="both", labelbottom=False)

@check_figures_equal(extensions=["png"])
def test_sharey_col(self, fig_test, fig_ref):
fig_test.subplot_mosaic([["A", "B"], ["C", "D"]],
sharex=False, sharey="col")

axes_ref = fig_ref.subplot_mosaic(
[
["A", "B"],
["C", "D"]
],
sharex=False,
sharey=False
)
axes_ref["A"].sharey(axes_ref["C"])
axes_ref["B"].sharey(axes_ref["D"])

@pytest.mark.parametrize(
"sharex,sharey",
[
("row", False),
(False, "row"),
("col", False),
(False, "col"),
("row", "col")
]
)
def test_share_row_col_fails_if_nested_mosaic(self, sharex, sharey):
mosaic = [
["A", [["B", "C"],
["D", "E"]]],
["F", "G"],
[".", [["H", [["I"],
["."]]]]]
]
fig = plt.figure()
with pytest.raises(ValueError):
fig.subplot_mosaic(mosaic, sharex=sharex, sharey=sharey)


def test_reused_gridspec():
"""Test that these all use the same gridspec"""
Expand Down