From ce1351dcff59058b3bd054feccaf780f303e193c Mon Sep 17 00:00:00 2001 From: Keto Zhang Date: Sat, 12 Aug 2023 14:14:11 -0700 Subject: [PATCH] Adds row and column axes sharing to simple subplot mosaics --- lib/matplotlib/figure.py | 195 +++++++++++++++++++++------- lib/matplotlib/figure.pyi | 4 +- lib/matplotlib/tests/test_figure.py | 106 +++++++++++++-- 3 files changed, 244 insertions(+), 61 deletions(-) diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index ce263c3d8d1c..4a4ca68f2a19 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -30,37 +30,41 @@ :ref:`figure_explanation`. """ -from contextlib import ExitStack import inspect import itertools import logging -from numbers import Integral import threading +from contextlib import ExitStack +from numbers import Integral import numpy as np import matplotlib as mpl -from matplotlib import _blocking_input, backend_bases, _docstring, projections -from matplotlib.artist import ( - Artist, allow_rasterization, _finalize_rasterization) -from matplotlib.backend_bases import ( - DrawEvent, FigureCanvasBase, NonGuiException, MouseButton, _get_renderer) import matplotlib._api as _api import matplotlib.cbook as cbook import matplotlib.colorbar as cbar import matplotlib.image as mimage - +import matplotlib.legend as mlegend +from matplotlib import _blocking_input, _docstring, backend_bases, projections +from matplotlib.artist import Artist, _finalize_rasterization, allow_rasterization from matplotlib.axes import Axes +from matplotlib.backend_bases import ( + DrawEvent, + FigureCanvasBase, + MouseButton, + NonGuiException, + _get_renderer, +) from matplotlib.gridspec import GridSpec from matplotlib.layout_engine import ( - ConstrainedLayoutEngine, TightLayoutEngine, LayoutEngine, - PlaceHolderLayoutEngine + ConstrainedLayoutEngine, + LayoutEngine, + PlaceHolderLayoutEngine, + TightLayoutEngine, ) -import matplotlib.legend as mlegend from matplotlib.patches import Rectangle from matplotlib.text import Text -from matplotlib.transforms import (Affine2D, Bbox, BboxTransformTo, - TransformedBbox) +from matplotlib.transforms import Affine2D, Bbox, BboxTransformTo, TransformedBbox _log = logging.getLogger(__name__) @@ -1871,11 +1875,18 @@ def subplot_mosaic(self, mosaic, *, sharex=False, sharey=False, The Axes identifiers may be `str` or a non-iterable hashable object (e.g. `tuple` s may not be used). - sharex, sharey : bool, default: False - If True, the x-axis (*sharex*) or y-axis (*sharey*) will be shared - among all subplots. In that case, tick label visibility and axis - units behave as for `subplots`. If False, each subplot's x- or - y-axis will be independent. + sharex, sharey : bool or {'none', 'all', 'row', 'col'}, default: False + Controls sharing of x-axis (*sharex*) or y-axis (*sharey*): + + - True or 'all': x- or y-axis will be shared among all subplots. + - False or 'none': each subplot x- or y-axis will be independent. + - 'row': each subplot with the same rows span will share an x- or + y-axis. + - 'col': each subplot with the same column span will share an x- or + y-axis. + + Tick label visibility and axis units behave as for `subplots`. + width_ratios : array-like of length *ncols*, optional Defines the relative widths of the columns. Each column gets a @@ -1934,6 +1945,13 @@ def subplot_mosaic(self, mosaic, *, sharex=False, sharey=False, gridspec_kw = dict(gridspec_kw or {}) per_subplot_kw = per_subplot_kw or {} + # Only accept strict bool and str to allow a possible future API expansion. + _api.check_isinstance((bool, str), sharex=sharex, sharey=sharey) + if not isinstance(sharex, str): + sharex = "all" if sharex else "none" + if not isinstance(sharey, str): + sharey = "all" if sharey else "none" + if height_ratios is not None: if 'height_ratios' in gridspec_kw: raise ValueError("'height_ratios' must not be defined both as " @@ -1954,9 +1972,6 @@ def subplot_mosaic(self, mosaic, *, sharex=False, sharey=False, per_subplot_kw = self._norm_per_subplot_kw(per_subplot_kw) - # Only accept strict bools to allow a possible future API expansion. - _api.check_isinstance(bool, sharex=sharex, sharey=sharey) - def _make_array(inp): """ Convert input into 2D array @@ -2015,6 +2030,33 @@ def _identify_keys_and_nested(mosaic): return tuple(unique_ids), nested + def _parse_mosaic_to_span(mosaic, unique_ids): + """ + Maps the mosaic label/ids to the row and column span. + + Returns + ------- + dict[str, (row_slice, col_slice)] + """ + ids_to_span = {} + for id_ in unique_ids: + # sort out where each axes starts/ends + indx = np.argwhere(mosaic == id_) + start_row, start_col = np.min(indx, axis=0) + end_row, end_col = np.max(indx, axis=0) + 1 + # and construct the slice object + slc = (slice(start_row, end_row), slice(start_col, end_col)) + + if (mosaic[slc] != id_).any(): + raise ValueError( + f"While trying to layout\n{mosaic!r}\n" + f"we found that the label {id_!r} specifies a " + "non-rectangular or non-contiguous area.") + + ids_to_span[id_] = slc + + return ids_to_span + def _do_layout(gs, mosaic, unique_ids, nested): """ Recursively do the mosaic. @@ -2045,22 +2087,14 @@ def _do_layout(gs, mosaic, unique_ids, nested): # nested mosaic) at this level this_level = dict() + label_to_span = _parse_mosaic_to_span(mosaic, unique_ids) + # go through the unique keys, - for name in unique_ids: - # sort out where each axes starts/ends - indx = np.argwhere(mosaic == name) - start_row, start_col = np.min(indx, axis=0) - end_row, end_col = np.max(indx, axis=0) + 1 - # and construct the slice object - slc = (slice(start_row, end_row), slice(start_col, end_col)) - # some light error checking - if (mosaic[slc] != name).any(): - raise ValueError( - f"While trying to layout\n{mosaic!r}\n" - f"we found that the label {name!r} specifies a " - "non-rectangular or non-contiguous area.") + for label, slc in label_to_span.items(): # and stash this slice for later - this_level[(start_row, start_col)] = (name, slc, 'axes') + start_row = slc[0].start + start_col = slc[1].start + this_level[(start_row, start_col)] = (label, slc, 'axes') # do the same thing for the nested mosaics (simpler because these # cannot be spans yet!) @@ -2070,24 +2104,25 @@ def _do_layout(gs, mosaic, unique_ids, nested): # now go through the things in this level and add them # in order left-to-right top-to-bottom for key in sorted(this_level): - name, arg, method = this_level[key] + label, arg, method = this_level[key] # we are doing some hokey function dispatch here based # on the 'method' string stashed above to sort out if this # element is an Axes or a nested mosaic. if method == 'axes': slc = arg # add a single axes - if name in output: - raise ValueError(f"There are duplicate keys {name} " + if label in output: + raise ValueError(f"There are duplicate keys {label} " f"in the layout\n{mosaic!r}") + ax = self.add_subplot( gs[slc], **{ - 'label': str(name), + 'label': str(label), **subplot_kw, - **per_subplot_kw.get(name, {}) + **per_subplot_kw.get(label, {}) } ) - output[name] = ax + output[label] = ax elif method == 'nested': nested_mosaic = arg j, k = key @@ -2113,15 +2148,75 @@ def _do_layout(gs, mosaic, unique_ids, nested): mosaic = _make_array(mosaic) rows, cols = mosaic.shape gs = self.add_gridspec(rows, cols, **gridspec_kw) - ret = _do_layout(gs, mosaic, *_identify_keys_and_nested(mosaic)) - ax0 = next(iter(ret.values())) - for ax in ret.values(): - if sharex: + unique_labels, nested_coord_to_labels = _identify_keys_and_nested(mosaic) + ret = _do_layout(gs, mosaic, unique_labels, nested_coord_to_labels) + + # Handle axes sharing + + def _find_row_col_groups(mosaic, unique_labels): + label_to_span = _parse_mosaic_to_span(mosaic, unique_labels) + + row_group = {} + col_group = {} + for label, (row_slice, col_slice) in label_to_span.items(): + start_row, end_row = row_slice.start, row_slice.stop + start_col, end_col = col_slice.start, col_slice.stop + + if (start_col, end_col) not in col_group: + col_group[(start_col, end_col)] = [label] + else: + col_group[(start_col, end_col)].append(label) + + if (start_row, end_row) not in row_group: + row_group[(start_row, end_row)] = [label] + else: + row_group[(start_row, end_row)].append(label) + + return row_group, col_group + + # Pairs of axes where the first axes is meant to call sharex/sharey on + # the second axes. The second axes depends on if the sharex/sharey is + # set to "row", "col", or "all". + share_axes_pairs = { + "all": tuple( + (ax, next(iter(ret.values()))) for ax in ret.values() + ) + } + if sharex in ("row", "col") or sharey in ("row", "col"): + if nested_coord_to_labels: + raise ValueError( + "Cannot share axes by row or column when using nested mosaic" + ) + else: + row_groups, col_groups = _find_row_col_groups(mosaic, unique_labels) + + share_axes_pairs.update({ + "row": tuple( + (ret[label], ret[row_group[0]]) + for row_group in row_groups.values() + for label in row_group + ), + "col": tuple( + (ret[label], ret[col_group[0]]) + for col_group in col_groups.values() + for label in col_group + ), + }) + + if sharex in share_axes_pairs: + for ax, ax0 in share_axes_pairs[sharex]: ax.sharex(ax0) - ax._label_outer_xaxis(skip_non_rectangular_axes=True) - if sharey: + if sharex in ["col", "all"] and ax0 is not ax: + ax0._label_outer_xaxis(skip_non_rectangular_axes=True) + ax._label_outer_xaxis(skip_non_rectangular_axes=True) + + if sharey in share_axes_pairs: + for ax, ax0 in share_axes_pairs[sharey]: ax.sharey(ax0) - ax._label_outer_yaxis(skip_non_rectangular_axes=True) + if sharey in ["row", "all"] and ax0 is not ax: + ax0._label_outer_yaxis(skip_non_rectangular_axes=True) + ax._label_outer_yaxis(skip_non_rectangular_axes=True) + if extra := set(per_subplot_kw) - set(ret): raise ValueError( f"The keys {extra} are in *per_subplot_kw* " @@ -2155,7 +2250,7 @@ class SubFigure(FigureBase): See :doc:`/gallery/subplots_axes_and_figures/subfigures` .. note:: - The *subfigure* concept is new in v3.4, and the API is still provisional. + The *subfigure* concept is new in v3.4, and API is still provisional. """ def __init__(self, parent, subplotspec, *, @@ -3213,8 +3308,8 @@ def __setstate__(self, state): if restore_to_pylab: # lazy import to avoid circularity - import matplotlib.pyplot as plt import matplotlib._pylab_helpers as pylab_helpers + import matplotlib.pyplot as plt allnums = plt.get_fignums() num = max(allnums) + 1 if allnums else 1 backend = plt._get_backend_mod() diff --git a/lib/matplotlib/figure.pyi b/lib/matplotlib/figure.pyi index 40e8fce0321f..72b1679cdc2f 100644 --- a/lib/matplotlib/figure.pyi +++ b/lib/matplotlib/figure.pyi @@ -232,8 +232,8 @@ class FigureBase(Artist): self, mosaic: str | HashableList, *, - sharex: bool = ..., - sharey: bool = ..., + sharex: bool | Literal["none", "col", "row", "all"] = ..., + sharey: bool | Literal["none", "col", "row", "all"] = ..., width_ratios: ArrayLike | None = ..., height_ratios: ArrayLike | None = ..., empty_sentinel: Any = ..., diff --git a/lib/matplotlib/tests/test_figure.py b/lib/matplotlib/tests/test_figure.py index 6d6a3d772f4e..9607cf61963d 100644 --- a/lib/matplotlib/tests/test_figure.py +++ b/lib/matplotlib/tests/test_figure.py @@ -8,22 +8,23 @@ from types import SimpleNamespace import warnings -import numpy as np -import pytest from PIL import Image +import pytest + +import matplotlib.pyplot as plt +import numpy as np import matplotlib as mpl from matplotlib import gridspec -from matplotlib.testing.decorators import image_comparison, check_figures_equal from matplotlib.axes import Axes from matplotlib.backend_bases import KeyEvent, MouseEvent +import matplotlib.dates as mdates from matplotlib.figure import Figure, FigureBase from matplotlib.layout_engine import (ConstrainedLayoutEngine, - TightLayoutEngine, - PlaceHolderLayoutEngine) + PlaceHolderLayoutEngine, + TightLayoutEngine) +from matplotlib.testing.decorators import check_figures_equal, image_comparison from matplotlib.ticker import AutoMinorLocator, FixedFormatter, ScalarFormatter -import matplotlib.pyplot as plt -import matplotlib.dates as mdates @image_comparison(['figure_align_labels'], extensions=['png', 'svg'], @@ -1234,7 +1235,8 @@ def test_nested_user_order(self): assert list(ax_dict) == list("ABCDEFGHI") assert list(fig.axes) == list(ax_dict.values()) - def test_share_all(self): + @pytest.mark.parametrize("sharex,sharey", [(True, True), ("all", "all")]) + def test_share_all(self, sharex, sharey): layout = [ ["A", [["B", "C"], ["D", "E"]]], @@ -1243,11 +1245,97 @@ def test_share_all(self): ["."]]]]] ] fig = plt.figure() - ax_dict = fig.subplot_mosaic(layout, sharex=True, sharey=True) + ax_dict = fig.subplot_mosaic(layout, sharex=sharex, sharey=sharey) ax_dict["A"].set(xscale="log", yscale="logit") assert all(ax.get_xscale() == "log" and ax.get_yscale() == "logit" for ax in ax_dict.values()) + @check_figures_equal(extensions=["png"]) + @pytest.mark.parametrize("sharex,sharey", [("row", "none"), ("none", "row")]) + def test_share_row(self, fig_test, fig_ref, sharex, sharey): + axes_test = fig_test.subplot_mosaic( + [["A", "B"], ["C", "D"]], + sharex=sharex, sharey=sharey + ) + axes_test["A"].set(xscale="log", yscale="logit") + if sharex == "row": + assert axes_test["A"].get_xscale() == axes_test["B"].get_xscale() + if sharey == "row": + assert axes_test["A"].get_yscale() == axes_test["B"].get_yscale() + + axes_ref = fig_ref.subplot_mosaic( + [ + ["A", "B"], + ["C", "D"] + ], + sharex=False, + sharey=False + ) + if sharex == "row": + axes_ref["A"].sharex(axes_ref["B"]) + axes_ref["C"].sharex(axes_ref["D"]) + if sharey == "row": + axes_ref["A"].sharey(axes_ref["B"]) + axes_ref["C"].sharey(axes_ref["D"]) + axes_ref["B"].yaxis.set_tick_params(which="both", labelleft=False) + axes_ref["D"].yaxis.set_tick_params(which="both", labelleft=False) + + axes_ref["A"].set(xscale="log", yscale="logit") + + @check_figures_equal(extensions=["png"]) + @pytest.mark.parametrize("sharex,sharey", [("col", "none"), ("none", "col")]) + def test_share_col(self, fig_test, fig_ref, sharex, sharey): + axes_test = fig_test.subplot_mosaic( + [["A", "B"], ["C", "D"]], + sharex=sharex, sharey=sharey + ) + axes_test["A"].set(xscale="log", yscale="logit") + if sharex == "col": + assert axes_test["A"].get_xscale() == axes_test["C"].get_xscale() + if sharey == "col": + assert axes_test["A"].get_yscale() == axes_test["C"].get_yscale() + + axes_ref = fig_ref.subplot_mosaic( + [ + ["A", "B"], + ["C", "D"] + ], + sharex=False, + sharey=False + ) + if sharex == "col": + axes_ref["A"].sharex(axes_ref["C"]) + axes_ref["B"].sharex(axes_ref["D"]) + axes_ref["A"].xaxis.set_tick_params(which="both", labelbottom=False) + axes_ref["B"].xaxis.set_tick_params(which="both", labelbottom=False) + if sharey == "col": + axes_ref["A"].sharey(axes_ref["C"]) + axes_ref["B"].sharey(axes_ref["D"]) + + axes_ref["A"].set(xscale="log", yscale="logit") + + @pytest.mark.parametrize( + "sharex,sharey", + [ + ("row", False), + (False, "row"), + ("col", False), + (False, "col"), + ("row", "col") + ] + ) + def test_share_row_col_fails_if_nested_mosaic(self, sharex, sharey): + mosaic = [ + ["A", [["B", "C"], + ["D", "E"]]], + ["F", "G"], + [".", [["H", [["I"], + ["."]]]]] + ] + fig = plt.figure() + with pytest.raises(ValueError): + fig.subplot_mosaic(mosaic, sharex=sharex, sharey=sharey) + def test_reused_gridspec(): """Test that these all use the same gridspec"""