From 878de0e296dee8e91e82cb85ad4dff90f25e1d88 Mon Sep 17 00:00:00 2001 From: "Keto D. Zhang" Date: Sat, 15 Jul 2023 11:28:13 -0500 Subject: [PATCH 01/12] Add sharex on row and col using simple list iteration --- lib/matplotlib/figure.py | 67 +++++++++++++++++++++++++++------------- 1 file changed, 45 insertions(+), 22 deletions(-) diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index 9aed565bebc8..e1f88729c298 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -39,22 +39,20 @@ import numpy as np import matplotlib as mpl -from matplotlib import _blocking_input, backend_bases, _docstring, projections -from matplotlib.artist import ( - Artist, allow_rasterization, _finalize_rasterization) -from matplotlib.backend_bases import ( - DrawEvent, FigureCanvasBase, NonGuiException, MouseButton, _get_renderer) +from matplotlib import _blocking_input, _docstring, backend_bases, projections import matplotlib._api as _api +from matplotlib.artist import (Artist, _finalize_rasterization, + allow_rasterization) +from matplotlib.axes import Axes +from matplotlib.backend_bases import (DrawEvent, FigureCanvasBase, MouseButton, + NonGuiException, _get_renderer) import matplotlib.cbook as cbook import matplotlib.colorbar as cbar -import matplotlib.image as mimage - -from matplotlib.axes import Axes from matplotlib.gridspec import GridSpec -from matplotlib.layout_engine import ( - ConstrainedLayoutEngine, TightLayoutEngine, LayoutEngine, - PlaceHolderLayoutEngine -) +import matplotlib.image as mimage +from matplotlib.layout_engine import (ConstrainedLayoutEngine, LayoutEngine, + PlaceHolderLayoutEngine, + TightLayoutEngine) import matplotlib.legend as mlegend from matplotlib.patches import Rectangle from matplotlib.text import Text @@ -1954,7 +1952,7 @@ def subplot_mosaic(self, mosaic, *, sharex=False, sharey=False, per_subplot_kw = self._norm_per_subplot_kw(per_subplot_kw) # Only accept strict bools to allow a possible future API expansion. - _api.check_isinstance(bool, sharex=sharex, sharey=sharey) + # _api.check_isinstance(bool, sharex=sharex, sharey=sharey) def _make_array(inp): """ @@ -2113,14 +2111,38 @@ def _do_layout(gs, mosaic, unique_ids, nested): rows, cols = mosaic.shape gs = self.add_gridspec(rows, cols, **gridspec_kw) ret = _do_layout(gs, mosaic, *_identify_keys_and_nested(mosaic)) - ax0 = next(iter(ret.values())) - for ax in ret.values(): - if sharex: - ax.sharex(ax0) - ax._label_outer_xaxis(check_patch=True) - if sharey: - ax.sharey(ax0) - ax._label_outer_yaxis(check_patch=True) + if sharex: + if sharex == "row": + for row in range(mosaic.shape[0]): + for col in range(mosaic.shape[1]): + if col == 0: + continue + ax0 = ret[mosaic[row][0]] + ax = ret[mosaic[row][col]] + ax.sharex(ax0) + ax._label_outer_xaxis(check_patch=True) + elif sharex == "col": + for col in range(mosaic.shape[1]): + for row in range(mosaic.shape[0]): + if row == 0: + continue + ax0 = ret[mosaic[0][col]] + ax = ret[mosaic[row][col]] + ax.sharex(ax0) + ax._label_outer_xaxis(check_patch=True) + elif sharex is True: + ax0 = next(iter(ret.values())) + for ax in ret.values(): + if sharex: + ax.sharex(ax0) + ax._label_outer_xaxis(check_patch=True) + + if sharey: + ax0 = next(iter(ret.values())) + for ax in ret.values(): + if sharey: + ax.sharey(ax0) + ax._label_outer_yaxis(check_patch=True) if extra := set(per_subplot_kw) - set(ret): raise ValueError( f"The keys {extra} are in *per_subplot_kw* " @@ -2154,7 +2176,7 @@ class SubFigure(FigureBase): See :doc:`/gallery/subplots_axes_and_figures/subfigures` .. note:: - The *subfigure* concept is new in v3.4, and the API is still provisional. + The *subfigure* concept is new in v3.4, and API is still provisional. """ def __init__(self, parent, subplotspec, *, @@ -3204,6 +3226,7 @@ def __setstate__(self, state): if restore_to_pylab: # lazy import to avoid circularity import matplotlib.pyplot as plt + import matplotlib._pylab_helpers as pylab_helpers allnums = plt.get_fignums() num = max(allnums) + 1 if allnums else 1 From 1d29bdd48a46a7c8601c59ca90f30014d11dd760 Mon Sep 17 00:00:00 2001 From: "Keto D. Zhang" Date: Sat, 15 Jul 2023 11:49:45 -0500 Subject: [PATCH 02/12] Fixed outer axis on sharex='col' only --- lib/matplotlib/figure.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index e1f88729c298..2dd19c29c2f9 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -2081,6 +2081,7 @@ def _do_layout(gs, mosaic, unique_ids, nested): gs[slc], **{ 'label': str(name), **subplot_kw, + # **{**subplot_kw, "sharex": "row"}, **per_subplot_kw.get(name, {}) } ) @@ -2120,7 +2121,6 @@ def _do_layout(gs, mosaic, unique_ids, nested): ax0 = ret[mosaic[row][0]] ax = ret[mosaic[row][col]] ax.sharex(ax0) - ax._label_outer_xaxis(check_patch=True) elif sharex == "col": for col in range(mosaic.shape[1]): for row in range(mosaic.shape[0]): @@ -2130,13 +2130,14 @@ def _do_layout(gs, mosaic, unique_ids, nested): ax = ret[mosaic[row][col]] ax.sharex(ax0) ax._label_outer_xaxis(check_patch=True) + ax0._label_outer_xaxis(check_patch=True) elif sharex is True: ax0 = next(iter(ret.values())) for ax in ret.values(): if sharex: ax.sharex(ax0) ax._label_outer_xaxis(check_patch=True) - + ax0._label_outer_xaxis(check_patch=True) if sharey: ax0 = next(iter(ret.values())) for ax in ret.values(): From 48c0520ee7c9a8703a1b7c96f9805b9ef47657d3 Mon Sep 17 00:00:00 2001 From: "Keto D. Zhang" Date: Sat, 15 Jul 2023 15:15:47 -0500 Subject: [PATCH 03/12] Reimplement using start and end indices to define row and col groups --- lib/matplotlib/figure.py | 89 +++++++++++++++++++++++++--------------- 1 file changed, 56 insertions(+), 33 deletions(-) diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index 2dd19c29c2f9..b5d6a6af6fae 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -2081,7 +2081,6 @@ def _do_layout(gs, mosaic, unique_ids, nested): gs[slc], **{ 'label': str(name), **subplot_kw, - # **{**subplot_kw, "sharex": "row"}, **per_subplot_kw.get(name, {}) } ) @@ -2108,42 +2107,66 @@ def _do_layout(gs, mosaic, unique_ids, nested): raise RuntimeError("This should never happen") return output + def _find_row_col_groups(unique_ids): + row_group = {} + col_group = {} + for name in unique_ids: + indx = np.argwhere(mosaic == name) + start_row, start_col = np.min(indx, axis=0) + end_row, end_col = np.max(indx, axis=0) + 1 + # sort out where each axes starts/ends + # and construct the slice object + + if (start_col, end_col) not in col_group: + col_group[(start_col, end_col)] = [name] + else: + col_group[(start_col, end_col)].append(name) + + if (start_row, end_row) not in row_group: + row_group[(start_row, end_row)] = [name] + else: + row_group[(start_row, end_row)].append(name) + + return row_group, col_group + + mosaic = _make_array(mosaic) rows, cols = mosaic.shape gs = self.add_gridspec(rows, cols, **gridspec_kw) ret = _do_layout(gs, mosaic, *_identify_keys_and_nested(mosaic)) - if sharex: - if sharex == "row": - for row in range(mosaic.shape[0]): - for col in range(mosaic.shape[1]): - if col == 0: - continue - ax0 = ret[mosaic[row][0]] - ax = ret[mosaic[row][col]] - ax.sharex(ax0) - elif sharex == "col": - for col in range(mosaic.shape[1]): - for row in range(mosaic.shape[0]): - if row == 0: - continue - ax0 = ret[mosaic[0][col]] - ax = ret[mosaic[row][col]] - ax.sharex(ax0) - ax._label_outer_xaxis(check_patch=True) - ax0._label_outer_xaxis(check_patch=True) - elif sharex is True: - ax0 = next(iter(ret.values())) - for ax in ret.values(): - if sharex: - ax.sharex(ax0) - ax._label_outer_xaxis(check_patch=True) - ax0._label_outer_xaxis(check_patch=True) - if sharey: - ax0 = next(iter(ret.values())) - for ax in ret.values(): - if sharey: - ax.sharey(ax0) - ax._label_outer_yaxis(check_patch=True) + row_groups, col_groups = _find_row_col_groups(_identify_keys_and_nested(mosaic)[0]) + if sharex == "row": + for row_group in row_groups.values(): + for name in row_group[1:]: + ret[name].sharex(ret[row_group[0]]) + # ret[name]._label_outer_xaxis(check_patch=True) + + # ret[row_group[0]]._label_outer_xaxis(check_patch=True) + + if sharex == "col": + for col_group in col_groups.values(): + for name in col_group[1:]: + ret[name].sharex(ret[col_group[0]]) + ret[name]._label_outer_xaxis(check_patch=True) + + if len(col_group) > 1: + ret[col_group[0]]._label_outer_xaxis(check_patch=True) + + if sharey == "row": + for row_group in row_groups.values(): + for name in row_group[1:]: + ret[name].sharey(ret[row_group[0]]) + ret[name]._label_outer_yaxis(check_patch=True) + + if len(row_group) > 1: + ret[row_group[0]]._label_outer_yaxis(check_patch=True) + + if sharey == "col": + for col_group in col_groups.values(): + for name in col_group[1:]: + ret[name].sharey(ret[col_group[0]]) + # ret[name]._label_outer_yaxis(check_patch=True) + if extra := set(per_subplot_kw) - set(ret): raise ValueError( f"The keys {extra} are in *per_subplot_kw* " From 890750ba1201de29934f2c4696b8348e23ea3886 Mon Sep 17 00:00:00 2001 From: "Keto D. Zhang" Date: Sat, 15 Jul 2023 17:38:55 -0500 Subject: [PATCH 04/12] Combine handling shared axis 'all', 'row', 'col' into one loop --- lib/matplotlib/figure.py | 150 ++++++++++++++++++++++----------------- 1 file changed, 86 insertions(+), 64 deletions(-) diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index b5d6a6af6fae..7d741c282df7 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -1931,6 +1931,11 @@ def subplot_mosaic(self, mosaic, *, sharex=False, sharey=False, gridspec_kw = dict(gridspec_kw or {}) per_subplot_kw = per_subplot_kw or {} + if not isinstance(sharex, str): + sharex = "all" if sharex else "none" + if not isinstance(sharey, str): + sharey = "all" if sharey else "none" + if height_ratios is not None: if 'height_ratios' in gridspec_kw: raise ValueError("'height_ratios' must not be defined both as " @@ -2012,6 +2017,33 @@ def _identify_keys_and_nested(mosaic): return tuple(unique_ids), nested + def _parse_mosaic_to_span(mosaic, unique_ids): + """ + Maps the mosaic label/ids to the row and column span. + + Returns + ------- + dict[str, (row_slice, col_slice)] + """ + ids_to_span = {} + for id_ in unique_ids: + # sort out where each axes starts/ends + indx = np.argwhere(mosaic == id_) + start_row, start_col = np.min(indx, axis=0) + end_row, end_col = np.max(indx, axis=0) + 1 + # and construct the slice object + slc = (slice(start_row, end_row), slice(start_col, end_col)) + + if (mosaic[slc] != id_).any(): + raise ValueError( + f"While trying to layout\n{mosaic!r}\n" + f"we found that the label {id_!r} specifies a " + "non-rectangular or non-contiguous area.") + + ids_to_span[id_] = slc + + return ids_to_span + def _do_layout(gs, mosaic, unique_ids, nested): """ Recursively do the mosaic. @@ -2042,22 +2074,16 @@ def _do_layout(gs, mosaic, unique_ids, nested): # nested mosaic) at this level this_level = dict() + label_to_span = _parse_mosaic_to_span(mosaic, unique_ids) + # go through the unique keys, - for name in unique_ids: - # sort out where each axes starts/ends - indx = np.argwhere(mosaic == name) - start_row, start_col = np.min(indx, axis=0) - end_row, end_col = np.max(indx, axis=0) + 1 - # and construct the slice object - slc = (slice(start_row, end_row), slice(start_col, end_col)) - # some light error checking - if (mosaic[slc] != name).any(): - raise ValueError( - f"While trying to layout\n{mosaic!r}\n" - f"we found that the label {name!r} specifies a " - "non-rectangular or non-contiguous area.") + for label, slc in label_to_span.items(): # and stash this slice for later - this_level[(start_row, start_col)] = (name, slc, 'axes') + start_row = slc[0].start + start_col = slc[1].start + this_level[(start_row, start_col)] = (label, slc, 'axes') + + # do the same thing for the nested mosaics (simpler because these # cannot be spans yet!) @@ -2067,24 +2093,26 @@ def _do_layout(gs, mosaic, unique_ids, nested): # now go through the things in this level and add them # in order left-to-right top-to-bottom for key in sorted(this_level): - name, arg, method = this_level[key] + label, arg, method = this_level[key] # we are doing some hokey function dispatch here based # on the 'method' string stashed above to sort out if this # element is an Axes or a nested mosaic. if method == 'axes': slc = arg # add a single axes - if name in output: - raise ValueError(f"There are duplicate keys {name} " + if label in output: + raise ValueError(f"There are duplicate keys {label} " f"in the layout\n{mosaic!r}") + + # shared_with = {"none": None, "all": "all", "row": "row", "col": "col"} ax = self.add_subplot( gs[slc], **{ - 'label': str(name), + 'label': str(label), **subplot_kw, - **per_subplot_kw.get(name, {}) + **per_subplot_kw.get(label, {}) } ) - output[name] = ax + output[label] = ax elif method == 'nested': nested_mosaic = arg j, k = key @@ -2107,65 +2135,59 @@ def _do_layout(gs, mosaic, unique_ids, nested): raise RuntimeError("This should never happen") return output - def _find_row_col_groups(unique_ids): + def _find_row_col_groups(mosaic, unique_labels): + label_to_span = _parse_mosaic_to_span(mosaic, unique_labels) + row_group = {} col_group = {} - for name in unique_ids: - indx = np.argwhere(mosaic == name) - start_row, start_col = np.min(indx, axis=0) - end_row, end_col = np.max(indx, axis=0) + 1 - # sort out where each axes starts/ends - # and construct the slice object + for label, (row_slice, col_slice) in label_to_span.items(): + start_row, end_row = row_slice.start, row_slice.stop + start_col, end_col = col_slice.start, col_slice.stop if (start_col, end_col) not in col_group: - col_group[(start_col, end_col)] = [name] + col_group[(start_col, end_col)] = [label] else: - col_group[(start_col, end_col)].append(name) + col_group[(start_col, end_col)].append(label) if (start_row, end_row) not in row_group: - row_group[(start_row, end_row)] = [name] + row_group[(start_row, end_row)] = [label] else: - row_group[(start_row, end_row)].append(name) + row_group[(start_row, end_row)].append(label) + - return row_group, col_group + return ( + {v[i]: v[0] for v in row_group.values() for i in range(len(v))}, + {v[i]: v[0] for v in col_group.values() for i in range(len(v))} + ) mosaic = _make_array(mosaic) rows, cols = mosaic.shape gs = self.add_gridspec(rows, cols, **gridspec_kw) ret = _do_layout(gs, mosaic, *_identify_keys_and_nested(mosaic)) - row_groups, col_groups = _find_row_col_groups(_identify_keys_and_nested(mosaic)[0]) - if sharex == "row": - for row_group in row_groups.values(): - for name in row_group[1:]: - ret[name].sharex(ret[row_group[0]]) - # ret[name]._label_outer_xaxis(check_patch=True) - - # ret[row_group[0]]._label_outer_xaxis(check_patch=True) - - if sharex == "col": - for col_group in col_groups.values(): - for name in col_group[1:]: - ret[name].sharex(ret[col_group[0]]) - ret[name]._label_outer_xaxis(check_patch=True) - - if len(col_group) > 1: - ret[col_group[0]]._label_outer_xaxis(check_patch=True) - - if sharey == "row": - for row_group in row_groups.values(): - for name in row_group[1:]: - ret[name].sharey(ret[row_group[0]]) - ret[name]._label_outer_yaxis(check_patch=True) - - if len(row_group) > 1: - ret[row_group[0]]._label_outer_yaxis(check_patch=True) - - if sharey == "col": - for col_group in col_groups.values(): - for name in col_group[1:]: - ret[name].sharey(ret[col_group[0]]) - # ret[name]._label_outer_yaxis(check_patch=True) + + # Handle axes sharing + row_groups, col_groups = _find_row_col_groups(mosaic, ret.keys()) + + for label, ax in ret.items(): + shared_with = { + "all": next(iter(ret.values())), + "row": ret[row_groups[label]], + "col": ret[col_groups[label]] + } + + if sharex in shared_with: + ax0 = shared_with[sharex] + ax.sharex(ax0) + if sharex in ["col", "all"] and ax0 is not ax: + ax0._label_outer_xaxis(check_patch=True) + ax._label_outer_xaxis(check_patch=True) + if sharey in shared_with: + ax0 = shared_with[sharey] + ax.sharey(ax0) + if sharey in ["row", "all"] and ax0 is not ax: + ax0._label_outer_yaxis(check_patch=True) + ax._label_outer_yaxis(check_patch=True) if extra := set(per_subplot_kw) - set(ret): raise ValueError( From b5afe0e3500e512201f919a7dd9d7c6c6dad3dfb Mon Sep 17 00:00:00 2001 From: "Keto D. Zhang" Date: Sat, 15 Jul 2023 22:32:03 -0500 Subject: [PATCH 05/12] Improvents to reduce dict creation {row, col, all} -> axes pairing to one --- lib/matplotlib/figure.py | 44 ++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index 7d741c282df7..4ec7dec09ae0 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -2135,6 +2135,14 @@ def _do_layout(gs, mosaic, unique_ids, nested): raise RuntimeError("This should never happen") return output + + mosaic = _make_array(mosaic) + rows, cols = mosaic.shape + gs = self.add_gridspec(rows, cols, **gridspec_kw) + ret = _do_layout(gs, mosaic, *_identify_keys_and_nested(mosaic)) + + # Handle axes sharing + def _find_row_col_groups(mosaic, unique_labels): label_to_span = _parse_mosaic_to_span(mosaic, unique_labels) @@ -2155,35 +2163,27 @@ def _find_row_col_groups(mosaic, unique_labels): row_group[(start_row, end_row)].append(label) - return ( - {v[i]: v[0] for v in row_group.values() for i in range(len(v))}, - {v[i]: v[0] for v in col_group.values() for i in range(len(v))} - ) - - - mosaic = _make_array(mosaic) - rows, cols = mosaic.shape - gs = self.add_gridspec(rows, cols, **gridspec_kw) - ret = _do_layout(gs, mosaic, *_identify_keys_and_nested(mosaic)) + return row_group, col_group - # Handle axes sharing row_groups, col_groups = _find_row_col_groups(mosaic, ret.keys()) - for label, ax in ret.items(): - shared_with = { - "all": next(iter(ret.values())), - "row": ret[row_groups[label]], - "col": ret[col_groups[label]] - } - - if sharex in shared_with: - ax0 = shared_with[sharex] + # Pairs of axes where the first axes is meant to call sharex/sharey on + # the second axes. The second axes depends on if the sharex/sharey is + # set to "row", "col", or "all". + share_axes_pairs = { + "row": ((ret[label], ret[row_group[0]]) for row_group in row_groups.values() for label in row_group), + "col": ((ret[label], ret[col_group[0]]) for col_group in col_groups.values() for label in col_group), + "all": ((ax, next(iter(ret.values()))) for ax in ret.values()) + } + if sharex in share_axes_pairs: + for ax, ax0 in share_axes_pairs[sharex]: ax.sharex(ax0) if sharex in ["col", "all"] and ax0 is not ax: ax0._label_outer_xaxis(check_patch=True) ax._label_outer_xaxis(check_patch=True) - if sharey in shared_with: - ax0 = shared_with[sharey] + + if sharey in share_axes_pairs: + for ax, ax0 in share_axes_pairs[sharey]: ax.sharey(ax0) if sharey in ["row", "all"] and ax0 is not ax: ax0._label_outer_yaxis(check_patch=True) From da9d3538dd2231fb25c6e9dc9a2a43d21ed004c4 Mon Sep 17 00:00:00 2001 From: "Keto D. Zhang" Date: Sat, 15 Jul 2023 22:48:37 -0500 Subject: [PATCH 06/12] Add back api check and update sharex, sharey docstring --- lib/matplotlib/figure.py | 41 ++++++++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index 4ec7dec09ae0..644f9a3ae7a8 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -39,20 +39,22 @@ import numpy as np import matplotlib as mpl -from matplotlib import _blocking_input, _docstring, backend_bases, projections +from matplotlib import _blocking_input, backend_bases, _docstring, projections +from matplotlib.artist import ( + Artist, allow_rasterization, _finalize_rasterization) +from matplotlib.backend_bases import ( + DrawEvent, FigureCanvasBase, NonGuiException, MouseButton, _get_renderer) import matplotlib._api as _api -from matplotlib.artist import (Artist, _finalize_rasterization, - allow_rasterization) -from matplotlib.axes import Axes -from matplotlib.backend_bases import (DrawEvent, FigureCanvasBase, MouseButton, - NonGuiException, _get_renderer) import matplotlib.cbook as cbook import matplotlib.colorbar as cbar -from matplotlib.gridspec import GridSpec import matplotlib.image as mimage -from matplotlib.layout_engine import (ConstrainedLayoutEngine, LayoutEngine, - PlaceHolderLayoutEngine, - TightLayoutEngine) + +from matplotlib.axes import Axes +from matplotlib.gridspec import GridSpec +from matplotlib.layout_engine import ( + ConstrainedLayoutEngine, TightLayoutEngine, LayoutEngine, + PlaceHolderLayoutEngine +) import matplotlib.legend as mlegend from matplotlib.patches import Rectangle from matplotlib.text import Text @@ -1868,11 +1870,18 @@ def subplot_mosaic(self, mosaic, *, sharex=False, sharey=False, The Axes identifiers may be `str` or a non-iterable hashable object (e.g. `tuple` s may not be used). - sharex, sharey : bool, default: False - If True, the x-axis (*sharex*) or y-axis (*sharey*) will be shared - among all subplots. In that case, tick label visibility and axis - units behave as for `subplots`. If False, each subplot's x- or - y-axis will be independent. + sharex, sharey : bool or {'none', 'all', 'row', 'col'}, default: False + Controls sharing of x-axis (*sharex*) or y-axis (*sharey*): + + - True or 'all': x- or y-axis will be shared among all subplots. + - False or 'none': each subplot x- or y-axis will be independent. + - 'row': each subplot with the same rows span will share an x- or + y-axis. + - 'col': each subplot with the same column span will share an x- or + y-axis. + + Tick label visibility and axis units behave as for `subplots`. + width_ratios : array-like of length *ncols*, optional Defines the relative widths of the columns. Each column gets a @@ -1957,7 +1966,7 @@ def subplot_mosaic(self, mosaic, *, sharex=False, sharey=False, per_subplot_kw = self._norm_per_subplot_kw(per_subplot_kw) # Only accept strict bools to allow a possible future API expansion. - # _api.check_isinstance(bool, sharex=sharex, sharey=sharey) + _api.check_isinstance((bool, str), sharex=sharex, sharey=sharey) def _make_array(inp): """ From 0bbb3fe97c444f24766e74e000b11db94169d93f Mon Sep 17 00:00:00 2001 From: "Keto D. Zhang" Date: Sun, 16 Jul 2023 10:47:32 -0500 Subject: [PATCH 07/12] Update type hinting --- lib/matplotlib/figure.pyi | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/matplotlib/figure.pyi b/lib/matplotlib/figure.pyi index 2b2bd9a49326..08d0b759ff66 100644 --- a/lib/matplotlib/figure.pyi +++ b/lib/matplotlib/figure.pyi @@ -231,8 +231,8 @@ class FigureBase(Artist): self, mosaic: str | HashableList, *, - sharex: bool = ..., - sharey: bool = ..., + sharex: bool | Literal["none", "col", "row", "all"] = ..., + sharey: bool | Literal["none", "col", "row", "all"] = ..., width_ratios: ArrayLike | None = ..., height_ratios: ArrayLike | None = ..., empty_sentinel: Any = ..., From 5294ebd3cfa8fa0fa1043c1df65c230b11110b70 Mon Sep 17 00:00:00 2001 From: "Keto D. Zhang" Date: Sun, 16 Jul 2023 11:03:22 -0500 Subject: [PATCH 08/12] Add tests --- lib/matplotlib/tests/test_figure.py | 43 +++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/lib/matplotlib/tests/test_figure.py b/lib/matplotlib/tests/test_figure.py index cfd5f27de13a..4f99aff80ffe 100644 --- a/lib/matplotlib/tests/test_figure.py +++ b/lib/matplotlib/tests/test_figure.py @@ -1248,6 +1248,49 @@ def test_share_all(self): assert all(ax.get_xscale() == "log" and ax.get_yscale() == "logit" for ax in ax_dict.values()) + @check_figures_equal(extensions=["png"]) + def test_sharex_row(self, fig_test, fig_ref): + fig_test.subplot_mosaic([["A", "B"], ["C", "D"]], + sharex="row", sharey=False) + + axes_ref = fig_ref.subplot_mosaic([["A", "B"], ["C", "D"]], + sharex=False, sharey=False) + axes_ref["A"].sharex(axes_ref["B"]) + axes_ref["C"].sharex(axes_ref["D"]) + + @check_figures_equal(extensions=["png"]) + def test_sharey_row(self, fig_test, fig_ref): + fig_test.subplot_mosaic([["A", "B"], ["C", "D"]], + sharex=False, sharey="row") + + axes_ref = fig_ref.subplot_mosaic([["A", "B"], ["C", "D"]], + sharex=False, sharey=False) + axes_ref["A"].sharey(axes_ref["B"]) + axes_ref["C"].sharey(axes_ref["D"]) + axes_ref["B"].yaxis.set_tick_params(which="both", labelleft=False) + axes_ref["D"].yaxis.set_tick_params(which="both", labelleft=False) + + @check_figures_equal(extensions=["png"]) + def test_sharex_col(self, fig_test, fig_ref): + fig_test.subplot_mosaic([["A", "B"], ["C", "D"]], + sharex="col", sharey=False) + + axes_ref = fig_ref.subplot_mosaic([["A", "B"], ["C", "D"]], + sharex=False, sharey=False) + axes_ref["A"].sharex(axes_ref["B"]) + axes_ref["B"].sharex(axes_ref["D"]) + axes_ref["A"].xaxis.set_tick_params(which="both", labelbottom=False) + axes_ref["B"].xaxis.set_tick_params(which="both", labelbottom=False) + + @check_figures_equal(extensions=["png"]) + def test_sharey_col(self, fig_test, fig_ref): + fig_test.subplot_mosaic([["A", "B"], ["C", "D"]], + sharex=False, sharey="col") + + axes_ref = fig_ref.subplot_mosaic([["A", "B"], ["C", "D"]], + sharex=False, sharey=False) + axes_ref["A"].sharey(axes_ref["C"]) + axes_ref["B"].sharey(axes_ref["D"]) def test_reused_gridspec(): """Test that these all use the same gridspec""" From edfcb3fdcf46418de70120a1bab8b8ec0077673e Mon Sep 17 00:00:00 2001 From: "Keto D. Zhang" Date: Sun, 16 Jul 2023 11:03:43 -0500 Subject: [PATCH 09/12] Turn off row/col share axis if nested --- lib/matplotlib/figure.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index 644f9a3ae7a8..17b50b81b864 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -2148,7 +2148,8 @@ def _do_layout(gs, mosaic, unique_ids, nested): mosaic = _make_array(mosaic) rows, cols = mosaic.shape gs = self.add_gridspec(rows, cols, **gridspec_kw) - ret = _do_layout(gs, mosaic, *_identify_keys_and_nested(mosaic)) + unique_labels, nested_coord_to_labels = _identify_keys_and_nested(mosaic) + ret = _do_layout(gs, mosaic, unique_labels, nested_coord_to_labels) # Handle axes sharing @@ -2174,16 +2175,22 @@ def _find_row_col_groups(mosaic, unique_labels): return row_group, col_group - row_groups, col_groups = _find_row_col_groups(mosaic, ret.keys()) - # Pairs of axes where the first axes is meant to call sharex/sharey on # the second axes. The second axes depends on if the sharex/sharey is # set to "row", "col", or "all". - share_axes_pairs = { - "row": ((ret[label], ret[row_group[0]]) for row_group in row_groups.values() for label in row_group), - "col": ((ret[label], ret[col_group[0]]) for col_group in col_groups.values() for label in col_group), - "all": ((ax, next(iter(ret.values()))) for ax in ret.values()) - } + + share_axes_pairs = {"all": tuple((ax, next(iter(ret.values()))) for ax in ret.values())} + if sharex in ("row", "col") or sharey in ("row", "col"): + if nested_coord_to_labels: + raise ValueError("Cannot share axes by row or column when using nested mosaic") + else: + row_groups, col_groups = _find_row_col_groups(mosaic, unique_labels) + + share_axes_pairs.update({ + "row": tuple((ret[label], ret[row_group[0]]) for row_group in row_groups.values() for label in row_group), + "col": tuple((ret[label], ret[col_group[0]]) for col_group in col_groups.values() for label in col_group), + }) + if sharex in share_axes_pairs: for ax, ax0 in share_axes_pairs[sharex]: ax.sharex(ax0) From 4bf4c69fa7cf09a5ff2957b633ab30c86aa27859 Mon Sep 17 00:00:00 2001 From: "Keto D. Zhang" Date: Sun, 16 Jul 2023 11:24:14 -0500 Subject: [PATCH 10/12] Add check if nested using gridspec equality check --- lib/matplotlib/figure.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index 17b50b81b864..c04d76fdca4b 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -2148,8 +2148,7 @@ def _do_layout(gs, mosaic, unique_ids, nested): mosaic = _make_array(mosaic) rows, cols = mosaic.shape gs = self.add_gridspec(rows, cols, **gridspec_kw) - unique_labels, nested_coord_to_labels = _identify_keys_and_nested(mosaic) - ret = _do_layout(gs, mosaic, unique_labels, nested_coord_to_labels) + ret = _do_layout(gs, mosaic, *_identify_keys_and_nested(mosaic)) # Handle axes sharing @@ -2179,12 +2178,19 @@ def _find_row_col_groups(mosaic, unique_labels): # the second axes. The second axes depends on if the sharex/sharey is # set to "row", "col", or "all". + def check_is_nested(): + gs0 = self.axes[0].get_subplotspec().get_gridspec() + for ax in self.axes[1:]: + if ax.get_subplotspec().get_gridspec() != gs0: + return True + return False + share_axes_pairs = {"all": tuple((ax, next(iter(ret.values()))) for ax in ret.values())} if sharex in ("row", "col") or sharey in ("row", "col"): - if nested_coord_to_labels: + if check_is_nested(): raise ValueError("Cannot share axes by row or column when using nested mosaic") else: - row_groups, col_groups = _find_row_col_groups(mosaic, unique_labels) + row_groups, col_groups = _find_row_col_groups(mosaic, ret.keys()) share_axes_pairs.update({ "row": tuple((ret[label], ret[row_group[0]]) for row_group in row_groups.values() for label in row_group), From af4f6aeffacf7d2e6e5d8e017e9e627fdf97cd47 Mon Sep 17 00:00:00 2001 From: "Keto D. Zhang" Date: Sun, 16 Jul 2023 11:58:18 -0500 Subject: [PATCH 11/12] Fix lint --- lib/matplotlib/figure.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index c04d76fdca4b..395794194a8b 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -2092,8 +2092,6 @@ def _do_layout(gs, mosaic, unique_ids, nested): start_col = slc[1].start this_level[(start_row, start_col)] = (label, slc, 'axes') - - # do the same thing for the nested mosaics (simpler because these # cannot be spans yet!) for (j, k), nested_mosaic in nested.items(): @@ -2113,7 +2111,6 @@ def _do_layout(gs, mosaic, unique_ids, nested): raise ValueError(f"There are duplicate keys {label} " f"in the layout\n{mosaic!r}") - # shared_with = {"none": None, "all": "all", "row": "row", "col": "col"} ax = self.add_subplot( gs[slc], **{ 'label': str(label), @@ -2144,14 +2141,12 @@ def _do_layout(gs, mosaic, unique_ids, nested): raise RuntimeError("This should never happen") return output - mosaic = _make_array(mosaic) rows, cols = mosaic.shape gs = self.add_gridspec(rows, cols, **gridspec_kw) ret = _do_layout(gs, mosaic, *_identify_keys_and_nested(mosaic)) # Handle axes sharing - def _find_row_col_groups(mosaic, unique_labels): label_to_span = _parse_mosaic_to_span(mosaic, unique_labels) @@ -2171,7 +2166,6 @@ def _find_row_col_groups(mosaic, unique_labels): else: row_group[(start_row, end_row)].append(label) - return row_group, col_group # Pairs of axes where the first axes is meant to call sharex/sharey on @@ -2185,16 +2179,31 @@ def check_is_nested(): return True return False - share_axes_pairs = {"all": tuple((ax, next(iter(ret.values()))) for ax in ret.values())} + share_axes_pairs = { + "all": tuple( + (ax, next(iter(ret.values()))) + for ax in ret.values() + ) + } if sharex in ("row", "col") or sharey in ("row", "col"): if check_is_nested(): - raise ValueError("Cannot share axes by row or column when using nested mosaic") + raise ValueError( + "Cannot share axes by row or column when using nested mosaic" + ) else: row_groups, col_groups = _find_row_col_groups(mosaic, ret.keys()) share_axes_pairs.update({ - "row": tuple((ret[label], ret[row_group[0]]) for row_group in row_groups.values() for label in row_group), - "col": tuple((ret[label], ret[col_group[0]]) for col_group in col_groups.values() for label in col_group), + "row": tuple( + (ret[label], ret[row_group[0]]) + for row_group in row_groups.values() + for label in row_group + ), + "col": tuple( + (ret[label], ret[col_group[0]]) + for col_group in col_groups.values() + for label in col_group + ), }) if sharex in share_axes_pairs: From 4f8f73a0daa17114bb94cb53af29650bce3390f0 Mon Sep 17 00:00:00 2001 From: "Keto D. Zhang" Date: Sun, 16 Jul 2023 11:58:53 -0500 Subject: [PATCH 12/12] Add test mosaic fails if nested and share is row/col --- lib/matplotlib/tests/test_figure.py | 69 ++++++++++++++++++++++++----- 1 file changed, 58 insertions(+), 11 deletions(-) diff --git a/lib/matplotlib/tests/test_figure.py b/lib/matplotlib/tests/test_figure.py index 4f99aff80ffe..009106890841 100644 --- a/lib/matplotlib/tests/test_figure.py +++ b/lib/matplotlib/tests/test_figure.py @@ -1234,7 +1234,8 @@ def test_nested_user_order(self): assert list(ax_dict) == list("ABCDEFGHI") assert list(fig.axes) == list(ax_dict.values()) - def test_share_all(self): + @pytest.mark.parametrize("sharex,sharey", [(True, True), ("all", "all")]) + def test_share_all(self, sharex, sharey): layout = [ ["A", [["B", "C"], ["D", "E"]]], @@ -1243,7 +1244,7 @@ def test_share_all(self): ["."]]]]] ] fig = plt.figure() - ax_dict = fig.subplot_mosaic(layout, sharex=True, sharey=True) + ax_dict = fig.subplot_mosaic(layout, sharex=sharex, sharey=sharey) ax_dict["A"].set(xscale="log", yscale="logit") assert all(ax.get_xscale() == "log" and ax.get_yscale() == "logit" for ax in ax_dict.values()) @@ -1253,8 +1254,14 @@ def test_sharex_row(self, fig_test, fig_ref): fig_test.subplot_mosaic([["A", "B"], ["C", "D"]], sharex="row", sharey=False) - axes_ref = fig_ref.subplot_mosaic([["A", "B"], ["C", "D"]], - sharex=False, sharey=False) + axes_ref = fig_ref.subplot_mosaic( + [ + ["A", "B"], + ["C", "D"] + ], + sharex=False, + sharey=False + ) axes_ref["A"].sharex(axes_ref["B"]) axes_ref["C"].sharex(axes_ref["D"]) @@ -1263,8 +1270,14 @@ def test_sharey_row(self, fig_test, fig_ref): fig_test.subplot_mosaic([["A", "B"], ["C", "D"]], sharex=False, sharey="row") - axes_ref = fig_ref.subplot_mosaic([["A", "B"], ["C", "D"]], - sharex=False, sharey=False) + axes_ref = fig_ref.subplot_mosaic( + [ + ["A", "B"], + ["C", "D"] + ], + sharex=False, + sharey=False + ) axes_ref["A"].sharey(axes_ref["B"]) axes_ref["C"].sharey(axes_ref["D"]) axes_ref["B"].yaxis.set_tick_params(which="both", labelleft=False) @@ -1274,9 +1287,14 @@ def test_sharey_row(self, fig_test, fig_ref): def test_sharex_col(self, fig_test, fig_ref): fig_test.subplot_mosaic([["A", "B"], ["C", "D"]], sharex="col", sharey=False) - - axes_ref = fig_ref.subplot_mosaic([["A", "B"], ["C", "D"]], - sharex=False, sharey=False) + axes_ref = fig_ref.subplot_mosaic( + [ + ["A", "B"], + ["C", "D"] + ], + sharex=False, + sharey=False + ) axes_ref["A"].sharex(axes_ref["B"]) axes_ref["B"].sharex(axes_ref["D"]) axes_ref["A"].xaxis.set_tick_params(which="both", labelbottom=False) @@ -1287,11 +1305,40 @@ def test_sharey_col(self, fig_test, fig_ref): fig_test.subplot_mosaic([["A", "B"], ["C", "D"]], sharex=False, sharey="col") - axes_ref = fig_ref.subplot_mosaic([["A", "B"], ["C", "D"]], - sharex=False, sharey=False) + axes_ref = fig_ref.subplot_mosaic( + [ + ["A", "B"], + ["C", "D"] + ], + sharex=False, + sharey=False + ) axes_ref["A"].sharey(axes_ref["C"]) axes_ref["B"].sharey(axes_ref["D"]) + @pytest.mark.parametrize( + "sharex,sharey", + [ + ("row", False), + (False, "row"), + ("col", False), + (False, "col"), + ("row", "col") + ] + ) + def test_share_row_col_fails_if_nested_mosaic(self, sharex, sharey): + mosaic = [ + ["A", [["B", "C"], + ["D", "E"]]], + ["F", "G"], + [".", [["H", [["I"], + ["."]]]]] + ] + fig = plt.figure() + with pytest.raises(ValueError): + fig.subplot_mosaic(mosaic, sharex=sharex, sharey=sharey) + + def test_reused_gridspec(): """Test that these all use the same gridspec""" fig = plt.figure()