diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index 9aed565bebc8..395794194a8b 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -1870,11 +1870,18 @@ def subplot_mosaic(self, mosaic, *, sharex=False, sharey=False, The Axes identifiers may be `str` or a non-iterable hashable object (e.g. `tuple` s may not be used). - sharex, sharey : bool, default: False - If True, the x-axis (*sharex*) or y-axis (*sharey*) will be shared - among all subplots. In that case, tick label visibility and axis - units behave as for `subplots`. If False, each subplot's x- or - y-axis will be independent. + sharex, sharey : bool or {'none', 'all', 'row', 'col'}, default: False + Controls sharing of x-axis (*sharex*) or y-axis (*sharey*): + + - True or 'all': x- or y-axis will be shared among all subplots. + - False or 'none': each subplot x- or y-axis will be independent. + - 'row': each subplot with the same rows span will share an x- or + y-axis. + - 'col': each subplot with the same column span will share an x- or + y-axis. + + Tick label visibility and axis units behave as for `subplots`. + width_ratios : array-like of length *ncols*, optional Defines the relative widths of the columns. Each column gets a @@ -1933,6 +1940,11 @@ def subplot_mosaic(self, mosaic, *, sharex=False, sharey=False, gridspec_kw = dict(gridspec_kw or {}) per_subplot_kw = per_subplot_kw or {} + if not isinstance(sharex, str): + sharex = "all" if sharex else "none" + if not isinstance(sharey, str): + sharey = "all" if sharey else "none" + if height_ratios is not None: if 'height_ratios' in gridspec_kw: raise ValueError("'height_ratios' must not be defined both as " @@ -1954,7 +1966,7 @@ def subplot_mosaic(self, mosaic, *, sharex=False, sharey=False, per_subplot_kw = self._norm_per_subplot_kw(per_subplot_kw) # Only accept strict bools to allow a possible future API expansion. - _api.check_isinstance(bool, sharex=sharex, sharey=sharey) + _api.check_isinstance((bool, str), sharex=sharex, sharey=sharey) def _make_array(inp): """ @@ -2014,6 +2026,33 @@ def _identify_keys_and_nested(mosaic): return tuple(unique_ids), nested + def _parse_mosaic_to_span(mosaic, unique_ids): + """ + Maps the mosaic label/ids to the row and column span. + + Returns + ------- + dict[str, (row_slice, col_slice)] + """ + ids_to_span = {} + for id_ in unique_ids: + # sort out where each axes starts/ends + indx = np.argwhere(mosaic == id_) + start_row, start_col = np.min(indx, axis=0) + end_row, end_col = np.max(indx, axis=0) + 1 + # and construct the slice object + slc = (slice(start_row, end_row), slice(start_col, end_col)) + + if (mosaic[slc] != id_).any(): + raise ValueError( + f"While trying to layout\n{mosaic!r}\n" + f"we found that the label {id_!r} specifies a " + "non-rectangular or non-contiguous area.") + + ids_to_span[id_] = slc + + return ids_to_span + def _do_layout(gs, mosaic, unique_ids, nested): """ Recursively do the mosaic. @@ -2044,22 +2083,14 @@ def _do_layout(gs, mosaic, unique_ids, nested): # nested mosaic) at this level this_level = dict() + label_to_span = _parse_mosaic_to_span(mosaic, unique_ids) + # go through the unique keys, - for name in unique_ids: - # sort out where each axes starts/ends - indx = np.argwhere(mosaic == name) - start_row, start_col = np.min(indx, axis=0) - end_row, end_col = np.max(indx, axis=0) + 1 - # and construct the slice object - slc = (slice(start_row, end_row), slice(start_col, end_col)) - # some light error checking - if (mosaic[slc] != name).any(): - raise ValueError( - f"While trying to layout\n{mosaic!r}\n" - f"we found that the label {name!r} specifies a " - "non-rectangular or non-contiguous area.") + for label, slc in label_to_span.items(): # and stash this slice for later - this_level[(start_row, start_col)] = (name, slc, 'axes') + start_row = slc[0].start + start_col = slc[1].start + this_level[(start_row, start_col)] = (label, slc, 'axes') # do the same thing for the nested mosaics (simpler because these # cannot be spans yet!) @@ -2069,24 +2100,25 @@ def _do_layout(gs, mosaic, unique_ids, nested): # now go through the things in this level and add them # in order left-to-right top-to-bottom for key in sorted(this_level): - name, arg, method = this_level[key] + label, arg, method = this_level[key] # we are doing some hokey function dispatch here based # on the 'method' string stashed above to sort out if this # element is an Axes or a nested mosaic. if method == 'axes': slc = arg # add a single axes - if name in output: - raise ValueError(f"There are duplicate keys {name} " + if label in output: + raise ValueError(f"There are duplicate keys {label} " f"in the layout\n{mosaic!r}") + ax = self.add_subplot( gs[slc], **{ - 'label': str(name), + 'label': str(label), **subplot_kw, - **per_subplot_kw.get(name, {}) + **per_subplot_kw.get(label, {}) } ) - output[name] = ax + output[label] = ax elif method == 'nested': nested_mosaic = arg j, k = key @@ -2113,14 +2145,81 @@ def _do_layout(gs, mosaic, unique_ids, nested): rows, cols = mosaic.shape gs = self.add_gridspec(rows, cols, **gridspec_kw) ret = _do_layout(gs, mosaic, *_identify_keys_and_nested(mosaic)) - ax0 = next(iter(ret.values())) - for ax in ret.values(): - if sharex: + + # Handle axes sharing + def _find_row_col_groups(mosaic, unique_labels): + label_to_span = _parse_mosaic_to_span(mosaic, unique_labels) + + row_group = {} + col_group = {} + for label, (row_slice, col_slice) in label_to_span.items(): + start_row, end_row = row_slice.start, row_slice.stop + start_col, end_col = col_slice.start, col_slice.stop + + if (start_col, end_col) not in col_group: + col_group[(start_col, end_col)] = [label] + else: + col_group[(start_col, end_col)].append(label) + + if (start_row, end_row) not in row_group: + row_group[(start_row, end_row)] = [label] + else: + row_group[(start_row, end_row)].append(label) + + return row_group, col_group + + # Pairs of axes where the first axes is meant to call sharex/sharey on + # the second axes. The second axes depends on if the sharex/sharey is + # set to "row", "col", or "all". + + def check_is_nested(): + gs0 = self.axes[0].get_subplotspec().get_gridspec() + for ax in self.axes[1:]: + if ax.get_subplotspec().get_gridspec() != gs0: + return True + return False + + share_axes_pairs = { + "all": tuple( + (ax, next(iter(ret.values()))) + for ax in ret.values() + ) + } + if sharex in ("row", "col") or sharey in ("row", "col"): + if check_is_nested(): + raise ValueError( + "Cannot share axes by row or column when using nested mosaic" + ) + else: + row_groups, col_groups = _find_row_col_groups(mosaic, ret.keys()) + + share_axes_pairs.update({ + "row": tuple( + (ret[label], ret[row_group[0]]) + for row_group in row_groups.values() + for label in row_group + ), + "col": tuple( + (ret[label], ret[col_group[0]]) + for col_group in col_groups.values() + for label in col_group + ), + }) + + if sharex in share_axes_pairs: + for ax, ax0 in share_axes_pairs[sharex]: ax.sharex(ax0) - ax._label_outer_xaxis(check_patch=True) - if sharey: + if sharex in ["col", "all"] and ax0 is not ax: + ax0._label_outer_xaxis(check_patch=True) + ax._label_outer_xaxis(check_patch=True) + + if sharey in share_axes_pairs: + for ax, ax0 in share_axes_pairs[sharey]: ax.sharey(ax0) - ax._label_outer_yaxis(check_patch=True) + if sharey in ["row", "all"] and ax0 is not ax: + ax0._label_outer_yaxis(check_patch=True) + ax._label_outer_yaxis(check_patch=True) + if extra := set(per_subplot_kw) - set(ret): raise ValueError( f"The keys {extra} are in *per_subplot_kw* " @@ -2154,7 +2253,7 @@ class SubFigure(FigureBase): See :doc:`/gallery/subplots_axes_and_figures/subfigures` .. note:: - The *subfigure* concept is new in v3.4, and the API is still provisional. + The *subfigure* concept is new in v3.4, and API is still provisional. """ def __init__(self, parent, subplotspec, *, @@ -3204,6 +3303,7 @@ def __setstate__(self, state): if restore_to_pylab: # lazy import to avoid circularity import matplotlib.pyplot as plt + import matplotlib._pylab_helpers as pylab_helpers allnums = plt.get_fignums() num = max(allnums) + 1 if allnums else 1 diff --git a/lib/matplotlib/figure.pyi b/lib/matplotlib/figure.pyi index 2b2bd9a49326..08d0b759ff66 100644 --- a/lib/matplotlib/figure.pyi +++ b/lib/matplotlib/figure.pyi @@ -231,8 +231,8 @@ class FigureBase(Artist): self, mosaic: str | HashableList, *, - sharex: bool = ..., - sharey: bool = ..., + sharex: bool | Literal["none", "col", "row", "all"] = ..., + sharey: bool | Literal["none", "col", "row", "all"] = ..., width_ratios: ArrayLike | None = ..., height_ratios: ArrayLike | None = ..., empty_sentinel: Any = ..., diff --git a/lib/matplotlib/tests/test_figure.py b/lib/matplotlib/tests/test_figure.py index cfd5f27de13a..009106890841 100644 --- a/lib/matplotlib/tests/test_figure.py +++ b/lib/matplotlib/tests/test_figure.py @@ -1234,7 +1234,8 @@ def test_nested_user_order(self): assert list(ax_dict) == list("ABCDEFGHI") assert list(fig.axes) == list(ax_dict.values()) - def test_share_all(self): + @pytest.mark.parametrize("sharex,sharey", [(True, True), ("all", "all")]) + def test_share_all(self, sharex, sharey): layout = [ ["A", [["B", "C"], ["D", "E"]]], @@ -1243,11 +1244,100 @@ def test_share_all(self): ["."]]]]] ] fig = plt.figure() - ax_dict = fig.subplot_mosaic(layout, sharex=True, sharey=True) + ax_dict = fig.subplot_mosaic(layout, sharex=sharex, sharey=sharey) ax_dict["A"].set(xscale="log", yscale="logit") assert all(ax.get_xscale() == "log" and ax.get_yscale() == "logit" for ax in ax_dict.values()) + @check_figures_equal(extensions=["png"]) + def test_sharex_row(self, fig_test, fig_ref): + fig_test.subplot_mosaic([["A", "B"], ["C", "D"]], + sharex="row", sharey=False) + + axes_ref = fig_ref.subplot_mosaic( + [ + ["A", "B"], + ["C", "D"] + ], + sharex=False, + sharey=False + ) + axes_ref["A"].sharex(axes_ref["B"]) + axes_ref["C"].sharex(axes_ref["D"]) + + @check_figures_equal(extensions=["png"]) + def test_sharey_row(self, fig_test, fig_ref): + fig_test.subplot_mosaic([["A", "B"], ["C", "D"]], + sharex=False, sharey="row") + + axes_ref = fig_ref.subplot_mosaic( + [ + ["A", "B"], + ["C", "D"] + ], + sharex=False, + sharey=False + ) + axes_ref["A"].sharey(axes_ref["B"]) + axes_ref["C"].sharey(axes_ref["D"]) + axes_ref["B"].yaxis.set_tick_params(which="both", labelleft=False) + axes_ref["D"].yaxis.set_tick_params(which="both", labelleft=False) + + @check_figures_equal(extensions=["png"]) + def test_sharex_col(self, fig_test, fig_ref): + fig_test.subplot_mosaic([["A", "B"], ["C", "D"]], + sharex="col", sharey=False) + axes_ref = fig_ref.subplot_mosaic( + [ + ["A", "B"], + ["C", "D"] + ], + sharex=False, + sharey=False + ) + axes_ref["A"].sharex(axes_ref["B"]) + axes_ref["B"].sharex(axes_ref["D"]) + axes_ref["A"].xaxis.set_tick_params(which="both", labelbottom=False) + axes_ref["B"].xaxis.set_tick_params(which="both", labelbottom=False) + + @check_figures_equal(extensions=["png"]) + def test_sharey_col(self, fig_test, fig_ref): + fig_test.subplot_mosaic([["A", "B"], ["C", "D"]], + sharex=False, sharey="col") + + axes_ref = fig_ref.subplot_mosaic( + [ + ["A", "B"], + ["C", "D"] + ], + sharex=False, + sharey=False + ) + axes_ref["A"].sharey(axes_ref["C"]) + axes_ref["B"].sharey(axes_ref["D"]) + + @pytest.mark.parametrize( + "sharex,sharey", + [ + ("row", False), + (False, "row"), + ("col", False), + (False, "col"), + ("row", "col") + ] + ) + def test_share_row_col_fails_if_nested_mosaic(self, sharex, sharey): + mosaic = [ + ["A", [["B", "C"], + ["D", "E"]]], + ["F", "G"], + [".", [["H", [["I"], + ["."]]]]] + ] + fig = plt.figure() + with pytest.raises(ValueError): + fig.subplot_mosaic(mosaic, sharex=sharex, sharey=sharey) + def test_reused_gridspec(): """Test that these all use the same gridspec"""