diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index 11b42b1e1ac7..527ea9bd91d6 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -10,6 +10,9 @@ `SubFigure`) with `Figure.add_subfigure` or `Figure.subfigures` methods (provisional API v3.4). +`SubplotParams` + Control the default spacing between subplots. + Figures are typically created using pyplot methods `~.pyplot.figure`, `~.pyplot.subplots`, and `~.pyplot.subplot_mosaic`. @@ -48,7 +51,7 @@ import matplotlib.image as mimage from matplotlib.axes import Axes -from matplotlib.gridspec import GridSpec, SubplotParams +from matplotlib.gridspec import GridSpec from matplotlib.layout_engine import ( ConstrainedLayoutEngine, TightLayoutEngine, LayoutEngine, PlaceHolderLayoutEngine @@ -115,6 +118,66 @@ def __setstate__(self, state): self._counter = itertools.count(next_counter) +class SubplotParams: + """ + A class to hold the parameters for a subplot. + """ + + def __init__(self, left=None, bottom=None, right=None, top=None, + wspace=None, hspace=None): + """ + Defaults are given by :rc:`figure.subplot.[name]`. + + Parameters + ---------- + left : float + The position of the left edge of the subplots, + as a fraction of the figure width. + right : float + The position of the right edge of the subplots, + as a fraction of the figure width. + bottom : float + The position of the bottom edge of the subplots, + as a fraction of the figure height. + top : float + The position of the top edge of the subplots, + as a fraction of the figure height. + wspace : float + The width of the padding between subplots, + as a fraction of the average Axes width. + hspace : float + The height of the padding between subplots, + as a fraction of the average Axes height. + """ + for key in ["left", "bottom", "right", "top", "wspace", "hspace"]: + setattr(self, key, mpl.rcParams[f"figure.subplot.{key}"]) + self.update(left, bottom, right, top, wspace, hspace) + + def update(self, left=None, bottom=None, right=None, top=None, + wspace=None, hspace=None): + """ + Update the dimensions of the passed parameters. *None* means unchanged. + """ + if ((left if left is not None else self.left) + >= (right if right is not None else self.right)): + raise ValueError('left cannot be >= right') + if ((bottom if bottom is not None else self.bottom) + >= (top if top is not None else self.top)): + raise ValueError('bottom cannot be >= top') + if left is not None: + self.left = left + if right is not None: + self.right = right + if bottom is not None: + self.bottom = bottom + if top is not None: + self.top = top + if wspace is not None: + self.wspace = wspace + if hspace is not None: + self.hspace = hspace + + class FigureBase(Artist): """ Base class for `.Figure` and `.SubFigure` containing the methods that add @@ -1291,6 +1354,65 @@ def subplots_adjust(self, left=None, bottom=None, right=None, top=None, ax._set_position(ax.get_subplotspec().get_position(self)) self.stale = True + def set_subplotpars(self, subplotparams={}): + """ + Set the subplot layout parameters. + Accepts either a `.SubplotParams` object, from which the relevant + parameters are copied, or a dictionary of subplot layout parameters. + If a dictionary is provided, this function is a convenience wrapper for + `matplotlib.figure.Figure.subplots_adjust` + + Parameters + ---------- + subplotparams : `~matplotlib.figure.SubplotParams` or dict with keys + "left", "bottom", "right", 'top", "wspace", "hspace"] , optional + SubplotParams object to copy new subplot parameters from, or a dict + of SubplotParams constructor arguments. + By default, an empty dictionary is passed, which maintains the + current state of the figure's `.SubplotParams` + + See Also + -------- + matplotlib.figure.Figure.subplots_adjust + matplotlib.figure.Figure.get_subplotpars + """ + subplotparams_args = ["left", "bottom", "right", + "top", "wspace", "hspace"] + kwargs = {} + if isinstance(subplotparams, SubplotParams): + for key in subplotparams_args: + kwargs[key] = getattr(subplotparams, key) + elif isinstance(subplotparams, dict): + for key in subplotparams.keys(): + if key in subplotparams_args: + kwargs[key] = subplotparams[key] + else: + _api.warn_external( + f"'{key}' is not a valid key for set_subplotpars;" + " this key was ignored.") + else: + raise TypeError( + "subplotpars must be a dictionary of keyword-argument pairs or" + " an instance of SubplotParams()") + if kwargs == {}: + self.set_subplotpars(self.get_subplotpars()) + self.subplots_adjust(**kwargs) + + def get_subplotpars(self): + """ + Return the `.SubplotParams` object associated with the Figure. + + Returns + ------- + `.SubplotParams` + + See Also + -------- + matplotlib.figure.Figure.subplots_adjust + matplotlib.figure.Figure.get_subplotpars + """ + return self.subplotpars + def align_xlabels(self, axs=None): """ Align the xlabels of subplots in the same subplot column if label @@ -2306,6 +2428,10 @@ def draw(self, renderer): @_docstring.interpd +@_api.define_aliases({ + "size_inches": ["figsize"], + "layout_engine": ["layout"] +}) class Figure(FigureBase): """ The top level container for all the plot elements. @@ -2381,7 +2507,7 @@ def __init__(self, frameon : bool, default: :rc:`figure.frameon` If ``False``, suppress drawing the figure background patch. - subplotpars : `~matplotlib.gridspec.SubplotParams` + subplotpars : `SubplotParams` Subplot parameters. If not given, the default subplot parameters :rc:`figure.subplot.*` are used. diff --git a/lib/matplotlib/figure.pyi b/lib/matplotlib/figure.pyi index 687ae9e500d0..733223bde2e6 100644 --- a/lib/matplotlib/figure.pyi +++ b/lib/matplotlib/figure.pyi @@ -16,7 +16,7 @@ from matplotlib.backend_bases import ( from matplotlib.colors import Colormap, Normalize from matplotlib.colorbar import Colorbar from matplotlib.cm import ScalarMappable -from matplotlib.gridspec import GridSpec, SubplotSpec, SubplotParams as SubplotParams +from matplotlib.gridspec import GridSpec, SubplotSpec from matplotlib.image import _ImageBase, FigureImage from matplotlib.layout_engine import LayoutEngine from matplotlib.legend import Legend @@ -28,6 +28,32 @@ from .typing import ColorType, HashableList _T = TypeVar("_T") +class SubplotParams: + def __init__( + self, + left: float | None = ..., + bottom: float | None = ..., + right: float | None = ..., + top: float | None = ..., + wspace: float | None = ..., + hspace: float | None = ..., + ) -> None: ... + left: float + right: float + bottom: float + top: float + wspace: float + hspace: float + def update( + self, + left: float | None = ..., + bottom: float | None = ..., + right: float | None = ..., + top: float | None = ..., + wspace: float | None = ..., + hspace: float | None = ..., + ) -> None: ... + class FigureBase(Artist): artists: list[Artist] lines: list[Line2D] @@ -244,6 +270,13 @@ class FigureBase(Artist): gridspec_kw: dict[str, Any] | None = ..., ) -> dict[Hashable, Axes]: ... + def set_subplotpars( + self, + subplotparams: SubplotParams | dict[str, Any] = ..., + ) -> None: ... + + def get_subplotpars(self) -> SubplotParams: ... + class SubFigure(FigureBase): figure: Figure subplotpars: SubplotParams diff --git a/lib/matplotlib/tests/test_figure.py b/lib/matplotlib/tests/test_figure.py index 99b2602bc4a7..1c6c8b3165e8 100644 --- a/lib/matplotlib/tests/test_figure.py +++ b/lib/matplotlib/tests/test_figure.py @@ -1652,6 +1652,45 @@ def test_get_constrained_layout_pads(): assert fig.get_constrained_layout_pads() == expected +def test_get_subplot_params(): + fig = plt.figure() + subplotparams_keys = ["left", "bottom", "right", "top", "wspace", "hspace"] + subplotparams = fig.get_subplotpars() + test_dict = {} + for key in subplotparams_keys: + attr = getattr(subplotparams, key) + assert attr == mpl.rcParams[f"figure.subplot.{key}"] + test_dict[key] = attr * 2 + + fig.set_subplotpars(test_dict) + for key, value in test_dict.items(): + assert getattr(fig.get_subplotpars(), key) == value + + test_dict['foo'] = 'bar' + with pytest.warns(UserWarning, + match="'foo' is not a valid key for set_subplotpars;" + " this key was ignored"): + fig.set_subplotpars(test_dict) + + with pytest.raises(TypeError, + match="subplotpars must be a dictionary of " + "keyword-argument pairs or " + "an instance of SubplotParams()"): + fig.set_subplotpars(['foo']) + + assert fig.subplotpars == fig.get_subplotpars() + + +def test_fig_get_set(): + varnames = filter(lambda var: var not in ['self', 'kwargs', 'args', 'connect'], + Figure.__init__.__code__.co_varnames) + fig = plt.figure() + for var in varnames: + # if getattr fails then the getter and setter does not exist + getfunc = getattr(fig, f"get_{var}") + setfunc = getattr(fig, f"set_{var}") + + def test_not_visible_figure(): fig = Figure()