diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index e5cf88131178..b9454451312f 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -2893,6 +2893,111 @@ def set_tight_layout(self, tight): self.set_layout_engine(_tight, **_tight_parameters) self.stale = True + def set_subplotparams(self, subplotparams={}): + """ + Set the subplot layout parameters. + Accepts either a `.SubplotParams` object, from which the relevant + parameters are copied, or a dictionary of subplot layout parameters. + If a dictionary is provided, this function is a convenience wrapper for + `matplotlib.figure.Figure.subplots_adjust` + + Parameters + ---------- + subplotparams : `~matplotlib.figure.SubplotParams` or dict with keys + "left", "bottom", "right", 'top", "wspace", "hspace"] , optional + SubplotParams object to copy new subplot parameters from, or a dict + of SubplotParams constructor arguments. + By default, an empty dictionary is passed, which maintains the + current state of the figure's `.SubplotParams` + + See Also + -------- + matplotlib.figure.Figure.subplots_adjust + matplotlib.figure.Figure.get_subplotparams + """ + + subplotparams_args = ["left", "bottom", "right", + "top", "wspace", "hspace"] + kwargs = {} + if isinstance(subplotparams, SubplotParams): + for key in subplotparams_args: + kwargs[key] = getattr(subplotparams, key) + elif isinstance(subplotparams, dict): + for key in subplotparams.keys(): + if key in subplotparams_args: + kwargs[key] = subplotparams[key] + else: + _api.warn_external( + f"'{key}' is not a valid key for set_subplotparams;" + " this key was ignored.") + else: + raise TypeError( + "subplotparams must be a dictionary of keyword-argument pairs or" + " an instance of SubplotParams()") + if kwargs == {}: + self.set_subplotparams(self.get_subplotparams()) + self.subplots_adjust(**kwargs) + + def get_subplotparams(self): + """ + Return the `.SubplotParams` object associated with the Figure. + + Returns + ------- + .SubplotParams` + + See Also + -------- + matplotlib.figure.Figure.subplots_adjust + matplotlib.figure.Figure.get_subplotparams + """ + return self.subplotpars + + def get_figsize(self): + """ + Return the current size of the figure. + + Returns + ------- + ndarray + The size (width, height) of the figure in inches. + + See Also + -------- + matplotlib.figure.Figure.get_size_inches + matplotlib.figure.Figure.set_size_inches + matplotlib.figure.Figure.get_figwidth + matplotlib.figure.Figure.get_figheight + + Notes + ----- + The size in pixels can be obtained by multiplying with `Figure.dpi`. + """ + return self.get_size_inches() + + def set_figsize(self, w, h=None, forward=True): + """ + Set the figure size. + + Parameters + ---------- + w : (float, float) or float + Width and height in inches (if height not specified as a separate + argument) or width. + h : float + Height in inches. + forward : bool, default: True + If ``True``, the canvas size is automatically updated, e.g., + you can resize the figure window from the shell. + + See Also + -------- + matplotlib.figure.Figure.get_figsize + matplotlib.figure.Figure.get_size_inches + matplotlib.figure.Figure.set_size_inches + """ + self.set_size_inches(w, h, forward=True) + def get_constrained_layout(self): """ Return whether constrained layout is being used. diff --git a/lib/matplotlib/figure.pyi b/lib/matplotlib/figure.pyi index 08bf1505532b..be961493a62b 100644 --- a/lib/matplotlib/figure.pyi +++ b/lib/matplotlib/figure.pyi @@ -362,6 +362,8 @@ class Figure(FigureBase): def get_constrained_layout(self) -> bool: ... canvas: FigureCanvasBase def set_canvas(self, canvas: FigureCanvasBase) -> None: ... + def set_subplotparams(self, subplotparams: SubplotParams | dict = {}) -> None: ... + def get_subplotparams(self) -> Any: ... def figimage( self, X: ArrayLike, @@ -382,6 +384,10 @@ class Figure(FigureBase): self, w: float | tuple[float, float], h: float | None = ..., forward: bool = ... ) -> None: ... def get_size_inches(self) -> np.ndarray: ... + def set_figsize( + self, w: float | tuple[float, float], h: float | None = ..., forward: bool = ... + ) -> None: ... + def get_figsize(self) -> np.ndarray: ... def get_figwidth(self) -> float: ... def get_figheight(self) -> float: ... def get_dpi(self) -> float: ... diff --git a/lib/matplotlib/tests/test_figure.py b/lib/matplotlib/tests/test_figure.py index edf5ea05f119..869187c913e1 100644 --- a/lib/matplotlib/tests/test_figure.py +++ b/lib/matplotlib/tests/test_figure.py @@ -265,6 +265,66 @@ def test_gca(): assert fig.gca() is ax1 +def test_get_subplotparams(): + fig = plt.figure() + subplotparams_keys = ["left", "bottom", "right", "top", "wspace", "hspace"] + subplotparams = fig.get_subplotparams() + for key in subplotparams_keys: + attr = getattr(subplotparams, key) + assert attr == mpl.rcParams[f"figure.subplot.{key}"] + + +def test_set_subplotparams(): + fig = plt.figure() + subplotparams_keys = ["left", "bottom", "right", "top", "wspace", "hspace"] + subplotparams = fig.get_subplotparams() + test_dict = {} + for key in subplotparams_keys: + attr = getattr(subplotparams, key) + assert attr == mpl.rcParams[f"figure.subplot.{key}"] + test_dict[key] = attr * 2 + + fig.set_subplotparams() + for key in subplotparams_keys: + attr = getattr(subplotparams, key) + assert attr == mpl.rcParams[f"figure.subplot.{key}"] + + fig.set_subplotparams(fig.get_subplotparams()) + for key in subplotparams_keys: + attr = getattr(subplotparams, key) + assert attr == mpl.rcParams[f"figure.subplot.{key}"] + + fig.set_subplotparams(test_dict) + for key, value in test_dict.items(): + assert getattr(fig.get_subplotparams(), key) == value + + test_dict['foo'] = 'bar' + with pytest.warns(UserWarning, + match="'foo' is not a valid key for set_subplotparams;" + " this key was ignored"): + fig.set_subplotparams(test_dict) + + with pytest.raises(TypeError, + match="subplotparams must be a dictionary of " + "keyword-argument pairs or " + "an instance of SubplotParams()"): + fig.set_subplotparams(['foo']) + + +def test_set_figsize(): + fig = plt.figure() + fig.set_figsize(2, 4) + assert fig.get_figwidth() == 2 + assert fig.get_figheight() == 4 + + +def test_get_figsize(): + fig = plt.figure() + # check using tuple to first argument + fig.set_figsize((1, 3)) + assert np.array_equal(fig.get_figsize(), [1, 3]) + + def test_add_subplot_subclass(): fig = plt.figure() fig.add_subplot(axes_class=Axes)